00001 #include "Manager.h"
00002 #include "Timer.h"
00003 #include "SearchNormal.h"
00004
00005 using namespace std;
00006
00007 namespace Moses
00008 {
00015 SearchNormal::SearchNormal(Manager& manager, const InputType &source, const TranslationOptionCollection &transOptColl)
00016 :Search(manager)
00017 ,m_source(source)
00018 ,m_hypoStackColl(source.GetSize() + 1)
00019 ,m_initialTargetPhrase(source.m_initialTargetPhrase)
00020 ,m_start(clock())
00021 ,interrupted_flag(0)
00022 ,m_transOptColl(transOptColl)
00023 {
00024 VERBOSE(1, "Translating: " << m_source << endl);
00025 const StaticData &staticData = StaticData::Instance();
00026
00027 if (m_initialTargetPhrase.GetSize() > 0) {
00028 VERBOSE(1, "Search extends partial output: " << m_initialTargetPhrase<<endl);
00029 }
00030
00031
00032 long sentenceID = source.GetTranslationId();
00033 m_constraint = staticData.GetConstrainingPhrase(sentenceID);
00034 if (m_constraint) {
00035 VERBOSE(1, "Search constraint to output: " << *m_constraint<<endl);
00036 }
00037
00038
00039 std::vector < HypothesisStackNormal >::iterator iterStack;
00040 for (size_t ind = 0 ; ind < m_hypoStackColl.size() ; ++ind) {
00041 HypothesisStackNormal *sourceHypoColl = new HypothesisStackNormal(m_manager);
00042 sourceHypoColl->SetMaxHypoStackSize(staticData.GetMaxHypoStackSize(),staticData.GetMinHypoStackDiversity());
00043 sourceHypoColl->SetBeamWidth(staticData.GetBeamWidth());
00044
00045 m_hypoStackColl[ind] = sourceHypoColl;
00046 }
00047 }
00048
00049 SearchNormal::~SearchNormal()
00050 {
00051 RemoveAllInColl(m_hypoStackColl);
00052 }
00053
00058 void SearchNormal::ProcessSentence()
00059 {
00060 const StaticData &staticData = StaticData::Instance();
00061 SentenceStats &stats = m_manager.GetSentenceStats();
00062 clock_t t=0;
00063
00064
00065 Hypothesis *hypo = Hypothesis::Create(m_manager,m_source, m_initialTargetPhrase);
00066 m_hypoStackColl[0]->AddPrune(hypo);
00067
00068
00069 std::vector < HypothesisStack* >::iterator iterStack;
00070 for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00071
00072 double _elapsed_time = GetUserTime();
00073 if (_elapsed_time > staticData.GetTimeoutThreshold()) {
00074 VERBOSE(1,"Decoding is out of time (" << _elapsed_time << "," << staticData.GetTimeoutThreshold() << ")" << std::endl);
00075 interrupted_flag = 1;
00076 return;
00077 }
00078 HypothesisStackNormal &sourceHypoColl = *static_cast<HypothesisStackNormal*>(*iterStack);
00079
00080
00081 VERBOSE(3,"processing hypothesis from next stack");
00082 IFVERBOSE(2) {
00083 t = clock();
00084 }
00085 sourceHypoColl.PruneToSize(staticData.GetMaxHypoStackSize());
00086 VERBOSE(3,std::endl);
00087 sourceHypoColl.CleanupArcList();
00088 IFVERBOSE(2) {
00089 stats.AddTimeStack( clock()-t );
00090 }
00091
00092
00093 HypothesisStackNormal::const_iterator iterHypo;
00094 for (iterHypo = sourceHypoColl.begin() ; iterHypo != sourceHypoColl.end() ; ++iterHypo) {
00095 Hypothesis &hypothesis = **iterHypo;
00096 ProcessOneHypothesis(hypothesis);
00097 }
00098
00099 IFVERBOSE(2) {
00100 OutputHypoStackSize();
00101 }
00102
00103
00104 actual_hypoStack = &sourceHypoColl;
00105 }
00106
00107
00108 IFVERBOSE(2) {
00109 m_manager.GetSentenceStats().SetTimeTotal( clock()-m_start );
00110 }
00111 VERBOSE(2, m_manager.GetSentenceStats());
00112 }
00113
00114
00120 void SearchNormal::ProcessOneHypothesis(const Hypothesis &hypothesis)
00121 {
00122
00123 int maxDistortion = StaticData::Instance().GetMaxDistortion();
00124 bool isWordLattice = StaticData::Instance().GetInputType() == WordLatticeInput;
00125
00126
00127 if (maxDistortion < 0) {
00128 const WordsBitmap hypoBitmap = hypothesis.GetWordsBitmap();
00129 const size_t hypoFirstGapPos = hypoBitmap.GetFirstGapPos()
00130 , sourceSize = m_source.GetSize();
00131
00132 for (size_t startPos = hypoFirstGapPos ; startPos < sourceSize ; ++startPos) {
00133 size_t maxSize = sourceSize - startPos;
00134 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00135 maxSize = (maxSize < maxSizePhrase) ? maxSize : maxSizePhrase;
00136
00137 for (size_t endPos = startPos ; endPos < startPos + maxSize ; ++endPos) {
00138
00139
00140 if (m_transOptColl.GetTranslationOptionList(WordsRange(startPos, endPos)).size() == 0 ||
00141
00142 hypoBitmap.Overlap(WordsRange(startPos, endPos)) ||
00143
00144 !m_source.GetReorderingConstraint().Check( hypoBitmap, startPos, endPos ) ) {
00145 continue;
00146 }
00147
00148
00149 ExpandAllHypotheses(hypothesis, startPos, endPos);
00150 }
00151 }
00152
00153 return;
00154 }
00155
00156
00157
00158 const WordsBitmap hypoBitmap = hypothesis.GetWordsBitmap();
00159 const size_t hypoFirstGapPos = hypoBitmap.GetFirstGapPos()
00160 , sourceSize = m_source.GetSize();
00161
00162
00163 for (size_t startPos = hypoFirstGapPos ; startPos < sourceSize ; ++startPos) {
00164
00165 if(hypoBitmap.GetValue(startPos))
00166 continue;
00167
00168 WordsRange prevRange = hypothesis.GetCurrSourceWordsRange();
00169
00170 size_t maxSize = sourceSize - startPos;
00171 size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
00172 maxSize = (maxSize < maxSizePhrase) ? maxSize : maxSizePhrase;
00173 size_t closestLeft = hypoBitmap.GetEdgeToTheLeftOf(startPos);
00174 if (isWordLattice) {
00175
00176
00177
00178
00179 if (closestLeft != 0 && closestLeft != startPos && !m_source.CanIGetFromAToB(closestLeft, startPos)) {
00180 continue;
00181 }
00182 if (prevRange.GetStartPos() != NOT_FOUND &&
00183 prevRange.GetStartPos() > startPos && !m_source.CanIGetFromAToB(startPos, prevRange.GetStartPos())) {
00184 continue;
00185 }
00186 }
00187
00188 WordsRange currentStartRange(startPos, startPos);
00189 if(m_source.ComputeDistortionDistance(prevRange, currentStartRange) > maxDistortion)
00190 continue;
00191
00192 for (size_t endPos = startPos ; endPos < startPos + maxSize ; ++endPos) {
00193
00194 WordsRange extRange(startPos, endPos);
00195
00196 if (m_transOptColl.GetTranslationOptionList(extRange).size() == 0 ||
00197
00198 hypoBitmap.Overlap(extRange) ||
00199
00200 !m_source.GetReorderingConstraint().Check( hypoBitmap, startPos, endPos ) ||
00201
00202 (isWordLattice && !m_source.IsCoveragePossible(extRange))) {
00203 continue;
00204 }
00205
00206
00207
00208
00209
00210 bool leftMostEdge = (hypoFirstGapPos == startPos);
00211
00212
00213 size_t closestRight = hypoBitmap.GetEdgeToTheRightOf(endPos);
00214 if (isWordLattice) {
00215
00216 if (closestRight != endPos && ((closestRight + 1) < sourceSize) && !m_source.CanIGetFromAToB(endPos + 1, closestRight + 1)) {
00217 continue;
00218 }
00219 }
00220
00221
00222 if (leftMostEdge) {
00223 ExpandAllHypotheses(hypothesis, startPos, endPos);
00224 }
00225
00226 else {
00227
00228
00229
00230
00231
00232
00233
00234 WordsRange bestNextExtension(hypoFirstGapPos, hypoFirstGapPos);
00235 int required_distortion =
00236 m_source.ComputeDistortionDistance(extRange, bestNextExtension);
00237
00238 if (required_distortion > maxDistortion) {
00239 continue;
00240 }
00241
00242
00243 ExpandAllHypotheses(hypothesis, startPos, endPos);
00244
00245 }
00246 }
00247 }
00248 }
00249
00250
00258 void SearchNormal::ExpandAllHypotheses(const Hypothesis &hypothesis, size_t startPos, size_t endPos)
00259 {
00260
00261
00262 float expectedScore = 0.0f;
00263 if (StaticData::Instance().UseEarlyDiscarding()) {
00264
00265 expectedScore = hypothesis.GetScore();
00266
00267
00268 expectedScore += m_transOptColl.GetFutureScore().CalcFutureScore( hypothesis.GetWordsBitmap(), startPos, endPos );
00269 }
00270
00271
00272 const TranslationOptionList &transOptList = m_transOptColl.GetTranslationOptionList(WordsRange(startPos, endPos));
00273 TranslationOptionList::const_iterator iter;
00274 for (iter = transOptList.begin() ; iter != transOptList.end() ; ++iter) {
00275 ExpandHypothesis(hypothesis, **iter, expectedScore);
00276 }
00277 }
00278
00288 void SearchNormal::ExpandHypothesis(const Hypothesis &hypothesis, const TranslationOption &transOpt, float expectedScore)
00289 {
00290 const StaticData &staticData = StaticData::Instance();
00291 SentenceStats &stats = m_manager.GetSentenceStats();
00292 clock_t t=0;
00293
00294 Hypothesis *newHypo;
00295 if (! staticData.UseEarlyDiscarding()) {
00296
00297 IFVERBOSE(2) {
00298 t = clock();
00299 }
00300 newHypo = hypothesis.CreateNext(transOpt, m_constraint);
00301 IFVERBOSE(2) {
00302 stats.AddTimeBuildHyp( clock()-t );
00303 }
00304 if (newHypo==NULL) return;
00305 newHypo->CalcScore(m_transOptColl.GetFutureScore());
00306 } else
00307
00308 {
00309
00310 size_t wordsTranslated = hypothesis.GetWordsBitmap().GetNumWordsCovered() + transOpt.GetSize();
00311 float allowedScore = m_hypoStackColl[wordsTranslated]->GetWorstScore();
00312 if (staticData.GetMinHypoStackDiversity()) {
00313 WordsBitmapID id = hypothesis.GetWordsBitmap().GetIDPlus(transOpt.GetStartPos(), transOpt.GetEndPos());
00314 float allowedScoreForBitmap = m_hypoStackColl[wordsTranslated]->GetWorstScoreForBitmap( id );
00315 allowedScore = std::min( allowedScore, allowedScoreForBitmap );
00316 }
00317 allowedScore += staticData.GetEarlyDiscardingThreshold();
00318
00319
00320 expectedScore += transOpt.GetFutureScore();
00321
00322
00323
00324 if (expectedScore < allowedScore) {
00325 IFVERBOSE(2) {
00326 stats.AddNotBuilt();
00327 }
00328 return;
00329 }
00330
00331
00332 IFVERBOSE(2) {
00333 t = clock();
00334 }
00335 newHypo = hypothesis.CreateNext(transOpt, m_constraint);
00336 if (newHypo==NULL) return;
00337 IFVERBOSE(2) {
00338 stats.AddTimeBuildHyp( clock()-t );
00339 }
00340
00341
00342 expectedScore = newHypo->CalcExpectedScore( m_transOptColl.GetFutureScore() );
00343
00344 if (expectedScore < allowedScore) {
00345 IFVERBOSE(2) {
00346 stats.AddEarlyDiscarded();
00347 }
00348 FREEHYPO( newHypo );
00349 return;
00350 }
00351
00352
00353 newHypo->CalcRemainingScore();
00354
00355 }
00356
00357
00358 IFVERBOSE(3) {
00359 newHypo->PrintHypothesis();
00360 }
00361
00362
00363 size_t wordsTranslated = newHypo->GetWordsBitmap().GetNumWordsCovered();
00364 IFVERBOSE(2) {
00365 t = clock();
00366 }
00367 m_hypoStackColl[wordsTranslated]->AddPrune(newHypo);
00368 IFVERBOSE(2) {
00369 stats.AddTimeStack( clock()-t );
00370 }
00371 }
00372
00373 const std::vector < HypothesisStack* >& SearchNormal::GetHypothesisStacks() const
00374 {
00375 return m_hypoStackColl;
00376 }
00377
00382 const Hypothesis *SearchNormal::GetBestHypothesis() const
00383 {
00384 if (interrupted_flag == 0) {
00385 const HypothesisStackNormal &hypoColl = *static_cast<HypothesisStackNormal*>(m_hypoStackColl.back());
00386 return hypoColl.GetBestHypothesis();
00387 } else {
00388 const HypothesisStackNormal &hypoColl = *actual_hypoStack;
00389 return hypoColl.GetBestHypothesis();
00390 }
00391 }
00392
00396 void SearchNormal::OutputHypoStackSize()
00397 {
00398 std::vector < HypothesisStack* >::const_iterator iterStack = m_hypoStackColl.begin();
00399 TRACE_ERR( "Stack sizes: " << (int)(*iterStack)->size());
00400 for (++iterStack; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00401 TRACE_ERR( ", " << (int)(*iterStack)->size());
00402 }
00403 TRACE_ERR( endl);
00404 }
00405
00406 }