00001
00002 #ifndef INC_PERFECTHASH_H
00003 #define INC_PERFECTHASH_H
00004
00005 #include <map>
00006 #include <stdint.h>
00007 #include "hash.h"
00008 #include "RandLMFilter.h"
00009 #include "quantizer.h"
00010
00015 using randlm::Filter;
00016 using randlm::BitFilter;
00017 typedef std::map<std::string, count_t> hpDict_t;
00018 typedef hpDict_t::iterator hpdEntry_t;
00019 static count_t collisions_ = 0;
00020
00021
00022 template<typename T>
00023 class PerfectHash
00024 {
00025 public:
00026 PerfectHash(uint16_t MBs, int width, int bucketRange, float qBase);
00027 PerfectHash(Moses::FileHandler* fin) {
00028 UTIL_THROW_IF2(fin == 0, "Invalid file handle");
00029 }
00030 virtual ~PerfectHash();
00031 void analyze();
00032 count_t hpDictMemUse();
00033 count_t bucketsMemUse();
00034 protected:
00035 Filter<T>* filter_;
00036 Filter<T>* values_;
00037 hpDict_t dict_;
00038 uint64_t cells_;
00039 count_t hitMask_;
00040 int totBuckets_;
00041 uint8_t bucketRange_;
00042 uint8_t* idxTracker_;
00043 uint64_t insert(const wordID_t* IDs, const int len, const count_t value);
00044 bool update(const wordID_t* IDs, const int len, const count_t value,
00045 hpdEntry_t& hpdAddr, uint64_t& filterIdx);
00046 bool update2(const wordID_t* IDs, const int len, const count_t value,
00047 hpdEntry_t& hpdAddr, uint64_t& filterIdx);
00048 int query(const wordID_t* IDs, const int len,
00049 hpdEntry_t& hpdAddr, uint64_t& filterIdx);
00050 virtual void remove(const wordID_t* IDs, const int len);
00051 void remove(uint64_t index);
00052 void save(Moses::FileHandler* fout);
00053 void load(Moses::FileHandler* fin);
00054 virtual void markQueried(const uint64_t&)=0;
00055
00056 virtual void markQueried(hpdEntry_t&)=0;
00057 private:
00058 T nonZeroSignature(const wordID_t* IDs, const int len, count_t bucket);
00059 std::string hpDictKeyValue(const wordID_t* IDs, const int len);
00060 uint64_t memBound_;
00061 uint16_t cellWidth_;
00062 UnivHash_linear<count_t>* bucketHash_;
00063 UnivHash_linear<T>* fingerHash_;
00064 LogQtizer* qtizer_;
00065 };
00066
00067 template<typename T>
00068 PerfectHash<T>::PerfectHash(uint16_t MBs, int width, int bucketRange,
00069 float qBase): hitMask_(1 << 31), memBound_(MBs * (1ULL << 20)),
00070 cellWidth_(width)
00071 {
00072 bucketRange_ = static_cast<uint8_t>(bucketRange);
00073 if(bucketRange > 255) {
00074 std::cerr << "ERROR: Max bucket range is > 2^8\n";
00075 exit(1);
00076 }
00077 qtizer_ = new LogQtizer(qBase);
00078 int valBits = (int)ceil(log2((float)qtizer_->maxcode()));
00079 std::cerr << "BITS FOR VALUES ARRAY = " << valBits << std::endl;
00080 uint64_t totalBits = memBound_ << 3;
00081 cells_ = (uint64_t) ceil((float)totalBits / (float)(cellWidth_ + valBits));
00082 cells_ += (cells_ % bucketRange_);
00083 totBuckets_ = (cells_ / bucketRange_) - 1;
00084 filter_ = new Filter<T>(cells_, cellWidth_);
00085 values_ = new Filter<T>(cells_, valBits);
00086 idxTracker_ = new uint8_t[totBuckets_];
00087 for(int i=0; i < totBuckets_; ++i) idxTracker_[i] = 0;
00088
00089 bucketHash_ = new UnivHash_linear<count_t>(totBuckets_, 1, PRIME);
00090 fingerHash_ = new UnivHash_linear<T>(pow(2.0f, cellWidth_), MAX_HASH_FUNCS, PRIME);
00091 }
00092
00093 template<typename T>
00094 PerfectHash<T>::~PerfectHash()
00095 {
00096 delete[] idxTracker_;
00097 delete filter_;
00098 filter_ = NULL;
00099 delete fingerHash_;
00100 delete bucketHash_;
00101 delete qtizer_;
00102 delete values_;
00103 }
00104
00105 template<typename T>
00106 uint64_t PerfectHash<T>::insert(const wordID_t* IDs, const int len,
00107 const count_t value)
00108 {
00109 count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len, 0));
00110 if(idxTracker_[bucket] < (int)bucketRange_) {
00111
00112 T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
00113 uint64_t emptyidx = cells_ + 1;
00114 uint64_t index = bucket * bucketRange_,
00115 lastrow = index + bucketRange_;
00116 while(index < lastrow) {
00117 T filterVal = filter_->read(index);
00118 if((filterVal == 0) && (emptyidx == cells_ + 1)) {
00119 emptyidx = index;
00120 } else if(filterVal == fp) {
00121 ++collisions_;
00122 dict_[hpDictKeyValue(IDs, len)] = value;
00123 return cells_ + 1;
00124 }
00125 ++index;
00126 }
00127 UTIL_THROW_IF2((emptyidx >= index) || (filter_->read(emptyidx) != 0), "Error");
00128 T code = (T)qtizer_->code(value);
00129 filter_->write(emptyidx, fp);
00130 values_->write(emptyidx, code);
00131 ++idxTracker_[bucket];
00132 return emptyidx;
00133 } else {
00134 dict_[hpDictKeyValue(IDs, len)] = value;
00135 return cells_ + 1;
00136 }
00137 }
00138
00139 template<typename T>
00140 bool PerfectHash<T>::update(const wordID_t* IDs, const int len,
00141 const count_t value, hpdEntry_t& hpdAddr, uint64_t& filterIdx)
00142 {
00143
00144 filterIdx = cells_ + 1;
00145 std::string skey = hpDictKeyValue(IDs, len);
00146 if((hpdAddr = dict_.find(skey)) != dict_.end()) {
00147 hpdAddr->second = value;
00148 return true;
00149 }
00150
00151
00152 count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len, 0));
00153
00154 T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
00155 uint64_t index = bucket * bucketRange_,
00156 lastrow = index + bucketRange_;
00157 while(index < lastrow) {
00158 T filterVal = filter_->read(index);
00159 if(filterVal == fp) {
00160 values_->write(index, (T)qtizer_->code(value));
00161 filterIdx = index;
00162 return true;
00163 }
00164 ++index;
00165 }
00166
00167 return false;
00168 }
00169
00170 template<typename T>
00171 int PerfectHash<T>::query(const wordID_t* IDs, const int len,
00172 hpdEntry_t& hpdAddr, uint64_t& filterIdx)
00173 {
00174
00175 std::string skey = hpDictKeyValue(IDs, len);
00176 if((hpdAddr = dict_.find(skey)) != dict_.end()) {
00177 filterIdx = cells_ + 1;
00178 return(hpdAddr->second);
00179 } else {
00180
00181
00182 count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len, 0));
00183
00184 T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
00185
00186 uint64_t index = bucket * bucketRange_,
00187 lastrow = index + bucketRange_;
00188 for(; index < lastrow; ++index) {
00189 if(filter_->read(index) == fp) {
00190
00191
00192 filterIdx = index;
00193 hpdAddr = dict_.end();
00194 return (int)qtizer_->value(values_->read(index));
00195 }
00196 }
00197 }
00198 return -1;
00199 }
00200
00201 template<typename T>
00202 void PerfectHash<T>::remove(const wordID_t* IDs, const int len)
00203 {
00204
00205 std::string skey = hpDictKeyValue(IDs, len);
00206 if(dict_.find(skey) != dict_.end())
00207 dict_.erase(skey);
00208 else {
00209
00210
00211 count_t bucket = (bucketHash_->size() > 1? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len, 0));
00212
00213 T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
00214
00215 uint64_t index = bucket * bucketRange_,
00216 lastrow = index + bucketRange_;
00217 for(; index < lastrow; ++index) {
00218 if(filter_->read(index) == fp) {
00219 filter_->write(index, 0);
00220 values_->write(index, 0);
00221 --idxTracker_[bucket];
00222 break;
00223 }
00224 }
00225 }
00226 }
00227
00228 template<typename T>
00229 void PerfectHash<T>::remove(uint64_t index)
00230 {
00231 UTIL_THROW_IF2(index >= cells_, "Out of bound: " << index);
00232 UTIL_THROW_IF2(filter_->read(index) == 0, "Error");
00233 filter_->write(index, 0);
00234 values_->write(index, 0);
00235
00236 count_t bucket = index / bucketRange_;
00237 --idxTracker_[bucket];
00238 }
00239
00240 template<typename T>
00241 T PerfectHash<T>::nonZeroSignature(const wordID_t* IDs, const int len,
00242 count_t bucket)
00243 {
00244 count_t h = bucket;
00245 T fingerprint(0);
00246 do {
00247 fingerprint = fingerHash_->hash(IDs, len, h);
00248 h += (h < fingerHash_->size() - 1 ? 1 : -h);
00249 } while((fingerprint == 0) && (h != bucket));
00250 if(fingerprint == 0)
00251 std::cerr << "WARNING: Unable to find non-zero signature for ngram\n" << std::endl;
00252 return fingerprint;
00253 }
00254
00255 template<typename T>
00256 std::string PerfectHash<T>::hpDictKeyValue(const wordID_t* IDs, const int len)
00257 {
00258 std::string skey(" ");
00259 for(int i = 0; i < len; ++i)
00260 skey += Utils::IntToStr(IDs[i]) + "¬";
00261 Utils::trim(skey);
00262 return skey;
00263 }
00264
00265 template<typename T>
00266 count_t PerfectHash<T>::hpDictMemUse()
00267 {
00268
00269 return (count_t) sizeof(hpDict_t::value_type)* dict_.size() >> 20;
00270 }
00271
00272 template<typename T>
00273 count_t PerfectHash<T>::bucketsMemUse()
00274 {
00275
00276 return (count_t) (filter_->size() + values_->size());
00277 }
00278
00279 template<typename T>
00280 void PerfectHash<T>::save(Moses::FileHandler* fout)
00281 {
00282 UTIL_THROW_IF2(fout == 0, "Invalid file handle");
00283 std::cerr << "\tSaving perfect hash parameters...\n";
00284 fout->write((char*)&hitMask_, sizeof(hitMask_));
00285 fout->write((char*)&memBound_, sizeof(memBound_));
00286 fout->write((char*)&cellWidth_, sizeof(cellWidth_));
00287 fout->write((char*)&cells_, sizeof(cells_));
00288 fout->write((char*)&totBuckets_, sizeof(totBuckets_));
00289 fout->write((char*)&bucketRange_, sizeof(bucketRange_));
00290 fout->write((char*)idxTracker_, totBuckets_ * sizeof(idxTracker_[0]));
00291 qtizer_->save(fout);
00292 std::cerr << "\tSaving hash functions...\n";
00293 fingerHash_->save(fout);
00294 bucketHash_->save(fout);
00295 std::cerr << "\tSaving bit filter...\n";
00296 filter_->save(fout);
00297 values_->save(fout);
00298 std::cerr << "\tSaving high performance dictionary...\n";
00299 count_t size = dict_.size();
00300 fout->write((char*)&size, sizeof(count_t));
00301 *fout << std::endl;
00302 iterate(dict_, t)
00303 *fout << t->first << "\t" << t->second << "\n";
00304 }
00305
00306 template<typename T>
00307 void PerfectHash<T>::load(Moses::FileHandler* fin)
00308 {
00309 UTIL_THROW_IF2(fin == 0, "Invalid file handle");
00310 std::cerr << "\tLoading perfect hash parameters...\n";
00311 fin->read((char*)&hitMask_, sizeof(hitMask_));
00312 fin->read((char*)&memBound_, sizeof(memBound_));
00313 fin->read((char*)&cellWidth_, sizeof(cellWidth_));
00314 fin->read((char*)&cells_, sizeof(cells_));
00315 fin->read((char*)&totBuckets_, sizeof(totBuckets_));
00316 fin->read((char*)&bucketRange_, sizeof(bucketRange_));
00317 idxTracker_ = new uint8_t[totBuckets_];
00318 fin->read((char*)idxTracker_, totBuckets_ * sizeof(idxTracker_[0]));
00319 qtizer_ = new LogQtizer(fin);
00320 std::cerr << "\tLoading hash functions...\n";
00321 fingerHash_ = new UnivHash_linear<T>(fin);
00322 bucketHash_ = new UnivHash_linear<count_t>(fin);
00323 std::cerr << "\tLoading bit filter...\n";
00324 filter_ = new Filter<T>(fin);
00325 values_ = new Filter<T>(fin);
00326 std::cerr << "\tLoading HPD...\n";
00327 count_t size = 0;
00328 fin->read((char*)&size, sizeof(count_t));
00329 fin->ignore(256, '\n');
00330 std::string line;
00331 hpDict_t::key_type key;
00332 hpDict_t::mapped_type val;
00333 for(count_t i=0; i < size; ++i) {
00334 getline(*fin, line);
00335 Utils::trim(line);
00336 std::istringstream ss(line.c_str());
00337 ss >> key, ss >> val;
00338 dict_[key] = val;
00339 }
00340 std::cerr << "\tHPD size=" << dict_.size() << std::endl;
00341 std::cerr << "Finished loading ORLM." << std::endl;
00342 }
00343
00344 template<typename T>
00345 void PerfectHash<T>::analyze()
00346 {
00347 std::cerr << "Analyzing Dynamic Bloomier Filter...\n";
00348
00349 uint8_t* bucketCnt = new uint8_t[totBuckets_];
00350 unsigned largestBucket = 0, totalCellsSet = 0,
00351 smallestBucket = bucketRange_, totalZeroes = 0;
00352 int curBucket = -1, fullBuckets(0);
00353 for(int i = 0; i < totBuckets_; ++i) bucketCnt[i] = 0;
00354 for(uint64_t i =0; i < cells_; ++i) {
00355 if(i % bucketRange_ == 0) ++curBucket;
00356 if(filter_->read(i) != 0) {
00357 ++bucketCnt[curBucket];
00358 ++totalCellsSet;
00359 } else ++totalZeroes;
00360 }
00361 count_t bi = 0, si = 0;
00362 for(int i = 0; i < totBuckets_; ++i) {
00363 if(bucketCnt[i] > largestBucket) {
00364 largestBucket = bucketCnt[i];
00365 bi = i;
00366 } else if(bucketCnt[i] < smallestBucket) {
00367 smallestBucket = bucketCnt[i];
00368 si = i;
00369 }
00370 }
00371 count_t trackerCells(0);
00372 for(int i = 0; i < totBuckets_; i++) {
00373 trackerCells += idxTracker_[i];
00374 if(idxTracker_[i] == bucketRange_)
00375 ++fullBuckets;
00376 }
00377 for(int i = 0; i < totBuckets_; ++i) {
00378 if(bucketCnt[i] != idxTracker_[i])
00379 std::cerr << "bucketCnt[" << i << "] = " << (int)bucketCnt[i] <<
00380 "\tidxTracker_[" << i << "] = " << (int)idxTracker_[i] << std::endl;
00381 }
00382 std::cerr << "total cells= " << cells_ << std::endl;
00383 std::cerr << "total buckets= " << totBuckets_ << std::endl;
00384 std::cerr << "bucket range= " << (int)bucketRange_ << std::endl;
00385 std::cerr << "fingerprint bits= " << cellWidth_ << std::endl;
00386 std::cerr << "total cells set= " << totalCellsSet;
00387 std::cerr << " (idxTracker set = " << trackerCells << ")" << std::endl;
00388 std::cerr << "total zeroes=" << totalZeroes;
00389 std::cerr << " (idxTracker zeros = " << cells_ - trackerCells << ")" << std::endl;
00390 std::cerr << "largest bucket (" << bi << ") size= " << largestBucket << std::endl;
00391 std::cerr << "smallest bucket (" << si << ") size= " << smallestBucket << std::endl;
00392 std::cerr << "last bucket size= " << (int)bucketCnt[totBuckets_ - 1] <<
00393 " (idxTracker last bucket size = " << (int)idxTracker_[totBuckets_ - 1] << ")" << std::endl;
00394 std::cerr << "total buckets full = " << fullBuckets << std::endl;
00395 std::cerr << "total collision errors= " << collisions_ << std::endl;
00396 std::cerr << "high performance dictionary size= " << dict_.size() << std::endl;
00397 std::cerr << "high performance dictionary MBs= " << hpDictMemUse() << std::endl;
00398 std::cerr << "filter MBs= " << filter_->size() << std::endl;
00399 std::cerr << "values MBs= " << values_->size() << std::endl;
00400 delete[] bucketCnt;
00401 }
00402
00403 template<typename T>
00404 bool PerfectHash<T>::update2(const wordID_t* IDs, const int len,
00405 const count_t value, hpdEntry_t& hpdAddr, uint64_t& filterIdx)
00406 {
00407
00408 filterIdx = cells_ + 1;
00409 std::string skey = hpDictKeyValue(IDs, len);
00410 if((hpdAddr = dict_.find(skey)) != dict_.end()) {
00411 hpdAddr->second += value;
00412 return true;
00413 }
00414
00415
00416 count_t bucket = (bucketHash_->size() > 1 ? bucketHash_->hash(IDs, len, len) : bucketHash_->hash(IDs, len, 0));
00417
00418 T fp = nonZeroSignature(IDs, len, (bucket % MAX_HASH_FUNCS));
00419 uint64_t index = bucket * bucketRange_,
00420 lastrow = index + bucketRange_;
00421 while(index < lastrow) {
00422 T filterVal = filter_->read(index);
00423 if(filterVal == fp) {
00424 int oldval = (int)qtizer_->value(values_->read(index));
00425 values_->write(index, (T)qtizer_->code(oldval + value));
00426 filterIdx = index;
00427 return true;
00428 }
00429 ++index;
00430 }
00431
00432 insert(IDs, len, value);
00433 return false;
00434 }
00435
00436 #endif
00437