00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <cassert>
00023 #include <limits>
00024 #include <iostream>
00025 #include <fstream>
00026 #include "dictionary.h"
00027 #include "n_gram.h"
00028 #include "lmContainer.h"
00029
00030 #include "LanguageModelIRST.h"
00031 #include "TypeDef.h"
00032 #include "Util.h"
00033 #include "FactorCollection.h"
00034 #include "Phrase.h"
00035 #include "InputFileStream.h"
00036 #include "StaticData.h"
00037
00038 using namespace std;
00039
00040 namespace Moses
00041 {
00042
00043 LanguageModelIRST::LanguageModelIRST(int dub)
00044 :m_lmtb(0),m_lmtb_dub(dub)
00045 {
00046 }
00047
00048 LanguageModelIRST::~LanguageModelIRST()
00049 {
00050
00051 #ifndef WIN32
00052 TRACE_ERR( "reset mmap\n");
00053 m_lmtb->reset_mmap();
00054 #endif
00055
00056 delete m_lmtb;
00057 }
00058
00059
00060 bool LanguageModelIRST::Load(const std::string &filePath,
00061 FactorType factorType,
00062 size_t nGramOrder)
00063 {
00064 cerr << "In LanguageModelIRST::Load: nGramOrder = " << nGramOrder << "\n";
00065
00066 FactorCollection &factorCollection = FactorCollection::Instance();
00067
00068 m_factorType = factorType;
00069 m_nGramOrder = nGramOrder;
00070 m_filePath = filePath;
00071
00072
00073 m_lmtb = m_lmtb->CreateLanguageModel(m_filePath);
00074 m_lmtb->setMaxLoadedLevel(1000);
00075 m_lmtb->load(m_filePath);
00076 d=m_lmtb->getDict();
00077 d->incflag(1);
00078
00079 m_lmtb_size=m_lmtb->maxlevel();
00080
00081
00082
00083 m_unknownId = d->oovcode();
00084 m_empty = -1;
00085
00086 CreateFactors(factorCollection);
00087
00088 VERBOSE(1, "IRST: m_unknownId=" << m_unknownId << std::endl);
00089
00090
00091 m_lmtb->init_caches(m_lmtb_size>2?m_lmtb_size-1:2);
00092
00093 if (m_lmtb_dub > 0) m_lmtb->setlogOOVpenalty(m_lmtb_dub);
00094
00095 return true;
00096 }
00097
00098 void LanguageModelIRST::CreateFactors(FactorCollection &factorCollection)
00099 {
00100
00101
00102 std::map<size_t, int> lmIdMap;
00103 size_t maxFactorId = 0;
00104 m_empty = -1;
00105
00106 dict_entry *entry;
00107 dictionary_iter iter(d);
00108 while ( (entry = iter.next()) != NULL) {
00109 size_t factorId = factorCollection.AddFactor(Output, m_factorType, entry->word)->GetId();
00110 lmIdMap[factorId] = entry->code;
00111 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00112 }
00113
00114 size_t factorId;
00115
00116 m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_);
00117 factorId = m_sentenceStart->GetId();
00118 m_lmtb_sentenceStart=lmIdMap[factorId] = GetLmID(BOS_);
00119 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00120 m_sentenceStartArray[m_factorType] = m_sentenceStart;
00121
00122 m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_);
00123 factorId = m_sentenceEnd->GetId();
00124 m_lmtb_sentenceEnd=lmIdMap[factorId] = GetLmID(EOS_);
00125 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00126 m_sentenceEndArray[m_factorType] = m_sentenceEnd;
00127
00128
00129 m_lmIdLookup.resize(maxFactorId+1);
00130 fill(m_lmIdLookup.begin(), m_lmIdLookup.end(), m_empty);
00131
00132 map<size_t, int>::iterator iterMap;
00133 for (iterMap = lmIdMap.begin() ; iterMap != lmIdMap.end() ; ++iterMap) {
00134 m_lmIdLookup[iterMap->first] = iterMap->second;
00135 }
00136 }
00137
00138 int LanguageModelIRST::GetLmID( const std::string &str ) const
00139 {
00140 return d->encode( str.c_str() );
00141 }
00142
00143 int LanguageModelIRST::GetLmID( const Factor *factor ) const
00144 {
00145 size_t factorId = factor->GetId();
00146
00147 if ((factorId >= m_lmIdLookup.size()) || (m_lmIdLookup[factorId] == m_empty)) {
00148 if (d->incflag()==1) {
00149 std::string s = factor->GetString();
00150 int code = d->encode(s.c_str());
00151
00160
00162
00179
00180
00181 if (factorId >= m_lmIdLookup.size()){
00182
00183
00184 m_lmIdLookup.resize(factorId+10, m_empty);
00185 }
00186
00187
00188 m_lmIdLookup[factorId] = code;
00189 return code;
00190
00191 } else {
00192 return m_unknownId;
00193 }
00194 } else {
00195 return m_lmIdLookup[factorId];
00196 }
00197 }
00198
00199 LMResult LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, State* finalState) const
00200 {
00201 FactorType factorType = GetFactorType();
00202
00203
00204 size_t count = contextFactor.size();
00205 if (count < 0) {
00206 cerr << "ERROR count < 0\n";
00207 exit(100);
00208 };
00209
00210
00211 int codes[MAX_NGRAM_SIZE];
00212
00213 size_t idx=0;
00214
00215
00216 if (count < (size_t) (m_lmtb_size-1)) codes[idx++] = m_lmtb_sentenceEnd;
00217 if (count < (size_t) m_lmtb_size) codes[idx++] = m_lmtb_sentenceStart;
00218
00219 for (size_t i = 0 ; i < count ; i++) {
00220 codes[idx++] = GetLmID((*contextFactor[i])[factorType]);
00221 }
00222 LMResult result;
00223 result.unknown = (codes[idx - 1] == m_unknownId);
00224
00225 char* msp = NULL;
00226 unsigned int ilen;
00227 result.score = m_lmtb->clprob(codes,idx,NULL,NULL,&msp,&ilen);
00228
00229 if (finalState) *finalState=(State *) msp;
00230
00231 result.score = TransformLMScore(result.score);
00232 return result;
00233 }
00234
00235
00236 bool LMCacheCleanup(size_t sentences_done, size_t m_lmcache_cleanup_threshold)
00237 {
00238 if (sentences_done==-1) return true;
00239 if (m_lmcache_cleanup_threshold)
00240 if (sentences_done % m_lmcache_cleanup_threshold == 0)
00241 return true;
00242 return false;
00243 }
00244
00245
00246 void LanguageModelIRST::CleanUpAfterSentenceProcessing()
00247 {
00248 const StaticData &staticData = StaticData::Instance();
00249 static int sentenceCount = 0;
00250 sentenceCount++;
00251
00252 size_t lmcache_cleanup_threshold = staticData.GetLMCacheCleanupThreshold();
00253
00254 if (LMCacheCleanup(sentenceCount, lmcache_cleanup_threshold)) {
00255 TRACE_ERR( "reset caches\n");
00256 m_lmtb->reset_caches();
00257 }
00258 }
00259
00260 void LanguageModelIRST::InitializeBeforeSentenceProcessing()
00261 {
00262
00263 #ifdef TRACE_CACHE
00264 m_lmtb->sentence_id++;
00265 #endif
00266 }
00267
00268 }
00269