00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include <cassert>
00023 #include <limits>
00024 #include <iostream>
00025 #include <fstream>
00026
00027 #include "LanguageModelSRI.h"
00028 #include "TypeDef.h"
00029 #include "Util.h"
00030 #include "FactorCollection.h"
00031 #include "Phrase.h"
00032 #include "StaticData.h"
00033
00034 using namespace std;
00035
00036 namespace Moses
00037 {
00038 LanguageModelSRI::LanguageModelSRI()
00039 : m_srilmVocab(0)
00040 , m_srilmModel(0)
00041 {
00042 }
00043
00044 LanguageModelSRI::~LanguageModelSRI()
00045 {
00046 delete m_srilmModel;
00047 delete m_srilmVocab;
00048 }
00049
00050 bool LanguageModelSRI::Load(const std::string &filePath
00051 , FactorType factorType
00052 , size_t nGramOrder)
00053 {
00054 m_srilmVocab = new ::Vocab();
00055 m_srilmModel = new Ngram(*m_srilmVocab, nGramOrder);
00056 m_factorType = factorType;
00057 m_nGramOrder = nGramOrder;
00058 m_filePath = filePath;
00059
00060 m_srilmModel->skipOOVs() = false;
00061
00062 File file( filePath.c_str(), "r" );
00063 m_srilmModel->read(file);
00064
00065
00066 CreateFactors();
00067 m_unknownId = m_srilmVocab->unkIndex();
00068
00069 return true;
00070 }
00071
00072 void LanguageModelSRI::CreateFactors()
00073 {
00074
00075 FactorCollection &factorCollection = FactorCollection::Instance();
00076
00077 std::map<size_t, VocabIndex> lmIdMap;
00078 size_t maxFactorId = 0;
00079
00080 VocabString str;
00081 VocabIter iter(*m_srilmVocab);
00082 while ( (str = iter.next()) != NULL) {
00083 VocabIndex lmId = GetLmID(str);
00084 size_t factorId = factorCollection.AddFactor(Output, m_factorType, str)->GetId();
00085 lmIdMap[factorId] = lmId;
00086 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00087 }
00088
00089 size_t factorId;
00090
00091 m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_);
00092 factorId = m_sentenceStart->GetId();
00093 lmIdMap[factorId] = GetLmID(BOS_);
00094 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00095 m_sentenceStartArray[m_factorType] = m_sentenceStart;
00096
00097 m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_);
00098 factorId = m_sentenceEnd->GetId();
00099 lmIdMap[factorId] = GetLmID(EOS_);
00100 maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
00101 m_sentenceEndArray[m_factorType] = m_sentenceEnd;
00102
00103
00104 m_lmIdLookup.resize(maxFactorId+1);
00105
00106 fill(m_lmIdLookup.begin(), m_lmIdLookup.end(), m_unknownId);
00107
00108 map<size_t, VocabIndex>::iterator iterMap;
00109 for (iterMap = lmIdMap.begin() ; iterMap != lmIdMap.end() ; ++iterMap) {
00110 m_lmIdLookup[iterMap->first] = iterMap->second;
00111 }
00112 }
00113
00114 VocabIndex LanguageModelSRI::GetLmID( const std::string &str ) const
00115 {
00116 return m_srilmVocab->getIndex( str.c_str(), m_unknownId );
00117 }
00118 VocabIndex LanguageModelSRI::GetLmID( const Factor *factor ) const
00119 {
00120 size_t factorId = factor->GetId();
00121 return ( factorId >= m_lmIdLookup.size()) ? m_unknownId : m_lmIdLookup[factorId];
00122 }
00123
00124 LMResult LanguageModelSRI::GetValue(VocabIndex wordId, VocabIndex *context) const
00125 {
00126 LMResult ret;
00127 ret.score = FloorScore(TransformLMScore(m_srilmModel->wordProb( wordId, context)));
00128 ret.unknown = (wordId == m_unknownId);
00129 return ret;
00130 }
00131
00132 LMResult LanguageModelSRI::GetValue(const vector<const Word*> &contextFactor, State* finalState) const
00133 {
00134 LMResult ret;
00135 FactorType factorType = GetFactorType();
00136 size_t count = contextFactor.size();
00137 if (count <= 0) {
00138 if(finalState)
00139 *finalState = NULL;
00140 ret.score = 0.0;
00141 ret.unknown = false;
00142 return ret;
00143 }
00144
00145
00146 VocabIndex ngram[count + 1];
00147 for (size_t i = 0 ; i < count - 1 ; i++) {
00148 ngram[i+1] = GetLmID((*contextFactor[count-2-i])[factorType]);
00149 }
00150 ngram[count] = Vocab_None;
00151
00152 assert((*contextFactor[count-1])[factorType] != NULL);
00153
00154 VocabIndex lmId = GetLmID((*contextFactor[count-1])[factorType]);
00155 ret = GetValue(lmId, ngram+1);
00156
00157 if (finalState) {
00158 ngram[0] = lmId;
00159 unsigned int dummy;
00160 *finalState = m_srilmModel->contextID(ngram, dummy);
00161 }
00162 return ret;
00163 }
00164
00165 }
00166
00167
00168