00001 #include "OxLM.h"
00002
00003 #include <boost/archive/binary_iarchive.hpp>
00004 #include <boost/archive/binary_oarchive.hpp>
00005 #include <boost/filesystem.hpp>
00006 #include <boost/functional/hash.hpp>
00007
00008 #include "moses/FactorCollection.h"
00009 #include "moses/InputType.h"
00010 #include "moses/TranslationTask.h"
00011
00012 using namespace std;
00013 using namespace oxlm;
00014
00015 namespace Moses
00016 {
00017
00018 template<class Model>
00019 OxLM<Model>::OxLM(const string &line)
00020 : LanguageModelSingleFactor(line), normalized(true),
00021 posBackOff(false), posFactorType(1),
00022 persistentCache(false)
00023 {
00024 ReadParameters();
00025
00026 FactorCollection &factorCollection = FactorCollection::Instance();
00027
00028
00029 m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_);
00030 m_sentenceStartWord[m_factorType] = m_sentenceStart;
00031
00032 m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_);
00033 m_sentenceEndWord[m_factorType] = m_sentenceEnd;
00034
00035 cacheHits = totalHits = 0;
00036 }
00037
00038
00039 template<class Model>
00040 OxLM<Model>::~OxLM()
00041 {
00042 if (persistentCache) {
00043 if (cache.get()) {
00044 string cache_file = m_filePath + ".phrases.cache.bin";
00045 savePersistentCache(cache_file);
00046 }
00047
00048 double cache_hit_ratio = 100.0 * cacheHits / totalHits;
00049 cerr << "Cache hit ratio: " << cache_hit_ratio << endl;
00050 }
00051 }
00052
00053
00054 template<class Model>
00055 void OxLM<Model>::SetParameter(const string& key, const string& value)
00056 {
00057 if (key == "normalized") {
00058 normalized = Scan<bool>(value);
00059 } else if (key == "persistent-cache") {
00060 persistentCache = Scan<bool>(value);
00061 } else if (key == "normalized") {
00062 normalized = Scan<bool>(value);
00063 } else if (key == "pos-back-off") {
00064 posBackOff = Scan<bool>(value);
00065 } else if (key == "pos-factor-type") {
00066 posFactorType = Scan<FactorType>(value);
00067 } else {
00068 LanguageModelSingleFactor::SetParameter(key, value);
00069 }
00070 }
00071
00072 template<class Model>
00073 void OxLM<Model>::Load(AllOptions::ptr const& opts)
00074 {
00075 model.load(m_filePath);
00076
00077 boost::shared_ptr<Vocabulary> vocab = model.getVocab();
00078 mapper = boost::make_shared<OxLMMapper>(vocab, posBackOff, posFactorType);
00079
00080 kSTART = vocab->convert("<s>");
00081 kSTOP = vocab->convert("</s>");
00082 kUNKNOWN = vocab->convert("<unk>");
00083
00084 size_t ngram_order = model.getConfig()->ngram_order;
00085 UTIL_THROW_IF2(
00086 m_nGramOrder != ngram_order,
00087 "Wrong order for OxLM: LM has " << ngram_order << ", but Moses expects " << m_nGramOrder);
00088 }
00089
00090 template<class Model>
00091 double OxLM<Model>::GetScore(int word, const vector<int>& context) const
00092 {
00093 if (normalized) {
00094 return model.getLogProb(word, context);
00095 } else {
00096 return model.getUnnormalizedScore(word, context);
00097 }
00098 }
00099
00100 template<class Model>
00101 LMResult OxLM<Model>::GetValue(
00102 const vector<const Word*> &contextFactor, State* finalState) const
00103 {
00104 if (!cache.get()) {
00105 cache.reset(new QueryCache());
00106 string cache_file = m_filePath + ".phrases.cache.bin";
00107 loadPersistentCache(cache_file);
00108 }
00109
00110 vector<int> context;
00111 int word;
00112 mapper->convert(contextFactor, context, word);
00113
00114 size_t context_width = m_nGramOrder - 1;
00115 if (!context.empty() && context.back() == kSTART) {
00116 context.resize(context_width, kSTART);
00117 } else {
00118 context.resize(context_width, kUNKNOWN);
00119 }
00120
00121 double score;
00122 if (persistentCache) {
00123 ++totalHits;
00124 NGram query(word, context);
00125 pair<double, bool> ret = cache->get(query);
00126 if (ret.second) {
00127 score = ret.first;
00128 ++cacheHits;
00129 } else {
00130 score = GetScore(word, context);
00131 cache->put(query, score);
00132 }
00133 } else {
00134 score = GetScore(word, context);
00135 }
00136
00137 LMResult ret;
00138 ret.score = score;
00139 ret.unknown = (word == kUNKNOWN);
00140
00141
00142 size_t seed = 0;
00143 boost::hash_combine(seed, word);
00144 for (size_t i = 0; i < context.size() && i < context_width - 1; ++i) {
00145 int id = context[i];
00146 boost::hash_combine(seed, id);
00147 }
00148
00149 (*finalState) = (State*) seed;
00150 return ret;
00151 }
00152
00153 template<class Model>
00154 void OxLM<Model>::loadPersistentCache(const string& cache_file) const
00155 {
00156 if (boost::filesystem::exists(cache_file)) {
00157 ifstream f(cache_file);
00158 boost::archive::binary_iarchive iar(f);
00159 cerr << "Loading n-gram probability cache from " << cache_file << endl;
00160 iar >> *cache;
00161 cerr << "Done loading " << cache->size()
00162 << " n-gram probabilities..." << endl;
00163 } else {
00164 cerr << "Cache file not found" << endl;
00165 }
00166 }
00167
00168 template<class Model>
00169 void OxLM<Model>::savePersistentCache(const string& cache_file) const
00170 {
00171 ofstream f(cache_file);
00172 boost::archive::binary_oarchive oar(f);
00173 cerr << "Saving persistent cache to " << cache_file << endl;
00174 oar << *cache;
00175 cerr << "Done saving " << cache->size()
00176 << " n-gram probabilities..." << endl;
00177 }
00178
00179 template<class Model>
00180 void OxLM<Model>::InitializeForInput(ttasksptr const& ttask)
00181 {
00182 const InputType& source = *ttask->GetSource();
00183 LanguageModelSingleFactor::InitializeForInput(ttask);
00184
00185 if (persistentCache) {
00186 if (!cache.get()) {
00187 cache.reset(new QueryCache());
00188 }
00189
00190 int sentence_id = source.GetTranslationId();
00191 string cache_file = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
00192 loadPersistentCache(cache_file);
00193 }
00194 }
00195
00196 template<class Model>
00197 void OxLM<Model>::CleanUpAfterSentenceProcessing(const InputType& source)
00198 {
00199
00200 model.clearCache();
00201
00202 if (persistentCache) {
00203 int sentence_id = source.GetTranslationId();
00204 string cache_file = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
00205 savePersistentCache(cache_file);
00206
00207 cache->clear();
00208 }
00209
00210 LanguageModelSingleFactor::CleanUpAfterSentenceProcessing(source);
00211 }
00212
00213 template class OxLM<LM>;
00214 template class OxLM<FactoredLM>;
00215 template class OxLM<FactoredMaxentLM>;
00216 template class OxLM<FactoredTreeLM>;
00217
00218 }
00219
00220
00221