/Archive/Testing/Peter/Pytorch04/Models/BBB_scale_mix_v2.py

https://bitbucket.org/RamiroCope/thesis_repo · Python · 254 lines · 162 code · 38 blank · 54 comment · 19 complexity · f13de435f08c591c14678ec984fe6c96 MD5 · raw file

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. import numpy as np
  8. from scipy.special import gammaln, psi
  9. #############################
  10. ## Bayesian Neural Network ##
  11. #############################
  12. class BNN(nn.Module):
  13. def __init__(self, input_size,
  14. hidden_sizes,
  15. output_size,
  16. act_func,
  17. prior_prec=1.0,
  18. prec_init=1.0,
  19. alpha_prior = 5.0,
  20. beta_prior = 5.0,
  21. alpha_init = 5.0,
  22. beta_init = 5.0):
  23. super(type(self), self).__init__()
  24. self.input_size = input_size
  25. sigma_prior = 1.0/math.sqrt(prior_prec)
  26. sigma_init = 1.0/math.sqrt(prec_init)
  27. if output_size:
  28. self.output_size = output_size
  29. self.squeeze_output = False
  30. else :
  31. self.output_size = 1
  32. self.squeeze_output = True
  33. self.act = F.tanh if act_func == "tanh" else F.relu
  34. if len(hidden_sizes) == 0:
  35. self.hidden_layers = []
  36. self.output_layer = StochasticLinear(self.input_size,
  37. self.output_size,
  38. sigma_prior = sigma_prior,
  39. alpha_prior = alpha_prior,
  40. beta_prior = beta_prior,
  41. sigma_init = sigma_init,
  42. alpha_init = alpha_init,
  43. beta_init = beta_init)
  44. else:
  45. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, alpha_prior = alpha_prior, beta_prior = beta_prior, sigma_init = sigma_init, alpha_init = alpha_init, beta_init = beta_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  46. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, alpha_prior = alpha_prior, beta_prior = beta_prior, sigma_init = sigma_init, alpha_init = alpha_init, beta_init = beta_init)
  47. def forward(self, x, training = True):
  48. x = x.view(-1,self.input_size)
  49. out = x
  50. for layer in self.hidden_layers:
  51. out = self.act(layer(out, training = training))
  52. logits = self.output_layer(out, training = training)
  53. if self.squeeze_output:
  54. logits = torch.squeeze(logits)
  55. return logits
  56. def kl_divergence(self):
  57. kl = 0
  58. for layer in self.hidden_layers:
  59. kl += layer.kl_divergence()
  60. kl += self.output_layer.kl_divergence()
  61. return(kl)
  62. def return_params(self):
  63. mus = []
  64. alphas = []
  65. betas = []
  66. sigmas = []
  67. for layer in self.hidden_layers:
  68. mu, alpha, beta, sigma = layer.return_parameters()
  69. mus += [mu.numpy()]
  70. alphas += [alpha.numpy()]
  71. betas += [beta.numpy()]
  72. sigmas += [sigma.numpy()]
  73. mu, alpha, beta, sigma = self.output_layer.return_parameters()
  74. mus += [mu.numpy()]
  75. alphas += [alpha.numpy()]
  76. betas += [beta.numpy()]
  77. sigmas += [sigma.numpy()]
  78. return mus, alphas, betas, sigmas
  79. ###############################################
  80. ## Gaussian Mean-Field Linear Transformation ##
  81. ###############################################
  82. class StochasticLinear(nn.Module):
  83. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  84. Args:
  85. in_features: size of each input sample
  86. out_features: size of each output sample
  87. bias: If set to False, the layer will not learn an additive bias.
  88. Default: ``True``
  89. Shape:
  90. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  91. additional dimensions
  92. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  93. are the same shape as the input.
  94. Attributes:
  95. weight: the learnable weights of the module of shape
  96. `(out_features x in_features)`
  97. bias: the learnable bias of the module of shape `(out_features)`
  98. Examples::
  99. >>> m = nn.Linear(20, 30)
  100. >>> input = torch.randn(128, 20)
  101. >>> output = m(input)
  102. >>> print(output.size())
  103. """
  104. def __init__(self, in_features,
  105. out_features,
  106. sigma_prior=1.0,
  107. alpha_prior = 5.0,
  108. beta_prior = 5.0,
  109. sigma_init=1.0,
  110. alpha_init = 5.0,
  111. beta_init = 5.0,
  112. bias=True):
  113. super(type(self), self).__init__()
  114. self.count = 0
  115. self.in_features = in_features
  116. self.out_features = out_features
  117. self.sigma_prior = sigma_prior
  118. self.alpha_prior = alpha_prior
  119. self.beta_prior = beta_prior
  120. self.sigma_init = sigma_init
  121. self.alpha_init = alpha_init
  122. self.beta_init = beta_init
  123. self.weight_spsigma = None
  124. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  125. self.weight_spalpha = Parameter(torch.Tensor(out_features, in_features))
  126. self.weight_spbeta = Parameter(torch.Tensor(out_features, in_features))
  127. if bias:
  128. self.bias = True
  129. self.bias_spsigma = None
  130. self.bias_mu = Parameter(torch.Tensor(out_features))
  131. self.bias_spalpha = Parameter(torch.Tensor(out_features))
  132. self.bias_spbeta = Parameter(torch.Tensor(out_features))
  133. else:
  134. self.register_parameter('bias', None)
  135. self.reset_parameters()
  136. def reset_parameters(self):
  137. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  138. self.weight_mu.data.uniform_(-stdv, stdv)
  139. if self.bias is not None:
  140. self.bias_mu.data.uniform_(-stdv, stdv)
  141. self.weight_spalpha.data.uniform_(self.alpha_init-0.1,self.alpha_init+0.1)#fill_(self.alpha_init)
  142. self.weight_spbeta.data.uniform_(self.beta_init-0.1,self.beta_init+0.1)#fill_(self.beta_init)
  143. if self.bias is not None:
  144. self.bias_spalpha.data.uniform_(self.alpha_init-0.1,self.alpha_init+0.1)#fill_(self.alpha_init)
  145. self.bias_spbeta.data.uniform_(self.beta_init-0.1,self.beta_init+0.1)#fill_(self.beta_init)
  146. def forward(self, input, training = True):
  147. if not training:
  148. self.weight = self.weight_mu.data
  149. if self.bias is not None:
  150. self.bias = self.bias_mu.data
  151. return F.linear(input, self.weight, self.bias)
  152. # Construct Gamma distribution for reparameterization sampling
  153. gamma_dist = distributions.Gamma(F.softplus(self.weight_spalpha), F.softplus(self.weight_spbeta))
  154. # Sample variance parameters - Note here that spsigma is inverse Gamma
  155. # and that spsigma is considered the standard deviation of a Normal
  156. self.weight_spsigma = torch.sqrt(1.0/gamma_dist.rsample())
  157. epsilon_W = torch.normal(mean=torch.zeros_like(self.weight_mu), std=1.0)
  158. self.weight = self.weight_mu + self.weight_spsigma * epsilon_W
  159. if self.bias is not None:
  160. gamma_dist = distributions.Gamma(F.softplus(self.bias_spalpha), F.softplus(self.bias_spbeta))
  161. self.bias_spsigma = torch.sqrt(1.0/gamma_dist.rsample())
  162. epsilon_b = torch.normal(mean=torch.zeros_like(self.bias_mu), std=1.0)
  163. self.bias = self.bias_mu + self.bias_spsigma * epsilon_b
  164. return F.linear(input, self.weight, self.bias)
  165. # def _kl_gaussian(self, p_mu, p_sigma, q_mu, q_sigma):
  166. # var_ratio = (p_sigma / q_sigma).pow(2)
  167. # t1 = ((p_mu - q_mu) / q_sigma).pow(2)
  168. # return 0.5 * torch.sum((var_ratio + t1 - 1 - var_ratio.log()))
  169. #
  170. # def kl_divergence(self):
  171. # mu = self.weight_mu
  172. # sigma = F.softplus(self.weight_spsigma)
  173. # mu0 = torch.zeros_like(mu)
  174. # sigma0 = torch.ones_like(sigma) * self.sigma_prior
  175. # kl = self._kl_gaussian(p_mu = mu, p_sigma = sigma, q_mu = mu0, q_sigma = sigma0)
  176. # if self.bias is not None:
  177. # mu = self.bias_mu
  178. # sigma = F.softplus(self.bias_spsigma)
  179. # mu0 = torch.zeros_like(mu)
  180. # sigma0 = torch.ones_like(sigma) * self.sigma_prior
  181. # kl += self._kl_gaussian(p_mu = mu, p_sigma = sigma, q_mu = mu0, q_sigma = sigma0)
  182. # return kl
  183. def kl_divergence(self):
  184. # define posterior parameters
  185. mu = self.weight_mu
  186. sigma = self.weight_spsigma
  187. alpha = F.softplus(self.weight_spalpha)
  188. beta = F.softplus(self.weight_spbeta)
  189. # define prior parameters
  190. alpha0 = torch.ones_like(alpha)*self.alpha_prior
  191. beta0 = torch.ones_like(beta)*self.beta_prior
  192. q_w = distributions.Normal(mu, sigma)
  193. kl = (E_q_w_lambda + E_q_lambda - E_p_w_lambda - E_p_lambda).sum()
  194. if self.bias is not None:
  195. # define posterior parameters
  196. mu = self.bias_mu
  197. sigma = self.bias_spsigma
  198. alpha = F.softplus(self.bias_spalpha)
  199. beta = F.softplus(self.bias_spbeta)
  200. # define prior parameters
  201. alpha0 = torch.ones_like(alpha)*self.alpha_prior
  202. beta0 = torch.ones_like(beta)*self.beta_prior
  203. # Compute E_q( log p(lambda) )
  204. # Compute the KL divergence
  205. kl += (E_q_w_lambda + E_q_lambda - E_p_w_lambda - E_p_lambda).sum()
  206. return kl
  207. def return_parameters(self):
  208. mu = self.weight_mu.data
  209. alpha = self.weight_spalpha.data
  210. beta= self.weight_spbeta.data
  211. sigma = self.weight_spsigma.data
  212. return mu, alpha, beta, sigma
  213. def extra_repr(self):
  214. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  215. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  216. )