00001 #include <cmath>
00002 #include <stdexcept>
00003
00004 #include "moses/Incremental.h"
00005
00006 #include "moses/ChartCell.h"
00007 #include "moses/ChartParserCallback.h"
00008 #include "moses/FeatureVector.h"
00009 #include "moses/StaticData.h"
00010 #include "moses/Util.h"
00011 #include "moses/LM/Base.h"
00012 #include "moses/OutputCollector.h"
00013
00014 #include "lm/model.hh"
00015 #include "search/applied.hh"
00016 #include "search/config.hh"
00017 #include "search/context.hh"
00018 #include "search/edge_generator.hh"
00019 #include "search/rule.hh"
00020 #include "search/vertex_generator.hh"
00021
00022 #include <boost/lexical_cast.hpp>
00023
00024 namespace Moses
00025 {
00026 namespace Incremental
00027 {
00028 namespace
00029 {
00030
00031
00032
00033 template <class Best> class HypothesisCallback
00034 {
00035 private:
00036 typedef search::VertexGenerator<Best> Gen;
00037 public:
00038 HypothesisCallback(search::ContextBase &context, Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool)
00039 : context_(context), best_(best), out_(out), vertex_pool_(vertex_pool) {}
00040
00041 void NewHypothesis(search::PartialEdge partial) {
00042
00043
00044 ChartCellLabel::Stack &stack = out_.FindOrInsert(static_cast<const TargetPhrase *>(partial.GetNote().vp)->GetTargetLHS());
00045 Gen *entry = static_cast<Gen*>(stack.incr_generator);
00046 if (!entry) {
00047 entry = generator_pool_.construct(boost::ref(context_), boost::ref(*vertex_pool_.construct()), boost::ref(best_));
00048 stack.incr_generator = entry;
00049 }
00050 entry->NewHypothesis(partial);
00051 }
00052
00053 void FinishedSearch() {
00054 for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
00055 if ((*i) == NULL) {
00056 continue;
00057 }
00058 ChartCellLabel::Stack &stack = (*i)->MutableStack();
00059 Gen *gen = static_cast<Gen*>(stack.incr_generator);
00060 gen->FinishedSearch();
00061 stack.incr = &gen->Generating();
00062 }
00063 }
00064
00065 private:
00066 search::ContextBase &context_;
00067
00068 Best &best_;
00069
00070 ChartCellLabelSet &out_;
00071
00072 boost::object_pool<search::Vertex> &vertex_pool_;
00073 boost::object_pool<Gen> generator_pool_;
00074 };
00075
00076
00077
00078 template <class Model> class Fill : public ChartParserCallback
00079 {
00080 public:
00081 Fill(search::Context<Model> &context, const std::vector<lm::WordIndex> &vocab_mapping, search::Score oov_weight)
00082 : context_(context), vocab_mapping_(vocab_mapping), oov_weight_(oov_weight) {}
00083
00084 void Add(const TargetPhraseCollection &targets, const StackVec &nts, const Range &ignored);
00085
00086 void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection::shared_ptr > &waste_memory, const Range &range);
00087
00088 float GetBestScore(const ChartCellLabel *chartCell) const;
00089
00090 bool Empty() const {
00091 return edges_.Empty();
00092 }
00093
00094 template <class Best> void Search(Best &best, ChartCellLabelSet &out, boost::object_pool<search::Vertex> &vertex_pool) {
00095 HypothesisCallback<Best> callback(context_, best, out, vertex_pool);
00096 edges_.Search(context_, callback);
00097 }
00098
00099
00100 template <class Best> search::History RootSearch(Best &best) {
00101 search::Vertex vertex;
00102 search::RootVertexGenerator<Best> gen(vertex, best);
00103 edges_.Search(context_, gen);
00104 return vertex.BestChild();
00105 }
00106
00107 void EvaluateWithSourceContext(const InputType &input, const InputPath &inputPath) {
00108
00109 }
00110 private:
00111 lm::WordIndex Convert(const Word &word) const;
00112
00113 search::Context<Model> &context_;
00114
00115 const std::vector<lm::WordIndex> &vocab_mapping_;
00116
00117 search::EdgeGenerator edges_;
00118
00119 const search::Score oov_weight_;
00120 };
00121
00122 template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const Range &range)
00123 {
00124 std::vector<search::PartialVertex> vertices;
00125 vertices.reserve(nts.size());
00126 float below_score = 0.0;
00127 for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
00128 vertices.push_back((*i)->GetStack().incr->RootAlternate());
00129 below_score += (*i)->GetBestScore(this);
00130 }
00131
00132 std::vector<lm::WordIndex> words;
00133 for (TargetPhraseCollection::const_iterator p(targets.begin()); p != targets.end(); ++p) {
00134 words.clear();
00135 const TargetPhrase &phrase = **p;
00136 const AlignmentInfo::NonTermIndexMap &align = phrase.GetAlignNonTerm().GetNonTermIndexMap();
00137 search::PartialEdge edge(edges_.AllocateEdge(nts.size()));
00138
00139 search::PartialVertex *nt = edge.NT();
00140 for (size_t i = 0; i < phrase.GetSize(); ++i) {
00141 const Word &word = phrase.GetWord(i);
00142 if (word.IsNonTerminal()) {
00143 *(nt++) = vertices[align[i]];
00144 words.push_back(search::kNonTerminal);
00145 } else {
00146 words.push_back(Convert(word));
00147 }
00148 }
00149
00150 edge.SetScore(phrase.GetFutureScore() + below_score);
00151
00152 search::ScoreRule(context_.LanguageModel(), words, edge.Between());
00153
00154 search::Note note;
00155 note.vp = &phrase;
00156 edge.SetNote(note);
00157 edge.SetRange(range);
00158
00159 edges_.AddEdge(edge);
00160 }
00161 }
00162
00163 template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection::shared_ptr > &, const Range &range)
00164 {
00165 std::vector<lm::WordIndex> words;
00166 UTIL_THROW_IF2(phrase.GetSize() > 1,
00167 "OOV target phrase should be 0 or 1 word in length");
00168 if (phrase.GetSize())
00169 words.push_back(Convert(phrase.GetWord(0)));
00170
00171 search::PartialEdge edge(edges_.AllocateEdge(0));
00172
00173 search::ScoreRuleRet scored(search::ScoreRule(context_.LanguageModel(), words, edge.Between()));
00174 edge.SetScore(phrase.GetFutureScore() + scored.prob * context_.LMWeight() + static_cast<search::Score>(scored.oov) * oov_weight_);
00175
00176 search::Note note;
00177 note.vp = &phrase;
00178 edge.SetNote(note);
00179 edge.SetRange(range);
00180
00181 edges_.AddEdge(edge);
00182 }
00183
00184
00185 template <class Model> float Fill<Model>::GetBestScore(const ChartCellLabel *chartCell) const
00186 {
00187 search::PartialVertex vertex = chartCell->GetStack().incr->RootAlternate();
00188 UTIL_THROW_IF2(vertex.Empty(), "hypothesis with empty stack");
00189 return vertex.Bound();
00190 }
00191
00192
00193 template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const
00194 {
00195 std::size_t factor = word.GetFactor(0)->GetId();
00196 return (factor >= vocab_mapping_.size() ? 0 : vocab_mapping_[factor]);
00197 }
00198
00199 struct ChartCellBaseFactory {
00200 ChartCellBase *operator()(size_t startPos, size_t endPos) const {
00201 return new ChartCellBase(startPos, endPos);
00202 }
00203 };
00204
00205 }
00206
00207 Manager::Manager(ttasksptr const& ttask)
00208 : BaseManager(ttask)
00209 , cells_(m_source, ChartCellBaseFactory(), parser_)
00210 , parser_(ttask, cells_)
00211 , n_best_(search::NBestConfig(StaticData::Instance().options()->nbest.nbest_size))
00212 { }
00213
00214 Manager::~Manager()
00215 { }
00216
00217
00218 namespace
00219 {
00220
00221
00222
00223 const float log_10 = logf(10);
00224 }
00225
00226 template <class Model, class Best>
00227 search::History
00228 Manager::
00229 PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out)
00230 {
00231 const LanguageModel &abstract = LanguageModel::GetFirstLM();
00232 const StaticData &data = StaticData::Instance();
00233 const float lm_weight = data.GetWeights(&abstract)[0];
00234 const float oov_weight = abstract.OOVFeatureEnabled() ? data.GetWeights(&abstract)[1] : 0.0;
00235 size_t cpl = data.options()->cube.pop_limit;
00236 size_t nbs = data.options()->nbest.nbest_size;
00237 search::Config config(lm_weight * log_10, cpl, search::NBestConfig(nbs));
00238 search::Context<Model> context(config, model);
00239
00240 size_t size = m_source.GetSize();
00241 boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
00242
00243 for (int startPos = size-1; startPos >= 0; --startPos) {
00244 for (size_t width = 1; width <= size-startPos; ++width) {
00245
00246 if (startPos == 0 && startPos + width == size) {
00247 break;
00248 }
00249 Range range(startPos, startPos + width - 1);
00250 Fill<Model> filler(context, words, oov_weight);
00251 parser_.Create(range, filler);
00252 filler.Search(out, cells_.MutableBase(range).MutableTargetLabelSet(), vertex_pool);
00253 }
00254 }
00255
00256 Range range(0, size - 1);
00257 Fill<Model> filler(context, words, oov_weight);
00258 parser_.Create(range, filler);
00259 return filler.RootSearch(out);
00260 }
00261
00262 template <class Model> void Manager::LMCallback(const Model &model, const std::vector<lm::WordIndex> &words)
00263 {
00264 std::size_t nbest = StaticData::Instance().options()->nbest.nbest_size;
00265 if (nbest <= 1) {
00266 search::History ret = PopulateBest(model, words, single_best_);
00267 if (ret) {
00268 backing_for_single_.resize(1);
00269 backing_for_single_[0] = search::Applied(ret);
00270 } else {
00271 backing_for_single_.clear();
00272 }
00273 completed_nbest_ = &backing_for_single_;
00274 } else {
00275 search::History ret = PopulateBest(model, words, n_best_);
00276 if (ret) {
00277 completed_nbest_ = &n_best_.Extract(ret);
00278 } else {
00279 backing_for_single_.clear();
00280 completed_nbest_ = &backing_for_single_;
00281 }
00282 }
00283 }
00284
00285 template void Manager::LMCallback<lm::ngram::ProbingModel>(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words);
00286 template void Manager::LMCallback<lm::ngram::RestProbingModel>(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words);
00287 template void Manager::LMCallback<lm::ngram::TrieModel>(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words);
00288 template void Manager::LMCallback<lm::ngram::QuantTrieModel>(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words);
00289 template void Manager::LMCallback<lm::ngram::ArrayTrieModel>(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
00290 template void Manager::LMCallback<lm::ngram::QuantArrayTrieModel>(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words);
00291
00292 void Manager::Decode()
00293 {
00294 LanguageModel::GetFirstLM().IncrementalCallback(*this);
00295 }
00296
00297 const std::vector<search::Applied> &Manager::GetNBest() const
00298 {
00299 return *completed_nbest_;
00300 }
00301
00302 void Manager::OutputBest(OutputCollector *collector) const
00303 {
00304 const long translationId = m_source.GetTranslationId();
00305 const std::vector<search::Applied> &nbest = GetNBest();
00306 if (!nbest.empty()) {
00307 OutputBestHypo(collector, nbest[0], translationId);
00308 } else {
00309 OutputBestNone(collector, translationId);
00310 }
00311
00312 }
00313
00314
00315 void Manager::OutputNBest(OutputCollector *collector) const
00316 {
00317 if (collector == NULL) {
00318 return;
00319 }
00320
00321 OutputNBestList(collector, *completed_nbest_, m_source.GetTranslationId());
00322 }
00323
00324 void
00325 Manager::
00326 OutputNBestList(OutputCollector *collector,
00327 std::vector<search::Applied> const& nbest,
00328 long translationId) const
00329 {
00330 const std::vector<Moses::FactorType> &outputFactorOrder
00331 = options()->output.factor_order;
00332
00333 std::ostringstream out;
00334
00335 if (collector->OutputIsCout()) {
00336 FixPrecision(out);
00337 }
00338 Phrase outputPhrase;
00339 ScoreComponentCollection features;
00340 for (std::vector<search::Applied>::const_iterator i = nbest.begin();
00341 i != nbest.end(); ++i) {
00342 Incremental::PhraseAndFeatures(*i, outputPhrase, features);
00343
00344 UTIL_THROW_IF2(outputPhrase.GetSize() < 2,
00345 "Output phrase should have contained at least 2 words "
00346 << "(beginning and end-of-sentence)");
00347
00348 outputPhrase.RemoveWord(0);
00349 outputPhrase.RemoveWord(outputPhrase.GetSize() - 1);
00350 out << translationId << " ||| ";
00351 OutputSurface(out, outputPhrase);
00352 out << " ||| ";
00353 bool with_labels = options()->nbest.include_feature_labels;
00354 features.OutputAllFeatureScores(out, with_labels);
00355 out << " ||| " << i->GetScore() << '\n';
00356 }
00357 out << std::flush;
00358 assert(collector);
00359 collector->Write(translationId, out.str());
00360 }
00361
00362 void
00363 Manager::
00364 OutputDetailedTranslationReport(OutputCollector *collector) const
00365 {
00366 if (collector && !completed_nbest_->empty()) {
00367 const search::Applied &applied = completed_nbest_->at(0);
00368 OutputDetailedTranslationReport(collector,
00369 &applied,
00370 static_cast<const Sentence&>(m_source),
00371 m_source.GetTranslationId());
00372 }
00373
00374 }
00375
00376 void Manager::OutputDetailedTranslationReport(
00377 OutputCollector *collector,
00378 const search::Applied *applied,
00379 const Sentence &sentence,
00380 long translationId) const
00381 {
00382 if (applied == NULL) {
00383 return;
00384 }
00385 std::ostringstream out;
00386 ApplicationContext applicationContext;
00387
00388 OutputTranslationOptions(out, applicationContext, applied, sentence, translationId);
00389 collector->Write(translationId, out.str());
00390 }
00391
00392 void Manager::OutputTranslationOptions(std::ostream &out,
00393 ApplicationContext &applicationContext,
00394 const search::Applied *applied,
00395 const Sentence &sentence, long translationId) const
00396 {
00397 if (applied != NULL) {
00398 OutputTranslationOption(out, applicationContext, applied, sentence, translationId);
00399 out << std::endl;
00400 }
00401
00402
00403 const search::Applied *child = applied->Children();
00404 for (size_t i = 0; i < applied->GetArity(); i++) {
00405 OutputTranslationOptions(out, applicationContext, child++, sentence, translationId);
00406 }
00407 }
00408
00409 void Manager::OutputTranslationOption(std::ostream &out,
00410 ApplicationContext &applicationContext,
00411 const search::Applied *applied,
00412 const Sentence &sentence,
00413 long translationId) const
00414 {
00415 ReconstructApplicationContext(applied, sentence, applicationContext);
00416 const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(applied->GetNote().vp);
00417 out << "Trans Opt " << translationId
00418 << " " << applied->GetRange()
00419 << ": ";
00420 WriteApplicationContext(out, applicationContext);
00421 out << ": " << phrase.GetTargetLHS()
00422 << "->" << phrase
00423 << " " << applied->GetScore();
00424 }
00425
00426
00427
00428 void Manager::ReconstructApplicationContext(const search::Applied *applied,
00429 const Sentence &sentence,
00430 ApplicationContext &context) const
00431 {
00432 context.clear();
00433 const Range &span = applied->GetRange();
00434 const search::Applied *child = applied->Children();
00435 size_t i = span.GetStartPos();
00436 size_t j = 0;
00437
00438 while (i <= span.GetEndPos()) {
00439 if (j == applied->GetArity() || i < child->GetRange().GetStartPos()) {
00440
00441 const Word &symbol = sentence.GetWord(i);
00442 context.push_back(std::make_pair(symbol, Range(i, i)));
00443 ++i;
00444 } else {
00445
00446 const Word &symbol = static_cast<const TargetPhrase*>(child->GetNote().vp)->GetTargetLHS();
00447 const Range &range = child->GetRange();
00448 context.push_back(std::make_pair(symbol, range));
00449 i = range.GetEndPos()+1;
00450 ++child;
00451 ++j;
00452 }
00453 }
00454 }
00455
00456 void Manager::OutputDetailedTreeFragmentsTranslationReport(OutputCollector *collector) const
00457 {
00458 if (collector == NULL || Completed().empty()) {
00459 return;
00460 }
00461
00462 const search::Applied *applied = &Completed()[0];
00463 const Sentence &sentence = static_cast<const Sentence &>(m_source);
00464 const size_t translationId = m_source.GetTranslationId();
00465
00466 std::ostringstream out;
00467 ApplicationContext applicationContext;
00468
00469 OutputTreeFragmentsTranslationOptions(out, applicationContext, applied, sentence, translationId);
00470
00471
00472
00473
00474 collector->Write(translationId, out.str());
00475
00476 }
00477
00478 void Manager::OutputTreeFragmentsTranslationOptions(std::ostream &out,
00479 ApplicationContext &applicationContext,
00480 const search::Applied *applied,
00481 const Sentence &sentence,
00482 long translationId) const
00483 {
00484
00485 if (applied != NULL) {
00486 OutputTranslationOption(out, applicationContext, applied, sentence, translationId);
00487
00488 const TargetPhrase &currTarPhr = *static_cast<const TargetPhrase*>(applied->GetNote().vp);
00489
00490 out << " ||| ";
00491 if (const PhraseProperty *property = currTarPhr.GetProperty("Tree")) {
00492 out << " " << *property->GetValueString();
00493 } else {
00494 out << " " << "noTreeInfo";
00495 }
00496 out << std::endl;
00497 }
00498
00499
00500 const search::Applied *child = applied->Children();
00501 for (size_t i = 0; i < applied->GetArity(); i++) {
00502 OutputTreeFragmentsTranslationOptions(out, applicationContext, child++, sentence, translationId);
00503 }
00504 }
00505
00506 void Manager::OutputBestHypo(OutputCollector *collector, search::Applied applied, long translationId) const
00507 {
00508 if (collector == NULL) return;
00509 std::ostringstream out;
00510 FixPrecision(out);
00511 if (options()->output.ReportHypoScore) {
00512 out << applied.GetScore() << ' ';
00513 }
00514 Phrase outPhrase;
00515 Incremental::ToPhrase(applied, outPhrase);
00516
00517 UTIL_THROW_IF2(outPhrase.GetSize() < 2,
00518 "Output phrase should have contained at least 2 words (beginning and end-of-sentence)");
00519 outPhrase.RemoveWord(0);
00520 outPhrase.RemoveWord(outPhrase.GetSize() - 1);
00521 out << outPhrase.GetStringRep(options()->output.factor_order);
00522 out << '\n';
00523 collector->Write(translationId, out.str());
00524
00525 VERBOSE(1,"BEST TRANSLATION: " << outPhrase << "[total=" << applied.GetScore() << "]" << std::endl);
00526 }
00527
00528 void
00529 Manager::
00530 OutputBestNone(OutputCollector *collector, long translationId) const
00531 {
00532 if (collector == NULL) return;
00533 if (options()->output.ReportHypoScore) {
00534 collector->Write(translationId, "0 \n");
00535 } else {
00536 collector->Write(translationId, "\n");
00537 }
00538 }
00539
00540 namespace
00541 {
00542
00543 struct NoOp {
00544 void operator()(const TargetPhrase &) const {}
00545 };
00546 struct AccumScore {
00547 AccumScore(ScoreComponentCollection &out) : out_(&out) {}
00548 void operator()(const TargetPhrase &phrase) {
00549 out_->PlusEquals(phrase.GetScoreBreakdown());
00550 }
00551 ScoreComponentCollection *out_;
00552 };
00553 template <class Action> void AppendToPhrase(const search::Applied final, Phrase &out, Action action)
00554 {
00555 assert(final.Valid());
00556 const TargetPhrase &phrase = *static_cast<const TargetPhrase*>(final.GetNote().vp);
00557 action(phrase);
00558 const search::Applied *child = final.Children();
00559 for (std::size_t i = 0; i < phrase.GetSize(); ++i) {
00560 const Word &word = phrase.GetWord(i);
00561 if (word.IsNonTerminal()) {
00562 AppendToPhrase(*child++, out, action);
00563 } else {
00564 out.AddWord(word);
00565 }
00566 }
00567 }
00568
00569 }
00570
00571 void ToPhrase(const search::Applied final, Phrase &out)
00572 {
00573 out.Clear();
00574 AppendToPhrase(final, out, NoOp());
00575 }
00576
00577 void PhraseAndFeatures(const search::Applied final, Phrase &phrase, ScoreComponentCollection &features)
00578 {
00579 phrase.Clear();
00580 features.ZeroAll();
00581 AppendToPhrase(final, phrase, AccumScore(features));
00582
00583
00584 float full, ignored_ngram;
00585 std::size_t ignored_oov;
00586
00587 const LanguageModel &model = LanguageModel::GetFirstLM();
00588 model.CalcScore(phrase, full, ignored_ngram, ignored_oov);
00589
00590 features.Assign(&model, full);
00591 }
00592
00593 }
00594 }