00001 #ifndef LM_SEARCH_TRIE_H
00002 #define LM_SEARCH_TRIE_H
00003
00004 #include "lm/config.hh"
00005 #include "lm/model_type.hh"
00006 #include "lm/return.hh"
00007 #include "lm/trie.hh"
00008 #include "lm/weights.hh"
00009
00010 #include "util/file.hh"
00011 #include "util/file_piece.hh"
00012
00013 #include <vector>
00014 #include <cstdlib>
00015 #include <cassert>
00016
00017 namespace lm {
00018 namespace ngram {
00019 class BinaryFormat;
00020 class SortedVocabulary;
00021 namespace trie {
00022
00023 template <class Quant, class Bhiksha> class TrieSearch;
00024 class SortedFiles;
00025 template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
00026
00027 template <class Quant, class Bhiksha> class TrieSearch {
00028 public:
00029 typedef NodeRange Node;
00030
00031 typedef ::lm::ngram::trie::UnigramPointer UnigramPointer;
00032 typedef typename Quant::MiddlePointer MiddlePointer;
00033 typedef typename Quant::LongestPointer LongestPointer;
00034
00035 static const bool kDifferentRest = false;
00036
00037 static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
00038
00039 static const unsigned int kVersion = 1;
00040
00041 static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
00042 Quant::UpdateConfigFromBinary(file, offset, config);
00043
00044 if (counts.size() > 2)
00045 Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
00046 }
00047
00048 static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00049 uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
00050 for (unsigned char i = 1; i < counts.size() - 1; ++i) {
00051 ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
00052 }
00053 return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
00054 }
00055
00056 TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}
00057
00058 ~TrieSearch() { FreeMiddles(); }
00059
00060 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00061
00062 void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
00063
00064 unsigned char Order() const {
00065 return middle_end_ - middle_begin_ + 2;
00066 }
00067
00068 ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); }
00069
00070 UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
00071 extend_left = static_cast<uint64_t>(word);
00072 UnigramPointer ret(unigram_.Find(word, next));
00073 independent_left = (next.begin == next.end);
00074 return ret;
00075 }
00076
00077 MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
00078 return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node));
00079 }
00080
00081 MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const {
00082 util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left));
00083 independent_left = (address.base == NULL) || (node.begin == node.end);
00084 return MiddlePointer(quant_, order_minus_2, address);
00085 }
00086
00087 LongestPointer LookupLongest(WordIndex word, const Node &node) const {
00088 return LongestPointer(quant_, longest_.Find(word, node));
00089 }
00090
00091 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00092 assert(begin != end);
00093 bool independent_left;
00094 uint64_t ignored;
00095 LookupUnigram(*begin, node, independent_left, ignored);
00096 for (const WordIndex *i = begin + 1; i < end; ++i) {
00097 if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false;
00098 }
00099 return true;
00100 }
00101
00102 private:
00103 friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
00104
00105
00106 void FreeMiddles() {
00107 for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
00108 i->~Middle();
00109 }
00110 std::free(middle_begin_);
00111 }
00112
00113 typedef trie::BitPackedMiddle<Bhiksha> Middle;
00114
00115 typedef trie::BitPackedLongest Longest;
00116 Longest longest_;
00117
00118 Middle *middle_begin_, *middle_end_;
00119 Quant quant_;
00120
00121 typedef ::lm::ngram::trie::Unigram Unigram;
00122 Unigram unigram_;
00123 };
00124
00125 }
00126 }
00127 }
00128
00129 #endif // LM_SEARCH_TRIE_H