00001 #ifndef INC_DYNAMICLM_H
00002 #define INC_DYNAMICLM_H
00003
00004 #include <algorithm>
00005 #include <vector>
00006 #include "perfectHash.h"
00007 #include "RandLMCache.h"
00008 #include "types.h"
00009 #include "vocab.h"
00010
00011
00012
00013
00014 using randlm::BitFilter;
00015 using randlm::Cache;
00016
00017 const bool strict_checks_ = false;
00018
00020 template<typename T>
00021 class OnlineRLM: public PerfectHash<T>
00022 {
00023 public:
00024 OnlineRLM(uint16_t MBs, int width, int bucketRange, count_t order,
00025 Moses::Vocab* v, float qBase = 8): PerfectHash<T>(MBs, width, bucketRange, qBase),
00026 vocab_(v), bAdapting_(false), order_(order), corpusSize_(0), alpha_(0) {
00027 CHECK(vocab_ != 0);
00028
00029 cache_ = new Cache<float>(8888.8888, 9999.9999);
00030 alpha_ = new float[order_ + 1];
00031 for(count_t i = 0; i <= order_; ++i)
00032 alpha_[i] = i * log10(0.4);
00033 cerr << "Initialzing auxillary bit filters...\n";
00034 bPrefix_ = new BitFilter(this->cells_);
00035 bHit_ = new BitFilter(this->cells_);
00036 }
00037 OnlineRLM(FileHandler* fin, count_t order):
00038 PerfectHash<T>(fin), bAdapting_(true), order_(order), corpusSize_(0) {
00039 load(fin);
00040 cache_ = new Cache<float>(8888.8888, 9999.9999);
00041 alpha_ = new float[order_ + 1];
00042 for(count_t i = 0; i <= order_; ++i)
00043 alpha_[i] = i * log10(0.4);
00044 }
00045 ~OnlineRLM() {
00046 if(alpha_) delete[] alpha_;
00047 if(bAdapting_) delete vocab_;
00048 else vocab_ = NULL;
00049 if(cache_) delete cache_;
00050 delete bPrefix_;
00051 delete bHit_;
00052 }
00053 float getProb(const wordID_t* ngram, int len, const void** state);
00054
00055 bool insert(const std::vector<string>& ngram, const int value);
00056 bool update(const std::vector<string>& ngram, const int value);
00057 int query(const wordID_t* IDs, const int len);
00058 int sbsqQuery(const std::vector<string>& ngram, int* len,
00059 bool bStrict = false);
00060 int sbsqQuery(const wordID_t* IDs, const int len, int* codes,
00061 bool bStrict = false);
00062 void remove(const std::vector<string>& ngram);
00063 count_t heurDelete(count_t num2del, count_t order = 5);
00064 uint64_t corpusSize() {
00065 return corpusSize_;
00066 }
00067 void corpusSize(uint64_t c) {
00068 corpusSize_ = c;
00069 }
00070 void clearCache() {
00071 if(cache_) cache_->clear();
00072 }
00073 void save(FileHandler* fout);
00074 void load(FileHandler* fin);
00075 void randDelete(int num2del);
00076 int countHits();
00077 int countPrefixes();
00078 int cleanUpHPD();
00079 void clearMarkings();
00080 void removeNonMarked();
00081 Moses::Vocab* vocab_;
00082 protected:
00083 void markQueried(const uint64_t& index);
00084 void markQueried(hpdEntry_t& value);
00085 bool markPrefix(const wordID_t* IDs, const int len, bool bSet);
00086 private:
00087 const void* getContext(const wordID_t* ngram, int len);
00088 const bool bAdapting_;
00089 const count_t order_;
00090 uint64_t corpusSize_;
00091 float* alpha_;
00092 Cache<float>* cache_;
00093 BitFilter* bPrefix_;
00094 BitFilter* bHit_;
00095 };
00096
00097 template<typename T>
00098 bool OnlineRLM<T>::insert(const std::vector<string>& ngram, const int value)
00099 {
00100 int len = ngram.size();
00101 wordID_t wrdIDs[len];
00102 uint64_t index(this->cells_ + 1);
00103 for(int i = 0; i < len; ++i)
00104 wrdIDs[i] = vocab_->GetWordID(ngram[i]);
00105 index = PerfectHash<T>::insert(wrdIDs, len, value);
00106 if(value > 1 && len < order_)
00107 markPrefix(wrdIDs, ngram.size(), true);
00108
00109 if(ngram.size() == 1 && (!bAdapting_))
00110 corpusSize_ += (wrdIDs[0] != vocab_->GetBOSWordID()) ? value : 0;
00111 if(bAdapting_ && (index < this->cells_))
00112 markQueried(index);
00113 return true;
00114 }
00115
00116 template<typename T>
00117 bool OnlineRLM<T>::update(const std::vector<string>& ngram, const int value)
00118 {
00119 int len = ngram.size();
00120 std::vector<wordID_t> wrdIDs(len);
00121 uint64_t index(this->cells_ + 1);
00122 hpdEntry_t hpdItr;
00123 vocab_->MakeOpen();
00124 for(int i = 0; i < len; ++i)
00125 wrdIDs[i] = vocab_->GetWordID(ngram[i]);
00126
00127 bool bIncluded(true);
00128 if(value > 1 && len < (int)order_)
00129 bIncluded = markPrefix(&wrdIDs[0], ngram.size(), true);
00130 if(bIncluded) {
00131 bIncluded = PerfectHash<T>::update2(&wrdIDs[0], len, value, hpdItr, index);
00132 if(index < this->cells_) {
00133 markQueried(index);
00134 } else if(hpdItr != this->dict_.end()) markQueried(hpdItr);
00135 }
00136
00137 return bIncluded;
00138 }
00139 template<typename T>
00140 int OnlineRLM<T>::query(const wordID_t* IDs, int len)
00141 {
00142 uint64_t filterIdx = 0;
00143 hpdEntry_t hpdItr;
00144 int value(0);
00145 value = PerfectHash<T>::query(IDs, len, hpdItr, filterIdx);
00146 if(value != -1) {
00147 if(hpdItr != this->dict_.end()) {
00148
00149 value -= ((value & this->hitMask_) != 0) ? this->hitMask_ : 0;
00150 } else {
00151 CHECK(filterIdx < this->cells_);
00152
00153 }
00154 }
00155 return value > 0 ? value : 0;
00156 }
00157
00158 template<typename T>
00159 bool OnlineRLM<T>::markPrefix(const wordID_t* IDs, const int len, bool bSet)
00160 {
00161 if(len <= 1) return true;
00162 static Cache<int> pfCache(-1, -1);
00163 int code(0);
00164 if(!pfCache.checkCacheNgram(IDs, len - 1, &code, NULL)) {
00165 hpdEntry_t hpdItr;
00166 uint64_t filterIndex(0);
00167 code = PerfectHash<T>::query(IDs, len - 1, hpdItr, filterIndex);
00168 if(code == -1) {
00169 cerr << "WARNING: markPrefix(). The O-RLM is *not* well-formed.\n";
00170
00171 return false;
00172 }
00173 if(filterIndex != this->cells_ + 1) {
00174 CHECK(hpdItr == this->dict_.end());
00175 if(bSet) bPrefix_->setBit(filterIndex);
00176 else bPrefix_->clearBit(filterIndex);
00177 } else {
00178 CHECK(filterIndex == this->cells_ + 1);
00179
00180 }
00181 if(pfCache.nodes() > 10000) pfCache.clear();
00182 pfCache.setCacheNgram(IDs, len - 1, code, NULL);
00183 }
00184 return true;
00185 }
00186
00187 template<typename T>
00188 void OnlineRLM<T>::markQueried(const uint64_t& index)
00189 {
00190 bHit_->setBit(index);
00191
00192 }
00193
00194 template<typename T>
00195 void OnlineRLM<T>::markQueried(hpdEntry_t& value)
00196 {
00197
00198 value->second |= this->hitMask_;
00199 }
00200
00201 template<typename T>
00202 void OnlineRLM<T>::remove(const std::vector<string>& ngram)
00203 {
00204 wordID_t IDs[ngram.size()];
00205 for(count_t i = 0; i < ngram.size(); ++i)
00206 IDs[i] = vocab_->GetWordID(ngram[i]);
00207 PerfectHash<T>::remove(IDs, ngram.size());
00208 }
00209
00210 template<typename T>
00211 count_t OnlineRLM<T>::heurDelete(count_t num2del, count_t order)
00212 {
00213 count_t deleted = 0;
00214 cout << "Deleting " << num2del << " of order "<< order << endl;
00215
00216 int full = *std::max_element(this->idxTracker_, this->idxTracker_
00217 + this->totBuckets_);
00218 for(; full > 0; --full)
00219 for(int bk = 0; bk < this->totBuckets_; ++bk) {
00220 if(deleted >= num2del) break;
00221 if(this->idxTracker_[bk] == full) {
00222 uint64_t first = bk * this->bucketRange_,
00223 last = first + this->bucketRange_;
00224 for(uint64_t row = first; row < last; ++row) {
00225 if(!(bHit_->testBit(row) || bPrefix_->testBit(row) )) {
00226 if(this->filter_->read(row) != 0) {
00227 PerfectHash<T>::remove(row);
00228 ++deleted;
00229 }
00230 }
00231 }
00232 }
00233 }
00234 if(deleted < num2del) {
00235
00236 cerr << "TODO! HPD deletions\n";
00237 }
00238 cerr << "Total deleted = " << deleted << endl;
00239 return deleted;
00240 }
00241
00242 template<typename T>
00243 int OnlineRLM<T>::sbsqQuery(const std::vector<string>& ngram, int* codes,
00244 bool bStrict)
00245 {
00246 wordID_t IDs[ngram.size()];
00247 for(count_t i = 0; i < ngram.size(); ++i)
00248 IDs[i] = vocab_->GetWordID(ngram[i]);
00249 return sbsqQuery(IDs, ngram.size(), codes, bStrict);
00250 }
00251
00252 template<typename T>
00253 int OnlineRLM<T>::sbsqQuery(const wordID_t* IDs, const int len, int* codes,
00254 bool bStrict)
00255 {
00256 uint64_t filterIdx = 0;
00257 int val(0), fnd(0);
00258 hpdEntry_t hpdItr;
00259 for(int i = len - 1; i >= 0; --i) {
00260
00261 val = PerfectHash<T>::query(&IDs[i], len - i, hpdItr, filterIdx);
00262 if(val != -1) {
00263 fnd = len - i;
00264 if(hpdItr != this->dict_.end()) {
00265 val -= ((val & this->hitMask_) != 0) ? this->hitMask_ : 0;
00266 }
00267 } else if(bStrict) {
00268 break;
00269 }
00270
00271 codes[i] = val > 0 ? val : 0;
00272 }
00273 while(bStrict && (fnd > 1)) {
00274 val = PerfectHash<T>::query(&IDs[len - fnd], fnd - 1, hpdItr, filterIdx);
00275 if(val != -1) break;
00276 else --fnd;
00277 }
00278
00279 return fnd;
00280 }
00281
00282 template<typename T>
00283 float OnlineRLM<T>::getProb(const wordID_t* ngram, int len,
00284 const void** state)
00285 {
00286 static const float oovprob = log10(1.0 / (static_cast<float>(vocab_->Size()) - 1));
00287 float logprob(0);
00288 const void* context = (state) ? *state : 0;
00289
00290 if(!cache_->checkCacheNgram(ngram, len, &logprob, &context)) {
00291
00292 int num_fnd(0), den_val(0);
00293 int *in = new int[len];
00294 for(int i = 0; i < len; ++i) in[i] = 0;
00295 for(int i = len - 1; i >= 0; --i) {
00296 if(ngram[i] == vocab_->GetkOOVWordID()) break;
00297 in[i] = query(&ngram[i], len - i);
00298 if(in[i] > 0) {
00299 num_fnd = len - i;
00300 } else if(strict_checks_) break;
00301 }
00302 while(num_fnd > 1) {
00303
00304 if(((den_val = query(&ngram[len - num_fnd], num_fnd - 1)) > 0) &&
00305 (den_val >= in[len - num_fnd]) && (in[len - num_fnd] > 0)) {
00306 break;
00307 } else --num_fnd;
00308 }
00309 if(num_fnd == 1 && (in[len - 1] < 1))
00310 num_fnd = 0;
00311 switch(num_fnd) {
00312 case 0:
00313 logprob = alpha_[len] + oovprob;
00314 break;
00315 case 1:
00316 CHECK(in[len - 1] > 0);
00317 logprob = alpha_[len - 1] + (corpusSize_ > 0 ?
00318 log10(static_cast<float>(in[len - 1]) / static_cast<float>(corpusSize_)) : 0);
00319
00320
00321 break;
00322 default:
00323 CHECK(den_val > 0);
00324
00325 logprob = alpha_[len - num_fnd] +
00326 log10(static_cast<float>(in[len - num_fnd]) / static_cast<float>(den_val));
00327 break;
00328 }
00329
00330 context = getContext(&ngram[len - num_fnd], num_fnd);
00331
00332 cache_->setCacheNgram(ngram, len, logprob, context);
00333 }
00334 return logprob;
00335 }
00336
00337 template<typename T>
00338 const void* OnlineRLM<T>::getContext(const wordID_t* ngram, int len)
00339 {
00340 int dummy(0);
00341 float**addresses = new float*[len];
00342 CHECK(cache_->getCache2(ngram, len, &addresses[0], &dummy) == len);
00343
00344
00345 float *addr0 = addresses[0];
00346 free( addresses );
00347 return (const void*)addr0;
00348 }
00349
00350 template<typename T>
00351 void OnlineRLM<T>::randDelete(int num2del)
00352 {
00353 int deleted = 0;
00354 for(uint64_t i = 0; i < this->cells_; i++) {
00355 if(this->filter_->read(i) != 0) {
00356 PerfectHash<T>::remove(i);
00357 ++deleted;
00358 }
00359 if(deleted >= num2del) break;
00360 }
00361 }
00362
00363 template<typename T>
00364 int OnlineRLM<T>::countHits()
00365 {
00366 int hit(0);
00367 for(uint64_t i = 0; i < this->cells_; ++i)
00368 if(bHit_->testBit(i)) ++hit;
00369 iterate(this->dict_, itr)
00370 if((itr->second & this->hitMask_) != 0)
00371 ++hit;
00372 cerr << "Hit count = " << hit << endl;
00373 return hit;
00374 }
00375
00376 template<typename T>
00377 int OnlineRLM<T>::countPrefixes()
00378 {
00379 int pfx(0);
00380 for(uint64_t i = 0; i < this->cells_; ++i)
00381 if(bPrefix_->testBit(i)) ++pfx;
00382
00383 cerr << "Prefix count (in filter) = " << pfx << endl;
00384 return pfx;
00385 }
00386
00387 template<typename T>
00388 int OnlineRLM<T>::cleanUpHPD()
00389 {
00390 cerr << "HPD size before = " << this->dict_.size() << endl;
00391 std::vector<string> vDel, vtmp;
00392 iterate(this->dict_, itr) {
00393 if(((itr->second & this->hitMask_) == 0) &&
00394 (Utils::splitToStr(itr->first, vtmp, "¬") >= 3)) {
00395 vDel.push_back(itr->first);
00396 }
00397 }
00398 iterate(vDel, vitr)
00399 this->dict_.erase(*vitr);
00400 cerr << "HPD size after = " << this->dict_.size() << endl;
00401 return vDel.size();
00402 }
00403
00404 template<typename T>
00405 void OnlineRLM<T>::clearMarkings()
00406 {
00407 cerr << "clearing all event hits\n";
00408 bHit_->reset();
00409 count_t* value(0);
00410 iterate(this->dict_, itr) {
00411 value = &itr->second;
00412 *value -= ((*value & this->hitMask_) != 0) ? this->hitMask_ : 0;
00413 }
00414 }
00415
00416 template<typename T>
00417 void OnlineRLM<T>::save(FileHandler* fout)
00418 {
00419 cerr << "Saving ORLM...\n";
00420
00421 vocab_->Save(fout);
00422 fout->write((char*)&corpusSize_, sizeof(corpusSize_));
00423 fout->write((char*)&order_, sizeof(order_));
00424 bPrefix_->save(fout);
00425 bHit_->save(fout);
00426
00427 PerfectHash<T>::save(fout);
00428 cerr << "Finished saving ORLM." << endl;
00429 }
00430
00431 template<typename T>
00432 void OnlineRLM<T>::load(FileHandler* fin)
00433 {
00434 cerr << "Loading ORLM...\n";
00435
00436 vocab_ = new Moses::Vocab(fin);
00437 CHECK(vocab_ != 0);
00438 fin->read((char*)&corpusSize_, sizeof(corpusSize_));
00439 cerr << "\tCorpus size = " << corpusSize_ << endl;
00440 fin->read((char*)&order_, sizeof(order_));
00441 cerr << "\tModel order = " << order_ << endl;
00442 bPrefix_ = new BitFilter(fin);
00443 bHit_ = new BitFilter(fin);
00444
00445 PerfectHash<T>::load(fin);
00446 }
00447
00448 template<typename T>
00449 void OnlineRLM<T>::removeNonMarked()
00450 {
00451 cerr << "deleting all unused events\n";
00452 int deleted(0);
00453 for(uint64_t i = 0; i < this->cells_; ++i) {
00454 if(!(bHit_->testBit(i) || bPrefix_->testBit(i))
00455 && (this->filter_->read(i) != 0)) {
00456 PerfectHash<T>::remove(i);
00457 ++deleted;
00458 }
00459 }
00460 deleted += cleanUpHPD();
00461 cerr << "total removed from ORLM = " << deleted << endl;
00462 }
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490
00491
00492
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538 #endif
00539