00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <deque>
00023
00024 #include "PhraseDecoder.h"
00025 #include "moses/StaticData.h"
00026
00027 using namespace std;
00028
00029 namespace Moses
00030 {
00031
00032 PhraseDecoder::PhraseDecoder(
00033 PhraseDictionaryCompact &phraseDictionary,
00034 const std::vector<FactorType>* input,
00035 const std::vector<FactorType>* output,
00036 size_t numScoreComponent
00037
00038 )
00039 : m_coding(None), m_numScoreComponent(numScoreComponent),
00040 m_containsAlignmentInfo(true), m_maxRank(0),
00041 m_symbolTree(0), m_multipleScoreTrees(false),
00042 m_scoreTrees(1), m_alignTree(0),
00043 m_phraseDictionary(phraseDictionary), m_input(input), m_output(output),
00044
00045 m_separator(" ||| ")
00046 { }
00047
00048 PhraseDecoder::~PhraseDecoder()
00049 {
00050 if(m_symbolTree)
00051 delete m_symbolTree;
00052
00053 for(size_t i = 0; i < m_scoreTrees.size(); i++)
00054 if(m_scoreTrees[i])
00055 delete m_scoreTrees[i];
00056
00057 if(m_alignTree)
00058 delete m_alignTree;
00059 }
00060
00061 inline unsigned PhraseDecoder::GetSourceSymbolId(std::string& symbol)
00062 {
00063 boost::unordered_map<std::string, unsigned>::iterator it
00064 = m_sourceSymbolsMap.find(symbol);
00065 if(it != m_sourceSymbolsMap.end())
00066 return it->second;
00067
00068 size_t idx = m_sourceSymbols.find(symbol);
00069 m_sourceSymbolsMap[symbol] = idx;
00070 return idx;
00071 }
00072
00073 inline std::string PhraseDecoder::GetTargetSymbol(unsigned idx) const
00074 {
00075 if(idx < m_targetSymbols.size())
00076 return m_targetSymbols[idx];
00077 return std::string("##ERROR##");
00078 }
00079
00080 inline size_t PhraseDecoder::GetREncType(unsigned encodedSymbol)
00081 {
00082 return (encodedSymbol >> 30) + 1;
00083 }
00084
00085 inline size_t PhraseDecoder::GetPREncType(unsigned encodedSymbol)
00086 {
00087 return (encodedSymbol >> 31) + 1;
00088 }
00089
00090 inline unsigned PhraseDecoder::GetTranslation(unsigned srcIdx, size_t rank)
00091 {
00092 size_t srcTrgIdx = m_lexicalTableIndex[srcIdx];
00093 return m_lexicalTable[srcTrgIdx + rank].second;
00094 }
00095
00096 size_t PhraseDecoder::GetMaxSourcePhraseLength()
00097 {
00098 return m_maxPhraseLength;
00099 }
00100
00101 inline unsigned PhraseDecoder::DecodeREncSymbol1(unsigned encodedSymbol)
00102 {
00103 return encodedSymbol &= ~(3 << 30);
00104 }
00105
00106 inline unsigned PhraseDecoder::DecodeREncSymbol2Rank(unsigned encodedSymbol)
00107 {
00108 return encodedSymbol &= ~(255 << 24);
00109 }
00110
00111 inline unsigned PhraseDecoder::DecodeREncSymbol2Position(unsigned encodedSymbol)
00112 {
00113 encodedSymbol &= ~(3 << 30);
00114 encodedSymbol >>= 24;
00115 return encodedSymbol;
00116 }
00117
00118 inline unsigned PhraseDecoder::DecodeREncSymbol3(unsigned encodedSymbol)
00119 {
00120 return encodedSymbol &= ~(3 << 30);
00121 }
00122
00123 inline unsigned PhraseDecoder::DecodePREncSymbol1(unsigned encodedSymbol)
00124 {
00125 return encodedSymbol &= ~(1 << 31);
00126 }
00127
00128 inline int PhraseDecoder::DecodePREncSymbol2Left(unsigned encodedSymbol)
00129 {
00130 return ((encodedSymbol >> 25) & 63) - 32;
00131 }
00132
00133 inline int PhraseDecoder::DecodePREncSymbol2Right(unsigned encodedSymbol)
00134 {
00135 return ((encodedSymbol >> 19) & 63) - 32;
00136 }
00137
00138 inline unsigned PhraseDecoder::DecodePREncSymbol2Rank(unsigned encodedSymbol)
00139 {
00140 return (encodedSymbol & 524287);
00141 }
00142
00143 size_t PhraseDecoder::Load(std::FILE* in)
00144 {
00145 size_t start = std::ftell(in);
00146 size_t read = 0;
00147
00148 read += std::fread(&m_coding, sizeof(m_coding), 1, in);
00149 read += std::fread(&m_numScoreComponent, sizeof(m_numScoreComponent), 1, in);
00150 read += std::fread(&m_containsAlignmentInfo, sizeof(m_containsAlignmentInfo), 1, in);
00151 read += std::fread(&m_maxRank, sizeof(m_maxRank), 1, in);
00152 read += std::fread(&m_maxPhraseLength, sizeof(m_maxPhraseLength), 1, in);
00153
00154 if(m_coding == REnc) {
00155 m_sourceSymbols.load(in);
00156
00157 size_t size;
00158 read += std::fread(&size, sizeof(size_t), 1, in);
00159 m_lexicalTableIndex.resize(size);
00160 read += std::fread(&m_lexicalTableIndex[0], sizeof(size_t), size, in);
00161
00162 read += std::fread(&size, sizeof(size_t), 1, in);
00163 m_lexicalTable.resize(size);
00164 read += std::fread(&m_lexicalTable[0], sizeof(SrcTrg), size, in);
00165 }
00166
00167 m_targetSymbols.load(in);
00168
00169 m_symbolTree = new CanonicalHuffman<unsigned>(in);
00170
00171 read += std::fread(&m_multipleScoreTrees, sizeof(m_multipleScoreTrees), 1, in);
00172 if(m_multipleScoreTrees) {
00173 m_scoreTrees.resize(m_numScoreComponent);
00174 for(size_t i = 0; i < m_numScoreComponent; i++)
00175 m_scoreTrees[i] = new CanonicalHuffman<float>(in);
00176 } else {
00177 m_scoreTrees.resize(1);
00178 m_scoreTrees[0] = new CanonicalHuffman<float>(in);
00179 }
00180
00181 if(m_containsAlignmentInfo)
00182 m_alignTree = new CanonicalHuffman<AlignPoint>(in);
00183
00184 size_t end = std::ftell(in);
00185 return end - start;
00186 }
00187
00188 std::string PhraseDecoder::MakeSourceKey(std::string &source)
00189 {
00190 return source + m_separator;
00191 }
00192
00193 TargetPhraseVectorPtr PhraseDecoder::CreateTargetPhraseCollection(const Phrase &sourcePhrase, bool topLevel, bool eval)
00194 {
00195
00196
00197
00198 TargetPhraseVectorPtr tpv(new TargetPhraseVector());
00199 size_t bitsLeft = 0;
00200
00201 if(m_coding == PREnc) {
00202 std::pair<TargetPhraseVectorPtr, size_t> cachedPhraseColl
00203 = m_decodingCache.Retrieve(sourcePhrase);
00204
00205
00206 if(cachedPhraseColl.first != NULL && (!topLevel || cachedPhraseColl.second == 0))
00207 return cachedPhraseColl.first;
00208
00209
00210 else if(cachedPhraseColl.first != NULL) {
00211 bitsLeft = cachedPhraseColl.second;
00212 tpv->resize(cachedPhraseColl.first->size());
00213 std::copy(cachedPhraseColl.first->begin(),
00214 cachedPhraseColl.first->end(),
00215 tpv->begin());
00216 }
00217 }
00218
00219
00220 std::string sourcePhraseString = sourcePhrase.GetStringRep(*m_input);
00221 size_t sourcePhraseId = m_phraseDictionary.m_hash[MakeSourceKey(sourcePhraseString)];
00222
00223
00224
00225
00226
00227 if(sourcePhraseId != m_phraseDictionary.m_hash.GetSize()) {
00228
00229 std::string encodedPhraseCollection;
00230 if(m_phraseDictionary.m_inMemory)
00231 encodedPhraseCollection = m_phraseDictionary.m_targetPhrasesMemory[sourcePhraseId].str();
00232 else
00233 encodedPhraseCollection = m_phraseDictionary.m_targetPhrasesMapped[sourcePhraseId].str();
00234
00235 BitWrapper<> encodedBitStream(encodedPhraseCollection);
00236 if(m_coding == PREnc && bitsLeft)
00237 encodedBitStream.SeekFromEnd(bitsLeft);
00238
00239
00240 TargetPhraseVectorPtr decodedPhraseColl =
00241 DecodeCollection(tpv, encodedBitStream, sourcePhrase, topLevel, eval);
00242
00243 return decodedPhraseColl;
00244 } else
00245 return TargetPhraseVectorPtr();
00246 }
00247
00248 TargetPhraseVectorPtr PhraseDecoder::DecodeCollection(
00249 TargetPhraseVectorPtr tpv, BitWrapper<> &encodedBitStream,
00250 const Phrase &sourcePhrase, bool topLevel, bool eval)
00251 {
00252
00253 bool extending = tpv->size();
00254 size_t bitsLeft = encodedBitStream.TellFromEnd();
00255
00256 typedef std::pair<size_t, size_t> AlignPointSizeT;
00257
00258 std::vector<int> sourceWords;
00259 if(m_coding == REnc) {
00260 for(size_t i = 0; i < sourcePhrase.GetSize(); i++) {
00261 std::string sourceWord
00262 = sourcePhrase.GetWord(i).GetString(*m_input, false);
00263 unsigned idx = GetSourceSymbolId(sourceWord);
00264 sourceWords.push_back(idx);
00265 }
00266 }
00267
00268 unsigned phraseStopSymbol = 0;
00269 AlignPoint alignStopSymbol(-1, -1);
00270
00271 std::vector<float> scores;
00272 std::set<AlignPointSizeT> alignment;
00273
00274 enum DecodeState { New, Symbol, Score, Alignment, Add } state = New;
00275
00276 size_t srcSize = sourcePhrase.GetSize();
00277
00278 TargetPhrase* targetPhrase = NULL;
00279 while(encodedBitStream.TellFromEnd()) {
00280
00281 if(state == New) {
00282
00283 tpv->push_back(TargetPhrase());
00284 targetPhrase = &tpv->back();
00285
00286 alignment.clear();
00287 scores.clear();
00288
00289 state = Symbol;
00290 }
00291
00292 if(state == Symbol) {
00293 unsigned symbol = m_symbolTree->Read(encodedBitStream);
00294 if(symbol == phraseStopSymbol) {
00295 state = Score;
00296 } else {
00297 if(m_coding == REnc) {
00298 std::string wordString;
00299 size_t type = GetREncType(symbol);
00300
00301 if(type == 1) {
00302 unsigned decodedSymbol = DecodeREncSymbol1(symbol);
00303 wordString = GetTargetSymbol(decodedSymbol);
00304 } else if (type == 2) {
00305 size_t rank = DecodeREncSymbol2Rank(symbol);
00306 size_t srcPos = DecodeREncSymbol2Position(symbol);
00307
00308 if(srcPos >= sourceWords.size())
00309 return TargetPhraseVectorPtr();
00310
00311 wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank));
00312 if(m_phraseDictionary.m_useAlignmentInfo) {
00313 size_t trgPos = targetPhrase->GetSize();
00314 alignment.insert(AlignPoint(srcPos, trgPos));
00315 }
00316 } else if(type == 3) {
00317 size_t rank = DecodeREncSymbol3(symbol);
00318 size_t srcPos = targetPhrase->GetSize();
00319
00320 if(srcPos >= sourceWords.size())
00321 return TargetPhraseVectorPtr();
00322
00323 wordString = GetTargetSymbol(GetTranslation(sourceWords[srcPos], rank));
00324 if(m_phraseDictionary.m_useAlignmentInfo) {
00325 size_t trgPos = srcPos;
00326 alignment.insert(AlignPoint(srcPos, trgPos));
00327 }
00328 }
00329
00330 Word word;
00331 word.CreateFromString(Output, *m_output, wordString, false);
00332 targetPhrase->AddWord(word);
00333 } else if(m_coding == PREnc) {
00334
00335 if(GetPREncType(symbol) == 1) {
00336 unsigned decodedSymbol = DecodePREncSymbol1(symbol);
00337
00338 Word word;
00339 word.CreateFromString(Output, *m_output,
00340 GetTargetSymbol(decodedSymbol), false);
00341 targetPhrase->AddWord(word);
00342 }
00343
00344 else {
00345 int left = DecodePREncSymbol2Left(symbol);
00346 int right = DecodePREncSymbol2Right(symbol);
00347 unsigned rank = DecodePREncSymbol2Rank(symbol);
00348
00349 int srcStart = left + targetPhrase->GetSize();
00350 int srcEnd = srcSize - right - 1;
00351
00352
00353 if(0 > srcStart || srcStart > srcEnd || unsigned(srcEnd) >= srcSize)
00354 return TargetPhraseVectorPtr();
00355
00356
00357 if(m_maxRank && rank > m_maxRank)
00358 return TargetPhraseVectorPtr();
00359
00360
00361 TargetPhraseVectorPtr subTpv = tpv;
00362
00363
00364 if(unsigned(srcEnd - srcStart + 1) != srcSize) {
00365 Phrase subPhrase = sourcePhrase.GetSubString(Range(srcStart, srcEnd));
00366 subTpv = CreateTargetPhraseCollection(subPhrase, false);
00367 } else {
00368
00369 if(rank >= tpv->size()-1)
00370 return TargetPhraseVectorPtr();
00371 }
00372
00373
00374 if(subTpv != NULL && rank < subTpv->size()) {
00375
00376 TargetPhrase& subTp = subTpv->at(rank);
00377 if(m_phraseDictionary.m_useAlignmentInfo) {
00378
00379 for(AlignmentInfo::const_iterator it = subTp.GetAlignTerm().begin();
00380 it != subTp.GetAlignTerm().end(); it++) {
00381 alignment.insert(AlignPointSizeT(srcStart + it->first,
00382 targetPhrase->GetSize() + it->second));
00383 }
00384 }
00385 targetPhrase->Append(subTp);
00386 } else
00387 return TargetPhraseVectorPtr();
00388 }
00389 } else {
00390 Word word;
00391 word.CreateFromString(Output, *m_output,
00392 GetTargetSymbol(symbol), false);
00393 targetPhrase->AddWord(word);
00394 }
00395 }
00396 } else if(state == Score) {
00397 size_t idx = m_multipleScoreTrees ? scores.size() : 0;
00398 float score = m_scoreTrees[idx]->Read(encodedBitStream);
00399 scores.push_back(score);
00400
00401 if(scores.size() == m_numScoreComponent) {
00402 targetPhrase->GetScoreBreakdown().Assign(&m_phraseDictionary, scores);
00403
00404 if(m_containsAlignmentInfo)
00405 state = Alignment;
00406 else
00407 state = Add;
00408 }
00409 } else if(state == Alignment) {
00410 AlignPoint alignPoint = m_alignTree->Read(encodedBitStream);
00411 if(alignPoint == alignStopSymbol) {
00412 state = Add;
00413 } else {
00414 if(m_phraseDictionary.m_useAlignmentInfo)
00415 alignment.insert(AlignPointSizeT(alignPoint));
00416 }
00417 }
00418
00419 if(state == Add) {
00420 if(m_phraseDictionary.m_useAlignmentInfo) {
00421 size_t sourceSize = sourcePhrase.GetSize();
00422 size_t targetSize = targetPhrase->GetSize();
00423 for(std::set<AlignPointSizeT>::iterator it = alignment.begin(); it != alignment.end(); it++) {
00424 if(it->first >= sourceSize || it->second >= targetSize)
00425 return TargetPhraseVectorPtr();
00426 }
00427 targetPhrase->SetAlignTerm(alignment);
00428 }
00429
00430 if(eval) {
00431 targetPhrase->EvaluateInIsolation(sourcePhrase, m_phraseDictionary.GetFeaturesToApply());
00432 }
00433
00434 if(m_coding == PREnc) {
00435 if(!m_maxRank || tpv->size() <= m_maxRank)
00436 bitsLeft = encodedBitStream.TellFromEnd();
00437
00438 if(!topLevel && m_maxRank && tpv->size() >= m_maxRank)
00439 break;
00440 }
00441
00442 if(encodedBitStream.TellFromEnd() <= 8)
00443 break;
00444
00445 state = New;
00446 }
00447 }
00448
00449 if(m_coding == PREnc && !extending) {
00450 bitsLeft = bitsLeft > 8 ? bitsLeft : 0;
00451 m_decodingCache.Cache(sourcePhrase, tpv, bitsLeft, m_maxRank);
00452 }
00453
00454 return tpv;
00455 }
00456
00457 void PhraseDecoder::PruneCache()
00458 {
00459 m_decodingCache.Prune();
00460 }
00461
00462 }