00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include "ChartTranslationOptions.h"
00021 #include "ChartHypothesis.h"
00022 #include "ChartCellLabel.h"
00023 #include "ChartTranslationOption.h"
00024 #include "InputPath.h"
00025 #include "StaticData.h"
00026 #include "TranslationTask.h"
00027
00028 using namespace std;
00029
00030 namespace Moses
00031 {
00032
00033 ChartTranslationOptions::ChartTranslationOptions(const TargetPhraseCollection &targetPhraseColl,
00034 const StackVec &stackVec,
00035 const Range &range,
00036 float score)
00037 : m_stackVec(stackVec)
00038 , m_wordsRange(&range)
00039 , m_estimateOfBestScore(score)
00040 {
00041 TargetPhraseCollection::const_iterator iter;
00042 for (iter = targetPhraseColl.begin(); iter != targetPhraseColl.end(); ++iter) {
00043 const TargetPhrase *origTP = *iter;
00044
00045 boost::shared_ptr<ChartTranslationOption> ptr(new ChartTranslationOption(*origTP));
00046 m_collection.push_back(ptr);
00047 }
00048 }
00049
00050 ChartTranslationOptions::~ChartTranslationOptions()
00051 {
00052
00053 }
00054
00056 class ChartTranslationOptionScoreOrderer
00057 {
00058 public:
00059 bool operator()(const boost::shared_ptr<ChartTranslationOption> &transOptA
00060 , const boost::shared_ptr<ChartTranslationOption> &transOptB) const {
00061 const ScoreComponentCollection &scoresA = transOptA->GetScores();
00062 const ScoreComponentCollection &scoresB = transOptB->GetScores();
00063 return scoresA.GetWeightedScore() > scoresB.GetWeightedScore();
00064 }
00065 };
00066
00067 void ChartTranslationOptions::EvaluateWithSourceContext(const InputType &input, const InputPath &inputPath)
00068 {
00069 SetInputPath(&inputPath);
00070
00071 if (inputPath.ttask->options()->input.placeholder_factor != NOT_FOUND) {
00072 CreateSourceRuleFromInputPath();
00073 }
00074
00075 CollType::iterator iter;
00076 for (iter = m_collection.begin(); iter != m_collection.end(); ++iter) {
00077 ChartTranslationOption &transOpt = **iter;
00078 transOpt.SetInputPath(&inputPath);
00079 transOpt.EvaluateWithSourceContext(input, inputPath, m_stackVec);
00080 }
00081
00082
00083 size_t numDiscard = 0;
00084 for (size_t i = 0; i < m_collection.size(); ++i) {
00085 ChartTranslationOption *transOpt = m_collection[i].get();
00086
00087 if (transOpt->GetScores().GetWeightedScore() == - std::numeric_limits<float>::infinity()) {
00088 ++numDiscard;
00089 } else if (numDiscard) {
00090 m_collection[i - numDiscard] = m_collection[i];
00091 }
00092 }
00093
00094 size_t newSize = m_collection.size() - numDiscard;
00095 m_collection.resize(newSize);
00096
00097
00098 const StaticData &staticData = StaticData::Instance();
00099 if (staticData.RequireSortingAfterSourceContext()) {
00100 std::sort(m_collection.begin()
00101 , m_collection.begin() + newSize
00102 , ChartTranslationOptionScoreOrderer());
00103 }
00104
00105 }
00106
00107 void ChartTranslationOptions::SetInputPath(const InputPath *inputPath)
00108 {
00109 CollType::iterator iter;
00110 for (iter = m_collection.begin(); iter != m_collection.end(); ++iter) {
00111 ChartTranslationOption &transOpt = **iter;
00112 transOpt.SetInputPath(inputPath);
00113 }
00114 }
00115
00116 void ChartTranslationOptions::CreateSourceRuleFromInputPath()
00117 {
00118 if (m_collection.size() == 0) {
00119 return;
00120 }
00121
00122 const InputPath *inputPath = m_collection.front()->GetInputPath();
00123 assert(inputPath);
00124 std::vector<const Word*> &ruleSourceFromInputPath = inputPath->AddRuleSourceFromInputPath();
00125
00126 size_t chartCellIndex = 0;
00127 const ChartCellLabel *chartCellLabel = (chartCellIndex < m_stackVec.size()) ? m_stackVec[chartCellIndex] : NULL;
00128
00129 size_t ind = 0;
00130 for (size_t sourcePos = m_wordsRange->GetStartPos(); sourcePos <= m_wordsRange->GetEndPos(); ++sourcePos, ++ind) {
00131 if (chartCellLabel) {
00132 if (sourcePos == chartCellLabel->GetCoverage().GetEndPos()) {
00133
00134 ruleSourceFromInputPath.push_back(NULL);
00135 ++chartCellIndex;
00136 chartCellLabel = (chartCellIndex < m_stackVec.size()) ? m_stackVec[chartCellIndex] : NULL;
00137 } else if (sourcePos >= chartCellLabel->GetCoverage().GetStartPos()) {
00138
00139 } else {
00140
00141 ruleSourceFromInputPath.push_back(&inputPath->GetPhrase().GetWord(ind));
00142 }
00143 } else {
00144
00145 ruleSourceFromInputPath.push_back(&inputPath->GetPhrase().GetWord(ind));
00146 }
00147 }
00148
00149
00150 CollType::iterator iter;
00151 for (iter = m_collection.begin(); iter != m_collection.end(); ++iter) {
00152 ChartTranslationOption &transOpt = **iter;
00153 transOpt.SetSourceRuleFromInputPath(&ruleSourceFromInputPath);
00154 }
00155
00156 }
00157
00158 std::ostream& operator<<(std::ostream &out, const ChartTranslationOptions &obj)
00159 {
00160 for (size_t i = 0; i < obj.m_collection.size(); ++i) {
00161 const ChartTranslationOption &transOpt = *obj.m_collection[i];
00162 out << transOpt << endl;
00163 }
00164
00165 return out;
00166 }
00167
00168 }