00001 #include "lm/model.hh"
00002
00003 #include <cstdlib>
00004 #include <cstring>
00005
00006 #define BOOST_TEST_MODULE ModelTest
00007 #include <boost/test/unit_test.hpp>
00008 #include <boost/test/floating_point_comparison.hpp>
00009
00010
00011 #define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
00012
00013 namespace lm {
00014 namespace ngram {
00015
00016 std::ostream &operator<<(std::ostream &o, const State &state) {
00017 o << "State length " << static_cast<unsigned int>(state.length) << ':';
00018 for (const WordIndex *i = state.words; i < state.words + state.length; ++i) {
00019 o << ' ' << *i;
00020 }
00021 return o;
00022 }
00023
00024 namespace {
00025
00026
00027 const char *TestLocation() {
00028 if (boost::unit_test::framework::master_test_suite().argc < 3) {
00029 return "test.arpa";
00030 }
00031 char **argv = boost::unit_test::framework::master_test_suite().argv;
00032 return argv[strstr(argv[1], "nounk") ? 2 : 1];
00033 }
00034 const char *TestNoUnkLocation() {
00035 if (boost::unit_test::framework::master_test_suite().argc < 3) {
00036 return "test_nounk.arpa";
00037 }
00038 char **argv = boost::unit_test::framework::master_test_suite().argv;
00039 return argv[strstr(argv[1], "nounk") ? 1 : 2];
00040 }
00041
00042 template <class Model> State GetState(const Model &model, const char *word, const State &in) {
00043 WordIndex context[in.length + 1];
00044 context[0] = model.GetVocabulary().Index(word);
00045 std::copy(in.words, in.words + in.length, context + 1);
00046 State ret;
00047 model.GetState(context, context + in.length + 1, ret);
00048 return ret;
00049 }
00050
00051 #define StartTest(word, ngram, score, indep_left) \
00052 ret = model.FullScore( \
00053 state, \
00054 model.GetVocabulary().Index(word), \
00055 out);\
00056 SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
00057 BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
00058 BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
00059 BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
00060 BOOST_CHECK_EQUAL(out, GetState(model, word, state));
00061
00062 #define AppendTest(word, ngram, score, indep_left) \
00063 StartTest(word, ngram, score, indep_left) \
00064 state = out;
00065
00066 template <class M> void Starters(const M &model) {
00067 FullScoreReturn ret;
00068 Model::State state(model.BeginSentenceState());
00069 Model::State out;
00070
00071 StartTest("looking", 2, -0.4846522, true);
00072
00073
00074 StartTest(",", 1, -1.383514 + -0.4149733, true);
00075
00076 StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true);
00077 }
00078
00079 template <class M> void Continuation(const M &model) {
00080 FullScoreReturn ret;
00081 Model::State state(model.BeginSentenceState());
00082 Model::State out;
00083
00084 AppendTest("looking", 2, -0.484652, true);
00085 AppendTest("on", 3, -0.348837, true);
00086 AppendTest("a", 4, -0.0155266, true);
00087 AppendTest("little", 5, -0.00306122, true);
00088 State preserve = state;
00089 AppendTest("the", 1, -4.04005, true);
00090 AppendTest("biarritz", 1, -1.9889, true);
00091 AppendTest("not_found", 1, -2.29666, true);
00092 AppendTest("more", 1, -1.20632 - 20.0, true);
00093 AppendTest(".", 2, -0.51363, true);
00094 AppendTest("</s>", 3, -0.0191651, true);
00095 BOOST_CHECK_EQUAL(0, state.length);
00096
00097 state = preserve;
00098 AppendTest("more", 5, -0.00181395, true);
00099 BOOST_CHECK_EQUAL(4, state.length);
00100 AppendTest("loin", 5, -0.0432557, true);
00101 BOOST_CHECK_EQUAL(1, state.length);
00102 }
00103
00104 template <class M> void Blanks(const M &model) {
00105 FullScoreReturn ret;
00106 State state(model.NullContextState());
00107 State out;
00108 AppendTest("also", 1, -1.687872, false);
00109 AppendTest("would", 2, -2, true);
00110 AppendTest("consider", 3, -3, true);
00111 State preserve = state;
00112 AppendTest("higher", 4, -4, true);
00113 AppendTest("looking", 5, -5, true);
00114 BOOST_CHECK_EQUAL(1, state.length);
00115
00116 state = preserve;
00117
00118 AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true);
00119
00120 state = model.NullContextState();
00121
00122 AppendTest("higher", 1, -1.509559, false);
00123 AppendTest("looking", 2, -1.285941 - 0.30103, false);
00124
00125 State higher_looking = state;
00126
00127 BOOST_CHECK_EQUAL(1, state.length);
00128 AppendTest("not_found", 1, -1.995635 - 0.4771212, true);
00129
00130 state = higher_looking;
00131
00132 AppendTest("consider", 1, -1.687872 - 0.4771212, true);
00133
00134 state = model.NullContextState();
00135 AppendTest("would", 1, -1.687872, false);
00136 BOOST_CHECK_EQUAL(1, state.length);
00137 AppendTest("consider", 2, -1.687872 -0.30103, false);
00138 BOOST_CHECK_EQUAL(2, state.length);
00139 AppendTest("higher", 3, -1.509559 - 0.30103, false);
00140 BOOST_CHECK_EQUAL(3, state.length);
00141 AppendTest("looking", 4, -1.285941 - 0.30103, false);
00142 }
00143
00144 template <class M> void Unknowns(const M &model) {
00145 FullScoreReturn ret;
00146 State state(model.NullContextState());
00147 State out;
00148
00149 AppendTest("not_found", 1, -1.995635, false);
00150 State preserve = state;
00151 AppendTest("not_found2", 2, -15.0, true);
00152 AppendTest("not_found3", 2, -15.0 - 2.0, true);
00153
00154 state = preserve;
00155 AppendTest("however", 2, -4, true);
00156 AppendTest("not_found3", 3, -6, true);
00157 }
00158
00159 template <class M> void MinimalState(const M &model) {
00160 FullScoreReturn ret;
00161 State state(model.NullContextState());
00162 State out;
00163
00164 AppendTest("baz", 1, -6.535897, true);
00165 BOOST_CHECK_EQUAL(0, state.length);
00166 state = model.NullContextState();
00167 AppendTest("foo", 1, -3.141592, true);
00168 BOOST_CHECK_EQUAL(1, state.length);
00169 AppendTest("bar", 2, -6.0, true);
00170
00171 BOOST_CHECK_EQUAL(1, state.length);
00172 AppendTest("bar", 1, -2.718281 + 3.0, true);
00173 BOOST_CHECK_EQUAL(1, state.length);
00174
00175 state = model.NullContextState();
00176 AppendTest("to", 1, -1.687872, false);
00177 AppendTest("look", 2, -0.2922095, true);
00178 BOOST_CHECK_EQUAL(2, state.length);
00179 AppendTest("a", 3, -7, true);
00180 }
00181
00182 template <class M> void ExtendLeftTest(const M &model) {
00183 State right;
00184 FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right));
00185 const float kLittleProb = -1.285941;
00186 SLOPPY_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
00187 unsigned char next_use;
00188 float backoff_out[4];
00189
00190 FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
00191 BOOST_CHECK_EQUAL(0, next_use);
00192 BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
00193 SLOPPY_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
00194 BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
00195
00196 const WordIndex a = model.GetVocabulary().Index("a");
00197 float backoff_in = 3.14;
00198
00199 FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
00200 BOOST_CHECK_EQUAL(1, next_use);
00201 SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
00202 SLOPPY_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
00203 BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
00204 BOOST_CHECK(!extend_a.independent_left);
00205
00206 const WordIndex on = model.GetVocabulary().Index("on");
00207 FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
00208 BOOST_CHECK_EQUAL(1, next_use);
00209 SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
00210 SLOPPY_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
00211 BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
00212 BOOST_CHECK(!extend_on.independent_left);
00213
00214 const WordIndex both[2] = {a, on};
00215 float backoff_in_arr[4];
00216 FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use));
00217 BOOST_CHECK_EQUAL(2, next_use);
00218 SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
00219 SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
00220 SLOPPY_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
00221 BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
00222 BOOST_CHECK(!extend_both.independent_left);
00223 BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
00224 }
00225
00226 #define StatelessTest(word, provide, ngram, score) \
00227 ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
00228 SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
00229 BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
00230 model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
00231 ret = model.FullScore(before, indices[num_words - word - 1], out); \
00232 BOOST_CHECK(state == out); \
00233 SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
00234 BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
00235
00236 template <class M> void Stateless(const M &model) {
00237 const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
00238 const size_t num_words = sizeof(words) / sizeof(const char*);
00239
00240 WordIndex indices[num_words + 1];
00241 for (unsigned int i = 0; i < num_words; ++i) {
00242 indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]);
00243 }
00244 FullScoreReturn ret;
00245 State state, out, before;
00246
00247 ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
00248 SLOPPY_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
00249 StatelessTest(1, 1, 2, -0.484652);
00250
00251
00252 StatelessTest(1, 2, 2, -0.484652);
00253
00254 AppendTest("on", 3, -0.348837, true);
00255 StatelessTest(2, 3, 3, -0.348837);
00256 StatelessTest(2, 2, 3, -0.348837);
00257 StatelessTest(2, 1, 2, -0.4638903);
00258
00259 StatelessTest(3, 4, 4, -0.0155266);
00260
00261 AppendTest("little", 5, -0.00306122, true);
00262 StatelessTest(4, 5, 5, -0.00306122);
00263
00264 AppendTest("the", 1, -4.04005, true);
00265 StatelessTest(5, 5, 1, -4.04005);
00266
00267 StatelessTest(5, 0, 1, -1.687872);
00268
00269 StatelessTest(6, 1, 1, -1.9889);
00270
00271 StatelessTest(7, 1, 1, -2.29666);
00272 StatelessTest(7, 0, 1, -1.995635);
00273
00274 WordIndex unk[1];
00275 unk[0] = 0;
00276 model.GetState(unk, unk + 1, state);
00277 BOOST_CHECK_EQUAL(1, state.length);
00278 BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.words[0]);
00279 }
00280
00281 template <class M> void NoUnkCheck(const M &model) {
00282 WordIndex unk_index = 0;
00283 State state;
00284
00285 FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
00286 SLOPPY_CHECK_CLOSE(-100.0, ret.prob, 0.001);
00287 }
00288
00289 template <class M> void Everything(const M &m) {
00290 Starters(m);
00291 Continuation(m);
00292 Blanks(m);
00293 Unknowns(m);
00294 MinimalState(m);
00295 ExtendLeftTest(m);
00296 Stateless(m);
00297 }
00298
00299 class ExpectEnumerateVocab : public EnumerateVocab {
00300 public:
00301 ExpectEnumerateVocab() {}
00302
00303 void Add(WordIndex index, const StringPiece &str) {
00304 BOOST_CHECK_EQUAL(seen.size(), index);
00305 seen.push_back(std::string(str.data(), str.length()));
00306 }
00307
00308 void Check(const base::Vocabulary &vocab) {
00309 BOOST_CHECK_EQUAL(37ULL, seen.size());
00310 BOOST_REQUIRE(!seen.empty());
00311 BOOST_CHECK_EQUAL("<unk>", seen[0]);
00312 for (WordIndex i = 0; i < seen.size(); ++i) {
00313 BOOST_CHECK_EQUAL(i, vocab.Index(seen[i]));
00314 }
00315 }
00316
00317 void Clear() {
00318 seen.clear();
00319 }
00320
00321 std::vector<std::string> seen;
00322 };
00323
00324 template <class ModelT> void LoadingTest() {
00325 Config config;
00326 config.arpa_complain = Config::NONE;
00327 config.messages = NULL;
00328 config.probing_multiplier = 2.0;
00329 {
00330 ExpectEnumerateVocab enumerate;
00331 config.enumerate_vocab = &enumerate;
00332 ModelT m(TestLocation(), config);
00333 enumerate.Check(m.GetVocabulary());
00334 BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
00335 Everything(m);
00336 }
00337 {
00338 ExpectEnumerateVocab enumerate;
00339 config.enumerate_vocab = &enumerate;
00340 ModelT m(TestNoUnkLocation(), config);
00341 enumerate.Check(m.GetVocabulary());
00342 BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
00343 NoUnkCheck(m);
00344 }
00345 }
00346
00347 BOOST_AUTO_TEST_CASE(probing) {
00348 LoadingTest<Model>();
00349 }
00350 BOOST_AUTO_TEST_CASE(trie) {
00351 LoadingTest<TrieModel>();
00352 }
00353 BOOST_AUTO_TEST_CASE(quant_trie) {
00354 LoadingTest<QuantTrieModel>();
00355 }
00356 BOOST_AUTO_TEST_CASE(bhiksha_trie) {
00357 LoadingTest<ArrayTrieModel>();
00358 }
00359 BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
00360 LoadingTest<QuantArrayTrieModel>();
00361 }
00362
00363 template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
00364 Config config;
00365 config.write_mmap = "test.binary";
00366 config.messages = NULL;
00367 config.write_method = write_method;
00368 ExpectEnumerateVocab enumerate;
00369 config.enumerate_vocab = &enumerate;
00370
00371 {
00372 ModelT copy_model(TestLocation(), config);
00373 enumerate.Check(copy_model.GetVocabulary());
00374 enumerate.Clear();
00375 Everything(copy_model);
00376 }
00377
00378 config.write_mmap = NULL;
00379
00380 ModelType type;
00381 BOOST_REQUIRE(RecognizeBinary("test.binary", type));
00382 BOOST_CHECK_EQUAL(ModelT::kModelType, type);
00383
00384 {
00385 ModelT binary("test.binary", config);
00386 enumerate.Check(binary.GetVocabulary());
00387 Everything(binary);
00388 }
00389 unlink("test.binary");
00390
00391
00392 config.write_mmap = "test_nounk.binary";
00393 config.messages = NULL;
00394 enumerate.Clear();
00395 {
00396 ModelT copy_model(TestNoUnkLocation(), config);
00397 enumerate.Check(copy_model.GetVocabulary());
00398 enumerate.Clear();
00399 NoUnkCheck(copy_model);
00400 }
00401 config.write_mmap = NULL;
00402 {
00403 ModelT binary(TestNoUnkLocation(), config);
00404 enumerate.Check(binary.GetVocabulary());
00405 NoUnkCheck(binary);
00406 }
00407 unlink("test_nounk.binary");
00408 }
00409
00410 template <class ModelT> void BinaryTest() {
00411 BinaryTest<ModelT>(Config::WRITE_MMAP);
00412 BinaryTest<ModelT>(Config::WRITE_AFTER);
00413 }
00414
00415 BOOST_AUTO_TEST_CASE(write_and_read_probing) {
00416 BinaryTest<ProbingModel>();
00417 }
00418 BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) {
00419 BinaryTest<RestProbingModel>();
00420 }
00421 BOOST_AUTO_TEST_CASE(write_and_read_trie) {
00422 BinaryTest<TrieModel>();
00423 }
00424 BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
00425 BinaryTest<QuantTrieModel>();
00426 }
00427 BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
00428 BinaryTest<ArrayTrieModel>();
00429 }
00430 BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
00431 BinaryTest<QuantArrayTrieModel>();
00432 }
00433
00434 BOOST_AUTO_TEST_CASE(rest_max) {
00435 Config config;
00436 config.arpa_complain = Config::NONE;
00437 config.messages = NULL;
00438
00439 RestProbingModel model(TestLocation(), config);
00440 State state, out;
00441 FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
00442 SLOPPY_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
00443 SLOPPY_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
00444 }
00445
00446 }
00447 }
00448 }