00001 #include "lm/binary_format.hh"
00002
00003 #include "lm/lm_exception.hh"
00004 #include "util/file.hh"
00005 #include "util/file_piece.hh"
00006
00007 #include <cstddef>
00008 #include <cstring>
00009 #include <limits>
00010 #include <string>
00011
00012 #include <stdint.h>
00013
00014 namespace lm {
00015 namespace ngram {
00016 namespace {
00017 const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
00018 const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
00019
00020 const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
00021 const long int kMagicVersion = 5;
00022
00023
00024
00025 struct OldSanity {
00026 char magic[sizeof(kMagicBytes)];
00027 float zero_f, one_f, minus_half_f;
00028 WordIndex one_word_index, max_word_index;
00029 uint64_t one_uint64;
00030
00031 void SetToReference() {
00032 std::memset(this, 0, sizeof(OldSanity));
00033 std::memcpy(magic, kMagicBytes, sizeof(magic));
00034 zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
00035 one_word_index = 1;
00036 max_word_index = std::numeric_limits<WordIndex>::max();
00037 one_uint64 = 1;
00038 }
00039 };
00040
00041
00042
00043 struct Sanity {
00044 char magic[ALIGN8(sizeof(kMagicBytes))];
00045 float zero_f, one_f, minus_half_f;
00046 WordIndex one_word_index, max_word_index, padding_to_8;
00047 uint64_t one_uint64;
00048
00049 void SetToReference() {
00050 std::memset(this, 0, sizeof(Sanity));
00051 std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes));
00052 zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
00053 one_word_index = 1;
00054 max_word_index = std::numeric_limits<WordIndex>::max();
00055 padding_to_8 = 0;
00056 one_uint64 = 1;
00057 }
00058 };
00059
00060 const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
00061
00062 std::size_t TotalHeaderSize(unsigned char order) {
00063 return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
00064 }
00065
00066 void WriteHeader(void *to, const Parameters ¶ms) {
00067 Sanity header = Sanity();
00068 header.SetToReference();
00069 std::memcpy(to, &header, sizeof(Sanity));
00070 char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
00071
00072 *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
00073 out += sizeof(FixedWidthParameters);
00074
00075 uint64_t *counts = reinterpret_cast<uint64_t*>(out);
00076 for (std::size_t i = 0; i < params.counts.size(); ++i) {
00077 counts[i] = params.counts[i];
00078 }
00079 }
00080
00081 }
00082
00083 uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
00084 if (config.write_mmap) {
00085 std::size_t total = TotalHeaderSize(order) + memory_size;
00086 backing.file.reset(util::CreateOrThrow(config.write_mmap));
00087 if (config.write_method == Config::WRITE_MMAP) {
00088 backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
00089 } else {
00090 util::ResizeOrThrow(backing.file.get(), 0);
00091 util::MapAnonymous(total, backing.vocab);
00092 }
00093 strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
00094 return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
00095 } else {
00096 util::MapAnonymous(memory_size, backing.vocab);
00097 return reinterpret_cast<uint8_t*>(backing.vocab.get());
00098 }
00099 }
00100
00101 uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
00102 std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
00103 if (config.write_mmap) {
00104
00105 try {
00106 util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
00107 } catch (util::ErrnoException &e) {
00108 e << " for file " << config.write_mmap;
00109 throw e;
00110 }
00111
00112 if (config.write_method == Config::WRITE_AFTER) {
00113 util::MapAnonymous(memory_size, backing.search);
00114 return reinterpret_cast<uint8_t*>(backing.search.get());
00115 }
00116
00117
00118 std::size_t page_size = util::SizePage();
00119 std::size_t alignment_cruft = adjusted_vocab % page_size;
00120 backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
00121 return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
00122 } else {
00123 util::MapAnonymous(memory_size, backing.search);
00124 return reinterpret_cast<uint8_t*>(backing.search.get());
00125 }
00126 }
00127
00128 void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
00129 if (!config.write_mmap) return;
00130 switch (config.write_method) {
00131 case Config::WRITE_MMAP:
00132 util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
00133 util::SyncOrThrow(backing.search.get(), backing.search.size());
00134 break;
00135 case Config::WRITE_AFTER:
00136 util::SeekOrThrow(backing.file.get(), 0);
00137 util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
00138 util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
00139 util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
00140 util::FSyncOrThrow(backing.file.get());
00141 break;
00142 }
00143
00144 Parameters params = Parameters();
00145 params.counts = counts;
00146 params.fixed.order = counts.size();
00147 params.fixed.probing_multiplier = config.probing_multiplier;
00148 params.fixed.model_type = model_type;
00149 params.fixed.has_vocabulary = config.include_vocab;
00150 params.fixed.search_version = search_version;
00151 WriteHeader(backing.vocab.get(), params);
00152 if (config.write_method == Config::WRITE_AFTER) {
00153 util::SeekOrThrow(backing.file.get(), 0);
00154 util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
00155 }
00156 }
00157
00158 namespace detail {
00159
00160 bool IsBinaryFormat(int fd) {
00161 const uint64_t size = util::SizeFile(fd);
00162 if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
00163
00164 util::scoped_memory memory;
00165 try {
00166 util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
00167 } catch (const util::Exception &e) {
00168 return false;
00169 }
00170 Sanity reference_header = Sanity();
00171 reference_header.SetToReference();
00172 if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
00173 if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
00174 UTIL_THROW(FormatLoadException, "This binary file did not finish building");
00175 }
00176 if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
00177 char *end_ptr;
00178 const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
00179 long int version = strtol(begin_version, &end_ptr, 10);
00180 if ((end_ptr != begin_version) && version != kMagicVersion) {
00181 UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
00182 }
00183
00184 OldSanity old_sanity = OldSanity();
00185 old_sanity.SetToReference();
00186 UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
00187 UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
00188 }
00189 return false;
00190 }
00191
00192 void ReadHeader(int fd, Parameters &out) {
00193 util::SeekOrThrow(fd, sizeof(Sanity));
00194 util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed));
00195 if (out.fixed.probing_multiplier < 1.0)
00196 UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
00197
00198 out.counts.resize(static_cast<std::size_t>(out.fixed.order));
00199 if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
00200 }
00201
00202 void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) {
00203 if (params.fixed.model_type != model_type) {
00204 if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
00205 UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
00206 UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
00207 }
00208 UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
00209 }
00210
00211 void SeekPastHeader(int fd, const Parameters ¶ms) {
00212 util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
00213 }
00214
00215 uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) {
00216 const uint64_t file_size = util::SizeFile(backing.file.get());
00217
00218 std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
00219 if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
00220 UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
00221
00222 util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
00223
00224 if (config.enumerate_vocab && !params.fixed.has_vocabulary)
00225 UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
00226
00227
00228 util::SeekOrThrow(backing.file.get(), total_map);
00229 return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
00230 }
00231
00232 void ComplainAboutARPA(const Config &config, ModelType model_type) {
00233 if (config.write_mmap || !config.messages) return;
00234 if (config.arpa_complain == Config::ALL) {
00235 *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
00236 } else if (config.arpa_complain == Config::EXPENSIVE &&
00237 (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
00238 *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
00239 }
00240 }
00241
00242 }
00243
00244 bool RecognizeBinary(const char *file, ModelType &recognized) {
00245 util::scoped_fd fd(util::OpenReadOrThrow(file));
00246 if (!detail::IsBinaryFormat(fd.get())) return false;
00247 Parameters params;
00248 detail::ReadHeader(fd.get(), params);
00249 recognized = params.fixed.model_type;
00250 return true;
00251 }
00252
00253 }
00254 }