00001 #pragma once
00002
00003 #include <string>
00004 #include "moses/FF/StatefulFeatureFunction.h"
00005 #include "moses/FF/FFState.h"
00006 #include <boost/thread/tss.hpp>
00007 #include "moses/Hypothesis.h"
00008 #include "moses/ChartHypothesis.h"
00009 #include "moses/InputPath.h"
00010 #include "moses/Manager.h"
00011 #include "moses/ChartManager.h"
00012 #include "moses/FactorCollection.h"
00013
00014 namespace Moses
00015 {
00016
00017 class BilingualLMState : public FFState
00018 {
00019 size_t m_hash;
00020 std::vector<int> word_alignments;
00021 std::vector<int> neuralLM_ids;
00022 public:
00023 BilingualLMState(size_t hash)
00024 :m_hash(hash) {
00025 }
00026 BilingualLMState(size_t hash, std::vector<int>& word_alignments_vec, std::vector<int>& neural_ids)
00027 :m_hash(hash)
00028 , word_alignments(word_alignments_vec)
00029 , neuralLM_ids(neural_ids) {
00030 }
00031
00032 const std::vector<int>& GetWordAlignmentVector() const {
00033 return word_alignments;
00034 }
00035
00036 const std::vector<int>& GetWordIdsVector() const {
00037 return neuralLM_ids;
00038 }
00039
00040 virtual size_t hash() const {
00041 return m_hash;
00042 }
00043 virtual bool operator==(const FFState& other) const {
00044 const BilingualLMState &otherState = static_cast<const BilingualLMState&>(other);
00045 return m_hash == otherState.m_hash;
00046 }
00047
00048 };
00049
00050 class BilingualLM : public StatefulFeatureFunction
00051 {
00052 private:
00053 virtual float Score(std::vector<int>& source_words, std::vector<int>& target_words) const = 0;
00054
00055 virtual int getNeuralLMId(const Word& word, bool is_source_word) const = 0;
00056
00057 virtual void loadModel() = 0;
00058
00059 virtual const Word& getNullWord() const = 0;
00060
00061 size_t selectMiddleAlignment(const std::set<size_t>& alignment_links) const;
00062
00063 void getSourceWords(
00064 const TargetPhrase &targetPhrase,
00065 int targetWordIdx,
00066 const Sentence &source_sent,
00067 const Range &sourceWordRange,
00068 std::vector<int> &words) const;
00069
00070 void appendSourceWordsToVector(const Sentence &source_sent, std::vector<int> &words, int source_word_mid_idx) const;
00071
00072 void getTargetWords(
00073 const Hypothesis &cur_hypo,
00074 const TargetPhrase &targetPhrase,
00075 int current_word_index,
00076 std::vector<int> &words) const;
00077
00078 size_t getState(const Hypothesis &cur_hypo) const;
00079
00080 void requestPrevTargetNgrams(const Hypothesis &cur_hypo, int amount, std::vector<int> &words) const;
00081
00082
00083 void getTargetWordsChart(
00084 std::vector<int>& neuralLMids,
00085 int current_word_index,
00086 std::vector<int>& words,
00087 bool sentence_begin) const;
00088
00089 size_t getStateChart(std::vector<int>& neuralLMids) const;
00090
00091
00092 void getAllTargetIdsChart(const ChartHypothesis& cur_hypo, size_t featureID, std::vector<int>& wordIds) const;
00093
00094 void getAllAlignments(const ChartHypothesis& cur_hypo, size_t featureID, std::vector<int>& alignemnts) const;
00095
00096 protected:
00097
00098 std::string m_filePath;
00099 int target_ngrams;
00100 int source_ngrams;
00101
00102
00103 FactorType word_factortype;
00104 FactorType pos_factortype;
00105 const Factor* BOS_factor;
00106 const Factor* EOS_factor;
00107 mutable Word BOS_word;
00108 mutable Word EOS_word;
00109
00110 public:
00111 BilingualLM(const std::string &line);
00112
00113 bool IsUseable(const FactorMask &mask) const {
00114 return true;
00115 }
00116 virtual const FFState* EmptyHypothesisState(const InputType &input) const {
00117 return new BilingualLMState(0);
00118 }
00119
00120 void Load(AllOptions::ptr const& opts);
00121
00122 FFState* EvaluateWhenApplied(
00123 const Hypothesis& cur_hypo,
00124 const FFState* prev_state,
00125 ScoreComponentCollection* accumulator) const;
00126
00127 FFState* EvaluateWhenApplied(
00128 const ChartHypothesis& cur_hypo ,
00129 int featureID,
00130 ScoreComponentCollection* accumulator) const;
00131
00132 void SetParameter(const std::string& key, const std::string& value);
00133 };
00134
00135 }
00136