00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00043 #include <cstdlib>
00044 #include <iostream>
00045 #include <map>
00046 #include <stdexcept>
00047 #include <set>
00048
00049 #include "IOWrapper.h"
00050 #include "LatticeMBR.h"
00051 #include "Manager.h"
00052 #include "StaticData.h"
00053
00054
00055 using namespace std;
00056 using namespace Moses;
00057
00058
00059 enum gridkey {lmbr_p,lmbr_r,lmbr_prune,lmbr_scale};
00060
00061 class Grid
00062 {
00063 public:
00065 void addParam(gridkey key, const string& arg, float defaultValue) {
00066 m_args[arg] = key;
00067 CHECK(m_grid.find(key) == m_grid.end());
00068 m_grid[key].push_back(defaultValue);
00069 }
00070
00072 void parseArgs(int& argc, char**& argv) {
00073 char** newargv = new char*[argc+1];
00074 int newargc = 0;
00075 for (int i = 0; i < argc; ++i) {
00076 bool consumed = false;
00077 for (map<string,gridkey>::const_iterator argi = m_args.begin(); argi != m_args.end(); ++argi) {
00078 if (!strcmp(argv[i], argi->first.c_str())) {
00079 ++i;
00080 if (i >= argc) {
00081 cerr << "Error: missing parameter for " << argi->first << endl;
00082 throw runtime_error("Missing parameter");
00083 } else {
00084 string value = argv[i];
00085 gridkey key = argi->second;
00086 if (m_grid[key].size() != 1) {
00087 throw runtime_error("Duplicate grid argument");
00088 }
00089 m_grid[key].clear();
00090 char delim = ',';
00091 string::size_type lastpos = value.find_first_not_of(delim);
00092 string::size_type pos = value.find_first_of(delim,lastpos);
00093 while (string::npos != pos || string::npos != lastpos) {
00094 float param = atof(value.substr(lastpos, pos-lastpos).c_str());
00095 if (!param) {
00096 cerr << "Error: Illegal grid parameter for " << argi->first << endl;
00097 throw runtime_error("Illegal grid parameter");
00098 }
00099 m_grid[key].push_back(param);
00100 lastpos = value.find_first_not_of(delim,pos);
00101 pos = value.find_first_of(delim,lastpos);
00102 }
00103 consumed = true;
00104 }
00105 if (consumed) break;
00106 }
00107 }
00108 if (!consumed) {
00109 newargv[newargc] = new char[strlen(argv[i]) + 1];
00110 strcpy(newargv[newargc],argv[i]);
00111 ++newargc;
00112 }
00113 }
00114 argc = newargc;
00115 argv = newargv;
00116 }
00117
00119 const vector<float>& getGrid(gridkey key) const {
00120 map<gridkey,vector<float> >::const_iterator iter = m_grid.find(key);
00121 assert (iter != m_grid.end());
00122 return iter->second;
00123
00124 }
00125
00126 private:
00127 map<gridkey,vector<float> > m_grid;
00128 map<string,gridkey> m_args;
00129 };
00130
00131 int main(int argc, char* argv[])
00132 {
00133 cerr << "Lattice MBR Grid search" << endl;
00134
00135 Grid grid;
00136 grid.addParam(lmbr_p, "-lmbr-p", 0.5);
00137 grid.addParam(lmbr_r, "-lmbr-r", 0.5);
00138 grid.addParam(lmbr_prune, "-lmbr-pruning-factor",30.0);
00139 grid.addParam(lmbr_scale, "-mbr-scale",1.0);
00140
00141 grid.parseArgs(argc,argv);
00142
00143 Parameter* params = new Parameter();
00144 if (!params->LoadParam(argc,argv)) {
00145 params->Explain();
00146 exit(1);
00147 }
00148 if (!StaticData::LoadDataStatic(params)) {
00149 exit(1);
00150 }
00151
00152 StaticData& staticData = const_cast<StaticData&>(StaticData::Instance());
00153 staticData.SetUseLatticeMBR(true);
00154 IOWrapper* ioWrapper = GetIODevice(staticData);
00155
00156 if (!ioWrapper) {
00157 throw runtime_error("Failed to initialise IOWrapper");
00158 }
00159 size_t nBestSize = staticData.GetMBRSize();
00160
00161 if (nBestSize <= 0) {
00162 throw new runtime_error("Non-positive size specified for n-best list");
00163 }
00164
00165 size_t lineCount = 0;
00166 InputType* source = NULL;
00167
00168 const vector<float>& pgrid = grid.getGrid(lmbr_p);
00169 const vector<float>& rgrid = grid.getGrid(lmbr_r);
00170 const vector<float>& prune_grid = grid.getGrid(lmbr_prune);
00171 const vector<float>& scale_grid = grid.getGrid(lmbr_scale);
00172
00173 while(ReadInput(*ioWrapper,staticData.GetInputType(),source)) {
00174 ++lineCount;
00175 Sentence sentence;
00176 const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
00177 Manager manager(*source,staticData.GetSearchAlgorithm(), &system);
00178 manager.ProcessSentence();
00179 TrellisPathList nBestList;
00180 manager.CalcNBest(nBestSize, nBestList,true);
00181
00182 for (vector<float>::const_iterator pi = pgrid.begin(); pi != pgrid.end(); ++pi) {
00183 float p = *pi;
00184 staticData.SetLatticeMBRPrecision(p);
00185 for (vector<float>::const_iterator ri = rgrid.begin(); ri != rgrid.end(); ++ri) {
00186 float r = *ri;
00187 staticData.SetLatticeMBRPRatio(r);
00188 for (vector<float>::const_iterator prune_i = prune_grid.begin(); prune_i != prune_grid.end(); ++prune_i) {
00189 size_t prune = (size_t)(*prune_i);
00190 staticData.SetLatticeMBRPruningFactor(prune);
00191 for (vector<float>::const_iterator scale_i = scale_grid.begin(); scale_i != scale_grid.end(); ++scale_i) {
00192 float scale = *scale_i;
00193 staticData.SetMBRScale(scale);
00194 cout << lineCount << " ||| " << p << " " << r << " " << prune << " " << scale << " ||| ";
00195 vector<Word> mbrBestHypo = doLatticeMBR(manager,nBestList);
00196 OutputBestHypo(mbrBestHypo, lineCount, staticData.GetReportSegmentation(),
00197 staticData.GetReportAllFactors(),cout);
00198 }
00199 }
00200
00201 }
00202 }
00203
00204
00205 }
00206
00207 }