00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <algorithm>
00023 #include "StaticData.h"
00024 #include "ChartHypothesisCollection.h"
00025 #include "ChartHypothesis.h"
00026 #include "ChartManager.h"
00027 #include "HypergraphOutput.h"
00028 #include "util/exception.hh"
00029 #include "parameters/AllOptions.h"
00030
00031 using namespace std;
00032 using namespace Moses;
00033
00034 namespace Moses
00035 {
00036
00037 ChartHypothesisCollection::ChartHypothesisCollection(AllOptions const& opts)
00038 {
00039
00040
00041 m_beamWidth = opts.search.beam_width;
00042 m_maxHypoStackSize = opts.search.stack_size;
00043 m_nBestIsEnabled = opts.nbest.enabled;
00044 m_bestScore = -std::numeric_limits<float>::infinity();
00045 }
00046
00047 ChartHypothesisCollection::~ChartHypothesisCollection()
00048 {
00049 HCType::iterator iter;
00050 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00051 ChartHypothesis *hypo = *iter;
00052 delete hypo;
00053 }
00054
00055 }
00056
00064 bool ChartHypothesisCollection::AddHypothesis(ChartHypothesis *hypo, ChartManager &manager)
00065 {
00066 if (hypo->GetFutureScore() == - std::numeric_limits<float>::infinity()) {
00067 manager.GetSentenceStats().AddDiscarded();
00068 VERBOSE(3,"discarded, -inf score" << std::endl);
00069 delete hypo;
00070 return false;
00071 }
00072
00073 if (hypo->GetFutureScore() < m_bestScore + m_beamWidth) {
00074
00075 manager.GetSentenceStats().AddDiscarded();
00076 VERBOSE(3,"discarded, too bad for stack" << std::endl);
00077 delete hypo;
00078 return false;
00079 }
00080
00081
00082 std::pair<HCType::iterator, bool> addRet = Add(hypo, manager);
00083
00084
00085 if (addRet.second) {
00086
00087 return true;
00088 }
00089
00090
00091 HCType::iterator &iterExisting = addRet.first;
00092 ChartHypothesis *hypoExisting = *iterExisting;
00093 UTIL_THROW_IF2(iterExisting == m_hypos.end(),
00094 "Adding a hypothesis should have returned a valid iterator");
00095
00096
00097
00098
00099
00100 if (hypo->GetFutureScore() > hypoExisting->GetFutureScore()) {
00101
00102 VERBOSE(3,"better than matching hyp " << hypoExisting->GetId() << ", recombining, ");
00103 if (m_nBestIsEnabled) {
00104 hypo->AddArc(hypoExisting);
00105 Detach(iterExisting);
00106 } else {
00107 Remove(iterExisting);
00108 }
00109
00110 bool added = Add(hypo, manager).second;
00111 if (!added) {
00112 iterExisting = m_hypos.find(hypo);
00113 UTIL_THROW2("Offending hypo = " << **iterExisting);
00114 }
00115 return false;
00116 } else {
00117
00118 VERBOSE(3,"worse than matching hyp " << hypoExisting->GetId() << ", recombining" << std::endl)
00119 if (m_nBestIsEnabled) {
00120 hypoExisting->AddArc(hypo);
00121 } else {
00122 delete hypo;
00123 }
00124 return false;
00125 }
00126 }
00127
00133 pair<ChartHypothesisCollection::HCType::iterator, bool> ChartHypothesisCollection::Add(ChartHypothesis *hypo, ChartManager &manager)
00134 {
00135 std::pair<HCType::iterator, bool> ret = m_hypos.insert(hypo);
00136 if (ret.second) {
00137
00138 VERBOSE(3,"added hyp to stack");
00139
00140
00141 if (hypo->GetFutureScore() > m_bestScore) {
00142 VERBOSE(3,", best on stack");
00143 m_bestScore = hypo->GetFutureScore();
00144 }
00145
00146
00147 VERBOSE(3,", now size " << m_hypos.size());
00148 if (m_hypos.size() > 2*m_maxHypoStackSize-1) {
00149 PruneToSize(manager);
00150 } else {
00151 VERBOSE(3,std::endl);
00152 }
00153 }
00154
00155 return ret;
00156 }
00157
00161 void ChartHypothesisCollection::Detach(const HCType::iterator &iter)
00162 {
00163 m_hypos.erase(iter);
00164 }
00165
00168 void ChartHypothesisCollection::Remove(const HCType::iterator &iter)
00169 {
00170 ChartHypothesis *h = *iter;
00171 Detach(iter);
00172 delete h;
00173 }
00174
00179 void ChartHypothesisCollection::PruneToSize(ChartManager &manager)
00180 {
00181 if (m_maxHypoStackSize == 0) return;
00182
00183 if (GetSize() > m_maxHypoStackSize) {
00184 priority_queue<float> bestScores;
00185
00186
00187
00188 HCType::iterator iter = m_hypos.begin();
00189 float score = 0;
00190 while (iter != m_hypos.end()) {
00191 ChartHypothesis *hypo = *iter;
00192 score = hypo->GetFutureScore();
00193 if (score > m_bestScore+m_beamWidth) {
00194 bestScores.push(score);
00195 }
00196 ++iter;
00197 }
00198
00199
00200
00201 size_t minNewSizeHeapSize = m_maxHypoStackSize > bestScores.size() ? bestScores.size() : m_maxHypoStackSize;
00202 for (size_t i = 1 ; i < minNewSizeHeapSize ; i++)
00203 bestScores.pop();
00204
00205
00206 float scoreThreshold = bestScores.top();
00207
00208
00209 iter = m_hypos.begin();
00210 while (iter != m_hypos.end()) {
00211 ChartHypothesis *hypo = *iter;
00212 float score = hypo->GetFutureScore();
00213 if (score < scoreThreshold) {
00214 HCType::iterator iterRemove = iter++;
00215 Remove(iterRemove);
00216 manager.GetSentenceStats().AddPruning();
00217 } else {
00218 ++iter;
00219 }
00220 }
00221 VERBOSE(3,", pruned to size " << m_hypos.size() << endl);
00222
00223 IFVERBOSE(3) {
00224 TRACE_ERR("stack now contains: ");
00225 for(iter = m_hypos.begin(); iter != m_hypos.end(); iter++) {
00226 ChartHypothesis *hypo = *iter;
00227 TRACE_ERR( hypo->GetId() << " (" << hypo->GetFutureScore() << ") ");
00228 }
00229 TRACE_ERR( endl);
00230 }
00231
00232
00233 if (m_hypos.size() > m_maxHypoStackSize * 2) {
00234 std::vector<ChartHypothesis*> hyposOrdered;
00235
00236
00237 std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(hyposOrdered, hyposOrdered.end()));
00238 std::sort(hyposOrdered.begin(), hyposOrdered.end(), ChartHypothesisScoreOrderer());
00239
00240
00241 std::vector<ChartHypothesis*>::iterator iter;
00242 for (iter = hyposOrdered.begin() + (m_maxHypoStackSize * 2); iter != hyposOrdered.end(); ++iter) {
00243 ChartHypothesis *hypo = *iter;
00244 HCType::iterator iterFindHypo = m_hypos.find(hypo);
00245 UTIL_THROW_IF2(iterFindHypo == m_hypos.end(),
00246 "Adding a hypothesis should have returned a valid iterator");
00247
00248 Remove(iterFindHypo);
00249 }
00250 }
00251 }
00252 }
00253
00255 void ChartHypothesisCollection::SortHypotheses()
00256 {
00257 UTIL_THROW_IF2(!m_hyposOrdered.empty(), "Hypotheses already sorted");
00258 if (!m_hypos.empty()) {
00259
00260
00261
00262 m_hyposOrdered.reserve(m_hypos.size());
00263 std::copy(m_hypos.begin(), m_hypos.end(), back_inserter(m_hyposOrdered));
00264 std::sort(m_hyposOrdered.begin(), m_hyposOrdered.end(), ChartHypothesisScoreOrderer());
00265 }
00266 }
00267
00269 void ChartHypothesisCollection::CleanupArcList()
00270 {
00271 HCType::iterator iter;
00272 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00273 ChartHypothesis *mainHypo = *iter;
00274 mainHypo->CleanupArcList();
00275 }
00276 }
00277
00284 void ChartHypothesisCollection::WriteSearchGraph(const ChartSearchGraphWriter& writer, const std::map<unsigned, bool> &reachable) const
00285 {
00286 writer.WriteHypos(*this,reachable);
00287 }
00288
00289 std::ostream& operator<<(std::ostream &out, const ChartHypothesisCollection &coll)
00290 {
00291 HypoList::const_iterator iterInside;
00292 for (iterInside = coll.m_hyposOrdered.begin(); iterInside != coll.m_hyposOrdered.end(); ++iterInside) {
00293 const ChartHypothesis &hypo = **iterInside;
00294 out << hypo << endl;
00295 }
00296
00297 return out;
00298 }
00299
00300
00301 }