00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifdef WIN32
00023 #include <hash_set>
00024 #else
00025 #include <ext/hash_set>
00026 #endif
00027
00028 #include <algorithm>
00029 #include <limits>
00030 #include <cmath>
00031 #include "Manager.h"
00032 #include "TypeDef.h"
00033 #include "Util.h"
00034 #include "TargetPhrase.h"
00035 #include "TrellisPath.h"
00036 #include "TrellisPathCollection.h"
00037 #include "TranslationOption.h"
00038 #include "LexicalReordering.h"
00039 #include "LMList.h"
00040 #include "TranslationOptionCollection.h"
00041 #include "DummyScoreProducers.h"
00042 #ifdef HAVE_PROTOBUF
00043 #include "hypergraph.pb.h"
00044 #include "rule.pb.h"
00045 #endif
00046
00047 using namespace std;
00048
00049 namespace Moses
00050 {
00051 Manager::Manager(InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system)
00052 :m_system(system)
00053 ,m_transOptColl(source.CreateTranslationOptionCollection(system))
00054 ,m_search(Search::CreateSearch(*this, source, searchAlgorithm, *m_transOptColl))
00055 ,m_start(clock())
00056 ,interrupted_flag(0)
00057 ,m_hypoId(0)
00058 ,m_source(source)
00059 {
00060 m_system->InitializeBeforeSentenceProcessing(source);
00061 }
00062
00063 Manager::~Manager()
00064 {
00065 delete m_transOptColl;
00066 delete m_search;
00067
00068 m_system->CleanUpAfterSentenceProcessing();
00069
00070 clock_t end = clock();
00071 float et = (end - m_start);
00072 et /= (float)CLOCKS_PER_SEC;
00073 VERBOSE(1, "Translation took " << et << " seconds" << endl);
00074 VERBOSE(1, "Finished translating" << endl);
00075 }
00076
00081 void Manager::ProcessSentence()
00082 {
00083
00084 ResetSentenceStats(m_source);
00085
00086
00087 m_system->InitializeBeforeSentenceProcessing(m_source);
00088 m_transOptColl->CreateTranslationOptions();
00089
00090
00091 clock_t gotOptions = clock();
00092 float et = (gotOptions - m_start);
00093 IFVERBOSE(2) {
00094 GetSentenceStats().AddTimeCollectOpts( gotOptions - m_start );
00095 }
00096 et /= (float)CLOCKS_PER_SEC;
00097 VERBOSE(1, "Collecting options took " << et << " seconds" << endl);
00098
00099
00100 m_search->ProcessSentence();
00101 VERBOSE(1, "Search took " << ((clock()-m_start)/(float)CLOCKS_PER_SEC) << " seconds" << endl);
00102 }
00103
00109 void Manager::PrintAllDerivations(long translationId, ostream& outputStream) const
00110 {
00111 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00112
00113 vector<const Hypothesis*> sortedPureHypo = hypoStackColl.back()->GetSortedList();
00114
00115 if (sortedPureHypo.size() == 0)
00116 return;
00117
00118 float remainingScore = 0;
00119 vector<const TargetPhrase*> remainingPhrases;
00120
00121
00122 vector<const Hypothesis*>::const_iterator iterBestHypo;
00123 for (iterBestHypo = sortedPureHypo.begin()
00124 ; iterBestHypo != sortedPureHypo.end()
00125 ; ++iterBestHypo) {
00126 printThisHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore, outputStream);
00127 printDivergentHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore, outputStream);
00128 }
00129 }
00130
00131 const TranslationOptionCollection* Manager::getSntTranslationOptions()
00132 {
00133 return m_transOptColl;
00134 }
00135
00136 void Manager::printDivergentHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore , ostream& outputStream ) const
00137 {
00138
00139 if (hypo->GetId() > 0) {
00140 vector <const TargetPhrase*> followingPhrases;
00141 followingPhrases.push_back(& (hypo->GetCurrTargetPhrase()));
00143 followingPhrases.insert(followingPhrases.end()--, remainingPhrases.begin(), remainingPhrases.end());
00144 printDivergentHypothesis(translationId, hypo->GetPrevHypo(), followingPhrases , remainingScore + hypo->GetScore() - hypo->GetPrevHypo()->GetScore(), outputStream);
00145 }
00146
00147
00148 const ArcList *pAL = hypo->GetArcList();
00149 if (pAL) {
00150 const ArcList &arcList = *pAL;
00151
00152 ArcList::const_iterator iterArc;
00153 for (iterArc = arcList.begin() ; iterArc != arcList.end() ; ++iterArc) {
00154 const Hypothesis *loserHypo = *iterArc;
00155 const Hypothesis* loserPrevHypo = loserHypo->GetPrevHypo();
00156 float arcScore = loserHypo->GetScore() - loserPrevHypo->GetScore();
00157 vector <const TargetPhrase* > followingPhrases;
00158 followingPhrases.push_back(&(loserHypo->GetCurrTargetPhrase()));
00159 followingPhrases.insert(followingPhrases.end()--, remainingPhrases.begin(), remainingPhrases.end());
00160 printThisHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore, outputStream);
00161 printDivergentHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore, outputStream);
00162 }
00163 }
00164 }
00165
00166
00167 void Manager::printThisHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore, ostream& outputStream) const
00168 {
00169
00170 outputStream << translationId << " ||| ";
00171
00172
00173 hypo->ToStream(outputStream);
00174 for (size_t p = 0; p < remainingPhrases.size(); ++p) {
00175 const TargetPhrase * phrase = remainingPhrases[p];
00176 size_t size = phrase->GetSize();
00177 for (size_t pos = 0 ; pos < size ; pos++) {
00178 const Factor *factor = phrase->GetFactor(pos, 0);
00179 outputStream << *factor;
00180 outputStream << " ";
00181 }
00182 }
00183
00184 outputStream << "||| " << hypo->GetScore() + remainingScore;
00185 outputStream << endl;
00186 }
00187
00188
00189
00190
00200 void Manager::CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct) const
00201 {
00202 if (count <= 0)
00203 return;
00204
00205 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00206
00207 vector<const Hypothesis*> sortedPureHypo = hypoStackColl.back()->GetSortedList();
00208
00209 if (sortedPureHypo.size() == 0)
00210 return;
00211
00212 TrellisPathCollection contenders;
00213
00214 set<Phrase> distinctHyps;
00215
00216
00217 vector<const Hypothesis*>::const_iterator iterBestHypo;
00218 for (iterBestHypo = sortedPureHypo.begin()
00219 ; iterBestHypo != sortedPureHypo.end()
00220 ; ++iterBestHypo) {
00221 contenders.Add(new TrellisPath(*iterBestHypo));
00222 }
00223
00224
00225 size_t nBestFactor = StaticData::Instance().GetNBestFactor();
00226 if (nBestFactor < 1) nBestFactor = 1000;
00227
00228
00229 for (size_t iteration = 0 ; (onlyDistinct ? distinctHyps.size() : ret.GetSize()) < count && contenders.GetSize() > 0 && (iteration < count * nBestFactor) ; iteration++) {
00230
00231 TrellisPath *path = contenders.pop();
00232 CHECK(path);
00233
00234 path->CreateDeviantPaths(contenders);
00235 if(onlyDistinct) {
00236 Phrase tgtPhrase = path->GetSurfacePhrase();
00237 if (distinctHyps.insert(tgtPhrase).second) {
00238 ret.Add(path);
00239 } else {
00240 delete path;
00241 path = NULL;
00242 }
00243 } else {
00244 ret.Add(path);
00245 }
00246
00247
00248 if(onlyDistinct) {
00249 const size_t nBestFactor = StaticData::Instance().GetNBestFactor();
00250 if (nBestFactor > 0)
00251 contenders.Prune(count * nBestFactor);
00252 } else {
00253 contenders.Prune(count);
00254 }
00255 }
00256 }
00257
00258 struct SGNReverseCompare {
00259 bool operator() (const SearchGraphNode& s1, const SearchGraphNode& s2) const {
00260 return s1.hypo->GetId() > s2.hypo->GetId();
00261 }
00262 };
00263
00267 void Manager::CalcLatticeSamples(size_t count, TrellisPathList &ret) const {
00268
00269 vector<SearchGraphNode> searchGraph;
00270 GetSearchGraph(searchGraph);
00271
00272
00273
00274
00275 typedef pair<int, int> Edge;
00276 map<const Hypothesis*, float> sigmas;
00277 map<Edge, float> edgeScores;
00278 map<const Hypothesis*, set<const Hypothesis*> > outgoingHyps;
00279 map<int,const Hypothesis*> idToHyp;
00280 map<int,float> fscores;
00281
00282
00283
00284
00285
00286 sort(searchGraph.begin(), searchGraph.end(), SGNReverseCompare());
00287
00288
00289 for (vector<SearchGraphNode>::const_iterator i = searchGraph.begin();
00290 i != searchGraph.end(); ++i) {
00291 const Hypothesis* hypo = i->hypo;
00292 idToHyp[hypo->GetId()] = hypo;
00293 fscores[hypo->GetId()] = i->fscore;
00294 if (hypo->GetId()) {
00295
00296 const Hypothesis* prevHypo = i->hypo->GetPrevHypo();
00297 outgoingHyps[prevHypo].insert(hypo);
00298 edgeScores[Edge(prevHypo->GetId(),hypo->GetId())] =
00299 hypo->GetScore() - prevHypo->GetScore();
00300 }
00301
00302 if (i->forward >= 0) {
00303 map<int,const Hypothesis*>::const_iterator idToHypIter = idToHyp.find(i->forward);
00304 CHECK(idToHypIter != idToHyp.end());
00305 const Hypothesis* nextHypo = idToHypIter->second;
00306 outgoingHyps[hypo].insert(nextHypo);
00307 map<int,float>::const_iterator fscoreIter = fscores.find(nextHypo->GetId());
00308 CHECK(fscoreIter != fscores.end());
00309 edgeScores[Edge(hypo->GetId(),nextHypo->GetId())] =
00310 i->fscore - fscoreIter->second;
00311 }
00312 }
00313
00314
00315
00316 for (vector<SearchGraphNode>::const_iterator i = searchGraph.begin();
00317 i != searchGraph.end(); ++i) {
00318
00319 if (i->forward == -1) {
00320 sigmas[i->hypo] = 0;
00321 } else {
00322 map<const Hypothesis*, set<const Hypothesis*> >::const_iterator outIter =
00323 outgoingHyps.find(i->hypo);
00324
00325 CHECK(outIter != outgoingHyps.end());
00326 float sigma = 0;
00327 for (set<const Hypothesis*>::const_iterator j = outIter->second.begin();
00328 j != outIter->second.end(); ++j) {
00329 map<const Hypothesis*, float>::const_iterator succIter = sigmas.find(*j);
00330 CHECK(succIter != sigmas.end());
00331 map<Edge,float>::const_iterator edgeScoreIter =
00332 edgeScores.find(Edge(i->hypo->GetId(),(*j)->GetId()));
00333 CHECK(edgeScoreIter != edgeScores.end());
00334 float term = edgeScoreIter->second + succIter->second;
00335 if (sigma == 0) {
00336 sigma = term;
00337 } else {
00338 sigma = log_sum(sigma,term);
00339 }
00340 }
00341 sigmas[i->hypo] = sigma;
00342 }
00343 }
00344
00345
00346 const Hypothesis* startHypo = searchGraph.back().hypo;
00347 CHECK(startHypo->GetId() == 0);
00348 for (size_t i = 0; i < count; ++i) {
00349 vector<const Hypothesis*> path;
00350 path.push_back(startHypo);
00351 while(1) {
00352 map<const Hypothesis*, set<const Hypothesis*> >::const_iterator outIter =
00353 outgoingHyps.find(path.back());
00354 if (outIter == outgoingHyps.end() || !outIter->second.size()) {
00355
00356 break;
00357 }
00358
00359 vector<const Hypothesis*> candidates;
00360 vector<float> candidateScores;
00361 float scoreTotal = 0;
00362 for (set<const Hypothesis*>::const_iterator j = outIter->second.begin();
00363 j != outIter->second.end(); ++j) {
00364 candidates.push_back(*j);
00365 CHECK(sigmas.find(*j) != sigmas.end());
00366 Edge edge(path.back()->GetId(),(*j)->GetId());
00367 CHECK(edgeScores.find(edge) != edgeScores.end());
00368 candidateScores.push_back(sigmas[*j] + edgeScores[edge]);
00369 if (scoreTotal == 0) {
00370 scoreTotal = candidateScores.back();
00371 } else {
00372 scoreTotal = log_sum(candidateScores.back(), scoreTotal);
00373 }
00374 }
00375
00376
00377 transform(candidateScores.begin(), candidateScores.end(), candidateScores.begin(), bind2nd(minus<float>(),scoreTotal));
00378
00379
00380
00381
00382 float random = log((float)rand()/RAND_MAX);
00383 size_t position = 1;
00384 float sum = candidateScores[0];
00385 for (; position < candidateScores.size() && sum < random; ++position) {
00386 sum = log_sum(sum,candidateScores[position]);
00387 }
00388
00389 const Hypothesis* chosen = candidates[position-1];
00390 path.push_back(chosen);
00391 }
00392
00393
00394
00395
00396
00397
00398
00399 ret.Add(new TrellisPath(path));
00400
00401 }
00402
00403 }
00404
00405
00406
00407 void Manager::CalcDecoderStatistics() const
00408 {
00409 const Hypothesis *hypo = GetBestHypothesis();
00410 if (hypo != NULL) {
00411 GetSentenceStats().CalcFinalStats(*hypo);
00412 IFVERBOSE(2) {
00413 if (hypo != NULL) {
00414 string buff;
00415 string buff2;
00416 TRACE_ERR( "Source and Target Units:"
00417 << hypo->GetInput());
00418 buff2.insert(0,"] ");
00419 buff2.insert(0,(hypo->GetCurrTargetPhrase()).ToString());
00420 buff2.insert(0,":");
00421 buff2.insert(0,(hypo->GetCurrSourceWordsRange()).ToString());
00422 buff2.insert(0,"[");
00423
00424 hypo = hypo->GetPrevHypo();
00425 while (hypo != NULL) {
00426
00427 buff.insert(0,buff2);
00428 buff2.clear();
00429 buff2.insert(0,"] ");
00430 buff2.insert(0,(hypo->GetCurrTargetPhrase()).ToString());
00431 buff2.insert(0,":");
00432 buff2.insert(0,(hypo->GetCurrSourceWordsRange()).ToString());
00433 buff2.insert(0,"[");
00434 hypo = hypo->GetPrevHypo();
00435 }
00436 TRACE_ERR( buff << endl);
00437 }
00438 }
00439 }
00440 }
00441
00442 void OutputWordGraph(std::ostream &outputWordGraphStream, const Hypothesis *hypo, size_t &linkId, const TranslationSystem* system)
00443 {
00444
00445 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00446
00447
00448 outputWordGraphStream << "J=" << linkId++
00449 << "\tS=" << prevHypo->GetId()
00450 << "\tE=" << hypo->GetId()
00451 << "\ta=";
00452
00453
00454 const std::vector<PhraseDictionaryFeature*> &phraseTables = system->GetPhraseDictionaries();
00455 std::vector<PhraseDictionaryFeature*>::const_iterator iterPhraseTable;
00456 for (iterPhraseTable = phraseTables.begin() ; iterPhraseTable != phraseTables.end() ; ++iterPhraseTable) {
00457 const PhraseDictionaryFeature *phraseTable = *iterPhraseTable;
00458 vector<float> scores = hypo->GetScoreBreakdown().GetScoresForProducer(phraseTable);
00459
00460 outputWordGraphStream << scores[0];
00461 vector<float>::const_iterator iterScore;
00462 for (iterScore = ++scores.begin() ; iterScore != scores.end() ; ++iterScore) {
00463 outputWordGraphStream << ", " << *iterScore;
00464 }
00465 }
00466
00467
00468 outputWordGraphStream << "\tl=";
00469 const LMList &lmList = system->GetLanguageModels();
00470 LMList::const_iterator iterLM;
00471 for (iterLM = lmList.begin() ; iterLM != lmList.end() ; ++iterLM) {
00472 LanguageModel *lm = *iterLM;
00473 vector<float> scores = hypo->GetScoreBreakdown().GetScoresForProducer(lm);
00474
00475 outputWordGraphStream << scores[0];
00476 vector<float>::const_iterator iterScore;
00477 for (iterScore = ++scores.begin() ; iterScore != scores.end() ; ++iterScore) {
00478 outputWordGraphStream << ", " << *iterScore;
00479 }
00480 }
00481
00482
00483 outputWordGraphStream << "\tr=";
00484
00485 outputWordGraphStream << hypo->GetScoreBreakdown().GetScoreForProducer(system->GetDistortionProducer());
00486
00487
00488 const std::vector<LexicalReordering*> &lexOrderings = system->GetReorderModels();
00489 std::vector<LexicalReordering*>::const_iterator iterLexOrdering;
00490 for (iterLexOrdering = lexOrderings.begin() ; iterLexOrdering != lexOrderings.end() ; ++iterLexOrdering) {
00491 LexicalReordering *lexicalReordering = *iterLexOrdering;
00492 vector<float> scores = hypo->GetScoreBreakdown().GetScoresForProducer(lexicalReordering);
00493
00494 outputWordGraphStream << scores[0];
00495 vector<float>::const_iterator iterScore;
00496 for (iterScore = ++scores.begin() ; iterScore != scores.end() ; ++iterScore) {
00497 outputWordGraphStream << ", " << *iterScore;
00498 }
00499 }
00500
00501
00502
00503
00504
00505 outputWordGraphStream << "\tw=" << hypo->GetSourcePhraseStringRep() << "|" << hypo->GetCurrTargetPhrase();
00506
00507 outputWordGraphStream << endl;
00508 }
00509
00510 void Manager::GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const
00511 {
00512 const StaticData &staticData = StaticData::Instance();
00513 string fileName = staticData.GetParam("output-word-graph")[0];
00514 bool outputNBest = Scan<bool>(staticData.GetParam("output-word-graph")[1]);
00515 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00516
00517 outputWordGraphStream << "VERSION=1.0" << endl
00518 << "UTTERANCE=" << translationId << endl;
00519
00520 size_t linkId = 0;
00521 size_t stackNo = 1;
00522 std::vector < HypothesisStack* >::const_iterator iterStack;
00523 for (iterStack = ++hypoStackColl.begin() ; iterStack != hypoStackColl.end() ; ++iterStack) {
00524 cerr << endl << stackNo++ << endl;
00525 const HypothesisStack &stack = **iterStack;
00526 HypothesisStack::const_iterator iterHypo;
00527 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00528 const Hypothesis *hypo = *iterHypo;
00529 OutputWordGraph(outputWordGraphStream, hypo, linkId, m_system);
00530
00531 if (outputNBest) {
00532 const ArcList *arcList = hypo->GetArcList();
00533 if (arcList != NULL) {
00534 ArcList::const_iterator iterArcList;
00535 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00536 const Hypothesis *loserHypo = *iterArcList;
00537 OutputWordGraph(outputWordGraphStream, loserHypo, linkId,m_system);
00538 }
00539 }
00540 }
00541 }
00542 }
00543 }
00544
00545 void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const
00546 {
00547 std::map < int, bool > connected;
00548 std::map < int, int > forward;
00549 std::map < int, double > forwardScore;
00550
00551
00552 std::vector< const Hypothesis *> connectedList;
00553 GetConnectedGraph(&connected, &connectedList);
00554
00555
00556
00557
00558 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00559 const HypothesisStack &finalStack = *hypoStackColl.back();
00560 HypothesisStack::const_iterator iterHypo;
00561 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
00562 const Hypothesis *hypo = *iterHypo;
00563 forwardScore[ hypo->GetId() ] = 0.0f;
00564 forward[ hypo->GetId() ] = -1;
00565 }
00566
00567
00568 std::vector < HypothesisStack* >::const_iterator iterStack;
00569 for (iterStack = --hypoStackColl.end() ; iterStack != hypoStackColl.begin() ; --iterStack) {
00570 const HypothesisStack &stack = **iterStack;
00571 HypothesisStack::const_iterator iterHypo;
00572 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00573 const Hypothesis *hypo = *iterHypo;
00574 if (connected.find( hypo->GetId() ) != connected.end()) {
00575
00576 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00577 double fscore = forwardScore[ hypo->GetId() ] +
00578 hypo->GetScore() - prevHypo->GetScore();
00579 if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
00580 || forwardScore.find( prevHypo->GetId() )->second < fscore) {
00581 forwardScore[ prevHypo->GetId() ] = fscore;
00582 forward[ prevHypo->GetId() ] = hypo->GetId();
00583 }
00584
00585 const ArcList *arcList = hypo->GetArcList();
00586 if (arcList != NULL) {
00587 ArcList::const_iterator iterArcList;
00588 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00589 const Hypothesis *loserHypo = *iterArcList;
00590
00591 const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
00592 double fscore = forwardScore[ hypo->GetId() ] +
00593 loserHypo->GetScore() - loserPrevHypo->GetScore();
00594 if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
00595 || forwardScore.find( loserPrevHypo->GetId() )->second < fscore) {
00596 forwardScore[ loserPrevHypo->GetId() ] = fscore;
00597 forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
00598 }
00599 }
00600 }
00601 }
00602 }
00603 }
00604
00605
00606
00607 connected[ 0 ] = true;
00608 for (iterStack = hypoStackColl.begin() ; iterStack != hypoStackColl.end() ; ++iterStack) {
00609 const HypothesisStack &stack = **iterStack;
00610 HypothesisStack::const_iterator iterHypo;
00611 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00612 const Hypothesis *hypo = *iterHypo;
00613 if (connected.find( hypo->GetId() ) != connected.end()) {
00614 searchGraph.push_back(SearchGraphNode(hypo,NULL,forward[hypo->GetId()],
00615 forwardScore[hypo->GetId()]));
00616
00617 const ArcList *arcList = hypo->GetArcList();
00618 if (arcList != NULL) {
00619 ArcList::const_iterator iterArcList;
00620 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00621 const Hypothesis *loserHypo = *iterArcList;
00622 searchGraph.push_back(SearchGraphNode(loserHypo,hypo,
00623 forward[hypo->GetId()], forwardScore[hypo->GetId()]));
00624 }
00625 }
00626 }
00627 }
00628 }
00629
00630 }
00631
00632 void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream,
00633 const SearchGraphNode& searchNode)
00634 {
00635 const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
00636 bool extendedFormat = StaticData::Instance().GetOutputSearchGraphExtended();
00637 outputSearchGraphStream << translationId;
00638
00639
00640 if ( searchNode.hypo->GetId() == 0 ) {
00641 outputSearchGraphStream << " hyp=0 stack=0";
00642 if (!extendedFormat) {
00643 outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore;
00644 }
00645 outputSearchGraphStream << endl;
00646 return;
00647 }
00648
00649 const Hypothesis *prevHypo = searchNode.hypo->GetPrevHypo();
00650
00651
00652 if (!extendedFormat) {
00653 outputSearchGraphStream << " hyp=" << searchNode.hypo->GetId()
00654 << " stack=" << searchNode.hypo->GetWordsBitmap().GetNumWordsCovered()
00655 << " back=" << prevHypo->GetId()
00656 << " score=" << searchNode.hypo->GetScore()
00657 << " transition=" << (searchNode.hypo->GetScore() - prevHypo->GetScore());
00658
00659 if (searchNode.recombinationHypo != NULL)
00660 outputSearchGraphStream << " recombined=" << searchNode.recombinationHypo->GetId();
00661
00662 outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore
00663 << " covered=" << searchNode.hypo->GetCurrSourceWordsRange().GetStartPos()
00664 << "-" << searchNode.hypo->GetCurrSourceWordsRange().GetEndPos()
00665 << " out=" << searchNode.hypo->GetCurrTargetPhrase().GetStringRep(outputFactorOrder)
00666 << endl;
00667 return;
00668 }
00669
00670
00671 if (searchNode.recombinationHypo != NULL)
00672 outputSearchGraphStream << " hyp=" << searchNode.recombinationHypo->GetId();
00673 else
00674 outputSearchGraphStream << " hyp=" << searchNode.hypo->GetId();
00675
00676 outputSearchGraphStream << " stack=" << searchNode.hypo->GetWordsBitmap().GetNumWordsCovered()
00677 << " back=" << prevHypo->GetId()
00678 << " score=" << searchNode.hypo->GetScore()
00679 << " transition=" << (searchNode.hypo->GetScore() - prevHypo->GetScore());
00680
00681 if (searchNode.recombinationHypo != NULL)
00682 outputSearchGraphStream << " recombined=" << searchNode.recombinationHypo->GetId();
00683
00684 outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore
00685 << " covered=" << searchNode.hypo->GetCurrSourceWordsRange().GetStartPos()
00686 << "-" << searchNode.hypo->GetCurrSourceWordsRange().GetEndPos();
00687
00688
00689 ScoreComponentCollection scoreBreakdown = searchNode.hypo->GetScoreBreakdown();
00690 scoreBreakdown.MinusEquals( prevHypo->GetScoreBreakdown() );
00691 outputSearchGraphStream << " scores=[ ";
00692 StaticData::Instance().GetScoreIndexManager().PrintLabeledScores( outputSearchGraphStream, scoreBreakdown );
00693 outputSearchGraphStream << " ]";
00694
00695
00696 outputSearchGraphStream << " out=" << searchNode.hypo->GetSourcePhraseStringRep() << "|" <<
00697 searchNode.hypo->GetCurrTargetPhrase().GetStringRep(outputFactorOrder) << endl;
00698
00699 }
00700
00701 void Manager::GetConnectedGraph(
00702 std::map< int, bool >* pConnected,
00703 std::vector< const Hypothesis* >* pConnectedList) const
00704 {
00705 std::map < int, bool >& connected = *pConnected;
00706 std::vector< const Hypothesis *>& connectedList = *pConnectedList;
00707
00708
00709 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00710 const HypothesisStack &finalStack = *hypoStackColl.back();
00711 HypothesisStack::const_iterator iterHypo;
00712 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
00713 const Hypothesis *hypo = *iterHypo;
00714 connected[ hypo->GetId() ] = true;
00715 connectedList.push_back( hypo );
00716 }
00717
00718 for(size_t i=0; i<connectedList.size(); i++) {
00719 const Hypothesis *hypo = connectedList[i];
00720
00721
00722 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00723 if (prevHypo && prevHypo->GetId() > 0
00724 && connected.find( prevHypo->GetId() ) == connected.end()) {
00725 connected[ prevHypo->GetId() ] = true;
00726 connectedList.push_back( prevHypo );
00727 }
00728
00729
00730 const ArcList *arcList = hypo->GetArcList();
00731 if (arcList != NULL) {
00732 ArcList::const_iterator iterArcList;
00733 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00734 const Hypothesis *loserHypo = *iterArcList;
00735 if (connected.find( loserHypo->GetId() ) == connected.end()) {
00736 connected[ loserHypo->GetId() ] = true;
00737 connectedList.push_back( loserHypo );
00738 }
00739 }
00740 }
00741 }
00742 }
00743
00744 void Manager::GetWinnerConnectedGraph(
00745 std::map< int, bool >* pConnected,
00746 std::vector< const Hypothesis* >* pConnectedList) const
00747 {
00748 std::map < int, bool >& connected = *pConnected;
00749 std::vector< const Hypothesis *>& connectedList = *pConnectedList;
00750
00751
00752 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00753 const HypothesisStack &finalStack = *hypoStackColl.back();
00754 HypothesisStack::const_iterator iterHypo;
00755 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
00756 const Hypothesis *hypo = *iterHypo;
00757 connected[ hypo->GetId() ] = true;
00758 connectedList.push_back( hypo );
00759 }
00760
00761
00762 for(size_t i=0; i<connectedList.size(); i++) {
00763 const Hypothesis *hypo = connectedList[i];
00764
00765
00766 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00767 if (prevHypo->GetId() > 0
00768 && connected.find( prevHypo->GetId() ) == connected.end()) {
00769 connected[ prevHypo->GetId() ] = true;
00770 connectedList.push_back( prevHypo );
00771 }
00772
00773
00774 const ArcList *arcList = hypo->GetArcList();
00775 if (arcList != NULL) {
00776 ArcList::const_iterator iterArcList;
00777 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00778 const Hypothesis *loserHypo = *iterArcList;
00779 if (connected.find( loserHypo->GetPrevHypo()->GetId() ) == connected.end() && loserHypo->GetPrevHypo()->GetId() > 0) {
00780 connected[ loserHypo->GetPrevHypo()->GetId() ] = true;
00781 connectedList.push_back( loserHypo->GetPrevHypo() );
00782 }
00783 }
00784 }
00785 }
00786 }
00787
00788
00789 #ifdef HAVE_PROTOBUF
00790
00791 void SerializeEdgeInfo(const Hypothesis* hypo, hgmert::Hypergraph_Edge* edge)
00792 {
00793 hgmert::Rule* rule = edge->mutable_rule();
00794 hypo->GetCurrTargetPhrase().WriteToRulePB(rule);
00795 const Hypothesis* prev = hypo->GetPrevHypo();
00796
00797 if (!prev) return;
00798
00799
00800 const ScoreComponentCollection& scores = hypo->GetScoreBreakdown();
00801 const ScoreComponentCollection& pscores = prev->GetScoreBreakdown();
00802 for (unsigned int i = 0; i < scores.size(); ++i)
00803 edge->add_feature_values((scores[i] - pscores[i]) * -1.0);
00804 }
00805
00806 hgmert::Hypergraph_Node* GetHGNode(
00807 const Hypothesis* hypo,
00808 std::map< int, int>* i2hgnode,
00809 hgmert::Hypergraph* hg,
00810 int* hgNodeIdx)
00811 {
00812 hgmert::Hypergraph_Node* hgnode;
00813 std::map < int, int >::iterator idxi = i2hgnode->find(hypo->GetId());
00814 if (idxi == i2hgnode->end()) {
00815 *hgNodeIdx = ((*i2hgnode)[hypo->GetId()] = hg->nodes_size());
00816 hgnode = hg->add_nodes();
00817 } else {
00818 *hgNodeIdx = idxi->second;
00819 hgnode = hg->mutable_nodes(*hgNodeIdx);
00820 }
00821 return hgnode;
00822 }
00823
00824 void Manager::SerializeSearchGraphPB(
00825 long translationId,
00826 std::ostream& outputStream) const
00827 {
00828 using namespace hgmert;
00829 std::map < int, bool > connected;
00830 std::map < int, int > i2hgnode;
00831 std::vector< const Hypothesis *> connectedList;
00832 GetConnectedGraph(&connected, &connectedList);
00833 connected[ 0 ] = true;
00834 Hypergraph hg;
00835 hg.set_is_sorted(false);
00836 int num_feats = (*m_search->GetHypothesisStacks().back()->begin())->GetScoreBreakdown().size();
00837 hg.set_num_features(num_feats);
00838 StaticData::Instance().GetScoreIndexManager().SerializeFeatureNamesToPB(&hg);
00839 Hypergraph_Node* goal = hg.add_nodes();
00840 Hypergraph_Node* source = hg.add_nodes();
00841 i2hgnode[-1] = 1;
00842 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00843 const HypothesisStack &finalStack = *hypoStackColl.back();
00844 for (std::vector < HypothesisStack* >::const_iterator iterStack = hypoStackColl.begin();
00845 iterStack != hypoStackColl.end() ; ++iterStack) {
00846 const HypothesisStack &stack = **iterStack;
00847 HypothesisStack::const_iterator iterHypo;
00848
00849 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00850 const Hypothesis *hypo = *iterHypo;
00851 bool is_goal = hypo->GetWordsBitmap().IsComplete();
00852 if (connected.find( hypo->GetId() ) != connected.end()) {
00853 int headNodeIdx;
00854 Hypergraph_Node* headNode = GetHGNode(hypo, &i2hgnode, &hg, &headNodeIdx);
00855 if (is_goal) {
00856 Hypergraph_Edge* ge = hg.add_edges();
00857 ge->set_head_node(0);
00858 ge->add_tail_nodes(headNodeIdx);
00859 ge->mutable_rule()->add_trg_words("[X,1]");
00860 }
00861 Hypergraph_Edge* edge = hg.add_edges();
00862 SerializeEdgeInfo(hypo, edge);
00863 edge->set_head_node(headNodeIdx);
00864 const Hypothesis* prev = hypo->GetPrevHypo();
00865 int tailNodeIdx = 1;
00866 if (prev)
00867 tailNodeIdx = i2hgnode.find(prev->GetId())->second;
00868 edge->add_tail_nodes(tailNodeIdx);
00869
00870 const ArcList *arcList = hypo->GetArcList();
00871 if (arcList != NULL) {
00872 ArcList::const_iterator iterArcList;
00873 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00874 const Hypothesis *loserHypo = *iterArcList;
00875 CHECK(connected[loserHypo->GetId()]);
00876 Hypergraph_Edge* edge = hg.add_edges();
00877 SerializeEdgeInfo(loserHypo, edge);
00878 edge->set_head_node(headNodeIdx);
00879 tailNodeIdx = i2hgnode.find(loserHypo->GetPrevHypo()->GetId())->second;
00880 edge->add_tail_nodes(tailNodeIdx);
00881 }
00882 }
00883 }
00884 }
00885 }
00886 hg.SerializeToOstream(&outputStream);
00887 }
00888 #endif
00889
00890 void Manager::OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
00891 {
00892 vector<SearchGraphNode> searchGraph;
00893 GetSearchGraph(searchGraph);
00894 for (size_t i = 0; i < searchGraph.size(); ++i) {
00895 OutputSearchNode(translationId,outputSearchGraphStream,searchGraph[i]);
00896 }
00897 }
00898
00899 void Manager::GetForwardBackwardSearchGraph(std::map< int, bool >* pConnected,
00900 std::vector< const Hypothesis* >* pConnectedList, std::map < const Hypothesis*, set< const Hypothesis* > >* pOutgoingHyps, vector< float>* pFwdBwdScores) const
00901 {
00902 std::map < int, bool > &connected = *pConnected;
00903 std::vector< const Hypothesis *>& connectedList = *pConnectedList;
00904 std::map < int, int > forward;
00905 std::map < int, double > forwardScore;
00906
00907 std::map < const Hypothesis*, set <const Hypothesis*> > & outgoingHyps = *pOutgoingHyps;
00908 vector< float> & estimatedScores = *pFwdBwdScores;
00909
00910
00911 GetWinnerConnectedGraph(&connected, &connectedList);
00912
00913
00914
00915
00916 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00917 const HypothesisStack &finalStack = *hypoStackColl.back();
00918 HypothesisStack::const_iterator iterHypo;
00919 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
00920 const Hypothesis *hypo = *iterHypo;
00921 forwardScore[ hypo->GetId() ] = 0.0f;
00922 forward[ hypo->GetId() ] = -1;
00923 }
00924
00925
00926 std::vector < HypothesisStack* >::const_iterator iterStack;
00927 for (iterStack = --hypoStackColl.end() ; iterStack != hypoStackColl.begin() ; --iterStack) {
00928 const HypothesisStack &stack = **iterStack;
00929 HypothesisStack::const_iterator iterHypo;
00930 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00931 const Hypothesis *hypo = *iterHypo;
00932 if (connected.find( hypo->GetId() ) != connected.end()) {
00933
00934 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00935 double fscore = forwardScore[ hypo->GetId() ] +
00936 hypo->GetScore() - prevHypo->GetScore();
00937 if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
00938 || forwardScore.find( prevHypo->GetId() )->second < fscore) {
00939 forwardScore[ prevHypo->GetId() ] = fscore;
00940 forward[ prevHypo->GetId() ] = hypo->GetId();
00941 }
00942
00943 outgoingHyps[prevHypo].insert(hypo);
00944
00945
00946 const ArcList *arcList = hypo->GetArcList();
00947 if (arcList != NULL) {
00948 ArcList::const_iterator iterArcList;
00949 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00950 const Hypothesis *loserHypo = *iterArcList;
00951
00952 const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
00953 double fscore = forwardScore[ hypo->GetId() ] +
00954 loserHypo->GetScore() - loserPrevHypo->GetScore();
00955 if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
00956 || forwardScore.find( loserPrevHypo->GetId() )->second < fscore) {
00957 forwardScore[ loserPrevHypo->GetId() ] = fscore;
00958 forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
00959 }
00960
00961 outgoingHyps[loserPrevHypo].insert(hypo);
00962
00963
00964 }
00965 }
00966 }
00967 }
00968 }
00969
00970 for (std::vector< const Hypothesis *>::iterator it = connectedList.begin(); it != connectedList.end(); ++it) {
00971 float estimatedScore = (*it)->GetScore() + forwardScore[(*it)->GetId()];
00972 estimatedScores.push_back(estimatedScore);
00973 }
00974 }
00975
00976
00977 const Hypothesis *Manager::GetBestHypothesis() const
00978 {
00979 return m_search->GetBestHypothesis();
00980 }
00981
00982 int Manager::GetNextHypoId()
00983 {
00984 return m_hypoId++;
00985 }
00986
00987 void Manager::ResetSentenceStats(const InputType& source)
00988 {
00989 m_sentenceStats = std::auto_ptr<SentenceStats>(new SentenceStats(source));
00990 }
00991 SentenceStats& Manager::GetSentenceStats() const
00992 {
00993 return *m_sentenceStats;
00994
00995 }
00996
00997 }