00001 #include "lm/trie.hh"
00002
00003 #include "lm/bhiksha.hh"
00004 #include "util/bit_packing.hh"
00005 #include "util/exception.hh"
00006 #include "util/sorted_uniform.hh"
00007
00008 #include <cassert>
00009
00010 namespace lm {
00011 namespace ngram {
00012 namespace trie {
00013 namespace {
00014
00015 class KeyAccessor {
00016 public:
00017 KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
00018 : base_(reinterpret_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
00019
00020 typedef uint64_t Key;
00021
00022 Key operator()(uint64_t index) const {
00023 return util::ReadInt57(base_, index * static_cast<uint64_t>(total_bits_), key_bits_, key_mask_);
00024 }
00025
00026 private:
00027 const uint8_t *const base_;
00028 const WordIndex key_mask_;
00029 const uint8_t key_bits_, total_bits_;
00030 };
00031
00032 bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) {
00033 KeyAccessor accessor(base, key_mask, key_bits, total_bits);
00034 if (!util::BoundedSortedUniformFind<uint64_t, KeyAccessor, util::PivotSelect<sizeof(WordIndex)>::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false;
00035 return true;
00036 }
00037 }
00038
00039 uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
00040 uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
00041
00042
00043
00044
00045 return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
00046 }
00047
00048 void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
00049 util::BitPackingSanity();
00050 word_bits_ = util::RequiredBits(max_vocab);
00051 word_mask_ = (1ULL << word_bits_) - 1ULL;
00052 if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
00053 total_bits_ = word_bits_ + remaining_bits;
00054
00055 base_ = static_cast<uint8_t*>(base);
00056 insert_index_ = 0;
00057 max_vocab_ = max_vocab;
00058 }
00059
00060 template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
00061 return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
00062 }
00063
00064 template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
00065 BitPacked(),
00066 quant_bits_(quant_bits),
00067
00068 bhiksha_(base, entries + 1, max_next, config),
00069 next_source_(&next_source) {
00070 if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
00071 BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits());
00072 }
00073
00074 template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) {
00075 assert(word <= word_mask_);
00076 uint64_t at_pointer = insert_index_ * total_bits_;
00077
00078 util::WriteInt57(base_, at_pointer, word_bits_, word);
00079 at_pointer += word_bits_;
00080 util::BitAddress ret(base_, at_pointer);
00081 at_pointer += quant_bits_;
00082 uint64_t next = next_source_->InsertIndex();
00083 bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
00084 ++insert_index_;
00085 return ret;
00086 }
00087
00088 template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const {
00089 uint64_t at_pointer;
00090 if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
00091 return util::BitAddress(NULL, 0);
00092 }
00093 pointer = at_pointer;
00094 at_pointer *= total_bits_;
00095 at_pointer += word_bits_;
00096 bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range);
00097
00098 return util::BitAddress(base_, at_pointer);
00099 }
00100
00101 template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
00102
00103 uint64_t last_next_write = insert_index_ * total_bits_ +
00104
00105 (total_bits_ - bhiksha_.InlineBits());
00106 bhiksha_.WriteNext(base_, last_next_write, insert_index_, next_end);
00107 bhiksha_.FinishedLoading(config);
00108 }
00109
00110 util::BitAddress BitPackedLongest::Insert(WordIndex index) {
00111 assert(index <= word_mask_);
00112 uint64_t at_pointer = insert_index_ * total_bits_;
00113 util::WriteInt57(base_, at_pointer, word_bits_, index);
00114 at_pointer += word_bits_;
00115 ++insert_index_;
00116 return util::BitAddress(base_, at_pointer);
00117 }
00118
00119 util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const {
00120 uint64_t at_pointer;
00121 if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0);
00122 at_pointer = at_pointer * total_bits_ + word_bits_;
00123 return util::BitAddress(base_, at_pointer);
00124 }
00125
00126 template class BitPackedMiddle<DontBhiksha>;
00127 template class BitPackedMiddle<ArrayBhiksha>;
00128
00129 }
00130 }
00131 }