00001
00002 #pragma once
00003
00004 #include <iostream>
00005 #include <sstream>
00006
00007 #include "moses/DecodeGraph.h"
00008 #include "moses/StaticData.h"
00009 #include "moses/Syntax/BoundedPriorityContainer.h"
00010 #include "moses/Syntax/CubeQueue.h"
00011 #include "moses/Syntax/PHyperedge.h"
00012 #include "moses/Syntax/RuleTable.h"
00013 #include "moses/Syntax/RuleTableFF.h"
00014 #include "moses/Syntax/SHyperedgeBundle.h"
00015 #include "moses/Syntax/SVertex.h"
00016 #include "moses/Syntax/SVertexRecombinationEqualityPred.h"
00017 #include "moses/Syntax/SVertexRecombinationHasher.h"
00018 #include "moses/Syntax/SymbolEqualityPred.h"
00019 #include "moses/Syntax/SymbolHasher.h"
00020
00021 #include "DerivationWriter.h"
00022 #include "OovHandler.h"
00023 #include "PChart.h"
00024 #include "RuleTrie.h"
00025 #include "SChart.h"
00026
00027 namespace Moses
00028 {
00029 namespace Syntax
00030 {
00031 namespace S2T
00032 {
00033
00034 template<typename Parser>
00035 Manager<Parser>::Manager(ttasksptr const& ttask)
00036 : Syntax::Manager(ttask)
00037 , m_pchart(m_source.GetSize(), Parser::RequiresCompressedChart())
00038 , m_schart(m_source.GetSize())
00039 { }
00040
00041 template<typename Parser>
00042 void Manager<Parser>::InitializeCharts()
00043 {
00044
00045 for (std::size_t i = 0; i < m_source.GetSize(); ++i) {
00046 const Word &terminal = m_source.GetWord(i);
00047
00048
00049 PVertex tmp(Range(i,i), terminal);
00050 PVertex &pvertex = m_pchart.AddVertex(tmp);
00051
00052
00053 boost::shared_ptr<SVertex> v(new SVertex());
00054 v->best = 0;
00055 v->pvertex = &pvertex;
00056 SChart::Cell &scell = m_schart.GetCell(i,i);
00057 SVertexStack stack(1, v);
00058 SChart::Cell::TMap::value_type x(terminal, stack);
00059 scell.terminalStacks.insert(x);
00060 }
00061 }
00062
00063 template<typename Parser>
00064 void Manager<Parser>::InitializeParsers(PChart &pchart,
00065 std::size_t ruleLimit)
00066 {
00067 const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances();
00068
00069 const std::vector<DecodeGraph*> &graphs =
00070 StaticData::Instance().GetDecodeGraphs();
00071
00072 UTIL_THROW_IF2(ffs.size() != graphs.size(),
00073 "number of RuleTables does not match number of decode graphs");
00074
00075 for (std::size_t i = 0; i < ffs.size(); ++i) {
00076 RuleTableFF *ff = ffs[i];
00077 std::size_t maxChartSpan = graphs[i]->GetMaxChartSpan();
00078
00079
00080
00081
00082 const RuleTable *table = ff->GetTable();
00083 assert(table);
00084 RuleTable *nonConstTable = const_cast<RuleTable*>(table);
00085 boost::shared_ptr<Parser> parser;
00086 typename Parser::RuleTrie *trie =
00087 dynamic_cast<typename Parser::RuleTrie*>(nonConstTable);
00088 assert(trie);
00089 parser.reset(new Parser(pchart, *trie, maxChartSpan));
00090 m_parsers.push_back(parser);
00091 }
00092
00093
00094
00095 m_oovs.clear();
00096 std::size_t maxOovWidth = 0;
00097 FindOovs(pchart, m_oovs, maxOovWidth);
00098 if (!m_oovs.empty()) {
00099
00100 OovHandler<typename Parser::RuleTrie> oovHandler(*ffs[0]);
00101 m_oovRuleTrie = oovHandler.SynthesizeRuleTrie(m_oovs.begin(), m_oovs.end());
00102
00103 boost::shared_ptr<Parser> parser(
00104 new Parser(pchart, *m_oovRuleTrie, maxOovWidth));
00105 m_parsers.push_back(parser);
00106 }
00107 }
00108
00109
00110
00111 template<typename Parser>
00112 void Manager<Parser>::FindOovs(const PChart &pchart, boost::unordered_set<Word> &oovs,
00113 std::size_t maxOovWidth)
00114 {
00115
00116 std::vector<const RuleTrie *> tries;
00117 const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances();
00118 for (std::size_t i = 0; i < ffs.size(); ++i) {
00119 const RuleTableFF *ff = ffs[i];
00120 if (ff->GetTable()) {
00121 const RuleTrie *trie = dynamic_cast<const RuleTrie*>(ff->GetTable());
00122 assert(trie);
00123 tries.push_back(trie);
00124 }
00125 }
00126
00127
00128
00129
00130 oovs.clear();
00131 maxOovWidth = 0;
00132
00133
00134 for (std::size_t i = 1; i < pchart.GetWidth()-1; ++i) {
00135 for (std::size_t j = i; j < pchart.GetWidth()-1; ++j) {
00136 std::size_t width = j-i+1;
00137 const PChart::Cell::TMap &map = pchart.GetCell(i,j).terminalVertices;
00138 for (PChart::Cell::TMap::const_iterator p = map.begin();
00139 p != map.end(); ++p) {
00140 const Word &word = p->first;
00141 assert(!word.IsNonTerminal());
00142 bool found = false;
00143 for (std::vector<const RuleTrie *>::const_iterator q = tries.begin();
00144 q != tries.end(); ++q) {
00145 const RuleTrie *trie = *q;
00146 if (trie->HasPreterminalRule(word)) {
00147 found = true;
00148 break;
00149 }
00150 }
00151 if (!found) {
00152 oovs.insert(word);
00153 maxOovWidth = std::max(maxOovWidth, width);
00154 }
00155 }
00156 }
00157 }
00158 }
00159
00160 template<typename Parser>
00161 void Manager<Parser>::Decode()
00162 {
00163
00164 const std::size_t popLimit = options()->cube.pop_limit;
00165 const std::size_t ruleLimit = options()->syntax.rule_limit;
00166 const std::size_t stackLimit = options()->search.stack_size;
00167
00168
00169 InitializeCharts();
00170
00171
00172 InitializeParsers(m_pchart, ruleLimit);
00173
00174
00175 typename Parser::CallbackType callback(m_schart, ruleLimit);
00176
00177
00178 std::size_t size = m_source.GetSize();
00179 for (int start = size-1; start >= 0; --start) {
00180 for (std::size_t width = 1; width <= size-start; ++width) {
00181 std::size_t end = start + width - 1;
00182
00183
00184 SChart::Cell &scell = m_schart.GetCell(start, end);
00185
00186 Range range(start, end);
00187
00188
00189
00190
00191 callback.InitForRange(range);
00192 for (typename std::vector<boost::shared_ptr<Parser> >::iterator
00193 p = m_parsers.begin(); p != m_parsers.end(); ++p) {
00194 (*p)->EnumerateHyperedges(range, callback);
00195 }
00196
00197
00198 const BoundedPriorityContainer<SHyperedgeBundle> &bundles =
00199 callback.GetContainer();
00200
00201
00202
00203 CubeQueue cubeQueue(bundles.Begin(), bundles.End());
00204 std::size_t count = 0;
00205 typedef boost::unordered_map<Word, std::vector<SHyperedge*>,
00206 SymbolHasher, SymbolEqualityPred > BufferMap;
00207 BufferMap buffers;
00208 while (count < popLimit && !cubeQueue.IsEmpty()) {
00209 SHyperedge *hyperedge = cubeQueue.Pop();
00210
00211
00212
00213
00214
00215
00216
00217 const Word &lhs = hyperedge->label.translation->GetTargetLHS();
00218 hyperedge->head->pvertex = &m_pchart.AddVertex(PVertex(range, lhs));
00219
00220 buffers[lhs].push_back(hyperedge);
00221 ++count;
00222 }
00223
00224
00225 for (BufferMap::const_iterator p = buffers.begin(); p != buffers.end();
00226 ++p) {
00227 const Word &category = p->first;
00228 const std::vector<SHyperedge*> &buffer = p->second;
00229 std::pair<SChart::Cell::NMap::Iterator, bool> ret =
00230 scell.nonTerminalStacks.Insert(category, SVertexStack());
00231 assert(ret.second);
00232 SVertexStack &stack = ret.first->second;
00233 RecombineAndSort(buffer, stack);
00234 }
00235
00236
00237 if (stackLimit > 0) {
00238 for (SChart::Cell::NMap::Iterator p = scell.nonTerminalStacks.Begin();
00239 p != scell.nonTerminalStacks.End(); ++p) {
00240 SVertexStack &stack = p->second;
00241 if (stack.size() > stackLimit) {
00242 stack.resize(stackLimit);
00243 }
00244 }
00245 }
00246
00247
00248
00249
00250
00251 }
00252 }
00253 }
00254
00255 template<typename Parser>
00256 const SHyperedge *Manager<Parser>::GetBestSHyperedge() const
00257 {
00258 const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1);
00259 const SChart::Cell::NMap &stacks = cell.nonTerminalStacks;
00260 if (stacks.Size() == 0) {
00261 return 0;
00262 }
00263 assert(stacks.Size() == 1);
00264 const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second;
00265
00266 return stack[0]->best;
00267 }
00268
00269 template<typename Parser>
00270 void Manager<Parser>::ExtractKBest(
00271 std::size_t k,
00272 std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList,
00273 bool onlyDistinct) const
00274 {
00275 kBestList.clear();
00276 if (k == 0 || m_source.GetSize() == 0) {
00277 return;
00278 }
00279
00280
00281 const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1);
00282 const SChart::Cell::NMap &stacks = cell.nonTerminalStacks;
00283 if (stacks.Size() == 0) {
00284 return;
00285 }
00286 assert(stacks.Size() == 1);
00287 const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second;
00288
00289
00290 KBestExtractor extractor;
00291
00292 if (!onlyDistinct) {
00293
00294 extractor.Extract(stack, k, kBestList);
00295 return;
00296 }
00297
00298
00299
00300
00301
00302
00303 const StaticData &staticData = StaticData::Instance();
00304 const std::size_t nBestFactor = staticData.options()->nbest.factor;
00305 std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor;
00306
00307
00308 KBestExtractor::KBestVec bigList;
00309 bigList.reserve(numDerivations);
00310 extractor.Extract(stack, numDerivations, bigList);
00311
00312
00313 std::set<Phrase> distinct;
00314 for (KBestExtractor::KBestVec::const_iterator p = bigList.begin();
00315 kBestList.size() < k && p != bigList.end(); ++p) {
00316 boost::shared_ptr<KBestExtractor::Derivation> derivation = *p;
00317 Phrase translation = KBestExtractor::GetOutputPhrase(*derivation);
00318 if (distinct.insert(translation).second) {
00319 kBestList.push_back(derivation);
00320 }
00321 }
00322 }
00323
00324 template<typename Parser>
00325 void Manager<Parser>::PrunePChart(const SChart::Cell &scell,
00326 PChart::Cell &pcell)
00327 {
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340 }
00341
00342 template<typename Parser>
00343 void Manager<Parser>::RecombineAndSort(const std::vector<SHyperedge*> &buffer,
00344 SVertexStack &stack)
00345 {
00346
00347
00348
00349
00350
00351 typedef boost::unordered_map<SVertex *, SVertex *,
00352 SVertexRecombinationHasher,
00353 SVertexRecombinationEqualityPred> Map;
00354 Map map;
00355 for (std::vector<SHyperedge*>::const_iterator p = buffer.begin();
00356 p != buffer.end(); ++p) {
00357 SHyperedge *h = *p;
00358 SVertex *v = h->head;
00359 assert(v->best == h);
00360 assert(v->recombined.empty());
00361 std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v));
00362 if (result.second) {
00363 continue;
00364 }
00365
00366
00367
00368 SVertex *storedVertex = result.first->second;
00369 if (h->label.futureScore > storedVertex->best->label.futureScore) {
00370
00371 storedVertex->recombined.push_back(storedVertex->best);
00372 storedVertex->best = h;
00373 } else {
00374 storedVertex->recombined.push_back(h);
00375 }
00376 h->head->best = 0;
00377 delete h->head;
00378 h->head = storedVertex;
00379 }
00380
00381
00382 stack.clear();
00383 stack.reserve(map.size());
00384 for (Map::const_iterator p = map.begin(); p != map.end(); ++p) {
00385 stack.push_back(boost::shared_ptr<SVertex>(p->first));
00386 }
00387
00388
00389 std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer());
00390 }
00391
00392 template<typename Parser>
00393 void Manager<Parser>::OutputDetailedTranslationReport(
00394 OutputCollector *collector) const
00395 {
00396 const SHyperedge *best = GetBestSHyperedge();
00397 if (best == NULL || collector == NULL) {
00398 return;
00399 }
00400 long translationId = m_source.GetTranslationId();
00401 std::ostringstream out;
00402 DerivationWriter::Write(*best, translationId, out);
00403 collector->Write(translationId, out.str());
00404 }
00405
00406 }
00407 }
00408 }