00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include "ChartParser.h"
00023 #include "ChartParserCallback.h"
00024 #include "ChartRuleLookupManager.h"
00025 #include "StaticData.h"
00026 #include "TreeInput.h"
00027 #include "Sentence.h"
00028 #include "DecodeGraph.h"
00029 #include "moses/FF/UnknownWordPenaltyProducer.h"
00030 #include "moses/TranslationModel/PhraseDictionary.h"
00031 #include "moses/TranslationTask.h"
00032
00033 using namespace std;
00034 using namespace Moses;
00035
00036 namespace Moses
00037 {
00038
00039 ChartParserUnknown
00040 ::ChartParserUnknown(ttasksptr const& ttask)
00041 : m_ttask(ttask)
00042 { }
00043
00044 ChartParserUnknown::~ChartParserUnknown()
00045 {
00046 RemoveAllInColl(m_unksrcs);
00047 }
00048
00049 AllOptions::ptr const&
00050 ChartParserUnknown::
00051 options() const
00052 {
00053 return m_ttask.lock()->options();
00054 }
00055
00056 void
00057 ChartParserUnknown::
00058 Process(const Word &sourceWord, const Range &range, ChartParserCallback &to)
00059 {
00060
00061 const StaticData &staticData = StaticData::Instance();
00062 const UnknownWordPenaltyProducer &unknownWordPenaltyProducer
00063 = UnknownWordPenaltyProducer::Instance();
00064
00065 size_t isDigit = 0;
00066 if (options()->unk.drop) {
00067 const Factor *f = sourceWord[0];
00068 const StringPiece s = f->GetString();
00069 isDigit = s.find_first_of("0123456789");
00070 if (isDigit == string::npos)
00071 isDigit = 0;
00072 else
00073 isDigit = 1;
00074
00075 }
00076
00077 Phrase* unksrc = new Phrase(1);
00078 unksrc->AddWord() = sourceWord;
00079 Word &newWord = unksrc->GetWord(0);
00080 newWord.SetIsOOV(true);
00081
00082 m_unksrcs.push_back(unksrc);
00083
00084
00085 PhraseDictionary *firstPt = NULL;
00086 if (PhraseDictionary::GetColl().size() == 0) {
00087 firstPt = PhraseDictionary::GetColl()[0];
00088 }
00089
00090
00091 if (! options()->unk.drop || isDigit) {
00092
00093 const UnknownLHSList &lhsList = options()->syntax.unknown_lhs;
00094 UnknownLHSList::const_iterator iterLHS;
00095 for (iterLHS = lhsList.begin(); iterLHS != lhsList.end(); ++iterLHS) {
00096 const string &targetLHSStr = iterLHS->first;
00097 float prob = iterLHS->second;
00098
00099
00100
00101 Word *targetLHS = new Word(true);
00102
00103 targetLHS->CreateFromString(Output, options()->output.factor_order,
00104 targetLHSStr, true);
00105 UTIL_THROW_IF2(targetLHS->GetFactor(0) == NULL, "Null factor for target LHS");
00106
00107
00108 TargetPhrase *targetPhrase = new TargetPhrase(firstPt);
00109 Word &targetWord = targetPhrase->AddWord();
00110 targetWord.CreateUnknownWord(sourceWord);
00111
00112
00113 float unknownScore = FloorScore(TransformScore(prob));
00114
00115 targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
00116 targetPhrase->SetTargetLHS(targetLHS);
00117 targetPhrase->SetAlignmentInfo("0-0");
00118 targetPhrase->EvaluateInIsolation(*unksrc);
00119
00120 if (!options()->output.detailed_tree_transrep_filepath.empty() ||
00121 options()->nbest.print_trees || staticData.GetTreeStructure() != NULL) {
00122 std::string prop = "[ ";
00123 prop += (*targetLHS)[0]->GetString().as_string() + " ";
00124 prop += sourceWord[0]->GetString().as_string() + " ]";
00125 targetPhrase->SetProperty("Tree", prop);
00126 }
00127
00128
00129 to.AddPhraseOOV(*targetPhrase, m_cacheTargetPhraseCollection, range);
00130 }
00131 } else {
00132
00133 float unknownScore = FloorScore(-numeric_limits<float>::infinity());
00134
00135 TargetPhrase *targetPhrase = new TargetPhrase(firstPt);
00136
00137 const UnknownLHSList &lhsList = options()->syntax.unknown_lhs;
00138 UnknownLHSList::const_iterator iterLHS;
00139 for (iterLHS = lhsList.begin(); iterLHS != lhsList.end(); ++iterLHS) {
00140 const string &targetLHSStr = iterLHS->first;
00141
00142
00143 Word *targetLHS = new Word(true);
00144 targetLHS->CreateFromString(Output, staticData.options()->output.factor_order,
00145 targetLHSStr, true);
00146 UTIL_THROW_IF2(targetLHS->GetFactor(0) == NULL, "Null factor for target LHS");
00147
00148 targetPhrase->GetScoreBreakdown().Assign(&unknownWordPenaltyProducer, unknownScore);
00149 targetPhrase->EvaluateInIsolation(*unksrc);
00150
00151 targetPhrase->SetTargetLHS(targetLHS);
00152
00153
00154 to.AddPhraseOOV(*targetPhrase, m_cacheTargetPhraseCollection, range);
00155 }
00156 }
00157 }
00158
00159 ChartParser
00160 ::ChartParser(ttasksptr const& ttask, ChartCellCollectionBase &cells)
00161 : m_ttask(ttask)
00162 , m_unknown(ttask)
00163 , m_decodeGraphList(StaticData::Instance().GetDecodeGraphs())
00164 , m_source(*(ttask->GetSource().get()))
00165 {
00166 const StaticData &staticData = StaticData::Instance();
00167
00168 staticData.InitializeForInput(ttask);
00169 CreateInputPaths(m_source);
00170
00171 const std::vector<PhraseDictionary*> &dictionaries = PhraseDictionary::GetColl();
00172 assert(dictionaries.size() == m_decodeGraphList.size());
00173 m_ruleLookupManagers.reserve(dictionaries.size());
00174 for (std::size_t i = 0; i < dictionaries.size(); ++i) {
00175 const PhraseDictionary *dict = dictionaries[i];
00176 PhraseDictionary *nonConstDict = const_cast<PhraseDictionary*>(dict);
00177 std::size_t maxChartSpan = m_decodeGraphList[i]->GetMaxChartSpan();
00178 ChartRuleLookupManager *lookupMgr = nonConstDict->CreateRuleLookupManager(*this, cells, maxChartSpan);
00179 m_ruleLookupManagers.push_back(lookupMgr);
00180 }
00181
00182 }
00183
00184 ChartParser::~ChartParser()
00185 {
00186 RemoveAllInColl(m_ruleLookupManagers);
00187 StaticData::Instance().CleanUpAfterSentenceProcessing(m_ttask.lock());
00188
00189 InputPathMatrix::const_iterator iterOuter;
00190 for (iterOuter = m_inputPathMatrix.begin(); iterOuter != m_inputPathMatrix.end(); ++iterOuter) {
00191 const std::vector<InputPath*> &outer = *iterOuter;
00192
00193 std::vector<InputPath*>::const_iterator iterInner;
00194 for (iterInner = outer.begin(); iterInner != outer.end(); ++iterInner) {
00195 InputPath *path = *iterInner;
00196 delete path;
00197 }
00198 }
00199 }
00200
00201 void ChartParser::Create(const Range &range, ChartParserCallback &to)
00202 {
00203 assert(m_decodeGraphList.size() == m_ruleLookupManagers.size());
00204
00205 std::vector <DecodeGraph*>::const_iterator iterDecodeGraph;
00206 std::vector <ChartRuleLookupManager*>::const_iterator iterRuleLookupManagers = m_ruleLookupManagers.begin();
00207 for (iterDecodeGraph = m_decodeGraphList.begin(); iterDecodeGraph != m_decodeGraphList.end(); ++iterDecodeGraph, ++iterRuleLookupManagers) {
00208 const DecodeGraph &decodeGraph = **iterDecodeGraph;
00209 assert(decodeGraph.GetSize() == 1);
00210 ChartRuleLookupManager &ruleLookupManager = **iterRuleLookupManagers;
00211 size_t maxSpan = decodeGraph.GetMaxChartSpan();
00212 size_t last = m_source.GetSize()-1;
00213 if (maxSpan != 0) {
00214 last = min(last, range.GetStartPos()+maxSpan);
00215 }
00216 if (maxSpan == 0 || range.GetNumWordsCovered() <= maxSpan) {
00217 const InputPath &inputPath = GetInputPath(range);
00218 ruleLookupManager.GetChartRuleCollection(inputPath, last, to);
00219 }
00220 }
00221
00222 if (range.GetNumWordsCovered() == 1
00223 && range.GetStartPos() != 0
00224 && range.GetStartPos() != m_source.GetSize()-1) {
00225 bool always = options()->unk.always_create_direct_transopt;
00226 if (to.Empty() || always) {
00227
00228 const Word &sourceWord = m_source.GetWord(range.GetStartPos());
00229 m_unknown.Process(sourceWord, range, to);
00230 }
00231 }
00232 }
00233
00234 void ChartParser::CreateInputPaths(const InputType &input)
00235 {
00236 size_t size = input.GetSize();
00237 m_inputPathMatrix.resize(size);
00238
00239 UTIL_THROW_IF2(input.GetType() != SentenceInput && input.GetType() != TreeInputType,
00240 "Input must be a sentence or a tree, " <<
00241 "not lattice or confusion networks");
00242
00243 TranslationTask const* ttask = m_ttask.lock().get();
00244 for (size_t phaseSize = 1; phaseSize <= size; ++phaseSize) {
00245 for (size_t startPos = 0; startPos < size - phaseSize + 1; ++startPos) {
00246 size_t endPos = startPos + phaseSize -1;
00247 vector<InputPath*> &vec = m_inputPathMatrix[startPos];
00248
00249 Range range(startPos, endPos);
00250 Phrase subphrase(input.GetSubString(Range(startPos, endPos)));
00251 const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
00252
00253 InputPath *node;
00254 if (range.GetNumWordsCovered() == 1) {
00255 node = new InputPath(ttask, subphrase, labels, range, NULL, NULL);
00256 vec.push_back(node);
00257 } else {
00258 const InputPath &prevNode = GetInputPath(startPos, endPos - 1);
00259 node = new InputPath(ttask, subphrase, labels, range, &prevNode, NULL);
00260 vec.push_back(node);
00261 }
00262
00263
00264 }
00265 }
00266 }
00267
00268 const InputPath &ChartParser::GetInputPath(const Range &range) const
00269 {
00270 return GetInputPath(range.GetStartPos(), range.GetEndPos());
00271 }
00272
00273 const InputPath &ChartParser::GetInputPath(size_t startPos, size_t endPos) const
00274 {
00275 size_t offset = endPos - startPos;
00276 UTIL_THROW_IF2(offset >= m_inputPathMatrix[startPos].size(),
00277 "Out of bound: " << offset);
00278 return *m_inputPathMatrix[startPos][offset];
00279 }
00280
00281 InputPath &ChartParser::GetInputPath(size_t startPos, size_t endPos)
00282 {
00283 size_t offset = endPos - startPos;
00284 UTIL_THROW_IF2(offset >= m_inputPathMatrix[startPos].size(),
00285 "Out of bound: " << offset);
00286 return *m_inputPathMatrix[startPos][offset];
00287 }
00288
00289
00290
00291
00292
00293
00294 size_t ChartParser::GetSize() const
00295 {
00296 return m_source.GetSize();
00297 }
00298
00299 long ChartParser::GetTranslationId() const
00300 {
00301 return m_source.GetTranslationId();
00302 }
00303
00304
00305 AllOptions::ptr const&
00306 ChartParser::
00307 options() const
00308 {
00309 return m_ttask.lock()->options();
00310 }
00311
00312
00313 }