00001
00002
00003 #include "mmsapt.h"
00004 #include <boost/foreach.hpp>
00005 #include <boost/scoped_ptr.hpp>
00006 #include <boost/intrusive_ptr.hpp>
00007 #include <boost/tokenizer.hpp>
00008 #include <boost/thread/locks.hpp>
00009 #include <algorithm>
00010 #include "util/exception.hh"
00011 #include <set>
00012 #include "util/usage.hh"
00013
00014 namespace Moses
00015 {
00016 using namespace sapt;
00017 using namespace std;
00018 using namespace boost;
00019
00020 void
00021 fillIdSeq(Phrase const& mophrase, std::vector<FactorType> const& ifactors,
00022 TokenIndex const& V, vector<id_type>& dest)
00023 {
00024 dest.resize(mophrase.GetSize());
00025 for (size_t i = 0; i < mophrase.GetSize(); ++i)
00026 {
00027
00028 dest[i] = V[mophrase.GetWord(i).GetString(ifactors, false)];
00029 }
00030 }
00031
00032 void
00033 parseLine(string const& line, map<string,string> & param)
00034 {
00035 char_separator<char> sep("; ");
00036 tokenizer<char_separator<char> > tokens(line,sep);
00037 BOOST_FOREACH(string const& t,tokens)
00038 {
00039 size_t i = t.find_first_not_of(" =");
00040 size_t j = t.find_first_of(" =",i+1);
00041 size_t k = t.find_first_not_of(" =",j+1);
00042 UTIL_THROW_IF2(i == string::npos || k == string::npos,
00043 "[" << HERE << "] "
00044 << "Parameter specification error near '"
00045 << t << "' in moses ini line\n"
00046 << line);
00047 assert(i != string::npos);
00048 assert(k != string::npos);
00049 param[t.substr(i,j)] = t.substr(k);
00050 }
00051 }
00052
00053 vector<string> const&
00054 Mmsapt::
00055 GetFeatureNames() const
00056 {
00057 return m_feature_names;
00058 }
00059
00060 Mmsapt::
00061 Mmsapt(string const& line)
00062 : PhraseDictionary(line, false)
00063 , btfix(new mmbitext)
00064 , m_bias_log(NULL)
00065 , m_bias_loglevel(0)
00066 #ifndef NO_MOSES
00067 , m_lr_func(NULL)
00068 #endif
00069 , m_sampling_method(random_sampling)
00070 , bias_key(((char*)this)+3)
00071 , cache_key(((char*)this)+2)
00072 , context_key(((char*)this)+1)
00073 , m_track_coord(false)
00074
00075
00076
00077 {
00078 init(line);
00079 setup_local_feature_functions();
00080
00081
00082
00083
00084
00085 SetFeaturesToApply();
00086
00087 }
00088
00089 void
00090 Mmsapt::
00091 read_config_file(string fname, map<string,string>& param)
00092 {
00093 string line;
00094 ifstream config(fname.c_str());
00095 while (getline(config,line))
00096 {
00097 if (line[0] == '#') continue;
00098 char_separator<char> sep(" \t");
00099 tokenizer<char_separator<char> > tokens(line,sep);
00100 tokenizer<char_separator<char> >::const_iterator t = tokens.begin();
00101 if (t == tokens.end()) continue;
00102 string& foo = param[*t++];
00103 if (t == tokens.end() || foo.size()) continue;
00104
00105 UTIL_THROW_IF2(*t++ != "=" || t == tokens.end(),
00106 "Syntax error in Mmsapt config file '" << fname << "'.");
00107 for (foo = *t++; t != tokens.end(); foo += " " + *t++);
00108 }
00109 }
00110
00111 void
00112 Mmsapt::
00113 register_ff(SPTR<pscorer> const& ff, vector<SPTR<pscorer> > & registry)
00114 {
00115 registry.push_back(ff);
00116 ff->setIndex(m_feature_names.size());
00117 for (int i = 0; i < ff->fcnt(); ++i)
00118 {
00119 m_feature_names.push_back(ff->fname(i));
00120 m_is_logval.push_back(ff->isLogVal(i));
00121 m_is_integer.push_back(ff->isIntegerValued(i));
00122 }
00123 }
00124
00125 bool Mmsapt::isLogVal(int i) const { return m_is_logval.at(i); }
00126 bool Mmsapt::isInteger(int i) const { return m_is_integer.at(i); }
00127
00128 void
00129 Mmsapt::
00130 parse_factor_spec(std::vector<FactorType>& flist, std::string const key)
00131 {
00132 pair<string,string> dflt(key, "0");
00133 string factors = this->param.insert(dflt).first->second;
00134 size_t p = 0, q = factors.find(',');
00135 for (; q < factors.size(); q = factors.find(',', p=q+1))
00136 flist.push_back(atoi(factors.substr(p, q-p).c_str()));
00137 flist.push_back(atoi(factors.substr(p).c_str()));
00138 }
00139
00140
00141 void Mmsapt::init(string const& line)
00142 {
00143 map<string,string>::const_iterator m;
00144 parseLine(line,this->param);
00145
00146 this->m_numScoreComponents = atoi(param["num-features"].c_str());
00147
00148 m = param.find("config");
00149 if (m != param.end())
00150 read_config_file(m->second,param);
00151
00152 m = param.find("base");
00153 if (m != param.end())
00154 {
00155 m_bname = m->second;
00156 m = param.find("path");
00157 UTIL_THROW_IF2((m != param.end() && m->second != m_bname),
00158 "Conflicting aliases for path:\n"
00159 << "path=" << string(m->second) << "\n"
00160 << "base=" << m_bname.c_str() );
00161 }
00162 else m_bname = param["path"];
00163 L1 = param["L1"];
00164 L2 = param["L2"];
00165
00166 UTIL_THROW_IF2(m_bname.size() == 0, "Missing corpus base name at " << HERE);
00167 UTIL_THROW_IF2(L1.size() == 0, "Missing L1 tag at " << HERE);
00168 UTIL_THROW_IF2(L2.size() == 0, "Missing L2 tag at " << HERE);
00169
00170
00171 parse_factor_spec(m_ifactor,"input-factor");
00172 parse_factor_spec(m_ofactor,"output-factor");
00173
00174
00175 m_inputFactors = FactorMask(m_ifactor);
00176 m_outputFactors = FactorMask(m_ofactor);
00177
00178 pair<string,string> dflt = pair<string,string> ("smooth",".01");
00179 m_lbop_conf = atof(param.insert(dflt).first->second.c_str());
00180
00181 dflt = pair<string,string> ("lexalpha","0");
00182 m_lex_alpha = atof(param.insert(dflt).first->second.c_str());
00183
00184 dflt = pair<string,string> ("sample","1000");
00185 m_default_sample_size = atoi(param.insert(dflt).first->second.c_str());
00186
00187 dflt = pair<string,string> ("min-sample","0");
00188 m_min_sample_size = atoi(param.insert(dflt).first->second.c_str());
00189
00190 dflt = pair<string,string>("workers","0");
00191 m_workers = atoi(param.insert(dflt).first->second.c_str());
00192 if (m_workers == 0) m_workers = StaticData::Instance().ThreadCount();
00193 else m_workers = min(m_workers,size_t(boost::thread::hardware_concurrency()));
00194
00195 dflt = pair<string,string>("bias-loglevel","0");
00196 m_bias_loglevel = atoi(param.insert(dflt).first->second.c_str());
00197
00198 dflt = pair<string,string>("table-limit","20");
00199 m_tableLimit = atoi(param.insert(dflt).first->second.c_str());
00200
00201 dflt = pair<string,string>("cache","100000");
00202 m_cache_size = max(10000,atoi(param.insert(dflt).first->second.c_str()));
00203
00204 m_cache.reset(new TPCollCache(m_cache_size));
00205
00206
00207
00208
00209
00210
00211 param.insert(pair<string,string>("pfwd", "g"));
00212 param.insert(pair<string,string>("pbwd", "g"));
00213 param.insert(pair<string,string>("lenrat", "1"));
00214 param.insert(pair<string,string>("rare", "1"));
00215 param.insert(pair<string,string>("logcnt", "0"));
00216 param.insert(pair<string,string>("coh", "0"));
00217 param.insert(pair<string,string>("prov", "0"));
00218 param.insert(pair<string,string>("cumb", "0"));
00219
00220 poolCounts = true;
00221
00222
00223 if ((m = param.find("bias")) != param.end())
00224 m_bias_file = m->second;
00225
00226 if ((m = param.find("bias-server")) != param.end())
00227 m_bias_server = m->second;
00228 if (m_bias_loglevel)
00229 {
00230 dflt = pair<string,string>("bias-logfile","/dev/stderr");
00231 param.insert(dflt);
00232 }
00233 if ((m = param.find("bias-logfile")) != param.end())
00234 {
00235 m_bias_logfile = m->second;
00236 if (m_bias_logfile == "/dev/stderr")
00237 m_bias_log = &std::cerr;
00238 else if (m_bias_logfile == "/dev/stdout")
00239 m_bias_log = &std::cout;
00240 else
00241 {
00242 m_bias_logger.reset(new std::ofstream(m_bias_logfile.c_str()));
00243 m_bias_log = m_bias_logger.get();
00244 }
00245 }
00246
00247 if ((m = param.find("lr-func")) != param.end())
00248 m_lr_func_name = m->second;
00249
00250
00251 if ((m = param.find("lrfunc")) != param.end())
00252 m_lr_func_name = m->second;
00253
00254 if ((m = param.find("extra")) != param.end())
00255 m_extra_data = m->second;
00256
00257 if ((m = param.find("method")) != param.end())
00258 {
00259 if (m->second == "random")
00260 m_sampling_method = random_sampling;
00261 else if (m->second == "ranked")
00262 m_sampling_method = ranked_sampling;
00263 else if (m->second == "ranked2")
00264 m_sampling_method = ranked_sampling2;
00265 else if (m->second == "full")
00266 m_sampling_method = full_coverage;
00267 else UTIL_THROW2("unrecognized specification 'method='" << m->second
00268 << "' in line:\n" << line);
00269 }
00270
00271 dflt = pair<string,string>("tuneable","true");
00272 m_tuneable = Scan<bool>(param.insert(dflt).first->second.c_str());
00273
00274 dflt = pair<string,string>("feature-sets","standard");
00275 m_feature_set_names = Tokenize(param.insert(dflt).first->second.c_str(), ",");
00276 m = param.find("name");
00277 if (m != param.end()) m_name = m->second;
00278
00279
00280
00281
00282 param.insert(pair<string,string>("coord","0"));
00283 if(param["coord"] != "0")
00284 {
00285 m_track_coord = true;
00286 vector<string> coord_instances = Tokenize(param["coord"], ",");
00287 BOOST_FOREACH(std::string instance, coord_instances)
00288 {
00289 vector<string> toks = Moses::Tokenize(instance, ":");
00290 string space = toks[0];
00291 string file = toks[1];
00292
00293 m_coord_spaces.push_back(StaticData::InstanceNonConst().MapCoordSpace(space));
00294
00295 m_sid_coord_list.push_back(vector<SPTR<vector<float> > >());
00296 vector<SPTR<vector<float> > >& sid_coord = m_sid_coord_list[m_sid_coord_list.size() - 1];
00297
00298 sid_coord.reserve(btfix->T1->size());
00299 string line;
00300 cerr << "Loading coordinate lines for space \"" << space << "\" from " << file << endl;
00301 iostreams::filtering_istream in;
00302 ugdiss::open_input_stream(file, in);
00303 while(getline(in, line))
00304 {
00305 SPTR<vector<float> > coord(new vector<float>);
00306 Scan<float>(*coord, Tokenize(line));
00307 sid_coord.push_back(coord);
00308 }
00309 cerr << "Loaded " << sid_coord.size() << " lines" << endl;
00310 }
00311 }
00312
00313
00314 vector<string> known_parameters; known_parameters.reserve(50);
00315 known_parameters.push_back("L1");
00316 known_parameters.push_back("L2");
00317 known_parameters.push_back("Mmsapt");
00318 known_parameters.push_back("PhraseDictionaryBitextSampling");
00319
00320 known_parameters.push_back("base");
00321 known_parameters.push_back("bias");
00322 known_parameters.push_back("bias-server");
00323 known_parameters.push_back("bias-logfile");
00324 known_parameters.push_back("bias-loglevel");
00325 known_parameters.push_back("cache");
00326 known_parameters.push_back("coh");
00327 known_parameters.push_back("config");
00328 known_parameters.push_back("coord");
00329 known_parameters.push_back("cumb");
00330 known_parameters.push_back("extra");
00331 known_parameters.push_back("feature-sets");
00332 known_parameters.push_back("input-factor");
00333 known_parameters.push_back("lenrat");
00334 known_parameters.push_back("lexalpha");
00335
00336 known_parameters.push_back("logcnt");
00337 known_parameters.push_back("lr-func");
00338 known_parameters.push_back("lrfunc");
00339 known_parameters.push_back("method");
00340 known_parameters.push_back("name");
00341 known_parameters.push_back("num-features");
00342 known_parameters.push_back("output-factor");
00343 known_parameters.push_back("path");
00344 known_parameters.push_back("pbwd");
00345 known_parameters.push_back("pfwd");
00346 known_parameters.push_back("prov");
00347 known_parameters.push_back("rare");
00348 known_parameters.push_back("sample");
00349 known_parameters.push_back("min-sample");
00350 known_parameters.push_back("smooth");
00351 known_parameters.push_back("table-limit");
00352 known_parameters.push_back("tuneable");
00353 known_parameters.push_back("unal");
00354 known_parameters.push_back("workers");
00355 sort(known_parameters.begin(),known_parameters.end());
00356 for (map<string,string>::iterator m = param.begin(); m != param.end(); ++m)
00357 {
00358 UTIL_THROW_IF2(!binary_search(known_parameters.begin(),
00359 known_parameters.end(), m->first),
00360 HERE << ": Unknown parameter specification for Mmsapt: "
00361 << m->first);
00362 }
00363 }
00364
00365 void
00366 Mmsapt::
00367 load_bias(string const fname)
00368 {
00369 m_bias = btfix->loadSentenceBias(fname);
00370 }
00371
00372 void
00373 Mmsapt::
00374 load_extra_data(string bname, bool locking = true)
00375 {
00376 using namespace boost;
00377 using namespace ugdiss;
00378
00379
00380
00381
00382 vector<string> text1,text2,symal;
00383 string line;
00384 boost::iostreams::filtering_istream in1,in2,ina;
00385
00386 open_input_stream(bname+L1+".txt.gz",in1);
00387 open_input_stream(bname+L2+".txt.gz",in2);
00388 open_input_stream(bname+L1+"-"+L2+".symal.gz",ina);
00389
00390 while(getline(in1,line)) text1.push_back(line);
00391 while(getline(in2,line)) text2.push_back(line);
00392 while(getline(ina,line)) symal.push_back(line);
00393
00394 scoped_ptr<boost::unique_lock<shared_mutex> > guard;
00395 if (locking) guard.reset(new boost::unique_lock<shared_mutex>(m_lock));
00396 btdyn = btdyn->add(text1,text2,symal);
00397 assert(btdyn);
00398 cerr << "Loaded " << btdyn->T1->size() << " sentence pairs" << endl;
00399 }
00400
00401 template<typename fftype>
00402 void
00403 Mmsapt::
00404 check_ff(string const ffname, vector<SPTR<pscorer> >* registry)
00405 {
00406 string const& spec = param[ffname];
00407 if (spec == "" || spec == "0") return;
00408 if (registry)
00409 {
00410 SPTR<fftype> ff(new fftype(spec));
00411 register_ff(ff, *registry);
00412 }
00413 else if (spec[spec.size()-1] == '+')
00414 {
00415 SPTR<fftype> ff(new fftype(spec));
00416 register_ff(ff, m_active_ff_fix);
00417 ff.reset(new fftype(spec));
00418 register_ff(ff, m_active_ff_dyn);
00419 }
00420 else
00421 {
00422 SPTR<fftype> ff(new fftype(spec));
00423 register_ff(ff, m_active_ff_common);
00424 }
00425 }
00426
00427 template<typename fftype>
00428 void
00429 Mmsapt::
00430 check_ff(string const ffname, float const xtra,
00431 vector<SPTR<pscorer> >* registry)
00432 {
00433 string const& spec = param[ffname];
00434 if (spec == "" || spec == "0") return;
00435 if (registry)
00436 {
00437 SPTR<fftype> ff(new fftype(xtra,spec));
00438 register_ff(ff, *registry);
00439 }
00440 else if (spec[spec.size()-1] == '+')
00441 {
00442 SPTR<fftype> ff(new fftype(xtra,spec));
00443 register_ff(ff, m_active_ff_fix);
00444 ff.reset(new fftype(xtra,spec));
00445 register_ff(ff, m_active_ff_dyn);
00446 }
00447 else
00448 {
00449 SPTR<fftype> ff(new fftype(xtra,spec));
00450 register_ff(ff, m_active_ff_common);
00451 }
00452 }
00453
00454 void
00455 Mmsapt::
00456 Load(AllOptions::ptr const& opts)
00457 {
00458 Load(opts, true);
00459 }
00460
00461 void
00462 Mmsapt
00463 ::setup_local_feature_functions()
00464 {
00465 boost::unique_lock<boost::shared_mutex> lock(m_lock);
00466
00467 BOOST_FOREACH(string const& fsname, m_feature_set_names)
00468 {
00469
00470 if (fsname == "standard")
00471 {
00472
00473 string lexfile = m_bname + L1 + "-" + L2 + ".lex";
00474 SPTR<PScoreLex1<Token> >
00475 ff(new PScoreLex1<Token>(param["lex_alpha"],lexfile));
00476 register_ff(ff,m_active_ff_common);
00477
00478
00479 check_ff<PScoreRareness<Token> > ("rare", &m_active_ff_common);
00480 check_ff<PScoreUnaligned<Token> >("unal", &m_active_ff_common);
00481 check_ff<PScoreCoherence<Token> >("coh", &m_active_ff_common);
00482 check_ff<PScoreCumBias<Token> >("cumb", &m_active_ff_common);
00483 check_ff<PScoreLengthRatio<Token> > ("lenrat", &m_active_ff_common);
00484
00485
00486
00487 check_ff<PScorePfwd<Token> >("pfwd", m_lbop_conf);
00488 check_ff<PScorePbwd<Token> >("pbwd", m_lbop_conf);
00489 check_ff<PScoreLogCnt<Token> >("logcnt");
00490
00491
00492 check_ff<PScoreProvenance<Token> >("prov", &m_active_ff_fix);
00493 check_ff<PScoreProvenance<Token> >("prov", &m_active_ff_dyn);
00494 }
00495
00496
00497
00498 else if (fsname == "datasource")
00499 {
00500 SPTR<PScorePC<Token> > ffpcnt(new PScorePC<Token>("pcnt"));
00501 register_ff(ffpcnt,m_active_ff_common);
00502 SPTR<PScoreWC<Token> > ffwcnt(new PScoreWC<Token>("wcnt"));
00503 register_ff(ffwcnt,m_active_ff_common);
00504 }
00505 }
00506
00507 this->m_numScoreComponents = this->m_feature_names.size();
00508 this->m_numTuneableComponents = this->m_numScoreComponents;
00509 }
00510
00511 void
00512 Mmsapt::
00513 Load(AllOptions::ptr const& opts, bool with_checks)
00514 {
00515 m_options = opts;
00516 boost::unique_lock<boost::shared_mutex> lock(m_lock);
00517
00518 BOOST_FOREACH(SPTR<pscorer>& ff, m_active_ff_fix) ff->load();
00519 BOOST_FOREACH(SPTR<pscorer>& ff, m_active_ff_dyn) ff->load();
00520 BOOST_FOREACH(SPTR<pscorer>& ff, m_active_ff_common) ff->load();
00521 #if 0
00522 if (with_checks)
00523 {
00524 UTIL_THROW_IF2(this->m_feature_names.size() != this->m_numScoreComponents,
00525 "At " << HERE << ": number of feature values provided by "
00526 << "Phrase table (" << this->m_feature_names.size()
00527 << ") does not match number specified in Moses config file ("
00528 << this->m_numScoreComponents << ")!\n";);
00529 }
00530 #endif
00531
00532 m_thread_pool.reset(new ug::ThreadPool(max(m_workers,size_t(1))));
00533
00534
00535
00536
00537 btfix->m_num_workers = this->m_workers;
00538 btfix->open(m_bname, L1, L2);
00539 btfix->setDefaultSampleSize(m_default_sample_size);
00540
00541 btdyn.reset(new imbitext(btfix->V1, btfix->V2, m_default_sample_size, m_workers));
00542 if (m_bias_file.size())
00543 load_bias(m_bias_file);
00544
00545 if (m_extra_data.size())
00546 load_extra_data(m_extra_data, false);
00547
00548 #if 0
00549
00550 LexicalPhraseScorer2<Token>::table_t & COOC = calc_lex.scorer.COOC;
00551 typedef LexicalPhraseScorer2<Token>::table_t::Cell cell_t;
00552 wlex21.resize(COOC.numCols);
00553 for (size_t r = 0; r < COOC.numRows; ++r)
00554 for (cell_t const* c = COOC[r].start; c < COOC[r].stop; ++c)
00555 wlex21[c->id].push_back(r);
00556 COOCraw.open(m_bname + L1 + "-" + L2 + ".coc");
00557 #endif
00558 assert(btdyn);
00559
00560 }
00561
00562 void
00563 Mmsapt::
00564 add(string const& s1, string const& s2, string const& a)
00565 {
00566 vector<string> S1(1,s1);
00567 vector<string> S2(1,s2);
00568 vector<string> ALN(1,a);
00569 boost::unique_lock<boost::shared_mutex> guard(m_lock);
00570 btdyn = btdyn->add(S1,S2,ALN);
00571 }
00572
00573
00574 TargetPhrase*
00575 Mmsapt::
00576 mkTPhrase(ttasksptr const& ttask,
00577 Phrase const& src,
00578 PhrasePair<Token>* fix,
00579 PhrasePair<Token>* dyn,
00580 SPTR<Bitext<Token> > const& dynbt) const
00581 {
00582 UTIL_THROW_IF2(!fix && !dyn, HERE <<
00583 ": Can't create target phrase from nothing.");
00584 vector<float> fvals(this->m_numScoreComponents);
00585 PhrasePair<Token> pool = fix ? *fix : *dyn;
00586 if (fix)
00587 {
00588 BOOST_FOREACH(SPTR<pscorer> const& ff, m_active_ff_fix)
00589 (*ff)(*btfix, *fix, &fvals);
00590 }
00591 if (dyn)
00592 {
00593 BOOST_FOREACH(SPTR<pscorer> const& ff, m_active_ff_dyn)
00594 (*ff)(*dynbt, *dyn, &fvals);
00595 }
00596
00597 if (fix && dyn) { pool += *dyn; }
00598 else if (fix)
00599 {
00600 PhrasePair<Token> zilch; zilch.init();
00601 TSA<Token>::tree_iterator m(dynbt->I2.get(), fix->start2, fix->len2);
00602 if (m.size() == fix->len2)
00603 zilch.raw2 = m.approxOccurrenceCount();
00604 pool += zilch;
00605 BOOST_FOREACH(SPTR<pscorer> const& ff, m_active_ff_dyn)
00606 (*ff)(*dynbt, ff->allowPooling() ? pool : zilch, &fvals);
00607 }
00608 else if (dyn)
00609 {
00610 PhrasePair<Token> zilch; zilch.init();
00611 TSA<Token>::tree_iterator m(btfix->I2.get(), dyn->start2, dyn->len2);
00612 if (m.size() == dyn->len2)
00613 zilch.raw2 = m.approxOccurrenceCount();
00614 pool += zilch;
00615 BOOST_FOREACH(SPTR<pscorer> const& ff, m_active_ff_fix)
00616 (*ff)(*dynbt, ff->allowPooling() ? pool : zilch, &fvals);
00617 }
00618 if (fix)
00619 {
00620 BOOST_FOREACH(SPTR<pscorer> const& ff, m_active_ff_common)
00621 (*ff)(*btfix, pool, &fvals);
00622 }
00623 else
00624 {
00625 BOOST_FOREACH(SPTR<pscorer> const& ff, m_active_ff_common)
00626 (*ff)(*dynbt, pool, &fvals);
00627 }
00628
00629 TargetPhrase* tp = new TargetPhrase(const_cast<ttasksptr&>(ttask), this);
00630 Token const* x = fix ? fix->start2 : dyn->start2;
00631 uint32_t len = fix ? fix->len2 : dyn->len2;
00632 for (uint32_t k = 0; k < len; ++k, x = x->next())
00633 {
00634 StringPiece wrd = (*(btfix->V2))[x->id()];
00635 Word w;
00636 w.CreateFromString(Output, m_ofactor, wrd, false);
00637 tp->AddWord(w);
00638 }
00639 tp->SetAlignTerm(pool.aln);
00640 tp->GetScoreBreakdown().Assign(this, fvals);
00641
00642 tp->EvaluateInIsolation(src, m_featuresToApply);
00643
00644 #ifndef NO_MOSES
00645 if (m_lr_func)
00646 {
00647 LRModel::ModelType mdl = m_lr_func->GetModel().GetModelType();
00648 LRModel::Direction dir = m_lr_func->GetModel().GetDirection();
00649 SPTR<Scores> scores(new Scores());
00650 pool.fill_lr_vec(dir, mdl, *scores);
00651 tp->SetExtraScores(m_lr_func, scores);
00652 }
00653 #endif
00654
00655
00656 if (m_track_coord)
00657 {
00658 BOOST_FOREACH(uint32_t const sid, *pool.sids)
00659 {
00660 for(size_t i = 0; i < m_coord_spaces.size(); ++i)
00661 {
00662 tp->PushCoord(m_coord_spaces[i], m_sid_coord_list[i][sid]);
00663 }
00664 }
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674
00675
00676 }
00677
00678 return tp;
00679 }
00680
00681 void
00682 Mmsapt::
00683 GetTargetPhraseCollectionBatch(ttasksptr const& ttask,
00684 const InputPathList &inputPathQueue) const
00685 {
00686 InputPathList::const_iterator iter;
00687 for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter)
00688 {
00689 InputPath &inputPath = **iter;
00690 const Phrase &phrase = inputPath.GetPhrase();
00691 PrefixExists(ttask, phrase);
00692 }
00693 for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter)
00694 {
00695 InputPath &inputPath = **iter;
00696 const Phrase &phrase = inputPath.GetPhrase();
00697 TargetPhraseCollection::shared_ptr targetPhrases
00698 = this->GetTargetPhraseCollectionLEGACY(ttask,phrase);
00699 inputPath.SetTargetPhrases(*this, targetPhrases, NULL);
00700 }
00701 }
00702
00703
00704
00705
00706
00707
00708
00709
00710
00711 TargetPhraseCollection::shared_ptr
00712 Mmsapt::
00713 GetTargetPhraseCollectionLEGACY(ttasksptr const& ttask, const Phrase& src) const
00714 {
00715 SPTR<TPCollWrapper> ret;
00716
00717
00718
00719 vector<id_type> sphrase;
00720 fillIdSeq(src, m_ifactor, *(btfix->V1), sphrase);
00721 if (sphrase.size() == 0) return ret;
00722
00723
00724
00725
00726 SPTR<imBitext<Token> > dyn;
00727 {
00728 boost::unique_lock<boost::shared_mutex> guard(m_lock);
00729 assert(btdyn);
00730 dyn = btdyn;
00731 }
00732 assert(dyn);
00733
00734
00735 TSA<Token>::tree_iterator mfix(btfix->I1.get(), &sphrase[0], sphrase.size());
00736 TSA<Token>::tree_iterator mdyn(dyn->I1.get());
00737 if (dyn->I1.get())
00738 for (size_t i = 0; mdyn.size() == i && i < sphrase.size(); ++i)
00739 mdyn.extend(sphrase[i]);
00740
00741 if (mdyn.size() != sphrase.size() && mfix.size() != sphrase.size())
00742 return ret;
00743
00744
00745 uint64_t phrasekey = (mfix.size() == sphrase.size()
00746 ? (mfix.getPid()<<1)
00747 : (mdyn.getPid()<<1)+1);
00748
00749
00750 SPTR<ContextScope> const& scope = ttask->GetScope();
00751 SPTR<TPCollCache> cache = scope->get<TPCollCache>(cache_key);
00752 if (!cache) cache = m_cache;
00753
00754 ret = cache->get(phrasekey, dyn->revision());
00755
00756
00757
00758
00759
00760
00761
00762
00763
00764
00765
00766 boost::upgrade_lock<boost::shared_mutex> rlock(ret->lock);
00767 if (ret->GetSize()) return ret;
00768
00769
00770 boost::upgrade_to_unique_lock<boost::shared_mutex> wlock(rlock);
00771
00772 if (ret->GetSize()) return ret;
00773
00774
00775
00776
00777
00778
00779
00780 SPTR<pstats> sfix,sdyn;
00781
00782 if (mfix.size() == sphrase.size())
00783 {
00784 SPTR<ContextForQuery> context = scope->get<ContextForQuery>(btfix.get());
00785 SPTR<pstats> const* foo = context->cache1->get(mfix.getPid());
00786 if (foo) { sfix = *foo; sfix->wait(); }
00787 else
00788 {
00789 BitextSampler<Token> s(btfix, mfix, context->bias,
00790 m_min_sample_size,
00791 m_default_sample_size,
00792 m_sampling_method,
00793 m_track_coord);
00794 s();
00795 sfix = s.stats();
00796 }
00797 }
00798
00799 if (mdyn.size() == sphrase.size())
00800 sdyn = dyn->lookup(ttask, mdyn);
00801
00802 vector<PhrasePair<Token> > ppfix,ppdyn;
00803 PhrasePair<Token>::SortByTargetIdSeq sort_by_tgt_id;
00804 if (sfix)
00805 {
00806 expand(mfix, *btfix, *sfix, ppfix, m_bias_log);
00807 sort(ppfix.begin(), ppfix.end(),sort_by_tgt_id);
00808 }
00809 if (sdyn)
00810 {
00811 expand(mdyn, *dyn, *sdyn, ppdyn, m_bias_log);
00812 sort(ppdyn.begin(), ppdyn.end(),sort_by_tgt_id);
00813 }
00814
00815
00816 PhrasePair<Token>::SortByTargetIdSeq sorter;
00817 size_t i = 0; size_t k = 0;
00818 while (i < ppfix.size() && k < ppdyn.size())
00819 {
00820 int cmp = sorter.cmp(ppfix[i], ppdyn[k]);
00821 if (cmp < 0) ret->Add(mkTPhrase(ttask,src,&ppfix[i++],NULL,dyn));
00822 else if (cmp == 0) ret->Add(mkTPhrase(ttask,src,&ppfix[i++],&ppdyn[k++],dyn));
00823 else ret->Add(mkTPhrase(ttask,src,NULL,&ppdyn[k++],dyn));
00824 }
00825 while (i < ppfix.size()) ret->Add(mkTPhrase(ttask,src,&ppfix[i++],NULL,dyn));
00826 while (k < ppdyn.size()) ret->Add(mkTPhrase(ttask,src,NULL,&ppdyn[k++],dyn));
00827
00828
00829 if (m_tableLimit) ret->Prune(true, m_tableLimit);
00830 else ret->Prune(true,ret->GetSize());
00831
00832 #if 1
00833 if (m_bias_log && m_lr_func && m_bias_loglevel > 3)
00834 {
00835 PhrasePair<Token>::SortDescendingByJointCount sorter;
00836 sort(ppfix.begin(), ppfix.end(),sorter);
00837 BOOST_FOREACH(PhrasePair<Token> const& pp, ppfix)
00838 {
00839
00840 pp.print(*m_bias_log,*btfix->V1, *btfix->V2, m_lr_func->GetModel());
00841 }
00842 }
00843 #endif
00844 return ret;
00845 }
00846
00847 size_t
00848 Mmsapt::
00849 SetTableLimit(size_t limit)
00850 {
00851 std::swap(m_tableLimit,limit);
00852 return limit;
00853 }
00854
00855 void
00856 Mmsapt::
00857 CleanUpAfterSentenceProcessing(ttasksptr const& ttask)
00858 { }
00859
00860
00861 ChartRuleLookupManager*
00862 Mmsapt::
00863 CreateRuleLookupManager(const ChartParser &, const ChartCellCollectionBase &)
00864 {
00865 throw "CreateRuleLookupManager is currently not supported in Mmsapt!";
00866 }
00867
00868 ChartRuleLookupManager*
00869 Mmsapt::
00870 CreateRuleLookupManager(const ChartParser &, const ChartCellCollectionBase &,
00871 size_t )
00872 {
00873 throw "CreateRuleLookupManager is currently not supported in Mmsapt!";
00874 }
00875
00876 void
00877 Mmsapt::
00878 setup_bias(ttasksptr const& ttask)
00879 {
00880 SPTR<ContextScope> const& scope = ttask->GetScope();
00881 SPTR<ContextForQuery> context;
00882 context = scope->get<ContextForQuery>(btfix.get(), true);
00883 if (context->bias) return;
00884
00885
00886 SPTR<std::map<std::string, float> const> w;
00887 w = ttask->GetScope()->GetContextWeights();
00888 if (w && !w->empty())
00889 {
00890 if (m_bias_log)
00891 *m_bias_log << "BIAS WEIGHTS GIVEN WITH INPUT at " << HERE << endl;
00892 context->bias = btfix->SetupDocumentBias(*w, m_bias_log);
00893 }
00894 else if (m_bias_server.size() && ttask->GetContextWindow())
00895 {
00896
00897 string context_words;
00898 BOOST_FOREACH(string const& line, *ttask->GetContextWindow())
00899 {
00900 if (context_words.size()) context_words += " ";
00901 context_words += line;
00902 }
00903 if (context_words.size())
00904 {
00905 if (m_bias_log)
00906 *m_bias_log << "GETTING BIAS FROM SERVER at " << HERE << endl
00907 << "BIAS LOOKUP CONTEXT: " << context_words << endl;
00908 context->bias
00909 = btfix->SetupDocumentBias(m_bias_server,context_words,m_bias_log);
00910
00911
00912 ttask->GetScope()->SetContextWeights(context->bias->getBiasMap());
00913 }
00914 }
00915 if (context->bias)
00916 {
00917 context->bias_log = m_bias_log;
00918 context->bias->loglevel = m_bias_loglevel;
00919 }
00920 }
00921
00922 void
00923 Mmsapt::
00924 InitializeForInput(ttasksptr const& ttask)
00925 {
00926 boost::unique_lock<boost::shared_mutex> mylock(m_lock);
00927
00928 SPTR<ContextScope> const& scope = ttask->GetScope();
00929 SPTR<TPCollCache> localcache = scope->get<TPCollCache>(cache_key);
00930 SPTR<ContextForQuery> context = scope->get<ContextForQuery>(btfix.get(), true);
00931 boost::unique_lock<boost::shared_mutex> ctxlock(context->lock);
00932
00933
00934
00935 if (!localcache)
00936 {
00937
00938 setup_bias(ttask);
00939 if (context->bias)
00940 {
00941 localcache.reset(new TPCollCache(m_cache_size));
00942 }
00943 else localcache = m_cache;
00944 scope->set<TPCollCache>(cache_key, localcache);
00945 }
00946
00947 if (!context->cache1) context->cache1.reset(new pstats::cache_t);
00948 if (!context->cache2) context->cache2.reset(new pstats::cache_t);
00949
00950 #ifndef NO_MOSES
00951 if (m_lr_func_name.size() && m_lr_func == NULL)
00952 {
00953 FeatureFunction* lr = &FeatureFunction::FindFeatureFunction(m_lr_func_name);
00954 m_lr_func = dynamic_cast<LexicalReordering*>(lr);
00955 UTIL_THROW_IF2(lr == NULL, "FF " << m_lr_func_name
00956 << " does not seem to be a lexical reordering function!");
00957
00958 }
00959 #endif
00960 }
00961
00962 bool
00963 Mmsapt::
00964 PrefixExists(ttasksptr const& ttask, Moses::Phrase const& phrase) const
00965 {
00966 if (phrase.GetSize() == 0) return false;
00967 SPTR<ContextScope> const& scope = ttask->GetScope();
00968
00969 vector<id_type> myphrase;
00970 fillIdSeq(phrase, m_ifactor, *btfix->V1, myphrase);
00971
00972 TSA<Token>::tree_iterator mfix(btfix->I1.get(),&myphrase[0],myphrase.size());
00973 if (mfix.size() == myphrase.size())
00974 {
00975 SPTR<ContextForQuery> context = scope->get<ContextForQuery>(btfix.get(), true);
00976 uint64_t pid = mfix.getPid();
00977 if (!context->cache1->get(pid))
00978 {
00979 BitextSampler<Token> s(btfix, mfix, context->bias,
00980 m_min_sample_size, m_default_sample_size,
00981 m_sampling_method, m_track_coord);
00982 if (*context->cache1->get(pid, s.stats()) == s.stats())
00983 m_thread_pool->add(s);
00984 }
00985
00986
00987 return true;
00988 }
00989
00990 SPTR<imBitext<Token> > dyn;
00991 {
00992 boost::unique_lock<boost::shared_mutex> guard(m_lock);
00993 dyn = btdyn;
00994 }
00995 assert(dyn);
00996 TSA<Token>::tree_iterator mdyn(dyn->I1.get());
00997 if (dyn->I1.get())
00998 {
00999 for (size_t i = 0; mdyn.size() == i && i < myphrase.size(); ++i)
01000 mdyn.extend(myphrase[i]);
01001
01002 if (mdyn.size() == myphrase.size()) dyn->prep(ttask, mdyn, m_track_coord);
01003 }
01004 return mdyn.size() == myphrase.size();
01005 }
01006
01007 #if 0
01008 void
01009 Mmsapt
01010 ::Release(ttasksptr const& ttask, TargetPhraseCollection::shared_ptr*& tpc) const
01011 {
01012 if (!tpc)
01013 {
01014
01015 return;
01016 }
01017 SPTR<TPCollCache> cache = ttask->GetScope()->get<TPCollCache>(cache_key);
01018
01019 TPCollWrapper const* foo = static_cast<TPCollWrapper const*>(tpc);
01020
01021
01022
01023 if (cache) cache->release(static_cast<TPCollWrapper const*>(tpc));
01024 tpc = NULL;
01025 }
01026 #endif
01027
01028 bool Mmsapt
01029 ::ProvidesPrefixCheck() const { return true; }
01030
01031 string const& Mmsapt
01032 ::GetName() const { return m_name; }
01033
01034
01035
01036
01037
01038
01039
01040
01041 vector<float>
01042 Mmsapt
01043 ::DefaultWeights() const
01044 { return vector<float>(this->GetNumScoreComponents(), 1.); }
01045
01046 }