00001 #ifndef INC_DYNAMICLM_H
00002 #define INC_DYNAMICLM_H
00003
00004 #include <algorithm>
00005 #include <vector>
00006 #include "perfectHash.h"
00007 #include "RandLMCache.h"
00008 #include "types.h"
00009 #include "vocab.h"
00010
00011
00012
00013
00014 using randlm::BitFilter;
00015 using randlm::Cache;
00016
00017 const bool strict_checks_ = false;
00018
00020 template<typename T>
00021 class OnlineRLM: public PerfectHash<T>
00022 {
00023 public:
00024 OnlineRLM(uint16_t MBs, int width, int bucketRange, count_t order,
00025 Moses::Vocab* v, float qBase = 8): PerfectHash<T>(MBs, width, bucketRange, qBase),
00026 vocab_(v), bAdapting_(false), order_(order), corpusSize_(0), alpha_(0) {
00027 UTIL_THROW_IF2(vocab_ == 0, "Vocab object not set");
00028
00029 cache_ = new randlm::Cache<float>(8888.8888, 9999.9999);
00030 alpha_ = new float[order_ + 1];
00031 for(count_t i = 0; i <= order_; ++i)
00032 alpha_[i] = i * log10(0.4);
00033 std::cerr << "Initialzing auxillary bit filters...\n";
00034 bPrefix_ = new randlm::BitFilter(this->cells_);
00035 bHit_ = new randlm::BitFilter(this->cells_);
00036 }
00037 OnlineRLM(Moses::FileHandler* fin, count_t order):
00038 PerfectHash<T>(fin), bAdapting_(true), order_(order), corpusSize_(0) {
00039 load(fin);
00040 cache_ = new randlm::Cache<float>(8888.8888, 9999.9999);
00041 alpha_ = new float[order_ + 1];
00042 for(count_t i = 0; i <= order_; ++i)
00043 alpha_[i] = i * log10(0.4);
00044 }
00045 ~OnlineRLM() {
00046 delete[] alpha_;
00047 if(bAdapting_) delete vocab_;
00048 else vocab_ = NULL;
00049 delete cache_;
00050 delete bPrefix_;
00051 delete bHit_;
00052 }
00053 float getProb(const wordID_t* ngram, int len, const void** state);
00054
00055 bool insert(const std::vector<std::string>& ngram, const int value);
00056 bool update(const std::vector<std::string>& ngram, const int value);
00057 int query(const wordID_t* IDs, const int len);
00058 int sbsqQuery(const std::vector<std::string>& ngram, int* len,
00059 bool bStrict = false);
00060 int sbsqQuery(const wordID_t* IDs, const int len, int* codes,
00061 bool bStrict = false);
00062 void remove(const std::vector<std::string>& ngram);
00063 count_t heurDelete(count_t num2del, count_t order = 5);
00064 uint64_t corpusSize() {
00065 return corpusSize_;
00066 }
00067 void corpusSize(uint64_t c) {
00068 corpusSize_ = c;
00069 }
00070 void clearCache() {
00071 if(cache_) cache_->clear();
00072 }
00073 void save(Moses::FileHandler* fout);
00074 void load(Moses::FileHandler* fin);
00075 void randDelete(int num2del);
00076 int countHits();
00077 int countPrefixes();
00078 int cleanUpHPD();
00079 void clearMarkings();
00080 void removeNonMarked();
00081 Moses::Vocab* vocab_;
00082 protected:
00083 void markQueried(const uint64_t& index);
00084 void markQueried(hpdEntry_t& value);
00085 bool markPrefix(const wordID_t* IDs, const int len, bool bSet);
00086 private:
00087 const void* getContext(const wordID_t* ngram, int len);
00088 const bool bAdapting_;
00089 const count_t order_;
00090 uint64_t corpusSize_;
00091 float* alpha_;
00092 randlm::Cache<float>* cache_;
00093 randlm::BitFilter* bPrefix_;
00094 randlm::BitFilter* bHit_;
00095 };
00096
00097 template<typename T>
00098 bool OnlineRLM<T>::insert(const std::vector<std::string>& ngram, const int value)
00099 {
00100 int len = ngram.size();
00101 wordID_t wrdIDs[len];
00102 uint64_t index(this->cells_ + 1);
00103 for(int i = 0; i < len; ++i)
00104 wrdIDs[i] = vocab_->GetWordID(ngram[i]);
00105 index = PerfectHash<T>::insert(wrdIDs, len, value);
00106 if(value > 1 && len < order_)
00107 markPrefix(wrdIDs, ngram.size(), true);
00108
00109 if(ngram.size() == 1 && (!bAdapting_))
00110 corpusSize_ += (wrdIDs[0] != vocab_->GetBOSWordID()) ? value : 0;
00111 if(bAdapting_ && (index < this->cells_))
00112 markQueried(index);
00113 return true;
00114 }
00115
00116 template<typename T>
00117 bool OnlineRLM<T>::update(const std::vector<std::string>& ngram, const int value)
00118 {
00119 int len = ngram.size();
00120 std::vector<wordID_t> wrdIDs(len);
00121 uint64_t index(this->cells_ + 1);
00122 hpdEntry_t hpdItr;
00123 vocab_->MakeOpen();
00124 for(int i = 0; i < len; ++i)
00125 wrdIDs[i] = vocab_->GetWordID(ngram[i]);
00126
00127 bool bIncluded(true);
00128 if(value > 1 && len < (int)order_)
00129 bIncluded = markPrefix(&wrdIDs[0], ngram.size(), true);
00130 if(bIncluded) {
00131 bIncluded = PerfectHash<T>::update2(&wrdIDs[0], len, value, hpdItr, index);
00132 if(index < this->cells_) {
00133 markQueried(index);
00134 } else if(hpdItr != this->dict_.end()) markQueried(hpdItr);
00135 }
00136
00137 return bIncluded;
00138 }
00139 template<typename T>
00140 int OnlineRLM<T>::query(const wordID_t* IDs, int len)
00141 {
00142 uint64_t filterIdx = 0;
00143 hpdEntry_t hpdItr;
00144 int value(0);
00145 value = PerfectHash<T>::query(IDs, len, hpdItr, filterIdx);
00146 if(value != -1) {
00147 if(hpdItr != this->dict_.end()) {
00148
00149 value -= ((value & this->hitMask_) != 0) ? this->hitMask_ : 0;
00150 } else {
00151 UTIL_THROW_IF2(filterIdx >= this->cells_,
00152 "Out of bound: " << filterIdx);
00153
00154 }
00155 }
00156 return value > 0 ? value : 0;
00157 }
00158
00159 template<typename T>
00160 bool OnlineRLM<T>::markPrefix(const wordID_t* IDs, const int len, bool bSet)
00161 {
00162 if(len <= 1) return true;
00163 static randlm::Cache<int> pfCache(-1, -1);
00164 int code(0);
00165 if(!pfCache.checkCacheNgram(IDs, len - 1, &code, NULL)) {
00166 hpdEntry_t hpdItr;
00167 uint64_t filterIndex(0);
00168 code = PerfectHash<T>::query(IDs, len - 1, hpdItr, filterIndex);
00169 if(code == -1) {
00170 std::cerr << "WARNING: markPrefix(). The O-RLM is *not* well-formed.\n";
00171
00172 return false;
00173 }
00174 if(filterIndex != this->cells_ + 1) {
00175 UTIL_THROW_IF2(hpdItr != this->dict_.end(), "Error");
00176 if(bSet) bPrefix_->setBit(filterIndex);
00177 else bPrefix_->clearBit(filterIndex);
00178 } else {
00179 UTIL_THROW_IF2(filterIndex != this->cells_ + 1, "Error");
00180
00181 }
00182 if(pfCache.nodes() > 10000) pfCache.clear();
00183 pfCache.setCacheNgram(IDs, len - 1, code, NULL);
00184 }
00185 return true;
00186 }
00187
00188 template<typename T>
00189 void OnlineRLM<T>::markQueried(const uint64_t& index)
00190 {
00191 bHit_->setBit(index);
00192
00193 }
00194
00195 template<typename T>
00196 void OnlineRLM<T>::markQueried(hpdEntry_t& value)
00197 {
00198
00199 value->second |= this->hitMask_;
00200 }
00201
00202 template<typename T>
00203 void OnlineRLM<T>::remove(const std::vector<std::string>& ngram)
00204 {
00205 wordID_t IDs[ngram.size()];
00206 for(count_t i = 0; i < ngram.size(); ++i)
00207 IDs[i] = vocab_->GetWordID(ngram[i]);
00208 PerfectHash<T>::remove(IDs, ngram.size());
00209 }
00210
00211 template<typename T>
00212 count_t OnlineRLM<T>::heurDelete(count_t num2del, count_t order)
00213 {
00214 count_t deleted = 0;
00215 std::cout << "Deleting " << num2del << " of order "<< order << std::endl;
00216
00217 int full = *std::max_element(this->idxTracker_, this->idxTracker_
00218 + this->totBuckets_);
00219 for(; full > 0; --full)
00220 for(int bk = 0; bk < this->totBuckets_; ++bk) {
00221 if(deleted >= num2del) break;
00222 if(this->idxTracker_[bk] == full) {
00223 uint64_t first = bk * this->bucketRange_,
00224 last = first + this->bucketRange_;
00225 for(uint64_t row = first; row < last; ++row) {
00226 if(!(bHit_->testBit(row) || bPrefix_->testBit(row) )) {
00227 if(this->filter_->read(row) != 0) {
00228 PerfectHash<T>::remove(row);
00229 ++deleted;
00230 }
00231 }
00232 }
00233 }
00234 }
00235 if(deleted < num2del) {
00236
00237 std::cerr << "TODO! HPD deletions\n";
00238 }
00239 std::cerr << "Total deleted = " << deleted << std::endl;
00240 return deleted;
00241 }
00242
00243 template<typename T>
00244 int OnlineRLM<T>::sbsqQuery(const std::vector<std::string>& ngram, int* codes,
00245 bool bStrict)
00246 {
00247 wordID_t IDs[ngram.size()];
00248 for(count_t i = 0; i < ngram.size(); ++i)
00249 IDs[i] = vocab_->GetWordID(ngram[i]);
00250 return sbsqQuery(IDs, ngram.size(), codes, bStrict);
00251 }
00252
00253 template<typename T>
00254 int OnlineRLM<T>::sbsqQuery(const wordID_t* IDs, const int len, int* codes,
00255 bool bStrict)
00256 {
00257 uint64_t filterIdx = 0;
00258 int val(0), fnd(0);
00259 hpdEntry_t hpdItr;
00260 for(int i = len - 1; i >= 0; --i) {
00261
00262 val = PerfectHash<T>::query(&IDs[i], len - i, hpdItr, filterIdx);
00263 if(val != -1) {
00264 fnd = len - i;
00265 if(hpdItr != this->dict_.end()) {
00266 val -= ((val & this->hitMask_) != 0) ? this->hitMask_ : 0;
00267 }
00268 } else if(bStrict) {
00269 break;
00270 }
00271
00272 codes[i] = val > 0 ? val : 0;
00273 }
00274 while(bStrict && (fnd > 1)) {
00275 val = PerfectHash<T>::query(&IDs[len - fnd], fnd - 1, hpdItr, filterIdx);
00276 if(val != -1) break;
00277 else --fnd;
00278 }
00279
00280 return fnd;
00281 }
00282
00283 template<typename T>
00284 float OnlineRLM<T>::getProb(const wordID_t* ngram, int len,
00285 const void** state)
00286 {
00287 static const float oovprob = log10(1.0 / (static_cast<float>(vocab_->Size()) - 1));
00288 float logprob(0);
00289 const void* context = (state) ? *state : 0;
00290
00291 if(!cache_->checkCacheNgram(ngram, len, &logprob, &context)) {
00292
00293 int num_fnd(0), den_val(0);
00294 int *in = new int[len];
00295 for(int i = 0; i < len; ++i) in[i] = 0;
00296 for(int i = len - 1; i >= 0; --i) {
00297 if(ngram[i] == vocab_->GetkOOVWordID()) break;
00298 in[i] = query(&ngram[i], len - i);
00299 if(in[i] > 0) {
00300 num_fnd = len - i;
00301 } else if(strict_checks_) break;
00302 }
00303 while(num_fnd > 1) {
00304
00305 den_val = query(&ngram[len - num_fnd], num_fnd - 1);
00306 if((den_val > 0) &&
00307 (den_val >= in[len - num_fnd]) && (in[len - num_fnd] > 0)) {
00308 break;
00309 } else --num_fnd;
00310 }
00311 if(num_fnd == 1 && (in[len - 1] < 1))
00312 num_fnd = 0;
00313 switch(num_fnd) {
00314 case 0:
00315 logprob = alpha_[len] + oovprob;
00316 break;
00317 case 1:
00318 UTIL_THROW_IF2(in[len - 1] <= 0, "Error");
00319 logprob = alpha_[len - 1] + (corpusSize_ > 0 ?
00320 log10(static_cast<float>(in[len - 1]) / static_cast<float>(corpusSize_)) : 0);
00321
00322
00323 break;
00324 default:
00325 UTIL_THROW_IF2(den_val <= 0, "Error");
00326
00327 logprob = alpha_[len - num_fnd] +
00328 log10(static_cast<float>(in[len - num_fnd]) / static_cast<float>(den_val));
00329 break;
00330 }
00331
00332 context = getContext(&ngram[len - num_fnd], num_fnd);
00333
00334 cache_->setCacheNgram(ngram, len, logprob, context);
00335 }
00336 return logprob;
00337 }
00338
00339 template<typename T>
00340 const void* OnlineRLM<T>::getContext(const wordID_t* ngram, int len)
00341 {
00342 int dummy(0);
00343 float**addresses = new float*[len];
00344 UTIL_THROW_IF2(cache_->getCache2(ngram, len, &addresses[0], &dummy) != len,
00345 "Error");
00346
00347
00348 float *addr0 = addresses[0];
00349 free( addresses );
00350 return (const void*)addr0;
00351 }
00352
00353 template<typename T>
00354 void OnlineRLM<T>::randDelete(int num2del)
00355 {
00356 int deleted = 0;
00357 for(uint64_t i = 0; i < this->cells_; i++) {
00358 if(this->filter_->read(i) != 0) {
00359 PerfectHash<T>::remove(i);
00360 ++deleted;
00361 }
00362 if(deleted >= num2del) break;
00363 }
00364 }
00365
00366 template<typename T>
00367 int OnlineRLM<T>::countHits()
00368 {
00369 int hit(0);
00370 for(uint64_t i = 0; i < this->cells_; ++i)
00371 if(bHit_->testBit(i)) ++hit;
00372 iterate(this->dict_, itr)
00373 if((itr->second & this->hitMask_) != 0)
00374 ++hit;
00375 std::cerr << "Hit count = " << hit << std::endl;
00376 return hit;
00377 }
00378
00379 template<typename T>
00380 int OnlineRLM<T>::countPrefixes()
00381 {
00382 int pfx(0);
00383 for(uint64_t i = 0; i < this->cells_; ++i)
00384 if(bPrefix_->testBit(i)) ++pfx;
00385
00386 std::cerr << "Prefix count (in filter) = " << pfx << std::endl;
00387 return pfx;
00388 }
00389
00390 template<typename T>
00391 int OnlineRLM<T>::cleanUpHPD()
00392 {
00393 std::cerr << "HPD size before = " << this->dict_.size() << std::endl;
00394 std::vector<std::string> vDel, vtmp;
00395 iterate(this->dict_, itr) {
00396 if(((itr->second & this->hitMask_) == 0) &&
00397 (Utils::splitToStr(itr->first, vtmp, "¬") >= 3)) {
00398 vDel.push_back(itr->first);
00399 }
00400 }
00401 iterate(vDel, vitr)
00402 this->dict_.erase(*vitr);
00403 std::cerr << "HPD size after = " << this->dict_.size() << std::endl;
00404 return vDel.size();
00405 }
00406
00407 template<typename T>
00408 void OnlineRLM<T>::clearMarkings()
00409 {
00410 std::cerr << "clearing all event hits\n";
00411 bHit_->reset();
00412 count_t* value(0);
00413 iterate(this->dict_, itr) {
00414 value = &itr->second;
00415 *value -= ((*value & this->hitMask_) != 0) ? this->hitMask_ : 0;
00416 }
00417 }
00418
00419 template<typename T>
00420 void OnlineRLM<T>::save(Moses::FileHandler* fout)
00421 {
00422 std::cerr << "Saving ORLM...\n";
00423
00424 vocab_->Save(fout);
00425 fout->write((char*)&corpusSize_, sizeof(corpusSize_));
00426 fout->write((char*)&order_, sizeof(order_));
00427 bPrefix_->save(fout);
00428 bHit_->save(fout);
00429
00430 PerfectHash<T>::save(fout);
00431 std::cerr << "Finished saving ORLM." << std::endl;
00432 }
00433
00434 template<typename T>
00435 void OnlineRLM<T>::load(Moses::FileHandler* fin)
00436 {
00437 std::cerr << "Loading ORLM...\n";
00438
00439 vocab_ = new Moses::Vocab(fin);
00440 UTIL_THROW_IF2(vocab_ == 0, "Vocab object not set");
00441 fin->read((char*)&corpusSize_, sizeof(corpusSize_));
00442 std::cerr << "\tCorpus size = " << corpusSize_ << std::endl;
00443 fin->read((char*)&order_, sizeof(order_));
00444 std::cerr << "\tModel order = " << order_ << std::endl;
00445 bPrefix_ = new randlm::BitFilter(fin);
00446 bHit_ = new randlm::BitFilter(fin);
00447
00448 PerfectHash<T>::load(fin);
00449 }
00450
00451 template<typename T>
00452 void OnlineRLM<T>::removeNonMarked()
00453 {
00454 std::cerr << "deleting all unused events\n";
00455 int deleted(0);
00456 for(uint64_t i = 0; i < this->cells_; ++i) {
00457 if(!(bHit_->testBit(i) || bPrefix_->testBit(i))
00458 && (this->filter_->read(i) != 0)) {
00459 PerfectHash<T>::remove(i);
00460 ++deleted;
00461 }
00462 }
00463 deleted += cleanUpHPD();
00464 std::cerr << "total removed from ORLM = " << deleted << std::endl;
00465 }
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490
00491
00492
00493
00494
00495
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541 #endif
00542