PageRenderTime 45ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/fhmm.c

https://github.com/chuan/chmm
C | 507 lines | 404 code | 65 blank | 38 comment | 150 complexity | 7e7d3ca56e9bc32ec16a11e47c0a1ad3 MD5 | raw file
  1. /*
  2. * Copyright (c) 2009, Chuan Liu <chuan@cs.jhu.edu>
  3. *
  4. * Permission is hereby granted, free of charge, to any person
  5. * obtaining a copy of this software and associated documentation
  6. * files (the "Software"), to deal in the Software without
  7. * restriction, including without limitation the rights to use, copy,
  8. * modify, merge, publish, distribute, sublicense, and/or sell copies
  9. * of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be
  13. * included in all copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  16. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  17. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  18. * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  19. * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  20. * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  21. * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  22. * SOFTWARE.
  23. *
  24. */
  25. #ifndef _GNU_SOURCE
  26. #define _GNU_SOURCE
  27. #endif
  28. #include <math.h>
  29. #include <stdio.h>
  30. #include <stdlib.h>
  31. #include <unistd.h>
  32. #define handle_error(msg) \
  33. do { perror(msg); exit(EXIT_FAILURE); } while (0)
  34. #define IDX(i,j,d) (((i)*(d))+(j))
  35. int nstates = 0; /* number of states */
  36. int nobvs = 0; /* number of observations */
  37. int nseq = 0; /* number of data sequences */
  38. int length = 0; /* data sequencel length */
  39. float *prior = NULL; /* initial state probabilities */
  40. float *trans = NULL; /* state transition probabilities */
  41. float *obvs = NULL; /* output probabilities */
  42. int *data = NULL;
  43. float *gmm = NULL; /* gamma */
  44. float *xi = NULL; /* xi */
  45. float *pi = NULL; /* pi */
  46. float logadd(float, float);
  47. float sumf(float *, int);
  48. float forward_backward(int *, size_t, int);
  49. void viterbi(int *, size_t);
  50. void init_count();
  51. void update_prob();
  52. void usage();
  53. void freeall();
  54. int main(int argc, char *argv[])
  55. {
  56. char *configfile = NULL;
  57. FILE *fin, *bin;
  58. char *linebuf = NULL;
  59. size_t buflen = 0;
  60. int iterations = 3;
  61. int mode = 3;
  62. int c;
  63. float d;
  64. float *loglik;
  65. float p;
  66. int i, j, k;
  67. opterr = 0;
  68. while ((c = getopt(argc, argv, "c:n:hp:")) != -1) {
  69. switch (c) {
  70. case 'c':
  71. configfile = optarg;
  72. break;
  73. case 'h':
  74. usage();
  75. exit(EXIT_SUCCESS);
  76. case 'n':
  77. iterations = atoi(optarg);
  78. break;
  79. case 'p':
  80. mode = atoi(optarg);
  81. if (mode != 1 && mode != 2 && mode != 3) {
  82. fprintf(stderr, "illegal mode: %d\n", mode);
  83. exit(EXIT_FAILURE);
  84. }
  85. break;
  86. case '?':
  87. fprintf(stderr, "illegal options\n");
  88. exit(EXIT_FAILURE);
  89. default:
  90. abort();
  91. }
  92. }
  93. if (configfile == NULL) {
  94. fin = stdin;
  95. } else {
  96. fin = fopen(configfile, "r");
  97. if (fin == NULL) {
  98. handle_error("fopen");
  99. }
  100. }
  101. i = 0;
  102. while ((c = getline(&linebuf, &buflen, fin)) != -1) {
  103. if (c <= 1 || linebuf[0] == '#')
  104. continue;
  105. if (i == 0) {
  106. if (sscanf(linebuf, "%d", &nstates) != 1) {
  107. fprintf(stderr, "config file format error: %d\n", i);
  108. freeall();
  109. exit(EXIT_FAILURE);
  110. }
  111. prior = (float *) malloc(sizeof(float) * nstates);
  112. if (prior == NULL) handle_error("malloc");
  113. trans = (float *) malloc(sizeof(float) * nstates * nstates);
  114. if (trans == NULL) handle_error("malloc");
  115. xi = (float *) malloc(sizeof(float) * nstates * nstates);
  116. if (xi == NULL) handle_error("malloc");
  117. pi = (float *) malloc(sizeof(float) * nstates);
  118. if (pi == NULL) handle_error("malloc");
  119. } else if (i == 1) {
  120. if (sscanf(linebuf, "%d", &nobvs) != 1) {
  121. fprintf(stderr, "config file format error: %d\n", i);
  122. freeall();
  123. exit(EXIT_FAILURE);
  124. }
  125. obvs = (float *) malloc(sizeof(float) * nstates * nobvs);
  126. if (obvs == NULL) handle_error("malloc");
  127. gmm = (float *) malloc(sizeof(float) * nstates * nobvs);
  128. if (gmm == NULL) handle_error("malloc");
  129. } else if (i == 2) {
  130. /* read initial state probabilities */
  131. bin = fmemopen(linebuf, buflen, "r");
  132. if (bin == NULL) handle_error("fmemopen");
  133. for (j = 0; j < nstates; j++) {
  134. if (fscanf(bin, "%f", &d) != 1) {
  135. fprintf(stderr, "config file format error: %d\n", i);
  136. freeall();
  137. exit(EXIT_FAILURE);
  138. }
  139. prior[j] = logf(d);
  140. }
  141. fclose(bin);
  142. } else if (i <= 2 + nstates) {
  143. /* read state transition probabilities */
  144. bin = fmemopen(linebuf, buflen, "r");
  145. if (bin == NULL) handle_error("fmemopen");
  146. for (j = 0; j < nstates; j++) {
  147. if (fscanf(bin, "%f", &d) != 1) {
  148. fprintf(stderr, "config file format error: %d\n", i);
  149. freeall();
  150. exit(EXIT_FAILURE);
  151. }
  152. trans[IDX((i - 3),j,nstates)] = logf(d);
  153. }
  154. fclose(bin);
  155. } else if (i <= 2 + nstates * 2) {
  156. /* read output probabilities */
  157. bin = fmemopen(linebuf, buflen, "r");
  158. if (bin == NULL) handle_error("fmemopen");
  159. for (j = 0; j < nobvs; j++) {
  160. if (fscanf(bin, "%f", &d) != 1) {
  161. fprintf(stderr, "config file format error: %d\n", i);
  162. freeall();
  163. exit(EXIT_FAILURE);
  164. }
  165. obvs[IDX((i - 3 - nstates),j,nobvs)] = logf(d);
  166. }
  167. fclose(bin);
  168. } else if (i == 3 + nstates * 2) {
  169. if (sscanf(linebuf, "%d %d", &nseq, &length) != 2) {
  170. fprintf(stderr, "config file format error: %d\n", i);
  171. freeall();
  172. exit(EXIT_FAILURE);
  173. }
  174. data = (int *) malloc (sizeof(int) * nseq * length);
  175. if (data == NULL) handle_error("malloc");
  176. } else if (i <= 3 + nstates * 2 + nseq) {
  177. /* read data */
  178. bin = fmemopen(linebuf, buflen, "r");
  179. if (bin == NULL) handle_error("fmemopen");
  180. for (j = 0; j < length; j++) {
  181. if (fscanf(bin, "%d", &k) != 1 || k < 0 || k >= nobvs) {
  182. fprintf(stderr, "config file format error: %d\n", i);
  183. freeall();
  184. exit(EXIT_FAILURE);
  185. }
  186. data[(i - 4 - nstates * 2) * length + j] = k;
  187. }
  188. fclose(bin);
  189. }
  190. i++;
  191. }
  192. fclose(fin);
  193. if (linebuf) free(linebuf);
  194. if (i < 4 + nstates * 2 + nseq) {
  195. fprintf(stderr, "configuration incomplete.\n");
  196. freeall();
  197. exit(EXIT_FAILURE);
  198. }
  199. if (mode == 3) {
  200. loglik = (float *) malloc(sizeof(float) * nseq);
  201. if (loglik == NULL) handle_error("malloc");
  202. for (i = 0; i < iterations; i++) {
  203. init_count();
  204. for (j = 0; j < nseq; j++) {
  205. loglik[j] = forward_backward(data + length * j, length, 1);
  206. }
  207. p = sumf(loglik, nseq);
  208. update_prob();
  209. printf("iteration %d log-likelihood: %.4f\n", i + 1, p);
  210. printf("updated parameters:\n");
  211. printf("# initial state probability\n");
  212. for (j = 0; j < nstates; j++) {
  213. printf(" %.4f", exp(prior[j]));
  214. }
  215. printf("\n");
  216. printf("# state transition probability\n");
  217. for (j = 0; j < nstates; j++) {
  218. for (k = 0; k < nstates; k++) {
  219. printf(" %.4f", exp(trans[IDX(j,k,nstates)]));
  220. }
  221. printf("\n");
  222. }
  223. printf("# state output probility\n");
  224. for (j = 0; j < nstates; j++) {
  225. for (k = 0; k < nobvs; k++) {
  226. printf(" %.4f", exp(obvs[IDX(j,k,nobvs)]));
  227. }
  228. printf("\n");
  229. }
  230. printf("\n");
  231. }
  232. free(loglik);
  233. } else if (mode == 2) {
  234. for (i = 0; i < nseq; i++) {
  235. viterbi(data + length * i, length);
  236. }
  237. } else if (mode == 1) {
  238. loglik = (float *) malloc(sizeof(float) * nseq);
  239. if (loglik == NULL) handle_error("malloc");
  240. for (i = 0; i < nseq; i++) {
  241. loglik[i] = forward_backward(data + length * i, length, 0);
  242. }
  243. p = sumf(loglik, nseq);
  244. for (i = 0; i < nseq; i++)
  245. printf("%.4f\n", loglik[i]);
  246. printf("total: %.4f\n", p);
  247. free(loglik);
  248. }
  249. freeall();
  250. return 0;
  251. }
  252. /* compute sum of the array using Kahan summation algorithm */
  253. float sumf(float *data, int size)
  254. {
  255. float sum = data[0];
  256. int i;
  257. float y, t;
  258. float c = 0.0;
  259. for (i = 1; i < size; i++) {
  260. y = data[i] - c;
  261. t = sum + y;
  262. c = (t - sum) - y;
  263. sum = t;
  264. }
  265. return sum;
  266. }
  267. /* initilize counts */
  268. void init_count() {
  269. size_t i;
  270. for (i = 0; i < nstates * nobvs; i++)
  271. gmm[i] = - INFINITY;
  272. for (i = 0; i < nstates * nstates; i++)
  273. xi[i] = - INFINITY;
  274. for (i = 0; i < nstates; i++)
  275. pi[i] = - INFINITY;
  276. }
  277. void update_prob() {
  278. float pisum = - INFINITY;
  279. float gmmsum[nstates];
  280. float xisum[nstates];
  281. size_t i, j;
  282. for (i = 0; i < nstates; i++) {
  283. gmmsum[i] = - INFINITY;
  284. xisum[i] = - INFINITY;
  285. pisum = logadd(pi[i], pisum);
  286. }
  287. for (i = 0; i < nstates; i++) {
  288. prior[i] = pi[i] - pisum;
  289. }
  290. for (i = 0; i < nstates; i++) {
  291. for (j = 0; j < nstates; j++) {
  292. xisum[i] = logadd(xisum[i], xi[IDX(i,j,nstates)]);
  293. }
  294. for (j = 0; j < nobvs; j++) {
  295. gmmsum[i] = logadd(gmmsum[i], gmm[IDX(i,j,nobvs)]);
  296. }
  297. }
  298. for (i = 0; i < nstates; i++) {
  299. for (j = 0; j < nstates; j++) {
  300. trans[IDX(i,j,nstates)] = xi[IDX(i,j,nstates)] - xisum[i];
  301. }
  302. for (j = 0; j < nobvs; j++) {
  303. obvs[IDX(i,j,nobvs)] = gmm[IDX(i,j,nobvs)] - gmmsum[i];
  304. }
  305. }
  306. }
  307. /* forward backward algoritm: return observation likelihood */
  308. float forward_backward(int *data, size_t len, int backward)
  309. {
  310. /* construct trellis */
  311. float alpha[len][nstates];
  312. float beta[len][nstates];
  313. size_t i, j, k;
  314. float p, e;
  315. float loglik;
  316. for (i = 0; i < len; i++) {
  317. for (j = 0; j < nstates; j++) {
  318. alpha[i][j] = - INFINITY;
  319. beta[i][j] = - INFINITY;
  320. }
  321. }
  322. /* forward pass */
  323. for (i = 0; i < nstates; i++) {
  324. alpha[0][i] = prior[i] + obvs[IDX(i,data[0],nobvs)];
  325. }
  326. for (i = 1; i < len; i++) {
  327. for (j = 0; j < nstates; j++) {
  328. for (k = 0; k < nstates; k++) {
  329. p = alpha[i-1][k] + trans[IDX(k,j,nstates)] + obvs[IDX(j,data[i],nobvs)];
  330. alpha[i][j] = logadd(alpha[i][j], p);
  331. }
  332. }
  333. }
  334. loglik = -INFINITY;
  335. for (i = 0; i < nstates; i++) {
  336. loglik = logadd(loglik, alpha[len-1][i]);
  337. }
  338. if (! backward)
  339. return loglik;
  340. /* backward pass & update counts */
  341. for (i = 0; i < nstates; i++) {
  342. beta[len-1][i] = 0; /* 0 = log (1.0) */
  343. }
  344. for (i = 1; i < len; i++) {
  345. for (j = 0; j < nstates; j++) {
  346. e = alpha[len-i][j] + beta[len-i][j] - loglik;
  347. gmm[IDX(j,data[len-i],nobvs)] = logadd(gmm[IDX(j,data[len-i],nobvs)], e);
  348. for (k = 0; k < nstates; k++) {
  349. p = beta[len-i][k] + trans[IDX(j,k,nstates)] + obvs[IDX(k,data[len-i],nobvs)];
  350. beta[len-1-i][j] = logadd(beta[len-1-i][j], p);
  351. e = alpha[len-1-i][j] + beta[len-i][k]
  352. + trans[IDX(j,k,nstates)] + obvs[IDX(k,data[len-i],nobvs)] - loglik;
  353. xi[IDX(j,k,nstates)] = logadd(xi[IDX(j,k,nstates)], e);
  354. }
  355. }
  356. }
  357. p = -INFINITY;
  358. for (i = 0; i < nstates; i++) {
  359. p = logadd(p, prior[i] + beta[0][i] + obvs[IDX(i,data[0],nobvs)]);
  360. e = alpha[0][i] + beta[0][i] - loglik;
  361. gmm[IDX(i,data[0],nobvs)] = logadd(gmm[IDX(i,data[0],nobvs)], e);
  362. pi[i] = logadd(pi[i], e);
  363. }
  364. #ifdef DEBUG
  365. /* verify if forward prob == backward prob */
  366. if (fabs(p - loglik) > 1e-3) {
  367. fprintf(stderr, "Error: forward and backward incompatible: %f, %f\n", loglik, p);
  368. }
  369. #endif
  370. return loglik;
  371. }
  372. /* find the most probable sequence */
  373. void viterbi(int *data, size_t len)
  374. {
  375. float lambda[len][nstates];
  376. int backtrace[len][nstates];
  377. int stack[len];
  378. size_t i, j, k;
  379. float p;
  380. for (i = 0; i < len; i++) {
  381. for (j = 0; j < nstates; j++) {
  382. lambda[i][j] = - INFINITY;
  383. }
  384. }
  385. for (i = 0; i < nstates; i++) {
  386. lambda[0][i] = prior[i] + obvs[IDX(i,data[0],nobvs)];
  387. backtrace[0][i] = -1; /* -1 is starting point */
  388. }
  389. for (i = 1; i < len; i++) {
  390. for (j = 0; j < nstates; j++) {
  391. for (k = 0; k < nstates; k++) {
  392. p = lambda[i-1][k] + trans[IDX(k,j,nstates)] + obvs[IDX(j,data[i],nobvs)];
  393. if (p > lambda[i][j]) {
  394. lambda[i][j] = p;
  395. backtrace[i][j] = k;
  396. }
  397. }
  398. }
  399. }
  400. /* backtrace */
  401. for (i = 0; i < nstates; i++) {
  402. if (i == 0 || lambda[len-1][i] > p) {
  403. p = lambda[len-1][i];
  404. k = i;
  405. }
  406. }
  407. stack[len - 1] = k;
  408. for (i = 1; i < len; i++) {
  409. stack[len - 1 - i] = backtrace[len - i][stack[len - i]];
  410. }
  411. for (i = 0; i < len; i++) {
  412. printf("%d ", stack[i]);
  413. }
  414. printf("\n");
  415. }
  416. float logadd(float x, float y) {
  417. if (y <= x)
  418. return x + log1pf(expf(y - x));
  419. else
  420. return y + log1pf(expf(x - y));
  421. }
  422. void usage() {
  423. fprintf(stdout, "hmm [-hnt] [-c config] [-p(1|2|3)]\n");
  424. fprintf(stdout, "usage:\n");
  425. fprintf(stdout, " -h help\n");
  426. fprintf(stdout, " -c configuration file\n");
  427. fprintf(stdout, " -t output computation time\n");
  428. fprintf(stdout, " -p1 compute the probability of the observation sequence\n");
  429. fprintf(stdout, " -p2 compute the most probable sequence (Viterbi)\n");
  430. fprintf(stdout, " -p3 train hidden Markov mode parameters (Baum-Welch)\n");
  431. fprintf(stdout, " -n number of iterations\n");
  432. }
  433. /* free all memory */
  434. void freeall() {
  435. if (trans) free(trans);
  436. if (obvs) free(obvs);
  437. if (prior) free(prior);
  438. if (data) free(data);
  439. if (gmm) free(gmm);
  440. if (xi) free(xi);
  441. if (pi) free(pi);
  442. }