00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00043 #include <cstdlib>
00044 #include <iostream>
00045 #include <map>
00046 #include <stdexcept>
00047 #include <set>
00048
00049 #include "moses/IOWrapper.h"
00050 #include "moses/LatticeMBR.h"
00051 #include "moses/Manager.h"
00052 #include "moses/Timer.h"
00053 #include "moses/StaticData.h"
00054 #include "util/exception.hh"
00055
00056 #include <boost/foreach.hpp>
00057 #include "moses/TranslationTask.h"
00058
00059 using namespace std;
00060 using namespace Moses;
00061
00062
00063 enum gridkey {lmbr_p,lmbr_r,lmbr_prune,lmbr_scale};
00064
00065 namespace Moses
00066 {
00067
00068 class Grid
00069 {
00070 public:
00072 void addParam(gridkey key, const string& arg, float defaultValue) {
00073 m_args[arg] = key;
00074 UTIL_THROW_IF2(m_grid.find(key) != m_grid.end(),
00075 "Couldn't find value for key " << (int) key);
00076 m_grid[key].push_back(defaultValue);
00077 }
00078
00080 void parseArgs(int& argc, char const**& argv) {
00081 char const** newargv = new char const*[argc+1];
00082 int newargc = 0;
00083 for (int i = 0; i < argc; ++i) {
00084 bool consumed = false;
00085 for (map<string,gridkey>::const_iterator argi = m_args.begin(); argi != m_args.end(); ++argi) {
00086 if (!strcmp(argv[i], argi->first.c_str())) {
00087 ++i;
00088 if (i >= argc) {
00089 cerr << "Error: missing parameter for " << argi->first << endl;
00090 throw runtime_error("Missing parameter");
00091 } else {
00092 string value = argv[i];
00093 gridkey key = argi->second;
00094 if (m_grid[key].size() != 1) {
00095 throw runtime_error("Duplicate grid argument");
00096 }
00097 m_grid[key].clear();
00098 char delim = ',';
00099 string::size_type lastpos = value.find_first_not_of(delim);
00100 string::size_type pos = value.find_first_of(delim,lastpos);
00101 while (string::npos != pos || string::npos != lastpos) {
00102 float param = atof(value.substr(lastpos, pos-lastpos).c_str());
00103 if (!param) {
00104 cerr << "Error: Illegal grid parameter for " << argi->first << endl;
00105 throw runtime_error("Illegal grid parameter");
00106 }
00107 m_grid[key].push_back(param);
00108 lastpos = value.find_first_not_of(delim,pos);
00109 pos = value.find_first_of(delim,lastpos);
00110 }
00111 consumed = true;
00112 }
00113 if (consumed) break;
00114 }
00115 }
00116 if (!consumed) {
00117
00118
00119 newargv[newargc] = argv[i];
00120 ++newargc;
00121 }
00122 }
00123 argc = newargc;
00124 argv = newargv;
00125 }
00126
00128 const vector<float>& getGrid(gridkey key) const {
00129 map<gridkey,vector<float> >::const_iterator iter = m_grid.find(key);
00130 assert (iter != m_grid.end());
00131 return iter->second;
00132
00133 }
00134
00135 private:
00136 map<gridkey,vector<float> > m_grid;
00137 map<string,gridkey> m_args;
00138 };
00139
00140 }
00141
00142 int main(int argc, char const* argv[])
00143 {
00144 cerr << "Lattice MBR Grid search" << endl;
00145
00146 Grid grid;
00147 grid.addParam(lmbr_p, "-lmbr-p", 0.5);
00148 grid.addParam(lmbr_r, "-lmbr-r", 0.5);
00149 grid.addParam(lmbr_prune, "-lmbr-pruning-factor",30.0);
00150 grid.addParam(lmbr_scale, "-mbr-scale",1.0);
00151
00152 grid.parseArgs(argc,argv);
00153
00154 Parameter* params = new Parameter();
00155 if (!params->LoadParam(argc,argv)) {
00156 params->Explain();
00157 exit(1);
00158 }
00159
00160 ResetUserTime();
00161 if (!StaticData::LoadDataStatic(params, argv[0])) {
00162 exit(1);
00163 }
00164
00165 StaticData& SD = const_cast<StaticData&>(StaticData::Instance());
00166 boost::shared_ptr<AllOptions> opts(new AllOptions(*SD.options()));
00167 LMBR_Options& lmbr = opts->lmbr;
00168 MBR_Options& mbr = opts->mbr;
00169 lmbr.enabled = true;
00170
00171 boost::shared_ptr<IOWrapper> ioWrapper(new IOWrapper(*opts));
00172 if (!ioWrapper) {
00173 throw runtime_error("Failed to initialise IOWrapper");
00174 }
00175 size_t nBestSize = mbr.size;
00176
00177 if (nBestSize <= 0) {
00178 throw new runtime_error("Non-positive size specified for n-best list");
00179 }
00180
00181 const vector<float>& pgrid = grid.getGrid(lmbr_p);
00182 const vector<float>& rgrid = grid.getGrid(lmbr_r);
00183 const vector<float>& prune_grid = grid.getGrid(lmbr_prune);
00184 const vector<float>& scale_grid = grid.getGrid(lmbr_scale);
00185
00186 boost::shared_ptr<InputType> source;
00187 while((source = ioWrapper->ReadInput()) != NULL) {
00188
00189 boost::shared_ptr<TranslationTask> ttask;
00190 ttask = TranslationTask::create(source, ioWrapper);
00191 Manager manager(ttask);
00192 manager.Decode();
00193 TrellisPathList nBestList;
00194 manager.CalcNBest(nBestSize, nBestList,true);
00195
00196 BOOST_FOREACH(float const& p, pgrid) {
00197 lmbr.precision = p;
00198 BOOST_FOREACH(float const& r, rgrid) {
00199 lmbr.ratio = r;
00200 BOOST_FOREACH(size_t const prune_i, prune_grid) {
00201 lmbr.pruning_factor = prune_i;
00202 BOOST_FOREACH(float const& scale_i, scale_grid) {
00203 mbr.scale = scale_i;
00204 size_t lineCount = source->GetTranslationId();
00205 cout << lineCount << " ||| " << p << " "
00206 << r << " " << size_t(prune_i) << " " << scale_i
00207 << " ||| ";
00208 vector<Word> mbrBestHypo = doLatticeMBR(manager,nBestList);
00209 manager.OutputBestHypo(mbrBestHypo, cout);
00210 }
00211 }
00212 }
00213 }
00214 }
00215 }