00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include <typeinfo>
00024 #include <algorithm>
00025 #include <typeinfo>
00026 #include "TranslationOptionCollection.h"
00027 #include "Sentence.h"
00028 #include "DecodeStep.h"
00029 #include "LM/Base.h"
00030 #include "FactorCollection.h"
00031 #include "InputType.h"
00032 #include "Util.h"
00033 #include "StaticData.h"
00034 #include "DecodeStepTranslation.h"
00035 #include "DecodeStepGeneration.h"
00036 #include "DecodeGraph.h"
00037 #include "InputPath.h"
00038 #include "moses/FF/UnknownWordPenaltyProducer.h"
00039 #include "moses/FF/LexicalReordering/LexicalReordering.h"
00040 #include "moses/FF/InputFeature.h"
00041 #include "TranslationTask.h"
00042 #include "util/exception.hh"
00043
00044 #include <boost/foreach.hpp>
00045 using namespace std;
00046
00047 namespace Moses
00048 {
00049
00053 TranslationOptionCollection::
00054 TranslationOptionCollection(ttasksptr const& ttask,
00055 InputType const& src)
00056 : m_ttask(ttask)
00057 , m_source(src)
00058 , m_estimatedScores(src.GetSize())
00059 , m_maxNoTransOptPerCoverage(ttask->options()->search.max_trans_opt_per_cov)
00060 , m_translationOptionThreshold(ttask->options()->search.trans_opt_threshold)
00061 , m_max_phrase_length(ttask->options()->search.max_phrase_length)
00062 , max_partial_trans_opt(ttask->options()->search.max_partial_trans_opt)
00063 {
00064
00065 size_t size = src.GetSize();
00066 for (size_t sPos = 0 ; sPos < size ; ++sPos) {
00067 m_collection.push_back( vector< TranslationOptionList >() );
00068
00069 size_t maxSize = size - sPos;
00070 maxSize = std::min(maxSize, m_max_phrase_length);
00071
00072 for (size_t ePos = 0 ; ePos < maxSize ; ++ePos) {
00073 m_collection[sPos].push_back( TranslationOptionList() );
00074 }
00075 }
00076 }
00077
00079 TranslationOptionCollection::
00080 ~TranslationOptionCollection()
00081 {
00082 RemoveAllInColl(m_inputPathQueue);
00083 }
00084
00085 void
00086 TranslationOptionCollection::
00087 Prune()
00088 {
00089 static float no_th = -std::numeric_limits<float>::infinity();
00090
00091 if (m_maxNoTransOptPerCoverage == 0 && m_translationOptionThreshold == no_th)
00092 return;
00093
00094
00095 size_t total = 0;
00096 size_t totalPruned = 0;
00097
00098
00099 size_t size = m_source.GetSize();
00100 for (size_t sPos = 0 ; sPos < size; ++sPos) {
00101 BOOST_FOREACH(TranslationOptionList& fullList, m_collection[sPos]) {
00102 total += fullList.size();
00103 totalPruned += fullList.SelectNBest(m_maxNoTransOptPerCoverage);
00104 totalPruned += fullList.PruneByThreshold(m_translationOptionThreshold);
00105 }
00106 }
00107
00108 VERBOSE(2," Total translation options: " << total << std::endl
00109 << "Total translation options pruned: " << totalPruned << std::endl);
00110 }
00111
00128 void
00129 TranslationOptionCollection::
00130 ProcessUnknownWord()
00131 {
00132 const vector<DecodeGraph*>& decodeGraphList
00133 = StaticData::Instance().GetDecodeGraphs();
00134 size_t size = m_source.GetSize();
00135
00136 for (size_t graphInd = 0 ; graphInd < decodeGraphList.size() ; graphInd++) {
00137 const DecodeGraph &decodeGraph = *decodeGraphList[graphInd];
00138 for (size_t pos = 0 ; pos < size ; ++pos) {
00139 TranslationOptionList* fullList = GetTranslationOptionList(pos, pos);
00140
00141 if (!fullList || fullList->size() == 0) {
00142 CreateTranslationOptionsForRange(decodeGraph, pos, pos, false, graphInd);
00143 }
00144 }
00145 }
00146
00147
00148
00149 bool always = m_ttask.lock()->options()->unk.always_create_direct_transopt;
00150
00151
00152 for (size_t pos = 0 ; pos < size ; ++pos) {
00153 TranslationOptionList* fullList = GetTranslationOptionList(pos, pos);
00154 if (!fullList || fullList->size() == 0 || always)
00155 ProcessUnknownWord(pos);
00156 }
00157 }
00158
00172 void
00173 TranslationOptionCollection::
00174 ProcessOneUnknownWord(const InputPath &inputPath, size_t sourcePos,
00175 size_t length, const ScorePair *inputScores)
00176 {
00177 const UnknownWordPenaltyProducer&
00178 unknownWordPenaltyProducer = UnknownWordPenaltyProducer::Instance();
00179 float unknownScore = FloorScore(TransformScore(0));
00180 const Word &sourceWord = inputPath.GetPhrase().GetWord(0);
00181
00182
00183 PhraseDictionary *firstPt = NULL;
00184 if (PhraseDictionary::GetColl().size() == 0) {
00185 firstPt = PhraseDictionary::GetColl()[0];
00186 }
00187
00188
00189 FactorCollection &factorCollection = FactorCollection::Instance();
00190
00191 size_t isDigit = 0;
00192
00193 const Factor *f = sourceWord[0];
00194 const StringPiece s = f->GetString();
00195 bool isEpsilon = (s=="" || s==EPSILON);
00196 bool dropUnk = GetTranslationTask()->options()->unk.drop;
00197 if (dropUnk) {
00198 isDigit = s.find_first_of("0123456789");
00199 if (isDigit == string::npos)
00200 isDigit = 0;
00201 else
00202 isDigit = 1;
00203
00204 }
00205
00206 TargetPhrase targetPhrase(firstPt);
00207
00208 if (!(dropUnk || isEpsilon) || isDigit) {
00209
00210
00211 Word &targetWord = targetPhrase.AddWord();
00212 targetWord.SetIsOOV(true);
00213
00214 for (unsigned int currFactor = 0 ; currFactor < MAX_NUM_FACTORS ; currFactor++) {
00215 FactorType factorType = static_cast<FactorType>(currFactor);
00216
00217 const Factor *sourceFactor = sourceWord[currFactor];
00218 if (sourceFactor == NULL)
00219 targetWord[factorType] = factorCollection.AddFactor(UNKNOWN_FACTOR);
00220 else
00221 targetWord[factorType] = factorCollection.AddFactor(sourceFactor->GetString());
00222 }
00223
00224
00225 targetPhrase.SetAlignmentInfo("0-0");
00226
00227 }
00228
00229 targetPhrase.GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
00230
00231
00232 const Phrase &sourcePhrase = inputPath.GetPhrase();
00233 m_unksrcs.push_back(&sourcePhrase);
00234 Range range(sourcePos, sourcePos + length - 1);
00235
00236 targetPhrase.EvaluateInIsolation(sourcePhrase);
00237
00238 TranslationOption *transOpt = new TranslationOption(range, targetPhrase);
00239 transOpt->SetInputPath(inputPath);
00240 Add(transOpt);
00241
00242
00243 }
00244
00249 void
00250 TranslationOptionCollection::
00251 CalcEstimatedScore()
00252 {
00253
00254 m_estimatedScores.InitTriangle(-numeric_limits<float>::infinity());
00255
00256
00257 size_t size = m_source.GetSize();
00258 for (size_t sPos = 0 ; sPos < size ; ++sPos) {
00259 size_t ePos = sPos;
00260 BOOST_FOREACH(TranslationOptionList& tol, m_collection[sPos]) {
00261 TranslationOptionList::const_iterator toi;
00262 for(toi = tol.begin() ; toi != tol.end() ; ++toi) {
00263 const TranslationOption& to = **toi;
00264 float score = to.GetFutureScore();
00265 if (score > m_estimatedScores.GetScore(sPos, ePos))
00266 m_estimatedScores.SetScore(sPos, ePos, score);
00267 }
00268 ++ePos;
00269 }
00270 }
00271
00272
00273
00274
00275
00276
00277
00278
00279 for(size_t colstart = 1; colstart < size ; colstart++) {
00280 for(size_t diagshift = 0; diagshift < size-colstart ; diagshift++) {
00281 size_t sPos = diagshift;
00282 size_t ePos = colstart+diagshift;
00283 for(size_t joinAt = sPos; joinAt < ePos ; joinAt++) {
00284 float joinedScore = m_estimatedScores.GetScore(sPos, joinAt)
00285 + m_estimatedScores.GetScore(joinAt+1, ePos);
00286
00287
00288
00289
00290
00291
00292
00293 if (joinedScore > m_estimatedScores.GetScore(sPos, ePos))
00294 m_estimatedScores.SetScore(sPos, ePos, joinedScore);
00295 }
00296 }
00297 }
00298
00299 IFVERBOSE(3) {
00300 int total = 0;
00301 for(size_t row = 0; row < size; row++) {
00302 size_t col = row;
00303 BOOST_FOREACH(TranslationOptionList& tol, m_collection[row]) {
00304
00305
00306
00307
00308
00309 int count = tol.size();
00310 TRACE_ERR( "translation options spanning from "
00311 << row <<" to "<< col <<" is "
00312 << count <<endl);
00313 total += count;
00314 ++col;
00315 }
00316 }
00317 TRACE_ERR( "translation options generated in total: "<< total << endl);
00318
00319 for(size_t row=0; row<size; row++)
00320 for(size_t col=row; col<size; col++)
00321 TRACE_ERR( "future cost from "<< row <<" to "<< col <<" is "
00322 << m_estimatedScores.GetScore(row, col) <<endl);
00323 }
00324 }
00325
00326
00327
00332 void
00333 TranslationOptionCollection::
00334 CreateTranslationOptions()
00335 {
00336
00337
00338
00339
00340
00341
00342 const vector <DecodeGraph*> &decodeGraphList
00343 = StaticData::Instance().GetDecodeGraphs();
00344
00345
00346 const size_t size = m_source.GetSize();
00347
00348
00349 for (size_t gidx = 0 ; gidx < decodeGraphList.size() ; gidx++) {
00350 if (decodeGraphList.size() > 1)
00351 VERBOSE(3,"Creating translation options from decoding graph " << gidx << endl);
00352
00353 const DecodeGraph& dg = *decodeGraphList[gidx];
00354 size_t backoff = dg.GetBackoff();
00355
00356 for (size_t sPos = 0 ; sPos < size; sPos++) {
00357 size_t maxSize = size - sPos;
00358
00359 maxSize = std::min(maxSize, m_max_phrase_length);
00360
00361 for (size_t ePos = sPos ; ePos < sPos + maxSize ; ePos++) {
00362 if (gidx && backoff &&
00363 (ePos-sPos+1 <= backoff ||
00364 m_collection[sPos][ePos-sPos].size() > 0)) {
00365 VERBOSE(3,"No backoff to graph " << gidx << " for span [" << sPos << ";" << ePos << "]" << endl);
00366 continue;
00367 }
00368 CreateTranslationOptionsForRange(dg, sPos, ePos, true, gidx);
00369 }
00370 }
00371 }
00372 ProcessUnknownWord();
00373 EvaluateWithSourceContext();
00374 VERBOSE(3,"Translation Option Collection\n " << *this << endl);
00375 Prune();
00376 Sort();
00377 CalcEstimatedScore();
00378 CacheLexReordering();
00379 }
00380
00381
00382 bool
00383 TranslationOptionCollection::
00384 CreateTranslationOptionsForRange
00385 (const DecodeGraph& dgraph, size_t sPos, size_t ePos,
00386 bool adhereTableLimit, size_t gidx, InputPath &inputPath)
00387 {
00388 typedef DecodeStepTranslation Tstep;
00389 typedef DecodeStepGeneration Gstep;
00390 XmlInputType xml_policy = m_ttask.lock()->options()->input.xml_policy;
00391 if ((xml_policy != XmlExclusive)
00392 || !HasXmlOptionsOverlappingRange(sPos,ePos)) {
00393
00394
00395 PartialTranslOptColl* oldPtoc = new PartialTranslOptColl(max_partial_trans_opt);
00396 size_t totalEarlyPruned = 0;
00397
00398
00399 list <const DecodeStep* >::const_iterator d = dgraph.begin();
00400 const DecodeStep &dstep = **d;
00401
00402 const PhraseDictionary &pdict = *dstep.GetPhraseDictionaryFeature();
00403 TargetPhraseCollection::shared_ptr targetPhrases = inputPath.GetTargetPhrases(pdict);
00404
00405 static_cast<const Tstep&>(dstep).ProcessInitialTranslation
00406 (m_source, *oldPtoc, sPos, ePos, adhereTableLimit, inputPath, targetPhrases);
00407
00408 SetInputScore(inputPath, *oldPtoc);
00409
00410
00411 int indexStep = 0;
00412
00413 for (++d ; d != dgraph.end() ; ++d) {
00414 const DecodeStep *dstep = *d;
00415 PartialTranslOptColl* newPtoc = new PartialTranslOptColl(m_max_phrase_length);
00416
00417
00418 const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList();
00419 vector<TranslationOption*>::const_iterator pto;
00420 for (pto = partTransOptList.begin() ; pto != partTransOptList.end() ; ++pto) {
00421 TranslationOption &inputPartialTranslOpt = **pto;
00422 if (const Tstep *tstep = dynamic_cast<const Tstep*>(dstep)) {
00423 const PhraseDictionary &pdict = *tstep->GetPhraseDictionaryFeature();
00424 TargetPhraseCollection::shared_ptr targetPhrases = inputPath.GetTargetPhrases(pdict);
00425 tstep->Process(inputPartialTranslOpt, *dstep, *newPtoc,
00426 this, adhereTableLimit, targetPhrases);
00427 } else {
00428 const Gstep *genStep = dynamic_cast<const Gstep*>(dstep);
00429 UTIL_THROW_IF2(!genStep, "Decode steps must be either "
00430 << "Translation or Generation Steps!");
00431 genStep->Process(inputPartialTranslOpt, *dstep, *newPtoc,
00432 this, adhereTableLimit);
00433 }
00434 }
00435
00436
00437 totalEarlyPruned += newPtoc->GetPrunedCount();
00438 delete oldPtoc;
00439 oldPtoc = newPtoc;
00440
00441 indexStep++;
00442 }
00443
00444
00445 PartialTranslOptColl &lastPartialTranslOptColl = *oldPtoc;
00446 const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList();
00447 vector<TranslationOption*>::const_iterator c;
00448 for (c = partTransOptList.begin() ; c != partTransOptList.end() ; ++c) {
00449 TranslationOption *transOpt = *c;
00450 if (xml_policy != XmlConstraint ||
00451 !ViolatesXmlOptionsConstraint(sPos,ePos,transOpt)) {
00452 Add(transOpt);
00453 }
00454 }
00455 lastPartialTranslOptColl.DetachAll();
00456 totalEarlyPruned += oldPtoc->GetPrunedCount();
00457 delete oldPtoc;
00458
00459 }
00460
00461 if (gidx == 0 && xml_policy != XmlPassThrough
00462 && HasXmlOptionsOverlappingRange(sPos,ePos)) {
00463 CreateXmlOptionsForRange(sPos, ePos);
00464 }
00465
00466 return true;
00467 }
00468
00469 void
00470 TranslationOptionCollection::
00471 SetInputScore(const InputPath &inputPath, PartialTranslOptColl &oldPtoc)
00472 {
00473 const ScorePair* inputScore = inputPath.GetInputScore();
00474 if (inputScore == NULL) return;
00475
00476 const InputFeature *inputFeature = InputFeature::InstancePtr();
00477
00478 const std::vector<TranslationOption*> &transOpts = oldPtoc.GetList();
00479 for (size_t i = 0; i < transOpts.size(); ++i) {
00480 TranslationOption &transOpt = *transOpts[i];
00481
00482 ScoreComponentCollection &scores = transOpt.GetScoreBreakdown();
00483 scores.PlusEquals(inputFeature, *inputScore);
00484
00485 }
00486 }
00487
00488 void
00489 TranslationOptionCollection::
00490 EvaluateWithSourceContext()
00491 {
00492 const size_t size = m_source.GetSize();
00493 for (size_t sPos = 0 ; sPos < size ; ++sPos) {
00494 BOOST_FOREACH(TranslationOptionList& tol, m_collection[sPos]) {
00495 typedef TranslationOptionList::const_iterator to_iter;
00496 for(to_iter i = tol.begin() ; i != tol.end() ; ++i)
00497 (*i)->EvaluateWithSourceContext(m_source);
00498 EvaluateTranslationOptionListWithSourceContext(tol);
00499 }
00500 }
00501 }
00502
00503 void TranslationOptionCollection::EvaluateTranslationOptionListWithSourceContext(
00504 TranslationOptionList &translationOptionList)
00505 {
00506
00507 const std::vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions();
00508 const StaticData &staticData = StaticData::Instance();
00509 for (size_t i = 0; i < ffs.size(); ++i) {
00510 const FeatureFunction &ff = *ffs[i];
00511 if (! staticData.IsFeatureFunctionIgnored(ff)) {
00512 ff.EvaluateTranslationOptionListWithSourceContext(m_source, translationOptionList);
00513 }
00514 }
00515
00516 }
00517
00518 void
00519 TranslationOptionCollection::
00520 Sort()
00521 {
00522 static TranslationOption::Better cmp;
00523 size_t size = m_source.GetSize();
00524 for (size_t sPos = 0 ; sPos < size; ++sPos) {
00525 BOOST_FOREACH(TranslationOptionList& tol, m_collection.at(sPos)) {
00526
00527
00528
00529
00530
00531
00532
00533 std::sort(tol.begin(), tol.end(), cmp);
00534 }
00535 }
00536 }
00537
00544 bool
00545 TranslationOptionCollection::
00546 HasXmlOptionsOverlappingRange(size_t, size_t) const
00547 {
00548 return false;
00549 }
00550
00557 bool
00558 TranslationOptionCollection::
00559 ViolatesXmlOptionsConstraint(size_t, size_t, TranslationOption*) const
00560 {
00561 return false;
00562 }
00563
00569 void
00570 TranslationOptionCollection::
00571 CreateXmlOptionsForRange(size_t, size_t)
00572 { }
00573
00574
00577 void
00578 TranslationOptionCollection::
00579 Add(TranslationOption *translationOption)
00580 {
00581 const Range &coverage = translationOption->GetSourceWordsRange();
00582 size_t const s = coverage.GetStartPos();
00583 size_t const e = coverage.GetEndPos();
00584 size_t const i = e - s;
00585
00586 UTIL_THROW_IF2(e >= m_source.GetSize(),
00587 "Coverage exceeds input size:" << coverage << "\n"
00588 << "translationOption=" << *translationOption);
00589
00590 vector<TranslationOptionList>& v = m_collection[s];
00591 while (i >= v.size()) v.push_back(TranslationOptionList());
00592 v[i].Add(translationOption);
00593 }
00594
00595 TO_STRING_BODY(TranslationOptionCollection);
00596
00597 std::ostream&
00598 operator<<(std::ostream& out, const TranslationOptionCollection& coll)
00599 {
00600 size_t stop = coll.m_source.GetSize();
00601 TranslationOptionList const* tol;
00602 for (size_t sPos = 0 ; sPos < stop ; ++sPos) {
00603 for (size_t ePos = sPos;
00604 (tol = coll.GetTranslationOptionList(sPos, ePos)) != NULL;
00605 ++ePos) {
00606 BOOST_FOREACH(TranslationOption const* to, *tol)
00607 out << *to << std::endl;
00608 }
00609 }
00610 return out;
00611 }
00612
00613 void
00614 TranslationOptionCollection::
00615 CacheLexReordering()
00616 {
00617 size_t const stop = m_source.GetSize();
00618 typedef StatefulFeatureFunction sfFF;
00619 BOOST_FOREACH(sfFF const* ff, sfFF::GetStatefulFeatureFunctions()) {
00620 if (typeid(*ff) != typeid(LexicalReordering)) continue;
00621 LexicalReordering const& lr = static_cast<const LexicalReordering&>(*ff);
00622 for (size_t s = 0 ; s < stop ; s++)
00623 BOOST_FOREACH(TranslationOptionList& tol, m_collection[s])
00624 lr.SetCache(tol);
00625 }
00626 }
00627
00629 TranslationOptionList*
00630 TranslationOptionCollection::
00631 GetTranslationOptionList(size_t const sPos, size_t const ePos)
00632 {
00633 UTIL_THROW_IF2(sPos >= m_collection.size(), "Out of bound access.");
00634 vector<TranslationOptionList>& tol = m_collection[sPos];
00635 size_t idx = ePos - sPos;
00636 return idx < tol.size() ? &tol[idx] : NULL;
00637 }
00638
00639 TranslationOptionList const*
00640 TranslationOptionCollection::
00641 GetTranslationOptionList(size_t sPos, size_t ePos) const
00642 {
00643 UTIL_THROW_IF2(sPos >= m_collection.size(), "Out of bound access.");
00644 vector<TranslationOptionList> const& tol = m_collection[sPos];
00645 size_t idx = ePos - sPos;
00646 return idx < tol.size() ? &tol[idx] : NULL;
00647 }
00648
00649 void
00650 TranslationOptionCollection::
00651 GetTargetPhraseCollectionBatch()
00652 {
00653 typedef DecodeStepTranslation Tstep;
00654 const vector <DecodeGraph*> &dgl = StaticData::Instance().GetDecodeGraphs();
00655 BOOST_FOREACH(DecodeGraph const* dgraph, dgl) {
00656 typedef list <const DecodeStep* >::const_iterator dsiter;
00657 for (dsiter i = dgraph->begin(); i != dgraph->end() ; ++i) {
00658 const Tstep* tstep = dynamic_cast<const Tstep *>(*i);
00659 if (tstep) {
00660 const PhraseDictionary &pdict = *tstep->GetPhraseDictionaryFeature();
00661 pdict.GetTargetPhraseCollectionBatch(m_ttask.lock(), m_inputPathQueue);
00662 }
00663 }
00664 }
00665 }
00666
00667 }
00668