00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 using namespace std;
00024
00025 #include <iostream>
00026 #include <fstream>
00027 #include <sstream>
00028 #include <vector>
00029 #include <string>
00030 #include <stdlib.h>
00031 #include "util.h"
00032 #include "math.h"
00033
00034 #include "lmtable.h"
00035
00036
00037
00038 std::string slearn = "";
00039 std::string seval = "";
00040 std::string sorder = "";
00041 std::string sscore = "no";
00042 std::string sdebug = "0";
00043 std::string smemmap = "0";
00044 std::string sdub = "10000000";
00045
00046
00047
00048
00049 lmtable *load_lm(std::string file,int dub,int memmap);
00050
00051 void usage(const char *msg = 0) {
00052 if (msg) { std::cerr << msg << std::endl; }
00053 std::cerr << "Usage: interpolate-lm [options] lm-list-file [lm-list-file.out]" << std::endl;
00054 if (!msg) std::cerr << std::endl
00055 << " interpolate-lm reads a LM list file including interpolation weights " << std::endl
00056 << " with the format: N\\n w1 lm1 \\n w2 lm2 ...\\n wN lmN\n" << std::endl
00057 << " It estimates new weights on a development text, " << std::endl
00058 << " computes the perplexity on an evaluation text, " << std::endl
00059 << " computes probabilities of n-grams read from stdin." << std::endl
00060 << " It reads LMs in ARPA and IRSTLM binary format." << std::endl << std::endl;
00061
00062 std::cerr << "Options:\n"
00063 << "--learn|-l text-file learn optimal interpolation for text-file"<< std::endl
00064 << "--order|-o n order of n-grams used in --learn (optional)"<< std::endl
00065 << "--eval|-e text-file computes perplexity on text-file"<< std::endl
00066 << "--dub dict-size dictionary upperbound (default 10^7)"<< std::endl
00067 << "--score|-s [yes|no] compute log-probs of n-grams from stdin"<< std::endl
00068 << "--debug|-d [1-3] verbose output for --eval option (see compile-lm)"<< std::endl
00069 << "--memmap| -mm 1 use memory map to read a binary LM\n" ;
00070 }
00071
00072
00073 bool starts_with(const std::string &s, const std::string &pre) {
00074 if (pre.size() > s.size()) return false;
00075
00076 if (pre == s) return true;
00077 std::string pre_equals(pre+'=');
00078 if (pre_equals.size() > s.size()) return false;
00079 return (s.substr(0,pre_equals.size()) == pre_equals);
00080 }
00081
00082 std::string get_param(const std::string& opt, int argc, const char **argv, int& argi)
00083 {
00084 std::string::size_type equals = opt.find_first_of('=');
00085 if (equals != std::string::npos && equals < opt.size()-1) {
00086 return opt.substr(equals+1);
00087 }
00088 std::string nexto;
00089 if (argi + 1 < argc) {
00090 nexto = argv[++argi];
00091 } else {
00092 usage((opt + " requires a value!").c_str());
00093 exit(1);
00094 }
00095 return nexto;
00096 }
00097
00098 void handle_option(const std::string& opt, int argc, const char **argv, int& argi)
00099 {
00100 if (opt == "--help" || opt == "-h") { usage(); exit(1); }
00101
00102 if (starts_with(opt, "--learn") || starts_with(opt, "-l"))
00103 slearn = get_param(opt, argc, argv, argi);
00104 else
00105 if (starts_with(opt, "--order") || starts_with(opt, "-o"))
00106 sorder = get_param(opt, argc, argv, argi);
00107 else
00108 if (starts_with(opt, "--eval") || starts_with(opt, "-e"))
00109 seval = get_param(opt, argc, argv, argi);
00110 else
00111 if (starts_with(opt, "--score") || starts_with(opt, "-s"))
00112 sscore = get_param(opt, argc, argv, argi);
00113 else
00114 if (starts_with(opt, "--debug") || starts_with(opt, "-d"))
00115 sdebug = get_param(opt, argc, argv, argi);
00116
00117 else
00118 if (starts_with(opt, "--memmap") || starts_with(opt, "-mm") || starts_with(opt, "-m") )
00119 smemmap = get_param(opt, argc, argv, argi);
00120
00121 else
00122 if (starts_with(opt, "--dub") || starts_with(opt, "-dub"))
00123 sdub = get_param(opt, argc, argv, argi);
00124
00125 else {
00126 usage(("Don't understand option " + opt).c_str());
00127 exit(1);
00128 }
00129 }
00130
00131 int main(int argc, const char **argv)
00132 {
00133
00134 if (argc < 2) { usage(); exit(1); }
00135 std::vector<std::string> files;
00136 for (int i=1; i < argc; i++) {
00137 std::string opt = argv[i];
00138 if (opt[0] == '-') { handle_option(opt, argc, argv, i); }
00139 else files.push_back(opt);
00140 }
00141
00142 bool learn = (slearn != ""? true : false);
00143 bool score = (sscore != ""? true : false);
00144 int order=(sorder!=""?atoi(sorder.c_str()):0);
00145 int debug = atoi(sdebug.c_str());
00146 int memmap = atoi(smemmap.c_str());
00147 int dub = atoi(sdub.c_str());
00148
00149 if (sorder != "" && order < 1) {usage("Order must be a positive integer"); exit(1);}
00150
00151 if (files.size() > 2) { usage("Too many arguments"); exit(1); }
00152 if (files.size() < 1) { usage("Please specify a LM list file to read from"); exit(1); }
00153
00154 std::string infile = files[0];
00155 std::string outfile="";
00156
00157 if (files.size() == 1) {
00158 outfile=infile;
00159
00160 std::string::size_type p = outfile.rfind('/');
00161 if (p != std::string::npos && ((p+1) < outfile.size()))
00162 outfile.erase(0,p+1);
00163 outfile+=".out";
00164 }
00165 else
00166 outfile = files[1];
00167
00168 std::cerr << "inpfile: " << infile << std::endl;
00169
00170 if (learn) std::cerr << "outfile: " << outfile << std::endl;
00171 if (score) std::cerr << "interactive: " << sscore << std::endl;
00172 std::cerr << "order: " << order << std::endl;
00173 if (memmap) std::cerr << "memory mapping: " << memmap << std::endl;
00174
00175 std::cerr << "dub: " << dub<< std::endl;
00176
00177
00178 lmtable *lmt[100], *start_lmt[100];
00179 std::string lmf[100];
00180
00181 float w[100];
00182 int N;
00183
00184
00185
00186 std::cerr << "Reading " << infile << "..." << std::endl;
00187 std::fstream inptxt(infile.c_str(),std::ios::in);
00188 inptxt >> N; std::cerr << "Number of LMs: " << N << "..." << std::endl;
00189
00190 if(N > 100) {
00191 std::cerr << "Can't interpolate more than 100 language models." << std::endl;
00192 exit(1);
00193 }
00194
00195 for (int i=0;i<N;i++){
00196 inptxt >> w[i] >> lmf[i];
00197 start_lmt[i] = lmt[i] = load_lm(lmf[i],dub,memmap);
00198 }
00199 inptxt.close();
00200
00201
00202 if (learn){
00203
00204 std::vector<float> p[N];
00205 float c[N];
00206 float den,norm;
00207 float variation=1.0;
00208
00209 dictionary* dict;dict=new dictionary((char*)slearn.c_str(),1000000);
00210 ngram ng(dict);
00211 int bos=ng.dict->encode(ng.dict->BoS());
00212 std::ifstream dev(slearn.c_str(),std::ios::in);
00213
00214 for(;;) {
00215 std::string line;
00216 getline(dev, line);
00217 if(dev.eof())
00218 break;
00219 if(dev.fail()) {
00220 std::cerr << "Problem reading input file " << seval << std::endl;
00221 return 1;
00222 }
00223 std::istringstream lstream(line);
00224 if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
00225 std::string token, newlm;
00226 int id;
00227 lstream >> token >> id >> newlm;
00228 if(id <= 0 || id > N) {
00229 std::cerr << "LM id out of range." << std::endl;
00230 return 1;
00231 }
00232 id--;
00233 if(lmt[id] != start_lmt[id])
00234 delete lmt[id];
00235 lmt[id] = load_lm(newlm,dub,memmap);
00236 continue;
00237 }
00238 while(lstream >> ng){
00239
00240
00241 if (*ng.wordp(1)==bos) {ng.size=1;continue;}
00242 if (order > 0 && ng.size > order) ng.size=order;
00243 for (int i=0;i<N;i++){
00244 ngram ong(lmt[i]->dict);ong.trans(ng);
00245 p[i].push_back(pow(10.0,lmt[i]->clprob(ong)));
00246 }
00247 }
00248
00249 for (int i=0;i<N;i++) lmt[i]->check_cache_levels();
00250 }
00251 dev.close();
00252
00253 while( variation > 0.01 ){
00254
00255 for (int i=0;i<N;i++) c[i]=0;
00256
00257 for(unsigned i = 0; i < p[0].size(); i++) {
00258 den=0.0;
00259 for(int j = 0; j < N; j++)
00260 den += w[j] * p[j][i];
00261
00262 for(int j = 0; j < N; j++)
00263 c[j] += w[j] * p[j][i] / den;
00264 }
00265
00266 norm=0.0;
00267 for (int i=0;i<N;i++) norm+=c[i];
00268
00269
00270 variation=0.0;
00271 for (int i=0;i<N;i++){
00272 c[i]/=norm;
00273 variation+=(w[i]>c[i]?(w[i]-c[i]):(c[i]-w[i]));
00274 w[i]=c[i];
00275 }
00276 std::cerr << "Variation " << variation << std::endl;
00277 }
00278
00279
00280 std::cerr << "Saving in " << outfile << "..." << std::endl;
00281
00282 std::fstream outtxt(outfile.c_str(),std::ios::out);
00283 outtxt << N << "\n";
00284 for (int i=0;i<N;i++) outtxt << w[i] << " " << lmf[i] << "\n";
00285 outtxt.close();
00286 }
00287
00288 for(int i = 0; i < N; i++)
00289 if(lmt[i] != start_lmt[i]) {
00290 delete lmt[i];
00291 lmt[i] = start_lmt[i];
00292 }
00293
00294 if (seval != ""){
00295 std::cerr << "Start Eval" << std::endl;
00296
00297 std::cout.setf(ios::fixed);
00298 std::cout.precision(2);
00299 int i,Nw=0,Noov=0, Nbo=0;
00300 double logPr=0,PP=0,Pr;
00301
00302
00303 for (i=0,Pr=0;i<N;i++) Pr+=w[i];
00304 for (i=0;i<N;i++) w[i]/=Pr;
00305
00306 dictionary* dict;dict=new dictionary((char*)seval.c_str(),1000000);
00307 ngram ng(dict);
00308 int bos=ng.dict->encode(ng.dict->BoS());
00309 int eos=ng.dict->encode(ng.dict->EoS());
00310
00311 std::fstream inptxt(seval.c_str(),std::ios::in);
00312
00313 for(;;) {
00314 std::string line;
00315 getline(inptxt, line);
00316 if(inptxt.eof())
00317 break;
00318 if(inptxt.fail()) {
00319 std::cerr << "Problem reading input file " << seval << std::endl;
00320 return 1;
00321 }
00322 std::istringstream lstream(line);
00323 if(line.substr(0, 26) == "###interpolate-lm:weights ") {
00324 std::string token;
00325 lstream >> token;
00326 for(int i = 0; i < N; i++) {
00327 if(lstream.eof()) {
00328 std::cerr << "Not enough weights!" << std::endl;
00329 return 1;
00330 }
00331 lstream >> w[i];
00332 }
00333 continue;
00334 }
00335 if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
00336 std::string token, newlm;
00337 int id;
00338 lstream >> token >> id >> newlm;
00339 if(id <= 0 || id > N) {
00340 std::cerr << "LM id out of range." << std::endl;
00341 return 1;
00342 }
00343 id--;
00344 delete lmt[id];
00345 lmt[id] = load_lm(newlm,dub,memmap);
00346 continue;
00347 }
00348
00349 double bow; int bol=0;
00350
00351
00352 while(lstream >> ng){
00353
00354
00355 if (*ng.wordp(1)==bos) {ng.size=1;continue;}
00356 if (order > 0 && ng.size > order) ng.size=order;
00357
00358 if (ng.size>=1){
00359
00360 int minbol=MAX_NGRAM;
00361 bool OOVflag=true;
00362
00363 for (i=0,Pr=0;i<N;i++){
00364 ngram ong(lmt[i]->dict);ong.trans(ng);
00365 Pr+=w[i] * pow(10.0,lmt[i]->lprob(ong,&bow,&bol));
00366 if (bol < minbol) minbol=bol;
00367 if (*ong.wordp(1) != lmt[i]->dict->oovcode()) OOVflag=false;
00368 }
00369
00370 logPr+=(log(Pr)/M_LN10);
00371
00372 if (debug==1){
00373 std::cout << ng.dict->decode(*ng.wordp(1)) << "[" << ng.size-minbol << "]" << " ";
00374 if (*ng.wordp(1)==eos) std::cout << std::endl;
00375 }
00376 if (debug==2)
00377 std::cout << ng << "[" << ng.size-minbol << "-gram]" << " " << logPr << std::endl;
00378
00379 if (debug==3)
00380 std::cout << ng << "[" << ng.size-minbol << "-gram]" << " " << logPr << " bow:" << bow << std::endl;
00381
00382
00383 if (minbol) Nbo++;
00384
00385 if (OOVflag) Noov++;
00386
00387 Nw++;
00388
00389 if ((Nw % 10000)==0) std::cerr << ".";
00390 }
00391 }
00392 }
00393
00394 PP=exp((-logPr * M_LN10) /Nw);
00395
00396 std::cout << "%% Nw=" << Nw << " PP=" << PP
00397 << " Nbo=" << Nbo << " Noov=" << Noov
00398 << " OOV=" << (float)Noov/Nw * 100.0 << "%" << std::endl;
00399
00400 };
00401
00402
00403 if (sscore == "yes"){
00404
00405
00406 dictionary* dict;dict=new dictionary(NULL,1000000);
00407 dict->incflag(1);
00408 ngram ng(dict);
00409 double Pr,logPr;
00410
00411
00412
00413
00414
00415 unsigned int maxstatesize, statesize;
00416 int i,n=0;
00417 std::cout << "> ";
00418 while(std::cin >> ng){
00419 n++;
00420 maxstatesize=0;
00421 for (i=0,Pr=0;i<N;i++){
00422 Pr+=w[i] * pow(10.0,lmt[i]->clprob(ng));
00423 lmt[i]->maxsuffptr(ng,&statesize);
00424 if (maxstatesize<statesize) maxstatesize=statesize;
00425 };
00426 logPr=log(Pr);
00427
00428 ng.size=maxstatesize;
00429 std::cout << "recombine= " << maxstatesize << " " << ng << " p= " << logPr << std::endl;
00430
00431 if ((n % 10000000)==0){
00432 std::cerr << "check cache levels" << std::endl;
00433 for (i=0;i<N;i++) lmt[i]->check_cache_levels();
00434 }
00435
00436
00437 std::cout << "> ";
00438 }
00439
00440
00441 }
00442
00443 for (int i=0;i<N;i++) delete lmt[i];
00444
00445 return 0;
00446 }
00447
00448 lmtable *load_lm(std::string file,int dub,int memmap) {
00449 inputfilestream inplm(file.c_str());
00450 std::cerr << "Reading " << file.c_str() << "..." << std::endl;
00451 lmtable *lmt=new lmtable;
00452 if (file.compare(file.size()-3,3,".mm")==0)
00453 lmt->load(inplm,file.c_str(),NULL,1,NONE);
00454 else
00455 lmt->load(inplm,file.c_str(),NULL,memmap,NONE);
00456 if (dub) lmt->setlogOOVpenalty(dub);
00457 lmt->init_probcache();
00458 return lmt;
00459 }