00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <algorithm>
00023 #include <set>
00024 #include <queue>
00025 #include "HypothesisStackNormal.h"
00026 #include "TypeDef.h"
00027 #include "Util.h"
00028 #include "StaticData.h"
00029 #include "Manager.h"
00030
00031 using namespace std;
00032
00033 namespace Moses
00034 {
00035 HypothesisStackNormal::HypothesisStackNormal(Manager& manager) :
00036 HypothesisStack(manager)
00037 {
00038 m_nBestIsEnabled = StaticData::Instance().IsNBestEnabled();
00039 m_bestScore = -std::numeric_limits<float>::infinity();
00040 m_worstScore = -std::numeric_limits<float>::infinity();
00041 }
00042
00044 void HypothesisStackNormal::RemoveAll()
00045 {
00046 while (m_hypos.begin() != m_hypos.end()) {
00047 Remove(m_hypos.begin());
00048 }
00049 }
00050
00051 pair<HypothesisStackNormal::iterator, bool> HypothesisStackNormal::Add(Hypothesis *hypo)
00052 {
00053 std::pair<iterator, bool> ret = m_hypos.insert(hypo);
00054 if (ret.second) {
00055
00056 VERBOSE(3,"added hyp to stack");
00057
00058
00059 if (hypo->GetTotalScore() > m_bestScore) {
00060 VERBOSE(3,", best on stack");
00061 m_bestScore = hypo->GetTotalScore();
00062
00063 if ( m_bestScore + m_beamWidth > m_worstScore )
00064 m_worstScore = m_bestScore + m_beamWidth;
00065 }
00066
00067 if ( m_minHypoStackDiversity == 1 &&
00068 hypo->GetTotalScore() > GetWorstScoreForBitmap( hypo->GetWordsBitmap() ) ) {
00069 SetWorstScoreForBitmap( hypo->GetWordsBitmap().GetID(), hypo->GetTotalScore() );
00070 }
00071
00072 VERBOSE(3,", now size " << m_hypos.size());
00073
00074
00075 size_t toleratedSize = 2*m_maxHypoStackSize-1;
00076
00077 if (m_minHypoStackDiversity)
00078 toleratedSize += m_minHypoStackDiversity << StaticData::Instance().GetMaxDistortion();
00079 if (m_hypos.size() > toleratedSize) {
00080 PruneToSize(m_maxHypoStackSize);
00081 } else {
00082 VERBOSE(3,std::endl);
00083 }
00084 }
00085
00086 return ret;
00087 }
00088
00089 bool HypothesisStackNormal::AddPrune(Hypothesis *hypo)
00090 {
00091
00092 if (!StaticData::Instance().GetDisableDiscarding() &&
00093 hypo->GetTotalScore() < m_worstScore
00094 && ! ( m_minHypoStackDiversity > 0
00095 && hypo->GetTotalScore() >= GetWorstScoreForBitmap( hypo->GetWordsBitmap() ) ) ) {
00096 m_manager.GetSentenceStats().AddDiscarded();
00097 VERBOSE(3,"discarded, too bad for stack" << std::endl);
00098 FREEHYPO(hypo);
00099 return false;
00100 }
00101
00102
00103 std::pair<iterator, bool> addRet = Add(hypo);
00104 if (addRet.second) {
00105
00106 return true;
00107 }
00108
00109
00110 iterator &iterExisting = addRet.first;
00111 Hypothesis *hypoExisting = *iterExisting;
00112 CHECK(iterExisting != m_hypos.end());
00113
00114 m_manager.GetSentenceStats().AddRecombination(*hypo, **iterExisting);
00115
00116
00117
00118 if (hypo->GetTotalScore() > hypoExisting->GetTotalScore()) {
00119
00120 VERBOSE(3,"better than matching hyp " << hypoExisting->GetId() << ", recombining, ");
00121 if (m_nBestIsEnabled) {
00122 hypo->AddArc(hypoExisting);
00123 Detach(iterExisting);
00124 } else {
00125 Remove(iterExisting);
00126 }
00127
00128 bool added = Add(hypo).second;
00129 if (!added) {
00130 iterExisting = m_hypos.find(hypo);
00131 TRACE_ERR("Offending hypo = " << **iterExisting << endl);
00132 abort();
00133 }
00134 return false;
00135 } else {
00136
00137 VERBOSE(3,"worse than matching hyp " << hypoExisting->GetId() << ", recombining" << std::endl)
00138 if (m_nBestIsEnabled) {
00139 hypoExisting->AddArc(hypo);
00140 } else {
00141 FREEHYPO(hypo);
00142 }
00143 return false;
00144 }
00145 }
00146
00147 void HypothesisStackNormal::PruneToSize(size_t newSize)
00148 {
00149 if ( size() <= newSize ) return;
00150
00151
00152 vector< Hypothesis* > hypos = GetSortedListNOTCONST();
00153 bool* included = (bool*) malloc(sizeof(bool) * hypos.size());
00154 for(size_t i=0; i<hypos.size(); i++) included[i] = false;
00155
00156
00157 for( iterator iter = m_hypos.begin(); iter != m_hypos.end(); ) {
00158 iterator removeHyp = iter++;
00159 Detach(removeHyp);
00160 }
00161
00162
00163 if ( m_minHypoStackDiversity > 0 ) {
00164 map< WordsBitmapID, size_t > diversityCount;
00165 for(size_t i=0; i<hypos.size(); i++) {
00166 Hypothesis *hyp = hypos[i];
00167 WordsBitmapID coverage = hyp->GetWordsBitmap().GetID();;
00168 if (diversityCount.find( coverage ) == diversityCount.end())
00169 diversityCount[ coverage ] = 0;
00170
00171 if (diversityCount[ coverage ] < m_minHypoStackDiversity) {
00172 m_hypos.insert( hyp );
00173 included[i] = true;
00174 diversityCount[ coverage ]++;
00175 if (diversityCount[ coverage ] == m_minHypoStackDiversity)
00176 SetWorstScoreForBitmap( coverage, hyp->GetTotalScore());
00177 }
00178 }
00179 }
00180
00181
00182 if ( size() < newSize ) {
00183
00184
00185 for(size_t i=0; i<hypos.size()
00186 && size() < newSize
00187 && hypos[i]->GetTotalScore() > m_bestScore+m_beamWidth; i++) {
00188 if (! included[i]) {
00189 m_hypos.insert( hypos[i] );
00190 included[i] = true;
00191 if (size() == newSize)
00192 m_worstScore = hypos[i]->GetTotalScore();
00193 }
00194 }
00195 }
00196
00197
00198 for(size_t i=0; i<hypos.size(); i++) {
00199 if (! included[i]) {
00200 FREEHYPO( hypos[i] );
00201 m_manager.GetSentenceStats().AddPruning();
00202 }
00203 }
00204 free(included);
00205
00206
00207 VERBOSE(3,", pruned to size " << size() << endl);
00208 IFVERBOSE(3) {
00209 TRACE_ERR("stack now contains: ");
00210 for(iterator iter = m_hypos.begin(); iter != m_hypos.end(); iter++) {
00211 Hypothesis *hypo = *iter;
00212 TRACE_ERR( hypo->GetId() << " (" << hypo->GetTotalScore() << ") ");
00213 }
00214 TRACE_ERR( endl);
00215 }
00216 }
00217
00218 const Hypothesis *HypothesisStackNormal::GetBestHypothesis() const
00219 {
00220 if (!m_hypos.empty()) {
00221 const_iterator iter = m_hypos.begin();
00222 Hypothesis *bestHypo = *iter;
00223 while (++iter != m_hypos.end()) {
00224 Hypothesis *hypo = *iter;
00225 if (hypo->GetTotalScore() > bestHypo->GetTotalScore())
00226 bestHypo = hypo;
00227 }
00228 return bestHypo;
00229 }
00230 return NULL;
00231 }
00232
00233 vector<const Hypothesis*> HypothesisStackNormal::GetSortedList() const
00234 {
00235 vector<const Hypothesis*> ret;
00236 ret.reserve(m_hypos.size());
00237 std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(ret, ret.end()));
00238 sort(ret.begin(), ret.end(), CompareHypothesisTotalScore());
00239
00240 return ret;
00241 }
00242
00243 vector<Hypothesis*> HypothesisStackNormal::GetSortedListNOTCONST()
00244 {
00245 vector<Hypothesis*> ret;
00246 ret.reserve(m_hypos.size());
00247 std::copy(m_hypos.begin(), m_hypos.end(), std::inserter(ret, ret.end()));
00248 sort(ret.begin(), ret.end(), CompareHypothesisTotalScore());
00249
00250 return ret;
00251 }
00252
00253 void HypothesisStackNormal::CleanupArcList()
00254 {
00255
00256 if (!m_nBestIsEnabled) return;
00257
00258 iterator iter;
00259 for (iter = m_hypos.begin() ; iter != m_hypos.end() ; ++iter) {
00260 Hypothesis *mainHypo = *iter;
00261 mainHypo->CleanupArcList();
00262 }
00263 }
00264
00265 TO_STRING_BODY(HypothesisStackNormal);
00266
00267
00268
00269 std::ostream& operator<<(std::ostream& out, const HypothesisStackNormal& hypoColl)
00270 {
00271 HypothesisStackNormal::const_iterator iter;
00272
00273 for (iter = hypoColl.begin() ; iter != hypoColl.end() ; ++iter) {
00274 const Hypothesis &hypo = **iter;
00275 out << hypo << endl;
00276
00277 }
00278 return out;
00279 }
00280
00281
00282 }
00283