00001 #include "lm/builder/adjust_counts.hh"
00002 #include "lm/common/ngram_stream.hh"
00003 #include "lm/builder/payload.hh"
00004 #include "util/stream/timer.hh"
00005
00006 #include <algorithm>
00007 #include <iostream>
00008 #include <limits>
00009
00010 namespace lm { namespace builder {
00011
00012 BadDiscountException::BadDiscountException() throw() {}
00013 BadDiscountException::~BadDiscountException() throw() {}
00014
00015 namespace {
00016
00017 const WordIndex* FindDifference(const NGram<BuildingPayload> &full, const NGram<BuildingPayload> &lower_last) {
00018 const WordIndex *cur_word = full.end() - 1;
00019 const WordIndex *pre_word = lower_last.end() - 1;
00020
00021 for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
00022 return cur_word;
00023 }
00024
00025 class StatCollector {
00026 public:
00027 StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts)
00028 : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) {
00029 memset(&orders_[0], 0, sizeof(OrderStat) * order);
00030 }
00031
00032 ~StatCollector() {}
00033
00034 void CalculateDiscounts(const DiscountConfig &config) {
00035 counts_.resize(orders_.size());
00036 counts_pruned_.resize(orders_.size());
00037 for (std::size_t i = 0; i < orders_.size(); ++i) {
00038 const OrderStat &s = orders_[i];
00039 counts_[i] = s.count;
00040 counts_pruned_[i] = s.count_pruned;
00041 }
00042
00043 discounts_ = config.overwrite;
00044 discounts_.resize(orders_.size());
00045 for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) {
00046 const OrderStat &s = orders_[i];
00047 try {
00048 for (unsigned j = 1; j < 4; ++j) {
00049
00050 UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
00051 << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
00052 << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?\n"
00053 << "Try deduplicating the input. To override this error for e.g. a class-based model, rerun with --discount_fallback\n");
00054 }
00055
00056
00057 discounts_[i].amount[0] = 0.0;
00058 float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
00059 for (unsigned j = 1; j < 4; ++j) {
00060 discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
00061 UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
00062 }
00063 } catch (const BadDiscountException &e) {
00064 switch (config.bad_action) {
00065 case THROW_UP:
00066 throw;
00067 case COMPLAIN:
00068 std::cerr << "Substituting fallback discounts for order " << i << ": D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl;
00069 case SILENT:
00070 break;
00071 }
00072 discounts_[i] = config.fallback;
00073 }
00074 }
00075 }
00076
00077 void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) {
00078 OrderStat &stat = orders_[order_minus_1];
00079 ++stat.count;
00080 if (!pruned)
00081 ++stat.count_pruned;
00082 if (count < 5) ++stat.n[count];
00083 }
00084
00085 void AddFull(uint64_t count, bool pruned = false) {
00086 ++full_.count;
00087 if (!pruned)
00088 ++full_.count_pruned;
00089 if (count < 5) ++full_.n[count];
00090 }
00091
00092 private:
00093 struct OrderStat {
00094
00095 uint64_t n[5];
00096 uint64_t count;
00097 uint64_t count_pruned;
00098 };
00099
00100 std::vector<OrderStat> orders_;
00101 OrderStat &full_;
00102
00103 std::vector<uint64_t> &counts_;
00104 std::vector<uint64_t> &counts_pruned_;
00105 std::vector<Discount> &discounts_;
00106 };
00107
00108
00109
00110
00111
00112 class CollapseStream {
00113 public:
00114 CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector<bool>& prune_words) :
00115 current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
00116 prune_threshold_(prune_threshold),
00117 prune_words_(prune_words),
00118 block_(position) {
00119 StartBlock();
00120 }
00121
00122 const NGram<BuildingPayload> &operator*() const { return current_; }
00123 const NGram<BuildingPayload> *operator->() const { return ¤t_; }
00124
00125 operator bool() const { return block_; }
00126
00127 CollapseStream &operator++() {
00128 assert(block_);
00129
00130 if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
00131 memcpy(current_.Base(), copy_from_, current_.TotalSize());
00132 UpdateCopyFrom();
00133
00134
00135 if(current_.Value().count <= prune_threshold_) {
00136 current_.Value().Mark();
00137 }
00138
00139 if(!prune_words_.empty()) {
00140 for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
00141 if(prune_words_[*i]) {
00142 current_.Value().Mark();
00143 break;
00144 }
00145 }
00146 }
00147
00148 }
00149
00150 current_.NextInMemory();
00151 uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
00152 if (current_.Base() == block_base + block_->ValidSize()) {
00153 block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
00154 ++block_;
00155 StartBlock();
00156 }
00157
00158
00159 if(current_.Value().count <= prune_threshold_) {
00160 current_.Value().Mark();
00161 }
00162
00163 if(!prune_words_.empty()) {
00164 for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
00165 if(prune_words_[*i]) {
00166 current_.Value().Mark();
00167 break;
00168 }
00169 }
00170 }
00171
00172 return *this;
00173 }
00174
00175 private:
00176 void StartBlock() {
00177 for (; ; ++block_) {
00178 if (!block_) return;
00179 if (block_->ValidSize()) break;
00180 }
00181 current_.ReBase(block_->Get());
00182 copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
00183 UpdateCopyFrom();
00184
00185
00186 if(current_.Value().count <= prune_threshold_) {
00187 current_.Value().Mark();
00188 }
00189
00190 if(!prune_words_.empty()) {
00191 for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
00192 if(prune_words_[*i]) {
00193 current_.Value().Mark();
00194 break;
00195 }
00196 }
00197 }
00198
00199 }
00200
00201
00202 void UpdateCopyFrom() {
00203 for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
00204 if (NGram<BuildingPayload>(copy_from_, current_.Order()).begin()[1] != kBOS) break;
00205 }
00206 }
00207
00208 NGram<BuildingPayload> current_;
00209
00210
00211 uint8_t *copy_from_;
00212 uint64_t prune_threshold_;
00213 const std::vector<bool>& prune_words_;
00214 util::stream::Link block_;
00215 };
00216
00217 }
00218
00219 void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
00220 UTIL_TIMER("(%w s) Adjusted counts\n");
00221
00222 const std::size_t order = positions.size();
00223 StatCollector stats(order, counts_, counts_pruned_, discounts_);
00224 if (order == 1) {
00225
00226
00227 for (NGramStream<BuildingPayload> full(positions[0]); full; ++full) {
00228
00229
00230 if(*full->begin() > 2) {
00231 if(full->Value().count <= prune_thresholds_[0])
00232 full->Value().Mark();
00233
00234 if(!prune_words_.empty() && prune_words_[*full->begin()])
00235 full->Value().Mark();
00236 }
00237
00238 stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
00239 }
00240
00241 stats.CalculateDiscounts(discount_config_);
00242 return;
00243 }
00244
00245 NGramStreams<BuildingPayload> streams;
00246 streams.Init(positions, positions.size() - 1);
00247
00248 CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_);
00249
00250
00251 NGramStream<BuildingPayload> *lower_valid = streams.begin();
00252 const NGramStream<BuildingPayload> *const streams_begin = streams.begin();
00253 streams[0]->Value().count = 0;
00254 *streams[0]->begin() = kUNK;
00255 stats.Add(0, 0);
00256 (++streams[0])->Value().count = 0;
00257 *streams[0]->begin() = kBOS;
00258
00259
00260
00261
00262 std::vector<uint64_t> actual_counts(positions.size(), 0);
00263
00264 actual_counts[0] = std::numeric_limits<uint64_t>::max();
00265
00266
00267 for (; full; ++full) {
00268 const WordIndex *different = FindDifference(*full, **lower_valid);
00269 std::size_t same = full->end() - 1 - different;
00270
00271
00272 for (; lower_valid >= streams.begin() + same; --lower_valid) {
00273 uint64_t order_minus_1 = lower_valid - streams_begin;
00274 if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1])
00275 (*lower_valid)->Value().Mark();
00276
00277 if(!prune_words_.empty()) {
00278 for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) {
00279 if(prune_words_[*i]) {
00280 (*lower_valid)->Value().Mark();
00281 break;
00282 }
00283 }
00284 }
00285
00286 stats.Add(order_minus_1, (*lower_valid)->Value().UnmarkedCount(), (*lower_valid)->Value().IsMarked());
00287 ++*lower_valid;
00288 }
00289
00290
00291
00292 for (std::size_t i = 0; i < same; ++i) {
00293 actual_counts[i] += full->Value().UnmarkedCount();
00294 }
00295
00296 if (same) ++streams[same - 1]->Value().count;
00297
00298
00299
00300
00301 const WordIndex *full_end = full->end();
00302
00303 const WordIndex *bos;
00304 for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
00305 NGramStream<BuildingPayload> &to = *++lower_valid;
00306 std::copy(bos, full_end, to->begin());
00307 to->Value().count = 1;
00308 actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
00309 }
00310
00311 if (bos != full->begin()) {
00312
00313 NGramStream<BuildingPayload> &to = *++lower_valid;
00314 std::copy(bos, full_end, to->begin());
00315
00316
00317 to->Value().count = full->Value().UnmarkedCount();
00318 actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
00319 } else {
00320 stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
00321 }
00322 assert(lower_valid >= &streams[0]);
00323 }
00324
00325
00326
00327 for (NGramStream<BuildingPayload> *s = streams.begin(); s <= lower_valid; ++s) {
00328 uint64_t lower_count = actual_counts[(*s)->Order() - 1];
00329 if(lower_count <= prune_thresholds_[(*s)->Order() - 1])
00330 (*s)->Value().Mark();
00331
00332 if(!prune_words_.empty()) {
00333 for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) {
00334 if(prune_words_[*i]) {
00335 (*s)->Value().Mark();
00336 break;
00337 }
00338 }
00339 }
00340
00341 stats.Add(s - streams.begin(), lower_count, (*s)->Value().IsMarked());
00342 ++*s;
00343 }
00344
00345 for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s)
00346 s->Poison();
00347
00348 stats.CalculateDiscounts(discount_config_);
00349
00350
00351 }
00352
00353 }}