00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <stdio.h>
00023 #include "ChartManager.h"
00024 #include "ChartCell.h"
00025 #include "ChartHypothesis.h"
00026 #include "ChartTranslationOptions.h"
00027 #include "ChartTrellisDetourQueue.h"
00028 #include "ChartTrellisNode.h"
00029 #include "ChartTrellisPath.h"
00030 #include "ChartTrellisPathList.h"
00031 #include "StaticData.h"
00032 #include "DecodeStep.h"
00033 #include "TreeInput.h"
00034 #include "moses/FF/WordPenaltyProducer.h"
00035
00036 using namespace std;
00037 using namespace Moses;
00038
00039 namespace Moses
00040 {
00041 extern bool g_debug;
00042
00043
00044
00045
00046
00047 ChartManager::ChartManager(InputType const& source)
00048 :m_source(source)
00049 ,m_hypoStackColl(source, *this)
00050 ,m_start(clock())
00051 ,m_hypothesisId(0)
00052 ,m_parser(source, m_hypoStackColl)
00053 ,m_translationOptionList(StaticData::Instance().GetRuleLimit())
00054 {
00055 }
00056
00057 ChartManager::~ChartManager()
00058 {
00059 clock_t end = clock();
00060 float et = (end - m_start);
00061 et /= (float)CLOCKS_PER_SEC;
00062 VERBOSE(1, "Translation took " << et << " seconds" << endl);
00063
00064 }
00065
00067 void ChartManager::ProcessSentence()
00068 {
00069 VERBOSE(1,"Translating: " << m_source << endl);
00070
00071 ResetSentenceStats(m_source);
00072
00073 VERBOSE(2,"Decoding: " << endl);
00074
00075
00076 AddXmlChartOptions();
00077
00078
00079 size_t size = m_source.GetSize();
00080 for (size_t width = 1; width <= size; ++width) {
00081 for (size_t startPos = 0; startPos <= size-width; ++startPos) {
00082 size_t endPos = startPos + width - 1;
00083 WordsRange range(startPos, endPos);
00084
00085
00086 m_translationOptionList.Clear();
00087 m_parser.Create(range, m_translationOptionList);
00088 m_translationOptionList.ApplyThreshold();
00089
00090
00091 ChartCell &cell = m_hypoStackColl.Get(range);
00092
00093 cell.ProcessSentence(m_translationOptionList, m_hypoStackColl);
00094 m_translationOptionList.Clear();
00095 cell.PruneToSize();
00096 cell.CleanupArcList();
00097 cell.SortHypotheses();
00098 }
00099 }
00100
00101 IFVERBOSE(1) {
00102
00103 for (size_t startPos = 0; startPos < size; ++startPos) {
00104 cerr.width(3);
00105 cerr << startPos << " ";
00106 }
00107 cerr << endl;
00108 for (size_t width = 1; width <= size; width++) {
00109 for( size_t space = 0; space < width-1; space++ ) {
00110 cerr << " ";
00111 }
00112 for (size_t startPos = 0; startPos <= size-width; ++startPos) {
00113 WordsRange range(startPos, startPos+width-1);
00114 cerr.width(3);
00115 cerr << m_hypoStackColl.Get(range).GetSize() << " ";
00116 }
00117 cerr << endl;
00118 }
00119 }
00120 }
00121
00126 void ChartManager::AddXmlChartOptions()
00127 {
00128 const StaticData &staticData = StaticData::Instance();
00129 const std::vector <ChartTranslationOptions*> xmlChartOptionsList = m_source.GetXmlChartTranslationOptions();
00130 IFVERBOSE(2) {
00131 cerr << "AddXmlChartOptions " << xmlChartOptionsList.size() << endl;
00132 }
00133 if (xmlChartOptionsList.size() == 0) return;
00134
00135 for(std::vector<ChartTranslationOptions*>::const_iterator i = xmlChartOptionsList.begin();
00136 i != xmlChartOptionsList.end(); ++i) {
00137 ChartTranslationOptions* opt = *i;
00138
00139 TargetPhrase &targetPhrase = *opt->GetTargetPhraseCollection().GetCollection()[0];
00140 targetPhrase.GetScoreBreakdown().Assign(staticData.GetWordPenaltyProducer(), -1);
00141
00142 const WordsRange &range = opt->GetSourceWordsRange();
00143 RuleCubeItem* item = new RuleCubeItem( *opt, m_hypoStackColl );
00144 ChartHypothesis* hypo = new ChartHypothesis(*opt, *item, *this);
00145 hypo->Evaluate();
00146 ChartCell &cell = m_hypoStackColl.Get(range);
00147 cell.AddHypothesis(hypo);
00148 }
00149 }
00150
00152 const ChartHypothesis *ChartManager::GetBestHypothesis() const
00153 {
00154 size_t size = m_source.GetSize();
00155
00156 if (size == 0)
00157 return NULL;
00158 else {
00159 WordsRange range(0, size-1);
00160 const ChartCell &lastCell = m_hypoStackColl.Get(range);
00161 return lastCell.GetBestHypothesis();
00162 }
00163 }
00164
00171 void ChartManager::CalcNBest(size_t count, ChartTrellisPathList &ret,bool onlyDistinct) const
00172 {
00173 size_t size = m_source.GetSize();
00174 if (count == 0 || size == 0)
00175 return;
00176
00177
00178 WordsRange range(0, size-1);
00179 const ChartCell &lastCell = m_hypoStackColl.Get(range);
00180 const ChartHypothesis *hypo = lastCell.GetBestHypothesis();
00181 if (hypo == NULL) {
00182
00183 return;
00184 }
00185 boost::shared_ptr<ChartTrellisPath> basePath(new ChartTrellisPath(*hypo));
00186
00187
00188 if (count == 1) {
00189 ret.Add(basePath);
00190 return;
00191 }
00192
00193
00194
00195
00196 const StaticData &staticData = StaticData::Instance();
00197 const size_t nBestFactor = staticData.GetNBestFactor();
00198 size_t popLimit;
00199 if (!onlyDistinct) {
00200 popLimit = count-1;
00201 } else if (nBestFactor == 0) {
00202
00203
00204 popLimit = count * 1000;
00205 } else {
00206 popLimit = count * nBestFactor;
00207 }
00208
00209
00210
00211 ChartTrellisDetourQueue contenders(popLimit);
00212
00213
00214 const HypoList *topHypos = lastCell.GetAllSortedHypotheses();
00215
00216
00217 HypoList::const_iterator iter;
00218 for (iter = topHypos->begin(); iter != topHypos->end(); ++iter) {
00219 const ChartHypothesis &hypo = **iter;
00220 boost::shared_ptr<ChartTrellisPath> basePath(new ChartTrellisPath(hypo));
00221 ChartTrellisDetour *detour = new ChartTrellisDetour(basePath, basePath->GetFinalNode(), hypo);
00222 contenders.Push(detour);
00223 }
00224
00225 delete topHypos;
00226
00227
00228 set<Phrase> distinctHyps;
00229
00230
00231 for (size_t i = 0; ret.GetSize() < count && !contenders.Empty() && i < popLimit; ++i) {
00232
00233 std::auto_ptr<const ChartTrellisDetour> detour(contenders.Pop());
00234 CHECK(detour.get());
00235
00236
00237
00238 boost::shared_ptr<ChartTrellisPath> path(new ChartTrellisPath(*detour));
00239
00240
00241
00242
00243 CHECK(path->GetDeviationPoint());
00244 CreateDeviantPaths(path, *(path->GetDeviationPoint()), contenders);
00245
00246
00247
00248
00249 if (!onlyDistinct) {
00250 ret.Add(path);
00251 } else {
00252 Phrase tgtPhrase = path->GetOutputPhrase();
00253 if (distinctHyps.insert(tgtPhrase).second) {
00254 ret.Add(path);
00255 }
00256 }
00257 }
00258 }
00259
00260 void ChartManager::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
00261 {
00262 size_t size = m_source.GetSize();
00263
00264
00265 std::map<unsigned,bool> reachable;
00266 WordsRange fullRange(0, size-1);
00267 const ChartCell &lastCell = m_hypoStackColl.Get(fullRange);
00268 const ChartHypothesis *hypo = lastCell.GetBestHypothesis();
00269
00270 if (hypo == NULL) {
00271
00272 return;
00273 }
00274 FindReachableHypotheses( hypo, reachable);
00275
00276 for (size_t width = 1; width <= size; ++width) {
00277 for (size_t startPos = 0; startPos <= size-width; ++startPos) {
00278 size_t endPos = startPos + width - 1;
00279 WordsRange range(startPos, endPos);
00280 TRACE_ERR(" " << range << "=");
00281
00282 const ChartCell &cell = m_hypoStackColl.Get(range);
00283 cell.GetSearchGraph(translationId, outputSearchGraphStream, reachable);
00284 }
00285 }
00286 }
00287
00288 void ChartManager::FindReachableHypotheses( const ChartHypothesis *hypo, std::map<unsigned,bool> &reachable ) const
00289 {
00290
00291 if (reachable.find(hypo->GetId()) != reachable.end()) {
00292 return;
00293 }
00294
00295
00296 reachable[ hypo->GetId() ] = true;
00297 const std::vector<const ChartHypothesis*> &previous = hypo->GetPrevHypos();
00298 for(std::vector<const ChartHypothesis*>::const_iterator i = previous.begin(); i != previous.end(); ++i) {
00299 FindReachableHypotheses( *i, reachable );
00300 }
00301
00302
00303 const ChartArcList *arcList = hypo->GetArcList();
00304 if (arcList) {
00305 ChartArcList::const_iterator iterArc;
00306 for (iterArc = arcList->begin(); iterArc != arcList->end(); ++iterArc) {
00307 const ChartHypothesis &arc = **iterArc;
00308 FindReachableHypotheses( &arc, reachable );
00309 }
00310 }
00311 }
00312
00313 void ChartManager::CreateDeviantPaths(
00314 boost::shared_ptr<const ChartTrellisPath> basePath,
00315 ChartTrellisDetourQueue &q)
00316 {
00317 CreateDeviantPaths(basePath, basePath->GetFinalNode(), q);
00318 }
00319
00320 void ChartManager::CreateDeviantPaths(
00321 boost::shared_ptr<const ChartTrellisPath> basePath,
00322 const ChartTrellisNode &substitutedNode,
00323 ChartTrellisDetourQueue &queue)
00324 {
00325 const ChartArcList *arcList = substitutedNode.GetHypothesis().GetArcList();
00326 if (arcList) {
00327 for (ChartArcList::const_iterator iter = arcList->begin();
00328 iter != arcList->end(); ++iter) {
00329 const ChartHypothesis &replacement = **iter;
00330 queue.Push(new ChartTrellisDetour(basePath, substitutedNode,
00331 replacement));
00332 }
00333 }
00334
00335 const ChartTrellisNode::NodeChildren &children = substitutedNode.GetChildren();
00336 ChartTrellisNode::NodeChildren::const_iterator iter;
00337 for (iter = children.begin(); iter != children.end(); ++iter) {
00338 const ChartTrellisNode &child = **iter;
00339 CreateDeviantPaths(basePath, child, queue);
00340 }
00341 }
00342
00343
00344 }