00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include "ParallelBackoff.h"
00023
00024 #include <vector>
00025 #include <string>
00026 #include <sstream>
00027 #include <fstream>
00028
00029 #include "MultiFactor.h"
00030 #include "moses/Word.h"
00031 #include "moses/Factor.h"
00032 #include "moses/FactorTypeSet.h"
00033 #include "moses/FactorCollection.h"
00034 #include "moses/Phrase.h"
00035 #include "moses/TypeDef.h"
00036 #include "moses/Util.h"
00037
00038 #include "FNgramSpecs.h"
00039 #include "FNgramStats.h"
00040 #include "FactoredVocab.h"
00041 #include "FNgram.h"
00042 #include "wmatrix.h"
00043 #include "Vocab.h"
00044 #include "File.h"
00045
00046 using namespace std;
00047
00048 namespace Moses
00049 {
00050
00051 namespace
00052 {
00053 class LanguageModelParallelBackoff : public LanguageModelMultiFactor
00054 {
00055 private:
00056 std::vector<FactorType> m_factorTypesOrdered;
00057
00058 FactoredVocab *m_srilmVocab;
00059 FNgram *m_srilmModel;
00060 VocabIndex m_unknownId;
00061 VocabIndex m_wtid;
00062 VocabIndex m_wtbid;
00063 VocabIndex m_wteid;
00064 FNgramSpecs<FNgramCount>* fnSpecs;
00065
00066 std::map<size_t, VocabIndex>* lmIdMap;
00067 std::fstream* debugStream;
00068
00069 WidMatrix *widMatrix;
00070
00071 public:
00072 LanguageModelParallelBackoff(const std::string &line)
00073 :LanguageModelMultiFactor("ParallelBackoffLM", line) {
00074 }
00075
00076 ~LanguageModelParallelBackoff();
00077
00078 bool Load(const std::string &filePath, const std::vector<FactorType> &factorTypes, size_t nGramOrder);
00079
00080 VocabIndex GetLmID( const std::string &str ) const;
00081
00082 VocabIndex GetLmID( const Factor *factor, FactorType ft ) const;
00083
00084 void CreateFactors();
00085
00086 LMResult GetValueForgotState(const std::vector<const Word*> &contextFactor, FFState &outState) const;
00087 const FFState *GetNullContextState() const;
00088 const FFState *GetBeginSentenceState() const;
00089 FFState *NewState(const FFState *from) const;
00090 };
00091
00092 LanguageModelParallelBackoff::~LanguageModelParallelBackoff()
00093 {
00095 }
00096
00097
00098 bool LanguageModelParallelBackoff::Load(const std::string &filePath, const std::vector<FactorType> &factorTypes, size_t nGramOrder)
00099 {
00100
00101 cerr << "Loading Language Model Parallel Backoff!!!\n";
00102 widMatrix = new ::WidMatrix();
00103 m_factorTypes = FactorMask(factorTypes);
00104 m_srilmVocab = new ::FactoredVocab();
00105
00106
00107 fnSpecs = 0;
00108 File f(filePath.c_str(),"r");
00109 fnSpecs = new ::FNgramSpecs<FNgramCount>(f,*m_srilmVocab, 0);
00110
00111 cerr << "Loaded fnSpecs!\n";
00112
00113 m_srilmVocab->unkIsWord() = true;
00114 m_srilmVocab->nullIsWord() = true;
00115 m_srilmVocab->toLower() = false;
00116
00117 FNgramStats *factoredStats = new FNgramStats(*m_srilmVocab, *fnSpecs);
00118
00119 factoredStats->debugme(2);
00120
00121 cerr << "Factored stats\n";
00122
00123 FNgram* fngramLM = new FNgram(*m_srilmVocab,*fnSpecs);
00124
00125 cerr << "FNgram object created\n";
00126
00127 fngramLM->skipOOVs = false;
00128
00129 if (!factoredStats->read()) {
00130 cerr << "error reading in counts in factor file\n";
00131 exit(1);
00132 }
00133
00134 cerr << "Factored stats read!\n";
00135
00136 factoredStats->estimateDiscounts();
00137 factoredStats->computeCardinalityFunctions();
00138 factoredStats->sumCounts();
00139
00140 cerr << "Another three operations made!\n";
00141
00142 if (!fngramLM->read()) {
00143 cerr << "format error in lm file\n";
00144 exit(1);
00145 }
00146
00147 cerr << "fngramLM reads!\n";
00148
00149 m_filePath = filePath;
00150 m_nGramOrder= nGramOrder;
00151
00152 m_factorTypesOrdered= factorTypes;
00153
00154 m_unknownId = m_srilmVocab->unkIndex();
00155
00156 cerr << "m_unknowdId = " << m_unknownId << endl;
00157
00158 m_srilmModel = fngramLM;
00159
00160 cerr << "Create factors...\n";
00161
00162 CreateFactors();
00163
00164 cerr << "Factors created! \n";
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188 return true;
00189 }
00190
00191 VocabIndex LanguageModelParallelBackoff::GetLmID( const std::string &str ) const
00192 {
00193 return m_srilmVocab->getIndex( str.c_str(), m_unknownId );
00194 }
00195
00196 VocabIndex LanguageModelParallelBackoff::GetLmID( const Factor *factor, size_t ft ) const
00197 {
00198
00199 size_t factorId = factor->GetId();
00200 if ( lmIdMap->find( factorId * 10 + ft ) != lmIdMap->end() ) {
00201 return lmIdMap->find( factorId * 10 + ft )->second;
00202 } else {
00203 return m_unknownId;
00204 }
00205
00206 }
00207
00208 void LanguageModelParallelBackoff::CreateFactors()
00209 {
00210
00211
00212 FactorCollection &factorCollection = FactorCollection::Instance();
00213
00214 lmIdMap = new std::map<size_t, VocabIndex>();
00215
00216
00217 VocabString str;
00218 VocabIter iter(*m_srilmVocab);
00219
00220 iter.init();
00221
00222 size_t pomFactorTypeNum = 0;
00223
00224
00225 while ( (str = iter.next()) != NULL) {
00226
00227 if ((str[0] < 'a' || str[0] > 'k') && str[0] != 'W') {
00228 continue;
00229 }
00230 VocabIndex lmId = GetLmID(str);
00231 pomFactorTypeNum = str[0] - 'a';
00232
00233 size_t factorId = factorCollection.AddFactor(Output, m_factorTypesOrdered[pomFactorTypeNum], &(str[2]) )->GetId();
00234 (*lmIdMap)[factorId * 10 + pomFactorTypeNum] = lmId;
00235 }
00236
00237 size_t factorIdStart;
00238 size_t factorIdEnd;
00239
00240
00241 for (size_t index = 0 ; index < m_factorTypesOrdered.size() ; ++index) {
00242 FactorType factorType = m_factorTypesOrdered[index];
00243 m_sentenceStartWord[index] = factorCollection.AddFactor(Output, factorType, BOS_);
00244
00245
00246 m_sentenceEndWord[index] = factorCollection.AddFactor(Output, factorType, EOS_);
00247
00248 factorIdStart = m_sentenceStartWord[index]->GetId();
00249 factorIdEnd = m_sentenceEndWord[index]->GetId();
00250
00251
00252
00253
00254
00255
00256
00257 (*lmIdMap)[factorIdStart * 10 + index] = GetLmID(BOS_);
00258 (*lmIdMap)[factorIdEnd * 10 + index] = GetLmID(EOS_);
00259
00260 cerr << "BOS_:" << GetLmID(BOS_) << ", EOS_:" << GetLmID(EOS_) << endl;
00261
00262 }
00263
00264 m_wtid = GetLmID("W-<unk>");
00265 m_wtbid = GetLmID("W-<s>");
00266 m_wteid = GetLmID("W-</s>");
00267
00268 cerr << "W-<unk> index: " << m_wtid << endl;
00269 cerr << "W-<s> index: " << m_wtbid << endl;
00270 cerr << "W-</s> index: " << m_wteid << endl;
00271
00272
00273 }
00274
00275 LMResult LanguageModelParallelBackoff::GetValueForgotState(const std::vector<const Word*> &contextFactor, FFState & ) const
00276 {
00277
00278 static WidMatrix widMatrix;
00279
00280 for (int i=0; i<contextFactor.size(); i++)
00281 ::memset(widMatrix[i],0,(m_factorTypesOrdered.size() + 1)*sizeof(VocabIndex));
00282
00283
00284 for (size_t i = 0; i < contextFactor.size(); i++) {
00285 const Word &word = *contextFactor[i];
00286
00287 for (size_t j = 0; j < m_factorTypesOrdered.size(); j++) {
00288 const Factor *factor = word[ m_factorTypesOrdered[j] ];
00289
00290 if (factor == NULL)
00291 widMatrix[i][j + 1] = 0;
00292 else
00293 widMatrix[i][j + 1] = GetLmID(factor, j);
00294 }
00295
00296 if (widMatrix[i][1] == GetLmID(m_sentenceStartWord[0], 0) ) {
00297 widMatrix[i][0] = m_wtbid;
00298 } else if (widMatrix[i][1] == GetLmID(m_sentenceEndWord[0], 0 )) {
00299 widMatrix[i][0] = m_wteid;
00300 } else {
00301 widMatrix[i][0] = m_wtid;
00302 }
00303 }
00304
00305
00306 LMResult ret;
00307 ret.score = m_srilmModel->wordProb( widMatrix, contextFactor.size() - 1, contextFactor.size() );
00308 ret.score = FloorScore(TransformLMScore(ret.score));
00309 ret.unknown = !contextFactor.empty() && (widMatrix[contextFactor.size() - 1][0] == m_unknownId);
00310 return ret;
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334 }
00335
00336
00337 FFState *LanguageModelParallelBackoff::NewState(const FFState * ) const
00338 {
00339 return NULL;
00340 }
00341
00342 const FFState *LanguageModelParallelBackoff::GetNullContextState() const
00343 {
00344 return NULL;
00345 }
00346
00347 const FFState *LanguageModelParallelBackoff::GetBeginSentenceState() const
00348 {
00349 return NULL;
00350 }
00351
00352 }
00353
00354
00355 }
00356