/ED-iHMM/code/matlab/iHMM/iHmmHyperSample.m

http://github.com/jhuggins/columbia · Objective C · 79 lines · 71 code · 8 blank · 0 comment · 12 complexity · 8b9434067bace99a63b81c246ca3ceb9 MD5 · raw file

  1. function [sbeta, salpha0, sgamma, N, M] = iHmmHyperSample(S, ibeta, ialpha0, igamma, hypers, numi)
  2. % IHMMHYPERSAMPLE resamples the hyperparameters of an infinite hmm.
  3. %
  4. % [sbeta, salpha0, sgamma, N, M] = ...
  5. % iHmmHyperSample(S, ibeta, ialpha0, igamma, hypers, numi) resamples the
  6. % hyperparameters given the state sequence S, the previous
  7. % hyperparameters ibeta, ialpha0, igamma and their respective
  8. % hyper-hyperparameters in the structure hypers (needs alpha0_a,
  9. % alpha0_b, gamma_a and gamma_b fields corresponding to gamma prior on
  10. % the hyperparameters). If the hyper-hyperparameters are not given, the
  11. % estimated alpha0 and gamma will be the same as the input alpha0 and
  12. % gamma. numi is the number of times we run the Gibbs samplers for alpha0
  13. % and gamma (see HDP paper or Escobar & West); we recommend a value of
  14. % around 20. The function returns the new hyperparameters, the CRF counts
  15. % (N) and the sampled number of tables in every restaurant (M).
  16. %
  17. % Note that the size of the resampled beta will be the same as the size
  18. % of the original beta.
  19. K = length(ibeta)-1; % # of states in iHmm.
  20. T = size(S,2); % length of iHmm.
  21. % Compute N: state transition counts.
  22. N = zeros(K,K);
  23. N(1,S(1)) = 1;
  24. for t=2:T
  25. N(S(t-1), S(t)) = N(S(t-1), S(t)) + 1;
  26. end
  27. % Compute M: number of similar dishes in each restaurant.
  28. M = zeros(K);
  29. for j=1:K
  30. for k=1:K
  31. if N(j,k) == 0
  32. M(j,k) = 0;
  33. else
  34. for l=1:N(j,k)
  35. M(j,k) = M(j,k) + (rand() < (ialpha0 * ibeta(k)) / (ialpha0 * ibeta(k) + l - 1));
  36. end
  37. end
  38. end
  39. end
  40. % Resample beta
  41. ibeta = dirichlet_sample([sum(M,1) igamma]);
  42. % Resample alpha
  43. if isfield(hypers, 'alpha0')
  44. ialpha0 = hypers.alpha0;
  45. else
  46. for iter = 1:numi
  47. w = betarnd(ialpha0+1, sum(N,2));
  48. p = sum(N,2)/ialpha0; p = p ./ (p+1);
  49. s = binornd(1, p);
  50. ialpha0 = gamrnd(hypers.alpha0_a + sum(sum(M)) - sum(s), 1.0 / (hypers.alpha0_b - sum(log(w))));
  51. end
  52. end
  53. % Resample gamma (using Escobar & West 1995)
  54. if isfield(hypers, 'gamma')
  55. igamma = hypers.gamma;
  56. else
  57. k = length(ibeta);
  58. m = sum(sum(M));
  59. for iter = 1:numi
  60. mu = betarnd(igamma + 1, m);
  61. pi_mu = 1 / (1 + (m * (hypers.gamma_b - log(mu))) / (hypers.gamma_a + k - 1) );
  62. if rand() < pi_mu
  63. igamma = gamrnd(hypers.gamma_a + k, 1.0 / (hypers.gamma_b - log(mu)));
  64. else
  65. igamma = gamrnd(hypers.gamma_a + k - 1, 1.0 / (hypers.gamma_b - log(mu)));
  66. end
  67. end
  68. end
  69. sbeta = ibeta;
  70. salpha0 = ialpha0;
  71. sgamma = igamma;