00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 using namespace std;
00030
00031 #include <cmath>
00032 #include "util.h"
00033 #include <sstream>
00034 #include "mfstream.h"
00035 #include "mempool.h"
00036 #include "htable.h"
00037 #include "dictionary.h"
00038 #include "n_gram.h"
00039 #include "ngramtable.h"
00040 #include "cmd.h"
00041
00042 #define YES 1
00043 #define NO 0
00044
00045 void print_help(int TypeFlag=0){
00046 std::cerr << std::endl << "dtsel - performs data selection" << std::endl;
00047 std::cerr << std::endl << "USAGE:" << std::endl
00048 << " dtsel -s=<outfile> [options]" << std::endl;
00049 std::cerr << std::endl << "OPTIONS:" << std::endl;
00050 FullPrintParams(TypeFlag, 0, 1, stderr);
00051 }
00052
00053 void usage(const char *msg = 0)
00054 {
00055 if (msg){
00056 std::cerr << msg << std::endl;
00057 }
00058 else{
00059 print_help();
00060 }
00061 exit(1);
00062 }
00063
00064 double prob(ngramtable* ngt,ngram ng,int size,int cv){
00065 double fstar,lambda;
00066
00067 assert(size<=ngt->maxlevel() && size<=ng.size);
00068 if (size>1){
00069 ngram history=ng;
00070 if (ngt->get(history,size,size-1) && history.freq>cv){
00071 fstar=0.0;
00072 if (ngt->get(ng,size,size)){
00073 cv=(cv>ng.freq)?ng.freq:cv;
00074 if (ng.freq>cv){
00075 fstar=(double)(ng.freq-cv)/(double)(history.freq -cv + history.succ);
00076 lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
00077 }else
00078 lambda=(double)(history.succ-1)/(double)(history.freq -cv + history.succ-1);
00079 }
00080 else
00081 lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
00082
00083 return fstar + lambda * prob(ngt,ng,size-1,cv);
00084 }
00085 else return prob(ngt,ng,size-1,cv);
00086
00087 }else{
00088 if (ngt->get(ng,1,1) && ng.freq>cv)
00089 return (double)(ng.freq-cv)/(ngt->totfreq()-1);
00090 else{
00091
00092 *ng.wordp(1)=ngt->dict->oovcode();
00093 if (ngt->get(ng,1,1) && ng.freq>0)
00094 return (double)ng.freq/ngt->totfreq();
00095 else
00096 return (double)ngt->dict->size()/(ngt->totfreq()+ngt->dict->size());
00097 }
00098
00099 }
00100
00101 }
00102
00103
00104 double computePP(ngramtable* train,ngramtable* test,double oovpenalty,double& oovrate,int cv=0){
00105
00106
00107 ngram ng2(test->dict);ngram ng1(train->dict);
00108 int N=0; double H=0; oovrate=0;
00109
00110 test->scan(ng2,INIT,test->maxlevel());
00111 while(test->scan(ng2,CONT,test->maxlevel())) {
00112
00113 ng1.trans(ng2);
00114 H-=log(prob(train,ng1,ng1.size,cv));
00115 if (*ng1.wordp(1)==train->dict->oovcode()){
00116 H-=oovpenalty;
00117 oovrate++;
00118 }
00119 N++;
00120 }
00121 oovrate/=N;
00122 return exp(H/N);
00123 }
00124
00125
00126 int main(int argc, char **argv)
00127 {
00128 char *indom=NULL;
00129 char *outdom=NULL;
00130 char *scorefile=NULL;
00131 char *evalset=NULL;
00132
00133 int minfreq=2;
00134 int ngsz=0;
00135 int dub=10000000;
00136 int model=2;
00137
00138 int cv=1;
00139
00140 int blocksize=100000;
00141 int verbose=0;
00142 int useindex=0;
00143 double convergence_treshold=0;
00144
00145 bool help=false;
00146
00147 DeclareParams((char*)
00148 "min-word-freq", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
00149 "f", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
00150
00151 "ngram-order", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
00152 "n", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
00153
00154 "in-domain-file", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
00155 "i", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
00156
00157 "out-domain-file", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
00158 "o", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
00159
00160 "score-file", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
00161 "s", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
00162
00163 "dictionary-upper-bound", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
00164 "dub", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
00165
00166 "model", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
00167 "m", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
00168
00169 "cross-validation", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
00170 "cv", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
00171
00172 "test", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
00173 "t", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
00174
00175 "block-size", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
00176 "bs", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
00177
00178 "convergence-threshold", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
00179 "c", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
00180
00181 "index", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
00182 "x", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
00183
00184 "verbose", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
00185 "v", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
00186 "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
00187 "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
00188
00189 (char *)NULL
00190 );
00191
00192
00193
00194 GetParams(&argc, &argv, (char*) NULL);
00195
00196 if (help){
00197 usage();
00198 }
00199 if (scorefile==NULL) {
00200 usage();
00201 }
00202
00203 if (!evalset && (!indom || !outdom)){
00204 cerr <<"Must specify in-domain and out-domain data files\n";
00205 exit(1);
00206 };
00207
00208
00209 if (!scorefile){
00210 cerr <<"Must specify score file\n";
00211 exit(1);
00212 };
00213
00214 if (!evalset && !model){
00215 cerr <<"Must specify data selection model\n";
00216 exit(1);
00217 }
00218
00219 if (evalset && (convergence_treshold<0 || convergence_treshold > 0.1)){
00220 cerr <<"Convergence threshold must be between 0 and 0.1. \n";
00221 exit(1);
00222 }
00223
00224 TABLETYPE table_type=COUNT;
00225
00226
00227 if (!evalset){
00228
00229
00230 dictionary *dict = new dictionary(indom,1000000,0);
00231 dictionary *pd=new dictionary(dict,true,minfreq);
00232 delete dict;dict=pd;
00233
00234
00235 ngramtable *indngt=new ngramtable(indom,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
00236 double indoovpenalty=-log(dub-indngt->dict->size());
00237 ngram indng(indngt->dict);
00238 int indoovcode=indngt->dict->oovcode();
00239
00240
00241 char command[1000]="";
00242
00243 if (useindex)
00244 sprintf(command,"cut -d \" \" -f 2- %s",outdom);
00245 else
00246 sprintf(command,"%s",outdom);
00247
00248
00249 ngramtable *outdngt=new ngramtable(command,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
00250 double outdoovpenalty=-log(dub-outdngt->dict->size());
00251 ngram outdng(outdngt->dict);
00252 int outdoovcode=outdngt->dict->oovcode();
00253
00254 cerr << "dict size idom: " << indngt->dict->size() << " odom: " << outdngt->dict->size() << "\n";
00255 cerr << "oov penalty idom: " << indoovpenalty << " odom: " << outdoovpenalty << "\n";
00256
00257
00258 int bos=dict->encode(dict->BoS());
00259 mfstream inp(outdom,ios::in); ngram ng(dict);
00260 mfstream txt(outdom,ios::in);
00261 mfstream output(scorefile,ios::out);
00262
00263
00264 int linenumber=1; string line;
00265 int lenght=0;float deltaH=0; float deltaHoov=0; int words=0;string index;
00266
00267 while (getline(inp,line)){
00268
00269 istringstream lninp(line);
00270
00271 linenumber++;
00272
00273 if (useindex) lninp >> index;
00274
00275
00276 ng.size=1; deltaH=0;deltaHoov=0; lenght=0;
00277
00278 while(lninp>>ng){
00279
00280 if (*ng.wordp(1)==bos) continue;
00281
00282 lenght++; words++;
00283
00284 if ((words % 1000000)==0) cerr << ".";
00285
00286 if (ng.size>ngsz) ng.size=ngsz;
00287 indng.trans(ng);outdng.trans(ng);
00288
00289 if (model==1){
00290 deltaH-=log(prob(indngt,indng,indng.size,0));
00291 deltaHoov-=(*indng.wordp(1)==indoovcode?indoovpenalty:0);
00292 }
00293
00294 if (model==2){
00295 deltaH+=log(prob(outdngt,outdng,outdng.size,cv))-log(prob(indngt,indng,indng.size,0));
00296 deltaHoov+=(*outdng.wordp(1)==outdoovcode?outdoovpenalty:0)-(*indng.wordp(1)==indoovcode?indoovpenalty:0);
00297 }
00298 }
00299
00300 output << (deltaH + deltaHoov)/lenght << " " << line << "\n";
00301 }
00302 }
00303 else{
00304
00305
00306 ngramtable *tstngt=new ngramtable(evalset,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
00307
00308
00309 ngramtable *outdngt=new ngramtable(NULL,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
00310
00311
00312 dictionary *dict = NULL;
00313 if (indom){
00314 cerr << "dtsel: limit evaluation dict to indomain words with freq >=" << minfreq << "\n";
00315
00316 dict = new dictionary(indom,1000000,0);
00317 dictionary *pd=new dictionary(dict,true,minfreq);
00318 delete dict;dict=pd;
00319 outdngt->dict=dict;
00320 }
00321
00322 dictionary* outddict=outdngt->dict;
00323
00324
00325 outddict->incflag(1);
00326 int bos=outddict->encode(outddict->BoS());
00327 int oov=outddict->encode(outddict->OOV());
00328 outddict->incflag(0);
00329 outddict->oovcode(oov);
00330
00331
00332 double oldPP=dub; double newPP=0; double oovrate=0;
00333
00334 long totwords=0; long totlines=0; long nextstep=blocksize;
00335
00336 double score; string index;
00337
00338 mfstream outd(scorefile,ios::in); string line;
00339
00340
00341 ngram ng(outdngt->dict); for (int i=1;i<ngsz;i++) ng.pushc(bos); ng.freq=1;
00342
00343
00344
00345 if (!dict) outddict->incflag(1);
00346
00347 while (getline(outd,line)){
00348
00349 istringstream lninp(line);
00350
00351
00352 lninp >> score; if (useindex) lninp >> index;
00353
00354 while (lninp >> ng){
00355
00356 if (*ng.wordp(1) == bos) continue;
00357
00358 if (ng.size>ngsz) ng.size=ngsz;
00359
00360 outdngt->put(ng);
00361
00362 totwords++;
00363 }
00364
00365 totlines++;
00366
00367 if (totwords>=nextstep){
00368
00369 if (!dict) outddict->incflag(0);
00370
00371 newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
00372
00373 if (!dict) outddict->incflag(1);
00374
00375 cout << totwords << " " << newPP;
00376 if (verbose) cout << " " << totlines << " " << oovrate;
00377 cout << "\n";
00378
00379 if (convergence_treshold && (oldPP-newPP)/oldPP < convergence_treshold) return 1;
00380
00381 oldPP=newPP;
00382
00383 nextstep+=blocksize;
00384 }
00385 }
00386
00387 if (!dict) outddict->incflag(0);
00388 newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
00389 cout << totwords << " " << newPP;
00390 if (verbose) cout << " " << totlines << " " << oovrate;
00391
00392 }
00393
00394 }
00395
00396
00397