00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 using namespace std;
00024
00025 #include <iostream>
00026 #include <fstream>
00027 #include <vector>
00028 #include <string>
00029 #include <stdlib.h>
00030 #include "util.h"
00031 #include "math.h"
00032 #include "lmtable.h"
00033
00034
00035
00036 std::string spthr = "0";
00037 int aflag=0;
00038
00039
00040 void usage(const char *msg = 0) {
00041 if (msg) { std::cerr << msg << std::endl; }
00042 std::cerr << "Usage: prune-lm [--threshold=th2,th3,...] [--abs=1|0] input-file [output-file]" << std::endl << std::endl;
00043 std::cerr << " prune-lm reads a LM in either ARPA or compiled format and" << std::endl;
00044 std::cerr << " prunes out n-grams (n=2,3,..) for which backing-off to the" << std::endl;
00045 std::cerr << " lower order n-gram results in a small difference in probability." << std::endl;
00046 std::cerr << " The pruned LM is saved in ARPA format" << std::endl << std::endl;
00047 std::cerr << " Options:" << std::endl;
00048 std::cerr << " --threshold=th2,th3,th4,... (pruning threshods for 2-grams, 3-grams, 4-grams,..." << std::endl;
00049 std::cerr << " If less thresholds are specified, the last one is " << std::endl;
00050 std::cerr << " applied to all following n-gram levels. " << std::endl << std::endl;
00051 std::cerr << " --abs=1|0 if 1, use absolute value of weighted difference"<< std::endl;
00052
00053 }
00054
00055 bool starts_with(const std::string &s, const std::string &pre) {
00056 if (pre.size() > s.size()) return false;
00057
00058 if (pre == s) return true;
00059 std::string pre_equals(pre+'=');
00060 if (pre_equals.size() > s.size()) return false;
00061 return (s.substr(0,pre_equals.size()) == pre_equals);
00062 }
00063
00064 std::string get_param(const std::string& opt, int argc, const char **argv, int& argi)
00065 {
00066 std::string::size_type equals = opt.find_first_of('=');
00067 if (equals != std::string::npos && equals < opt.size()-1) {
00068 return opt.substr(equals+1);
00069 }
00070 std::string nexto;
00071 if (argi + 1 < argc) {
00072 nexto = argv[++argi];
00073 } else {
00074 usage((opt + " requires a value!").c_str());
00075 exit(1);
00076 }
00077 return nexto;
00078 }
00079
00080 void handle_option(const std::string& opt, int argc, const char **argv, int& argi)
00081 {
00082 if (opt == "--help" || opt == "-h") { usage(); exit(1); }
00083
00084 if (starts_with(opt, "--threshold") || starts_with(opt, "-t"))
00085 spthr = get_param(opt, argc, argv, argi);
00086 else if (starts_with(opt, "--abs"))
00087 aflag = atoi(get_param(opt, argc, argv, argi).c_str());
00088
00089 else {
00090 usage(("Don't understand option " + opt).c_str());
00091 exit(1);
00092 }
00093 }
00094
00095 void s2t(string cps,
00096 float *thr)
00097 {
00098 int i;
00099 char *s=strdup(cps.c_str()),
00100 *tk;
00101
00102 thr[0]=0;
00103 for(i=1,tk=strtok(s, ","); tk; tk=strtok(0, ","),i++) thr[i]=atof(tk);
00104 for(; i<MAX_NGRAM; i++) thr[i]=thr[i-1];
00105 }
00106
00107 int main(int argc, const char **argv)
00108 {
00109 float thr[MAX_NGRAM];
00110
00111 if (argc < 2) { usage(); exit(1); }
00112 std::vector<std::string> files;
00113 for(int i=1; i < argc; i++) {
00114 std::string opt = argv[i];
00115 if(opt[0] == '-') handle_option(opt, argc, argv, i);
00116 else files.push_back(opt);
00117 }
00118 if (files.size() > 2) { usage("Too many arguments"); exit(1); }
00119 if (files.size() < 1) { usage("Please specify a LM file to read from"); exit(1); }
00120 memset(thr, 0, sizeof(thr));
00121 if(spthr != "") s2t(spthr, thr);
00122 std::string infile = files[0];
00123 std::string outfile= "";
00124
00125 if (files.size() == 1) {
00126 outfile=infile;
00127
00128
00129 std::string::size_type p = outfile.rfind('/');
00130 if (p != std::string::npos && ((p+1) < outfile.size()))
00131 outfile.erase(0,p+1);
00132
00133
00134 if (outfile.compare(outfile.size()-3,3,".gz")==0)
00135 outfile.erase(outfile.size()-3,3);
00136
00137 outfile+=".plm";
00138 }
00139 else
00140 outfile = files[1];
00141
00142
00143 lmtable lmt;
00144 inputfilestream inp(infile.c_str());
00145 if (!inp.good()) {
00146 std::cerr << "Failed to open " << infile << "!" << std::endl;
00147 exit(1);
00148 }
00149
00150 lmt.load(inp,infile.c_str(),outfile.c_str(),0,NONE);
00151 std::cerr << "pruning LM with thresholds: \n";
00152
00153 for (int i=1;i<lmt.maxlevel();i++) std::cerr<< " " << thr[i];
00154 std::cerr << "\n";
00155 lmt.wdprune((float*)thr, aflag);
00156 lmt.savetxt(outfile.c_str());
00157 return 0;
00158 }
00159