00001
00002 #include "ProbingPT.h"
00003 #include "moses/StaticData.h"
00004 #include "moses/FactorCollection.h"
00005 #include "moses/TargetPhraseCollection.h"
00006 #include "moses/InputFileStream.h"
00007 #include "probingpt/querying.h"
00008 #include "probingpt/probing_hash_utils.h"
00009
00010 using namespace std;
00011
00012 namespace Moses
00013 {
00014 ProbingPT::ProbingPT(const std::string &line)
00015 : PhraseDictionary(line,true)
00016 ,m_engine(NULL)
00017 ,load_method(util::POPULATE_OR_READ)
00018 {
00019 ReadParameters();
00020
00021 assert(m_input.size() == 1);
00022 assert(m_output.size() == 1);
00023 }
00024
00025 ProbingPT::~ProbingPT()
00026 {
00027 delete m_engine;
00028 }
00029
00030 void ProbingPT::Load(AllOptions::ptr const& opts)
00031 {
00032 m_options = opts;
00033 SetFeaturesToApply();
00034
00035 m_engine = new probingpt::QueryEngine(m_filePath.c_str(), load_method);
00036
00037 m_unkId = 456456546456;
00038
00039 FactorCollection &vocab = FactorCollection::Instance();
00040
00041
00042 const std::map<uint64_t, std::string> &sourceVocab =
00043 m_engine->getSourceVocab();
00044 std::map<uint64_t, std::string>::const_iterator iterSource;
00045 for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end();
00046 ++iterSource) {
00047 string wordStr = iterSource->second;
00048
00049
00050 const Factor *factor = vocab.AddFactor(wordStr);
00051
00052 uint64_t probingId = iterSource->first;
00053 size_t factorId = factor->GetId();
00054
00055 if (factorId >= m_sourceVocab.size()) {
00056 m_sourceVocab.resize(factorId + 1, m_unkId);
00057 }
00058 m_sourceVocab[factorId] = probingId;
00059 }
00060
00061
00062 InputFileStream targetVocabStrme(m_filePath + "/TargetVocab.dat");
00063 string line;
00064 while (getline(targetVocabStrme, line)) {
00065 vector<string> toks = Tokenize(line, "\t");
00066 UTIL_THROW_IF2(toks.size() != 2, string("Incorrect format:") + line + "\n");
00067
00068
00069
00070 const Factor *factor = vocab.AddFactor(toks[0]);
00071 uint32_t probingId = Scan<uint32_t>(toks[1]);
00072
00073 if (probingId >= m_targetVocab.size()) {
00074 m_targetVocab.resize(probingId + 1);
00075 }
00076
00077 m_targetVocab[probingId] = factor;
00078 }
00079
00080
00081 CreateAlignmentMap(m_filePath + "/Alignments.dat");
00082
00083
00084 string filePath = m_filePath + "/TargetColl.dat";
00085 file.open(filePath.c_str());
00086 if (!file.is_open()) {
00087 throw "Couldn't open file ";
00088 }
00089
00090 data = file.data();
00091
00092
00093
00094
00095
00096 }
00097
00098 void ProbingPT::CreateAlignmentMap(const std::string path)
00099 {
00100 const std::vector< std::vector<unsigned char> > &probingAlignColl = m_engine->getAlignments();
00101 m_aligns.resize(probingAlignColl.size(), NULL);
00102
00103 for (size_t i = 0; i < probingAlignColl.size(); ++i) {
00104 AlignmentInfo::CollType aligns;
00105
00106 const std::vector<unsigned char> &probingAligns = probingAlignColl[i];
00107 for (size_t j = 0; j < probingAligns.size(); j += 2) {
00108 size_t startPos = probingAligns[j];
00109 size_t endPos = probingAligns[j+1];
00110
00111 aligns.insert(std::pair<size_t,size_t>(startPos, endPos));
00112 }
00113
00114 const AlignmentInfo *align = AlignmentInfoCollection::Instance().Add(aligns);
00115 m_aligns[i] = align;
00116
00117 }
00118 }
00119
00120 void ProbingPT::SetParameter(const std::string& key, const std::string& value)
00121 {
00122 if (key == "load") {
00123 if (value == "lazy") {
00124 load_method = util::LAZY;
00125 } else if (value == "populate_or_lazy") {
00126 load_method = util::POPULATE_OR_LAZY;
00127 } else if (value == "populate_or_read" || value == "populate") {
00128 load_method = util::POPULATE_OR_READ;
00129 } else if (value == "read") {
00130 load_method = util::READ;
00131 } else if (value == "parallel_read") {
00132 load_method = util::PARALLEL_READ;
00133 } else {
00134 UTIL_THROW2("load method not supported" << value);
00135 }
00136 } else {
00137 PhraseDictionary::SetParameter(key, value);
00138 }
00139
00140 }
00141
00142 void ProbingPT::InitializeForInput(ttasksptr const& ttask)
00143 {
00144
00145 }
00146
00147 void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const
00148 {
00149 InputPathList::const_iterator iter;
00150 for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter) {
00151 InputPath &inputPath = **iter;
00152 const Phrase &sourcePhrase = inputPath.GetPhrase();
00153
00154 if (sourcePhrase.GetSize() > m_options->search.max_phrase_length) {
00155 continue;
00156 }
00157
00158 TargetPhraseCollection::shared_ptr tpColl = CreateTargetPhrase(sourcePhrase);
00159 inputPath.SetTargetPhrases(*this, tpColl, NULL);
00160 }
00161 }
00162
00163 TargetPhraseCollection::shared_ptr ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const
00164 {
00165
00166 assert(sourcePhrase.GetSize());
00167
00168 std::pair<bool, uint64_t> keyStruct = GetKey(sourcePhrase);
00169 if (!keyStruct.first) {
00170 return TargetPhraseCollection::shared_ptr();
00171 }
00172
00173
00174 CachePb::const_iterator iter = m_cachePb.find(keyStruct.second);
00175 if (iter != m_cachePb.end()) {
00176
00177 TargetPhraseCollection *tps = iter->second;
00178 return TargetPhraseCollection::shared_ptr(tps);
00179 }
00180
00181
00182 TargetPhraseCollection *tps = CreateTargetPhrases(sourcePhrase,
00183 keyStruct.second);
00184 return TargetPhraseCollection::shared_ptr(tps);
00185 }
00186
00187 std::pair<bool, uint64_t> ProbingPT::GetKey(const Phrase &sourcePhrase) const
00188 {
00189 std::pair<bool, uint64_t> ret;
00190
00191
00192 size_t sourceSize = sourcePhrase.GetSize();
00193 assert(sourceSize);
00194
00195 uint64_t probingSource[sourceSize];
00196 GetSourceProbingIds(sourcePhrase, ret.first, probingSource);
00197 if (!ret.first) {
00198
00199
00200 } else {
00201 ret.second = m_engine->getKey(probingSource, sourceSize);
00202 }
00203
00204 return ret;
00205
00206 }
00207
00208 void ProbingPT::GetSourceProbingIds(const Phrase &sourcePhrase,
00209 bool &ok, uint64_t probingSource[]) const
00210 {
00211
00212 size_t size = sourcePhrase.GetSize();
00213 for (size_t i = 0; i < size; ++i) {
00214 const Word &word = sourcePhrase.GetWord(i);
00215 uint64_t probingId = GetSourceProbingId(word);
00216 if (probingId == m_unkId) {
00217 ok = false;
00218 return;
00219 } else {
00220 probingSource[i] = probingId;
00221 }
00222 }
00223
00224 ok = true;
00225 }
00226
00227 uint64_t ProbingPT::GetSourceProbingId(const Word &word) const
00228 {
00229 uint64_t ret = 0;
00230
00231 for (size_t i = 0; i < m_input.size(); ++i) {
00232 FactorType factorType = m_input[i];
00233 const Factor *factor = word[factorType];
00234
00235 size_t factorId = factor->GetId();
00236 if (factorId >= m_sourceVocab.size()) {
00237 return m_unkId;
00238 }
00239 ret += m_sourceVocab[factorId];
00240 }
00241
00242 return ret;
00243 }
00244
00245 TargetPhraseCollection *ProbingPT::CreateTargetPhrases(
00246 const Phrase &sourcePhrase, uint64_t key) const
00247 {
00248 TargetPhraseCollection *tps = NULL;
00249
00250
00251 std::pair<bool, uint64_t> query_result;
00252 query_result = m_engine->query(key);
00253
00254
00255 if (query_result.first) {
00256 const char *offset = data + query_result.second;
00257 uint64_t *numTP = (uint64_t*) offset;
00258
00259 tps = new TargetPhraseCollection();
00260
00261 offset += sizeof(uint64_t);
00262 for (size_t i = 0; i < *numTP; ++i) {
00263 TargetPhrase *tp = CreateTargetPhrase(offset);
00264 assert(tp);
00265 tp->EvaluateInIsolation(sourcePhrase, GetFeaturesToApply());
00266
00267 tps->Add(tp);
00268
00269 }
00270
00271 tps->Prune(true, m_tableLimit);
00272
00273 }
00274
00275 return tps;
00276
00277 }
00278
00279 TargetPhrase *ProbingPT::CreateTargetPhrase(
00280 const char *&offset) const
00281 {
00282 probingpt::TargetPhraseInfo *tpInfo = (probingpt::TargetPhraseInfo*) offset;
00283 size_t numRealWords = tpInfo->numWords / m_output.size();
00284
00285 TargetPhrase *tp = new TargetPhrase(this);
00286
00287 offset += sizeof(probingpt::TargetPhraseInfo);
00288
00289
00290 float *scores = (float*) offset;
00291
00292 size_t totalNumScores = m_engine->num_scores + m_engine->num_lex_scores;
00293
00294 if (m_engine->logProb) {
00295
00296 tp->GetScoreBreakdown().PlusEquals(this, scores);
00297
00298
00299
00300
00301
00302
00303
00304 } else {
00305
00306 float logScores[totalNumScores];
00307 for (size_t i = 0; i < totalNumScores; ++i) {
00308 logScores[i] = FloorScore(TransformScore(scores[i]));
00309 }
00310
00311
00312 tp->GetScoreBreakdown().PlusEquals(this, logScores);
00313
00314
00315
00316
00317
00318
00319
00320
00321 }
00322
00323 offset += sizeof(float) * totalNumScores;
00324
00325
00326 for (size_t targetPos = 0; targetPos < numRealWords; ++targetPos) {
00327 Word &word = tp->AddWord();
00328 for (size_t i = 0; i < m_output.size(); ++i) {
00329 FactorType factorType = m_output[i];
00330
00331 uint32_t *probingId = (uint32_t*) offset;
00332
00333 const Factor *factor = GetTargetFactor(*probingId);
00334 assert(factor);
00335
00336 word[factorType] = factor;
00337
00338 offset += sizeof(uint32_t);
00339 }
00340 }
00341
00342
00343 uint32_t alignTerm = tpInfo->alignTerm;
00344
00345 UTIL_THROW_IF2(alignTerm >= m_aligns.size(), "Unknown alignInd");
00346 tp->SetAlignTerm(m_aligns[alignTerm]);
00347
00348
00349
00350 return tp;
00351 }
00352
00354
00355
00356 ChartRuleLookupManager *ProbingPT::CreateRuleLookupManager(
00357 const ChartParser &,
00358 const ChartCellCollectionBase &,
00359 std::size_t)
00360 {
00361 abort();
00362 return NULL;
00363 }
00364
00365 TO_STRING_BODY(ProbingPT);
00366
00367
00368 ostream& operator<<(ostream& out, const ProbingPT& phraseDict)
00369 {
00370 return out;
00371 }
00372
00373 }