00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #ifdef WIN32
00036
00037
00038 #endif
00039
00040 #include <exception>
00041 #include <fstream>
00042 #include "Main.h"
00043 #include "TranslationAnalysis.h"
00044 #include "mbr.h"
00045 #include "IOWrapper.h"
00046
00047 #include "moses/DummyScoreProducers.h"
00048 #include "moses/FactorCollection.h"
00049 #include "moses/Manager.h"
00050 #include "moses/Phrase.h"
00051 #include "moses/Util.h"
00052 #include "moses/Timer.h"
00053 #include "moses/Sentence.h"
00054 #include "moses/ConfusionNet.h"
00055 #include "moses/WordLattice.h"
00056 #include "moses/TreeInput.h"
00057 #include "moses/ThreadPool.h"
00058 #include "moses/ChartManager.h"
00059 #include "moses/ChartHypothesis.h"
00060 #include "moses/ChartTrellisPath.h"
00061 #include "moses/ChartTrellisPathList.h"
00062 #include "moses/Incremental.h"
00063
00064 #include "util/usage.hh"
00065
00066 using namespace std;
00067 using namespace Moses;
00068 using namespace MosesChartCmd;
00069
00073 class TranslationTask : public Task
00074 {
00075 public:
00076 TranslationTask(InputType *source, IOWrapper &ioWrapper)
00077 : m_source(source)
00078 , m_ioWrapper(ioWrapper)
00079 {}
00080
00081 ~TranslationTask() {
00082 delete m_source;
00083 }
00084
00085 void Run() {
00086 const StaticData &staticData = StaticData::Instance();
00087 const TranslationSystem &system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
00088 const size_t translationId = m_source->GetTranslationId();
00089
00090 VERBOSE(2,"\nTRANSLATING(" << translationId << "): " << *m_source);
00091
00092 if (staticData.GetSearchAlgorithm() == ChartIncremental) {
00093 Incremental::Manager manager(*m_source, system);
00094 const std::vector<search::Applied> &nbest = manager.ProcessSentence();
00095 if (!nbest.empty()) {
00096 m_ioWrapper.OutputBestHypo(nbest[0], translationId);
00097 } else {
00098 m_ioWrapper.OutputBestNone(translationId);
00099 }
00100 if (staticData.GetNBestSize() > 0)
00101 m_ioWrapper.OutputNBestList(nbest, system, translationId);
00102 return;
00103 }
00104
00105 ChartManager manager(*m_source, &system);
00106 manager.ProcessSentence();
00107
00108 CHECK(!staticData.UseMBR());
00109
00110
00111 const ChartHypothesis *bestHypo = manager.GetBestHypothesis();
00112 m_ioWrapper.OutputBestHypo(bestHypo, translationId);
00113 IFVERBOSE(2) {
00114 PrintUserTime("Best Hypothesis Generation Time:");
00115 }
00116
00117 if (!staticData.GetAlignmentOutputFile().empty()) {
00118 m_ioWrapper.OutputAlignment(translationId, bestHypo);
00119 }
00120
00121 if (staticData.IsDetailedTranslationReportingEnabled()) {
00122 const Sentence &sentence = dynamic_cast<const Sentence &>(*m_source);
00123 m_ioWrapper.OutputDetailedTranslationReport(bestHypo, sentence, translationId);
00124 }
00125
00126
00127 size_t nBestSize = staticData.GetNBestSize();
00128 if (nBestSize > 0) {
00129 VERBOSE(2,"WRITING " << nBestSize << " TRANSLATION ALTERNATIVES TO " << staticData.GetNBestFilePath() << endl);
00130 ChartTrellisPathList nBestList;
00131 manager.CalcNBest(nBestSize, nBestList,staticData.GetDistinctNBest());
00132 m_ioWrapper.OutputNBestList(nBestList, &system, translationId);
00133 IFVERBOSE(2) {
00134 PrintUserTime("N-Best Hypotheses Generation Time:");
00135 }
00136 }
00137
00138 if (staticData.GetOutputSearchGraph()) {
00139 std::ostringstream out;
00140 manager.GetSearchGraph(translationId, out);
00141 OutputCollector *oc = m_ioWrapper.GetSearchGraphOutputCollector();
00142 CHECK(oc);
00143 oc->Write(translationId, out.str());
00144 }
00145
00146 IFVERBOSE(2) {
00147 PrintUserTime("Sentence Decoding Time:");
00148 }
00149 manager.CalcDecoderStatistics();
00150 }
00151
00152 private:
00153
00154 TranslationTask(const TranslationTask &);
00155 TranslationTask &operator=(const TranslationTask &);
00156
00157 InputType *m_source;
00158 IOWrapper &m_ioWrapper;
00159 };
00160
00161 bool ReadInput(IOWrapper &ioWrapper, InputTypeEnum inputType, InputType*& source)
00162 {
00163 delete source;
00164 switch(inputType) {
00165 case SentenceInput:
00166 source = ioWrapper.GetInput(new Sentence);
00167 break;
00168 case ConfusionNetworkInput:
00169 source = ioWrapper.GetInput(new ConfusionNet);
00170 break;
00171 case WordLatticeInput:
00172 source = ioWrapper.GetInput(new WordLattice);
00173 break;
00174 case TreeInputType:
00175 source = ioWrapper.GetInput(new TreeInput);
00176 break;
00177 default:
00178 TRACE_ERR("Unknown input type: " << inputType << "\n");
00179 }
00180 return (source ? true : false);
00181 }
00182 static void PrintFeatureWeight(const FeatureFunction* ff)
00183 {
00184 size_t numScoreComps = ff->GetNumScoreComponents();
00185 if (numScoreComps != ScoreProducer::unlimited) {
00186 vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
00187 for (size_t i = 0; i < numScoreComps; ++i)
00188 cout << ff->GetScoreProducerDescription() << " "
00189 << ff->GetScoreProducerWeightShortName() << " "
00190 << values[i] << endl;
00191 }
00192 else {
00193 if (ff->GetSparseProducerWeight() == 1)
00194 cout << ff->GetScoreProducerDescription() << " " <<
00195 ff->GetScoreProducerWeightShortName() << " sparse" << endl;
00196 else
00197 cout << ff->GetScoreProducerDescription() << " " <<
00198 ff->GetScoreProducerWeightShortName() << " " << ff->GetSparseProducerWeight() << endl;
00199 }
00200 }
00201
00202 static void ShowWeights()
00203 {
00204 cout.precision(6);
00205 const StaticData& staticData = StaticData::Instance();
00206 const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
00207
00208
00209
00210 const LMList& lml = system.GetLanguageModels();
00211 LMList::const_iterator lmi = lml.begin();
00212 for (; lmi != lml.end(); ++lmi) {
00213 PrintFeatureWeight(*lmi);
00214 }
00215
00216
00217 const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
00218 for( size_t i=0; i<sff.size(); i++ ) {
00219 if (sff[i]->GetNumScoreComponents() == ScoreProducer::unlimited) {
00220 PrintFeatureWeight(sff[i]);
00221 }
00222 }
00223
00224
00225 const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
00226 for( size_t i=0; i<pds.size(); i++ ) {
00227 PrintFeatureWeight(pds[i]);
00228 }
00229
00230
00231 PrintFeatureWeight(system.GetWordPenaltyProducer());
00232
00233
00234 const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
00235 for( size_t i=0; i<gds.size(); i++ ) {
00236 PrintFeatureWeight(gds[i]);
00237 }
00238
00239
00240 const vector<const StatelessFeatureFunction*>& slf = system.GetStatelessFeatureFunctions();
00241 for( size_t i=0; i<slf.size(); i++ ) {
00242 if (slf[i]->GetNumScoreComponents() == ScoreProducer::unlimited) {
00243 PrintFeatureWeight(slf[i]);
00244 }
00245 }
00246
00247
00248 }
00249
00250
00251 int main(int argc, char* argv[])
00252 {
00253 try {
00254 IFVERBOSE(1) {
00255 TRACE_ERR("command: ");
00256 for(int i=0; i<argc; ++i) TRACE_ERR(argv[i]<<" ");
00257 TRACE_ERR(endl);
00258 }
00259
00260 IOWrapper::FixPrecision(cout);
00261 IOWrapper::FixPrecision(cerr);
00262
00263
00264 Parameter parameter;
00265 if (!parameter.LoadParam(argc, argv)) {
00266 return EXIT_FAILURE;
00267 }
00268
00269 const StaticData &staticData = StaticData::Instance();
00270 if (!StaticData::LoadDataStatic(¶meter, argv[0]))
00271 return EXIT_FAILURE;
00272
00273 if (parameter.isParamSpecified("show-weights")) {
00274 ShowWeights();
00275 exit(0);
00276 }
00277
00278 CHECK(staticData.IsChart());
00279
00280
00281 IOWrapper *ioWrapper = GetIOWrapper(staticData);
00282
00283
00284 const ScoreComponentCollection& weights = staticData.GetAllWeights();
00285 IFVERBOSE(2) {
00286 TRACE_ERR("The global weight vector looks like this: ");
00287 TRACE_ERR(weights);
00288 TRACE_ERR("\n");
00289 }
00290
00291 if (ioWrapper == NULL)
00292 return EXIT_FAILURE;
00293
00294 #ifdef WITH_THREADS
00295 ThreadPool pool(staticData.ThreadCount());
00296 #endif
00297
00298
00299 InputType *source=0;
00300 while(ReadInput(*ioWrapper,staticData.GetInputType(),source)) {
00301 IFVERBOSE(1)
00302 ResetUserTime();
00303 TranslationTask *task = new TranslationTask(source, *ioWrapper);
00304 source = NULL;
00305 #ifdef WITH_THREADS
00306 pool.Submit(task);
00307 #else
00308 task->Run();
00309 delete task;
00310 #endif
00311 }
00312
00313 #ifdef WITH_THREADS
00314 pool.Stop(true);
00315 #endif
00316
00317 delete ioWrapper;
00318
00319 IFVERBOSE(1)
00320 PrintUserTime("End.");
00321
00322 } catch (const std::exception &e) {
00323 std::cerr << "Exception: " << e.what() << std::endl;
00324 return EXIT_FAILURE;
00325 }
00326
00327 IFVERBOSE(1) util::PrintUsage(std::cerr);
00328
00329 #ifndef EXIT_RETURN
00330
00331 exit(EXIT_SUCCESS);
00332 #else
00333 return EXIT_SUCCESS;
00334 #endif
00335 }
00336
00337 IOWrapper *GetIOWrapper(const StaticData &staticData)
00338 {
00339 IOWrapper *ioWrapper;
00340 const std::vector<FactorType> &inputFactorOrder = staticData.GetInputFactorOrder()
00341 ,&outputFactorOrder = staticData.GetOutputFactorOrder();
00342 FactorMask inputFactorUsed(inputFactorOrder);
00343
00344
00345 if (staticData.GetParam("input-file").size() == 1) {
00346 VERBOSE(2,"IO from File" << endl);
00347 string filePath = staticData.GetParam("input-file")[0];
00348
00349 ioWrapper = new IOWrapper(inputFactorOrder, outputFactorOrder, inputFactorUsed
00350 , staticData.GetNBestSize()
00351 , staticData.GetNBestFilePath()
00352 , filePath);
00353 } else {
00354 VERBOSE(1,"IO from STDOUT/STDIN" << endl);
00355 ioWrapper = new IOWrapper(inputFactorOrder, outputFactorOrder, inputFactorUsed
00356 , staticData.GetNBestSize()
00357 , staticData.GetNBestFilePath());
00358 }
00359 ioWrapper->ResetTranslationId();
00360
00361 IFVERBOSE(1)
00362 PrintUserTime("Created input-output object");
00363
00364 return ioWrapper;
00365 }