PageRenderTime 61ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/scripts/training/mbr/mbr.cpp

https://github.com/kowey/moses
C++ | 406 lines | 320 code | 62 blank | 24 comment | 80 complexity | e9f94623b0f763aebcdd97c1670d4568 MD5 | raw file
  1. #include <iostream>
  2. #include <fstream>
  3. #include <sstream>
  4. #include <iomanip>
  5. #include <vector>
  6. #include <map>
  7. #include <stdlib.h>
  8. #include <math.h>
  9. #include <algorithm>
  10. #include <stdio.h>
  11. #include <unistd.h>
  12. #include <cstring>
  13. using namespace std ;
  14. /* Input :
  15. 1. a sorted n-best list, with duplicates filtered out in the following format
  16. 0 ||| amr moussa is currently on a visit to libya , tomorrow , sunday , to hold talks with regard to the in sudan . ||| 0 -4.94418 0 0 -2.16036 0 0 -81.4462 -106.593 -114.43 -105.55 -12.7873 -26.9057 -25.3715 -52.9336 7.99917 -24 ||| -4.58432
  17. 2. a weight vector
  18. 3. bleu order ( default = 4)
  19. 4. scaling factor to weigh the weight vector (default = 1.0)
  20. Output :
  21. translations that minimise the Bayes Risk of the n-best list
  22. */
  23. int TABLE_LINE_MAX_LENGTH = 5000;
  24. vector<double> weights;
  25. float SCALE = 1.0;
  26. int BLEU_ORDER = 4;
  27. int SMOOTH = 1;
  28. int DEBUG = 0;
  29. double min_interval = 1e-4;
  30. #define SAFE_GETLINE(_IS, _LINE, _SIZE, _DELIM) {_IS.getline(_LINE, _SIZE, _DELIM); if(_IS.fail() && !_IS.bad() && !_IS.eof()) _IS.clear();}
  31. typedef string WORD;
  32. typedef unsigned int WORD_ID;
  33. map<WORD, WORD_ID> lookup;
  34. vector< WORD > vocab;
  35. class candidate_t{
  36. public:
  37. vector<WORD_ID> translation;
  38. vector<double> features;
  39. int translation_size;
  40. } ;
  41. void usage(void)
  42. {
  43. fprintf(stderr,
  44. "usage: mbr -s SCALE -n BLEU_ORDER -w weights.txt -i nbest.txt");
  45. }
  46. char *strstrsep(char **stringp, const char *delim) {
  47. char *match, *save;
  48. save = *stringp;
  49. if (*stringp == NULL)
  50. return NULL;
  51. match = strstr(*stringp, delim);
  52. if (match == NULL) {
  53. *stringp = NULL;
  54. return save;
  55. }
  56. *match = '\0';
  57. *stringp = match + strlen(delim);
  58. return save;
  59. }
  60. vector<string> tokenize( const char input[] )
  61. {
  62. vector< string > token;
  63. bool betweenWords = true;
  64. int start;
  65. int i=0;
  66. for(; input[i] != '\0'; i++)
  67. {
  68. bool isSpace = (input[i] == ' ' || input[i] == '\t');
  69. if (!isSpace && betweenWords)
  70. {
  71. start = i;
  72. betweenWords = false;
  73. }
  74. else if (isSpace && !betweenWords)
  75. {
  76. token.push_back( string( input+start, i-start ) );
  77. betweenWords = true;
  78. }
  79. }
  80. if (!betweenWords)
  81. token.push_back( string( input+start, i-start+1 ) );
  82. return token;
  83. }
  84. WORD_ID storeIfNew( WORD word )
  85. {
  86. if( lookup.find( word ) != lookup.end() )
  87. return lookup[ word ];
  88. WORD_ID id = vocab.size();
  89. vocab.push_back( word );
  90. lookup[ word ] = id;
  91. return id;
  92. }
  93. int count( string input, char delim )
  94. {
  95. int count = 0;
  96. for ( int i = 0; i < input.size(); i++){
  97. if ( input[i] == delim)
  98. count++;
  99. }
  100. return count;
  101. }
  102. double calculate_probability(const vector<double> & feats, const vector<double> & weights,double SCALE){
  103. if (feats.size() != weights.size())
  104. cerr << "ERROR : Number of features <> number of weights " << endl;
  105. double prob = 0;
  106. for ( int i = 0; i < feats.size(); i++){
  107. prob += feats[i]*weights[i]*SCALE;
  108. }
  109. return exp(prob);
  110. }
  111. void extract_ngrams(const vector<WORD_ID>& sentence, map < vector < WORD_ID>, int > & allngrams)
  112. {
  113. vector< WORD_ID> ngram;
  114. for (int k = 0; k< BLEU_ORDER; k++)
  115. {
  116. for(int i =0; i < max((int)sentence.size()-k,0); i++)
  117. {
  118. for ( int j = i; j<= i+k; j++)
  119. {
  120. ngram.push_back(sentence[j]);
  121. }
  122. ++allngrams[ngram];
  123. ngram.clear();
  124. }
  125. }
  126. }
  127. double calculate_score(const vector<candidate_t*> & sents, int ref, int hyp, vector < map < vector < WORD_ID>, int > > & ngram_stats ) {
  128. int comps_n = 2*BLEU_ORDER+1;
  129. int comps[comps_n];
  130. double logbleu = 0.0, brevity;
  131. int hyp_length = sents[hyp]->translation_size;
  132. for (int i =0; i<BLEU_ORDER;i++)
  133. {
  134. comps[2*i] = 0;
  135. comps[2*i+1] = max(hyp_length-i,0);
  136. }
  137. map< vector < WORD_ID > ,int > & hyp_ngrams = ngram_stats[hyp] ;
  138. map< vector < WORD_ID >, int > & ref_ngrams = ngram_stats[ref] ;
  139. for (map< vector< WORD_ID >, int >::iterator it = hyp_ngrams.begin();
  140. it != hyp_ngrams.end(); it++)
  141. {
  142. map< vector< WORD_ID >, int >::iterator ref_it = ref_ngrams.find(it->first);
  143. if(ref_it != ref_ngrams.end())
  144. {
  145. comps[2* (it->first.size()-1)] += min(ref_it->second,it->second);
  146. }
  147. }
  148. comps[comps_n-1] = sents[ref]->translation_size;
  149. if (DEBUG)
  150. {
  151. for ( int i = 0; i < comps_n; i++)
  152. cerr << "Comp " << i << " : " << comps[i];
  153. }
  154. for (int i=0; i<BLEU_ORDER; i++)
  155. {
  156. if (comps[0] == 0)
  157. return 0.0;
  158. if ( i > 0 )
  159. logbleu += log(static_cast<double>(comps[2*i]+SMOOTH))-log(static_cast<double>(comps[2*i+1]+SMOOTH));
  160. else
  161. logbleu += log(static_cast<double>(comps[2*i]))-log(static_cast<double>(comps[2*i+1]));
  162. }
  163. logbleu /= BLEU_ORDER;
  164. brevity = 1.0-(double)comps[comps_n-1]/comps[1]; // comps[comps_n-1] is the ref length, comps[1] is the test length
  165. if (brevity < 0.0)
  166. logbleu += brevity;
  167. return exp(logbleu);
  168. }
  169. vector<double> read_weights(string fileName){
  170. ifstream inFile;
  171. inFile.open(fileName.c_str());
  172. istream *inFileP = &inFile;
  173. char line[TABLE_LINE_MAX_LENGTH];
  174. int i=0;
  175. vector<double> weights;
  176. while(true)
  177. {
  178. i++;
  179. SAFE_GETLINE((*inFileP), line, TABLE_LINE_MAX_LENGTH, '\n');
  180. if (inFileP->eof()) break;
  181. vector<string> token = tokenize(line);
  182. for (int j = 0; j < token.size(); j++){
  183. weights.push_back(atof(token[j].c_str()));
  184. }
  185. }
  186. cerr << endl;
  187. return weights;
  188. }
  189. int find_pos_of_min_element(const vector<double>& vec){
  190. int min_pos = -1;
  191. double min_element = 10000;
  192. for ( int i = 0; i < vec.size(); i++){
  193. if (vec[i] < min_element){
  194. min_element = vec[i];
  195. min_pos = i;
  196. }
  197. }
  198. /* cerr << "Min pos is : " << min_pos << endl;
  199. cerr << "Min mbr loss is : " << min_element << endl;*/
  200. return min_pos;
  201. }
  202. void process(int sent, const vector<candidate_t*> & sents){
  203. // cerr << "Sentence " << sent << " has " << sents.size() << " candidate translations" << endl;
  204. double marginal = 0;
  205. vector<double> joint_prob_vec;
  206. double joint_prob;
  207. vector< map < vector <WORD_ID>, int > > ngram_stats;
  208. for (int i = 0; i < sents.size(); i++){
  209. // cerr << "Sents " << i << " has trans : " << sents[i]->translation << endl;
  210. //Calculate marginal and cache the posteriors
  211. joint_prob = calculate_probability(sents[i]->features,weights,SCALE);
  212. marginal += joint_prob;
  213. joint_prob_vec.push_back(joint_prob);
  214. //Cache ngram counts
  215. map < vector <WORD_ID>, int > counts;
  216. extract_ngrams(sents[i]->translation,counts);
  217. ngram_stats.push_back(counts);
  218. }
  219. //cerr << "Marginal is " << marginal;
  220. vector<double> mbr_loss;
  221. double bleu, weightedLoss;
  222. double weightedLossCumul = 0;
  223. double minMBRLoss = 1000000;
  224. int minMBRLossIdx = -1;
  225. /* Main MBR computation done here */
  226. for (int i = 0; i < sents.size(); i++){
  227. weightedLossCumul = 0;
  228. for (int j = 0; j < sents.size(); j++){
  229. if ( i != j) {
  230. bleu = calculate_score(sents, j, i,ngram_stats );
  231. weightedLoss = ( 1 - bleu) * ( joint_prob_vec[j]/marginal);
  232. weightedLossCumul += weightedLoss;
  233. if (weightedLossCumul > minMBRLoss)
  234. break;
  235. }
  236. }
  237. if (weightedLossCumul < minMBRLoss){
  238. minMBRLoss = weightedLossCumul;
  239. minMBRLossIdx = i;
  240. }
  241. }
  242. // cerr << "Min pos is : " << minMBRLossIdx << endl;
  243. // cerr << "Min mbr loss is : " << minMBRLoss << endl;
  244. /* Find sentence that minimises Bayes Risk under 1- BLEU loss */
  245. vector< WORD_ID > best_translation = sents[minMBRLossIdx]->translation;
  246. for (int i = 0; i < best_translation.size(); i++)
  247. cout << vocab[best_translation[i]] << " " ;
  248. cout << endl;
  249. }
  250. void read_nbest_data(string fileName)
  251. {
  252. FILE * fp;
  253. fp = fopen (fileName.c_str() , "r");
  254. static char buf[10000];
  255. char *rest, *tok;
  256. int field;
  257. int sent_i, cur_sent;
  258. candidate_t *cand = NULL;
  259. vector<candidate_t*> testsents;
  260. cur_sent = -1;
  261. while (fgets(buf, sizeof(buf), fp) != NULL) {
  262. field = 0;
  263. rest = buf;
  264. while ((tok = strstrsep(&rest, "|||")) != NULL) {
  265. if (field == 0) {
  266. sent_i = strtol(tok, NULL, 10);
  267. cand = new candidate_t;
  268. } else if (field == 2) {
  269. vector<double> features;
  270. char * subtok;
  271. subtok = strtok (tok," ");
  272. while (subtok != NULL)
  273. {
  274. features.push_back(atof(subtok));
  275. subtok = strtok (NULL, " ");
  276. }
  277. cand->features = features;
  278. } else if (field == 1) {
  279. vector<string> trans_str = tokenize(tok);
  280. vector<WORD_ID> trans_int;
  281. for (int j=0; j<trans_str.size(); j++)
  282. {
  283. trans_int.push_back( storeIfNew( trans_str[j] ) );
  284. }
  285. cand->translation= trans_int;
  286. cand->translation_size = cand->translation.size();
  287. } else if (field == 3) {
  288. continue;
  289. }
  290. else {
  291. fprintf(stderr, "too many fields in n-best list line\n");
  292. }
  293. field++;
  294. }
  295. if (sent_i != cur_sent){
  296. if (cur_sent != - 1) {
  297. process(cur_sent,testsents);
  298. }
  299. cur_sent = sent_i;
  300. testsents.clear();
  301. }
  302. testsents.push_back(cand);
  303. }
  304. process(cur_sent,testsents);
  305. cerr << endl;
  306. }
  307. int main(int argc, char **argv)
  308. {
  309. time_t starttime = time(NULL);
  310. int c;
  311. string f_weight = "";
  312. string f_nbest = "";
  313. while ((c = getopt(argc, argv, "s:w:n:i:")) != -1) {
  314. switch (c) {
  315. case 's':
  316. SCALE = atof(optarg);
  317. break;
  318. case 'n':
  319. BLEU_ORDER = atoi(optarg);
  320. break;
  321. case 'w':
  322. f_weight = optarg;
  323. break;
  324. case 'i':
  325. f_nbest = optarg;
  326. break;
  327. default:
  328. usage();
  329. }
  330. }
  331. argc -= optind;
  332. argv += optind;
  333. if (argc < 2) {
  334. usage();
  335. }
  336. weights = read_weights(f_weight);
  337. read_nbest_data(f_nbest);
  338. time_t endtime = time(NULL);
  339. cerr << "Processed data in" << (endtime-starttime) << " seconds\n";
  340. }