00001 #include "Manager.h"
00002 #include "Util.h"
00003 #include "SearchCubePruning.h"
00004 #include "StaticData.h"
00005 #include "InputType.h"
00006 #include "TranslationOptionCollection.h"
00007 #include <boost/foreach.hpp>
00008 using namespace std;
00009
00010 namespace Moses
00011 {
00012 class BitmapContainerOrderer
00013 {
00014 public:
00015 bool operator()(const BitmapContainer* A, const BitmapContainer* B) const {
00016 if (B->Empty()) {
00017 if (A->Empty()) {
00018 return A < B;
00019 }
00020 return false;
00021 }
00022 if (A->Empty()) {
00023 return true;
00024 }
00025
00026
00027 const float scoreA = A->Top()->GetHypothesis()->GetFutureScore();
00028 const float scoreB = B->Top()->GetHypothesis()->GetFutureScore();
00029
00030 if (scoreA < scoreB) {
00031 return true;
00032 } else if (scoreA > scoreB) {
00033 return false;
00034 } else {
00035
00036
00037
00038
00039
00040
00041 boost::shared_ptr<TargetPhrase> phrA = A->Top()->GetTargetPhrase();
00042 boost::shared_ptr<TargetPhrase> phrB = B->Top()->GetTargetPhrase();
00043 if (!phrA || !phrB) {
00044
00045 return A < B;
00046 }
00047 return (phrA->Compare(*phrB) > 0);
00048 }
00049 }
00050 };
00051
00052 SearchCubePruning::
00053 SearchCubePruning(Manager& manager, TranslationOptionCollection const& transOptColl)
00054 : Search(manager)
00055 , m_hypoStackColl(manager.GetSource().GetSize() + 1)
00056 , m_transOptColl(transOptColl)
00057 {
00058 std::vector < HypothesisStackCubePruning >::iterator iterStack;
00059 for (size_t ind = 0 ; ind < m_hypoStackColl.size() ; ++ind) {
00060 HypothesisStackCubePruning *sourceHypoColl = new HypothesisStackCubePruning(m_manager);
00061 sourceHypoColl->SetMaxHypoStackSize(m_options.search.stack_size);
00062 sourceHypoColl->SetBeamWidth(m_options.search.beam_width);
00063
00064 m_hypoStackColl[ind] = sourceHypoColl;
00065 }
00066 }
00067
00068 SearchCubePruning::~SearchCubePruning()
00069 {
00070 RemoveAllInColl(m_hypoStackColl);
00071 }
00072
00077 void SearchCubePruning::Decode()
00078 {
00079
00080 const Bitmap &initBitmap = m_bitmaps.GetInitialBitmap();
00081 Hypothesis *hypo = new Hypothesis(m_manager, m_source, m_initialTransOpt, initBitmap, m_manager.GetNextHypoId());
00082
00083 HypothesisStackCubePruning &firstStack
00084 = *static_cast<HypothesisStackCubePruning*>(m_hypoStackColl.front());
00085 firstStack.AddInitial(hypo);
00086
00087 firstStack.CleanupArcList();
00088 CreateForwardTodos(firstStack);
00089
00090 const size_t PopLimit = m_manager.options()->cube.pop_limit;
00091 VERBOSE(2,"Cube Pruning pop limit is " << PopLimit << std::endl);
00092
00093 const size_t Diversity = m_manager.options()->cube.diversity;
00094 VERBOSE(2,"Cube Pruning diversity is " << Diversity << std::endl);
00095 VERBOSE(2,"Max Phrase length is "
00096 << m_manager.options()->search.max_phrase_length << std::endl);
00097
00098
00099 size_t stackNo = 1;
00100 std::vector < HypothesisStack* >::iterator iterStack;
00101 for (iterStack = m_hypoStackColl.begin() + 1 ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00102
00103 if (this->out_of_time()) return;
00104
00105 HypothesisStackCubePruning &sourceHypoColl
00106 = *static_cast<HypothesisStackCubePruning*>(*iterStack);
00107
00108
00109
00110 std::priority_queue < BitmapContainer*, std::vector< BitmapContainer* >,
00111 BitmapContainerOrderer > BCQueue;
00112
00113 _BMType::const_iterator bmIter;
00114 const _BMType &accessor = sourceHypoColl.GetBitmapAccessor();
00115
00116 for(bmIter = accessor.begin(); bmIter != accessor.end(); ++bmIter) {
00117
00118 IFVERBOSE(2) {
00119 m_manager.GetSentenceStats().StartTimeOtherScore();
00120 }
00121 bmIter->second->InitializeEdges();
00122 IFVERBOSE(2) {
00123 m_manager.GetSentenceStats().StopTimeOtherScore();
00124 }
00125 m_manager.GetSentenceStats().StartTimeManageCubes();
00126 BCQueue.push(bmIter->second);
00127 m_manager.GetSentenceStats().StopTimeManageCubes();
00128
00129 }
00130
00131
00132 for (size_t numpops = 1; numpops <= PopLimit && !BCQueue.empty(); numpops++) {
00133
00134 m_manager.GetSentenceStats().StartTimeManageCubes();
00135 BitmapContainer *bc = BCQueue.top();
00136 BCQueue.pop();
00137 m_manager.GetSentenceStats().StopTimeManageCubes();
00138 IFVERBOSE(2) {
00139 m_manager.GetSentenceStats().AddPopped();
00140 }
00141
00142 IFVERBOSE(2) {
00143 m_manager.GetSentenceStats().StartTimeOtherScore();
00144 }
00145 bc->ProcessBestHypothesis();
00146 IFVERBOSE(2) {
00147 m_manager.GetSentenceStats().StopTimeOtherScore();
00148 }
00149
00150 m_manager.GetSentenceStats().StartTimeManageCubes();
00151 if (!bc->Empty())
00152 BCQueue.push(bc);
00153 m_manager.GetSentenceStats().StopTimeManageCubes();
00154 }
00155
00156
00157
00158 if (Diversity > 0) {
00159 IFVERBOSE(2) {
00160 m_manager.GetSentenceStats().StartTimeOtherScore();
00161 }
00162 for(bmIter = accessor.begin(); bmIter != accessor.end(); ++bmIter) {
00163 bmIter->second->EnsureMinStackHyps(Diversity);
00164 }
00165 IFVERBOSE(2) {
00166 m_manager.GetSentenceStats().StopTimeOtherScore();
00167 }
00168 }
00169
00170
00171 VERBOSE(3,"processing hypothesis from next stack");
00172 IFVERBOSE(2) {
00173 m_manager.GetSentenceStats().StartTimeStack();
00174 }
00175 sourceHypoColl.PruneToSize(m_options.search.stack_size);
00176 VERBOSE(3,std::endl);
00177 sourceHypoColl.CleanupArcList();
00178 IFVERBOSE(2) {
00179 m_manager.GetSentenceStats().StopTimeStack();
00180 }
00181
00182 IFVERBOSE(2) {
00183 m_manager.GetSentenceStats().StartTimeSetupCubes();
00184 }
00185 CreateForwardTodos(sourceHypoColl);
00186 IFVERBOSE(2) {
00187 m_manager.GetSentenceStats().StopTimeSetupCubes();
00188 }
00189
00190 stackNo++;
00191 }
00192 }
00193
00194 void SearchCubePruning::CreateForwardTodos(HypothesisStackCubePruning &stack)
00195 {
00196 const _BMType &bitmapAccessor = stack.GetBitmapAccessor();
00197 _BMType::const_iterator iterAccessor;
00198 size_t size = m_source.GetSize();
00199
00200 stack.AddHypothesesToBitmapContainers();
00201
00202 for (iterAccessor = bitmapAccessor.begin() ; iterAccessor != bitmapAccessor.end() ; ++iterAccessor) {
00203 const Bitmap &bitmap = *iterAccessor->first;
00204 BitmapContainer &bitmapContainer = *iterAccessor->second;
00205
00206 if (bitmapContainer.GetHypothesesSize() == 0) {
00207
00208 continue;
00209 }
00210
00211
00212 bitmapContainer.SortHypotheses();
00213
00214
00215 size_t startPos, endPos;
00216 for (startPos = 0 ; startPos < size ; startPos++) {
00217 if (bitmap.GetValue(startPos))
00218 continue;
00219
00220
00221 Range applyRange(startPos, startPos);
00222 if (CheckDistortion(bitmap, applyRange)) {
00223
00224 CreateForwardTodos(bitmap, applyRange, bitmapContainer);
00225 }
00226
00227 size_t maxSize = size - startPos;
00228 size_t maxSizePhrase = m_manager.options()->search.max_phrase_length;
00229 maxSize = std::min(maxSize, maxSizePhrase);
00230 for (endPos = startPos+1; endPos < startPos + maxSize; endPos++) {
00231 if (bitmap.GetValue(endPos))
00232 break;
00233
00234 Range applyRange(startPos, endPos);
00235 if (CheckDistortion(bitmap, applyRange)) {
00236
00237 CreateForwardTodos(bitmap, applyRange, bitmapContainer);
00238 }
00239 }
00240 }
00241 }
00242 }
00243
00244 void
00245 SearchCubePruning::
00246 CreateForwardTodos(Bitmap const& bitmap, Range const& range,
00247 BitmapContainer& bitmapContainer)
00248 {
00249 const Bitmap &newBitmap = m_bitmaps.GetBitmap(bitmap, range);
00250
00251 size_t numCovered = newBitmap.GetNumWordsCovered();
00252 const TranslationOptionList* transOptList;
00253 transOptList = m_transOptColl.GetTranslationOptionList(range);
00254 const SquareMatrix &estimatedScores = m_transOptColl.GetEstimatedScores();
00255
00256 if (transOptList && transOptList->size() > 0) {
00257 HypothesisStackCubePruning& newStack
00258 = *static_cast<HypothesisStackCubePruning*>(m_hypoStackColl[numCovered]);
00259 newStack.SetBitmapAccessor(newBitmap, newStack, range, bitmapContainer,
00260 estimatedScores, *transOptList);
00261 }
00262 }
00263
00264 bool
00265 SearchCubePruning::
00266 CheckDistortion(const Bitmap &hypoBitmap, const Range &range) const
00267 {
00268
00269 int maxDistortion = m_manager.options()->reordering.max_distortion;
00270 if (maxDistortion < 0) return true;
00271
00272
00273
00274 size_t const startPos = range.GetStartPos();
00275 size_t const endPos = range.GetEndPos();
00276
00277
00278
00279 if (!m_source.GetReorderingConstraint().Check(hypoBitmap, startPos, endPos))
00280 return false;
00281
00282 size_t const hypoFirstGapPos = hypoBitmap.GetFirstGapPos();
00283
00284 if (hypoFirstGapPos == startPos) return true;
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294 Range bestNextExtension(hypoFirstGapPos, hypoFirstGapPos);
00295 return (m_source.ComputeDistortionDistance(range, bestNextExtension)
00296 <= maxDistortion);
00297 }
00298
00303 Hypothesis const*
00304 SearchCubePruning::
00305 GetBestHypothesis() const
00306 {
00307
00308 const HypothesisStack &hypoColl = *m_hypoStackColl.back();
00309 return hypoColl.GetBestHypothesis();
00310 }
00311
00315 void
00316 SearchCubePruning::
00317 OutputHypoStackSize()
00318 {
00319 std::vector < HypothesisStack* >::const_iterator iterStack = m_hypoStackColl.begin();
00320 TRACE_ERR( "Stack sizes: " << (int)(*iterStack)->size());
00321 for (++iterStack; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00322 TRACE_ERR( ", " << (int)(*iterStack)->size());
00323 }
00324 TRACE_ERR( endl);
00325 }
00326
00327 void SearchCubePruning::PrintBitmapContainerGraph()
00328 {
00329 HypothesisStackCubePruning &lastStack = *static_cast<HypothesisStackCubePruning*>(m_hypoStackColl.back());
00330 const _BMType &bitmapAccessor = lastStack.GetBitmapAccessor();
00331
00332 _BMType::const_iterator iterAccessor;
00333 for (iterAccessor = bitmapAccessor.begin(); iterAccessor != bitmapAccessor.end(); ++iterAccessor) {
00334 cerr << iterAccessor->first << endl;
00335
00336 }
00337
00338 }
00339
00344 void SearchCubePruning::OutputHypoStack(int stack)
00345 {
00346 if (stack >= 0) {
00347 TRACE_ERR( "Stack " << stack << ": " << endl << m_hypoStackColl[stack] << endl);
00348 } else {
00349
00350 int i = 0;
00351 vector < HypothesisStack* >::iterator iterStack;
00352 for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00353 HypothesisStackCubePruning &hypoColl = *static_cast<HypothesisStackCubePruning*>(*iterStack);
00354 TRACE_ERR( "Stack " << i++ << ": " << endl << hypoColl << endl);
00355 }
00356 }
00357 }
00358
00359 const std::vector < HypothesisStack* >& SearchCubePruning::GetHypothesisStacks() const
00360 {
00361 return m_hypoStackColl;
00362 }
00363
00364 }
00365