00001 #include <sstream>
00002 #include "CorrectionPattern.h"
00003 #include "moses/Phrase.h"
00004 #include "moses/TargetPhrase.h"
00005 #include "moses/InputPath.h"
00006 #include "moses/Hypothesis.h"
00007 #include "moses/ChartHypothesis.h"
00008 #include "moses/ScoreComponentCollection.h"
00009 #include "moses/TranslationOption.h"
00010 #include "util/string_piece_hash.hh"
00011 #include "util/exception.hh"
00012
00013 #include <functional>
00014 #include <algorithm>
00015
00016 #include <boost/foreach.hpp>
00017 #include <boost/algorithm/string.hpp>
00018
00019 #include "Diffs.h"
00020
00021 namespace Moses
00022 {
00023
00024 using namespace std;
00025
00026 std::string MakePair(const std::string &s1, const std::string &s2, bool general)
00027 {
00028 std::vector<std::string> sourceList;
00029 std::vector<std::string> targetList;
00030
00031 if(general) {
00032 Diffs diffs = CreateDiff(s1, s2);
00033
00034 size_t i = 0, j = 0;
00035 char lastType = 'm';
00036
00037 std::string source, target;
00038 std::string match;
00039
00040 int count = 1;
00041
00042 BOOST_FOREACH(Diff type, diffs) {
00043 if(type == 'm') {
00044 if(lastType != 'm') {
00045 sourceList.push_back(source);
00046 targetList.push_back(target);
00047 }
00048 source.clear();
00049 target.clear();
00050
00051 if(s1[i] == '+') {
00052 if(match.size() >= 3) {
00053 sourceList.push_back("(\\w{3,})·");
00054 std::string temp = "1";
00055 sprintf((char*)temp.c_str(), "%d", count);
00056 targetList.push_back("\\" + temp + "·");
00057 count++;
00058 } else {
00059 sourceList.push_back(match + "·");
00060 targetList.push_back(match + "·");
00061 }
00062 match.clear();
00063 } else
00064 match.push_back(s1[i]);
00065
00066 i++;
00067 j++;
00068 } else if(type == 'd') {
00069 if(s1[i] == '+')
00070 source += "·";
00071 else
00072 source.push_back(s1[i]);
00073 i++;
00074 } else if(type == 'i') {
00075 if(s2[j] == '+')
00076 target += "·";
00077 else
00078 target.push_back(s2[j]);
00079 j++;
00080 }
00081 if(type != 'm' && !match.empty()) {
00082 if(match.size() >= 3) {
00083 sourceList.push_back("(\\w{3,})");
00084 std::string temp = "1";
00085 sprintf((char*)temp.c_str(), "%d", count);
00086 targetList.push_back("\\" + temp);
00087 count++;
00088 } else {
00089 sourceList.push_back(match);
00090 targetList.push_back(match);
00091 }
00092
00093 match.clear();
00094 }
00095
00096 lastType = type;
00097 }
00098 if(lastType != 'm') {
00099 sourceList.push_back(source);
00100 targetList.push_back(target);
00101 }
00102
00103 if(!match.empty()) {
00104 if(match.size() >= 3) {
00105 sourceList.push_back("(\\w{3,})");
00106 std::string temp = "1";
00107 sprintf((char*)temp.c_str(), "%d", count);
00108 targetList.push_back("\\"+ temp);
00109 count++;
00110 } else {
00111 sourceList.push_back(match);
00112 targetList.push_back(match);
00113 }
00114 }
00115 match.clear();
00116 } else {
00117 std::string cs1 = s1;
00118 std::string cs2 = s2;
00119 boost::replace_all(cs1, "+", "·");
00120 boost::replace_all(cs2, "+", "·");
00121
00122 sourceList.push_back(cs1);
00123 targetList.push_back(cs2);
00124 }
00125
00126 std::stringstream out;
00127 out << "sub(«";
00128 out << boost::join(sourceList, "");
00129 out << "»,«";
00130 out << boost::join(targetList, "");
00131 out << "»)";
00132
00133 return out.str();
00134 }
00135
00136 std::string CorrectionPattern::CreateSinglePattern(const Tokens &s1, const Tokens &s2) const
00137 {
00138 std::stringstream out;
00139 if(s1.empty()) {
00140 out << "ins(«" << boost::join(s2, "·") << "»)";
00141 return out.str();
00142 } else if(s2.empty()) {
00143 out << "del(«" << boost::join(s1, "·") << "»)";
00144 return out.str();
00145 } else {
00146 Tokens::value_type v1 = boost::join(s1, "+");
00147 Tokens::value_type v2 = boost::join(s2, "+");
00148 out << MakePair(v1, v2, m_general);
00149 return out.str();
00150 }
00151 }
00152
00153 std::vector<std::string> GetContext(size_t pos,
00154 size_t len,
00155 size_t window,
00156 const InputType &input,
00157 const InputPath &inputPath,
00158 const std::vector<FactorType>& factorTypes,
00159 bool isRight)
00160 {
00161
00162 const Sentence& sentence = static_cast<const Sentence&>(input);
00163 const Range& range = inputPath.GetWordsRange();
00164
00165 int leftPos = range.GetStartPos() + pos - len - 1;
00166 int rightPos = range.GetStartPos() + pos;
00167
00168 std::vector<std::string> contexts;
00169
00170 for(int length = 1; length <= (int)window; ++length) {
00171 std::vector<std::string> current;
00172 if(!isRight) {
00173 for(int i = 0; i < length; i++) {
00174 if(leftPos - i >= 0) {
00175 current.push_back(sentence.GetWord(leftPos - i).GetString(factorTypes, false));
00176 } else {
00177 current.push_back("<s>");
00178 }
00179 }
00180
00181 if(current.back() == "<s>" && current.size() >= 2 && current[current.size()-2] == "<s>")
00182 continue;
00183
00184 std::reverse(current.begin(), current.end());
00185 contexts.push_back("left(«" + boost::join(current, "·") + "»)_");
00186 }
00187 if(isRight) {
00188 for(int i = 0; i < length; i++) {
00189 if(rightPos + i < (int)sentence.GetSize()) {
00190 current.push_back(sentence.GetWord(rightPos + i).GetString(factorTypes, false));
00191 } else {
00192 current.push_back("</s>");
00193 }
00194 }
00195
00196 if(current.back() == "</s>" && current.size() >= 2 && current[current.size()-2] == "</s>")
00197 continue;
00198
00199 contexts.push_back("_right(«" + boost::join(current, "·") + "»)");
00200 }
00201 }
00202 return contexts;
00203 }
00204
00205 std::vector<std::string>
00206 CorrectionPattern::CreatePattern(const Tokens &s1,
00207 const Tokens &s2,
00208 const InputType &input,
00209 const InputPath &inputPath) const
00210 {
00211
00212 Diffs diffs = CreateDiff(s1, s2);
00213 size_t i = 0, j = 0;
00214 char lastType = 'm';
00215 std::vector<std::string> patternList;
00216 Tokens source, target;
00217 BOOST_FOREACH(Diff type, diffs) {
00218 if(type == 'm') {
00219 if(lastType != 'm') {
00220 std::string pattern = CreateSinglePattern(source, target);
00221 patternList.push_back(pattern);
00222
00223 if(m_context > 0) {
00224 std::vector<std::string> leftContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, false);
00225 std::vector<std::string> rightContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, true);
00226
00227 BOOST_FOREACH(std::string left, leftContexts)
00228 patternList.push_back(left + pattern);
00229
00230 BOOST_FOREACH(std::string right, rightContexts)
00231 patternList.push_back(pattern + right);
00232
00233 BOOST_FOREACH(std::string left, leftContexts)
00234 BOOST_FOREACH(std::string right, rightContexts)
00235 patternList.push_back(left + pattern + right);
00236 }
00237 }
00238 source.clear();
00239 target.clear();
00240 if(s1[i] != s2[j]) {
00241 source.push_back(s1[i]);
00242 target.push_back(s2[j]);
00243 }
00244 i++;
00245 j++;
00246 } else if(type == 'd') {
00247 source.push_back(s1[i]);
00248 i++;
00249 } else if(type == 'i') {
00250 target.push_back(s2[j]);
00251 j++;
00252 }
00253 lastType = type;
00254 }
00255 if(lastType != 'm') {
00256 std::string pattern = CreateSinglePattern(source, target);
00257 patternList.push_back(pattern);
00258
00259 if(m_context > 0) {
00260 std::vector<std::string> leftContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, false);
00261 std::vector<std::string> rightContexts = GetContext(i, source.size(), m_context, input, inputPath, m_contextFactors, true);
00262
00263 BOOST_FOREACH(std::string left, leftContexts)
00264 patternList.push_back(left + pattern);
00265
00266 BOOST_FOREACH(std::string right, rightContexts)
00267 patternList.push_back(pattern + right);
00268
00269 BOOST_FOREACH(std::string left, leftContexts)
00270 BOOST_FOREACH(std::string right, rightContexts)
00271 patternList.push_back(left + pattern + right);
00272 }
00273 }
00274
00275 return patternList;
00276 }
00277
00278 CorrectionPattern::CorrectionPattern(const std::string &line)
00279 : StatelessFeatureFunction(0, line), m_factors(1, 0), m_general(false),
00280 m_context(0), m_contextFactors(1, 0)
00281 {
00282 std::cerr << "Initializing correction pattern feature.." << std::endl;
00283 ReadParameters();
00284 }
00285
00286 void CorrectionPattern::SetParameter(const std::string& key, const std::string& value)
00287 {
00288 if (key == "factor") {
00289 m_factors = std::vector<FactorType>(1, Scan<FactorType>(value));
00290 } else if (key == "context-factor") {
00291 m_contextFactors = std::vector<FactorType>(1, Scan<FactorType>(value));
00292 } else if (key == "general") {
00293 m_general = Scan<bool>(value);
00294 } else if (key == "context") {
00295 m_context = Scan<size_t>(value);
00296 } else {
00297 StatelessFeatureFunction::SetParameter(key, value);
00298 }
00299 }
00300
00301 void CorrectionPattern::EvaluateWithSourceContext(const InputType &input
00302 , const InputPath &inputPath
00303 , const TargetPhrase &targetPhrase
00304 , const StackVec *stackVec
00305 , ScoreComponentCollection &scoreBreakdown
00306 , ScoreComponentCollection *estimatedFutureScore) const
00307 {
00308 ComputeFeatures(input, inputPath, targetPhrase, &scoreBreakdown);
00309 }
00310
00311 void CorrectionPattern::ComputeFeatures(
00312 const InputType &input,
00313 const InputPath &inputPath,
00314 const TargetPhrase& target,
00315 ScoreComponentCollection* accumulator) const
00316 {
00317 const Phrase &source = inputPath.GetPhrase();
00318
00319 std::vector<std::string> sourceTokens;
00320 for(size_t i = 0; i < source.GetSize(); ++i)
00321 sourceTokens.push_back(source.GetWord(i).GetString(m_factors, false));
00322
00323 std::vector<std::string> targetTokens;
00324 for(size_t i = 0; i < target.GetSize(); ++i)
00325 targetTokens.push_back(target.GetWord(i).GetString(m_factors, false));
00326
00327 std::vector<std::string> patternList = CreatePattern(sourceTokens, targetTokens, input, inputPath);
00328 for(size_t i = 0; i < patternList.size(); ++i)
00329 accumulator->PlusEquals(this, patternList[i], 1);
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342 }
00343
00344 bool CorrectionPattern::IsUseable(const FactorMask &mask) const
00345 {
00346 bool ret = true;
00347 for(size_t i = 0; i < m_factors.size(); ++i)
00348 ret = ret && mask[m_factors[i]];
00349 for(size_t i = 0; i < m_contextFactors.size(); ++i)
00350 ret = ret && mask[m_contextFactors[i]];
00351 return ret;
00352 }
00353
00354 }