00001 #include "TargetNgramFeature.h"
00002 #include "moses/Phrase.h"
00003 #include "moses/TargetPhrase.h"
00004 #include "moses/Hypothesis.h"
00005 #include "moses/ScoreComponentCollection.h"
00006 #include "moses/ChartHypothesis.h"
00007 #include "util/exception.hh"
00008 #include "util/string_piece_hash.hh"
00009
00010 namespace Moses
00011 {
00012
00013 using namespace std;
00014
00015 size_t TargetNgramState::hash() const
00016 {
00017 std::size_t ret = boost::hash_range(m_words.begin(), m_words.end());
00018 return ret;
00019 }
00020
00021 bool TargetNgramState::operator==(const FFState& other) const
00022 {
00023 const TargetNgramState& rhs = static_cast<const TargetNgramState&>(other);
00024 bool result;
00025 if (m_words.size() == rhs.m_words.size()) {
00026 for (size_t i = 0; i < m_words.size(); ++i) {
00027 result = m_words[i] == rhs.m_words[i];
00028 if (!result) return false;
00029 }
00030 return true;
00031 } else if (m_words.size() < rhs.m_words.size()) {
00032 for (size_t i = 0; i < m_words.size(); ++i) {
00033 result = m_words[i] == rhs.m_words[i];
00034 if (!result) return false;
00035 }
00036 return true;
00037 } else {
00038 for (size_t i = 0; i < rhs.m_words.size(); ++i) {
00039 result = m_words[i] == rhs.m_words[i];
00040 if (!result) return false;
00041 }
00042 return true;
00043 }
00044 }
00045
00047 TargetNgramFeature::TargetNgramFeature(const std::string &line)
00048 :StatefulFeatureFunction(0, line)
00049 {
00050 std::cerr << "Initializing target ngram feature.." << std::endl;
00051
00052 ReadParameters();
00053
00054 FactorCollection& factorCollection = FactorCollection::Instance();
00055 const Factor* bosFactor = factorCollection.AddFactor(Output,m_factorType,BOS_);
00056 m_bos.SetFactor(m_factorType,bosFactor);
00057
00058 m_baseName = GetScoreProducerDescription();
00059 m_baseName.append("_");
00060 }
00061
00062 void TargetNgramFeature::SetParameter(const std::string& key, const std::string& value)
00063 {
00064 if (key == "factor") {
00065 m_factorType = Scan<FactorType>(value);
00066 } else if (key == "n") {
00067 m_n = Scan<size_t>(value);
00068 } else if (key == "lower-ngrams") {
00069 m_lower_ngrams = Scan<bool>(value);
00070 } else if (key == "file") {
00071 m_file = value;
00072 } else {
00073 StatefulFeatureFunction::SetParameter(key, value);
00074 }
00075 }
00076
00077 void TargetNgramFeature::Load(AllOptions::ptr const& opts)
00078 {
00079 m_options = opts;
00080 if (m_file == "") return;
00081
00082 if (m_file == "*") return;
00083 ifstream inFile(m_file.c_str());
00084 if (!inFile) {
00085 UTIL_THROW(util::Exception, "Couldn't open file" << m_file);
00086 }
00087
00088 std::string line;
00089 m_vocab.insert(BOS_);
00090 m_vocab.insert(EOS_);
00091 while (getline(inFile, line)) {
00092 m_vocab.insert(line);
00093 cerr << "ADD TO VOCAB: '" << line << "'" << endl;
00094 }
00095
00096 inFile.close();
00097 return;
00098 }
00099
00100 const FFState* TargetNgramFeature::EmptyHypothesisState(const InputType &) const
00101 {
00102 vector<Word> bos(1,m_bos);
00103 return new TargetNgramState(bos);
00104 }
00105
00106 FFState* TargetNgramFeature::EvaluateWhenApplied(const Hypothesis& cur_hypo,
00107 const FFState* prev_state,
00108 ScoreComponentCollection* accumulator) const
00109 {
00110 const TargetNgramState* tnState = static_cast<const TargetNgramState*>(prev_state);
00111 assert(tnState);
00112
00113
00114 const Phrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
00115 if (targetPhrase.GetSize() == 0) return new TargetNgramState(*tnState);
00116
00117
00118 vector<Word> prev_words(tnState->GetWords());
00119 util::StringStream curr_ngram;
00120 bool skip = false;
00121
00122
00123 size_t smallest_n = m_n;
00124 if (m_lower_ngrams) smallest_n = 1;
00125
00126 for (size_t n = m_n; n >= smallest_n; --n) {
00127 for (size_t i = 0; i < targetPhrase.GetSize(); ++i) {
00128
00129 const StringPiece curr_w = targetPhrase.GetWord(i).GetString(m_factorType);
00130
00131
00132 if (m_vocab.size() && (FindStringPiece(m_vocab, curr_w) == m_vocab.end())) continue;
00133
00134
00135 if (n > 1) {
00136
00137 size_t pos_in_translation = cur_hypo.GetSize() - targetPhrase.GetSize() + i;
00138 if (pos_in_translation < n - 2) continue;
00139
00140
00141 int from_prev_state = n - (i+1);
00142 skip = false;
00143 if (from_prev_state > 0) {
00144 if (prev_words.size() < from_prev_state) {
00145
00146 vector<Word> new_prev_words;
00147 for (size_t i = 0; i < prev_words.size(); ++i)
00148 new_prev_words.push_back(prev_words[i]);
00149 for (size_t i = 0; i < targetPhrase.GetSize(); ++i)
00150 new_prev_words.push_back(targetPhrase.GetWord(i));
00151 return new TargetNgramState(new_prev_words);
00152 }
00153
00154
00155 for (size_t j = prev_words.size()-from_prev_state; j < prev_words.size() && !skip; ++j)
00156 appendNgram(prev_words[j], skip, curr_ngram);
00157 }
00158
00159
00160 int start = i - n + 1;
00161 if (start < 0) start = 0;
00162 for (size_t j = start; j < i && !skip; ++j)
00163 appendNgram(targetPhrase.GetWord(j), skip, curr_ngram);
00164 }
00165
00166 if (!skip) {
00167 curr_ngram << curr_w;
00168
00169 accumulator->PlusEquals(this,curr_ngram.str(),1);
00170 }
00171 curr_ngram.str("");
00172 }
00173 }
00174
00175 if (cur_hypo.GetWordsBitmap().IsComplete()) {
00176 for (size_t n = m_n; n >= smallest_n; --n) {
00177 util::StringStream last_ngram;
00178 skip = false;
00179 for (size_t i = cur_hypo.GetSize() - n + 1; i < cur_hypo.GetSize() && !skip; ++i)
00180 appendNgram(cur_hypo.GetWord(i), skip, last_ngram);
00181
00182 if (n > 1 && !skip) {
00183 last_ngram << EOS_;
00184 accumulator->PlusEquals(this, last_ngram.str(), 1);
00185 }
00186 }
00187 return new TargetNgramState();
00188 }
00189
00190
00191 vector<Word> new_prev_words;
00192 if (targetPhrase.GetSize() >= m_n-1) {
00193
00194 for (size_t i = targetPhrase.GetSize() - m_n + 1; i < targetPhrase.GetSize(); ++i)
00195 new_prev_words.push_back(targetPhrase.GetWord(i));
00196 } else {
00197
00198 int from_prev_state = m_n - 1 - targetPhrase.GetSize();
00199 for (size_t i = prev_words.size()-from_prev_state; i < prev_words.size(); ++i)
00200 new_prev_words.push_back(prev_words[i]);
00201 for (size_t i = 0; i < targetPhrase.GetSize(); ++i)
00202 new_prev_words.push_back(targetPhrase.GetWord(i));
00203 }
00204 return new TargetNgramState(new_prev_words);
00205 }
00206
00207 void TargetNgramFeature::appendNgram(const Word& word, bool& skip, util::StringStream &ngram) const
00208 {
00209
00210 const StringPiece w = word.GetString(m_factorType);
00211 if (m_vocab.size() && (FindStringPiece(m_vocab, w) == m_vocab.end())) skip = true;
00212 else {
00213 ngram << w;
00214 ngram << ":";
00215 }
00216 }
00217
00218 FFState* TargetNgramFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureId, ScoreComponentCollection* accumulator) const
00219 {
00220 vector<const Word*> contextFactor;
00221 contextFactor.reserve(m_n);
00222
00223
00224 const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
00225 cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap();
00226
00227
00228 bool makePrefix = false;
00229 bool makeSuffix = false;
00230 bool collectForPrefix = true;
00231 size_t prefixTerminals = 0;
00232 size_t suffixTerminals = 0;
00233 bool onlyTerminals = true;
00234 bool prev_is_NT = false;
00235 size_t prev_subPhraseLength = 0;
00236 for (size_t phrasePos = 0; phrasePos < cur_hypo.GetCurrTargetPhrase().GetSize(); phrasePos++) {
00237
00238 const Word &word = cur_hypo.GetCurrTargetPhrase().GetWord(phrasePos);
00239
00240
00241
00242 if (!word.IsNonTerminal()) {
00243 contextFactor.push_back(&word);
00244 prev_is_NT = false;
00245
00246 if (phrasePos==0)
00247 makePrefix = true;
00248 if (phrasePos==cur_hypo.GetCurrTargetPhrase().GetSize()-1 || prev_is_NT)
00249 makeSuffix = true;
00250
00251
00252 StringPiece factorZero = word.GetString(0);
00253 if (factorZero.compare("<s>") == 0)
00254 prefixTerminals++;
00255
00256 else if (factorZero.compare("</s>") == 0)
00257 suffixTerminals++;
00258
00259 else {
00260 util::StringStream ngram;
00261 ngram << m_baseName;
00262 if (m_factorType == 0)
00263 ngram << factorZero;
00264 else
00265 ngram << word.GetString(m_factorType);
00266 accumulator->SparsePlusEquals(ngram.str(), 1);
00267
00268 if (collectForPrefix)
00269 prefixTerminals++;
00270 else
00271 suffixTerminals++;
00272 }
00273 }
00274
00275
00276 else if (m_n > 1) {
00277
00278 size_t nonTermIndex = nonTermIndexMap[phrasePos];
00279 const ChartHypothesis *prevHypo = cur_hypo.GetPrevHypo(nonTermIndex);
00280
00281 const TargetNgramChartState* prevState =
00282 static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId));
00283 size_t subPhraseLength = prevState->GetNumTargetTerminals();
00284
00285
00286 if (phrasePos == 0) {
00287 if (subPhraseLength == 1) {
00288 makePrefix = true;
00289 ++prefixTerminals;
00290
00291 const Word &word = prevState->GetSuffix().GetWord(0);
00292
00293 contextFactor.push_back(&word);
00294 } else {
00295 onlyTerminals = false;
00296 collectForPrefix = false;
00297 int suffixPos = prevState->GetSuffix().GetSize() - (m_n-1);
00298 if (suffixPos < 0) suffixPos = 0;
00299 for(; (size_t)suffixPos < prevState->GetSuffix().GetSize(); suffixPos++) {
00300 const Word &word = prevState->GetSuffix().GetWord(suffixPos);
00301
00302 contextFactor.push_back(&word);
00303 }
00304 }
00305 }
00306
00307
00308 else {
00309
00310 for(size_t prefixPos = 0; prefixPos < m_n-1
00311 && prefixPos < subPhraseLength; prefixPos++) {
00312 const Word &word = prevState->GetPrefix().GetWord(prefixPos);
00313
00314 contextFactor.push_back(&word);
00315 }
00316
00317 if (subPhraseLength==1) {
00318 if (collectForPrefix)
00319 ++prefixTerminals;
00320 else
00321 ++suffixTerminals;
00322
00323 if (phrasePos == cur_hypo.GetCurrTargetPhrase().GetSize()-1)
00324 makeSuffix = true;
00325 } else {
00326 onlyTerminals = false;
00327 collectForPrefix = true;
00328
00329
00330 bool wordFollowing = (phrasePos < cur_hypo.GetCurrTargetPhrase().GetSize() - 1)? true : false;
00331
00332
00333 if (wordFollowing && subPhraseLength > m_n - 1) {
00334
00335 MakePrefixNgrams(contextFactor, accumulator, prefixTerminals);
00336 contextFactor.clear();
00337 makePrefix = false;
00338 makeSuffix = true;
00339 collectForPrefix = false;
00340 prefixTerminals = 0;
00341 suffixTerminals = 0;
00342
00343
00344 size_t remainingWords = (remainingWords > m_n-1) ? m_n-1 : subPhraseLength - (m_n-1);
00345 for(size_t suffixPos = 0; suffixPos < prevState->GetSuffix().GetSize(); suffixPos++) {
00346 const Word &word = prevState->GetSuffix().GetWord(suffixPos);
00347
00348 contextFactor.push_back(&word);
00349 }
00350 }
00351
00352 else if (wordFollowing && subPhraseLength == m_n - 1) {
00353
00354 MakePrefixNgrams(contextFactor, accumulator, prefixTerminals);
00355 makePrefix = false;
00356 makeSuffix = true;
00357 collectForPrefix = false;
00358 prefixTerminals = 0;
00359 suffixTerminals = 0;
00360 } else if (prev_is_NT && prev_subPhraseLength > 1 && subPhraseLength > 1) {
00361
00362 MakePrefixNgrams(contextFactor, accumulator, 1, m_n-2);
00363 MakeSuffixNgrams(contextFactor, accumulator, 1, m_n-2);
00364 makePrefix = false;
00365 makeSuffix = false;
00366 collectForPrefix = false;
00367 prefixTerminals = 0;
00368 suffixTerminals = 0;
00369
00370
00371 util::StringStream curr_ngram;
00372 curr_ngram << m_baseName;
00373 curr_ngram << (*contextFactor[m_n-2]).GetString(m_factorType);
00374 curr_ngram << ":";
00375 curr_ngram << (*contextFactor[m_n-1]).GetString(m_factorType);
00376 accumulator->SparseMinusEquals(curr_ngram.str(),1);
00377 }
00378 }
00379 }
00380 prev_is_NT = true;
00381 prev_subPhraseLength = subPhraseLength;
00382 }
00383 }
00384
00385 if (m_n > 1) {
00386 if (onlyTerminals) {
00387 MakePrefixNgrams(contextFactor, accumulator, prefixTerminals-1);
00388 } else {
00389 if (makePrefix)
00390 MakePrefixNgrams(contextFactor, accumulator, prefixTerminals);
00391 if (makeSuffix)
00392 MakeSuffixNgrams(contextFactor, accumulator, suffixTerminals);
00393
00394
00395 size_t size = contextFactor.size();
00396 if (makePrefix && makeSuffix && (size <= m_n)) {
00397 util::StringStream curr_ngram;
00398 curr_ngram << m_baseName;
00399 for (size_t i = 0; i < size; ++i) {
00400 curr_ngram << (*contextFactor[i]).GetString(m_factorType);
00401 if (i < size-1)
00402 curr_ngram << ":";
00403 }
00404 accumulator->SparseMinusEquals(curr_ngram.str(), 1);
00405 }
00406 }
00407 }
00408
00409
00410 return new TargetNgramChartState(cur_hypo, featureId, m_n);
00411 }
00412
00413 void TargetNgramFeature::MakePrefixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator, size_t numberOfStartPos, size_t offset) const
00414 {
00415 util::StringStream ngram;
00416 size_t size = contextFactor.size();
00417 for (size_t k = 0; k < numberOfStartPos; ++k) {
00418 size_t max_end = (size < m_n+k+offset)? size: m_n+k+offset;
00419 for (size_t end_pos = 1+k+offset; end_pos < max_end; ++end_pos) {
00420 ngram << m_baseName;
00421 for (size_t i=k+offset; i <= end_pos; ++i) {
00422 if (i > k+offset)
00423 ngram << ":";
00424 StringPiece factorZero = (*contextFactor[i]).GetString(0);
00425 if (m_factorType == 0 || factorZero.compare("<s>") == 0 || factorZero.compare("</s>") == 0)
00426 ngram << factorZero;
00427 else
00428 ngram << (*contextFactor[i]).GetString(m_factorType);
00429 const Word w = *contextFactor[i];
00430 }
00431
00432 accumulator->SparsePlusEquals(ngram.str(), 1);
00433 ngram.str("");
00434 }
00435 }
00436 }
00437
00438 void TargetNgramFeature::MakeSuffixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator, size_t numberOfEndPos, size_t offset) const
00439 {
00440 util::StringStream ngram;
00441 for (size_t k = 0; k < numberOfEndPos; ++k) {
00442 size_t end_pos = contextFactor.size()-1-k-offset;
00443 for (int start_pos=end_pos-1; (start_pos >= 0) && (end_pos-start_pos < m_n); --start_pos) {
00444 ngram << m_baseName;
00445 for (size_t j=start_pos; j <= end_pos; ++j) {
00446 StringPiece factorZero = (*contextFactor[j]).GetString(0);
00447 if (m_factorType == 0 || factorZero.compare("<s>") == 0 || factorZero.compare("</s>") == 0)
00448 ngram << factorZero;
00449 else
00450 ngram << (*contextFactor[j]).GetString(m_factorType);
00451 if (j < end_pos)
00452 ngram << ":";
00453 }
00454
00455 accumulator->SparsePlusEquals(ngram.str(), 1);
00456 ngram.str("");
00457 }
00458 }
00459 }
00460
00461 bool TargetNgramFeature::IsUseable(const FactorMask &mask) const
00462 {
00463 bool ret = mask[m_factorType];
00464 return ret;
00465 }
00466
00467 }
00468