00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include <limits>
00022 #include <iostream>
00023 #include <memory>
00024 #include <sstream>
00025
00026 #include "moses/FF/FFState.h"
00027 #include "Implementation.h"
00028 #include "ChartState.h"
00029 #include "moses/TypeDef.h"
00030 #include "moses/Util.h"
00031 #include "moses/Manager.h"
00032 #include "moses/FactorCollection.h"
00033 #include "moses/Phrase.h"
00034 #include "moses/StaticData.h"
00035 #include "moses/ChartManager.h"
00036 #include "moses/ChartHypothesis.h"
00037 #include "util/exception.hh"
00038
00039 using namespace std;
00040
00041 namespace Moses
00042 {
00043 LanguageModelImplementation::LanguageModelImplementation(const std::string &line)
00044 :LanguageModel(line)
00045 ,m_nGramOrder(NOT_FOUND)
00046 {
00047 }
00048
00049 void LanguageModelImplementation::SetParameter(const std::string& key, const std::string& value)
00050 {
00051 if (key == "order") {
00052 m_nGramOrder = Scan<size_t>(value);
00053 } else if (key == "path") {
00054 m_filePath = value;
00055 } else {
00056 LanguageModel::SetParameter(key, value);
00057 }
00058
00059 }
00060
00061 void LanguageModelImplementation::ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const
00062 {
00063 if (contextFactor.size() < GetNGramOrder()) {
00064 contextFactor.push_back(&word);
00065 } else if (GetNGramOrder() > 0) {
00066
00067 for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++) {
00068 contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1];
00069 }
00070 contextFactor[GetNGramOrder() - 1] = &word;
00071 }
00072 }
00073
00074 LMResult LanguageModelImplementation::GetValueGivenState(
00075 const std::vector<const Word*> &contextFactor,
00076 FFState &state) const
00077 {
00078 return GetValueForgotState(contextFactor, state);
00079 }
00080
00081 void LanguageModelImplementation::GetState(
00082 const std::vector<const Word*> &contextFactor,
00083 FFState &state) const
00084 {
00085 GetValueForgotState(contextFactor, state);
00086 }
00087
00088
00089 void LanguageModelImplementation::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
00090 {
00091 fullScore = 0;
00092 ngramScore = 0;
00093
00094 oovCount = 0;
00095
00096 size_t phraseSize = phrase.GetSize();
00097 if (!phraseSize) return;
00098
00099 vector<const Word*> contextFactor;
00100 contextFactor.reserve(GetNGramOrder());
00101 std::auto_ptr<FFState> state(NewState((phrase.GetWord(0) == GetSentenceStartWord()) ?
00102 GetBeginSentenceState() : GetNullContextState()));
00103 size_t currPos = 0;
00104 while (currPos < phraseSize) {
00105 const Word &word = phrase.GetWord(currPos);
00106
00107 if (word.IsNonTerminal()) {
00108
00109 if (!contextFactor.empty()) {
00110
00111 state.reset(NewState(GetNullContextState()));
00112 contextFactor.clear();
00113 }
00114 } else {
00115 ShiftOrPush(contextFactor, word);
00116 UTIL_THROW_IF2(contextFactor.size() > GetNGramOrder(),
00117 "Can only calculate LM score of phrases up to the n-gram order");
00118
00119 if (word == GetSentenceStartWord()) {
00120
00121 if (currPos != 0) {
00122 UTIL_THROW2("Either your data contains <s> in a position other than the first word or your language model is missing <s>. Did you build your ARPA using IRSTLM and forget to run add-start-end.sh?");
00123 }
00124 } else {
00125 LMResult result = GetValueGivenState(contextFactor, *state);
00126 fullScore += result.score;
00127 if (contextFactor.size() == GetNGramOrder())
00128 ngramScore += result.score;
00129 if (result.unknown) ++oovCount;
00130 }
00131 }
00132
00133 currPos++;
00134 }
00135 }
00136
00137 FFState *LanguageModelImplementation::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
00138 {
00139
00140
00141
00142
00143
00144
00145 if(GetNGramOrder() <= 1)
00146 return NULL;
00147
00148
00149 if (hypo.GetCurrTargetLength() == 0)
00150 return ps ? NewState(ps) : NULL;
00151
00152 IFVERBOSE(2) {
00153 hypo.GetManager().GetSentenceStats().StartTimeCalcLM();
00154 }
00155
00156 const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
00157 const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
00158
00159
00160 vector<const Word*> contextFactor(GetNGramOrder());
00161 size_t index = 0;
00162 for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++) {
00163 if (currPos >= 0)
00164 contextFactor[index++] = &hypo.GetWord(currPos);
00165 else {
00166 contextFactor[index++] = &GetSentenceStartWord();
00167 }
00168 }
00169 FFState *res = NewState(ps);
00170 float lmScore = ps ? GetValueGivenState(contextFactor, *res).score : GetValueForgotState(contextFactor, *res).score;
00171
00172
00173 size_t endPos = std::min(startPos + GetNGramOrder() - 2
00174 , currEndPos);
00175 for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++) {
00176
00177 for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
00178 contextFactor[i] = contextFactor[i + 1];
00179
00180
00181 contextFactor.back() = &hypo.GetWord(currPos);
00182
00183 lmScore += GetValueGivenState(contextFactor, *res).score;
00184 }
00185
00186
00187 if (hypo.IsSourceCompleted()) {
00188 const size_t size = hypo.GetSize();
00189 contextFactor.back() = &GetSentenceEndWord();
00190
00191 for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++) {
00192 int currPos = (int)(size - GetNGramOrder() + i + 1);
00193 if (currPos < 0)
00194 contextFactor[i] = &GetSentenceStartWord();
00195 else
00196 contextFactor[i] = &hypo.GetWord((size_t)currPos);
00197 }
00198 lmScore += GetValueForgotState(contextFactor, *res).score;
00199 } else {
00200 if (endPos < currEndPos) {
00201
00202 for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
00203 for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
00204 contextFactor[i] = contextFactor[i + 1];
00205 contextFactor.back() = &hypo.GetWord(currPos);
00206 }
00207 GetState(contextFactor, *res);
00208 }
00209 }
00210 if (OOVFeatureEnabled()) {
00211 vector<float> scores(2);
00212 scores[0] = lmScore;
00213 scores[1] = 0;
00214 out->PlusEquals(this, scores);
00215 } else {
00216 out->PlusEquals(this, lmScore);
00217 }
00218
00219 IFVERBOSE(2) {
00220 hypo.GetManager().GetSentenceStats().StopTimeCalcLM();
00221 }
00222 return res;
00223 }
00224
00225 FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection* out) const
00226 {
00227 LanguageModelChartState *ret = new LanguageModelChartState(hypo, featureID, GetNGramOrder());
00228
00229 vector<const Word*> contextFactor;
00230 contextFactor.reserve(GetNGramOrder());
00231
00232
00233 FFState *lmState = NewState( GetNullContextState() );
00234
00235
00236 float prefixScore = 0.0;
00237 float finalizedScore = 0.0;
00238
00239
00240 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00241 hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap();
00242
00243
00244 for (size_t phrasePos = 0, wordPos = 0;
00245 phrasePos < hypo.GetCurrTargetPhrase().GetSize();
00246 phrasePos++) {
00247
00248 const Word &word = hypo.GetCurrTargetPhrase().GetWord(phrasePos);
00249
00250
00251 if (!word.IsNonTerminal()) {
00252 ShiftOrPush(contextFactor, word);
00253
00254
00255 if (word == GetSentenceStartWord()) {
00256 UTIL_THROW_IF2(phrasePos != 0,
00257 "Sentence start symbol must be at the beginning of sentence");
00258 delete lmState;
00259 lmState = NewState( GetBeginSentenceState() );
00260 }
00261
00262 else {
00263 updateChartScore( &prefixScore, &finalizedScore, GetValueGivenState(contextFactor, *lmState).score, ++wordPos );
00264 }
00265 }
00266
00267
00268 else {
00269
00270 size_t nonTermIndex = nonTermIndexMap[phrasePos];
00271 const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndex);
00272
00273 const LanguageModelChartState* prevState =
00274 static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID));
00275
00276 size_t subPhraseLength = prevState->GetNumTargetTerminals();
00277
00278
00279 if (phrasePos == 0) {
00280
00281
00282 prefixScore = prevState->GetPrefixScore();
00283 finalizedScore = -prefixScore;
00284
00285
00286 delete lmState;
00287 lmState = NewState( prevState->GetRightContext() );
00288
00289
00290 int suffixPos = prevState->GetSuffix().GetSize() - (GetNGramOrder()-1);
00291 if (suffixPos < 0) suffixPos = 0;
00292 for(; (size_t)suffixPos < prevState->GetSuffix().GetSize(); suffixPos++) {
00293 const Word &word = prevState->GetSuffix().GetWord(suffixPos);
00294 ShiftOrPush(contextFactor, word);
00295 wordPos++;
00296 }
00297 }
00298
00299
00300 else {
00301
00302 for(size_t prefixPos = 0;
00303 prefixPos < GetNGramOrder()-1
00304 && prefixPos < subPhraseLength;
00305 prefixPos++) {
00306 const Word &word = prevState->GetPrefix().GetWord(prefixPos);
00307 ShiftOrPush(contextFactor, word);
00308 updateChartScore( &prefixScore, &finalizedScore, GetValueGivenState(contextFactor, *lmState).score, ++wordPos );
00309 }
00310
00311 finalizedScore -= prevState->GetPrefixScore();
00312
00313
00314 if (subPhraseLength > GetNGramOrder() - 1) {
00315
00316 delete lmState;
00317 lmState = NewState( prevState->GetRightContext() );
00318
00319
00320 size_t remainingWords = subPhraseLength - (GetNGramOrder()-1);
00321 if (remainingWords > GetNGramOrder()-1) {
00322
00323 remainingWords = GetNGramOrder()-1;
00324 }
00325 for(size_t suffixPos = prevState->GetSuffix().GetSize() - remainingWords;
00326 suffixPos < prevState->GetSuffix().GetSize();
00327 suffixPos++) {
00328 const Word &word = prevState->GetSuffix().GetWord(suffixPos);
00329 ShiftOrPush(contextFactor, word);
00330 }
00331 wordPos += subPhraseLength;
00332 }
00333 }
00334 }
00335 }
00336
00337
00338 if (OOVFeatureEnabled()) {
00339 vector<float> scores(2);
00340 scores[0] = prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
00341
00342 scores[1] = 0;
00343 out->PlusEquals(this, scores);
00344 } else {
00345 out->PlusEquals(this, prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0]);
00346 }
00347
00348 ret->Set(prefixScore, lmState);
00349 return ret;
00350 }
00351
00352 void LanguageModelImplementation::updateChartScore(float *prefixScore, float *finalizedScore, float score, size_t wordPos) const
00353 {
00354 if (wordPos < GetNGramOrder()) {
00355 *prefixScore += score;
00356 } else {
00357 *finalizedScore += score;
00358 }
00359 }
00360
00361 }