00001
00002
00003
00004 #ifndef moses_PDTAimp_h
00005 #define moses_PDTAimp_h
00006
00007 #include "StaticData.h"
00008 #include "PhraseDictionaryTree.h"
00009 #include "UniqueObject.h"
00010 #include "InputFileStream.h"
00011 #include "PhraseDictionaryTreeAdaptor.h"
00012 #include "Util.h"
00013
00014 namespace Moses
00015 {
00016
00017 inline double addLogScale(double x,double y)
00018 {
00019 if(x>y) return addLogScale(y,x);
00020 else return x+log(1.0+exp(y-x));
00021 }
00022
00023 inline double Exp(double x)
00024 {
00025 return exp(x);
00026 }
00027
00028 class PDTAimp
00029 {
00030
00031 friend class PhraseDictionaryTreeAdaptor;
00032
00033 protected:
00034 PDTAimp(PhraseDictionaryTreeAdaptor *p,unsigned nis)
00035 : m_languageModels(0),m_weightWP(0.0),m_dict(0),
00036 m_obj(p),useCache(1),m_numInputScores(nis),totalE(0),distinctE(0) {}
00037
00038 public:
00039 std::vector<float> m_weights;
00040 LMList const* m_languageModels;
00041 float m_weightWP;
00042 std::vector<FactorType> m_input,m_output;
00043 PhraseDictionaryTree *m_dict;
00044 typedef std::vector<TargetPhraseCollection const*> vTPC;
00045 mutable vTPC m_tgtColls;
00046
00047 typedef std::map<Phrase,TargetPhraseCollection const*> MapSrc2Tgt;
00048 mutable MapSrc2Tgt m_cache;
00049 PhraseDictionaryTreeAdaptor *m_obj;
00050 int useCache;
00051
00052 std::vector<vTPC> m_rangeCache;
00053 unsigned m_numInputScores;
00054
00055 UniqueObjectManager<Phrase> uniqSrcPhr;
00056
00057 size_t totalE,distinctE;
00058 std::vector<size_t> path1Best,pathExplored;
00059 std::vector<double> pathCN;
00060
00061 ~PDTAimp() {
00062 CleanUp();
00063 delete m_dict;
00064
00065 if (StaticData::Instance().GetVerboseLevel() >= 2) {
00066
00067 TRACE_ERR("tgt candidates stats: total="<<totalE<<"; distinct="
00068 <<distinctE<<" ("<<distinctE/(0.01*totalE)<<"); duplicates="
00069 <<totalE-distinctE<<" ("<<(totalE-distinctE)/(0.01*totalE)
00070 <<")\n");
00071
00072 TRACE_ERR("\npath statistics\n");
00073
00074 if(path1Best.size()) {
00075 TRACE_ERR("1-best: ");
00076 std::copy(path1Best.begin()+1,path1Best.end(),
00077 std::ostream_iterator<size_t>(std::cerr," \t"));
00078 TRACE_ERR("\n");
00079 }
00080 if(pathCN.size()) {
00081 TRACE_ERR("CN (full): ");
00082 std::transform(pathCN.begin()+1
00083 ,pathCN.end()
00084 ,std::ostream_iterator<double>(std::cerr," \t")
00085 ,Exp);
00086 TRACE_ERR("\n");
00087 }
00088 if(pathExplored.size()) {
00089 TRACE_ERR("CN (explored): ");
00090 std::copy(pathExplored.begin()+1,pathExplored.end(),
00091 std::ostream_iterator<size_t>(std::cerr," \t"));
00092 TRACE_ERR("\n");
00093 }
00094 }
00095
00096 }
00097
00098 void Factors2String(Word const& w,std::string& s) const {
00099 s=w.GetString(m_input,false);
00100 }
00101
00102 void CleanUp() {
00103 CHECK(m_dict);
00104 m_dict->FreeMemory();
00105 for(size_t i=0; i<m_tgtColls.size(); ++i) delete m_tgtColls[i];
00106 m_tgtColls.clear();
00107 m_cache.clear();
00108 m_rangeCache.clear();
00109 uniqSrcPhr.clear();
00110 }
00111
00112 TargetPhraseCollection const*
00113 GetTargetPhraseCollection(Phrase const &src) const {
00114
00115 CHECK(m_dict);
00116 if(src.GetSize()==0) return 0;
00117
00118 std::pair<MapSrc2Tgt::iterator,bool> piter;
00119 if(useCache) {
00120 piter=m_cache.insert(std::make_pair(src,static_cast<TargetPhraseCollection const*>(0)));
00121 if(!piter.second) return piter.first->second;
00122 } else if (m_cache.size()) {
00123 MapSrc2Tgt::const_iterator i=m_cache.find(src);
00124 return (i!=m_cache.end() ? i->second : 0);
00125 }
00126
00127 std::vector<std::string> srcString(src.GetSize());
00128
00129 for(size_t i=0; i<srcString.size(); ++i) {
00130 Factors2String(src.GetWord(i),srcString[i]);
00131 }
00132
00133
00134 std::vector<StringTgtCand> cands;
00135 std::vector<std::string> wacands;
00136 m_dict->GetTargetCandidates(srcString,cands,wacands);
00137 if(cands.empty()) {
00138 return 0;
00139 }
00140
00141 std::vector<TargetPhrase> tCands;
00142 tCands.reserve(cands.size());
00143 std::vector<std::pair<float,size_t> > costs;
00144 costs.reserve(cands.size());
00145
00146
00147 for(size_t i=0; i<cands.size(); ++i) {
00148 TargetPhrase targetPhrase(Output);
00149
00150 StringTgtCand::first_type const& factorStrings=cands[i].first;
00151 StringTgtCand::second_type const& probVector=cands[i].second;
00152
00153 std::vector<float> scoreVector(probVector.size());
00154 std::transform(probVector.begin(),probVector.end(),scoreVector.begin(),
00155 TransformScore);
00156 std::transform(scoreVector.begin(),scoreVector.end(),scoreVector.begin(),
00157 FloorScore);
00158
00159 CreateTargetPhrase(targetPhrase,factorStrings,scoreVector,wacands[i],&src);
00160 costs.push_back(std::make_pair(-targetPhrase.GetFutureScore(),tCands.size()));
00161 tCands.push_back(targetPhrase);
00162 }
00163
00164 TargetPhraseCollection *rv;
00165 rv=PruneTargetCandidates(tCands,costs);
00166 if(rv->IsEmpty()) {
00167 delete rv;
00168 return 0;
00169 } else {
00170 if(useCache) piter.first->second=rv;
00171 m_tgtColls.push_back(rv);
00172 return rv;
00173 }
00174
00175 }
00176
00177
00178
00179 void Create(const std::vector<FactorType> &input
00180 , const std::vector<FactorType> &output
00181 , const std::string &filePath
00182 , const std::vector<float> &weight
00183 , const LMList &languageModels
00184 , float weightWP
00185 ) {
00186
00187
00188 m_dict=new PhraseDictionaryTree(weight.size()-m_numInputScores);
00189 m_input=input;
00190 m_output=output;
00191 m_languageModels=&languageModels;
00192 m_weightWP=weightWP;
00193 m_weights=weight;
00194
00195 const StaticData &staticData = StaticData::Instance();
00196 m_dict->UseWordAlignment(staticData.UseAlignmentInfo());
00197
00198 std::string binFname=filePath+".binphr.idx";
00199 if(!FileExists(binFname.c_str())) {
00200 UserMessage::Add( "bin ttable does not exist\n");
00201 abort();
00202
00203
00204
00205 }
00206 TRACE_ERR( "reading bin ttable\n");
00207
00208 bool res=m_dict->Read(filePath);
00209 if (!res) {
00210 std::stringstream strme;
00211 strme << "bin ttable was read in a wrong way\n";
00212 UserMessage::Add(strme.str());
00213 exit(1);
00214 }
00215 }
00216
00217 typedef PhraseDictionaryTree::PrefixPtr PPtr;
00218 typedef unsigned short Position;
00219 typedef std::pair<Position,Position> Range;
00220 struct State {
00221 PPtr ptr;
00222 Range range;
00223 std::vector<float> scores;
00224 Phrase src;
00225
00226 State() : range(0,0),scores(0),src(ARRAY_SIZE_INCR) {}
00227 State(Position b,Position e,const PPtr& v,const std::vector<float>& sv=std::vector<float>(0))
00228 : ptr(v),range(b,e),scores(sv),src(ARRAY_SIZE_INCR) {}
00229 State(Range const& r,const PPtr& v,const std::vector<float>& sv=std::vector<float>(0))
00230 : ptr(v),range(r),scores(sv),src(ARRAY_SIZE_INCR) {}
00231
00232 Position begin() const {
00233 return range.first;
00234 }
00235 Position end() const {
00236 return range.second;
00237 }
00238 std::vector<float> GetScores() const {
00239 return scores;
00240 }
00241
00242 friend std::ostream& operator<<(std::ostream& out,State const& s) {
00243 out<<" R=("<<s.begin()<<","<<s.end()<<"),";
00244 for(std::vector<float>::const_iterator scoreIterator = s.GetScores().begin(); scoreIterator<s.GetScores().end(); scoreIterator++) {
00245 out<<", "<<*scoreIterator;
00246 }
00247 out<<")";
00248 return out;
00249 }
00250
00251 };
00252
00253 void CreateTargetPhrase(TargetPhrase& targetPhrase,
00254 StringTgtCand::first_type const& factorStrings,
00255 StringTgtCand::second_type const& scoreVector,
00256 const std::string& alignmentString,
00257 Phrase const* srcPtr=0) const {
00258 CreateTargetPhrase(targetPhrase, factorStrings, scoreVector, srcPtr);
00259 targetPhrase.SetAlignmentInfo(alignmentString);
00260 }
00261
00262
00263 void CreateTargetPhrase(TargetPhrase& targetPhrase,
00264 StringTgtCand::first_type const& factorStrings,
00265 StringTgtCand::second_type const& scoreVector,
00266 Phrase const* srcPtr=0) const {
00267 FactorCollection &factorCollection = FactorCollection::Instance();
00268
00269 for(size_t k=0; k<factorStrings.size(); ++k) {
00270 std::vector<std::string> factors=TokenizeMultiCharSeparator(*factorStrings[k],StaticData::Instance().GetFactorDelimiter());
00271 CHECK(factors.size()==m_output.size());
00272 Word& w=targetPhrase.AddWord();
00273 for(size_t l=0; l<m_output.size(); ++l) {
00274 w[m_output[l]]= factorCollection.AddFactor(Output, m_output[l], factors[l]);
00275 }
00276 }
00277 targetPhrase.SetScore(m_obj->GetFeature(), scoreVector, m_weights, m_weightWP, *m_languageModels);
00278 targetPhrase.SetSourcePhrase(srcPtr);
00279 }
00280
00281
00282
00283
00284 TargetPhraseCollection* PruneTargetCandidates(std::vector<TargetPhrase> const & tCands,
00285 std::vector<std::pair<float,size_t> >& costs) const {
00286
00287 TargetPhraseCollection *rv=new TargetPhraseCollection;
00288
00289
00290 std::vector<std::pair<float,size_t> >::iterator nth =
00291 costs.begin() + ((m_obj->m_tableLimit>0 &&
00292 m_obj->m_tableLimit < costs.size()) ?
00293 m_obj->m_tableLimit : costs.size());
00294
00295
00296 std::nth_element(costs.begin(),nth ,costs.end());
00297
00298
00299 for(std::vector<std::pair<float,size_t> >::iterator
00300 it = costs.begin(); it != nth; ++it)
00301 rv->Add(new TargetPhrase(tCands[it->second]));
00302
00303 return rv;
00304 }
00305
00306
00307 struct TScores {
00308 float total;
00309 StringTgtCand::second_type trans;
00310 Phrase const* src;
00311
00312 TScores() : total(0.0),src(0) {}
00313 };
00314
00315 void CacheSource(ConfusionNet const& src) {
00316 CHECK(m_dict);
00317 const size_t srcSize=src.GetSize();
00318
00319 std::vector<size_t> exploredPaths(srcSize+1,0);
00320 std::vector<double> exPathsD(srcSize+1,-1.0);
00321
00322
00323 std::vector<size_t> cnDepths(srcSize,0);
00324 for(size_t i=0; i<srcSize; ++i) cnDepths[i]=src[i].size();
00325
00326 for(size_t len=1; len<=srcSize; ++len)
00327 for(size_t i=0; i<=srcSize-len; ++i) {
00328 double pd=0.0;
00329 for(size_t k=i; k<i+len; ++k) pd+=log(1.0*cnDepths[k]);
00330 exPathsD[len]=(exPathsD[len]>=0.0 ? addLogScale(pd,exPathsD[len]) : pd);
00331 }
00332
00333
00334 if(pathCN.size()<=srcSize) pathCN.resize(srcSize+1,-1.0);
00335 for(size_t len=1; len<=srcSize; ++len)
00336 pathCN[len]=pathCN[len]>=0.0 ? addLogScale(pathCN[len],exPathsD[len]) : exPathsD[len];
00337
00338 if(path1Best.size()<=srcSize) path1Best.resize(srcSize+1,0);
00339 for(size_t len=1; len<=srcSize; ++len) path1Best[len]+=srcSize-len+1;
00340
00341
00342 if (StaticData::Instance().GetVerboseLevel() >= 2 && exPathsD.size()) {
00343 TRACE_ERR("path stats for current CN: \nCN (full): ");
00344 std::transform(exPathsD.begin()+1
00345 ,exPathsD.end()
00346 ,std::ostream_iterator<double>(std::cerr," ")
00347 ,Exp);
00348 TRACE_ERR("\n");
00349 }
00350
00351 typedef StringTgtCand::first_type sPhrase;
00352 typedef std::map<StringTgtCand::first_type,TScores> E2Costs;
00353
00354 std::map<Range,E2Costs> cov2cand;
00355 std::vector<State> stack;
00356 for(Position i=0 ; i < srcSize ; ++i)
00357 stack.push_back(State(i, i, m_dict->GetRoot(), std::vector<float>(m_numInputScores,0.0)));
00358
00359 while(!stack.empty()) {
00360 State curr(stack.back());
00361 stack.pop_back();
00362
00363 CHECK(curr.end()<srcSize);
00364 const ConfusionNet::Column &currCol=src[curr.end()];
00365
00366 for(size_t colidx=0; colidx<currCol.size(); ++colidx) {
00367 const Word& w=currCol[colidx].first;
00368 std::string s;
00369 Factors2String(w,s);
00370 bool isEpsilon=(s=="" || s==EPSILON);
00371
00372
00373 CHECK(currCol[colidx].second.size() >= m_numInputScores);
00374
00375
00376 if(isEpsilon && curr.begin()==curr.end() && curr.begin()>0) continue;
00377
00378
00379
00380 PPtr nextP = (isEpsilon ? curr.ptr : m_dict->Extend(curr.ptr,s));
00381
00382 if(nextP) {
00383 Range newRange(curr.begin(),curr.end()+src.GetColumnIncrement(curr.end(),colidx));
00384
00385
00386 float inputScoreSum = 0;
00387 std::vector<float> newInputScores(m_numInputScores,0.0);
00388 if (m_numInputScores) {
00389 std::transform(currCol[colidx].second.begin(), currCol[colidx].second.end(),
00390 curr.GetScores().begin(),
00391 newInputScores.begin(),
00392 std::plus<float>());
00393
00394
00395
00396
00397
00398 inputScoreSum = std::accumulate(newInputScores.begin(),newInputScores.begin()+m_numInputScores,0.0);
00399 }
00400
00401 Phrase newSrc(curr.src);
00402 if(!isEpsilon) newSrc.AddWord(w);
00403 if(newRange.second<srcSize && inputScoreSum>LOWEST_SCORE) {
00404
00405
00406 stack.push_back(State(newRange,nextP,newInputScores));
00407 stack.back().src=newSrc;
00408 }
00409
00410 std::vector<StringTgtCand> tcands;
00411
00412
00413 m_dict->GetTargetCandidates(nextP,tcands);
00414
00415 if(newRange.second>=exploredPaths.size()+newRange.first)
00416 exploredPaths.resize(newRange.second-newRange.first+1,0);
00417 ++exploredPaths[newRange.second-newRange.first];
00418
00419 totalE+=tcands.size();
00420
00421 if(tcands.size()) {
00422 E2Costs& e2costs=cov2cand[newRange];
00423 Phrase const* srcPtr=uniqSrcPhr(newSrc);
00424 for(size_t i=0; i<tcands.size(); ++i) {
00425
00426 std::vector<float> nscores(newInputScores);
00427
00428
00429 nscores.resize(m_numInputScores+tcands[i].second.size(),0.0f);
00430
00431
00432 std::transform(tcands[i].second.begin(),tcands[i].second.end(),nscores.begin() + m_numInputScores,TransformScore);
00433
00434 CHECK(nscores.size()==m_weights.size());
00435
00436
00437 float score=std::inner_product(nscores.begin(), nscores.end(), m_weights.begin(), 0.0f);
00438
00439
00440 score-=tcands[i].first.size() * m_weightWP;
00441
00442 std::pair<E2Costs::iterator,bool> p=e2costs.insert(std::make_pair(tcands[i].first,TScores()));
00443
00444 if(p.second) ++distinctE;
00445
00446 TScores & scores=p.first->second;
00447 if(p.second || scores.total<score) {
00448 scores.total=score;
00449 scores.trans=nscores;
00450 scores.src=srcPtr;
00451 }
00452 }
00453 }
00454 }
00455 }
00456 }
00457
00458
00459 if (StaticData::Instance().GetVerboseLevel() >= 2 && exploredPaths.size()) {
00460 TRACE_ERR("CN (explored): ");
00461 std::copy(exploredPaths.begin()+1,exploredPaths.end(),
00462 std::ostream_iterator<size_t>(std::cerr," "));
00463 TRACE_ERR("\n");
00464 }
00465
00466 if(pathExplored.size()<exploredPaths.size())
00467 pathExplored.resize(exploredPaths.size(),0);
00468 for(size_t len=1; len<=srcSize; ++len)
00469 pathExplored[len]+=exploredPaths[len];
00470
00471
00472 m_rangeCache.resize(src.GetSize(),vTPC(src.GetSize(),0));
00473
00474 for(std::map<Range,E2Costs>::const_iterator i=cov2cand.begin(); i!=cov2cand.end(); ++i) {
00475 CHECK(i->first.first<m_rangeCache.size());
00476 CHECK(i->first.second>0);
00477 CHECK(static_cast<size_t>(i->first.second-1)<m_rangeCache[i->first.first].size());
00478 CHECK(m_rangeCache[i->first.first][i->first.second-1]==0);
00479
00480 std::vector<TargetPhrase> tCands;
00481 tCands.reserve(i->second.size());
00482 std::vector<std::pair<float,size_t> > costs;
00483 costs.reserve(i->second.size());
00484
00485 for(E2Costs::const_iterator j=i->second.begin(); j!=i->second.end(); ++j) {
00486 TScores const & scores=j->second;
00487 TargetPhrase targetPhrase(Output);
00488 CreateTargetPhrase(targetPhrase,j->first,scores.trans,scores.src);
00489 costs.push_back(std::make_pair(-targetPhrase.GetFutureScore(),tCands.size()));
00490 tCands.push_back(targetPhrase);
00491
00492 }
00493
00494 TargetPhraseCollection *rv=PruneTargetCandidates(tCands,costs);
00495
00496 if(rv->IsEmpty())
00497 delete rv;
00498 else {
00499 m_rangeCache[i->first.first][i->first.second-1]=rv;
00500 m_tgtColls.push_back(rv);
00501 }
00502 }
00503
00504 m_dict->FreeMemory();
00505 }
00506
00507
00508 size_t GetNumInputScores() const {
00509 return m_numInputScores;
00510 }
00511 };
00512
00513 }
00514 #endif