00001 #include <cstdio>
00002 #include <cstdlib>
00003 #include <cstring>
00004 #include <unistd.h>
00005 #include <sys/types.h>
00006 #include "Remote.h"
00007 #include "moses/Factor.h"
00008 #include "util/string_stream.hh"
00009
00010 #if !defined(_WIN32) && !defined(_WIN64)
00011 #include <arpa/inet.h>
00012 #endif
00013
00014 namespace Moses
00015 {
00016
00017 const Factor* LanguageModelRemote::BOS = NULL;
00018 const Factor* LanguageModelRemote::EOS = (LanguageModelRemote::BOS + 1);
00019
00020 bool LanguageModelRemote::Load(const std::string &filePath
00021 , FactorType factorType
00022 , size_t nGramOrder)
00023 {
00024 m_factorType = factorType;
00025 m_nGramOrder = nGramOrder;
00026
00027 int cutAt = filePath.find(':',0);
00028 std::string host = filePath.substr(0,cutAt);
00029
00030 int port = atoi(filePath.substr(cutAt+1,filePath.size()-cutAt).c_str());
00031 bool good = start(host,port);
00032 if (!good) {
00033 std::cerr << "failed to connect to lm server on " << host << " on port " << port << std::endl;
00034 }
00035 ClearSentenceCache();
00036 return good;
00037 }
00038
00039
00040 bool LanguageModelRemote::start(const std::string& host, int port)
00041 {
00042
00043 sock = socket(AF_INET, SOCK_STREAM, 0);
00044 hp = gethostbyname(host.c_str());
00045 if (hp==NULL) {
00046 #if defined(_WIN32) || defined(_WIN64)
00047 fprintf(stderr, "gethostbyname failed\n");
00048 #else
00049 herror("gethostbyname failed");
00050 #endif
00051 exit(1);
00052 }
00053
00054 memset(&server, '\0', sizeof(server));
00055 memcpy((char *)&server.sin_addr, hp->h_addr, hp->h_length);
00056 server.sin_family = hp->h_addrtype;
00057 server.sin_port = htons(port);
00058
00059 int errors = 0;
00060 while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) {
00061
00062 sleep(1);
00063 errors++;
00064 if (errors > 5) return false;
00065 }
00066 return true;
00067 }
00068
00069 LMResult LanguageModelRemote::GetValue(const std::vector<const Word*> &contextFactor, State* finalState) const
00070 {
00071 LMResult ret;
00072 ret.unknown = false;
00073 size_t count = contextFactor.size();
00074 if (count == 0) {
00075 if (finalState) *finalState = NULL;
00076 ret.score = 0.0;
00077 return ret;
00078 }
00079
00080 size_t max = m_nGramOrder;
00081 const FactorType factor = GetFactorType();
00082 if (max > count) max = count;
00083
00084 Cache* cur = &m_cache;
00085 int pc = static_cast<int>(count) - 1;
00086 for (int i = 0; i < pc; ++i) {
00087 const Factor* f = contextFactor[i]->GetFactor(factor);
00088 cur = &cur->tree[f ? f : BOS];
00089 }
00090 const Factor* event_word = contextFactor[pc]->GetFactor(factor);
00091 cur = &cur->tree[event_word ? event_word : EOS];
00092 if (cur->prob) {
00093 if (finalState) *finalState = cur->boState;
00094 ret.score = cur->prob;
00095 return ret;
00096 }
00097 cur->boState = *reinterpret_cast<const State*>(&m_curId);
00098 ++m_curId;
00099
00100 util::StringStream os;
00101 os << "prob ";
00102 if (event_word == NULL) {
00103 os << "</s>";
00104 } else {
00105 os << event_word->GetString();
00106 }
00107 for (size_t i=1; i<max; i++) {
00108 const Factor* f = contextFactor[count-1-i]->GetFactor(factor);
00109 if (f == NULL) {
00110 os << " <s>";
00111 } else {
00112 os << ' ' << f->GetString();
00113 }
00114 }
00115 os << "\n";
00116 write(sock, os.str().c_str(), os.str().size());
00117 char res[6];
00118 int r = read(sock, res, 6);
00119 int errors = 0;
00120 int cnt = 0;
00121 while (1) {
00122 if (r < 0) {
00123 errors++;
00124 sleep(1);
00125
00126 if (errors > 5) exit(1);
00127 } else if (r==0 || res[cnt] == '\n') {
00128 break;
00129 } else {
00130 cnt += r;
00131 if (cnt==6) break;
00132 read(sock, &res[cnt], 6-cnt);
00133 }
00134 }
00135 cur->prob = FloorScore(TransformLMScore(*reinterpret_cast<float*>(res)));
00136 if (finalState) {
00137 *finalState = cur->boState;
00138 }
00139 ret.score = cur->prob;
00140 return ret;
00141 }
00142
00143 LanguageModelRemote::~LanguageModelRemote()
00144 {
00145
00146 close(sock);
00147 }
00148
00149 }