00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021 #include "util/check.hh"
00022 #include "ChartTranslationOptionCollection.h"
00023 #include "ChartCellCollection.h"
00024 #include "InputType.h"
00025 #include "StaticData.h"
00026 #include "DecodeStep.h"
00027 #include "DummyScoreProducers.h"
00028 #include "Util.h"
00029
00030 using namespace std;
00031
00032 namespace Moses
00033 {
00034
00035 ChartTranslationOptionCollection::ChartTranslationOptionCollection(InputType const& source
00036 , const TranslationSystem* system
00037 , const ChartCellCollection &hypoStackColl
00038 , const std::vector<ChartRuleLookupManager*> &ruleLookupManagers)
00039 :m_source(source)
00040 ,m_system(system)
00041 ,m_decodeGraphList(system->GetDecodeGraphs())
00042 ,m_hypoStackColl(hypoStackColl)
00043 ,m_ruleLookupManagers(ruleLookupManagers)
00044 ,m_translationOptionList(StaticData::Instance().GetRuleLimit())
00045 {
00046 }
00047
00048 ChartTranslationOptionCollection::~ChartTranslationOptionCollection()
00049 {
00050 RemoveAllInColl(m_unksrcs);
00051 RemoveAllInColl(m_cacheTargetPhraseCollection);
00052 }
00053
00054 void ChartTranslationOptionCollection::CreateTranslationOptionsForRange(
00055 const WordsRange &wordsRange)
00056 {
00057 assert(m_decodeGraphList.size() == m_ruleLookupManagers.size());
00058
00059 m_translationOptionList.Clear();
00060
00061 std::vector <DecodeGraph*>::const_iterator iterDecodeGraph;
00062 std::vector <ChartRuleLookupManager*>::const_iterator iterRuleLookupManagers = m_ruleLookupManagers.begin();
00063 for (iterDecodeGraph = m_decodeGraphList.begin(); iterDecodeGraph != m_decodeGraphList.end(); ++iterDecodeGraph, ++iterRuleLookupManagers) {
00064 const DecodeGraph &decodeGraph = **iterDecodeGraph;
00065 assert(decodeGraph.GetSize() == 1);
00066 ChartRuleLookupManager &ruleLookupManager = **iterRuleLookupManagers;
00067 size_t maxSpan = decodeGraph.GetMaxChartSpan();
00068 if (maxSpan == 0 || wordsRange.GetNumWordsCovered() <= maxSpan) {
00069 ruleLookupManager.GetChartRuleCollection(wordsRange, m_translationOptionList);
00070 }
00071 }
00072
00073 if (wordsRange.GetNumWordsCovered() == 1 && wordsRange.GetStartPos() != 0 && wordsRange.GetStartPos() != m_source.GetSize()-1) {
00074 bool alwaysCreateDirectTranslationOption = StaticData::Instance().IsAlwaysCreateDirectTranslationOption();
00075 if (m_translationOptionList.GetSize() == 0 || alwaysCreateDirectTranslationOption) {
00076
00077 const Word &sourceWord = m_source.GetWord(wordsRange.GetStartPos());
00078 ProcessOneUnknownWord(sourceWord, wordsRange);
00079 }
00080 }
00081
00082 m_translationOptionList.ApplyThreshold();
00083 }
00084
00086 void ChartTranslationOptionCollection::ProcessOneUnknownWord(const Word &sourceWord, const WordsRange &range)
00087 {
00088
00089 const StaticData &staticData = StaticData::Instance();
00090 const UnknownWordPenaltyProducer *unknownWordPenaltyProducer = m_system->GetUnknownWordPenaltyProducer();
00091 vector<float> wordPenaltyScore(1, -0.434294482);
00092
00093 const ChartCell &chartCell = m_hypoStackColl.Get(range);
00094 const ChartCellLabel &sourceWordLabel = chartCell.GetSourceWordLabel();
00095
00096 size_t isDigit = 0;
00097 if (staticData.GetDropUnknown()) {
00098 const Factor *f = sourceWord[0];
00099 const string &s = f->GetString();
00100 isDigit = s.find_first_of("0123456789");
00101 if (isDigit == string::npos)
00102 isDigit = 0;
00103 else
00104 isDigit = 1;
00105
00106 }
00107
00108 Phrase* m_unksrc = new Phrase(1);
00109 m_unksrc->AddWord() = sourceWord;
00110 m_unksrcs.push_back(m_unksrc);
00111
00112
00113 if (! staticData.GetDropUnknown() || isDigit) {
00114
00115 const UnknownLHSList &lhsList = staticData.GetUnknownLHS();
00116 UnknownLHSList::const_iterator iterLHS;
00117 for (iterLHS = lhsList.begin(); iterLHS != lhsList.end(); ++iterLHS) {
00118 const string &targetLHSStr = iterLHS->first;
00119 float prob = iterLHS->second;
00120
00121
00122
00123 Word targetLHS(true);
00124
00125 targetLHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), targetLHSStr, true);
00126 CHECK(targetLHS.GetFactor(0) != NULL);
00127
00128
00129 TargetPhrase *targetPhrase = new TargetPhrase(Output);
00130 TargetPhraseCollection *tpc = new TargetPhraseCollection;
00131 tpc->Add(targetPhrase);
00132
00133 m_cacheTargetPhraseCollection.push_back(tpc);
00134 Word &targetWord = targetPhrase->AddWord();
00135 targetWord.CreateUnknownWord(sourceWord);
00136
00137
00138 vector<float> unknownScore(1, FloorScore(TransformScore(prob)));
00139
00140
00141 targetPhrase->SetScore(unknownWordPenaltyProducer, unknownScore);
00142 targetPhrase->SetScore(m_system->GetWordPenaltyProducer(), wordPenaltyScore);
00143 targetPhrase->SetSourcePhrase(m_unksrc);
00144 targetPhrase->SetTargetLHS(targetLHS);
00145
00146
00147 m_translationOptionList.Add(*tpc, m_emptyStackVec, range);
00148 }
00149 } else {
00150
00151 vector<float> unknownScore(1, FloorScore(-numeric_limits<float>::infinity()));
00152
00153 TargetPhrase *targetPhrase = new TargetPhrase(Output);
00154 TargetPhraseCollection *tpc = new TargetPhraseCollection;
00155 tpc->Add(targetPhrase);
00156
00157 const UnknownLHSList &lhsList = staticData.GetUnknownLHS();
00158 UnknownLHSList::const_iterator iterLHS;
00159 for (iterLHS = lhsList.begin(); iterLHS != lhsList.end(); ++iterLHS) {
00160 const string &targetLHSStr = iterLHS->first;
00161
00162
00163 Word targetLHS(true);
00164 targetLHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), targetLHSStr, true);
00165 CHECK(targetLHS.GetFactor(0) != NULL);
00166
00167 m_cacheTargetPhraseCollection.push_back(tpc);
00168 targetPhrase->SetSourcePhrase(m_unksrc);
00169 targetPhrase->SetScore(unknownWordPenaltyProducer, unknownScore);
00170 targetPhrase->SetTargetLHS(targetLHS);
00171
00172
00173 m_translationOptionList.Add(*tpc, m_emptyStackVec, range);
00174 }
00175 }
00176 }
00177
00178 }