00001 #include "lm/binary_format.hh"
00002
00003 #include "lm/lm_exception.hh"
00004 #include "util/file.hh"
00005 #include "util/file_piece.hh"
00006
00007 #include <cstddef>
00008 #include <cstring>
00009 #include <limits>
00010 #include <string>
00011 #include <cstdlib>
00012
00013 #include <stdint.h>
00014
00015 namespace lm {
00016 namespace ngram {
00017
00018 const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
00019
00020 namespace {
00021 const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
00022 const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
00023
00024 const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
00025 const long int kMagicVersion = 5;
00026
00027
00028
00029 struct OldSanity {
00030 char magic[sizeof(kMagicBytes)];
00031 float zero_f, one_f, minus_half_f;
00032 WordIndex one_word_index, max_word_index;
00033 uint64_t one_uint64;
00034
00035 void SetToReference() {
00036 std::memset(this, 0, sizeof(OldSanity));
00037 std::memcpy(magic, kMagicBytes, sizeof(magic));
00038 zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
00039 one_word_index = 1;
00040 max_word_index = std::numeric_limits<WordIndex>::max();
00041 one_uint64 = 1;
00042 }
00043 };
00044
00045
00046
00047 struct Sanity {
00048 char magic[ALIGN8(sizeof(kMagicBytes))];
00049 float zero_f, one_f, minus_half_f;
00050 WordIndex one_word_index, max_word_index, padding_to_8;
00051 uint64_t one_uint64;
00052
00053 void SetToReference() {
00054 std::memset(this, 0, sizeof(Sanity));
00055 std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes));
00056 zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
00057 one_word_index = 1;
00058 max_word_index = std::numeric_limits<WordIndex>::max();
00059 padding_to_8 = 0;
00060 one_uint64 = 1;
00061 }
00062 };
00063
00064 std::size_t TotalHeaderSize(unsigned char order) {
00065 return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
00066 }
00067
00068 void WriteHeader(void *to, const Parameters ¶ms) {
00069 Sanity header = Sanity();
00070 header.SetToReference();
00071 std::memcpy(to, &header, sizeof(Sanity));
00072 char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
00073
00074 *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
00075 out += sizeof(FixedWidthParameters);
00076
00077 uint64_t *counts = reinterpret_cast<uint64_t*>(out);
00078 for (std::size_t i = 0; i < params.counts.size(); ++i) {
00079 counts[i] = params.counts[i];
00080 }
00081 }
00082
00083 }
00084
00085 bool IsBinaryFormat(int fd) {
00086 const uint64_t size = util::SizeFile(fd);
00087 if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
00088
00089 util::scoped_memory memory;
00090 try {
00091 util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
00092 } catch (const util::Exception &e) {
00093 return false;
00094 }
00095 Sanity reference_header = Sanity();
00096 reference_header.SetToReference();
00097 if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
00098 if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
00099 UTIL_THROW(FormatLoadException, "This binary file did not finish building");
00100 }
00101 if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
00102 char *end_ptr;
00103 const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
00104 long int version = std::strtol(begin_version, &end_ptr, 10);
00105 if ((end_ptr != begin_version) && version != kMagicVersion) {
00106 UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
00107 }
00108
00109 OldSanity old_sanity = OldSanity();
00110 old_sanity.SetToReference();
00111 UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
00112 UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
00113 }
00114 return false;
00115 }
00116
00117 void ReadHeader(int fd, Parameters &out) {
00118 util::SeekOrThrow(fd, sizeof(Sanity));
00119 util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed));
00120 if (out.fixed.probing_multiplier < 1.0)
00121 UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
00122
00123 out.counts.resize(static_cast<std::size_t>(out.fixed.order));
00124 if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
00125 }
00126
00127 void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) {
00128 if (params.fixed.model_type != model_type) {
00129 if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
00130 UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
00131 UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
00132 }
00133 UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
00134 }
00135
00136 const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
00137
00138 BinaryFormat::BinaryFormat(const Config &config)
00139 : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
00140 header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
00141
00142 void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms) {
00143 file_.reset(fd);
00144 write_mmap_ = NULL;
00145 ReadHeader(fd, params);
00146 MatchCheck(model_type, search_version, params);
00147 header_size_ = TotalHeaderSize(params.counts.size());
00148 }
00149
00150 void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
00151 assert(header_size_ != kInvalidSize);
00152 util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_);
00153 }
00154
00155 void *BinaryFormat::LoadBinary(std::size_t size) {
00156 assert(header_size_ != kInvalidSize);
00157 const uint64_t file_size = util::SizeFile(file_.get());
00158
00159 uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
00160 UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
00161
00162 util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
00163
00164 vocab_string_offset_ = total_map;
00165 return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
00166 }
00167
00168 void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
00169 vocab_size_ = memory_size;
00170 if (!write_mmap_) {
00171 header_size_ = 0;
00172 util::HugeMalloc(memory_size, true, memory_vocab_);
00173 return reinterpret_cast<uint8_t*>(memory_vocab_.get());
00174 }
00175 header_size_ = TotalHeaderSize(order);
00176 std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
00177 file_.reset(util::CreateOrThrow(write_mmap_));
00178
00179 void *vocab_base = NULL;
00180 switch (write_method_) {
00181 case Config::WRITE_MMAP:
00182 mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
00183 util::AdviseHugePages(vocab_base, total);
00184 vocab_base = mapping_.get();
00185 break;
00186 case Config::WRITE_AFTER:
00187 util::ResizeOrThrow(file_.get(), 0);
00188 util::HugeMalloc(total, true, memory_vocab_);
00189 vocab_base = memory_vocab_.get();
00190 break;
00191 }
00192 strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
00193 return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
00194 }
00195
00196 void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
00197 assert(vocab_size_ != kInvalidSize);
00198 vocab_pad_ = vocab_pad;
00199 std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
00200 vocab_string_offset_ = new_size;
00201 if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
00202 util::HugeMalloc(memory_size, true, memory_search_);
00203 assert(header_size_ == 0 || write_mmap_);
00204 vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
00205 util::AdviseHugePages(memory_search_.get(), memory_size);
00206 return reinterpret_cast<uint8_t*>(memory_search_.get());
00207 }
00208
00209 assert(write_method_ == Config::WRITE_MMAP);
00210
00211
00212
00213
00214
00215 mapping_.reset();
00216 util::ResizeOrThrow(file_.get(), new_size);
00217 void *ret;
00218 MapFile(vocab_base, ret);
00219 util::AdviseHugePages(ret, new_size);
00220 return ret;
00221 }
00222
00223 void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
00224
00225 assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
00226 if (!write_mmap_) {
00227
00228 vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
00229 search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
00230 return;
00231 }
00232 if (write_method_ == Config::WRITE_MMAP) {
00233 mapping_.reset();
00234 }
00235 util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
00236 util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
00237 if (write_method_ == Config::WRITE_MMAP) {
00238 MapFile(vocab_base, search_base);
00239 } else {
00240 vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
00241 search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
00242 }
00243 }
00244
00245 void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
00246 if (!write_mmap_) return;
00247 switch (write_method_) {
00248 case Config::WRITE_MMAP:
00249 util::SyncOrThrow(mapping_.get(), mapping_.size());
00250 break;
00251 case Config::WRITE_AFTER:
00252 util::SeekOrThrow(file_.get(), 0);
00253 util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
00254 util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
00255 util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
00256 util::FSyncOrThrow(file_.get());
00257 break;
00258 }
00259
00260 Parameters params = Parameters();
00261 memset(¶ms, 0, sizeof(Parameters));
00262 params.counts = counts;
00263 params.fixed.order = counts.size();
00264 params.fixed.probing_multiplier = config.probing_multiplier;
00265 params.fixed.model_type = model_type;
00266 params.fixed.has_vocabulary = config.include_vocab;
00267 params.fixed.search_version = search_version;
00268 switch (write_method_) {
00269 case Config::WRITE_MMAP:
00270 WriteHeader(mapping_.get(), params);
00271 util::SyncOrThrow(mapping_.get(), mapping_.size());
00272 break;
00273 case Config::WRITE_AFTER:
00274 {
00275 std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
00276 WriteHeader(&buffer[0], params);
00277 util::SeekOrThrow(file_.get(), 0);
00278 util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
00279 }
00280 break;
00281 }
00282 }
00283
00284 void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
00285 mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
00286 vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
00287 search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
00288 }
00289
00290 bool RecognizeBinary(const char *file, ModelType &recognized) {
00291 util::scoped_fd fd(util::OpenReadOrThrow(file));
00292 if (!IsBinaryFormat(fd.get())) {
00293 return false;
00294 }
00295 Parameters params;
00296 ReadHeader(fd.get(), params);
00297 recognized = params.fixed.model_type;
00298 return true;
00299 }
00300
00301 }
00302 }