00001
00002 #include "lm/search_trie.hh"
00003
00004 #include "lm/bhiksha.hh"
00005 #include "lm/binary_format.hh"
00006 #include "lm/blank.hh"
00007 #include "lm/lm_exception.hh"
00008 #include "lm/max_order.hh"
00009 #include "lm/quantize.hh"
00010 #include "lm/trie.hh"
00011 #include "lm/trie_sort.hh"
00012 #include "lm/vocab.hh"
00013 #include "lm/weights.hh"
00014 #include "lm/word_index.hh"
00015 #include "util/ersatz_progress.hh"
00016 #include "util/mmap.hh"
00017 #include "util/proxy_iterator.hh"
00018 #include "util/scoped.hh"
00019 #include "util/sized_iterator.hh"
00020
00021 #include <algorithm>
00022 #include <cstring>
00023 #include <cstdio>
00024 #include <cstdlib>
00025 #include <queue>
00026 #include <limits>
00027 #include <numeric>
00028 #include <vector>
00029
00030 #if defined(_WIN32) || defined(_WIN64)
00031 #include <windows.h>
00032 #endif
00033
00034 namespace lm {
00035 namespace ngram {
00036 namespace trie {
00037 namespace {
00038
00039 void ReadOrThrow(FILE *from, void *data, size_t size) {
00040 UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read");
00041 }
00042
00043 int Compare(unsigned char order, const void *first_void, const void *second_void) {
00044 const WordIndex *first = reinterpret_cast<const WordIndex*>(first_void), *second = reinterpret_cast<const WordIndex*>(second_void);
00045 const WordIndex *end = first + order;
00046 for (; first != end; ++first, ++second) {
00047 if (*first < *second) return -1;
00048 if (*first > *second) return 1;
00049 }
00050 return 0;
00051 }
00052
00053 struct ProbPointer {
00054 unsigned char array;
00055 uint64_t index;
00056 };
00057
00058
00059 class BackoffMessages {
00060 public:
00061 void Init(std::size_t entry_size) {
00062 current_ = NULL;
00063 allocated_ = NULL;
00064 entry_size_ = entry_size;
00065 }
00066
00067 void Add(const WordIndex *to, ProbPointer index) {
00068 while (current_ + entry_size_ > allocated_) {
00069 std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get();
00070 Resize(std::max<std::size_t>(allocated_size * 2, entry_size_));
00071 }
00072 memcpy(current_, to, entry_size_ - sizeof(ProbPointer));
00073 *reinterpret_cast<ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)) = index;
00074 current_ += entry_size_;
00075 }
00076
00077 void Apply(float *const *const base, FILE *unigrams) {
00078 FinishedAdding();
00079 if (current_ == allocated_) return;
00080 rewind(unigrams);
00081 ProbBackoff weights;
00082 WordIndex unigram = 0;
00083 ReadOrThrow(unigrams, &weights, sizeof(weights));
00084 for (; current_ != allocated_; current_ += entry_size_) {
00085 const WordIndex &cur_word = *reinterpret_cast<const WordIndex*>(current_);
00086 for (; unigram < cur_word; ++unigram) {
00087 ReadOrThrow(unigrams, &weights, sizeof(weights));
00088 }
00089 if (!HasExtension(weights.backoff)) {
00090 weights.backoff = kExtensionBackoff;
00091 UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
00092 util::WriteOrThrow(unigrams, &weights, sizeof(weights));
00093 }
00094 const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
00095 base[write_to.array][write_to.index] += weights.backoff;
00096 }
00097 backing_.reset();
00098 }
00099
00100 void Apply(float *const *const base, RecordReader &reader) {
00101 FinishedAdding();
00102 if (current_ == allocated_) return;
00103
00104 WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
00105 const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
00106 for (reader.Rewind(); reader && (current_ != allocated_); ) {
00107 switch (Compare(order, reader.Data(), current_)) {
00108 case -1:
00109 ++reader;
00110 break;
00111 case 1:
00112
00113 for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
00114 current_ += entry_size_;
00115 break;
00116 case 0:
00117 float &backoff = reinterpret_cast<ProbBackoff*>((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff;
00118 if (!HasExtension(backoff)) {
00119 backoff = kExtensionBackoff;
00120 reader.Overwrite(&backoff, sizeof(float));
00121 } else {
00122 const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer));
00123 base[write_to.array][write_to.index] += backoff;
00124 }
00125 current_ += entry_size_;
00126 break;
00127 }
00128 }
00129
00130 entry_size_ = sizeof(WordIndex) * order;
00131 Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
00132 current_ = (uint8_t*)backing_.get();
00133 }
00134
00135
00136 bool Extends(unsigned char order, const WordIndex *words) {
00137 if (current_ == allocated_) return false;
00138 assert(order * sizeof(WordIndex) == entry_size_);
00139 while (true) {
00140 switch(Compare(order, words, current_)) {
00141 case 1:
00142 current_ += entry_size_;
00143 if (current_ == allocated_) return false;
00144 break;
00145 case -1:
00146 return false;
00147 case 0:
00148 return true;
00149 }
00150 }
00151 }
00152
00153 private:
00154 void FinishedAdding() {
00155 Resize(current_ - (uint8_t*)backing_.get());
00156
00157 std::sort(
00158 util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),
00159 util::SizedIterator(util::SizedProxy(current_, entry_size_)),
00160 util::SizedCompare<EntryCompare>(EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex))));
00161 current_ = (uint8_t*)backing_.get();
00162 }
00163
00164 void Resize(std::size_t to) {
00165 std::size_t current = current_ - (uint8_t*)backing_.get();
00166 backing_.call_realloc(to);
00167 current_ = (uint8_t*)backing_.get() + current;
00168 allocated_ = (uint8_t*)backing_.get() + to;
00169 }
00170
00171 util::scoped_malloc backing_;
00172
00173 uint8_t *current_, *allocated_;
00174
00175 std::size_t entry_size_;
00176 };
00177
00178 const float kBadProb = std::numeric_limits<float>::infinity();
00179
00180 class SRISucks {
00181 public:
00182 SRISucks() {
00183 for (BackoffMessages *i = messages_; i != messages_ + KENLM_MAX_ORDER - 1; ++i)
00184 i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1));
00185 }
00186
00187 void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) {
00188 assert(prob_basis != kBadProb);
00189 ProbPointer pointer;
00190 pointer.array = order - 1;
00191 pointer.index = values_[order - 1].size();
00192 for (unsigned char i = begin; i < order; ++i) {
00193 messages_[i - 1].Add(to, pointer);
00194 }
00195 values_[order - 1].push_back(prob_basis);
00196 }
00197
00198 void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
00199 for (unsigned char i = 0; i < KENLM_MAX_ORDER - 1; ++i) {
00200 it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
00201 }
00202 messages_[0].Apply(it_, unigram_file);
00203 BackoffMessages *messages = messages_ + 1;
00204 const RecordReader *end = reader + total_order - 2 ;
00205 for (; reader != end; ++messages, ++reader) {
00206 messages->Apply(it_, *reader);
00207 }
00208 }
00209
00210 ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) {
00211 assert(order > 1);
00212 ProbBackoff ret;
00213 ret.prob = *(it_[order - 1]++);
00214 ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff;
00215 return ret;
00216 }
00217
00218 const std::vector<float> &Values(unsigned char order) const {
00219 return values_[order - 1];
00220 }
00221
00222 private:
00223
00224 std::vector<float> values_[KENLM_MAX_ORDER - 1];
00225 BackoffMessages messages_[KENLM_MAX_ORDER - 1];
00226
00227 float *it_[KENLM_MAX_ORDER - 1];
00228 };
00229
00230 class FindBlanks {
00231 public:
00232 FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
00233 : counts_(order), unigrams_(unigrams), sri_(messages) {}
00234
00235 float UnigramProb(WordIndex index) const {
00236 return unigrams_[index].prob;
00237 }
00238
00239 void Unigram(WordIndex ) {
00240 ++counts_[0];
00241 }
00242
00243 void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) {
00244 sri_.Send(lower, order, indices + 1, prob_basis);
00245 ++counts_[order - 1];
00246 }
00247
00248 void Middle(const unsigned char order, const void * ) {
00249 ++counts_[order - 1];
00250 }
00251
00252 void Longest(const void * ) {
00253 ++counts_.back();
00254 }
00255
00256
00257 void Cleanup() {
00258 --counts_[0];
00259 }
00260
00261 const std::vector<uint64_t> &Counts() const {
00262 return counts_;
00263 }
00264
00265 private:
00266 std::vector<uint64_t> counts_;
00267
00268 const ProbBackoff *unigrams_;
00269
00270 SRISucks &sri_;
00271 };
00272
00273
00274 template <class Quant, class Bhiksha> class WriteEntries {
00275 public:
00276 WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
00277 contexts_(contexts),
00278 quant_(quant),
00279 unigrams_(unigrams),
00280 middle_(middle),
00281 longest_(longest),
00282 bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
00283 order_(order),
00284 sri_(sri) {}
00285
00286 float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; }
00287
00288 void Unigram(WordIndex word) {
00289 unigrams_[word].next = bigram_pack_.InsertIndex();
00290 }
00291
00292 void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char , float ) {
00293 ProbBackoff weights = sri_.GetBlank(order_, order, indices);
00294 typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff);
00295 }
00296
00297 void Middle(const unsigned char order, const void *data) {
00298 RecordReader &context = contexts_[order - 1];
00299 const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
00300 ProbBackoff weights = *reinterpret_cast<const ProbBackoff*>(words + order);
00301 if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) {
00302 SetExtension(weights.backoff);
00303 ++context;
00304 }
00305 typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff);
00306 }
00307
00308 void Longest(const void *data) {
00309 const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
00310 typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
00311 }
00312
00313 void Cleanup() {}
00314
00315 private:
00316 RecordReader *contexts_;
00317 const Quant &quant_;
00318 UnigramValue *const unigrams_;
00319 BitPackedMiddle<Bhiksha> *const middle_;
00320 BitPackedLongest &longest_;
00321 BitPacked &bigram_pack_;
00322 const unsigned char order_;
00323 SRISucks &sri_;
00324 };
00325
00326 struct Gram {
00327 Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {}
00328
00329 const WordIndex *begin, *end;
00330
00331
00332 bool operator<(const Gram &other) const {
00333 return std::lexicographical_compare(other.begin, other.end, begin, end);
00334 }
00335 };
00336
00337 template <class Doing> class BlankManager {
00338 public:
00339 BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) {
00340 for (float *i = basis_; i != basis_ + KENLM_MAX_ORDER - 1; ++i) *i = kBadProb;
00341 }
00342
00343 void Visit(const WordIndex *to, unsigned char length, float prob) {
00344 basis_[length - 1] = prob;
00345 unsigned char overlap = std::min<unsigned char>(length - 1, been_length_);
00346 const WordIndex *cur;
00347 WordIndex *pre;
00348 for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) {
00349 if (*pre != *cur) break;
00350 }
00351 if (cur == to + length - 1) {
00352 *pre = *cur;
00353 been_length_ = length;
00354 return;
00355 }
00356
00357 unsigned char blank = cur - to + 1;
00358 UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
00359 const float *lower_basis;
00360 for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {}
00361 unsigned char based_on = lower_basis - basis_ + 1;
00362 for (; cur != to + length - 1; ++blank, ++cur, ++pre) {
00363 assert(*lower_basis != kBadProb);
00364 doing_.MiddleBlank(blank, to, based_on, *lower_basis);
00365 *pre = *cur;
00366
00367 basis_[blank - 1] = kBadProb;
00368 }
00369 *pre = *cur;
00370 been_length_ = length;
00371 }
00372
00373 private:
00374 const unsigned char total_order_;
00375
00376 WordIndex been_[KENLM_MAX_ORDER];
00377 unsigned char been_length_;
00378
00379 float basis_[KENLM_MAX_ORDER];
00380
00381 Doing &doing_;
00382 };
00383
00384 template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
00385 util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
00386 WordIndex unigram = 0;
00387 std::priority_queue<Gram> grams;
00388 grams.push(Gram(&unigram, 1));
00389 for (unsigned char i = 2; i <= total_order; ++i) {
00390 if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
00391 }
00392
00393 BlankManager<Doing> blank(total_order, doing);
00394
00395 while (true) {
00396 Gram top = grams.top();
00397 grams.pop();
00398 unsigned char order = top.end - top.begin;
00399 if (order == 1) {
00400 blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
00401 doing.Unigram(unigram);
00402 progress.Set(unigram);
00403 if (++unigram == unigram_count + 1) break;
00404 grams.push(top);
00405 } else {
00406 if (order == total_order) {
00407 blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
00408 doing.Longest(top.begin);
00409 } else {
00410 blank.Visit(top.begin, order, reinterpret_cast<const ProbBackoff*>(top.end)->prob);
00411 doing.Middle(order, top.begin);
00412 }
00413 RecordReader &reader = input[order - 2];
00414 if (++reader) grams.push(top);
00415 }
00416 }
00417 assert(grams.empty());
00418 doing.Cleanup();
00419 }
00420
00421 void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
00422 if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
00423 if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());
00424 for (unsigned char i = 0; i < initial.size(); ++i) {
00425 if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen");
00426 }
00427 }
00428
00429 template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
00430 std::vector<float> probs(additional), backoffs;
00431 probs.reserve(count + additional.size());
00432 backoffs.reserve(count);
00433 for (reader.Rewind(); reader; ++reader) {
00434 const ProbBackoff &weights = *reinterpret_cast<const ProbBackoff*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
00435 probs.push_back(weights.prob);
00436 if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
00437 ++progress;
00438 }
00439 quant.Train(order, probs, backoffs);
00440 }
00441
00442 template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
00443 std::vector<float> probs, backoffs;
00444 probs.reserve(count);
00445 for (reader.Rewind(); reader; ++reader) {
00446 const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
00447 probs.push_back(weights.prob);
00448 ++progress;
00449 }
00450 quant.TrainProb(order, probs);
00451 }
00452
00453 void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
00454
00455 try {
00456 rewind(file);
00457 for (WordIndex i = 0; i < unigram_count; ++i) {
00458 ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff));
00459 if (contexts && *reinterpret_cast<const WordIndex*>(contexts.Data()) == i) {
00460 SetExtension(unigrams[i].weights.backoff);
00461 ++contexts;
00462 }
00463 }
00464 } catch (util::Exception &e) {
00465 e << " while re-reading unigram probabilities";
00466 throw;
00467 }
00468 }
00469
00470 }
00471
00472 template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
00473 RecordReader inputs[KENLM_MAX_ORDER - 1];
00474 RecordReader contexts[KENLM_MAX_ORDER - 1];
00475
00476 for (unsigned char i = 2; i <= counts.size(); ++i) {
00477 inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
00478 contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex));
00479 }
00480
00481 SRISucks sri;
00482 std::vector<uint64_t> fixed_counts;
00483 util::scoped_FILE unigram_file;
00484 util::scoped_fd unigram_fd(files.StealUnigram());
00485 {
00486 util::scoped_memory unigrams;
00487 MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
00488 FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
00489 RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
00490 fixed_counts = finder.Counts();
00491 }
00492 unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
00493 for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
00494 if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
00495 }
00496 SanityCheckCounts(counts, fixed_counts);
00497 counts = fixed_counts;
00498
00499 sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
00500
00501 out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
00502
00503 for (unsigned char i = 2; i <= counts.size(); ++i) {
00504 inputs[i-2].Rewind();
00505 }
00506 if (Quant::kTrain) {
00507 util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
00508 config.ProgressMessages(), "Quantizing");
00509 for (unsigned char i = 2; i < counts.size(); ++i) {
00510 TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
00511 }
00512 TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
00513 quant.FinishedLoading(config);
00514 }
00515
00516 UnigramValue *unigrams = out.unigram_.Raw();
00517 PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams);
00518 unigram_file.reset();
00519
00520 for (unsigned char i = 2; i <= counts.size(); ++i) {
00521 inputs[i-2].Rewind();
00522 }
00523
00524 {
00525 WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
00526 RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
00527 }
00528
00529
00530 for (unsigned char order = 2; order <= counts.size(); ++order) {
00531 const RecordReader &context = contexts[order - 2];
00532 if (context) {
00533 FormatLoadException e;
00534 e << "A " << static_cast<unsigned int>(order) << "-gram has context";
00535 const WordIndex *ctx = reinterpret_cast<const WordIndex*>(context.Data());
00536 for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) {
00537 e << ' ' << *i;
00538 }
00539 e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
00540 throw e;
00541 }
00542 }
00543
00544
00545
00546 if (out.middle_begin_ != out.middle_end_) {
00547 for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
00548 i->FinishedLoading((i+1)->InsertIndex(), config);
00549 }
00550 (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
00551 }
00552 }
00553
00554 template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
00555 quant_.SetupMemory(start, counts.size(), config);
00556 start += Quant::Size(counts.size(), config);
00557 unigram_.Init(start);
00558 start += Unigram::Size(counts[0]);
00559 FreeMiddles();
00560 middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
00561 middle_end_ = middle_begin_ + (counts.size() - 2);
00562 std::vector<uint8_t*> middle_starts(counts.size() - 2);
00563 for (unsigned char i = 2; i < counts.size(); ++i) {
00564 middle_starts[i-2] = start;
00565 start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);
00566 }
00567
00568 for (unsigned char i = counts.size() - 1; i >= 2; --i) {
00569 new (middle_begin_ + i - 2) Middle(
00570 middle_starts[i-2],
00571 quant_.MiddleBits(config),
00572 counts[i-1],
00573 counts[0],
00574 counts[i],
00575 (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest_) : static_cast<const BitPacked &>(middle_begin_[i-1]),
00576 config);
00577 }
00578 longest_.Init(start, quant_.LongestBits(config), counts[0]);
00579 return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
00580 }
00581
00582 template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
00583 unigram_.LoadedBinary();
00584 for (Middle *i = middle_begin_; i != middle_end_; ++i) {
00585 i->LoadedBinary();
00586 }
00587 longest_.LoadedBinary();
00588 }
00589
00590 template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
00591 std::string temporary_prefix;
00592 if (config.temporary_directory_prefix) {
00593 temporary_prefix = config.temporary_directory_prefix;
00594 } else if (config.write_mmap) {
00595 temporary_prefix = config.write_mmap;
00596 } else {
00597 temporary_prefix = file;
00598 }
00599
00600 SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
00601
00602 BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
00603 }
00604
00605 template class TrieSearch<DontQuantize, DontBhiksha>;
00606 template class TrieSearch<DontQuantize, ArrayBhiksha>;
00607 template class TrieSearch<SeparatelyQuantize, DontBhiksha>;
00608 template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;
00609
00610 }
00611 }
00612 }