00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include "util/check.hh"
00023 #include <algorithm>
00024 #include <boost/lexical_cast.hpp>
00025 #include "util/tokenize_piece.hh"
00026
00027 #include "TargetPhrase.h"
00028 #include "PhraseDictionaryMemory.h"
00029 #include "GenerationDictionary.h"
00030 #include "LM/Base.h"
00031 #include "StaticData.h"
00032 #include "ScoreIndexManager.h"
00033 #include "LMList.h"
00034 #include "ScoreComponentCollection.h"
00035 #include "Util.h"
00036 #include "DummyScoreProducers.h"
00037 #include "AlignmentInfoCollection.h"
00038
00039 using namespace std;
00040
00041 namespace Moses
00042 {
00043 TargetPhrase::TargetPhrase( std::string out_string)
00044 :Phrase(0),m_transScore(0.0), m_fullScore(0.0), m_sourcePhrase(0)
00045 , m_alignmentInfo(&AlignmentInfoCollection::Instance().GetEmptyAlignmentInfo())
00046 {
00047
00048
00049 const StaticData &staticData = StaticData::Instance();
00050 CreateFromString(staticData.GetInputFactorOrder(), out_string, staticData.GetFactorDelimiter());
00051 }
00052
00053
00054 TargetPhrase::TargetPhrase()
00055 :Phrase(ARRAY_SIZE_INCR)
00056 , m_transScore(0.0)
00057 , m_fullScore(0.0)
00058 , m_sourcePhrase(0)
00059 , m_alignmentInfo(&AlignmentInfoCollection::Instance().GetEmptyAlignmentInfo())
00060 {
00061 }
00062
00063 TargetPhrase::TargetPhrase(const Phrase &phrase)
00064 : Phrase(phrase)
00065 , m_transScore(0.0)
00066 , m_fullScore(0.0)
00067 , m_sourcePhrase(0)
00068 , m_alignmentInfo(&AlignmentInfoCollection::Instance().GetEmptyAlignmentInfo())
00069 {
00070 }
00071
00072 TargetPhrase::~TargetPhrase()
00073 {
00074 }
00075
00076 void TargetPhrase::SetScore(const TranslationSystem* system)
00077 {
00078
00079 m_transScore = 0;
00080 m_fullScore = - system->GetWeightWordPenalty();
00081 }
00082
00083 #ifdef HAVE_PROTOBUF
00084 void TargetPhrase::WriteToRulePB(hgmert::Rule* pb) const
00085 {
00086 pb->add_trg_words("[X,1]");
00087 for (size_t pos = 0 ; pos < GetSize() ; pos++)
00088 pb->add_trg_words(GetWord(pos)[0]->GetString());
00089 }
00090 #endif
00091
00092
00093
00094 void TargetPhrase::SetScore(float score)
00095 {
00096
00097
00098
00099 const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT);
00100 const ScoreProducer* prod = system.GetPhraseDictionaries()[0];
00101
00102
00103 unsigned int id = prod->GetScoreBookkeepingID();
00104
00105 const vector<float> &allWeights = StaticData::Instance().GetAllWeights();
00106
00107 size_t beginIndex = StaticData::Instance().GetScoreIndexManager().GetBeginIndex(id);
00108 size_t endIndex = StaticData::Instance().GetScoreIndexManager().GetEndIndex(id);
00109
00110 vector<float> weights;
00111
00112 std::copy(allWeights.begin() +beginIndex, allWeights.begin() + endIndex,std::back_inserter(weights));
00113
00114
00115 size_t numScores = prod->GetNumScoreComponents();
00116
00117
00118 vector <float> scoreVector(numScores,score/numScores);
00119
00120
00121 SetScore(prod,scoreVector,weights,system.GetWeightWordPenalty(),system.GetLanguageModels());
00122 }
00123
00128 void TargetPhrase::SetScore(const TranslationSystem* system, const Scores &scoreVector)
00129 {
00130
00131
00132 const ScoreProducer* prod = system->GetPhraseDictionaries()[0];
00133
00134
00135 unsigned int id = prod->GetScoreBookkeepingID();
00136 const vector<float> &allWeights = StaticData::Instance().GetAllWeights();
00137 size_t beginIndex = StaticData::Instance().GetScoreIndexManager().GetBeginIndex(id);
00138 size_t endIndex = StaticData::Instance().GetScoreIndexManager().GetEndIndex(id);
00139 vector<float> weights;
00140 std::copy(allWeights.begin() +beginIndex, allWeights.begin() + endIndex,std::back_inserter(weights));
00141
00142
00143 CHECK(scoreVector.size() <= prod->GetNumScoreComponents());
00144 Scores sizedScoreVector = scoreVector;
00145 sizedScoreVector.resize(prod->GetNumScoreComponents(),0.0f);
00146
00147 SetScore(prod,sizedScoreVector,weights,system->GetWeightWordPenalty(),system->GetLanguageModels());
00148 }
00149
00150 void TargetPhrase::SetScore(const ScoreProducer* translationScoreProducer,
00151 const Scores &scoreVector,
00152 const vector<float> &weightT,
00153 float weightWP, const LMList &languageModels)
00154 {
00155 CHECK(weightT.size() == scoreVector.size());
00156
00157
00158 m_transScore = std::inner_product(scoreVector.begin(), scoreVector.end(), weightT.begin(), 0.0f);
00159 m_scoreBreakdown.PlusEquals(translationScoreProducer, scoreVector);
00160
00161
00162 float totalNgramScore = 0;
00163 float totalFullScore = 0;
00164 float totalOOVScore = 0;
00165
00166 LMList::const_iterator lmIter;
00167 for (lmIter = languageModels.begin(); lmIter != languageModels.end(); ++lmIter) {
00168 const LanguageModel &lm = **lmIter;
00169
00170 if (lm.Useable(*this)) {
00171
00172 const float weightLM = lm.GetWeight();
00173 const float oovWeightLM = lm.GetOOVWeight();
00174 float fullScore, nGramScore;
00175 size_t oovCount;
00176
00177 lm.CalcScore(*this, fullScore, nGramScore, oovCount);
00178
00179 if (StaticData::Instance().GetLMEnableOOVFeature()) {
00180 vector<float> scores(2);
00181 scores[0] = nGramScore;
00182 scores[1] = oovCount;
00183 m_scoreBreakdown.Assign(&lm, scores);
00184 totalOOVScore += oovCount * oovWeightLM;
00185 } else {
00186 m_scoreBreakdown.Assign(&lm, nGramScore);
00187 }
00188
00189
00190
00191 totalNgramScore += nGramScore * weightLM;
00192 totalFullScore += fullScore * weightLM;
00193
00194 }
00195 }
00196
00197 m_fullScore = m_transScore + totalFullScore + totalOOVScore
00198 - (this->GetSize() * weightWP);
00199 }
00200
00201 void TargetPhrase::SetScoreChart(const ScoreProducer* translationScoreProducer,
00202 const Scores &scoreVector
00203 ,const vector<float> &weightT
00204 ,const LMList &languageModels
00205 ,const WordPenaltyProducer* wpProducer)
00206 {
00207
00208 CHECK(weightT.size() == scoreVector.size());
00209
00210
00211 m_transScore = std::inner_product(scoreVector.begin(), scoreVector.end(), weightT.begin(), 0.0f);
00212 m_scoreBreakdown.PlusEquals(translationScoreProducer, scoreVector);
00213
00214
00215 float totalNgramScore = 0;
00216 float totalFullScore = 0;
00217 float totalOOVScore = 0;
00218
00219 LMList::const_iterator lmIter;
00220 for (lmIter = languageModels.begin(); lmIter != languageModels.end(); ++lmIter) {
00221 const LanguageModel &lm = **lmIter;
00222
00223 if (lm.Useable(*this)) {
00224
00225 const float weightLM = lm.GetWeight();
00226 const float oovWeightLM = lm.GetOOVWeight();
00227 float fullScore, nGramScore;
00228 size_t oovCount;
00229
00230 lm.CalcScore(*this, fullScore, nGramScore, oovCount);
00231 fullScore = UntransformLMScore(fullScore);
00232 nGramScore = UntransformLMScore(nGramScore);
00233
00234 if (StaticData::Instance().GetLMEnableOOVFeature()) {
00235 vector<float> scores(2);
00236 scores[0] = nGramScore;
00237 scores[1] = oovCount;
00238 m_scoreBreakdown.Assign(&lm, scores);
00239 totalOOVScore += oovCount * oovWeightLM;
00240 } else {
00241 m_scoreBreakdown.Assign(&lm, nGramScore);
00242 }
00243
00244
00245 totalNgramScore += nGramScore * weightLM;
00246 totalFullScore += fullScore * weightLM;
00247 }
00248 }
00249
00250
00251 size_t wordCount = GetNumTerminals();
00252 m_scoreBreakdown.Assign(wpProducer, - (float) wordCount * 0.434294482);
00253
00254 m_fullScore = m_scoreBreakdown.GetWeightedScore() - totalNgramScore + totalFullScore + totalOOVScore;
00255 }
00256
00257 void TargetPhrase::SetScore(const ScoreProducer* producer, const Scores &scoreVector)
00258 {
00259
00260 m_scoreBreakdown.Assign(producer, scoreVector);
00261 m_transScore = 0;
00262 m_fullScore = m_scoreBreakdown.GetWeightedScore();
00263 }
00264
00265
00266 void TargetPhrase::SetWeights(const ScoreProducer* translationScoreProducer, const vector<float> &weightT)
00267 {
00268
00269 CHECK(StaticData::Instance().GetInputType()==SentenceInput);
00270
00271
00272
00273
00274
00275
00276 m_transScore = m_scoreBreakdown.PartialInnerProduct(translationScoreProducer, weightT);
00277 }
00278
00279 void TargetPhrase::ResetScore()
00280 {
00281 m_fullScore = 0;
00282 m_scoreBreakdown.ZeroAll();
00283 }
00284
00285 TargetPhrase *TargetPhrase::MergeNext(const TargetPhrase &inputPhrase) const
00286 {
00287 if (! IsCompatible(inputPhrase)) {
00288 return NULL;
00289 }
00290
00291
00292 TargetPhrase *clone = new TargetPhrase(*this);
00293 clone->m_sourcePhrase = m_sourcePhrase;
00294 int currWord = 0;
00295 const size_t len = GetSize();
00296 for (size_t currPos = 0 ; currPos < len ; currPos++) {
00297 const Word &inputWord = inputPhrase.GetWord(currPos);
00298 Word &cloneWord = clone->GetWord(currPos);
00299 cloneWord.Merge(inputWord);
00300
00301 currWord++;
00302 }
00303
00304 return clone;
00305 }
00306
00307 namespace {
00308 void MosesShouldUseExceptions(bool value) {
00309 if (!value) {
00310 std::cerr << "Could not parse alignment info" << std::endl;
00311 abort();
00312 }
00313 }
00314 }
00315
00316 void TargetPhrase::SetAlignmentInfo(const StringPiece &alignString)
00317 {
00318 set<pair<size_t,size_t> > alignmentInfo;
00319 for (util::TokenIter<util::AnyCharacter, true> token(alignString, util::AnyCharacter(" \t")); token; ++token) {
00320 util::TokenIter<util::AnyCharacter, false> dash(*token, util::AnyCharacter("-"));
00321 MosesShouldUseExceptions(dash);
00322 size_t sourcePos = boost::lexical_cast<size_t>(*dash++);
00323 MosesShouldUseExceptions(dash);
00324 size_t targetPos = boost::lexical_cast<size_t>(*dash++);
00325 MosesShouldUseExceptions(!dash);
00326
00327 alignmentInfo.insert(pair<size_t,size_t>(sourcePos, targetPos));
00328 }
00329
00330 SetAlignmentInfo(alignmentInfo);
00331 }
00332
00333 void TargetPhrase::SetAlignmentInfo(const std::set<std::pair<size_t,size_t> > &alignmentInfo)
00334 {
00335 m_alignmentInfo = AlignmentInfoCollection::Instance().Add(alignmentInfo);
00336 }
00337
00338
00339 TO_STRING_BODY(TargetPhrase);
00340
00341 std::ostream& operator<<(std::ostream& os, const TargetPhrase& tp)
00342 {
00343 os << static_cast<const Phrase&>(tp) << ":" << tp.GetAlignmentInfo();
00344 os << ": pC=" << tp.m_transScore << ", c=" << tp.m_fullScore;
00345
00346 return os;
00347 }
00348
00349 }
00350