00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <algorithm>
00023 #include <set>
00024 #include <queue>
00025 #include "HypothesisStackCubePruning.h"
00026 #include "TypeDef.h"
00027 #include "Util.h"
00028 #include "StaticData.h"
00029 #include "Manager.h"
00030
00031 using namespace std;
00032
00033 namespace Moses
00034 {
00035 HypothesisStackCubePruning::HypothesisStackCubePruning(Manager& manager) :
00036 HypothesisStack(manager)
00037 {
00038 m_nBestIsEnabled = StaticData::Instance().IsNBestEnabled();
00039 m_bestScore = -std::numeric_limits<float>::infinity();
00040 m_worstScore = -std::numeric_limits<float>::infinity();
00041 }
00042
00044 void HypothesisStackCubePruning::RemoveAll()
00045 {
00046
00047 _BMType::iterator iter;
00048 for (iter = m_bitmapAccessor.begin(); iter != m_bitmapAccessor.end(); ++iter) {
00049 delete iter->second;
00050 }
00051 }
00052
00053 pair<HypothesisStackCubePruning::iterator, bool> HypothesisStackCubePruning::Add(Hypothesis *hypo)
00054 {
00055 std::pair<iterator, bool> ret = m_hypos.insert(hypo);
00056
00057 if (ret.second) {
00058
00059 VERBOSE(3,"added hyp to stack");
00060
00061
00062 if (hypo->GetTotalScore() > m_bestScore) {
00063 VERBOSE(3,", best on stack");
00064 m_bestScore = hypo->GetTotalScore();
00065
00066 if ( m_bestScore + m_beamWidth > m_worstScore )
00067 m_worstScore = m_bestScore + m_beamWidth;
00068 }
00069
00070
00071 VERBOSE(3,", now size " << m_hypos.size());
00072 if (m_hypos.size() > 2*m_maxHypoStackSize-1) {
00073 PruneToSize(m_maxHypoStackSize);
00074 } else {
00075 VERBOSE(3,std::endl);
00076 }
00077 }
00078
00079 return ret;
00080 }
00081
00082 bool HypothesisStackCubePruning::AddPrune(Hypothesis *hypo)
00083 {
00084 if (hypo->GetTotalScore() < m_worstScore) {
00085
00086 m_manager.GetSentenceStats().AddDiscarded();
00087 VERBOSE(3,"discarded, too bad for stack" << std::endl);
00088 FREEHYPO(hypo);
00089 return false;
00090 }
00091
00092
00093 std::pair<iterator, bool> addRet = Add(hypo);
00094 if (addRet.second) {
00095
00096 return true;
00097 }
00098
00099
00100 iterator &iterExisting = addRet.first;
00101 Hypothesis *hypoExisting = *iterExisting;
00102 CHECK(iterExisting != m_hypos.end());
00103
00104 m_manager.GetSentenceStats().AddRecombination(*hypo, **iterExisting);
00105
00106
00107
00108 if (hypo->GetTotalScore() > hypoExisting->GetTotalScore()) {
00109
00110 VERBOSE(3,"better than matching hyp " << hypoExisting->GetId() << ", recombining, ");
00111 if (m_nBestIsEnabled) {
00112 hypo->AddArc(hypoExisting);
00113 Detach(iterExisting);
00114 } else {
00115 Remove(iterExisting);
00116 }
00117
00118 bool added = Add(hypo).second;
00119 if (!added) {
00120 iterExisting = m_hypos.find(hypo);
00121 TRACE_ERR("Offending hypo = " << **iterExisting << endl);
00122 CHECK(false);
00123 }
00124 return false;
00125 } else {
00126
00127 VERBOSE(3,"worse than matching hyp " << hypoExisting->GetId() << ", recombining" << std::endl)
00128 if (m_nBestIsEnabled) {
00129 hypoExisting->AddArc(hypo);
00130 } else {
00131 FREEHYPO(hypo);
00132 }
00133 return false;
00134 }
00135 }
00136
00137 void HypothesisStackCubePruning::AddInitial(Hypothesis *hypo)
00138 {
00139 std::pair<iterator, bool> addRet = Add(hypo);
00140 CHECK(addRet.second);
00141
00142 const WordsBitmap &bitmap = hypo->GetWordsBitmap();
00143 m_bitmapAccessor[bitmap] = new BitmapContainer(bitmap, *this);
00144 }
00145
00146 void HypothesisStackCubePruning::PruneToSize(size_t newSize)
00147 {
00148 if (m_hypos.size() > newSize) {
00149 priority_queue<float> bestScores;
00150
00151
00152
00153 iterator iter = m_hypos.begin();
00154 float score = 0;
00155 while (iter != m_hypos.end()) {
00156 Hypothesis *hypo = *iter;
00157 score = hypo->GetTotalScore();
00158 if (score > m_bestScore+m_beamWidth) {
00159 bestScores.push(score);
00160 }
00161 ++iter;
00162 }
00163
00164
00165
00166 size_t minNewSizeHeapSize = newSize > bestScores.size() ? bestScores.size() : newSize;
00167 for (size_t i = 1 ; i < minNewSizeHeapSize ; i++)
00168 bestScores.pop();
00169
00170
00171 float scoreThreshold = bestScores.top();
00172
00173
00174 iter = m_hypos.begin();
00175 while (iter != m_hypos.end()) {
00176 Hypothesis *hypo = *iter;
00177 float score = hypo->GetTotalScore();
00178 if (score < scoreThreshold) {
00179 iterator iterRemove = iter++;
00180 Remove(iterRemove);
00181 m_manager.GetSentenceStats().AddPruning();
00182 } else {
00183 ++iter;
00184 }
00185 }
00186 VERBOSE(3,", pruned to size " << size() << endl);
00187
00188 IFVERBOSE(3) {
00189 TRACE_ERR("stack now contains: ");
00190 for(iter = m_hypos.begin(); iter != m_hypos.end(); iter++) {
00191 Hypothesis *hypo = *iter;
00192 TRACE_ERR( hypo->GetId() << " (" << hypo->GetTotalScore() << ") ");
00193 }
00194 TRACE_ERR( endl);
00195 }
00196
00197
00198 m_worstScore = scoreThreshold;
00199 }
00200 }
00201
00202 const Hypothesis *HypothesisStackCubePruning::GetBestHypothesis() const
00203 {
00204 if (!m_hypos.empty()) {
00205 const_iterator iter = m_hypos.begin();
00206 Hypothesis *bestHypo = *iter;
00207 while (++iter != m_hypos.end()) {
00208 Hypothesis *hypo = *iter;
00209 if (hypo->GetTotalScore() > bestHypo->GetTotalScore())
00210 bestHypo = hypo;
00211 }
00212 return bestHypo;
00213 }
00214 return NULL;
00215 }
00216
00217 vector<const Hypothesis*> HypothesisStackCubePruning::GetSortedList() const
00218 {
00219 vector<const Hypothesis*> ret;
00220 ret.reserve(m_hypos.size());
00221 std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(ret, ret.end()));
00222 sort(ret.begin(), ret.end(), CompareHypothesisTotalScore());
00223
00224 return ret;
00225 }
00226
00227
00228 void HypothesisStackCubePruning::CleanupArcList()
00229 {
00230
00231 if (!m_nBestIsEnabled) return;
00232
00233 iterator iter;
00234 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00235 Hypothesis *mainHypo = *iter;
00236 mainHypo->CleanupArcList();
00237 }
00238 }
00239
00240 void HypothesisStackCubePruning::SetBitmapAccessor(const WordsBitmap &newBitmap
00241 , HypothesisStackCubePruning &stack
00242 , const WordsRange &
00243 , BitmapContainer &bitmapContainer
00244 , const SquareMatrix &futureScore
00245 , const TranslationOptionList &transOptList)
00246 {
00247 _BMType::iterator bcExists = m_bitmapAccessor.find(newBitmap);
00248
00249 BitmapContainer *bmContainer;
00250 if (bcExists == m_bitmapAccessor.end()) {
00251 bmContainer = new BitmapContainer(newBitmap, stack);
00252 m_bitmapAccessor[newBitmap] = bmContainer;
00253 } else {
00254 bmContainer = bcExists->second;
00255 }
00256
00257 BackwardsEdge *edge = new BackwardsEdge(bitmapContainer
00258 , *bmContainer
00259 , transOptList
00260 , futureScore,
00261 m_manager.GetSource(),
00262 m_manager.GetTranslationSystem());
00263 bmContainer->AddBackwardsEdge(edge);
00264 }
00265
00266
00267 TO_STRING_BODY(HypothesisStackCubePruning);
00268
00269
00270
00271 std::ostream& operator<<(std::ostream& out, const HypothesisStackCubePruning& hypoColl)
00272 {
00273 HypothesisStackCubePruning::const_iterator iter;
00274
00275 for (iter = hypoColl.begin() ; iter != hypoColl.end() ; ++iter) {
00276 const Hypothesis &hypo = **iter;
00277 out << hypo << endl;
00278
00279 }
00280 return out;
00281 }
00282
00283 void
00284 HypothesisStackCubePruning::AddHypothesesToBitmapContainers()
00285 {
00286 HypothesisStackCubePruning::const_iterator iter;
00287 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00288 Hypothesis *h = *iter;
00289 const WordsBitmap &bitmap = h->GetWordsBitmap();
00290 BitmapContainer *container = m_bitmapAccessor[bitmap];
00291 container->AddHypothesis(h);
00292 }
00293 }
00294
00295 }
00296