00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 using namespace std;
00022
00023 #include <cmath>
00024 #include <math.h>
00025 #include "mfstream.h"
00026 #include <fstream>
00027 #include <stdio.h>
00028 #include <iostream>
00029 #include "mempool.h"
00030 #include "htable.h"
00031 #include "dictionary.h"
00032 #include "n_gram.h"
00033 #include "mempool.h"
00034 #include "ngramcache.h"
00035 #include "ngramtable.h"
00036 #include "interplm.h"
00037 #include "normcache.h"
00038 #include "mdiadapt.h"
00039 #include "shiftlm.h"
00040 #include "linearlm.h"
00041 #include "mixture.h"
00042 #include "cmd.h"
00043 #include "lmtable.h"
00044
00045
00046 #define YES 1
00047 #define NO 0
00048
00049
00050 #define NGRAM 1
00051 #define SEQUENCE 2
00052 #define ADAPT 3
00053 #define TURN 4
00054 #define TEXT 5
00055
00056
00057 #define END_ENUM { (char*)0, 0 }
00058
00059 static Enum_T BooleanEnum [] = {
00060 { "Yes", YES },
00061 { "No", NO},
00062 { "yes", YES },
00063 { "no", NO},
00064 { "y", YES },
00065 { "n", NO},
00066 END_ENUM
00067 };
00068
00069 static Enum_T LmTypeEnum [] = {
00070 { "ModifiedShiftBeta", MOD_SHIFT_BETA },
00071 { "msb", MOD_SHIFT_BETA },
00072 { "InterpShiftBeta", SHIFT_BETA },
00073 { "ShiftBeta", SHIFT_BETA },
00074 { "sb", SHIFT_BETA },
00075 { "InterpShiftOne", SHIFT_ONE },
00076 { "ShiftOne", SHIFT_ONE },
00077 { "s1", SHIFT_ONE },
00078 { "LinearWittenBell", LINEAR_WB },
00079 { "wb", LINEAR_WB },
00080 { "LinearGoodTuring", LINEAR_GT },
00081 { "Mixture", MIXTURE },
00082 { "mix", MIXTURE },
00083 END_ENUM
00084 };
00085
00086
00087 #define RESET 1
00088 #define SAVE 2
00089 #define LOAD 3
00090 #define INIT 4
00091 #define STOP 5
00092
00093 #define BIN 11
00094 #define ARPA 12
00095 #define ASR 13
00096 #define TXT 14
00097 #define NGT 15
00098
00099
00100 int init(mdiadaptlm** lm, int lmtype, char *trainfile, int size, int prunefreq, double beta, int backoff, int dub, double oovrate, int mcl);
00101 int deinit(mdiadaptlm** lm);
00102
00103 int main(int argc, char **argv)
00104 {
00105
00106 char *dictfile=NULL;
00107 char *trainfile=NULL;
00108
00109 char *BINfile=NULL;
00110 char *ARPAfile=NULL;
00111 char *ASRfile=NULL;
00112
00113 int backoff=0;
00114 int lmtype=0;
00115 int dub=0;
00116 int size=0;
00117
00118 int statistics=0;
00119
00120 int prunefreq=NO;
00121 int prunesingletons=YES;
00122 int prunetopsingletons=NO;
00123
00124 double beta=-1;
00125
00126 int compsize=NO;
00127 int checkpr=NO;
00128 double oovrate=0;
00129 int max_caching_level=0;
00130
00131 char *outpr=NULL;
00132
00133 int memmap = 0;
00134
00135 DeclareParams(
00136
00137 "Back-off",CMDENUMTYPE, &backoff, BooleanEnum,
00138 "bo",CMDENUMTYPE, &backoff, BooleanEnum,
00139
00140 "Dictionary", CMDSTRINGTYPE, &dictfile,
00141 "d", CMDSTRINGTYPE, &dictfile,
00142
00143 "DictionaryUpperBound", CMDINTTYPE, &dub,
00144 "dub", CMDINTTYPE, &dub,
00145
00146 "NgramSize", CMDSUBRANGETYPE, &size, 1 , MAX_NGRAM,
00147 "n", CMDSUBRANGETYPE, &size, 1 , MAX_NGRAM,
00148
00149 "Ngram", CMDSTRINGTYPE, &trainfile,
00150 "TrainOn", CMDSTRINGTYPE, &trainfile,
00151 "tr", CMDSTRINGTYPE, &trainfile,
00152
00153 "oASR", CMDSTRINGTYPE, &ASRfile,
00154 "oasr", CMDSTRINGTYPE, &ASRfile,
00155
00156 "o", CMDSTRINGTYPE, &ARPAfile,
00157 "oARPA", CMDSTRINGTYPE, &ARPAfile,
00158 "oarpa", CMDSTRINGTYPE, &ARPAfile,
00159
00160 "oBIN", CMDSTRINGTYPE, &BINfile,
00161 "obin", CMDSTRINGTYPE, &BINfile,
00162
00163 "LanguageModelType",CMDENUMTYPE, &lmtype, LmTypeEnum,
00164 "lm",CMDENUMTYPE, &lmtype, LmTypeEnum,
00165
00166 "Statistics",CMDSUBRANGETYPE, &statistics, 1 , 3,
00167 "s",CMDSUBRANGETYPE, &statistics, 1 , 3,
00168
00169 "PruneThresh",CMDSUBRANGETYPE, &prunefreq, 1 , 1000,
00170 "p",CMDSUBRANGETYPE, &prunefreq, 1 , 1000,
00171
00172 "PruneSingletons",CMDENUMTYPE, &prunesingletons, BooleanEnum,
00173 "ps",CMDENUMTYPE, &prunesingletons, BooleanEnum,
00174
00175 "PruneTopSingletons",CMDENUMTYPE, &prunetopsingletons, BooleanEnum,
00176 "pts",CMDENUMTYPE, &prunetopsingletons, BooleanEnum,
00177
00178 "ComputeLMSize",CMDENUMTYPE, &compsize, BooleanEnum,
00179 "sz",CMDENUMTYPE, &compsize, BooleanEnum,
00180
00181 "MaximumCachingLevel", CMDINTTYPE , &max_caching_level,
00182 "mcl", CMDINTTYPE, &max_caching_level,
00183
00184 "MemoryMap", CMDENUMTYPE, &memmap, BooleanEnum,
00185 "memmap", CMDENUMTYPE, &memmap, BooleanEnum,
00186 "mm", CMDENUMTYPE, &memmap, BooleanEnum,
00187
00188 "CheckProb",CMDENUMTYPE, &checkpr, BooleanEnum,
00189 "cp",CMDENUMTYPE, &checkpr, BooleanEnum,
00190
00191 "OutProb",CMDSTRINGTYPE, &outpr,
00192 "op",CMDSTRINGTYPE, &outpr,
00193
00194 "SetOovRate", CMDDOUBLETYPE, &oovrate,
00195 "or", CMDDOUBLETYPE, &oovrate,
00196
00197 "Beta", CMDDOUBLETYPE, &beta,
00198 "beta", CMDDOUBLETYPE, &beta,
00199
00200 (char *)NULL
00201 );
00202
00203 GetParams(&argc, &argv, (char*) NULL);
00204
00205 if (!lmtype) {
00206 cerr <<"Missing parameters\n";
00207 exit(1);
00208 }
00209
00210
00211 cerr <<"LM size: " << size << "\n";
00212
00213
00214 char header[BUFSIZ];
00215 char filename[BUFSIZ];
00216 int cmdcounter=0;
00217 mdiadaptlm *lm=NULL;
00218
00219
00220 int cmdtype=INIT;
00221 int filetype=0;
00222 int BoSfreq=0;
00223
00224 init(&lm, lmtype, trainfile, size, prunefreq, beta, backoff, dub, oovrate, max_caching_level);
00225
00226 ngram ng(lm->dict), ng2(lm->dict);
00227
00228 cerr << "filling the initial n-grams with BoS\n";
00229 for (int i=1; i<lm->maxlevel(); i++) {
00230 ng.pushw(lm->dict->BoS());
00231 ng.freq=1;
00232 }
00233
00234 mfstream inp("/dev/stdin",ios::in );
00235 int c=0;
00236
00237 while (inp >> header) {
00238
00239 if (strncmp(header,"@CMD@",5)==0) {
00240 cmdcounter++;
00241 inp >> header;
00242
00243 cerr << "Read |@CMD@| |" << header << "|";
00244
00245 cmdtype=INIT;
00246 filetype=BIN;
00247 if (strncmp(header,"RESET",5)==0) cmdtype=RESET;
00248 else if (strncmp(header,"INIT",4)==0) cmdtype=INIT;
00249 else if (strncmp(header,"SAVEBIN",7)==0) {
00250 cmdtype=SAVE;
00251 filetype=BIN;
00252 } else if (strncmp(header,"SAVEARPA",8)==0) {
00253 cmdtype=SAVE;
00254 filetype=ARPA;
00255 } else if (strncmp(header,"SAVEASR",7)==0) {
00256 cmdtype=SAVE;
00257 filetype=ASR;
00258 } else if (strncmp(header,"SAVENGT",7)==0) {
00259 cmdtype=SAVE;
00260 filetype=NGT;
00261 } else if (strncmp(header,"LOADNGT",7)==0) {
00262 cmdtype=LOAD;
00263 filetype=NGT;
00264 } else if (strncmp(header,"LOADTXT",7)==0) {
00265 cmdtype=LOAD;
00266 filetype=TXT;
00267 } else if (strncmp(header,"STOP",4)==0) cmdtype=STOP;
00268 else {
00269 cerr << "CMD " << header << " is unknown\n";
00270 exit(1);
00271 }
00272
00273 char** lastwords;
00274 char *isym;
00275 switch (cmdtype) {
00276
00277 case STOP:
00278 cerr << "\n";
00279 exit(1);
00280 break;
00281
00282 case SAVE:
00283
00284 inp >> filename;
00285 cerr << " |" << filename << "|\n";
00286
00287
00288 char tmpngtfile[BUFSIZ];
00289 sprintf(tmpngtfile,"%s.ngt%d",filename,cmdcounter);
00290 cerr << "saving temporary ngramtable (binary)..." << tmpngtfile << "\n";
00291 ((ngramtable*) lm)->ngtype("ngram");
00292 ((ngramtable*) lm)->savetxt(tmpngtfile,size);
00293
00294
00295 BoSfreq=lm->dict->freq(lm->dict->encode(lm->dict->BoS()));
00296
00297 lm->train();
00298
00299 lm->prunesingletons(prunesingletons==YES);
00300 lm->prunetopsingletons(prunetopsingletons==YES);
00301
00302 if (prunetopsingletons==YES)
00303 lm->prunesingletons(NO);
00304
00305
00306 switch (filetype) {
00307
00308 case BIN:
00309 cerr << "saving lm (binary) ... " << filename << "\n";
00310 lm->saveBIN(filename,backoff,dictfile,memmap);
00311 cerr << "\n";
00312 break;
00313
00314 case ARPA:
00315 cerr << "save lm (ARPA)... " << filename << "\n";
00316 lm->saveARPA(filename,backoff,dictfile);
00317 cerr << "\n";
00318 break;
00319
00320 case ASR:
00321 cerr << "save lm (ASR)... " << filename << "\n";
00322 lm->saveASR(filename,backoff,dictfile);
00323 cerr << "\n";
00324 break;
00325
00326 case NGT:
00327 cerr << "save the ngramtable on ... " << filename << "\n";
00328 {
00329 ifstream ifs(tmpngtfile, ios::binary);
00330 std::ofstream ofs(filename, std::ios::binary);
00331 ofs << ifs.rdbuf();
00332 }
00333 cerr << "\n";
00334 break;
00335
00336 default:
00337 cerr << "Saving type is unknown\n";
00338 exit(1);
00339 };
00340
00341
00342 ng.size=(ng.size>lm->maxlevel())?lm->maxlevel():ng.size;
00343 lastwords = new char*[lm->maxlevel()];
00344
00345 for (int i=1; i<lm->maxlevel(); i++) {
00346 lastwords[i] = new char[BUFSIZ];
00347 if (i<=ng.size)
00348 strcpy(lastwords[i],lm->dict->decode(*ng.wordp(i)));
00349 else
00350 strcpy(lastwords[i],lm->dict->BoS());
00351 }
00352
00353 deinit(&lm);
00354
00355 init(&lm, lmtype, tmpngtfile, size, prunefreq, beta, backoff, dub, oovrate, max_caching_level);
00356 if (remove(tmpngtfile) != 0)
00357 cerr << "Error deleting file " << tmpngtfile << endl;
00358 else
00359 cerr << "File " << tmpngtfile << " successfully deleted" << endl;
00360
00361
00362 ng.dict=ng2.dict=lm->dict;
00363 ng.size=lm->maxlevel();
00364
00365
00366 for (int i=1; i<lm->maxlevel(); i++) {
00367 *ng.wordp(i)=lm->dict->encode(lastwords[i]);
00368 delete []lastwords[i];
00369 }
00370 delete []lastwords;
00371
00372
00373
00374 lm->dict->freq(lm->dict->encode(lm->dict->BoS()), BoSfreq);
00375 break;
00376
00377
00378 case RESET:
00379 deinit(&lm);
00380
00381 init(&lm, lmtype, NULL, size, prunefreq, beta, backoff, dub, oovrate, max_caching_level);
00382
00383 ng.dict=ng2.dict=lm->dict;
00384 cerr << "filling the initial n-grams with BoS\n";
00385 for (int i=1; i<lm->maxlevel(); i++) {
00386 ng.pushw(lm->dict->BoS());
00387 ng.freq=1;
00388 }
00389 break;
00390
00391
00392 case INIT:
00393 cerr << "CMD " << header << " not yet implemented\n";
00394 exit(1);
00395 break;
00396
00397 case LOAD:
00398 inp >> filename;
00399 cerr << " |" << filename << "|\n";
00400
00401
00402 isym=new char[BUFSIZ];
00403 strcpy(isym,lm->dict->EoS());
00404 ngramtable* ngt;
00405
00406 switch (filetype) {
00407
00408 case NGT:
00409 cerr << "loading an ngramtable..." << filename << "\n";
00410 ngt = new ngramtable(filename,size,isym,NULL,NULL);
00411 ((ngramtable*) lm)->augment(ngt);
00412 cerr << "\n";
00413 break;
00414
00415 case TXT:
00416 cerr << "loading from text..." << filename << "\n";
00417 ngt= new ngramtable(filename,size,isym,NULL,NULL);
00418 ((ngramtable*) lm)->augment(ngt);
00419 cerr << "\n";
00420 break;
00421
00422 default:
00423 cerr << "This file type is unknown\n";
00424 exit(1);
00425 };
00426
00427 break;
00428
00429 default:
00430 cerr << "CMD " << header << " is unknown\n";
00431 exit(1);
00432 };
00433 } else {
00434 ng.pushw(header);
00435
00436
00437 ng2.trans(ng);
00438
00439 lm->check_dictsize_bound();
00440
00441
00442 if (ng.size) lm->dict->incfreq(*ng2.wordp(1),1);
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452 lm->put(ng2);
00453
00454 if (!(++c % 1000000)) cerr << ".";
00455 }
00456 }
00457
00458 if (statistics) {
00459 cerr << "TLM: lm stat ...";
00460 lm->lmstat(statistics);
00461 cerr << "\n";
00462 }
00463
00464 cerr << "TLM: deleting lm ...";
00465
00466 cerr << "\n";
00467
00468 exit(0);
00469 }
00470
00471 int init(mdiadaptlm** lm, int lmtype, char *trainfile, int size, int prunefreq, double beta, int backoff, int dub, double oovrate, int mcl)
00472 {
00473
00474 cerr << "initializing lm... \n";
00475 if (trainfile) cerr << "creating lm from " << trainfile << "\n";
00476 else cerr << "creating an empty lm\n";
00477 switch (lmtype) {
00478
00479 case SHIFT_BETA:
00480 if (beta==-1 || (beta<1.0 && beta>0))
00481 *lm=new shiftbeta(trainfile,size,prunefreq,beta,(backoff?SHIFTBETA_B:SHIFTBETA_I));
00482 else {
00483 cerr << "ShiftBeta: beta must be >0 and <1\n";
00484 exit(1);
00485 }
00486 break;
00487
00488 case MOD_SHIFT_BETA:
00489 if (size>1)
00490 *lm=new mshiftbeta(trainfile,size,prunefreq,(backoff?MSHIFTBETA_B:MSHIFTBETA_I));
00491 else {
00492 cerr << "Modified Shift Beta requires size > 1!\n";
00493 exit(1);
00494 }
00495 break;
00496
00497 case SHIFT_ONE:
00498 *lm=new shiftone(trainfile,size,prunefreq,(backoff?SIMPLE_B:SIMPLE_I));
00499 break;
00500
00501 case LINEAR_WB:
00502 *lm=new linearwb(trainfile,size,prunefreq,(backoff?MSHIFTBETA_B:MSHIFTBETA_I));
00503 break;
00504
00505 case LINEAR_GT:
00506 cerr << "This LM is no more supported\n";
00507 break;
00508
00509 case MIXTURE:
00510 cerr << "not implemented yet\n";
00511 break;
00512
00513 default:
00514 cerr << "not implemented yet\n";
00515 exit(1);
00516 };
00517
00518 if (dub) (*lm)->dub(dub);
00519 (*lm)->create_caches(mcl);
00520
00521 cerr << "eventually generate OOV code\n";
00522 (*lm)->dict->genoovcode();
00523
00524 if (oovrate) (*lm)->dict->setoovrate(oovrate);
00525
00526 (*lm)->dict->incflag(1);
00527
00528 if (!trainfile) {
00529 cerr << "adding the initial dummy n-grams to make table consistent\n";
00530
00531 ngram dummyng((*lm)->dict);
00532 cerr << "preparing initial dummy n-grams\n";
00533 for (int i=1; i<(*lm)->maxlevel(); i++) {
00534 dummyng.pushw((*lm)->dict->BoS());
00535 dummyng.freq=1;
00536 }
00537 cerr << "inside init: dict: " << (*lm)->dict << " dictsize: " << (*lm)->dict->size() << "\n";
00538 cerr << "dummyng: |" << dummyng << "\n";
00539 (*lm)->put(dummyng);
00540 cerr << "inside init: dict: " << (*lm)->dict << " dictsize: " << (*lm)->dict->size() << "\n";
00541
00542 }
00543
00544 cerr << "lm initialized \n";
00545 return 1;
00546 }
00547
00548 int deinit(mdiadaptlm** lm)
00549 {
00550 delete *lm;
00551 return 1;
00552 }