00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00043 #include <cstdlib>
00044 #include <iostream>
00045 #include <map>
00046 #include <stdexcept>
00047 #include <set>
00048
00049 #include "IOWrapper.h"
00050 #include "LatticeMBR.h"
00051 #include "moses/Manager.h"
00052 #include "moses/StaticData.h"
00053
00054
00055 using namespace std;
00056 using namespace Moses;
00057 using namespace MosesCmd;
00058
00059
00060 enum gridkey {lmbr_p,lmbr_r,lmbr_prune,lmbr_scale};
00061
00062 namespace MosesCmd
00063 {
00064
00065 class Grid
00066 {
00067 public:
00069 void addParam(gridkey key, const string& arg, float defaultValue) {
00070 m_args[arg] = key;
00071 CHECK(m_grid.find(key) == m_grid.end());
00072 m_grid[key].push_back(defaultValue);
00073 }
00074
00076 void parseArgs(int& argc, char**& argv) {
00077 char** newargv = new char*[argc+1];
00078 int newargc = 0;
00079 for (int i = 0; i < argc; ++i) {
00080 bool consumed = false;
00081 for (map<string,gridkey>::const_iterator argi = m_args.begin(); argi != m_args.end(); ++argi) {
00082 if (!strcmp(argv[i], argi->first.c_str())) {
00083 ++i;
00084 if (i >= argc) {
00085 cerr << "Error: missing parameter for " << argi->first << endl;
00086 throw runtime_error("Missing parameter");
00087 } else {
00088 string value = argv[i];
00089 gridkey key = argi->second;
00090 if (m_grid[key].size() != 1) {
00091 throw runtime_error("Duplicate grid argument");
00092 }
00093 m_grid[key].clear();
00094 char delim = ',';
00095 string::size_type lastpos = value.find_first_not_of(delim);
00096 string::size_type pos = value.find_first_of(delim,lastpos);
00097 while (string::npos != pos || string::npos != lastpos) {
00098 float param = atof(value.substr(lastpos, pos-lastpos).c_str());
00099 if (!param) {
00100 cerr << "Error: Illegal grid parameter for " << argi->first << endl;
00101 throw runtime_error("Illegal grid parameter");
00102 }
00103 m_grid[key].push_back(param);
00104 lastpos = value.find_first_not_of(delim,pos);
00105 pos = value.find_first_of(delim,lastpos);
00106 }
00107 consumed = true;
00108 }
00109 if (consumed) break;
00110 }
00111 }
00112 if (!consumed) {
00113 newargv[newargc] = new char[strlen(argv[i]) + 1];
00114 strcpy(newargv[newargc],argv[i]);
00115 ++newargc;
00116 }
00117 }
00118 argc = newargc;
00119 argv = newargv;
00120 }
00121
00123 const vector<float>& getGrid(gridkey key) const {
00124 map<gridkey,vector<float> >::const_iterator iter = m_grid.find(key);
00125 assert (iter != m_grid.end());
00126 return iter->second;
00127
00128 }
00129
00130 private:
00131 map<gridkey,vector<float> > m_grid;
00132 map<string,gridkey> m_args;
00133 };
00134
00135 }
00136
00137 int main(int argc, char* argv[])
00138 {
00139 cerr << "Lattice MBR Grid search" << endl;
00140
00141 Grid grid;
00142 grid.addParam(lmbr_p, "-lmbr-p", 0.5);
00143 grid.addParam(lmbr_r, "-lmbr-r", 0.5);
00144 grid.addParam(lmbr_prune, "-lmbr-pruning-factor",30.0);
00145 grid.addParam(lmbr_scale, "-mbr-scale",1.0);
00146
00147 grid.parseArgs(argc,argv);
00148
00149 Parameter* params = new Parameter();
00150 if (!params->LoadParam(argc,argv)) {
00151 params->Explain();
00152 exit(1);
00153 }
00154 if (!StaticData::LoadDataStatic(params, argv[0])) {
00155 exit(1);
00156 }
00157
00158 StaticData& staticData = const_cast<StaticData&>(StaticData::Instance());
00159 staticData.SetUseLatticeMBR(true);
00160 IOWrapper* ioWrapper = GetIOWrapper(staticData);
00161
00162 if (!ioWrapper) {
00163 throw runtime_error("Failed to initialise IOWrapper");
00164 }
00165 size_t nBestSize = staticData.GetMBRSize();
00166
00167 if (nBestSize <= 0) {
00168 throw new runtime_error("Non-positive size specified for n-best list");
00169 }
00170
00171 size_t lineCount = 0;
00172 InputType* source = NULL;
00173
00174 const vector<float>& pgrid = grid.getGrid(lmbr_p);
00175 const vector<float>& rgrid = grid.getGrid(lmbr_r);
00176 const vector<float>& prune_grid = grid.getGrid(lmbr_prune);
00177 const vector<float>& scale_grid = grid.getGrid(lmbr_scale);
00178
00179 while(ReadInput(*ioWrapper,staticData.GetInputType(),source)) {
00180 ++lineCount;
00181 Sentence sentence;
00182 Manager manager(lineCount, *source, staticData.GetSearchAlgorithm());
00183 manager.ProcessSentence();
00184 TrellisPathList nBestList;
00185 manager.CalcNBest(nBestSize, nBestList,true);
00186
00187 for (vector<float>::const_iterator pi = pgrid.begin(); pi != pgrid.end(); ++pi) {
00188 float p = *pi;
00189 staticData.SetLatticeMBRPrecision(p);
00190 for (vector<float>::const_iterator ri = rgrid.begin(); ri != rgrid.end(); ++ri) {
00191 float r = *ri;
00192 staticData.SetLatticeMBRPRatio(r);
00193 for (vector<float>::const_iterator prune_i = prune_grid.begin(); prune_i != prune_grid.end(); ++prune_i) {
00194 size_t prune = (size_t)(*prune_i);
00195 staticData.SetLatticeMBRPruningFactor(prune);
00196 for (vector<float>::const_iterator scale_i = scale_grid.begin(); scale_i != scale_grid.end(); ++scale_i) {
00197 float scale = *scale_i;
00198 staticData.SetMBRScale(scale);
00199 cout << lineCount << " ||| " << p << " " << r << " " << prune << " " << scale << " ||| ";
00200 vector<Word> mbrBestHypo = doLatticeMBR(manager,nBestList);
00201 OutputBestHypo(mbrBestHypo, lineCount, staticData.GetReportSegmentation(),
00202 staticData.GetReportAllFactors(),cout);
00203 }
00204 }
00205
00206 }
00207 }
00208
00209
00210 }
00211
00212 }