00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include <algorithm>
00024 #include "TranslationOptionCollection.h"
00025 #include "Sentence.h"
00026 #include "DecodeStep.h"
00027 #include "LM/Base.h"
00028 #include "PhraseDictionaryMemory.h"
00029 #include "FactorCollection.h"
00030 #include "InputType.h"
00031 #include "LexicalReordering.h"
00032 #include "Util.h"
00033 #include "StaticData.h"
00034 #include "DecodeStepTranslation.h"
00035 #include "DecodeGraph.h"
00036
00037 using namespace std;
00038
00039 namespace Moses
00040 {
00042 bool CompareTranslationOption(const TranslationOption *a, const TranslationOption *b)
00043 {
00044 return a->GetFutureScore() > b->GetFutureScore();
00045 }
00046
00050 TranslationOptionCollection::TranslationOptionCollection(const TranslationSystem* system,
00051 InputType const& src, size_t maxNoTransOptPerCoverage, float translationOptionThreshold)
00052 : m_system(system),
00053 m_source(src)
00054 ,m_futureScore(src.GetSize())
00055 ,m_maxNoTransOptPerCoverage(maxNoTransOptPerCoverage)
00056 ,m_translationOptionThreshold(translationOptionThreshold)
00057 {
00058
00059 size_t size = src.GetSize();
00060 for (size_t startPos = 0 ; startPos < size ; ++startPos) {
00061 m_collection.push_back( vector< TranslationOptionList >() );
00062
00063 size_t maxSize = size - startPos;
00064 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00065 maxSize = std::min(maxSize, maxSizePhrase);
00066
00067 for (size_t endPos = 0 ; endPos < maxSize ; ++endPos) {
00068 m_collection[startPos].push_back( TranslationOptionList() );
00069 }
00070 }
00071 }
00072
00074 TranslationOptionCollection::~TranslationOptionCollection()
00075 {
00076 RemoveAllInColl(m_unksrcs);
00077 }
00078
00079 void TranslationOptionCollection::Prune()
00080 {
00081
00082 if (m_maxNoTransOptPerCoverage == 0 && m_translationOptionThreshold == -std::numeric_limits<float>::infinity())
00083 return;
00084
00085
00086 size_t total = 0;
00087 size_t totalPruned = 0;
00088
00089
00090 size_t size = m_source.GetSize();
00091 for (size_t startPos = 0 ; startPos < size; ++startPos) {
00092 size_t maxSize = size - startPos;
00093 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00094 maxSize = std::min(maxSize, maxSizePhrase);
00095
00096 for (size_t endPos = startPos ; endPos < startPos + maxSize ; ++endPos) {
00097
00098 TranslationOptionList &fullList = GetTranslationOptionList(startPos, endPos);
00099 total += fullList.size();
00100
00101
00102 if (m_maxNoTransOptPerCoverage > 0 &&
00103 fullList.size() > m_maxNoTransOptPerCoverage) {
00104
00105 nth_element(fullList.begin(), fullList.begin() + m_maxNoTransOptPerCoverage, fullList.end(), CompareTranslationOption);
00106 totalPruned += fullList.size() - m_maxNoTransOptPerCoverage;
00107
00108
00109 for (size_t i = m_maxNoTransOptPerCoverage ; i < fullList.size() ; ++i) {
00110 delete fullList.Get(i);
00111 }
00112 fullList.resize(m_maxNoTransOptPerCoverage);
00113 }
00114
00115
00116 if (fullList.size() > 1 && m_translationOptionThreshold != -std::numeric_limits<float>::infinity()) {
00117
00118 float bestScore = -std::numeric_limits<float>::infinity();
00119 for (size_t i=0; i < fullList.size() ; ++i) {
00120 if (fullList.Get(i)->GetFutureScore() > bestScore)
00121 bestScore = fullList.Get(i)->GetFutureScore();
00122 }
00123
00124
00125 for (size_t i=0; i < fullList.size() ; ++i) {
00126 if (fullList.Get(i)->GetFutureScore() < bestScore + m_translationOptionThreshold) {
00127
00128 delete fullList.Get(i);
00129 fullList.Remove(i);
00130 total--;
00131 totalPruned++;
00132 i--;
00133 }
00134
00135
00136
00137
00138 }
00139 }
00140 }
00141 }
00142
00143 VERBOSE(2," Total translation options: " << total << std::endl
00144 << "Total translation options pruned: " << totalPruned << std::endl);
00145 }
00146
00163 void TranslationOptionCollection::ProcessUnknownWord()
00164 {
00165 const vector<DecodeGraph*>& decodeGraphList = m_system->GetDecodeGraphs();
00166 size_t size = m_source.GetSize();
00167
00168 for (size_t graph = 0 ; graph < decodeGraphList.size() ; graph++) {
00169 const DecodeGraph &decodeGraph = *decodeGraphList[graph];
00170 for (size_t pos = 0 ; pos < size ; ++pos) {
00171 TranslationOptionList &fullList = GetTranslationOptionList(pos, pos);
00172 size_t numTransOpt = fullList.size();
00173 if (numTransOpt == 0) {
00174 CreateTranslationOptionsForRange(decodeGraph, pos, pos, false);
00175 }
00176 }
00177 }
00178
00179 bool alwaysCreateDirectTranslationOption = StaticData::Instance().IsAlwaysCreateDirectTranslationOption();
00180
00181 for (size_t pos = 0 ; pos < size ; ++pos) {
00182 TranslationOptionList &fullList = GetTranslationOptionList(pos, pos);
00183 if (fullList.size() == 0 || alwaysCreateDirectTranslationOption)
00184 ProcessUnknownWord(pos);
00185 }
00186 }
00187
00201 void TranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceWord,size_t sourcePos, size_t length, const Scores *inputScores)
00202
00203 {
00204
00205 FactorCollection &factorCollection = FactorCollection::Instance();
00206
00207 size_t isDigit = 0;
00208
00209 const Factor *f = sourceWord[0];
00210 const string &s = f->GetString();
00211 bool isEpsilon = (s=="" || s==EPSILON);
00212 if (StaticData::Instance().GetDropUnknown()) {
00213
00214
00215 isDigit = s.find_first_of("0123456789");
00216 if (isDigit == 1)
00217 isDigit = 1;
00218 else
00219 isDigit = 0;
00220
00221 }
00222
00223 Phrase* m_unksrc = new Phrase(1);
00224 m_unksrc->AddWord() = sourceWord;
00225 m_unksrcs.push_back(m_unksrc);
00226
00227 TranslationOption *transOpt;
00228 TargetPhrase targetPhrase(Output);
00229 targetPhrase.SetSourcePhrase(m_unksrc);
00230 if (inputScores != NULL) {
00231 targetPhrase.SetScore(m_system,*inputScores);
00232 } else {
00233 targetPhrase.SetScore(m_system);
00234 }
00235
00236 if (!(StaticData::Instance().GetDropUnknown() || isEpsilon) || isDigit) {
00237
00238
00239 Word &targetWord = targetPhrase.AddWord();
00240
00241 for (unsigned int currFactor = 0 ; currFactor < MAX_NUM_FACTORS ; currFactor++) {
00242 FactorType factorType = static_cast<FactorType>(currFactor);
00243
00244 const Factor *sourceFactor = sourceWord[currFactor];
00245 if (sourceFactor == NULL)
00246 targetWord[factorType] = factorCollection.AddFactor(Output, factorType, UNKNOWN_FACTOR);
00247 else
00248 targetWord[factorType] = factorCollection.AddFactor(Output, factorType, sourceFactor->GetString());
00249 }
00250
00251
00252 targetPhrase.SetAlignmentInfo("0-0");
00253
00254 } else {
00255
00256
00257
00258
00259 }
00260 transOpt = new TranslationOption(WordsRange(sourcePos, sourcePos + length - 1), targetPhrase, m_source
00261 , m_system->GetUnknownWordPenaltyProducer());
00262 transOpt->CalcScore(m_system);
00263 Add(transOpt);
00264 }
00265
00270 void TranslationOptionCollection::CalcFutureScore()
00271 {
00272
00273 size_t size = m_source.GetSize();
00274
00275 for(size_t row=0; row<size; row++) {
00276 for(size_t col=row; col<size; col++) {
00277 m_futureScore.SetScore(row, col, -numeric_limits<float>::infinity());
00278 }
00279 }
00280
00281
00282 for (size_t startPos = 0 ; startPos < size ; ++startPos) {
00283 size_t maxSize = m_source.GetSize() - startPos;
00284 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00285 maxSize = std::min(maxSize, maxSizePhrase);
00286
00287 for (size_t endPos = startPos ; endPos < startPos + maxSize ; ++endPos) {
00288 TranslationOptionList &transOptList = GetTranslationOptionList(startPos, endPos);
00289
00290 TranslationOptionList::const_iterator iterTransOpt;
00291 for(iterTransOpt = transOptList.begin() ; iterTransOpt != transOptList.end() ; ++iterTransOpt) {
00292 const TranslationOption &transOpt = **iterTransOpt;
00293 float score = transOpt.GetFutureScore();
00294 if (score > m_futureScore.GetScore(startPos, endPos))
00295 m_futureScore.SetScore(startPos, endPos, score);
00296 }
00297 }
00298 }
00299
00300
00301
00302
00303
00304
00305
00306
00307 for(size_t colstart = 1; colstart < size ; colstart++) {
00308 for(size_t diagshift = 0; diagshift < size-colstart ; diagshift++) {
00309 size_t startPos = diagshift;
00310 size_t endPos = colstart+diagshift;
00311 for(size_t joinAt = startPos; joinAt < endPos ; joinAt++) {
00312 float joinedScore = m_futureScore.GetScore(startPos, joinAt)
00313 + m_futureScore.GetScore(joinAt+1, endPos);
00314
00315
00316
00317
00318 if (joinedScore > m_futureScore.GetScore(startPos, endPos))
00319 m_futureScore.SetScore(startPos, endPos, joinedScore);
00320 }
00321 }
00322 }
00323
00324 IFVERBOSE(3) {
00325 int total = 0;
00326 for(size_t row=0; row<size; row++) {
00327 size_t maxSize = size - row;
00328 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00329 maxSize = std::min(maxSize, maxSizePhrase);
00330
00331 for(size_t col=row; col<row+maxSize; col++) {
00332 int count = GetTranslationOptionList(row, col).size();
00333 TRACE_ERR( "translation options spanning from "
00334 << row <<" to "<< col <<" is "
00335 << count <<endl);
00336 total += count;
00337 }
00338 }
00339 TRACE_ERR( "translation options generated in total: "<< total << endl);
00340
00341 for(size_t row=0; row<size; row++)
00342 for(size_t col=row; col<size; col++)
00343 TRACE_ERR( "future cost from "<< row <<" to "<< col <<" is "<< m_futureScore.GetScore(row, col) <<endl);
00344 }
00345 }
00346
00347
00348
00355 void TranslationOptionCollection::CreateTranslationOptions()
00356 {
00357
00358
00359
00360
00361
00362
00363 const vector <DecodeGraph*> &decodeGraphList = m_system->GetDecodeGraphs();
00364 const vector <size_t> &decodeGraphBackoff = m_system->GetDecodeGraphBackoff();
00365
00366
00367 size_t size = m_source.GetSize();
00368
00369
00370 for (size_t graph = 0 ; graph < decodeGraphList.size() ; graph++) {
00371 if (decodeGraphList.size() > 1) {
00372 VERBOSE(3,"Creating translation options from decoding graph " << graph << endl);
00373 }
00374
00375 const DecodeGraph &decodeGraph = *decodeGraphList[graph];
00376
00377 for (size_t startPos = 0 ; startPos < size; startPos++) {
00378 size_t maxSize = size - startPos;
00379 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00380 maxSize = std::min(maxSize, maxSizePhrase);
00381
00382
00383 for (size_t endPos = startPos ; endPos < startPos + maxSize ; endPos++) {
00384 if (graph > 0 &&
00385 decodeGraphBackoff[graph] != 0 &&
00386 (endPos-startPos+1 >= decodeGraphBackoff[graph] ||
00387 m_collection[startPos][endPos-startPos].size() > 0)) {
00388 VERBOSE(3,"No backoff to graph " << graph << " for span [" << startPos << ";" << endPos << "]" << endl);
00389
00390 continue;
00391 }
00392
00393
00394 CreateTranslationOptionsForRange( decodeGraph, startPos, endPos, true);
00395 }
00396 }
00397 }
00398
00399 VERBOSE(2,"Translation Option Collection\n " << *this << endl);
00400
00401 ProcessUnknownWord();
00402
00403
00404 Prune();
00405
00406 Sort();
00407
00408
00409 CalcFutureScore();
00410
00411
00412 CacheLexReordering();
00413 }
00414
00415 void TranslationOptionCollection::Sort()
00416 {
00417 size_t size = m_source.GetSize();
00418 for (size_t startPos = 0 ; startPos < size; ++startPos) {
00419 size_t maxSize = size - startPos;
00420 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00421 maxSize = std::min(maxSize, maxSizePhrase);
00422
00423 for (size_t endPos = startPos ; endPos < startPos + maxSize; ++endPos) {
00424 TranslationOptionList &transOptList = GetTranslationOptionList(startPos, endPos);
00425 std::sort(transOptList.begin(), transOptList.end(), CompareTranslationOption);
00426 }
00427 }
00428 }
00429
00430
00439 void TranslationOptionCollection::CreateTranslationOptionsForRange(
00440 const DecodeGraph &decodeGraph
00441 , size_t startPos
00442 , size_t endPos
00443 , bool adhereTableLimit)
00444 {
00445 if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) {
00446 Phrase *sourcePhrase = NULL;
00447
00448
00449 bool skipTransOptCreation = false
00450 , useCache = StaticData::Instance().GetUseTransOptCache();
00451 if (useCache) {
00452 const WordsRange wordsRange(startPos, endPos);
00453 sourcePhrase = new Phrase(m_source.GetSubString(wordsRange));
00454
00455 const TranslationOptionList *transOptList = StaticData::Instance().FindTransOptListInCache(decodeGraph, *sourcePhrase);
00456
00457 if (transOptList != NULL) {
00458 skipTransOptCreation = true;
00459 TranslationOptionList::const_iterator iterTransOpt;
00460 for (iterTransOpt = transOptList->begin() ; iterTransOpt != transOptList->end() ; ++iterTransOpt) {
00461 TranslationOption *transOpt = new TranslationOption(**iterTransOpt, wordsRange);
00462 Add(transOpt);
00463 }
00464 }
00465 }
00466
00467 if (!skipTransOptCreation) {
00468
00469 PartialTranslOptColl* oldPtoc = new PartialTranslOptColl;
00470 size_t totalEarlyPruned = 0;
00471
00472
00473 list <const DecodeStep* >::const_iterator iterStep = decodeGraph.begin();
00474 const DecodeStep &decodeStep = **iterStep;
00475
00476 static_cast<const DecodeStepTranslation&>(decodeStep).ProcessInitialTranslation
00477 (m_system, m_source, *oldPtoc
00478 , startPos, endPos, adhereTableLimit );
00479
00480
00481 int indexStep = 0;
00482 for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {
00483 const DecodeStep &decodeStep = **iterStep;
00484 PartialTranslOptColl* newPtoc = new PartialTranslOptColl;
00485
00486
00487 const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList();
00488 vector<TranslationOption*>::const_iterator iterPartialTranslOpt;
00489 for (iterPartialTranslOpt = partTransOptList.begin() ; iterPartialTranslOpt != partTransOptList.end() ; ++iterPartialTranslOpt) {
00490 TranslationOption &inputPartialTranslOpt = **iterPartialTranslOpt;
00491 decodeStep.Process(m_system, inputPartialTranslOpt
00492 , decodeStep
00493 , *newPtoc
00494 , this
00495 , adhereTableLimit);
00496 }
00497
00498 totalEarlyPruned += newPtoc->GetPrunedCount();
00499 delete oldPtoc;
00500 oldPtoc = newPtoc;
00501 indexStep++;
00502 }
00503
00504
00505 PartialTranslOptColl &lastPartialTranslOptColl = *oldPtoc;
00506 const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList();
00507 vector<TranslationOption*>::const_iterator iterColl;
00508 for (iterColl = partTransOptList.begin() ; iterColl != partTransOptList.end() ; ++iterColl) {
00509 TranslationOption *transOpt = *iterColl;
00510 transOpt->CalcScore(m_system);
00511 Add(transOpt);
00512 }
00513
00514
00515 if (useCache) {
00516 if (partTransOptList.size() > 0) {
00517 TranslationOptionList &transOptList = GetTranslationOptionList(startPos, endPos);
00518 StaticData::Instance().AddTransOptListToCache(decodeGraph, *sourcePhrase, transOptList);
00519 }
00520 }
00521
00522 lastPartialTranslOptColl.DetachAll();
00523 totalEarlyPruned += oldPtoc->GetPrunedCount();
00524 delete oldPtoc;
00525
00526 }
00527
00528 if (useCache)
00529 delete sourcePhrase;
00530 }
00531
00532 if ((StaticData::Instance().GetXmlInputType() != XmlPassThrough) && HasXmlOptionsOverlappingRange(startPos,endPos)) {
00533 CreateXmlOptionsForRange(startPos, endPos);
00534 }
00535 }
00536
00544 bool TranslationOptionCollection::HasXmlOptionsOverlappingRange(size_t, size_t) const
00545 {
00546 return false;
00547
00548 }
00549
00555 void TranslationOptionCollection::CreateXmlOptionsForRange(size_t, size_t)
00556 {
00557
00558 };
00559
00560
00561
00562
00565 void TranslationOptionCollection::Add(TranslationOption *translationOption)
00566 {
00567 const WordsRange &coverage = translationOption->GetSourceWordsRange();
00568 CHECK(coverage.GetEndPos() - coverage.GetStartPos() < m_collection[coverage.GetStartPos()].size());
00569 m_collection[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()].Add(translationOption);
00570 }
00571
00572 TO_STRING_BODY(TranslationOptionCollection);
00573
00574 std::ostream& operator<<(std::ostream& out, const TranslationOptionCollection& coll)
00575 {
00576 size_t size = coll.GetSize();
00577 for (size_t startPos = 0 ; startPos < size ; ++startPos) {
00578 size_t maxSize = size - startPos;
00579 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00580 maxSize = std::min(maxSize, maxSizePhrase);
00581
00582 for (size_t endPos = startPos ; endPos < startPos + maxSize ; ++endPos) {
00583 const TranslationOptionList& fullList = coll.GetTranslationOptionList(startPos, endPos);
00584 size_t sizeFull = fullList.size();
00585 for (size_t i = 0; i < sizeFull; i++) {
00586 out << *fullList.Get(i) << std::endl;
00587 }
00588 }
00589 }
00590
00591
00592
00593
00594
00595
00596
00597 return out;
00598 }
00599
00600 void TranslationOptionCollection::CacheLexReordering()
00601 {
00602 const vector<LexicalReordering*> &lexReorderingModels = m_system->GetReorderModels();
00603 std::vector<LexicalReordering*>::const_iterator iterLexreordering;
00604
00605 size_t size = m_source.GetSize();
00606 for (iterLexreordering = lexReorderingModels.begin() ; iterLexreordering != lexReorderingModels.end() ; ++iterLexreordering) {
00607 LexicalReordering &lexreordering = **iterLexreordering;
00608
00609 for (size_t startPos = 0 ; startPos < size ; startPos++) {
00610 size_t maxSize = size - startPos;
00611 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00612 maxSize = std::min(maxSize, maxSizePhrase);
00613
00614 for (size_t endPos = startPos ; endPos < startPos + maxSize; endPos++) {
00615 TranslationOptionList &transOptList = GetTranslationOptionList( startPos, endPos);
00616 TranslationOptionList::iterator iterTransOpt;
00617 for(iterTransOpt = transOptList.begin() ; iterTransOpt != transOptList.end() ; ++iterTransOpt) {
00618 TranslationOption &transOpt = **iterTransOpt;
00619
00620 const Phrase *sourcePhrase = transOpt.GetSourcePhrase();
00621 if (sourcePhrase) {
00622 Scores score = lexreordering.GetProb(*sourcePhrase
00623 , transOpt.GetTargetPhrase());
00624 if (!score.empty())
00625 transOpt.CacheScores(lexreordering, score);
00626 }
00627 }
00628 }
00629 }
00630 }
00631 }
00633 TranslationOptionList &TranslationOptionCollection::GetTranslationOptionList(size_t startPos, size_t endPos)
00634 {
00635 size_t maxSize = endPos - startPos;
00636 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00637 maxSize = std::min(maxSize, maxSizePhrase);
00638
00639 CHECK(maxSize < m_collection[startPos].size());
00640 return m_collection[startPos][maxSize];
00641 }
00642 const TranslationOptionList &TranslationOptionCollection::GetTranslationOptionList(size_t startPos, size_t endPos) const
00643 {
00644 size_t maxSize = endPos - startPos;
00645 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00646 maxSize = std::min(maxSize, maxSizePhrase);
00647
00648 CHECK(maxSize < m_collection[startPos].size());
00649 return m_collection[startPos][maxSize];
00650 }
00651
00652 }
00653