00001 #ifndef LM_SEARCH_HASHED__
00002 #define LM_SEARCH_HASHED__
00003
00004 #include "lm/model_type.hh"
00005 #include "lm/config.hh"
00006 #include "lm/read_arpa.hh"
00007 #include "lm/return.hh"
00008 #include "lm/weights.hh"
00009
00010 #include "util/bit_packing.hh"
00011 #include "util/probing_hash_table.hh"
00012
00013 #include <algorithm>
00014 #include <iostream>
00015 #include <vector>
00016
00017 namespace util { class FilePiece; }
00018
00019 namespace lm {
00020 namespace ngram {
00021 struct Backing;
00022 class ProbingVocabulary;
00023 namespace detail {
00024
00025 inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
00026 uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
00027 return ret;
00028 }
00029
00030 #pragma pack(push)
00031 #pragma pack(4)
00032 struct ProbEntry {
00033 uint64_t key;
00034 Prob value;
00035 typedef uint64_t Key;
00036 typedef Prob Value;
00037 uint64_t GetKey() const {
00038 return key;
00039 }
00040 };
00041
00042 #pragma pack(pop)
00043
00044 class LongestPointer {
00045 public:
00046 explicit LongestPointer(const float &to) : to_(&to) {}
00047
00048 LongestPointer() : to_(NULL) {}
00049
00050 bool Found() const {
00051 return to_ != NULL;
00052 }
00053
00054 float Prob() const {
00055 return *to_;
00056 }
00057
00058 private:
00059 const float *to_;
00060 };
00061
00062 template <class Value> class HashedSearch {
00063 public:
00064 typedef uint64_t Node;
00065
00066 typedef typename Value::ProbingProxy UnigramPointer;
00067 typedef typename Value::ProbingProxy MiddlePointer;
00068 typedef ::lm::ngram::detail::LongestPointer LongestPointer;
00069
00070 static const ModelType kModelType = Value::kProbingModelType;
00071 static const bool kDifferentRest = Value::kDifferentRest;
00072 static const unsigned int kVersion = 0;
00073
00074
00075 static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
00076
00077 static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00078 uint64_t ret = Unigram::Size(counts[0]);
00079 for (unsigned char n = 1; n < counts.size() - 1; ++n) {
00080 ret += Middle::Size(counts[n], config.probing_multiplier);
00081 }
00082 return ret + Longest::Size(counts.back(), config.probing_multiplier);
00083 }
00084
00085 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00086
00087 void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
00088
00089 void LoadedBinary();
00090
00091 unsigned char Order() const {
00092 return middle_.size() + 2;
00093 }
00094
00095 typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); }
00096
00097 UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
00098 extend_left = static_cast<uint64_t>(word);
00099 next = extend_left;
00100 UnigramPointer ret(unigram_.Lookup(word));
00101 independent_left = ret.IndependentLeft();
00102 return ret;
00103 }
00104
00105 #pragma GCC diagnostic ignored "-Wuninitialized"
00106 MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
00107 node = extend_pointer;
00108 typename Middle::ConstIterator found;
00109 bool got = middle_[extend_length - 2].Find(extend_pointer, found);
00110 assert(got);
00111 (void)got;
00112 return MiddlePointer(found->value);
00113 }
00114
00115 MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
00116 node = CombineWordHash(node, word);
00117 typename Middle::ConstIterator found;
00118 if (!middle_[order_minus_2].Find(node, found)) {
00119 independent_left = true;
00120 return MiddlePointer();
00121 }
00122 extend_pointer = node;
00123 MiddlePointer ret(found->value);
00124 independent_left = ret.IndependentLeft();
00125 return ret;
00126 }
00127
00128 LongestPointer LookupLongest(WordIndex word, const Node &node) const {
00129
00130 typename Longest::ConstIterator found;
00131 if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
00132 return LongestPointer(found->value.prob);
00133 }
00134
00135
00136
00137 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00138 assert(begin != end);
00139 node = static_cast<Node>(*begin);
00140 for (const WordIndex *i = begin + 1; i < end; ++i) {
00141 node = CombineWordHash(node, *i);
00142 }
00143 return true;
00144 }
00145
00146 private:
00147
00148 void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
00149
00150 template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
00151
00152 class Unigram {
00153 public:
00154 Unigram() {}
00155
00156 Unigram(void *start, uint64_t count, std::size_t ) :
00157 unigram_(static_cast<typename Value::Weights*>(start))
00158 #ifdef DEBUG
00159 , count_(count)
00160 #endif
00161 {}
00162
00163 static uint64_t Size(uint64_t count) {
00164 return (count + 1) * sizeof(typename Value::Weights);
00165 }
00166
00167 const typename Value::Weights &Lookup(WordIndex index) const {
00168 #ifdef DEBUG
00169 assert(index < count_);
00170 #endif
00171 return unigram_[index];
00172 }
00173
00174 typename Value::Weights &Unknown() { return unigram_[0]; }
00175
00176 void LoadedBinary() {}
00177
00178
00179 typename Value::Weights *Raw() { return unigram_; }
00180
00181 private:
00182 typename Value::Weights *unigram_;
00183 #ifdef DEBUG
00184 uint64_t count_;
00185 #endif
00186 };
00187
00188 Unigram unigram_;
00189
00190 typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
00191 std::vector<Middle> middle_;
00192
00193 typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest;
00194 Longest longest_;
00195 };
00196
00197 }
00198 }
00199 }
00200
00201 #endif // LM_SEARCH_HASHED__