00001 #include "lm/model.hh"
00002
00003 #include "lm/blank.hh"
00004 #include "lm/lm_exception.hh"
00005 #include "lm/search_hashed.hh"
00006 #include "lm/search_trie.hh"
00007 #include "lm/read_arpa.hh"
00008 #include "util/have.hh"
00009 #include "util/murmur_hash.hh"
00010
00011 #include <algorithm>
00012 #include <functional>
00013 #include <numeric>
00014 #include <cmath>
00015 #include <limits>
00016
00017 namespace lm {
00018 namespace ngram {
00019 namespace detail {
00020
00021 template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
00022
00023 template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
00024 return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
00025 }
00026
00027 template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
00028 size_t goal_size = util::CheckOverflow(Size(counts, config));
00029 uint8_t *start = static_cast<uint8_t*>(base);
00030 size_t allocated = VocabularyT::Size(counts[0], config);
00031 vocab_.SetupMemory(start, allocated, counts[0], config);
00032 start += allocated;
00033 start = search_.SetupMemory(start, counts, config);
00034 if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
00035 }
00036
00037 namespace {
00038 void ComplainAboutARPA(const Config &config, ModelType model_type) {
00039 if (config.write_mmap || !config.messages) return;
00040 if (config.arpa_complain == Config::ALL) {
00041 *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
00042 } else if (config.arpa_complain == Config::EXPENSIVE &&
00043 (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
00044 *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
00045 }
00046 }
00047
00048 void CheckCounts(const std::vector<uint64_t> &counts) {
00049 UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
00050 if (sizeof(uint64_t) > sizeof(std::size_t)) {
00051 for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) {
00052 UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines.");
00053 }
00054 }
00055 }
00056
00057 }
00058
00059 template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
00060 util::scoped_fd fd(util::OpenReadOrThrow(file));
00061 if (IsBinaryFormat(fd.get())) {
00062 Parameters parameters;
00063 int fd_shallow = fd.release();
00064 backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
00065 CheckCounts(parameters.counts);
00066
00067 Config new_config(init_config);
00068 new_config.probing_multiplier = parameters.fixed.probing_multiplier;
00069 Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
00070 UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
00071
00072 SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
00073 vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
00074 } else {
00075 ComplainAboutARPA(init_config, kModelType);
00076 InitializeFromARPA(fd.release(), file, init_config);
00077 }
00078
00079
00080 State begin_sentence = State();
00081 begin_sentence.length = 1;
00082 begin_sentence.words[0] = vocab_.BeginSentence();
00083 typename Search::Node ignored_node;
00084 bool ignored_independent_left;
00085 uint64_t ignored_extend_left;
00086 begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
00087 State null_context = State();
00088 null_context.length = 0;
00089 P::Init(begin_sentence, null_context, vocab_, search_.Order());
00090 }
00091
00092 template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
00093
00094 util::FilePiece f(fd, file, config.ProgressMessages());
00095 try {
00096 std::vector<uint64_t> counts;
00097
00098 ReadARPACounts(f, counts);
00099 CheckCounts(counts);
00100 if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
00101 if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
00102
00103 std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
00104
00105 vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
00106
00107 if (config.write_mmap && config.include_vocab) {
00108 WriteWordsWrapper wrap(config.enumerate_vocab);
00109 vocab_.ConfigureEnumerate(&wrap, counts[0]);
00110 search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
00111 void *vocab_rebase, *search_rebase;
00112 backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
00113
00114 vocab_.Relocate(vocab_rebase);
00115 search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
00116 } else {
00117 vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
00118 search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
00119 }
00120
00121 if (!vocab_.SawUnk()) {
00122 assert(config.unknown_missing != THROW_UP);
00123
00124 search_.UnknownUnigram().backoff = 0.0;
00125 search_.UnknownUnigram().prob = config.unknown_missing_logprob;
00126 }
00127 backing_.FinishFile(config, kModelType, kVersion, counts);
00128 } catch (util::Exception &e) {
00129 e << " Byte: " << f.Offset();
00130 throw;
00131 }
00132 }
00133
00134 template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
00135 FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
00136 for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
00137 ret.prob += *i;
00138 }
00139 return ret;
00140 }
00141
00142 template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
00143 context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
00144 FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
00145
00146
00147 unsigned char start = ret.ngram_length;
00148 if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
00149
00150 bool independent_left;
00151 uint64_t extend_left;
00152 typename Search::Node node;
00153 if (start <= 1) {
00154 ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
00155 start = 2;
00156 } else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
00157 return ret;
00158 }
00159
00160 unsigned char order_minus_2 = start - 2;
00161 for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) {
00162 typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
00163 if (!p.Found()) break;
00164 ret.prob += p.Backoff();
00165 }
00166 return ret;
00167 }
00168
00169 template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
00170
00171 context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
00172 if (context_rend == context_rbegin) {
00173 out_state.length = 0;
00174 return;
00175 }
00176 typename Search::Node node;
00177 bool independent_left;
00178 uint64_t extend_left;
00179 out_state.backoff[0] = search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
00180 out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
00181 float *backoff_out = out_state.backoff + 1;
00182 unsigned char order_minus_2 = 0;
00183 for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++order_minus_2) {
00184 typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
00185 if (!p.Found()) {
00186 std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
00187 return;
00188 }
00189 *backoff_out = p.Backoff();
00190 if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1;
00191 }
00192 std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
00193 }
00194
00195 template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ExtendLeft(
00196 const WordIndex *add_rbegin, const WordIndex *add_rend,
00197 const float *backoff_in,
00198 uint64_t extend_pointer,
00199 unsigned char extend_length,
00200 float *backoff_out,
00201 unsigned char &next_use) const {
00202 FullScoreReturn ret;
00203 typename Search::Node node;
00204 if (extend_length == 1) {
00205 typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(extend_pointer), node, ret.independent_left, ret.extend_left));
00206 ret.rest = ptr.Rest();
00207 ret.prob = ptr.Prob();
00208 assert(!ret.independent_left);
00209 } else {
00210 typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node));
00211 ret.rest = ptr.Rest();
00212 ret.prob = ptr.Prob();
00213 ret.extend_left = extend_pointer;
00214
00215 ret.independent_left = false;
00216 }
00217 float subtract_me = ret.rest;
00218 ret.ngram_length = extend_length;
00219 next_use = extend_length;
00220 ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
00221 next_use -= extend_length;
00222
00223 for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
00224 ret.prob -= subtract_me;
00225 ret.rest -= subtract_me;
00226 return ret;
00227 }
00228
00229 namespace {
00230
00231
00232
00233 void CopyRemainingHistory(const WordIndex *from, State &out_state) {
00234 WordIndex *out = out_state.words + 1;
00235 const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
00236 for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
00237 }
00238 }
00239
00240
00241
00242
00243
00244
00245 template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
00246 const WordIndex *const context_rbegin,
00247 const WordIndex *const context_rend,
00248 const WordIndex new_word,
00249 State &out_state) const {
00250 assert(new_word < vocab_.Bound());
00251 FullScoreReturn ret;
00252
00253 ret.ngram_length = 1;
00254
00255 typename Search::Node node;
00256 typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left));
00257 out_state.backoff[0] = uni.Backoff();
00258 ret.prob = uni.Prob();
00259 ret.rest = uni.Rest();
00260
00261
00262 out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
00263
00264 out_state.words[0] = new_word;
00265 if (context_rbegin == context_rend) return ret;
00266
00267 ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret);
00268 CopyRemainingHistory(context_rbegin, out_state);
00269 return ret;
00270 }
00271
00272 template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const {
00273 for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) {
00274 if (hist_iter == context_rend) return;
00275 if (ret.independent_left) return;
00276 if (order_minus_2 == P::Order() - 2) break;
00277
00278 typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left));
00279 if (!pointer.Found()) return;
00280 *backoff_out = pointer.Backoff();
00281 ret.prob = pointer.Prob();
00282 ret.rest = pointer.Rest();
00283 ret.ngram_length = order_minus_2 + 2;
00284 if (HasExtension(*backoff_out)) {
00285 next_use = ret.ngram_length;
00286 }
00287 }
00288 ret.independent_left = true;
00289 typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node));
00290 if (longest.Found()) {
00291 ret.prob = longest.Prob();
00292 ret.rest = ret.prob;
00293
00294 ret.ngram_length = P::Order();
00295 }
00296 }
00297
00298 template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
00299 float ret;
00300 typename Search::Node node;
00301 if (first_length == 1) {
00302 if (pointers_begin >= pointers_end) return 0.0;
00303 bool independent_left;
00304 uint64_t extend_left;
00305 typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(*pointers_begin), node, independent_left, extend_left));
00306 ret = ptr.Prob() - ptr.Rest();
00307 ++first_length;
00308 ++pointers_begin;
00309 } else {
00310 ret = 0.0;
00311 }
00312 for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) {
00313 typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node));
00314 ret += ptr.Prob() - ptr.Rest();
00315 }
00316 return ret;
00317 }
00318
00319 template class GenericModel<HashedSearch<BackoffValue>, ProbingVocabulary>;
00320 template class GenericModel<HashedSearch<RestValue>, ProbingVocabulary>;
00321 template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>;
00322 template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
00323 template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>;
00324 template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
00325
00326 }
00327
00328 base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) {
00329 RecognizeBinary(file_name, model_type);
00330 switch (model_type) {
00331 case PROBING:
00332 return new ProbingModel(file_name, config);
00333 case REST_PROBING:
00334 return new RestProbingModel(file_name, config);
00335 case TRIE:
00336 return new TrieModel(file_name, config);
00337 case QUANT_TRIE:
00338 return new QuantTrieModel(file_name, config);
00339 case ARRAY_TRIE:
00340 return new ArrayTrieModel(file_name, config);
00341 case QUANT_ARRAY_TRIE:
00342 return new QuantArrayTrieModel(file_name, config);
00343 default:
00344 UTIL_THROW(FormatLoadException, "Confused by model type " << model_type);
00345 }
00346 }
00347
00348 }
00349 }