00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <algorithm>
00023 #include <set>
00024 #include <queue>
00025 #include "HypothesisStackCubePruning.h"
00026 #include "TypeDef.h"
00027 #include "Util.h"
00028 #include "StaticData.h"
00029 #include "Manager.h"
00030 #include "util/exception.hh"
00031
00032 using namespace std;
00033
00034 namespace Moses
00035 {
00036 HypothesisStackCubePruning::HypothesisStackCubePruning(Manager& manager) :
00037 HypothesisStack(manager)
00038 {
00039 m_nBestIsEnabled = manager.options()->nbest.enabled;
00040 m_bestScore = -std::numeric_limits<float>::infinity();
00041 m_worstScore = -std::numeric_limits<float>::infinity();
00042 m_deterministic = manager.options()->cube.deterministic_search;
00043 }
00044
00046 void HypothesisStackCubePruning::RemoveAll()
00047 {
00048
00049 _BMType::iterator iter;
00050 for (iter = m_bitmapAccessor.begin(); iter != m_bitmapAccessor.end(); ++iter) {
00051 delete iter->second;
00052 }
00053 }
00054
00055 pair<HypothesisStackCubePruning::iterator, bool> HypothesisStackCubePruning::Add(Hypothesis *hypo)
00056 {
00057 std::pair<iterator, bool> ret = m_hypos.insert(hypo);
00058
00059 if (ret.second) {
00060
00061 VERBOSE(3,"added hyp to stack");
00062
00063
00064 if (hypo->GetFutureScore() > m_bestScore) {
00065 VERBOSE(3,", best on stack");
00066 m_bestScore = hypo->GetFutureScore();
00067
00068 if ( m_bestScore + m_beamWidth > m_worstScore )
00069 m_worstScore = m_bestScore + m_beamWidth;
00070 }
00071
00072
00073 VERBOSE(3,", now size " << m_hypos.size());
00074 if (m_hypos.size() > 2*m_maxHypoStackSize-1) {
00075 PruneToSize(m_maxHypoStackSize);
00076 } else {
00077 VERBOSE(3,std::endl);
00078 }
00079 }
00080
00081 return ret;
00082 }
00083
00084 bool HypothesisStackCubePruning::AddPrune(Hypothesis *hypo)
00085 {
00086 if (hypo->GetFutureScore() == - std::numeric_limits<float>::infinity()) {
00087 m_manager.GetSentenceStats().AddDiscarded();
00088 VERBOSE(3,"discarded, constraint" << std::endl);
00089 delete hypo;
00090 return false;
00091 }
00092
00093 if (hypo->GetFutureScore() < m_worstScore) {
00094
00095 m_manager.GetSentenceStats().AddDiscarded();
00096 VERBOSE(3,"discarded, too bad for stack" << std::endl);
00097 delete hypo;
00098 return false;
00099 }
00100
00101
00102 std::pair<iterator, bool> addRet = Add(hypo);
00103 if (addRet.second) {
00104
00105 return true;
00106 }
00107
00108
00109 iterator &iterExisting = addRet.first;
00110 assert(iterExisting != m_hypos.end());
00111 Hypothesis *hypoExisting = *iterExisting;
00112
00113 m_manager.GetSentenceStats().AddRecombination(*hypo, **iterExisting);
00114
00115
00116
00117 if (hypo->GetFutureScore() > hypoExisting->GetFutureScore()) {
00118
00119 VERBOSE(3,"better than matching hyp " << hypoExisting->GetId() << ", recombining, ");
00120 if (m_nBestIsEnabled) {
00121 hypo->AddArc(hypoExisting);
00122 Detach(iterExisting);
00123 } else {
00124 Remove(iterExisting);
00125 }
00126
00127 bool added = Add(hypo).second;
00128 if (!added) {
00129 iterExisting = m_hypos.find(hypo);
00130 UTIL_THROW(util::Exception, "Should have added hypothesis " << **iterExisting);
00131 }
00132 return false;
00133 } else {
00134
00135 VERBOSE(3,"worse than matching hyp " << hypoExisting->GetId() << ", recombining" << std::endl)
00136 if (m_nBestIsEnabled) {
00137 hypoExisting->AddArc(hypo);
00138 } else {
00139 delete hypo;
00140 }
00141 return false;
00142 }
00143 }
00144
00145 void HypothesisStackCubePruning::AddInitial(Hypothesis *hypo)
00146 {
00147 std::pair<iterator, bool> addRet = Add(hypo);
00148 UTIL_THROW_IF2(!addRet.second,
00149 "Should have added hypothesis " << *hypo);
00150
00151 const Bitmap &bitmap = hypo->GetWordsBitmap();
00152 AddBitmapContainer(bitmap, *this);
00153 }
00154
00155 void HypothesisStackCubePruning::PruneToSize(size_t newSize)
00156 {
00157 if ( newSize == 0) return;
00158
00159 if (m_hypos.size() > newSize) {
00160 priority_queue<float> bestScores;
00161
00162
00163
00164 iterator iter = m_hypos.begin();
00165 float score = 0;
00166 while (iter != m_hypos.end()) {
00167 Hypothesis *hypo = *iter;
00168 score = hypo->GetFutureScore();
00169 if (score > m_bestScore+m_beamWidth) {
00170 bestScores.push(score);
00171 }
00172 ++iter;
00173 }
00174
00175
00176
00177 size_t minNewSizeHeapSize = newSize > bestScores.size() ? bestScores.size() : newSize;
00178 for (size_t i = 1 ; i < minNewSizeHeapSize ; i++)
00179 bestScores.pop();
00180
00181
00182 float scoreThreshold = bestScores.top();
00183
00184
00185 iter = m_hypos.begin();
00186 while (iter != m_hypos.end()) {
00187 Hypothesis *hypo = *iter;
00188 float score = hypo->GetFutureScore();
00189 if (score < scoreThreshold) {
00190 iterator iterRemove = iter++;
00191 Remove(iterRemove);
00192 m_manager.GetSentenceStats().AddPruning();
00193 } else {
00194 ++iter;
00195 }
00196 }
00197 VERBOSE(3,", pruned to size " << size() << endl);
00198
00199 IFVERBOSE(3) {
00200 TRACE_ERR("stack now contains: ");
00201 for(iter = m_hypos.begin(); iter != m_hypos.end(); iter++) {
00202 Hypothesis *hypo = *iter;
00203 TRACE_ERR( hypo->GetId() << " (" << hypo->GetFutureScore() << ") ");
00204 }
00205 TRACE_ERR( endl);
00206 }
00207
00208
00209 m_worstScore = scoreThreshold;
00210 }
00211 }
00212
00213 const Hypothesis *HypothesisStackCubePruning::GetBestHypothesis() const
00214 {
00215 if (!m_hypos.empty()) {
00216 const_iterator iter = m_hypos.begin();
00217 Hypothesis *bestHypo = *iter;
00218 while (++iter != m_hypos.end()) {
00219 Hypothesis *hypo = *iter;
00220 if (hypo->GetFutureScore() > bestHypo->GetFutureScore())
00221 bestHypo = hypo;
00222 }
00223 return bestHypo;
00224 }
00225 return NULL;
00226 }
00227
00228 vector<const Hypothesis*> HypothesisStackCubePruning::GetSortedList() const
00229 {
00230 vector<const Hypothesis*> ret;
00231 ret.reserve(m_hypos.size());
00232 std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(ret, ret.end()));
00233 sort(ret.begin(), ret.end(), CompareHypothesisTotalScore());
00234
00235 return ret;
00236 }
00237
00238
00239 void HypothesisStackCubePruning::CleanupArcList()
00240 {
00241
00242 if (!m_nBestIsEnabled) return;
00243
00244 iterator iter;
00245 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00246 Hypothesis *mainHypo = *iter;
00247 mainHypo->CleanupArcList(this->m_manager.options()->nbest.nbest_size, this->m_manager.options()->NBestDistinct());
00248 }
00249 }
00250
00251 void HypothesisStackCubePruning::SetBitmapAccessor(const Bitmap &newBitmap
00252 , HypothesisStackCubePruning &stack
00253 , const Range &
00254 , BitmapContainer &bitmapContainer
00255 , const SquareMatrix &estimatedScores
00256 , const TranslationOptionList &transOptList)
00257 {
00258 BitmapContainer *bmContainer = AddBitmapContainer(newBitmap, stack);
00259 BackwardsEdge *edge = new BackwardsEdge(bitmapContainer
00260 , *bmContainer
00261 , transOptList
00262 , estimatedScores
00263 , m_manager.GetSource()
00264 , m_deterministic);
00265 bmContainer->AddBackwardsEdge(edge);
00266 }
00267
00268
00269 TO_STRING_BODY(HypothesisStackCubePruning);
00270
00271
00272
00273 std::ostream& operator<<(std::ostream& out, const HypothesisStackCubePruning& hypoColl)
00274 {
00275 HypothesisStackCubePruning::const_iterator iter;
00276
00277 for (iter = hypoColl.begin() ; iter != hypoColl.end() ; ++iter) {
00278 const Hypothesis &hypo = **iter;
00279 out << hypo << endl;
00280
00281 }
00282 return out;
00283 }
00284
00285 void
00286 HypothesisStackCubePruning::AddHypothesesToBitmapContainers()
00287 {
00288 HypothesisStackCubePruning::const_iterator iter;
00289 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00290 Hypothesis *h = *iter;
00291 const Bitmap &bitmap = h->GetWordsBitmap();
00292 BitmapContainer *container = m_bitmapAccessor[&bitmap];
00293 container->AddHypothesis(h);
00294 }
00295 }
00296
00297 BitmapContainer *HypothesisStackCubePruning::AddBitmapContainer(const Bitmap &bitmap, HypothesisStackCubePruning &stack)
00298 {
00299 _BMType::iterator iter = m_bitmapAccessor.find(&bitmap);
00300
00301 BitmapContainer *bmContainer;
00302 if (iter == m_bitmapAccessor.end()) {
00303 bmContainer = new BitmapContainer(bitmap, stack, m_deterministic);
00304 m_bitmapAccessor[&bitmap] = bmContainer;
00305 } else {
00306 bmContainer = iter->second;
00307 }
00308
00309 return bmContainer;
00310 }
00311
00312 }
00313