00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <algorithm>
00023 #include "StaticData.h"
00024 #include "ChartHypothesisCollection.h"
00025 #include "ChartHypothesis.h"
00026 #include "ChartManager.h"
00027
00028 using namespace std;
00029 using namespace Moses;
00030
00031 namespace Moses
00032 {
00033
00034 ChartHypothesisCollection::ChartHypothesisCollection()
00035 {
00036 const StaticData &staticData = StaticData::Instance();
00037
00038 m_beamWidth = staticData.GetBeamWidth();
00039 m_maxHypoStackSize = staticData.GetMaxHypoStackSize();
00040 m_nBestIsEnabled = staticData.IsNBestEnabled();
00041 m_bestScore = -std::numeric_limits<float>::infinity();
00042 }
00043
00044 ChartHypothesisCollection::~ChartHypothesisCollection()
00045 {
00046 HCType::iterator iter;
00047 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00048 ChartHypothesis *hypo = *iter;
00049 ChartHypothesis::Delete(hypo);
00050 }
00051
00052 }
00053
00061 bool ChartHypothesisCollection::AddHypothesis(ChartHypothesis *hypo, ChartManager &manager)
00062 {
00063 if (hypo->GetTotalScore() < m_bestScore + m_beamWidth) {
00064
00065 manager.GetSentenceStats().AddDiscarded();
00066 VERBOSE(3,"discarded, too bad for stack" << std::endl);
00067 ChartHypothesis::Delete(hypo);
00068 return false;
00069 }
00070
00071
00072 std::pair<HCType::iterator, bool> addRet = Add(hypo, manager);
00073
00074
00075 if (addRet.second) {
00076
00077 return true;
00078 }
00079
00080
00081 HCType::iterator &iterExisting = addRet.first;
00082 ChartHypothesis *hypoExisting = *iterExisting;
00083 CHECK(iterExisting != m_hypos.end());
00084
00085
00086
00087
00088
00089 if (hypo->GetTotalScore() > hypoExisting->GetTotalScore()) {
00090
00091 VERBOSE(3,"better than matching hyp " << hypoExisting->GetId() << ", recombining, ");
00092 if (m_nBestIsEnabled) {
00093 hypo->AddArc(hypoExisting);
00094 Detach(iterExisting);
00095 } else {
00096 Remove(iterExisting);
00097 }
00098
00099 bool added = Add(hypo, manager).second;
00100 if (!added) {
00101 iterExisting = m_hypos.find(hypo);
00102 TRACE_ERR("Offending hypo = " << **iterExisting << endl);
00103 abort();
00104 }
00105 return false;
00106 } else {
00107
00108 VERBOSE(3,"worse than matching hyp " << hypoExisting->GetId() << ", recombining" << std::endl)
00109 if (m_nBestIsEnabled) {
00110 hypoExisting->AddArc(hypo);
00111 }
00112 else {
00113 ChartHypothesis::Delete(hypo);
00114 }
00115 return false;
00116 }
00117 }
00118
00124 pair<ChartHypothesisCollection::HCType::iterator, bool> ChartHypothesisCollection::Add(ChartHypothesis *hypo, ChartManager &manager)
00125 {
00126 std::pair<HCType::iterator, bool> ret = m_hypos.insert(hypo);
00127 if (ret.second) {
00128
00129 VERBOSE(3,"added hyp to stack");
00130
00131
00132 if (hypo->GetTotalScore() > m_bestScore) {
00133 VERBOSE(3,", best on stack");
00134 m_bestScore = hypo->GetTotalScore();
00135 }
00136
00137
00138 VERBOSE(3,", now size " << m_hypos.size());
00139 if (m_hypos.size() > 2*m_maxHypoStackSize-1) {
00140 PruneToSize(manager);
00141 } else {
00142 VERBOSE(3,std::endl);
00143 }
00144 }
00145
00146 return ret;
00147 }
00148
00152 void ChartHypothesisCollection::Detach(const HCType::iterator &iter)
00153 {
00154 m_hypos.erase(iter);
00155 }
00156
00159 void ChartHypothesisCollection::Remove(const HCType::iterator &iter)
00160 {
00161 ChartHypothesis *h = *iter;
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176 Detach(iter);
00177 ChartHypothesis::Delete(h);
00178 }
00179
00184 void ChartHypothesisCollection::PruneToSize(ChartManager &manager)
00185 {
00186 if (GetSize() > m_maxHypoStackSize) {
00187 priority_queue<float> bestScores;
00188
00189
00190
00191 HCType::iterator iter = m_hypos.begin();
00192 float score = 0;
00193 while (iter != m_hypos.end()) {
00194 ChartHypothesis *hypo = *iter;
00195 score = hypo->GetTotalScore();
00196 if (score > m_bestScore+m_beamWidth) {
00197 bestScores.push(score);
00198 }
00199 ++iter;
00200 }
00201
00202
00203
00204 size_t minNewSizeHeapSize = m_maxHypoStackSize > bestScores.size() ? bestScores.size() : m_maxHypoStackSize;
00205 for (size_t i = 1 ; i < minNewSizeHeapSize ; i++)
00206 bestScores.pop();
00207
00208
00209 float scoreThreshold = bestScores.top();
00210
00211
00212 iter = m_hypos.begin();
00213 while (iter != m_hypos.end()) {
00214 ChartHypothesis *hypo = *iter;
00215 float score = hypo->GetTotalScore();
00216 if (score < scoreThreshold) {
00217 HCType::iterator iterRemove = iter++;
00218 Remove(iterRemove);
00219 manager.GetSentenceStats().AddPruning();
00220 } else {
00221 ++iter;
00222 }
00223 }
00224 VERBOSE(3,", pruned to size " << m_hypos.size() << endl);
00225
00226 IFVERBOSE(3) {
00227 TRACE_ERR("stack now contains: ");
00228 for(iter = m_hypos.begin(); iter != m_hypos.end(); iter++) {
00229 ChartHypothesis *hypo = *iter;
00230 TRACE_ERR( hypo->GetId() << " (" << hypo->GetTotalScore() << ") ");
00231 }
00232 TRACE_ERR( endl);
00233 }
00234
00235
00236 if (m_hypos.size() > m_maxHypoStackSize * 2) {
00237 std::vector<ChartHypothesis*> hyposOrdered;
00238
00239
00240 std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(hyposOrdered, hyposOrdered.end()));
00241 std::sort(hyposOrdered.begin(), hyposOrdered.end(), ChartHypothesisScoreOrderer());
00242
00243
00244 std::vector<ChartHypothesis*>::iterator iter;
00245 for (iter = hyposOrdered.begin() + (m_maxHypoStackSize * 2); iter != hyposOrdered.end(); ++iter) {
00246 ChartHypothesis *hypo = *iter;
00247 HCType::iterator iterFindHypo = m_hypos.find(hypo);
00248 CHECK(iterFindHypo != m_hypos.end());
00249 Remove(iterFindHypo);
00250 }
00251 }
00252 }
00253 }
00254
00256 void ChartHypothesisCollection::SortHypotheses()
00257 {
00258 CHECK(m_hyposOrdered.empty());
00259 if (!m_hypos.empty()) {
00260
00261
00262
00263 m_hyposOrdered.reserve(m_hypos.size());
00264 std::copy(m_hypos.begin(), m_hypos.end(), back_inserter(m_hyposOrdered));
00265 std::sort(m_hyposOrdered.begin(), m_hyposOrdered.end(), ChartHypothesisScoreOrderer());
00266 }
00267 }
00268
00270 void ChartHypothesisCollection::CleanupArcList()
00271 {
00272 HCType::iterator iter;
00273 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00274 ChartHypothesis *mainHypo = *iter;
00275 mainHypo->CleanupArcList();
00276 }
00277 }
00278
00285 void ChartHypothesisCollection::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const std::map<unsigned, bool> &reachable) const
00286 {
00287 HCType::const_iterator iter;
00288 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00289 ChartHypothesis &mainHypo = **iter;
00290 if (StaticData::Instance().GetUnprunedSearchGraph() ||
00291 reachable.find(mainHypo.GetId()) != reachable.end()) {
00292 outputSearchGraphStream << translationId << " " << mainHypo << endl;
00293 }
00294
00295 const ChartArcList *arcList = mainHypo.GetArcList();
00296 if (arcList) {
00297 ChartArcList::const_iterator iterArc;
00298 for (iterArc = arcList->begin(); iterArc != arcList->end(); ++iterArc) {
00299 const ChartHypothesis &arc = **iterArc;
00300 if (reachable.find(arc.GetId()) != reachable.end()) {
00301 outputSearchGraphStream << translationId << " " << arc << endl;
00302 }
00303 }
00304 }
00305 }
00306 }
00307
00308 std::ostream& operator<<(std::ostream &out, const ChartHypothesisCollection &coll)
00309 {
00310 HypoList::const_iterator iterInside;
00311 for (iterInside = coll.m_hyposOrdered.begin(); iterInside != coll.m_hyposOrdered.end(); ++iterInside) {
00312 const ChartHypothesis &hypo = **iterInside;
00313 out << hypo << endl;
00314 }
00315
00316 return out;
00317 }
00318
00319
00320 }