PageRenderTime 49ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/edward/inferences.py

https://gitlab.com/yogeshc/edward
Python | 357 lines | 177 code | 50 blank | 130 comment | 20 complexity | 977910eed793df42b05f7b862b193a46 MD5 | raw file
  1. from __future__ import print_function
  2. import numpy as np
  3. import tensorflow as tf
  4. from edward.data import Data
  5. from edward.util import kl_multivariate_normal, log_sum_exp
  6. from edward.variationals import PointMass
  7. try:
  8. import prettytensor as pt
  9. except ImportError:
  10. pass
  11. class Inference:
  12. """
  13. Base class for inference methods.
  14. Arguments
  15. ----------
  16. model: Model
  17. probability model p(x, z)
  18. data: Data, optional
  19. data x
  20. """
  21. def __init__(self, model, data=Data()):
  22. self.model = model
  23. self.data = data
  24. class MonteCarlo(Inference):
  25. """
  26. Base class for Monte Carlo methods.
  27. Arguments
  28. ----------
  29. model: Model
  30. probability model p(x, z)
  31. data: Data, optional
  32. data x
  33. """
  34. def __init__(self, *args, **kwargs):
  35. Inference.__init__(self, *args, **kwargs)
  36. class VariationalInference(Inference):
  37. """
  38. Base class for variational inference methods.
  39. Arguments
  40. ----------
  41. model: Model
  42. probability model p(x, z)
  43. variational: Variational
  44. variational model q(z; lambda)
  45. data: Data, optional
  46. data x
  47. """
  48. def __init__(self, model, variational, data=Data()):
  49. Inference.__init__(self, model, data)
  50. self.variational = variational
  51. def run(self, *args, **kwargs):
  52. """
  53. A simple wrapper to run the inference algorithm.
  54. """
  55. sess = self.initialize(*args, **kwargs)
  56. for t in range(self.n_iter):
  57. loss = self.update(sess)
  58. self.print_progress(t, loss, sess)
  59. return sess
  60. def initialize(self, n_iter=1000, n_data=None, n_print=100):
  61. """
  62. Initialize inference algorithm.
  63. Arguments
  64. ----------
  65. n_iter: int, optional
  66. Number of iterations for optimization.
  67. n_data: int, optional
  68. Number of samples for data subsampling. Default is to use all
  69. the data.
  70. n_print: int, optional
  71. Number of iterations for each print progress.
  72. """
  73. self.n_iter = n_iter
  74. self.n_data = n_data
  75. self.n_print = n_print
  76. self.losses = tf.constant(0.0)
  77. loss = self.build_loss()
  78. # Use ADAM with a decaying scale factor
  79. global_step = tf.Variable(0, trainable=False)
  80. starter_learning_rate = 0.1
  81. learning_rate = tf.train.exponential_decay(starter_learning_rate,
  82. global_step,
  83. 100, 0.9, staircase=True)
  84. self.train = tf.train.AdamOptimizer(learning_rate).minimize(
  85. loss, global_step=global_step)
  86. init = tf.initialize_all_variables()
  87. sess = tf.Session()
  88. sess.run(init)
  89. return sess
  90. def update(self, sess):
  91. _, loss = sess.run([self.train, self.losses])
  92. return loss
  93. def print_progress(self, t, losses, sess):
  94. if t % self.n_print == 0:
  95. print("iter %d loss %.2f " % (t, np.mean(losses)))
  96. self.variational.print_params(sess)
  97. def build_loss(self):
  98. raise NotImplementedError()
  99. class MFVI(VariationalInference):
  100. # TODO this isn't MFVI so much as VI where q is analytic
  101. """
  102. Mean-field variational inference
  103. (Ranganath et al., 2014)
  104. """
  105. def __init__(self, *args, **kwargs):
  106. VariationalInference.__init__(self, *args, **kwargs)
  107. def initialize(self, n_minibatch=1, score=None, *args, **kwargs):
  108. # TODO if score=True, make Normal do sess.run()
  109. """
  110. Parameters
  111. ----------
  112. n_minibatch: int, optional
  113. Number of samples from variational model for calculating
  114. stochastic gradients.
  115. score: bool, optional
  116. Whether to force inference to use the score function
  117. gradient estimator. Otherwise default is to use the
  118. reparameterization gradient if available.
  119. """
  120. if score is None and self.variational.is_reparam:
  121. self.score = False
  122. else:
  123. self.score = True
  124. self.n_minibatch = n_minibatch
  125. self.samples = tf.placeholder(shape=(self.n_minibatch, self.variational.num_vars),
  126. dtype=tf.float32,
  127. name='samples')
  128. return VariationalInference.initialize(self, *args, **kwargs)
  129. def update(self, sess):
  130. if self.score:
  131. # TODO the mapping should go here before sampling.
  132. # In principle the mapping should go here but we don't
  133. # want to have to run this twice. Also I've noticed that it
  134. # is significantly slower if I have it here for some reason,
  135. # so I'm leaving this as an open problem.
  136. #x = self.data.sample(self.n_data)
  137. #self.variational.set_params(self.variational.mapping(x))
  138. samples = self.variational.sample(self.samples.get_shape(), sess)
  139. else:
  140. samples = self.variational.sample_noise(self.samples.get_shape())
  141. _, loss = sess.run([self.train, self.losses], {self.samples: samples})
  142. return loss
  143. def build_loss(self):
  144. if self.score and hasattr(self.variational, 'entropy'):
  145. return self.build_score_loss_entropy()
  146. elif self.score:
  147. return self.build_score_loss()
  148. elif not self.score and hasattr(self.variational, 'entropy'):
  149. return self.build_reparam_loss_entropy()
  150. else:
  151. return self.build_reparam_loss()
  152. def build_score_loss(self):
  153. """
  154. Loss function to minimize, whose gradient is a stochastic
  155. gradient based on the score function estimator.
  156. (Paisley et al., 2012)
  157. """
  158. # ELBO = E_{q(z; lambda)} [ log p(x, z) - log q(z; lambda) ]
  159. x = self.data.sample(self.n_data)
  160. self.variational.set_params(self.variational.mapping(x))
  161. q_log_prob = tf.zeros([self.n_minibatch], dtype=tf.float32)
  162. for i in range(self.variational.num_vars):
  163. q_log_prob += self.variational.log_prob_zi(i, self.samples)
  164. self.losses = self.model.log_prob(x, self.samples) - q_log_prob
  165. return -tf.reduce_mean(q_log_prob * tf.stop_gradient(self.losses))
  166. def build_reparam_loss(self):
  167. """
  168. Loss function to minimize, whose gradient is a stochastic
  169. gradient based on the reparameterization trick.
  170. (Kingma and Welling, 2014)
  171. """
  172. # ELBO = E_{q(z; lambda)} [ log p(x, z) - log q(z; lambda) ]
  173. x = self.data.sample(self.n_data)
  174. self.variational.set_params(self.variational.mapping(x))
  175. z = self.variational.reparam(self.samples)
  176. q_log_prob = tf.zeros([self.n_minibatch], dtype=tf.float32)
  177. for i in range(self.variational.num_vars):
  178. q_log_prob += self.variational.log_prob_zi(i, z)
  179. self.losses = self.model.log_prob(x, z) - q_log_prob
  180. return -tf.reduce_mean(self.losses)
  181. def build_score_loss_entropy(self):
  182. """
  183. Loss function to minimize, whose gradient is a stochastic
  184. gradient based on the score function estimator.
  185. """
  186. # ELBO = E_{q(z; lambda)} [ log p(x, z) ] + H(q(z; lambda))
  187. # where entropy is analytic
  188. x = self.data.sample(self.n_data)
  189. self.variational.set_params(self.variational.mapping(x))
  190. q_log_prob = tf.zeros([self.n_minibatch], dtype=tf.float32)
  191. for i in range(self.variational.num_vars):
  192. q_log_prob += self.variational.log_prob_zi(i, self.samples)
  193. x = self.data.sample(self.n_data)
  194. p_log_prob = self.model.log_prob(x, self.samples)
  195. q_entropy = self.variational.entropy()
  196. self.losses = p_log_prob + q_entropy
  197. return tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) + \
  198. q_entropy
  199. def build_reparam_loss_entropy(self):
  200. """
  201. Loss function to minimize, whose gradient is a stochastic
  202. gradient based on the reparameterization trick.
  203. """
  204. # ELBO = E_{q(z; lambda)} [ log p(x, z) ] + H(q(z; lambda))
  205. # where entropy is analytic
  206. x = self.data.sample(self.n_data)
  207. self.variational.set_params(self.variational.mapping(x))
  208. z = self.variational.reparam(self.samples)
  209. self.losses = self.model.log_prob(x, z) + self.variational.entropy()
  210. return -tf.reduce_mean(self.losses)
  211. class VAE(VariationalInference):
  212. # TODO refactor into MFVI
  213. def __init__(self, *args, **kwargs):
  214. VariationalInference.__init__(self, *args, **kwargs)
  215. def initialize(self, n_data=None):
  216. # TODO refactor to use VariationalInference's initialize()
  217. self.n_data = n_data
  218. # TODO don't fix number of covariates
  219. self.x = tf.placeholder(tf.float32, [self.n_data, 28 * 28])
  220. self.losses = tf.constant(0.0)
  221. loss = self.build_loss()
  222. optimizer = tf.train.AdamOptimizer(1e-2, epsilon=1.0)
  223. # TODO move this to not rely on Pretty Tensor
  224. self.train = pt.apply_optimizer(optimizer, losses=[loss])
  225. init = tf.initialize_all_variables()
  226. sess = tf.Session()
  227. sess.run(init)
  228. return sess
  229. def update(self, sess):
  230. x = self.data.sample(self.n_data)
  231. _, loss_value = sess.run([self.train, self.losses], {self.x: x})
  232. return loss_value
  233. def build_loss(self):
  234. # ELBO = E_{q(z | x)} [ log p(x | z) ] - KL(q(z | x) || p(z))
  235. # In general, there should be a scale factor due to data
  236. # subsampling, so that
  237. # ELBO = N / M * ( ELBO using x_b )
  238. # where x^b is a mini-batch of x, with sizes M and N respectively.
  239. # This is absorbed into the learning rate.
  240. with tf.variable_scope("model") as scope:
  241. self.variational.set_params(self.variational.mapping(self.x))
  242. z = self.variational.sample([self.n_data, self.variational.num_vars])
  243. self.losses = tf.reduce_sum(self.model.log_likelihood(self.x, z)) - \
  244. kl_multivariate_normal(self.variational.m,
  245. self.variational.s)
  246. return -self.losses
  247. class KLpq(VariationalInference):
  248. """
  249. Kullback-Leibler divergence from posterior to variational model,
  250. KL( p(z |x) || q(z) ).
  251. (Cappe et al., 2008)
  252. """
  253. def __init__(self, *args, **kwargs):
  254. VariationalInference.__init__(self, *args, **kwargs)
  255. def initialize(self, n_minibatch=1, *args, **kwargs):
  256. self.n_minibatch = n_minibatch
  257. self.samples = tf.placeholder(shape=(self.n_minibatch, self.variational.num_vars),
  258. dtype=tf.float32,
  259. name='samples')
  260. return VariationalInference.initialize(self, *args, **kwargs)
  261. def update(self, sess):
  262. samples = self.variational.sample(self.samples.get_shape(), sess)
  263. _, loss = sess.run([self.train, self.losses], {self.samples: samples})
  264. return loss
  265. def build_loss(self):
  266. """
  267. Loss function to minimize, whose gradient is a stochastic
  268. gradient inspired by adaptive importance sampling.
  269. """
  270. # loss = E_{q(z; lambda)} [ w_norm(z; lambda) *
  271. # ( log p(x, z) - log q(z; lambda) ) ]
  272. # where
  273. # w_norm(z; lambda) = w(z; lambda) / sum_z( w(z; lambda) )
  274. # w(z; lambda) = p(x, z) / q(z; lambda)
  275. #
  276. # gradient = - E_{q(z; lambda)} [ w_norm(z; lambda) *
  277. # grad_{lambda} log q(z; lambda) ]
  278. x = self.data.sample(self.n_data)
  279. self.variational.set_params(self.variational.mapping(x))
  280. q_log_prob = tf.zeros([self.n_minibatch], dtype=tf.float32)
  281. for i in range(self.variational.num_vars):
  282. q_log_prob += self.variational.log_prob_zi(i, self.samples)
  283. # 1/B sum_{b=1}^B grad_log_q * w_norm
  284. # = 1/B sum_{b=1}^B grad_log_q * exp{ log(w_norm) }
  285. log_w = self.model.log_prob(x, self.samples) - q_log_prob
  286. # normalized log importance weights
  287. log_w_norm = log_w - log_sum_exp(log_w)
  288. w_norm = tf.exp(log_w_norm)
  289. self.losses = w_norm * log_w
  290. return -tf.reduce_mean(q_log_prob * tf.stop_gradient(w_norm))
  291. class MAP(VariationalInference):
  292. """
  293. Maximum a posteriori
  294. """
  295. def __init__(self, model, data=Data(), transform=tf.identity):
  296. variational = PointMass(model.num_vars, transform)
  297. VariationalInference.__init__(self, model, variational, data)
  298. def build_loss(self):
  299. x = self.data.sample(self.n_data)
  300. self.variational.set_params(self.variational.mapping(x))
  301. z = self.variational.get_params()
  302. self.losses = self.model.log_prob(x, z)
  303. return -tf.reduce_mean(self.losses)