00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifdef WIN32
00023 #include <hash_set>
00024 #else
00025 #include <ext/hash_set>
00026 #endif
00027
00028 #include <algorithm>
00029 #include <cmath>
00030 #include <limits>
00031 #include <map>
00032 #include <set>
00033 #include "Manager.h"
00034 #include "TypeDef.h"
00035 #include "Util.h"
00036 #include "TargetPhrase.h"
00037 #include "TrellisPath.h"
00038 #include "TrellisPathCollection.h"
00039 #include "TranslationOption.h"
00040 #include "LexicalReordering.h"
00041 #include "LMList.h"
00042 #include "TranslationOptionCollection.h"
00043 #include "DummyScoreProducers.h"
00044 #include "Timer.h"
00045
00046 #ifdef HAVE_PROTOBUF
00047 #include "hypergraph.pb.h"
00048 #include "rule.pb.h"
00049 #endif
00050
00051 #include "util/exception.hh"
00052
00053 using namespace std;
00054
00055 namespace Moses
00056 {
00057 Manager::Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm)
00058 :m_transOptColl(source.CreateTranslationOptionCollection())
00059 ,m_search(Search::CreateSearch(*this, source, searchAlgorithm, *m_transOptColl))
00060 ,interrupted_flag(0)
00061 ,m_hypoId(0)
00062 ,m_lineNumber(lineNumber)
00063 ,m_source(source)
00064 {
00065 StaticData::Instance().InitializeForInput(source);
00066 }
00067
00068 Manager::~Manager()
00069 {
00070 delete m_transOptColl;
00071 delete m_search;
00072
00073 StaticData::Instance().CleanUpAfterSentenceProcessing(m_source);
00074 }
00075
00080 void Manager::ProcessSentence()
00081 {
00082
00083 ResetSentenceStats(m_source);
00084
00085 Timer getOptionsTime;
00086 getOptionsTime.start();
00087 m_transOptColl->CreateTranslationOptions();
00088 VERBOSE(1, "Line "<< m_lineNumber << ": Collecting options took " << getOptionsTime << " seconds" << endl);
00089
00090
00091 IFVERBOSE(2) {
00092
00093 GetSentenceStats().AddTimeCollectOpts((clock_t) (getOptionsTime.get_elapsed_time() * CLOCKS_PER_SEC));
00094 }
00095
00096
00097 Timer searchTime;
00098 searchTime.start();
00099 m_search->ProcessSentence();
00100 VERBOSE(1, "Line " << m_lineNumber << ": Search took " << searchTime << " seconds" << endl);
00101 }
00102
00108 void Manager::PrintAllDerivations(long translationId, ostream& outputStream) const
00109 {
00110 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00111
00112 vector<const Hypothesis*> sortedPureHypo = hypoStackColl.back()->GetSortedList();
00113
00114 if (sortedPureHypo.size() == 0)
00115 return;
00116
00117 float remainingScore = 0;
00118 vector<const TargetPhrase*> remainingPhrases;
00119
00120
00121 vector<const Hypothesis*>::const_iterator iterBestHypo;
00122 for (iterBestHypo = sortedPureHypo.begin()
00123 ; iterBestHypo != sortedPureHypo.end()
00124 ; ++iterBestHypo) {
00125 printThisHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore, outputStream);
00126 printDivergentHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore, outputStream);
00127 }
00128 }
00129
00130 const TranslationOptionCollection* Manager::getSntTranslationOptions()
00131 {
00132 return m_transOptColl;
00133 }
00134
00135 void Manager::printDivergentHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore , ostream& outputStream ) const
00136 {
00137
00138 if (hypo->GetId() > 0) {
00139 vector <const TargetPhrase*> followingPhrases;
00140 followingPhrases.push_back(& (hypo->GetCurrTargetPhrase()));
00142 followingPhrases.insert(followingPhrases.end()--, remainingPhrases.begin(), remainingPhrases.end());
00143 printDivergentHypothesis(translationId, hypo->GetPrevHypo(), followingPhrases , remainingScore + hypo->GetScore() - hypo->GetPrevHypo()->GetScore(), outputStream);
00144 }
00145
00146
00147 const ArcList *pAL = hypo->GetArcList();
00148 if (pAL) {
00149 const ArcList &arcList = *pAL;
00150
00151 ArcList::const_iterator iterArc;
00152 for (iterArc = arcList.begin() ; iterArc != arcList.end() ; ++iterArc) {
00153 const Hypothesis *loserHypo = *iterArc;
00154 const Hypothesis* loserPrevHypo = loserHypo->GetPrevHypo();
00155 float arcScore = loserHypo->GetScore() - loserPrevHypo->GetScore();
00156 vector <const TargetPhrase* > followingPhrases;
00157 followingPhrases.push_back(&(loserHypo->GetCurrTargetPhrase()));
00158 followingPhrases.insert(followingPhrases.end()--, remainingPhrases.begin(), remainingPhrases.end());
00159 printThisHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore, outputStream);
00160 printDivergentHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore, outputStream);
00161 }
00162 }
00163 }
00164
00165
00166 void Manager::printThisHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore, ostream& outputStream) const
00167 {
00168
00169 outputStream << translationId << " ||| ";
00170
00171
00172 hypo->ToStream(outputStream);
00173 for (size_t p = 0; p < remainingPhrases.size(); ++p) {
00174 const TargetPhrase * phrase = remainingPhrases[p];
00175 size_t size = phrase->GetSize();
00176 for (size_t pos = 0 ; pos < size ; pos++) {
00177 const Factor *factor = phrase->GetFactor(pos, 0);
00178 outputStream << *factor;
00179 outputStream << " ";
00180 }
00181 }
00182
00183 outputStream << "||| " << hypo->GetScore() + remainingScore;
00184 outputStream << endl;
00185 }
00186
00187
00188
00189
00199 void Manager::CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct) const
00200 {
00201 if (count <= 0)
00202 return;
00203
00204 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00205
00206 vector<const Hypothesis*> sortedPureHypo = hypoStackColl.back()->GetSortedList();
00207
00208 if (sortedPureHypo.size() == 0)
00209 return;
00210
00211 TrellisPathCollection contenders;
00212
00213 set<Phrase> distinctHyps;
00214
00215
00216 vector<const Hypothesis*>::const_iterator iterBestHypo;
00217 for (iterBestHypo = sortedPureHypo.begin()
00218 ; iterBestHypo != sortedPureHypo.end()
00219 ; ++iterBestHypo) {
00220 contenders.Add(new TrellisPath(*iterBestHypo));
00221 }
00222
00223
00224 size_t nBestFactor = StaticData::Instance().GetNBestFactor();
00225 if (nBestFactor < 1) nBestFactor = 1000;
00226
00227
00228 for (size_t iteration = 0 ; (onlyDistinct ? distinctHyps.size() : ret.GetSize()) < count && contenders.GetSize() > 0 && (iteration < count * nBestFactor) ; iteration++) {
00229
00230 TrellisPath *path = contenders.pop();
00231 CHECK(path);
00232
00233 path->CreateDeviantPaths(contenders);
00234 if(onlyDistinct) {
00235 Phrase tgtPhrase = path->GetSurfacePhrase();
00236 if (distinctHyps.insert(tgtPhrase).second) {
00237 ret.Add(path);
00238 } else {
00239 delete path;
00240 path = NULL;
00241 }
00242 } else {
00243 ret.Add(path);
00244 }
00245
00246
00247 if(onlyDistinct) {
00248 const size_t nBestFactor = StaticData::Instance().GetNBestFactor();
00249 if (nBestFactor > 0)
00250 contenders.Prune(count * nBestFactor);
00251 } else {
00252 contenders.Prune(count);
00253 }
00254 }
00255 }
00256
00257 struct SGNReverseCompare {
00258 bool operator() (const SearchGraphNode& s1, const SearchGraphNode& s2) const {
00259 return s1.hypo->GetId() > s2.hypo->GetId();
00260 }
00261 };
00262
00266 void Manager::CalcLatticeSamples(size_t count, TrellisPathList &ret) const {
00267
00268 vector<SearchGraphNode> searchGraph;
00269 GetSearchGraph(searchGraph);
00270
00271
00272
00273
00274 typedef pair<int, int> Edge;
00275 map<const Hypothesis*, float> sigmas;
00276 map<Edge, float> edgeScores;
00277 map<const Hypothesis*, set<const Hypothesis*> > outgoingHyps;
00278 map<int,const Hypothesis*> idToHyp;
00279 map<int,float> fscores;
00280
00281
00282
00283
00284
00285 sort(searchGraph.begin(), searchGraph.end(), SGNReverseCompare());
00286
00287
00288 for (vector<SearchGraphNode>::const_iterator i = searchGraph.begin();
00289 i != searchGraph.end(); ++i) {
00290 const Hypothesis* hypo = i->hypo;
00291 idToHyp[hypo->GetId()] = hypo;
00292 fscores[hypo->GetId()] = i->fscore;
00293 if (hypo->GetId()) {
00294
00295 const Hypothesis* prevHypo = i->hypo->GetPrevHypo();
00296 outgoingHyps[prevHypo].insert(hypo);
00297 edgeScores[Edge(prevHypo->GetId(),hypo->GetId())] =
00298 hypo->GetScore() - prevHypo->GetScore();
00299 }
00300
00301 if (i->forward >= 0) {
00302 map<int,const Hypothesis*>::const_iterator idToHypIter = idToHyp.find(i->forward);
00303 CHECK(idToHypIter != idToHyp.end());
00304 const Hypothesis* nextHypo = idToHypIter->second;
00305 outgoingHyps[hypo].insert(nextHypo);
00306 map<int,float>::const_iterator fscoreIter = fscores.find(nextHypo->GetId());
00307 CHECK(fscoreIter != fscores.end());
00308 edgeScores[Edge(hypo->GetId(),nextHypo->GetId())] =
00309 i->fscore - fscoreIter->second;
00310 }
00311 }
00312
00313
00314
00315 for (vector<SearchGraphNode>::const_iterator i = searchGraph.begin();
00316 i != searchGraph.end(); ++i) {
00317
00318 if (i->forward == -1) {
00319 sigmas[i->hypo] = 0;
00320 } else {
00321 map<const Hypothesis*, set<const Hypothesis*> >::const_iterator outIter =
00322 outgoingHyps.find(i->hypo);
00323
00324 CHECK(outIter != outgoingHyps.end());
00325 float sigma = 0;
00326 for (set<const Hypothesis*>::const_iterator j = outIter->second.begin();
00327 j != outIter->second.end(); ++j) {
00328 map<const Hypothesis*, float>::const_iterator succIter = sigmas.find(*j);
00329 CHECK(succIter != sigmas.end());
00330 map<Edge,float>::const_iterator edgeScoreIter =
00331 edgeScores.find(Edge(i->hypo->GetId(),(*j)->GetId()));
00332 CHECK(edgeScoreIter != edgeScores.end());
00333 float term = edgeScoreIter->second + succIter->second;
00334 if (sigma == 0) {
00335 sigma = term;
00336 } else {
00337 sigma = log_sum(sigma,term);
00338 }
00339 }
00340 sigmas[i->hypo] = sigma;
00341 }
00342 }
00343
00344
00345 const Hypothesis* startHypo = searchGraph.back().hypo;
00346 CHECK(startHypo->GetId() == 0);
00347 for (size_t i = 0; i < count; ++i) {
00348 vector<const Hypothesis*> path;
00349 path.push_back(startHypo);
00350 while(1) {
00351 map<const Hypothesis*, set<const Hypothesis*> >::const_iterator outIter =
00352 outgoingHyps.find(path.back());
00353 if (outIter == outgoingHyps.end() || !outIter->second.size()) {
00354
00355 break;
00356 }
00357
00358 vector<const Hypothesis*> candidates;
00359 vector<float> candidateScores;
00360 float scoreTotal = 0;
00361 for (set<const Hypothesis*>::const_iterator j = outIter->second.begin();
00362 j != outIter->second.end(); ++j) {
00363 candidates.push_back(*j);
00364 CHECK(sigmas.find(*j) != sigmas.end());
00365 Edge edge(path.back()->GetId(),(*j)->GetId());
00366 CHECK(edgeScores.find(edge) != edgeScores.end());
00367 candidateScores.push_back(sigmas[*j] + edgeScores[edge]);
00368 if (scoreTotal == 0) {
00369 scoreTotal = candidateScores.back();
00370 } else {
00371 scoreTotal = log_sum(candidateScores.back(), scoreTotal);
00372 }
00373 }
00374
00375
00376 transform(candidateScores.begin(), candidateScores.end(), candidateScores.begin(), bind2nd(minus<float>(),scoreTotal));
00377
00378
00379
00380
00381 float random = log((float)rand()/RAND_MAX);
00382 size_t position = 1;
00383 float sum = candidateScores[0];
00384 for (; position < candidateScores.size() && sum < random; ++position) {
00385 sum = log_sum(sum,candidateScores[position]);
00386 }
00387
00388 const Hypothesis* chosen = candidates[position-1];
00389 path.push_back(chosen);
00390 }
00391
00392
00393
00394
00395
00396
00397
00398 ret.Add(new TrellisPath(path));
00399
00400 }
00401
00402 }
00403
00404
00405
00406 void Manager::CalcDecoderStatistics() const
00407 {
00408 const Hypothesis *hypo = GetBestHypothesis();
00409 if (hypo != NULL) {
00410 GetSentenceStats().CalcFinalStats(*hypo);
00411 IFVERBOSE(2) {
00412 if (hypo != NULL) {
00413 string buff;
00414 string buff2;
00415 TRACE_ERR( "Source and Target Units:"
00416 << hypo->GetInput());
00417 buff2.insert(0,"] ");
00418 buff2.insert(0,(hypo->GetCurrTargetPhrase()).ToString());
00419 buff2.insert(0,":");
00420 buff2.insert(0,(hypo->GetCurrSourceWordsRange()).ToString());
00421 buff2.insert(0,"[");
00422
00423 hypo = hypo->GetPrevHypo();
00424 while (hypo != NULL) {
00425
00426 buff.insert(0,buff2);
00427 buff2.clear();
00428 buff2.insert(0,"] ");
00429 buff2.insert(0,(hypo->GetCurrTargetPhrase()).ToString());
00430 buff2.insert(0,":");
00431 buff2.insert(0,(hypo->GetCurrSourceWordsRange()).ToString());
00432 buff2.insert(0,"[");
00433 hypo = hypo->GetPrevHypo();
00434 }
00435 TRACE_ERR( buff << endl);
00436 }
00437 }
00438 }
00439 }
00440
00441 void OutputWordGraph(std::ostream &outputWordGraphStream, const Hypothesis *hypo, size_t &linkId)
00442 {
00443
00444 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00445
00446
00447 outputWordGraphStream << "J=" << linkId++
00448 << "\tS=" << prevHypo->GetId()
00449 << "\tE=" << hypo->GetId()
00450 << "\ta=";
00451
00452
00453 const StaticData &staticData = StaticData::Instance();
00454 const std::vector<PhraseDictionary*> &phraseTables = staticData.GetPhraseDictionaries();
00455 std::vector<PhraseDictionary*>::const_iterator iterPhraseTable;
00456 for (iterPhraseTable = phraseTables.begin() ; iterPhraseTable != phraseTables.end() ; ++iterPhraseTable) {
00457 const PhraseDictionary *phraseTable = *iterPhraseTable;
00458 vector<float> scores = hypo->GetScoreBreakdown().GetScoresForProducer(phraseTable);
00459
00460 outputWordGraphStream << scores[0];
00461 vector<float>::const_iterator iterScore;
00462 for (iterScore = ++scores.begin() ; iterScore != scores.end() ; ++iterScore) {
00463 outputWordGraphStream << ", " << *iterScore;
00464 }
00465 }
00466
00467
00468 outputWordGraphStream << "\tl=";
00469 const LMList &lmList = StaticData::Instance().GetLMList();
00470
00471 LMList::const_iterator iterLM;
00472 for (iterLM = lmList.begin() ; iterLM != lmList.end() ; ++iterLM) {
00473 LanguageModel *lm = *iterLM;
00474 vector<float> scores = hypo->GetScoreBreakdown().GetScoresForProducer(lm);
00475
00476 outputWordGraphStream << scores[0];
00477 vector<float>::const_iterator iterScore;
00478 for (iterScore = ++scores.begin() ; iterScore != scores.end() ; ++iterScore) {
00479 outputWordGraphStream << ", " << *iterScore;
00480 }
00481 }
00482
00483
00484 outputWordGraphStream << "\tr=";
00485
00486 const std::vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions();
00487 std::vector<FeatureFunction*>::const_iterator iter;
00488 for (iter = ffs.begin(); iter != ffs.end(); ++iter) {
00489 const FeatureFunction *ff = *iter;
00490
00491 const DistortionScoreProducer *model = dynamic_cast<const DistortionScoreProducer*>(ff);
00492 if (model) {
00493 outputWordGraphStream << hypo->GetScoreBreakdown().GetScoreForProducer(model);
00494 }
00495 }
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516 outputWordGraphStream << "\tw=" << hypo->GetSourcePhraseStringRep() << "|" << hypo->GetCurrTargetPhrase();
00517
00518 outputWordGraphStream << endl;
00519 }
00520
00521 void Manager::GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const
00522 {
00523 const StaticData &staticData = StaticData::Instance();
00524 string fileName = staticData.GetParam("output-word-graph")[0];
00525 bool outputNBest = Scan<bool>(staticData.GetParam("output-word-graph")[1]);
00526 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00527
00528 outputWordGraphStream << "VERSION=1.0" << endl
00529 << "UTTERANCE=" << translationId << endl;
00530
00531 size_t linkId = 0;
00532 size_t stackNo = 1;
00533 std::vector < HypothesisStack* >::const_iterator iterStack;
00534 for (iterStack = ++hypoStackColl.begin() ; iterStack != hypoStackColl.end() ; ++iterStack) {
00535 const HypothesisStack &stack = **iterStack;
00536 HypothesisStack::const_iterator iterHypo;
00537 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00538 const Hypothesis *hypo = *iterHypo;
00539 OutputWordGraph(outputWordGraphStream, hypo, linkId);
00540
00541 if (outputNBest) {
00542 const ArcList *arcList = hypo->GetArcList();
00543 if (arcList != NULL) {
00544 ArcList::const_iterator iterArcList;
00545 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00546 const Hypothesis *loserHypo = *iterArcList;
00547 OutputWordGraph(outputWordGraphStream, loserHypo, linkId);
00548 }
00549 }
00550 }
00551 }
00552 }
00553 }
00554
00555 void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const
00556 {
00557 std::map < int, bool > connected;
00558 std::map < int, int > forward;
00559 std::map < int, double > forwardScore;
00560
00561
00562 std::vector< const Hypothesis *> connectedList;
00563 GetConnectedGraph(&connected, &connectedList);
00564
00565
00566
00567
00568 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
00569 const HypothesisStack &finalStack = *hypoStackColl.back();
00570 HypothesisStack::const_iterator iterHypo;
00571 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
00572 const Hypothesis *hypo = *iterHypo;
00573 forwardScore[ hypo->GetId() ] = 0.0f;
00574 forward[ hypo->GetId() ] = -1;
00575 }
00576
00577
00578 std::vector < HypothesisStack* >::const_iterator iterStack;
00579 for (iterStack = --hypoStackColl.end() ; iterStack != hypoStackColl.begin() ; --iterStack) {
00580 const HypothesisStack &stack = **iterStack;
00581 HypothesisStack::const_iterator iterHypo;
00582 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00583 const Hypothesis *hypo = *iterHypo;
00584 if (connected.find( hypo->GetId() ) != connected.end()) {
00585
00586 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00587 double fscore = forwardScore[ hypo->GetId() ] +
00588 hypo->GetScore() - prevHypo->GetScore();
00589 if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
00590 || forwardScore.find( prevHypo->GetId() )->second < fscore) {
00591 forwardScore[ prevHypo->GetId() ] = fscore;
00592 forward[ prevHypo->GetId() ] = hypo->GetId();
00593 }
00594
00595 const ArcList *arcList = hypo->GetArcList();
00596 if (arcList != NULL) {
00597 ArcList::const_iterator iterArcList;
00598 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00599 const Hypothesis *loserHypo = *iterArcList;
00600
00601 const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
00602 double fscore = forwardScore[ hypo->GetId() ] +
00603 loserHypo->GetScore() - loserPrevHypo->GetScore();
00604 if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
00605 || forwardScore.find( loserPrevHypo->GetId() )->second < fscore) {
00606 forwardScore[ loserPrevHypo->GetId() ] = fscore;
00607 forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
00608 }
00609 }
00610 }
00611 }
00612 }
00613 }
00614
00615
00616
00617 connected[ 0 ] = true;
00618 for (iterStack = hypoStackColl.begin() ; iterStack != hypoStackColl.end() ; ++iterStack) {
00619 const HypothesisStack &stack = **iterStack;
00620 HypothesisStack::const_iterator iterHypo;
00621 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
00622 const Hypothesis *hypo = *iterHypo;
00623 if (connected.find( hypo->GetId() ) != connected.end()) {
00624 searchGraph.push_back(SearchGraphNode(hypo,NULL,forward[hypo->GetId()],
00625 forwardScore[hypo->GetId()]));
00626
00627 const ArcList *arcList = hypo->GetArcList();
00628 if (arcList != NULL) {
00629 ArcList::const_iterator iterArcList;
00630 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
00631 const Hypothesis *loserHypo = *iterArcList;
00632 searchGraph.push_back(SearchGraphNode(loserHypo,hypo,
00633 forward[hypo->GetId()], forwardScore[hypo->GetId()]));
00634 }
00635 }
00636 }
00637 }
00638 }
00639
00640 }
00641
00642 void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const
00643 {
00644 outputSearchGraphStream.setf(std::ios::fixed);
00645 outputSearchGraphStream.precision(6);
00646
00647 const StaticData& staticData = StaticData::Instance();
00648 const vector<const StatelessFeatureFunction*>& slf = StatelessFeatureFunction::GetStatelessFeatureFunctions();
00649 const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
00650 size_t featureIndex = 1;
00651 for (size_t i = 0; i < sff.size(); ++i) {
00652 featureIndex = OutputFeatureWeightsForSLF(featureIndex, sff[i], outputSearchGraphStream);
00653 }
00654 for (size_t i = 0; i < slf.size(); ++i) {
00655
00656
00657
00658
00659
00660
00661 {
00662 featureIndex = OutputFeatureWeightsForSLF(featureIndex, slf[i], outputSearchGraphStream);
00663 }
00664 }
00665 const vector<PhraseDictionary*>& pds = staticData.GetPhraseDictionaries();
00666 for( size_t i=0; i<pds.size(); i++ ) {
00667 featureIndex = OutputFeatureWeightsForSLF(featureIndex, pds[i], outputSearchGraphStream);
00668 }
00669 const vector<const GenerationDictionary*>& gds = staticData.GetGenerationDictionaries();
00670 for( size_t i=0; i<gds.size(); i++ ) {
00671 featureIndex = OutputFeatureWeightsForSLF(featureIndex, gds[i], outputSearchGraphStream);
00672 }
00673
00674 }
00675
00676 void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const
00677 {
00678 outputSearchGraphStream.setf(std::ios::fixed);
00679 outputSearchGraphStream.precision(6);
00680
00681
00682
00683
00684
00685
00686 const StaticData& staticData = StaticData::Instance();
00687 const vector<const StatelessFeatureFunction*>& slf =StatelessFeatureFunction::GetStatelessFeatureFunctions();
00688 const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
00689 size_t featureIndex = 1;
00690 for (size_t i = 0; i < sff.size(); ++i) {
00691 featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, sff[i], outputSearchGraphStream);
00692 }
00693 for (size_t i = 0; i < slf.size(); ++i) {
00694
00695
00696
00697
00698
00699
00700 {
00701 featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, slf[i], outputSearchGraphStream);
00702 }
00703 }
00704 const vector<PhraseDictionary*>& pds = staticData.GetPhraseDictionaries();
00705 for( size_t i=0; i<pds.size(); i++ ) {
00706 featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, pds[i], outputSearchGraphStream);
00707 }
00708 const vector<const GenerationDictionary*>& gds = staticData.GetGenerationDictionaries();
00709 for( size_t i=0; i<gds.size(); i++ ) {
00710 featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, gds[i], outputSearchGraphStream);
00711 }
00712
00713 }
00714
00715 void Manager::OutputFeatureValuesForHypergraph(const Hypothesis* hypo, std::ostream &outputSearchGraphStream) const
00716 {
00717 outputSearchGraphStream.setf(std::ios::fixed);
00718 outputSearchGraphStream.precision(6);
00719
00720 const StaticData& staticData = StaticData::Instance();
00721 const vector<const StatelessFeatureFunction*>& slf =StatelessFeatureFunction::GetStatelessFeatureFunctions();
00722 const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
00723 size_t featureIndex = 1;
00724 for (size_t i = 0; i < sff.size(); ++i) {
00725 featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, sff[i], outputSearchGraphStream);
00726 }
00727 for (size_t i = 0; i < slf.size(); ++i) {
00728
00729
00730
00731
00732
00733
00734 {
00735 featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, slf[i], outputSearchGraphStream);
00736 }
00737 }
00738 const vector<PhraseDictionary*>& pds = staticData.GetPhraseDictionaries();
00739 for( size_t i=0; i<pds.size(); i++ ) {
00740 featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, pds[i], outputSearchGraphStream);
00741 }
00742 const vector<const GenerationDictionary*>& gds = staticData.GetGenerationDictionaries();
00743 for( size_t i=0; i<gds.size(); i++ ) {
00744 featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, gds[i], outputSearchGraphStream);
00745 }
00746
00747 }
00748
00749
00750 size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
00751 {
00752 size_t numScoreComps = ff->GetNumScoreComponents();
00753 if (numScoreComps != 0) {
00754 vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
00755 for (size_t i = 0; i < numScoreComps; ++i) {
00756 outputSearchGraphStream << "# " << ff->GetScoreProducerDescription()
00757 << " " << ff->GetScoreProducerDescription()
00758 << " " << (i+1) << " of " << numScoreComps << endl
00759 << "x" << (index+i) << "scale=" << values[i] << endl;
00760 }
00761 return index+numScoreComps;
00762 } else {
00763 cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
00764 assert(false);
00765 return 0;
00766 }
00767 }
00768
00769 size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
00770 {
00771
00772
00773
00774
00775
00776
00777
00778
00779
00780
00781
00782
00783
00784
00785
00786
00787
00788
00789
00790
00791
00792
00793
00794 const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
00795
00796 vector<float> featureValues = scoreCollection.GetScoresForProducer(ff);
00797 size_t numScoreComps = featureValues.size();
00798
00799
00800 for (size_t i = 0; i < numScoreComps; ++i) {
00801 outputSearchGraphStream << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " ";
00802 }
00803 return index+numScoreComps;
00804
00805
00806
00807
00808
00809 }
00810
00811 size_t Manager::OutputFeatureValuesForHypergraph(size_t index, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
00812 {
00813 ScoreComponentCollection scoreCollection = hypo->GetScoreBreakdown();
00814 const Hypothesis *prevHypo = hypo->GetPrevHypo();
00815 if (prevHypo) {
00816 scoreCollection.MinusEquals( prevHypo->GetScoreBreakdown() );
00817 }
00818 vector<float> featureValues = scoreCollection.GetScoresForProducer(ff);
00819 size_t numScoreComps = featureValues.size();
00820
00821 if (numScoreComps > 1) {
00822 for (size_t i = 0; i < numScoreComps; ++i) {
00823 outputSearchGraphStream << ff->GetScoreProducerDescription() << i << "=" << featureValues[i] << " ";
00824 }
00825 } else {
00826 outputSearchGraphStream << ff->GetScoreProducerDescription() << "=" << featureValues[0] << " ";
00827 }
00828
00829 return index+numScoreComps;
00830 }
00831
00833 void Manager::OutputSearchGraphAsHypergraph(long translationId, std::ostream &outputSearchGraphStream) const
00834 {
00835
00836 VERBOSE(2,"Getting search graph to output as hypergraph for sentence " << translationId << std::endl)
00837
00838 vector<SearchGraphNode> searchGraph;
00839 GetSearchGraph(searchGraph);
00840
00841
00842 map<int,int> mosesIDToHypergraphID;
00843
00844 set<int> terminalNodes;
00845 multimap<int,int> hypergraphIDToArcs;
00846
00847 VERBOSE(2,"Gathering information about search graph to output as hypergraph for sentence " << translationId << std::endl)
00848
00849 long numNodes = 0;
00850 long endNode = 0;
00851 {
00852 long hypergraphHypothesisID = 0;
00853 for (size_t arcNumber = 0, size=searchGraph.size(); arcNumber < size; ++arcNumber) {
00854
00855
00856 const Hypothesis *prevHypo = searchGraph[arcNumber].hypo->GetPrevHypo();
00857 if (prevHypo!=NULL) {
00858 int mosesPrevHypothesisID = prevHypo->GetId();
00859 if (mosesIDToHypergraphID.count(mosesPrevHypothesisID) == 0) {
00860 mosesIDToHypergraphID[mosesPrevHypothesisID] = hypergraphHypothesisID;
00861
00862 hypergraphHypothesisID += 1;
00863 }
00864 }
00865
00866
00867 int mosesHypothesisID;
00868 if (searchGraph[arcNumber].recombinationHypo) {
00869 mosesHypothesisID = searchGraph[arcNumber].recombinationHypo->GetId();
00870 } else {
00871 mosesHypothesisID = searchGraph[arcNumber].hypo->GetId();
00872 }
00873
00874 if (mosesIDToHypergraphID.count(mosesHypothesisID) == 0) {
00875
00876 mosesIDToHypergraphID[mosesHypothesisID] = hypergraphHypothesisID;
00877
00878
00879 bool terminalNode = (searchGraph[arcNumber].forward == -1);
00880 if (terminalNode) {
00881
00882 terminalNodes.insert(hypergraphHypothesisID);
00883 }
00884
00885 hypergraphHypothesisID += 1;
00886 }
00887
00888
00889 hypergraphIDToArcs.insert(pair<int,int>(mosesIDToHypergraphID[mosesHypothesisID],arcNumber));
00890
00891 }
00892
00893
00894 endNode = hypergraphHypothesisID;
00895
00896 numNodes = endNode + 1;
00897
00898 }
00899
00900
00901 long numArcs = searchGraph.size() + terminalNodes.size();
00902
00903
00904 outputSearchGraphStream << numNodes << " " << numArcs << endl;
00905
00906 VERBOSE(2,"Search graph to output as hypergraph for sentence " << translationId
00907 << " contains " << numArcs << " arcs and " << numNodes << " nodes" << std::endl)
00908
00909 VERBOSE(2,"Outputting search graph to output as hypergraph for sentence " << translationId << std::endl)
00910
00911
00912 for (int hypergraphHypothesisID=0; hypergraphHypothesisID < endNode; hypergraphHypothesisID+=1) {
00913 if (hypergraphHypothesisID % 100000 == 0) {
00914 VERBOSE(2,"Processed " << hypergraphHypothesisID << " of " << numNodes << " hypergraph nodes for sentence " << translationId << std::endl);
00915 }
00916
00917 size_t count = hypergraphIDToArcs.count(hypergraphHypothesisID);
00918
00919 if (count > 0) {
00920 outputSearchGraphStream << count << "\n";
00921
00922 pair<multimap<int,int>::iterator, multimap<int,int>::iterator> range =
00923 hypergraphIDToArcs.equal_range(hypergraphHypothesisID);
00924 for (multimap<int,int>::iterator it=range.first; it!=range.second; ++it) {
00925 int lineNumber = (*it).second;
00926 const Hypothesis *thisHypo = searchGraph[lineNumber].hypo;
00927 int mosesHypothesisID;
00928 if (searchGraph[lineNumber].recombinationHypo) {
00929 mosesHypothesisID = searchGraph[lineNumber].recombinationHypo->GetId();
00930 } else {
00931 mosesHypothesisID = searchGraph[lineNumber].hypo->GetId();
00932 }
00933
00934 UTIL_THROW_IF(
00935 (hypergraphHypothesisID != mosesIDToHypergraphID[mosesHypothesisID]),
00936 util::Exception,
00937 "Error while writing search lattice as hypergraph for sentence " << translationId << ". " <<
00938 "Moses node " << mosesHypothesisID << " was expected to have hypergraph id " << hypergraphHypothesisID <<
00939 ", but actually had hypergraph id " << mosesIDToHypergraphID[mosesHypothesisID] <<
00940 ". There are " << numNodes << " nodes in the search lattice."
00941 );
00942
00943 const Hypothesis *prevHypo = thisHypo->GetPrevHypo();
00944 if (prevHypo==NULL) {
00945
00946 outputSearchGraphStream << "<s> ||| \n";
00947 } else {
00948 int startNode = mosesIDToHypergraphID[prevHypo->GetId()];
00949
00950 UTIL_THROW_IF(
00951 (startNode >= hypergraphHypothesisID),
00952 util::Exception,
00953 "Error while writing search lattice as hypergraph for sentence" << translationId << ". " <<
00954 "The nodes must be output in topological order. The code attempted to violate this restriction."
00955 );
00956
00957 const TargetPhrase &targetPhrase = thisHypo->GetCurrTargetPhrase();
00958 int targetWordCount = targetPhrase.GetSize();
00959
00960 outputSearchGraphStream << "[" << startNode << "]";
00961 for (int targetWordIndex=0; targetWordIndex<targetWordCount; targetWordIndex+=1) {
00962 outputSearchGraphStream << " " << targetPhrase.GetWord(targetWordIndex);
00963 }
00964 outputSearchGraphStream << " ||| ";
00965 OutputFeatureValuesForHypergraph(thisHypo, outputSearchGraphStream);
00966 outputSearchGraphStream << "\n";
00967 }
00968 }
00969 }
00970 }
00971
00972
00973 outputSearchGraphStream << terminalNodes.size() << "\n";
00974 for (set<int>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); ++it) {
00975 outputSearchGraphStream << "[" << (*it) << "] </s> ||| \n";
00976 }
00977
00978 }
00979
00980
00982 void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const
00983 {
00984
00985 vector<SearchGraphNode> searchGraph;
00986 GetSearchGraph(searchGraph);
00987
00988 long numArcs = 0;
00989 long numNodes = 0;
00990
00991 map<int,int> nodes;
00992 set<int> terminalNodes;
00993
00994
00995 nodes[0] = 0;
00996
00997 for (size_t arcNumber = 0; arcNumber < searchGraph.size(); ++arcNumber) {
00998
00999 int targetWordCount = searchGraph[arcNumber].hypo->GetCurrTargetPhrase().GetSize();
01000 numArcs += targetWordCount;
01001
01002 int hypothesisID = searchGraph[arcNumber].hypo->GetId();
01003 if (nodes.count(hypothesisID) == 0) {
01004
01005 numNodes += targetWordCount;
01006 nodes[hypothesisID] = numNodes;
01007
01008
01009 bool terminalNode = (searchGraph[arcNumber].forward == -1);
01010 if (terminalNode) {
01011 numArcs += 1;
01012 }
01013 }
01014
01015 }
01016 numNodes += 1;
01017
01018
01019 nodes[numNodes] = numNodes;
01020
01021 outputSearchGraphStream << "UTTERANCE=Sentence_" << translationId << endl;
01022 outputSearchGraphStream << "VERSION=1.1" << endl;
01023 outputSearchGraphStream << "base=2.71828182845905" << endl;
01024 outputSearchGraphStream << "NODES=" << (numNodes+1) << endl;
01025 outputSearchGraphStream << "LINKS=" << numArcs << endl;
01026
01027 OutputFeatureWeightsForSLF(outputSearchGraphStream);
01028
01029 for (size_t arcNumber = 0, lineNumber = 0; lineNumber < searchGraph.size(); ++lineNumber) {
01030 const Hypothesis *thisHypo = searchGraph[lineNumber].hypo;
01031 const Hypothesis *prevHypo = thisHypo->GetPrevHypo();
01032 if (prevHypo) {
01033
01034 int startNode = nodes[prevHypo->GetId()];
01035 int endNode = nodes[thisHypo->GetId()];
01036 bool terminalNode = (searchGraph[lineNumber].forward == -1);
01037 const TargetPhrase &targetPhrase = thisHypo->GetCurrTargetPhrase();
01038 int targetWordCount = targetPhrase.GetSize();
01039
01040 for (int targetWordIndex=0; targetWordIndex<targetWordCount; targetWordIndex+=1) {
01041 int x = (targetWordCount-targetWordIndex);
01042
01043 outputSearchGraphStream << "J=" << arcNumber;
01044
01045 if (targetWordIndex==0) {
01046 outputSearchGraphStream << " S=" << startNode;
01047 } else {
01048 outputSearchGraphStream << " S=" << endNode - x;
01049 }
01050
01051 outputSearchGraphStream << " E=" << endNode - (x-1)
01052 << " W=" << targetPhrase.GetWord(targetWordIndex);
01053
01054 OutputFeatureValuesForSLF(thisHypo, (targetWordIndex>0), outputSearchGraphStream);
01055
01056 outputSearchGraphStream << endl;
01057
01058 arcNumber += 1;
01059 }
01060
01061 if (terminalNode && terminalNodes.count(endNode) == 0) {
01062 terminalNodes.insert(endNode);
01063 outputSearchGraphStream << "J=" << arcNumber
01064 << " S=" << endNode
01065 << " E=" << numNodes
01066 << endl;
01067 arcNumber += 1;
01068 }
01069 }
01070 }
01071
01072 }
01073
01074 void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream,
01075 const SearchGraphNode& searchNode)
01076 {
01077 const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
01078 bool extendedFormat = StaticData::Instance().GetOutputSearchGraphExtended();
01079 outputSearchGraphStream << translationId;
01080
01081
01082 if ( searchNode.hypo->GetId() == 0 ) {
01083 outputSearchGraphStream << " hyp=0 stack=0";
01084 if (extendedFormat) {
01085 outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore;
01086 }
01087 outputSearchGraphStream << endl;
01088 return;
01089 }
01090
01091 const Hypothesis *prevHypo = searchNode.hypo->GetPrevHypo();
01092
01093
01094 if (!extendedFormat) {
01095 outputSearchGraphStream << " hyp=" << searchNode.hypo->GetId()
01096 << " stack=" << searchNode.hypo->GetWordsBitmap().GetNumWordsCovered()
01097 << " back=" << prevHypo->GetId()
01098 << " score=" << searchNode.hypo->GetScore()
01099 << " transition=" << (searchNode.hypo->GetScore() - prevHypo->GetScore());
01100
01101 if (searchNode.recombinationHypo != NULL)
01102 outputSearchGraphStream << " recombined=" << searchNode.recombinationHypo->GetId();
01103
01104 outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore
01105 << " covered=" << searchNode.hypo->GetCurrSourceWordsRange().GetStartPos()
01106 << "-" << searchNode.hypo->GetCurrSourceWordsRange().GetEndPos()
01107 << " out=" << searchNode.hypo->GetCurrTargetPhrase().GetStringRep(outputFactorOrder)
01108 << endl;
01109 return;
01110 }
01111
01112
01113
01114
01115
01116 outputSearchGraphStream << " hyp=" << searchNode.hypo->GetId();
01117
01118 outputSearchGraphStream << " stack=" << searchNode.hypo->GetWordsBitmap().GetNumWordsCovered()
01119 << " back=" << prevHypo->GetId()
01120 << " score=" << searchNode.hypo->GetScore()
01121 << " transition=" << (searchNode.hypo->GetScore() - prevHypo->GetScore());
01122
01123 if (searchNode.recombinationHypo != NULL)
01124 outputSearchGraphStream << " recombined=" << searchNode.recombinationHypo->GetId();
01125
01126 outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore
01127 << " covered=" << searchNode.hypo->GetCurrSourceWordsRange().GetStartPos()
01128 << "-" << searchNode.hypo->GetCurrSourceWordsRange().GetEndPos();
01129
01130
01131 ScoreComponentCollection scoreBreakdown = searchNode.hypo->GetScoreBreakdown();
01132 scoreBreakdown.MinusEquals( prevHypo->GetScoreBreakdown() );
01133
01134 outputSearchGraphStream << " scores=\"" << scoreBreakdown << "\"";
01135
01136 outputSearchGraphStream << " out=\"" << searchNode.hypo->GetSourcePhraseStringRep() << "|" <<
01137 searchNode.hypo->GetCurrTargetPhrase().GetStringRep(outputFactorOrder) << "\"" << endl;
01138
01139 }
01140
01141 void Manager::GetConnectedGraph(
01142 std::map< int, bool >* pConnected,
01143 std::vector< const Hypothesis* >* pConnectedList) const
01144 {
01145 std::map < int, bool >& connected = *pConnected;
01146 std::vector< const Hypothesis *>& connectedList = *pConnectedList;
01147
01148
01149 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
01150 const HypothesisStack &finalStack = *hypoStackColl.back();
01151 HypothesisStack::const_iterator iterHypo;
01152 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
01153 const Hypothesis *hypo = *iterHypo;
01154 connected[ hypo->GetId() ] = true;
01155 connectedList.push_back( hypo );
01156 }
01157
01158 for(size_t i=0; i<connectedList.size(); i++) {
01159 const Hypothesis *hypo = connectedList[i];
01160
01161
01162 const Hypothesis *prevHypo = hypo->GetPrevHypo();
01163 if (prevHypo && prevHypo->GetId() > 0
01164 && connected.find( prevHypo->GetId() ) == connected.end()) {
01165 connected[ prevHypo->GetId() ] = true;
01166 connectedList.push_back( prevHypo );
01167 }
01168
01169
01170 const ArcList *arcList = hypo->GetArcList();
01171 if (arcList != NULL) {
01172 ArcList::const_iterator iterArcList;
01173 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
01174 const Hypothesis *loserHypo = *iterArcList;
01175 if (connected.find( loserHypo->GetId() ) == connected.end()) {
01176 connected[ loserHypo->GetId() ] = true;
01177 connectedList.push_back( loserHypo );
01178 }
01179 }
01180 }
01181 }
01182 }
01183
01184 void Manager::GetWinnerConnectedGraph(
01185 std::map< int, bool >* pConnected,
01186 std::vector< const Hypothesis* >* pConnectedList) const
01187 {
01188 std::map < int, bool >& connected = *pConnected;
01189 std::vector< const Hypothesis *>& connectedList = *pConnectedList;
01190
01191
01192 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
01193 const HypothesisStack &finalStack = *hypoStackColl.back();
01194 HypothesisStack::const_iterator iterHypo;
01195 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
01196 const Hypothesis *hypo = *iterHypo;
01197 connected[ hypo->GetId() ] = true;
01198 connectedList.push_back( hypo );
01199 }
01200
01201
01202 for(size_t i=0; i<connectedList.size(); i++) {
01203 const Hypothesis *hypo = connectedList[i];
01204
01205
01206 const Hypothesis *prevHypo = hypo->GetPrevHypo();
01207 if (prevHypo->GetId() > 0
01208 && connected.find( prevHypo->GetId() ) == connected.end()) {
01209 connected[ prevHypo->GetId() ] = true;
01210 connectedList.push_back( prevHypo );
01211 }
01212
01213
01214 const ArcList *arcList = hypo->GetArcList();
01215 if (arcList != NULL) {
01216 ArcList::const_iterator iterArcList;
01217 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
01218 const Hypothesis *loserHypo = *iterArcList;
01219 if (connected.find( loserHypo->GetPrevHypo()->GetId() ) == connected.end() && loserHypo->GetPrevHypo()->GetId() > 0) {
01220 connected[ loserHypo->GetPrevHypo()->GetId() ] = true;
01221 connectedList.push_back( loserHypo->GetPrevHypo() );
01222 }
01223 }
01224 }
01225 }
01226 }
01227
01228
01229 #ifdef HAVE_PROTOBUF
01230
01231 void SerializeEdgeInfo(const Hypothesis* hypo, hgmert::Hypergraph_Edge* edge)
01232 {
01233 hgmert::Rule* rule = edge->mutable_rule();
01234 hypo->GetCurrTargetPhrase().WriteToRulePB(rule);
01235 const Hypothesis* prev = hypo->GetPrevHypo();
01236
01237 if (!prev) return;
01238
01239
01240 const ScoreComponentCollection& scores = hypo->GetScoreBreakdown();
01241 const ScoreComponentCollection& pscores = prev->GetScoreBreakdown();
01242 for (unsigned int i = 0; i < scores.size(); ++i)
01243 edge->add_feature_values((scores[i] - pscores[i]) * -1.0);
01244 }
01245
01246 hgmert::Hypergraph_Node* GetHGNode(
01247 const Hypothesis* hypo,
01248 std::map< int, int>* i2hgnode,
01249 hgmert::Hypergraph* hg,
01250 int* hgNodeIdx)
01251 {
01252 hgmert::Hypergraph_Node* hgnode;
01253 std::map < int, int >::iterator idxi = i2hgnode->find(hypo->GetId());
01254 if (idxi == i2hgnode->end()) {
01255 *hgNodeIdx = ((*i2hgnode)[hypo->GetId()] = hg->nodes_size());
01256 hgnode = hg->add_nodes();
01257 } else {
01258 *hgNodeIdx = idxi->second;
01259 hgnode = hg->mutable_nodes(*hgNodeIdx);
01260 }
01261 return hgnode;
01262 }
01263
01264 void Manager::SerializeSearchGraphPB(
01265 long translationId,
01266 std::ostream& outputStream) const
01267 {
01268 using namespace hgmert;
01269 std::map < int, bool > connected;
01270 std::map < int, int > i2hgnode;
01271 std::vector< const Hypothesis *> connectedList;
01272 GetConnectedGraph(&connected, &connectedList);
01273 connected[ 0 ] = true;
01274 Hypergraph hg;
01275 hg.set_is_sorted(false);
01276 int num_feats = (*m_search->GetHypothesisStacks().back()->begin())->GetScoreBreakdown().size();
01277 hg.set_num_features(num_feats);
01278 StaticData::Instance().GetScoreIndexManager().SerializeFeatureNamesToPB(&hg);
01279 Hypergraph_Node* goal = hg.add_nodes();
01280 Hypergraph_Node* source = hg.add_nodes();
01281 i2hgnode[-1] = 1;
01282 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
01283 const HypothesisStack &finalStack = *hypoStackColl.back();
01284 for (std::vector < HypothesisStack* >::const_iterator iterStack = hypoStackColl.begin();
01285 iterStack != hypoStackColl.end() ; ++iterStack) {
01286 const HypothesisStack &stack = **iterStack;
01287 HypothesisStack::const_iterator iterHypo;
01288
01289 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
01290 const Hypothesis *hypo = *iterHypo;
01291 bool is_goal = hypo->GetWordsBitmap().IsComplete();
01292 if (connected.find( hypo->GetId() ) != connected.end()) {
01293 int headNodeIdx;
01294 Hypergraph_Node* headNode = GetHGNode(hypo, &i2hgnode, &hg, &headNodeIdx);
01295 if (is_goal) {
01296 Hypergraph_Edge* ge = hg.add_edges();
01297 ge->set_head_node(0);
01298 ge->add_tail_nodes(headNodeIdx);
01299 ge->mutable_rule()->add_trg_words("[X,1]");
01300 }
01301 Hypergraph_Edge* edge = hg.add_edges();
01302 SerializeEdgeInfo(hypo, edge);
01303 edge->set_head_node(headNodeIdx);
01304 const Hypothesis* prev = hypo->GetPrevHypo();
01305 int tailNodeIdx = 1;
01306 if (prev)
01307 tailNodeIdx = i2hgnode.find(prev->GetId())->second;
01308 edge->add_tail_nodes(tailNodeIdx);
01309
01310 const ArcList *arcList = hypo->GetArcList();
01311 if (arcList != NULL) {
01312 ArcList::const_iterator iterArcList;
01313 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
01314 const Hypothesis *loserHypo = *iterArcList;
01315 CHECK(connected[loserHypo->GetId()]);
01316 Hypergraph_Edge* edge = hg.add_edges();
01317 SerializeEdgeInfo(loserHypo, edge);
01318 edge->set_head_node(headNodeIdx);
01319 tailNodeIdx = i2hgnode.find(loserHypo->GetPrevHypo()->GetId())->second;
01320 edge->add_tail_nodes(tailNodeIdx);
01321 }
01322 }
01323 }
01324 }
01325 }
01326 hg.SerializeToOstream(&outputStream);
01327 }
01328 #endif
01329
01330 void Manager::OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
01331 {
01332 vector<SearchGraphNode> searchGraph;
01333 GetSearchGraph(searchGraph);
01334 for (size_t i = 0; i < searchGraph.size(); ++i) {
01335 OutputSearchNode(translationId,outputSearchGraphStream,searchGraph[i]);
01336 }
01337 }
01338
01339 void Manager::GetForwardBackwardSearchGraph(std::map< int, bool >* pConnected,
01340 std::vector< const Hypothesis* >* pConnectedList, std::map < const Hypothesis*, set< const Hypothesis* > >* pOutgoingHyps, vector< float>* pFwdBwdScores) const
01341 {
01342 std::map < int, bool > &connected = *pConnected;
01343 std::vector< const Hypothesis *>& connectedList = *pConnectedList;
01344 std::map < int, int > forward;
01345 std::map < int, double > forwardScore;
01346
01347 std::map < const Hypothesis*, set <const Hypothesis*> > & outgoingHyps = *pOutgoingHyps;
01348 vector< float> & estimatedScores = *pFwdBwdScores;
01349
01350
01351 GetWinnerConnectedGraph(&connected, &connectedList);
01352
01353
01354
01355
01356 const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
01357 const HypothesisStack &finalStack = *hypoStackColl.back();
01358 HypothesisStack::const_iterator iterHypo;
01359 for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo) {
01360 const Hypothesis *hypo = *iterHypo;
01361 forwardScore[ hypo->GetId() ] = 0.0f;
01362 forward[ hypo->GetId() ] = -1;
01363 }
01364
01365
01366 std::vector < HypothesisStack* >::const_iterator iterStack;
01367 for (iterStack = --hypoStackColl.end() ; iterStack != hypoStackColl.begin() ; --iterStack) {
01368 const HypothesisStack &stack = **iterStack;
01369 HypothesisStack::const_iterator iterHypo;
01370 for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo) {
01371 const Hypothesis *hypo = *iterHypo;
01372 if (connected.find( hypo->GetId() ) != connected.end()) {
01373
01374 const Hypothesis *prevHypo = hypo->GetPrevHypo();
01375 double fscore = forwardScore[ hypo->GetId() ] +
01376 hypo->GetScore() - prevHypo->GetScore();
01377 if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
01378 || forwardScore.find( prevHypo->GetId() )->second < fscore) {
01379 forwardScore[ prevHypo->GetId() ] = fscore;
01380 forward[ prevHypo->GetId() ] = hypo->GetId();
01381 }
01382
01383 outgoingHyps[prevHypo].insert(hypo);
01384
01385
01386 const ArcList *arcList = hypo->GetArcList();
01387 if (arcList != NULL) {
01388 ArcList::const_iterator iterArcList;
01389 for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList) {
01390 const Hypothesis *loserHypo = *iterArcList;
01391
01392 const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
01393 double fscore = forwardScore[ hypo->GetId() ] +
01394 loserHypo->GetScore() - loserPrevHypo->GetScore();
01395 if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
01396 || forwardScore.find( loserPrevHypo->GetId() )->second < fscore) {
01397 forwardScore[ loserPrevHypo->GetId() ] = fscore;
01398 forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
01399 }
01400
01401 outgoingHyps[loserPrevHypo].insert(hypo);
01402
01403
01404 }
01405 }
01406 }
01407 }
01408 }
01409
01410 for (std::vector< const Hypothesis *>::iterator it = connectedList.begin(); it != connectedList.end(); ++it) {
01411 float estimatedScore = (*it)->GetScore() + forwardScore[(*it)->GetId()];
01412 estimatedScores.push_back(estimatedScore);
01413 }
01414 }
01415
01416
01417 const Hypothesis *Manager::GetBestHypothesis() const
01418 {
01419 return m_search->GetBestHypothesis();
01420 }
01421
01422 int Manager::GetNextHypoId()
01423 {
01424 return m_hypoId++;
01425 }
01426
01427 void Manager::ResetSentenceStats(const InputType& source)
01428 {
01429 m_sentenceStats = std::auto_ptr<SentenceStats>(new SentenceStats(source));
01430 }
01431 SentenceStats& Manager::GetSentenceStats() const
01432 {
01433 return *m_sentenceStats;
01434
01435 }
01436
01437 }