00001 #include <sstream>
00002 #include <boost/algorithm/string.hpp>
00003 #include "WordTranslationFeature.h"
00004 #include "moses/Phrase.h"
00005 #include "moses/TargetPhrase.h"
00006 #include "moses/Hypothesis.h"
00007 #include "moses/ChartHypothesis.h"
00008 #include "moses/ScoreComponentCollection.h"
00009 #include "moses/TranslationOption.h"
00010 #include "moses/InputPath.h"
00011 #include "util/string_piece_hash.hh"
00012 #include "util/exception.hh"
00013
00014 using namespace std;
00015
00016 namespace Moses
00017 {
00018
00019 WordTranslationFeature::WordTranslationFeature(const std::string &line)
00020 :StatelessFeatureFunction(0, line)
00021 ,m_unrestricted(true)
00022 ,m_simple(true)
00023 ,m_sourceContext(false)
00024 ,m_targetContext(false)
00025 ,m_domainTrigger(false)
00026 ,m_ignorePunctuation(false)
00027 {
00028 VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
00029 ReadParameters();
00030
00031 if (m_simple == 1) VERBOSE(1, " Using simple word translations.");
00032 if (m_sourceContext == 1) VERBOSE(1, " Using source context.");
00033 if (m_targetContext == 1) VERBOSE(1, " Using target context.");
00034 if (m_domainTrigger == 1) VERBOSE(1, " Using domain triggers.");
00035
00036
00037 if (m_ignorePunctuation) {
00038 VERBOSE(1, " Ignoring punctuation for triggers.");
00039 char punctuation[] = "\"'!?¿·()#_,.:;•&@‑/\\0123456789~=";
00040 for (size_t i=0; i < sizeof(punctuation)-1; ++i) {
00041 m_punctuationHash[punctuation[i]] = 1;
00042 }
00043 }
00044
00045 VERBOSE(1, " Done." << std::endl);
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061 }
00062
00063 void WordTranslationFeature::SetParameter(const std::string& key, const std::string& value)
00064 {
00065 if (key == "input-factor") {
00066 m_factorTypeSource = Scan<FactorType>(value);
00067 } else if (key == "output-factor") {
00068 m_factorTypeTarget = Scan<FactorType>(value);
00069 } else if (key == "simple") {
00070 m_simple = Scan<bool>(value);
00071 } else if (key == "source-context") {
00072 m_sourceContext = Scan<bool>(value);
00073 } else if (key == "target-context") {
00074 m_targetContext = Scan<bool>(value);
00075 } else if (key == "ignore-punctuation") {
00076 m_ignorePunctuation = Scan<bool>(value);
00077 } else if (key == "domain-trigger") {
00078 m_domainTrigger = Scan<bool>(value);
00079 } else if (key == "texttype") {
00080
00081 } else if (key == "source-path") {
00082 m_filePathSource = value;
00083 } else if (key == "target-path") {
00084 m_filePathTarget = value;
00085 } else {
00086 StatelessFeatureFunction::SetParameter(key, value);
00087 }
00088 }
00089
00090 void WordTranslationFeature::Load(AllOptions::ptr const& opts)
00091 {
00092 m_options = opts;
00093
00094 if (m_filePathSource.empty()) {
00095 return;
00096 }
00097
00098 FEATUREVERBOSE(1, "Loading word translation word lists from " << m_filePathSource << " and " << m_filePathTarget << std::endl);
00099 if (m_domainTrigger) {
00100
00101 ifstream inFileSource(m_filePathSource.c_str());
00102 UTIL_THROW_IF2(!inFileSource, "could not open file " << m_filePathSource);
00103
00104 std::string line;
00105 while (getline(inFileSource, line)) {
00106 m_vocabDomain.resize(m_vocabDomain.size() + 1);
00107 vector<string> termVector;
00108 boost::split(termVector, line, boost::is_any_of("\t "));
00109 for (size_t i=0; i < termVector.size(); ++i)
00110 m_vocabDomain.back().insert(termVector[i]);
00111 }
00112
00113 inFileSource.close();
00114 } else {
00115
00116 ifstream inFileSource(m_filePathSource.c_str());
00117 UTIL_THROW_IF2(!inFileSource, "could not open file " << m_filePathSource);
00118
00119 std::string line;
00120 while (getline(inFileSource, line)) {
00121 m_vocabSource.insert(line);
00122 }
00123
00124 inFileSource.close();
00125
00126
00127 ifstream inFileTarget(m_filePathTarget.c_str());
00128 UTIL_THROW_IF2(!inFileTarget, "could not open file " << m_filePathTarget);
00129
00130 while (getline(inFileTarget, line)) {
00131 m_vocabTarget.insert(line);
00132 }
00133
00134 inFileTarget.close();
00135
00136 m_unrestricted = false;
00137 }
00138 }
00139
00140 void WordTranslationFeature::EvaluateWithSourceContext(const InputType &input
00141 , const InputPath &inputPath
00142 , const TargetPhrase &targetPhrase
00143 , const StackVec *stackVec
00144 , ScoreComponentCollection &scoreBreakdown
00145 , ScoreComponentCollection *estimatedScores) const
00146 {
00147 const Sentence& sentence = static_cast<const Sentence&>(input);
00148 const AlignmentInfo &alignment = targetPhrase.GetAlignTerm();
00149
00150
00151 for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {
00152 const Phrase& sourcePhrase = inputPath.GetPhrase();
00153 int sourceIndex = alignmentPoint->first;
00154 int targetIndex = alignmentPoint->second;
00155 Word ws = sourcePhrase.GetWord(sourceIndex);
00156 if (m_factorTypeSource == 0 && ws.IsNonTerminal()) continue;
00157 Word wt = targetPhrase.GetWord(targetIndex);
00158 if (m_factorTypeSource == 0 && wt.IsNonTerminal()) continue;
00159 StringPiece sourceWord = ws.GetFactor(m_factorTypeSource)->GetString();
00160 StringPiece targetWord = wt.GetFactor(m_factorTypeTarget)->GetString();
00161 if (m_ignorePunctuation) {
00162
00163 char firstChar = sourceWord[0];
00164 CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
00165 if(charIterator != m_punctuationHash.end())
00166 continue;
00167 firstChar = targetWord[0];
00168 charIterator = m_punctuationHash.find( firstChar );
00169 if(charIterator != m_punctuationHash.end())
00170 continue;
00171 }
00172
00173 if (!m_unrestricted) {
00174 if (FindStringPiece(m_vocabSource, sourceWord) == m_vocabSource.end())
00175 sourceWord = "OTHER";
00176 if (FindStringPiece(m_vocabTarget, targetWord) == m_vocabTarget.end())
00177 targetWord = "OTHER";
00178 }
00179
00180 if (m_simple) {
00181
00182 util::StringStream featureName;
00183 featureName << m_description << "_";
00184 featureName << sourceWord;
00185 featureName << "~";
00186 featureName << targetWord;
00187 scoreBreakdown.SparsePlusEquals(featureName.str(), 1);
00188 }
00189 if (m_domainTrigger && !m_sourceContext) {
00190 const bool use_topicid = sentence.GetUseTopicId();
00191 const bool use_topicid_prob = sentence.GetUseTopicIdAndProb();
00192 if (use_topicid || use_topicid_prob) {
00193 if(use_topicid) {
00194
00195 const long topicid = sentence.GetTopicId();
00196 util::StringStream feature;
00197 feature << m_description << "_";
00198 if (topicid == -1)
00199 feature << "unk";
00200 else
00201 feature << topicid;
00202
00203 feature << "_";
00204 feature << sourceWord;
00205 feature << "~";
00206 feature << targetWord;
00207 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00208 } else {
00209
00210 const vector<string> &topicid_prob = *(input.GetTopicIdAndProb());
00211 if (atol(topicid_prob[0].c_str()) == -1) {
00212 util::StringStream feature;
00213 feature << m_description << "_unk_";
00214 feature << sourceWord;
00215 feature << "~";
00216 feature << targetWord;
00217 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00218 } else {
00219 for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
00220 util::StringStream feature;
00221 feature << m_description << "_";
00222 feature << topicid_prob[i];
00223 feature << "_";
00224 feature << sourceWord;
00225 feature << "~";
00226 feature << targetWord;
00227 scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
00228 }
00229 }
00230 }
00231 } else {
00232
00233 const long docid = input.GetDocumentId();
00234 for (boost::unordered_set<std::string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
00235 string sourceTrigger = *p;
00236 util::StringStream feature;
00237 feature << m_description << "_";
00238 feature << sourceTrigger;
00239 feature << "_";
00240 feature << sourceWord;
00241 feature << "~";
00242 feature << targetWord;
00243 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00244 }
00245 }
00246 }
00247 if (m_sourceContext) {
00248 size_t globalSourceIndex = inputPath.GetWordsRange().GetStartPos() + sourceIndex;
00249 if (!m_domainTrigger && globalSourceIndex == 0) {
00250
00251 util::StringStream feature;
00252 feature << m_description << "_";
00253 feature << "<s>,";
00254 feature << sourceWord;
00255 feature << "~";
00256 feature << targetWord;
00257 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00258 }
00259
00260
00261 for(size_t contextIndex = 0; contextIndex < input.GetSize(); contextIndex++ ) {
00262 if (contextIndex == globalSourceIndex) continue;
00263 StringPiece sourceTrigger = input.GetWord(contextIndex).GetFactor(m_factorTypeSource)->GetString();
00264 if (m_ignorePunctuation) {
00265
00266 char firstChar = sourceTrigger[0];
00267 CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
00268 if(charIterator != m_punctuationHash.end())
00269 continue;
00270 }
00271
00272 const long docid = input.GetDocumentId();
00273 bool sourceTriggerExists = false;
00274 if (m_domainTrigger)
00275 sourceTriggerExists = FindStringPiece(m_vocabDomain[docid], sourceTrigger ) != m_vocabDomain[docid].end();
00276 else if (!m_unrestricted)
00277 sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
00278
00279 if (m_domainTrigger) {
00280 if (sourceTriggerExists) {
00281 util::StringStream feature;
00282 feature << m_description << "_";
00283 feature << sourceTrigger;
00284 feature << "_";
00285 feature << sourceWord;
00286 feature << "~";
00287 feature << targetWord;
00288 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00289 }
00290 } else if (m_unrestricted || sourceTriggerExists) {
00291 util::StringStream feature;
00292 feature << m_description << "_";
00293 if (contextIndex < globalSourceIndex) {
00294 feature << sourceTrigger;
00295 feature << ",";
00296 feature << sourceWord;
00297 } else {
00298 feature << sourceWord;
00299 feature << ",";
00300 feature << sourceTrigger;
00301 }
00302 feature << "~";
00303 feature << targetWord;
00304 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00305 }
00306 }
00307 }
00308 if (m_targetContext) {
00309 throw runtime_error("Can't use target words outside current translation option in a stateless feature");
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349 }
00350 }
00351 }
00352
00353 bool WordTranslationFeature::IsUseable(const FactorMask &mask) const
00354 {
00355 bool ret = mask[m_factorTypeTarget];
00356 return ret;
00357 }
00358
00359 }