00001 #include "lm/vocab.hh"
00002
00003 #include "lm/binary_format.hh"
00004 #include "lm/enumerate_vocab.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/config.hh"
00007 #include "lm/weights.hh"
00008 #include "util/exception.hh"
00009 #include "util/file_stream.hh"
00010 #include "util/file.hh"
00011 #include "util/joint_sort.hh"
00012 #include "util/murmur_hash.hh"
00013 #include "util/probing_hash_table.hh"
00014
00015 #include <cstring>
00016 #include <string>
00017
00018 namespace lm {
00019 namespace ngram {
00020
00021 namespace detail {
00022 uint64_t HashForVocab(const char *str, std::size_t len) {
00023
00024
00025 return util::MurmurHash64A(str, len, 0);
00026 }
00027 }
00028
00029 namespace {
00030
00031 const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
00032
00033 const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
00034
00035
00036 void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
00037 util::SeekOrThrow(fd, offset);
00038
00039 char check_unk[6];
00040 util::ReadOrThrow(fd, check_unk, 6);
00041 UTIL_THROW_IF(
00042 memcmp(check_unk, "<unk>", 6),
00043 FormatLoadException,
00044 "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure.");
00045 if (!enumerate) return;
00046 enumerate->Add(0, "<unk>");
00047
00048
00049 const std::size_t kInitialRead = 16384;
00050 std::string buf;
00051 buf.reserve(kInitialRead + 100);
00052 buf.resize(kInitialRead);
00053 WordIndex index = 1;
00054 while (true) {
00055 std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead);
00056 if (got == 0) break;
00057 buf.resize(got);
00058 while (buf[buf.size() - 1]) {
00059 char next_char;
00060 util::ReadOrThrow(fd, &next_char, 1);
00061 buf.push_back(next_char);
00062 }
00063
00064 for (const char *i = buf.data(); i != buf.data() + buf.size();) {
00065 std::size_t length = strlen(i);
00066 enumerate->Add(index++, StringPiece(i, length));
00067 i += length + 1 ;
00068 }
00069 }
00070
00071 UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
00072 }
00073
00074
00075 int SeekAndReturn(int fd, uint64_t start) {
00076 util::SeekOrThrow(fd, start);
00077 return fd;
00078 }
00079 }
00080
00081 ImmediateWriteWordsWrapper::ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start)
00082 : inner_(inner), stream_(SeekAndReturn(fd, start)) {}
00083
00084 WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
00085
00086 void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
00087 if (inner_) inner_->Add(index, str);
00088 buffer_.append(str.data(), str.size());
00089 buffer_.push_back(0);
00090 }
00091
00092 void WriteWordsWrapper::Write(int fd, uint64_t start) {
00093 util::SeekOrThrow(fd, start);
00094 util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
00095
00096 std::string for_swap;
00097 std::swap(buffer_, for_swap);
00098 }
00099
00100 SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
00101
00102 uint64_t SortedVocabulary::Size(uint64_t entries, const Config &) {
00103
00104 return sizeof(uint64_t) + sizeof(uint64_t) * entries;
00105 }
00106
00107 void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
00108 assert(allocated >= Size(entries, config));
00109
00110 begin_ = reinterpret_cast<uint64_t*>(start) + 1;
00111 end_ = begin_;
00112 saw_unk_ = false;
00113 }
00114
00115 void SortedVocabulary::Relocate(void *new_start) {
00116 std::size_t delta = end_ - begin_;
00117 begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
00118 end_ = begin_ + delta;
00119 }
00120
00121 void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
00122 enumerate_ = to;
00123 if (enumerate_) {
00124 enumerate_->Add(0, "<unk>");
00125 strings_to_enumerate_.resize(max_entries);
00126 }
00127 }
00128
00129 WordIndex SortedVocabulary::Insert(const StringPiece &str) {
00130 uint64_t hashed = detail::HashForVocab(str);
00131 if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
00132 saw_unk_ = true;
00133 return 0;
00134 }
00135 *end_ = hashed;
00136 if (enumerate_) {
00137 void *copied = string_backing_.Allocate(str.size());
00138 memcpy(copied, str.data(), str.size());
00139 strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
00140 }
00141 ++end_;
00142
00143 return end_ - begin_;
00144 }
00145
00146 void SortedVocabulary::FinishedLoading(ProbBackoff *reorder) {
00147 GenericFinished(reorder);
00148 }
00149
00150 namespace {
00151 #pragma pack(push)
00152 #pragma pack(4)
00153 struct RenumberEntry {
00154 uint64_t hash;
00155 const char *str;
00156 WordIndex old;
00157 bool operator<(const RenumberEntry &other) const {
00158 return hash < other.hash;
00159 }
00160 };
00161 #pragma pack(pop)
00162 }
00163
00164 void SortedVocabulary::ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping) {
00165 mapping.clear();
00166 uint64_t file_size = util::SizeOrThrow(from_words);
00167 util::scoped_memory strings;
00168 util::MapRead(util::POPULATE_OR_READ, from_words, 0, file_size, strings);
00169 const char *const start = static_cast<const char*>(strings.get());
00170 UTIL_THROW_IF(memcmp(start, "<unk>", 6), FormatLoadException, "Vocab file does not begin with <unk> followed by null");
00171 std::vector<RenumberEntry> entries;
00172 entries.reserve(types - 1);
00173 RenumberEntry entry;
00174 entry.old = 1;
00175 for (entry.str = start + 6 ; entry.str < start + file_size; ++entry.old) {
00176 StringPiece str(entry.str, strlen(entry.str));
00177 entry.hash = detail::HashForVocab(str);
00178 entries.push_back(entry);
00179 entry.str += str.size() + 1;
00180 }
00181 UTIL_THROW_IF2(entries.size() != types - 1, "Wrong number of vocab ids. Got " << (entries.size() + 1) << " expected " << types);
00182 std::sort(entries.begin(), entries.end());
00183
00184 {
00185 util::FileStream out(to_words);
00186 out << "<unk>" << '\0';
00187 for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
00188 out << i->str << '\0';
00189 }
00190 }
00191 strings.reset();
00192
00193 mapping.resize(types);
00194 mapping[0] = 0;
00195 for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
00196 mapping[i->old] = i + 1 - entries.begin();
00197 }
00198 }
00199
00200 void SortedVocabulary::Populated() {
00201 saw_unk_ = true;
00202 SetSpecial(Index("<s>"), Index("</s>"), 0);
00203 bound_ = end_ - begin_ + 1;
00204 *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
00205 }
00206
00207 void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
00208 end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
00209 SetSpecial(Index("<s>"), Index("</s>"), 0);
00210 bound_ = end_ - begin_ + 1;
00211 if (have_words) ReadWords(fd, to, bound_, offset);
00212 }
00213
00214 template <class T> void SortedVocabulary::GenericFinished(T *reorder) {
00215 if (enumerate_) {
00216 if (!strings_to_enumerate_.empty()) {
00217 util::PairedIterator<T*, StringPiece*> values(reorder + 1, &*strings_to_enumerate_.begin());
00218 util::JointSort(begin_, end_, values);
00219 }
00220 for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
00221
00222 enumerate_->Add(i + 1, strings_to_enumerate_[i]);
00223 }
00224 strings_to_enumerate_.clear();
00225 string_backing_.FreeAll();
00226 } else {
00227 util::JointSort(begin_, end_, reorder + 1);
00228 }
00229 SetSpecial(Index("<s>"), Index("</s>"), 0);
00230
00231 *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
00232
00233 bound_ = end_ - begin_ + 1;
00234 }
00235
00236 namespace {
00237 const unsigned int kProbingVocabularyVersion = 0;
00238 }
00239
00240 namespace detail {
00241 struct ProbingVocabularyHeader {
00242
00243 unsigned int version;
00244 WordIndex bound;
00245 };
00246 }
00247
00248 ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
00249
00250 uint64_t ProbingVocabulary::Size(uint64_t entries, float probing_multiplier) {
00251 return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, probing_multiplier);
00252 }
00253
00254 uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
00255 return Size(entries, config.probing_multiplier);
00256 }
00257
00258 void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated) {
00259 header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
00260 lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated);
00261 bound_ = 1;
00262 saw_unk_ = false;
00263 }
00264
00265 void ProbingVocabulary::Relocate(void *new_start) {
00266 header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
00267 lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
00268 }
00269
00270 void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t ) {
00271 enumerate_ = to;
00272 if (enumerate_) {
00273 enumerate_->Add(0, "<unk>");
00274 }
00275 }
00276
00277 WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
00278 uint64_t hashed = detail::HashForVocab(str);
00279
00280 if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
00281 saw_unk_ = true;
00282 return 0;
00283 } else {
00284 if (enumerate_) enumerate_->Add(bound_, str);
00285 lookup_.Insert(ProbingVocabularyEntry::Make(hashed, bound_));
00286 return bound_++;
00287 }
00288 }
00289
00290 void ProbingVocabulary::InternalFinishedLoading() {
00291 lookup_.FinishedInserting();
00292 header_->bound = bound_;
00293 header_->version = kProbingVocabularyVersion;
00294 SetSpecial(Index("<s>"), Index("</s>"), 0);
00295 }
00296
00297 void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
00298 UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
00299 bound_ = header_->bound;
00300 SetSpecial(Index("<s>"), Index("</s>"), 0);
00301 if (have_words) ReadWords(fd, to, bound_, offset);
00302 }
00303
00304 void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
00305 switch(config.unknown_missing) {
00306 case SILENT:
00307 return;
00308 case COMPLAIN:
00309 if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
00310 break;
00311 case THROW_UP:
00312 UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception.");
00313 }
00314 }
00315
00316 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
00317 switch (config.sentence_marker_missing) {
00318 case SILENT:
00319 return;
00320 case COMPLAIN:
00321 if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
00322 break;
00323 case THROW_UP:
00324 UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check.");
00325 }
00326 }
00327
00328 }
00329 }