00001 #include <boost/algorithm/string.hpp>
00002
00003 #include "PhrasePairFeature.h"
00004 #include "moses/AlignmentInfo.h"
00005 #include "moses/TargetPhrase.h"
00006 #include "moses/Hypothesis.h"
00007 #include "moses/TranslationOption.h"
00008 #include "moses/InputPath.h"
00009 #include "util/string_piece_hash.hh"
00010 #include "util/string_stream.hh"
00011 #include "util/exception.hh"
00012
00013 using namespace std;
00014
00015 namespace Moses
00016 {
00017
00018 PhrasePairFeature::PhrasePairFeature(const std::string &line)
00019 :StatelessFeatureFunction(0, line)
00020 ,m_unrestricted(false)
00021 ,m_simple(true)
00022 ,m_sourceContext(false)
00023 ,m_domainTrigger(false)
00024 ,m_ignorePunctuation(false)
00025 {
00026 VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
00027 ReadParameters();
00028
00029 if (m_simple == 1) VERBOSE(1, " Using simple phrase pairs.");
00030 if (m_sourceContext == 1) VERBOSE(1, " Using source context.");
00031 if (m_domainTrigger == 1) VERBOSE(1, " Using domain triggers.");
00032
00033
00034 if (m_ignorePunctuation) {
00035 VERBOSE(1, " Ignoring punctuation for triggers.");
00036 char punctuation[] = "\"'!?¿·()#_,.:;•&@‑/\\0123456789~=";
00037 for (size_t i=0; i < sizeof(punctuation)-1; ++i) {
00038 m_punctuationHash[punctuation[i]] = 1;
00039 }
00040 }
00041
00042 VERBOSE(1, " Done." << std::endl);
00043 }
00044
00045 void PhrasePairFeature::SetParameter(const std::string& key, const std::string& value)
00046 {
00047 if (key == "input-factor") {
00048 m_sourceFactorId = Scan<FactorType>(value);
00049 } else if (key == "output-factor") {
00050 m_targetFactorId = Scan<FactorType>(value);
00051 } else if (key == "unrestricted") {
00052 m_unrestricted = Scan<bool>(value);
00053 } else if (key == "simple") {
00054 m_simple = Scan<bool>(value);
00055 } else if (key == "source-context") {
00056 m_sourceContext = Scan<bool>(value);
00057 } else if (key == "domain-trigger") {
00058 m_domainTrigger = Scan<bool>(value);
00059 } else if (key == "ignore-punctuation") {
00060 m_ignorePunctuation = Scan<bool>(value);
00061 } else if (key == "path") {
00062 m_filePathSource = value;
00063 } else {
00064 StatelessFeatureFunction::SetParameter(key, value);
00065 }
00066 }
00067
00068 void PhrasePairFeature::Load(AllOptions::ptr const& opts)
00069 {
00070 m_options = opts;
00071 if (m_domainTrigger) {
00072
00073 ifstream inFileSource(m_filePathSource.c_str());
00074 UTIL_THROW_IF2(!inFileSource, "could not open file " << m_filePathSource);
00075
00076 std::string line;
00077 while (getline(inFileSource, line)) {
00078 std::set<std::string> terms;
00079 vector<string> termVector;
00080 boost::split(termVector, line, boost::is_any_of("\t "));
00081 for (size_t i=0; i < termVector.size(); ++i)
00082 terms.insert(termVector[i]);
00083
00084
00085 m_vocabDomain.push_back(terms);
00086 }
00087
00088 inFileSource.close();
00089 } else if (!m_unrestricted) {
00090
00091 ifstream inFileSource(m_filePathSource.c_str());
00092 UTIL_THROW_IF2(!inFileSource, "could not open file " << m_filePathSource);
00093
00094 std::string line;
00095 while (getline(inFileSource, line)) {
00096 m_vocabSource.insert(line);
00097 }
00098
00099 inFileSource.close();
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114 }
00115 }
00116
00117 void PhrasePairFeature::EvaluateWithSourceContext(const InputType &input
00118 , const InputPath &inputPath
00119 , const TargetPhrase &targetPhrase
00120 , const StackVec *stackVec
00121 , ScoreComponentCollection &scoreBreakdown
00122 , ScoreComponentCollection *estimatedScores) const
00123 {
00124 const Phrase& source = inputPath.GetPhrase();
00125 if (m_domainTrigger) {
00126 const Sentence& isnt = static_cast<const Sentence&>(input);
00127 const bool use_topicid = isnt.GetUseTopicId();
00128 const bool use_topicid_prob = isnt.GetUseTopicIdAndProb();
00129
00130
00131 util::StringStream pair;
00132
00133 pair << ReplaceTilde( source.GetWord(0).GetFactor(m_sourceFactorId)->GetString() );
00134 for (size_t i = 1; i < source.GetSize(); ++i) {
00135 const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
00136 pair << "~";
00137 pair << ReplaceTilde( sourceFactor->GetString() );
00138 }
00139 pair << "~~";
00140 pair << ReplaceTilde( targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString() );
00141 for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
00142 const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
00143 pair << "~";
00144 pair << ReplaceTilde( targetFactor->GetString() );
00145 }
00146
00147 if (use_topicid || use_topicid_prob) {
00148 if(use_topicid) {
00149
00150 const long topicid = isnt.GetTopicId();
00151 util::StringStream feature;
00152
00153 feature << m_description << "_";
00154 if (topicid == -1)
00155 feature << "unk";
00156 else
00157 feature << topicid;
00158
00159 feature << "_";
00160 feature << pair.str();
00161 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00162 } else {
00163
00164 const vector<string> &topicid_prob = *(isnt.GetTopicIdAndProb());
00165 if (atol(topicid_prob[0].c_str()) == -1) {
00166 util::StringStream feature;
00167 feature << m_description << "_unk_";
00168 feature << pair.str();
00169 scoreBreakdown.SparsePlusEquals(feature.str(), 1);
00170 } else {
00171 for (size_t i=0; i+1 < topicid_prob.size(); i+=2) {
00172 util::StringStream feature;
00173 feature << m_description << "_";
00174 feature << topicid_prob[i];
00175 feature << "_";
00176 feature << pair.str();
00177 scoreBreakdown.SparsePlusEquals(feature.str(), atof((topicid_prob[i+1]).c_str()));
00178 }
00179 }
00180 }
00181 } else {
00182
00183 const long docid = isnt.GetDocumentId();
00184 for (set<string>::const_iterator p = m_vocabDomain[docid].begin(); p != m_vocabDomain[docid].end(); ++p) {
00185 string sourceTrigger = *p;
00186 util::StringStream namestr;
00187 namestr << m_description << "_";
00188 namestr << sourceTrigger;
00189 namestr << "_";
00190 namestr << pair.str();
00191 scoreBreakdown.SparsePlusEquals(namestr.str(),1);
00192 }
00193 }
00194 }
00195 if (m_sourceContext) {
00196 const Sentence& isnt = static_cast<const Sentence&>(input);
00197
00198
00199 for(size_t contextIndex = 0; contextIndex < isnt.GetSize(); contextIndex++ ) {
00200 StringPiece sourceTrigger = isnt.GetWord(contextIndex).GetFactor(m_sourceFactorId)->GetString();
00201 if (m_ignorePunctuation) {
00202
00203 char firstChar = sourceTrigger[0];
00204 CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar );
00205 if(charIterator != m_punctuationHash.end())
00206 continue;
00207 }
00208
00209 bool sourceTriggerExists = false;
00210 if (!m_unrestricted)
00211 sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end();
00212
00213 if (m_unrestricted || sourceTriggerExists) {
00214 util::StringStream namestr;
00215 namestr << m_description << "_";
00216 namestr << sourceTrigger;
00217 namestr << "~";
00218 namestr << ReplaceTilde( source.GetWord(0).GetFactor(m_sourceFactorId)->GetString() );
00219 for (size_t i = 1; i < source.GetSize(); ++i) {
00220 const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
00221 namestr << "~";
00222 namestr << ReplaceTilde( sourceFactor->GetString() );
00223 }
00224 namestr << "~~";
00225 namestr << ReplaceTilde( targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString() );
00226 for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
00227 const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
00228 namestr << "~";
00229 namestr << ReplaceTilde( targetFactor->GetString() );
00230 }
00231
00232 scoreBreakdown.SparsePlusEquals(namestr.str(),1);
00233 }
00234 }
00235 }
00236 }
00237
00238 void PhrasePairFeature::EvaluateInIsolation(const Phrase &source
00239 , const TargetPhrase &targetPhrase
00240 , ScoreComponentCollection &scoreBreakdown
00241 , ScoreComponentCollection &estimatedScores) const
00242 {
00243 if (m_simple) {
00244 util::StringStream namestr;
00245 namestr << m_description << "_";
00246 namestr << ReplaceTilde( source.GetWord(0).GetFactor(m_sourceFactorId)->GetString() );
00247 for (size_t i = 1; i < source.GetSize(); ++i) {
00248 const Factor* sourceFactor = source.GetWord(i).GetFactor(m_sourceFactorId);
00249 namestr << "~";
00250 namestr << ReplaceTilde( sourceFactor->GetString() );
00251 }
00252 namestr << "~~";
00253 namestr << ReplaceTilde( targetPhrase.GetWord(0).GetFactor(m_targetFactorId)->GetString() );
00254 for (size_t i = 1; i < targetPhrase.GetSize(); ++i) {
00255 const Factor* targetFactor = targetPhrase.GetWord(i).GetFactor(m_targetFactorId);
00256 namestr << "~";
00257 namestr << ReplaceTilde( targetFactor->GetString() );
00258 }
00259 scoreBreakdown.SparsePlusEquals(namestr.str(),1);
00260 }
00261 }
00262
00263 bool PhrasePairFeature::IsUseable(const FactorMask &mask) const
00264 {
00265 bool ret = mask[m_targetFactorId];
00266 return ret;
00267 }
00268
00269 }