00001 #include "lm/wrappers/nplm.hh"
00002 #include "util/exception.hh"
00003 #include "util/file.hh"
00004
00005 #include <algorithm>
00006 #include <cstring>
00007
00008 #include "neuralLM.h"
00009
00010 namespace lm {
00011 namespace np {
00012
00013 Vocabulary::Vocabulary(const nplm::vocabulary &vocab)
00014 : base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")),
00015 vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {}
00016
00017 Vocabulary::~Vocabulary() {}
00018
00019 WordIndex Vocabulary::Index(const std::string &str) const {
00020 return vocab_.lookup_word(str);
00021 }
00022
00023 class Backend {
00024 public:
00025 Backend(const nplm::neuralLM &from, const std::size_t cache_size) : lm_(from), ngram_(from.get_order()) {
00026 lm_.set_cache(cache_size);
00027 }
00028
00029 nplm::neuralLM &LM() { return lm_; }
00030 const nplm::neuralLM &LM() const { return lm_; }
00031
00032 Eigen::Matrix<int,Eigen::Dynamic,1> &staging_ngram() { return ngram_; }
00033
00034 double lookup_from_staging() { return lm_.lookup_ngram(ngram_); }
00035
00036 int order() const { return lm_.get_order(); }
00037
00038 private:
00039 nplm::neuralLM lm_;
00040 Eigen::Matrix<int,Eigen::Dynamic,1> ngram_;
00041 };
00042
00043 bool Model::Recognize(const std::string &name) {
00044 try {
00045 util::scoped_fd file(util::OpenReadOrThrow(name.c_str()));
00046 char magic_check[16];
00047 util::ReadOrThrow(file.get(), magic_check, sizeof(magic_check));
00048 const char nnlm_magic[] = "\\config\nversion ";
00049 return !memcmp(magic_check, nnlm_magic, 16);
00050 } catch (const util::Exception &) {
00051 return false;
00052 }
00053 }
00054
00055 namespace {
00056 nplm::neuralLM *LoadNPLM(const std::string &file) {
00057 util::scoped_ptr<nplm::neuralLM> ret(new nplm::neuralLM());
00058 ret->read(file);
00059 return ret.release();
00060 }
00061 }
00062
00063 Model::Model(const std::string &file, std::size_t cache)
00064 : base_instance_(LoadNPLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) {
00065 UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile.");
00066
00067 base_instance_->set_log_base(10.0);
00068 State begin_sentence, null_context;
00069 std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>"));
00070 null_word_ = base_instance_->lookup_word("<null>");
00071 std::fill(null_context.words, null_context.words + NPLM_MAX_ORDER - 1, null_word_);
00072
00073 Init(begin_sentence, null_context, vocab_, base_instance_->get_order());
00074 }
00075
00076 Model::~Model() {}
00077
00078 FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, State &out_state) const {
00079 Backend *backend = backend_.get();
00080 if (!backend) {
00081 backend = new Backend(*base_instance_, cache_size_);
00082 backend_.reset(backend);
00083 }
00084
00085 FullScoreReturn ret;
00086 for (int i = 0; i < backend->order() - 1; ++i) {
00087 backend->staging_ngram()(i) = from.words[i];
00088 }
00089 backend->staging_ngram()(backend->order() - 1) = new_word;
00090 ret.prob = backend->lookup_from_staging();
00091
00092 ret.ngram_length = backend->order();
00093
00094 memcpy(out_state.words, from.words + 1, sizeof(WordIndex) * (backend->order() - 2));
00095 out_state.words[backend->order() - 2] = new_word;
00096
00097 memset(out_state.words + backend->order() - 1, 0, sizeof(WordIndex) * (NPLM_MAX_ORDER - backend->order()));
00098 return ret;
00099 }
00100
00101
00102 FullScoreReturn Model::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
00103
00104 std::size_t state_length = std::min<std::size_t>(Order() - 1, context_rend - context_rbegin);
00105 State state;
00106
00107 for (lm::WordIndex *i = state.words; i < state.words + Order() - 1 - state_length; ++i) {
00108 *i = null_word_;
00109 }
00110
00111 std::reverse_copy(context_rbegin, context_rbegin + state_length, state.words + Order() - 1 - state_length);
00112 return FullScore(state, new_word, out_state);
00113 }
00114
00115 }
00116 }