00001 #include <sys/stat.h>
00002 #include <boost/foreach.hpp>
00003 #include "line_splitter.hh"
00004 #include "storing.hh"
00005 #include "StoreTarget.h"
00006 #include "StoreVocab.h"
00007 #include "moses/Util.h"
00008 #include "moses/InputFileStream.h"
00009
00010 using namespace std;
00011
00012 namespace Moses
00013 {
00014
00016 void Node::Add(Table &table, const SourcePhrase &sourcePhrase, size_t pos)
00017 {
00018 if (pos < sourcePhrase.size()) {
00019 uint64_t vocabId = sourcePhrase[pos];
00020
00021 Node *child;
00022 Children::iterator iter = m_children.find(vocabId);
00023 if (iter == m_children.end()) {
00024
00025 BOOST_FOREACH(Children::value_type &valPair, m_children) {
00026 Node &otherChild = valPair.second;
00027 otherChild.Write(table);
00028 }
00029 m_children.clear();
00030
00031
00032 child = &m_children[vocabId];
00033 assert(!child->done);
00034 child->key = key + (vocabId << pos);
00035 } else {
00036 child = &iter->second;
00037 }
00038
00039 child->Add(table, sourcePhrase, pos + 1);
00040 } else {
00041
00042 done = true;
00043 }
00044 }
00045
00046 void Node::Write(Table &table)
00047 {
00048
00049 BOOST_FOREACH(Children::value_type &valPair, m_children) {
00050 Node &child = valPair.second;
00051 child.Write(table);
00052 }
00053
00054 if (!done) {
00055
00056 Entry sourceEntry;
00057 sourceEntry.value = NONE;
00058 sourceEntry.key = key;
00059
00060
00061 table.Insert(sourceEntry);
00062 }
00063 }
00064
00066 void createProbingPT(const std::string &phrasetable_path,
00067 const std::string &basepath, int num_scores, int num_lex_scores,
00068 bool log_prob, int max_cache_size, bool scfg)
00069 {
00070 std::cerr << "Starting..." << std::endl;
00071
00072
00073 mkdir(basepath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
00074
00075 StoreTarget storeTarget(basepath);
00076
00077
00078 unsigned long uniq_entries = countUniqueSource(phrasetable_path);
00079
00080
00081 StoreVocab<uint64_t> sourceVocab(basepath + "/source_vocabids");
00082
00083
00084 util::FilePiece filein(phrasetable_path.c_str());
00085
00086
00087 size_t size = Table::Size(uniq_entries, 1.2);
00088 char * mem = new char[size];
00089 memset(mem, 0, size);
00090 Table sourceEntries(mem, size);
00091
00092 std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> cache;
00093 float totalSourceCount = 0;
00094
00095
00096 size_t line_num = 0;
00097
00098
00099 std::string prevSource;
00100
00101 Node sourcePhrases;
00102 sourcePhrases.done = true;
00103 sourcePhrases.key = 0;
00104
00105 while (true) {
00106 try {
00107
00108 line_text line;
00109 line = splitLine(filein.ReadLine(), scfg);
00110
00111
00112 ++line_num;
00113 if (line_num % 1000000 == 0) {
00114 std::cerr << line_num << " " << std::flush;
00115 }
00116
00117
00118 add_to_map(sourceVocab, line.source_phrase);
00119
00120 if (prevSource.empty()) {
00121
00122 prevSource = line.source_phrase.as_string();
00123 storeTarget.Append(line, log_prob, scfg);
00124 } else if (prevSource == line.source_phrase) {
00125
00126 storeTarget.Append(line, log_prob, scfg);
00127 } else {
00128 assert(prevSource != line.source_phrase);
00129
00130
00131
00132
00133 uint64_t targetInd = storeTarget.Save();
00134
00135
00136 storeTarget.Append(line, log_prob, scfg);
00137
00138
00139 Entry sourceEntry;
00140 sourceEntry.value = targetInd;
00141
00142
00143 std::vector<uint64_t> vocabid_source = getVocabIDs(prevSource);
00144 if (scfg) {
00145
00146 sourcePhrases.Add(sourceEntries, vocabid_source);
00147 }
00148 sourceEntry.key = getKey(vocabid_source);
00149
00150
00151
00152
00153
00154
00155
00156 sourceEntries.Insert(sourceEntry);
00157
00158
00159 if (max_cache_size) {
00160 std::string countStr = line.counts.as_string();
00161 countStr = Trim(countStr);
00162 if (!countStr.empty()) {
00163 std::vector<float> toks = Tokenize<float>(countStr);
00164
00165
00166 if (toks.size() >= 2) {
00167 totalSourceCount += toks[1];
00168
00169
00170 std::vector<uint64_t> currVocabidSource = getVocabIDs(line.source_phrase.as_string());
00171 uint64_t currKey = getKey(currVocabidSource);
00172
00173 CacheItem *item = new CacheItem(
00174 Trim(line.source_phrase.as_string()),
00175 currKey,
00176 toks[1]);
00177 cache.push(item);
00178
00179 if (max_cache_size > 0 && cache.size() > max_cache_size) {
00180 cache.pop();
00181 }
00182 }
00183 }
00184 }
00185
00186
00187 prevSource = line.source_phrase.as_string();
00188 }
00189
00190 } catch (util::EndOfFileException e) {
00191 std::cerr
00192 << "Reading phrase table finished, writing remaining files to disk."
00193 << std::endl;
00194
00195
00196
00197 uint64_t targetInd = storeTarget.Save();
00198
00199 Entry sourceEntry;
00200 sourceEntry.value = targetInd;
00201
00202
00203 std::vector<uint64_t> vocabid_source = getVocabIDs(prevSource);
00204 sourceEntry.key = getKey(vocabid_source);
00205
00206
00207 sourceEntries.Insert(sourceEntry);
00208
00209 break;
00210 }
00211 }
00212
00213 sourcePhrases.Write(sourceEntries);
00214
00215 storeTarget.SaveAlignment();
00216
00217 serialize_table(mem, size, (basepath + "/probing_hash.dat"));
00218
00219 sourceVocab.Save();
00220
00221 serialize_cache(cache, (basepath + "/cache"), totalSourceCount);
00222
00223 delete[] mem;
00224
00225
00226 std::ofstream configfile;
00227 configfile.open((basepath + "/config").c_str());
00228 configfile << "API_VERSION\t" << API_VERSION << '\n';
00229 configfile << "uniq_entries\t" << uniq_entries << '\n';
00230 configfile << "num_scores\t" << num_scores << '\n';
00231 configfile << "num_lex_scores\t" << num_lex_scores << '\n';
00232 configfile << "log_prob\t" << log_prob << '\n';
00233 configfile.close();
00234 }
00235
00236 size_t countUniqueSource(const std::string &path)
00237 {
00238 size_t ret = 0;
00239 InputFileStream strme(path);
00240
00241 std::string line, prevSource;
00242 while (std::getline(strme, line)) {
00243 std::vector<std::string> toks = TokenizeMultiCharSeparator(line, "|||");
00244 assert(toks.size() != 0);
00245
00246 if (prevSource != toks[0]) {
00247 prevSource = toks[0];
00248 ++ret;
00249 }
00250 }
00251
00252 return ret;
00253 }
00254
00255 void serialize_cache(
00256 std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> &cache,
00257 const std::string &path, float totalSourceCount)
00258 {
00259 std::vector<const CacheItem*> vec(cache.size());
00260
00261 size_t ind = cache.size() - 1;
00262 while (!cache.empty()) {
00263 const CacheItem *item = cache.top();
00264 vec[ind] = item;
00265 cache.pop();
00266 --ind;
00267 }
00268
00269 std::ofstream os(path.c_str());
00270
00271 os << totalSourceCount << std::endl;
00272 for (size_t i = 0; i < vec.size(); ++i) {
00273 const CacheItem *item = vec[i];
00274 os << item->count << "\t" << item->sourceKey << "\t" << item->source << std::endl;
00275 delete item;
00276 }
00277
00278 os.close();
00279 }
00280
00281 uint64_t getKey(const std::vector<uint64_t> &vocabid_source)
00282 {
00283 return getKey(vocabid_source.data(), vocabid_source.size());
00284 }
00285
00286 std::vector<uint64_t> CreatePrefix(const std::vector<uint64_t> &vocabid_source, size_t endPos)
00287 {
00288 assert(endPos < vocabid_source.size());
00289
00290 std::vector<uint64_t> ret(endPos + 1);
00291 for (size_t i = 0; i <= endPos; ++i) {
00292 ret[i] = vocabid_source[i];
00293 }
00294 return ret;
00295 }
00296
00297 }
00298