00001 #ifndef LM_VOCAB__
00002 #define LM_VOCAB__
00003
00004 #include "lm/enumerate_vocab.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/virtual_interface.hh"
00007 #include "util/pool.hh"
00008 #include "util/probing_hash_table.hh"
00009 #include "util/sorted_uniform.hh"
00010 #include "util/string_piece.hh"
00011
00012 #include <limits>
00013 #include <string>
00014 #include <vector>
00015
00016 namespace lm {
00017 struct ProbBackoff;
00018 class EnumerateVocab;
00019
00020 namespace ngram {
00021 struct Config;
00022
00023 namespace detail {
00024 uint64_t HashForVocab(const char *str, std::size_t len);
00025 inline uint64_t HashForVocab(const StringPiece &str) {
00026 return HashForVocab(str.data(), str.length());
00027 }
00028 struct ProbingVocabularyHeader;
00029 }
00030
00031 class WriteWordsWrapper : public EnumerateVocab {
00032 public:
00033 WriteWordsWrapper(EnumerateVocab *inner);
00034
00035 ~WriteWordsWrapper();
00036
00037 void Add(WordIndex index, const StringPiece &str);
00038
00039 void Write(int fd, uint64_t start);
00040
00041 private:
00042 EnumerateVocab *inner_;
00043
00044 std::string buffer_;
00045 };
00046
00047
00048 class SortedVocabulary : public base::Vocabulary {
00049 public:
00050 SortedVocabulary();
00051
00052 WordIndex Index(const StringPiece &str) const {
00053 const uint64_t *found;
00054 if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
00055 util::IdentityAccessor<uint64_t>(),
00056 begin_ - 1, 0,
00057 end_, std::numeric_limits<uint64_t>::max(),
00058 detail::HashForVocab(str), found)) {
00059 return found - begin_ + 1;
00060 } else {
00061 return 0;
00062 }
00063 }
00064
00065
00066 static uint64_t Size(uint64_t entries, const Config &config);
00067
00068
00069 WordIndex Bound() const { return bound_; }
00070
00071
00072 void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
00073
00074 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00075
00076 WordIndex Insert(const StringPiece &str);
00077
00078
00079 void FinishedLoading(ProbBackoff *reorder_vocab);
00080
00081
00082 std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
00083
00084 bool SawUnk() const { return saw_unk_; }
00085
00086 void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
00087
00088 private:
00089 uint64_t *begin_, *end_;
00090
00091 WordIndex bound_;
00092
00093 WordIndex highest_value_;
00094
00095 bool saw_unk_;
00096
00097 EnumerateVocab *enumerate_;
00098
00099
00100 util::Pool string_backing_;
00101
00102 std::vector<StringPiece> strings_to_enumerate_;
00103 };
00104
00105 #pragma pack(push)
00106 #pragma pack(4)
00107 struct ProbingVocabuaryEntry {
00108 uint64_t key;
00109 WordIndex value;
00110
00111 typedef uint64_t Key;
00112 uint64_t GetKey() const {
00113 return key;
00114 }
00115
00116 static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) {
00117 ProbingVocabuaryEntry ret;
00118 ret.key = key;
00119 ret.value = value;
00120 return ret;
00121 }
00122 };
00123 #pragma pack(pop)
00124
00125
00126 class ProbingVocabulary : public base::Vocabulary {
00127 public:
00128 ProbingVocabulary();
00129
00130 WordIndex Index(const StringPiece &str) const {
00131 Lookup::ConstIterator i;
00132 return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
00133 }
00134
00135 static uint64_t Size(uint64_t entries, const Config &config);
00136
00137
00138 WordIndex Bound() const { return bound_; }
00139
00140
00141 void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
00142
00143 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00144
00145 WordIndex Insert(const StringPiece &str);
00146
00147 template <class Weights> void FinishedLoading(Weights * ) {
00148 InternalFinishedLoading();
00149 }
00150
00151 std::size_t UnkCountChangePadding() const { return 0; }
00152
00153 bool SawUnk() const { return saw_unk_; }
00154
00155 void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
00156
00157 private:
00158 void InternalFinishedLoading();
00159
00160 typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
00161
00162 Lookup lookup_;
00163
00164 WordIndex bound_;
00165
00166 bool saw_unk_;
00167
00168 EnumerateVocab *enumerate_;
00169
00170 detail::ProbingVocabularyHeader *header_;
00171 };
00172
00173 void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
00174 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
00175
00176 template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
00177 if (!vocab.SawUnk()) MissingUnknown(config);
00178 if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
00179 if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
00180 }
00181
00182 }
00183 }
00184
00185 #endif // LM_VOCAB__