00001
00002
00003 #include <list>
00004 #include <vector>
00005 #include "TranslationOptionCollectionConfusionNet.h"
00006 #include "ConfusionNet.h"
00007 #include "DecodeGraph.h"
00008 #include "DecodeStepTranslation.h"
00009 #include "DecodeStepGeneration.h"
00010 #include "FactorCollection.h"
00011 #include "FF/InputFeature.h"
00012 #include "TranslationModel/PhraseDictionaryTreeAdaptor.h"
00013 #include "util/exception.hh"
00014 #include <boost/foreach.hpp>
00015 #include "TranslationTask.h"
00016 using namespace std;
00017
00018 namespace Moses
00019 {
00020
00022 TranslationOptionCollectionConfusionNet::
00023 TranslationOptionCollectionConfusionNet(ttasksptr const& ttask,
00024 const ConfusionNet &input)
00025
00026 : TranslationOptionCollection(ttask,input)
00027
00028 {
00029 size_t maxNoTransOptPerCoverage = ttask->options()->search.max_trans_opt_per_cov;
00030 float translationOptionThreshold = ttask->options()->search.trans_opt_threshold;
00031
00032
00033
00034
00035 vector<PhraseDictionary*> prefixCheckers;
00036 BOOST_FOREACH(PhraseDictionary* pd, PhraseDictionary::GetColl())
00037 if (pd->ProvidesPrefixCheck()) prefixCheckers.push_back(pd);
00038
00039 const InputFeature *inputFeature = InputFeature::InstancePtr();
00040 UTIL_THROW_IF2(inputFeature == NULL, "Input feature must be specified");
00041
00042 size_t inputSize = input.GetSize();
00043 m_inputPathMatrix.resize(inputSize);
00044
00045 size_t maxSizePhrase = ttask->options()->search.max_phrase_length;
00046 maxSizePhrase = std::min(inputSize, maxSizePhrase);
00047
00048
00049 for (size_t startPos = 0; startPos < inputSize; ++startPos) {
00050 vector<InputPathList> &vec = m_inputPathMatrix[startPos];
00051 vec.push_back(InputPathList());
00052 InputPathList &list = vec.back();
00053
00054 Range range(startPos, startPos);
00055 const NonTerminalSet &labels = input.GetLabelSet(startPos, startPos);
00056
00057 const ConfusionNet::Column &col = input.GetColumn(startPos);
00058 for (size_t i = 0; i < col.size(); ++i) {
00059 const Word &word = col[i].first;
00060 Phrase subphrase;
00061 subphrase.AddWord(word);
00062
00063 const ScorePair &scores = col[i].second;
00064 ScorePair *inputScore = new ScorePair(scores);
00065
00066 InputPath* path = new InputPath(ttask.get(), subphrase, labels,
00067 range, NULL, inputScore);
00068 list.push_back(path);
00069
00070 m_inputPathQueue.push_back(path);
00071 }
00072 }
00073
00074
00075 for (size_t phraseSize = 2; phraseSize <= maxSizePhrase; ++phraseSize) {
00076 for (size_t startPos = 0; startPos < inputSize - phraseSize + 1; ++startPos) {
00077 size_t endPos = startPos + phraseSize -1;
00078
00079 Range range(startPos, endPos);
00080 const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
00081
00082 vector<InputPathList> &vec = m_inputPathMatrix[startPos];
00083 vec.push_back(InputPathList());
00084 InputPathList &list = vec.back();
00085
00086
00087 const InputPathList &prevPaths = GetInputPathList(startPos, endPos - 1);
00088
00089 int prevNodesInd = 0;
00090 InputPathList::const_iterator iterPath;
00091 for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
00092
00093 const InputPath &prevPath = **iterPath;
00094
00095
00096 const Phrase &prevPhrase = prevPath.GetPhrase();
00097 const ScorePair *prevInputScore = prevPath.GetInputScore();
00098 UTIL_THROW_IF2(prevInputScore == NULL,
00099 "No input score for path: " << prevPath);
00100
00101
00102 const ConfusionNet::Column &col = input.GetColumn(endPos);
00103
00104 for (size_t i = 0; i < col.size(); ++i) {
00105 const Word &word = col[i].first;
00106 Phrase subphrase(prevPhrase);
00107 subphrase.AddWord(word);
00108
00109 bool OK = prefixCheckers.size() == 0;
00110 for (size_t k = 0; !OK && k < prefixCheckers.size(); ++k)
00111 OK = prefixCheckers[k]->PrefixExists(m_ttask.lock(), subphrase);
00112 if (!OK) continue;
00113
00114 const ScorePair &scores = col[i].second;
00115 ScorePair *inputScore = new ScorePair(*prevInputScore);
00116 inputScore->PlusEquals(scores);
00117
00118 InputPath *path = new InputPath(ttask.get(), subphrase, labels, range,
00119 &prevPath, inputScore);
00120 list.push_back(path);
00121
00122 m_inputPathQueue.push_back(path);
00123 }
00124
00125 ++prevNodesInd;
00126 }
00127 }
00128 }
00129
00130
00131
00132 }
00133
00134 InputPathList &TranslationOptionCollectionConfusionNet::GetInputPathList(size_t startPos, size_t endPos)
00135 {
00136 size_t offset = endPos - startPos;
00137 UTIL_THROW_IF2(offset >= m_inputPathMatrix[startPos].size(),
00138 "Out of bound access: " << offset);
00139
00140 return m_inputPathMatrix[startPos][offset];
00141 }
00142
00143
00144
00145
00146
00147 void TranslationOptionCollectionConfusionNet::ProcessUnknownWord(size_t sourcePos)
00148 {
00149 ConfusionNet const& source=static_cast<ConfusionNet const&>(m_source);
00150
00151 ConfusionNet::Column const& coll=source.GetColumn(sourcePos);
00152 const InputPathList &inputPathList = GetInputPathList(sourcePos, sourcePos);
00153
00154 ConfusionNet::Column::const_iterator iterCol;
00155 InputPathList::const_iterator iterInputPath;
00156 size_t j=0;
00157 for(iterCol = coll.begin(), iterInputPath = inputPathList.begin();
00158 iterCol != coll.end();
00159 ++iterCol , ++iterInputPath) {
00160 const InputPath &inputPath = **iterInputPath;
00161 size_t length = source.GetColumnIncrement(sourcePos, j++);
00162 const ScorePair &inputScores = iterCol->second;
00163 ProcessOneUnknownWord(inputPath ,sourcePos, length, &inputScores);
00164 }
00165
00166 }
00167
00168 void
00169 TranslationOptionCollectionConfusionNet
00170 ::CreateTranslationOptions()
00171 {
00172 if (!StaticData::Instance().GetUseLegacyPT()) {
00173 GetTargetPhraseCollectionBatch();
00174 }
00175 TranslationOptionCollection::CreateTranslationOptions();
00176 }
00177
00178
00190 bool
00191 TranslationOptionCollectionConfusionNet::
00192 CreateTranslationOptionsForRange(const DecodeGraph &decodeGraph,
00193 size_t startPos, size_t endPos,
00194 bool adhereTableLimit, size_t graphInd)
00195 {
00196 if (StaticData::Instance().GetUseLegacyPT()) {
00197 return CreateTranslationOptionsForRangeLEGACY(decodeGraph, startPos, endPos,
00198 adhereTableLimit, graphInd);
00199 } else {
00200 return CreateTranslationOptionsForRangeNew(decodeGraph, startPos, endPos,
00201 adhereTableLimit, graphInd);
00202 }
00203 }
00204
00205 bool
00206 TranslationOptionCollectionConfusionNet::
00207 CreateTranslationOptionsForRangeNew
00208 ( const DecodeGraph &decodeGraph, size_t startPos, size_t endPos,
00209 bool adhereTableLimit, size_t graphInd)
00210 {
00211 InputPathList &inputPathList = GetInputPathList(startPos, endPos);
00212 if (inputPathList.size() == 0) return false;
00213 InputPathList::iterator iter;
00214 for (iter = inputPathList.begin(); iter != inputPathList.end(); ++iter) {
00215 InputPath &inputPath = **iter;
00216 TranslationOptionCollection::CreateTranslationOptionsForRange
00217 (decodeGraph, startPos, endPos, adhereTableLimit, graphInd, inputPath);
00218 }
00219 return true;
00220 }
00221
00222 bool
00223 TranslationOptionCollectionConfusionNet::
00224 CreateTranslationOptionsForRangeLEGACY(const DecodeGraph &decodeGraph,
00225 size_t startPos, size_t endPos,
00226 bool adhereTableLimit, size_t graphInd)
00227 {
00228 bool retval = true;
00229 size_t const max_phrase_length
00230 = StaticData::Instance().options()->search.max_phrase_length;
00231 XmlInputType intype = m_ttask.lock()->options()->input.xml_policy;
00232 if ((intype != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) {
00233 InputPathList &inputPathList = GetInputPathList(startPos, endPos);
00234
00235
00236 PartialTranslOptColl* oldPtoc = new PartialTranslOptColl(max_phrase_length);
00237 size_t totalEarlyPruned = 0;
00238
00239
00240 list <const DecodeStep* >::const_iterator iterStep = decodeGraph.begin();
00241 const DecodeStep &decodeStep = **iterStep;
00242
00243 DecodeStepTranslation const& dstep
00244 = static_cast<const DecodeStepTranslation&>(decodeStep);
00245 dstep.ProcessInitialTransLEGACY(m_source, *oldPtoc, startPos, endPos,
00246 adhereTableLimit, inputPathList);
00247
00248
00249 int indexStep = 0;
00250
00251 for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {
00252
00253 const DecodeStep *decodeStep = *iterStep;
00254 const DecodeStepTranslation *transStep =dynamic_cast<const DecodeStepTranslation*>(decodeStep);
00255 const DecodeStepGeneration *genStep =dynamic_cast<const DecodeStepGeneration*>(decodeStep);
00256
00257 PartialTranslOptColl* newPtoc = new PartialTranslOptColl(max_phrase_length);
00258
00259
00260 const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList();
00261 vector<TranslationOption*>::const_iterator iterPartialTranslOpt;
00262 for (iterPartialTranslOpt = partTransOptList.begin();
00263 iterPartialTranslOpt != partTransOptList.end();
00264 ++iterPartialTranslOpt) {
00265 TranslationOption &inputPartialTranslOpt = **iterPartialTranslOpt;
00266
00267 if (transStep) {
00268 transStep->ProcessLEGACY(inputPartialTranslOpt
00269 , *decodeStep
00270 , *newPtoc
00271 , this
00272 , adhereTableLimit);
00273 } else {
00274 assert(genStep);
00275 genStep->Process(inputPartialTranslOpt
00276 , *decodeStep
00277 , *newPtoc
00278 , this
00279 , adhereTableLimit);
00280 }
00281 }
00282
00283
00284 totalEarlyPruned += newPtoc->GetPrunedCount();
00285 delete oldPtoc;
00286 oldPtoc = newPtoc;
00287
00288 indexStep++;
00289 }
00290
00291
00292 PartialTranslOptColl &lastPartialTranslOptColl = *oldPtoc;
00293 const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList();
00294 vector<TranslationOption*>::const_iterator iterColl;
00295 for (iterColl = partTransOptList.begin() ; iterColl != partTransOptList.end() ; ++iterColl) {
00296 TranslationOption *transOpt = *iterColl;
00297 Add(transOpt);
00298 }
00299
00300 lastPartialTranslOptColl.DetachAll();
00301 totalEarlyPruned += oldPtoc->GetPrunedCount();
00302 delete oldPtoc;
00303
00304
00305 }
00306
00307
00308 if (graphInd == 0 && intype != XmlPassThrough &&
00309 HasXmlOptionsOverlappingRange(startPos,endPos)) {
00310 CreateXmlOptionsForRange(startPos, endPos);
00311 }
00312 return retval;
00313 }
00314
00315
00316 }
00317
00318