00001 #include "SearchNormalBatch.h"
00002 #include "LM/Base.h"
00003 #include "Manager.h"
00004 #include "Hypothesis.h"
00005 #include "util/exception.hh"
00006
00007
00008
00009 using namespace std;
00010
00011 namespace Moses
00012 {
00013 SearchNormalBatch::SearchNormalBatch(Manager& manager, const InputType &source, const TranslationOptionCollection &transOptColl)
00014 :SearchNormal(manager, source, transOptColl)
00015 ,m_batch_size(10000)
00016 {
00017 m_max_stack_size = StaticData::Instance().GetMaxHypoStackSize();
00018
00019
00020
00021 const vector<const StatefulFeatureFunction*>& ffs =
00022 StatefulFeatureFunction::GetStatefulFeatureFunctions();
00023 for (unsigned i = 0; i < ffs.size(); ++i) {
00024 if (ffs[i]->GetScoreProducerDescription() == "DLM_5gram") {
00025 m_dlm_ffs[i] = const_cast<LanguageModel*>(static_cast<const LanguageModel* const>(ffs[i]));
00026 m_dlm_ffs[i]->SetFFStateIdx(i);
00027 } else {
00028 m_stateful_ffs[i] = const_cast<StatefulFeatureFunction*>(ffs[i]);
00029 }
00030 }
00031 m_stateless_ffs = StatelessFeatureFunction::GetStatelessFeatureFunctions();
00032
00033 }
00034
00035 SearchNormalBatch::~SearchNormalBatch()
00036 {
00037 }
00038
00043 void SearchNormalBatch::Decode()
00044 {
00045 const StaticData &staticData = StaticData::Instance();
00046 SentenceStats &stats = m_manager.GetSentenceStats();
00047
00048
00049 Hypothesis *hypo = Hypothesis::Create(m_manager,m_source, m_initialTransOpt);
00050 m_hypoStackColl[0]->AddPrune(hypo);
00051
00052
00053 std::vector < HypothesisStack* >::iterator iterStack;
00054 for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
00055
00056 double _elapsed_time = GetUserTime();
00057 if (_elapsed_time > staticData.GetTimeoutThreshold()) {
00058 VERBOSE(1,"Decoding is out of time (" << _elapsed_time << "," << staticData.GetTimeoutThreshold() << ")" << std::endl);
00059 interrupted_flag = 1;
00060 return;
00061 }
00062 HypothesisStackNormal &sourceHypoColl = *static_cast<HypothesisStackNormal*>(*iterStack);
00063
00064
00065 VERBOSE(3,"processing hypothesis from next stack");
00066 IFVERBOSE(2) {
00067 stats.StartTimeStack();
00068 }
00069 sourceHypoColl.PruneToSize(staticData.GetMaxHypoStackSize());
00070 VERBOSE(3,std::endl);
00071 sourceHypoColl.CleanupArcList();
00072 IFVERBOSE(2) {
00073 stats.StopTimeStack();
00074 }
00075
00076
00077 HypothesisStackNormal::const_iterator iterHypo;
00078 for (iterHypo = sourceHypoColl.begin() ; iterHypo != sourceHypoColl.end() ; ++iterHypo) {
00079 Hypothesis &hypothesis = **iterHypo;
00080 ProcessOneHypothesis(hypothesis);
00081 }
00082 EvalAndMergePartialHypos();
00083
00084
00085 IFVERBOSE(2) {
00086 OutputHypoStackSize();
00087 }
00088
00089
00090 actual_hypoStack = &sourceHypoColl;
00091 }
00092
00093 EvalAndMergePartialHypos();
00094 }
00095
00106 void
00107 SearchNormalBatch::
00108 ExpandHypothesis(const Hypothesis &hypothesis,
00109 const TranslationOption &transOpt, float expectedScore)
00110 {
00111
00112 if (m_partial_hypos.size() >= m_batch_size) {
00113 EvalAndMergePartialHypos();
00114 }
00115
00116 const StaticData &staticData = StaticData::Instance();
00117 SentenceStats &stats = m_manager.GetSentenceStats();
00118
00119 Hypothesis *newHypo;
00120 if (! staticData.UseEarlyDiscarding()) {
00121
00122 IFVERBOSE(2) {
00123 stats.StartTimeBuildHyp();
00124 }
00125 newHypo = hypothesis.CreateNext(transOpt);
00126 IFVERBOSE(2) {
00127 stats.StopTimeBuildHyp();
00128 }
00129 if (newHypo==NULL) return;
00130
00131
00132
00133
00134 std::map<int, LanguageModel*>::iterator dlm_iter;
00135 for (dlm_iter = m_dlm_ffs.begin();
00136 dlm_iter != m_dlm_ffs.end();
00137 ++dlm_iter) {
00138 const FFState* input_state = newHypo->GetPrevHypo() ? newHypo->GetPrevHypo()->GetFFState((*dlm_iter).first) : NULL;
00139 (*dlm_iter).second->IssueRequestsFor(*newHypo, input_state);
00140 }
00141 m_partial_hypos.push_back(newHypo);
00142 } else {
00143 UTIL_THROW2("can't use early discarding with batch decoding!");
00144 }
00145 }
00146
00147 void SearchNormalBatch::EvalAndMergePartialHypos()
00148 {
00149 std::vector<Hypothesis*>::iterator partial_hypo_iter;
00150 for (partial_hypo_iter = m_partial_hypos.begin();
00151 partial_hypo_iter != m_partial_hypos.end();
00152 ++partial_hypo_iter) {
00153 Hypothesis* hypo = *partial_hypo_iter;
00154
00155
00156 std::map<int, StatefulFeatureFunction*>::iterator sfff_iter;
00157 for (sfff_iter = m_stateful_ffs.begin();
00158 sfff_iter != m_stateful_ffs.end();
00159 ++sfff_iter) {
00160 const StatefulFeatureFunction &ff = *(sfff_iter->second);
00161 int state_idx = sfff_iter->first;
00162 hypo->EvaluateWhenApplied(ff, state_idx);
00163 }
00164 std::vector<const StatelessFeatureFunction*>::iterator slff_iter;
00165 for (slff_iter = m_stateless_ffs.begin();
00166 slff_iter != m_stateless_ffs.end();
00167 ++slff_iter) {
00168 hypo->EvaluateWhenApplied(**slff_iter);
00169 }
00170 }
00171
00172
00173 std::map<int, LanguageModel*>::iterator dlm_iter;
00174 for (dlm_iter = m_dlm_ffs.begin();
00175 dlm_iter != m_dlm_ffs.end();
00176 ++dlm_iter) {
00177 (*dlm_iter).second->sync();
00178 }
00179
00180
00181
00182 for (partial_hypo_iter = m_partial_hypos.begin();
00183 partial_hypo_iter != m_partial_hypos.end();
00184 ++partial_hypo_iter) {
00185 Hypothesis* hypo = *partial_hypo_iter;
00186
00187
00188 std::map<int, LanguageModel*>::iterator dlm_iter;
00189 for (dlm_iter = m_dlm_ffs.begin();
00190 dlm_iter != m_dlm_ffs.end();
00191 ++dlm_iter) {
00192 LanguageModel &lm = *(dlm_iter->second);
00193 hypo->EvaluateWhenApplied(lm, (*dlm_iter).first);
00194 }
00195
00196
00197 size_t wordsTranslated = hypo->GetWordsBitmap().GetNumWordsCovered();
00198 m_hypoStackColl[wordsTranslated]->AddPrune(hypo);
00199 }
00200 m_partial_hypos.clear();
00201
00202 std::vector < HypothesisStack* >::iterator stack_iter;
00203 HypothesisStackNormal* stack;
00204 for (stack_iter = m_hypoStackColl.begin();
00205 stack_iter != m_hypoStackColl.end();
00206 ++stack_iter) {
00207 stack = static_cast<HypothesisStackNormal*>(*stack_iter);
00208 stack->PruneToSize(m_max_stack_size);
00209 }
00210 }
00211
00212 }