PageRenderTime 26ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/Archive/Testing/Peter/Pytorch04/Models/BBB_HS_test.py

https://bitbucket.org/RamiroCope/thesis_repo
Python | 265 lines | 155 code | 58 blank | 52 comment | 22 complexity | f988c2f32ad6ec66db5484c98f7c22d3 MD5 | raw file
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. import numpy as np
  8. #############################
  9. ## Bayesian Neural Network ##
  10. #############################
  11. class BNN(nn.Module):
  12. def __init__(self, input_size, hidden_sizes, output_size, act_func, prior_prec=1.0, prec_init=1.0, clip_var = None):
  13. super(type(self), self).__init__()
  14. self.input_size = input_size
  15. self.clip_var = clip_var
  16. sigma_prior = 1.0/math.sqrt(prior_prec)
  17. sigma_init = 1.0/math.sqrt(prec_init)
  18. if output_size:
  19. self.output_size = output_size
  20. self.squeeze_output = False
  21. else :
  22. self.output_size = 1
  23. self.squeeze_output = True
  24. self.act = F.tanh if act_func == "tanh" else F.relu
  25. if len(hidden_sizes) == 0:
  26. self.hidden_layers = []
  27. self.output_layer = StochasticLinear(self.input_size, self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var)
  28. else:
  29. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  30. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var)
  31. def forward(self, x, training = True):
  32. x = x.view(-1,self.input_size)
  33. out = x
  34. for layer in self.hidden_layers:
  35. out = self.act(layer(out, training = training))
  36. logits = self.output_layer(out, training = training)
  37. if self.squeeze_output:
  38. logits = torch.squeeze(logits)
  39. return logits
  40. def kl_divergence(self):
  41. kl = 0
  42. for layer in self.hidden_layers:
  43. kl += layer.kl_divergence()
  44. kl += self.output_layer.kl_divergence()
  45. return(kl)
  46. def return_params(self):
  47. mus = []
  48. sigmas = []
  49. weights = []
  50. for layer in self.hidden_layers:
  51. mu, sigma, weight = layer.return_parameters()
  52. mus += [mu.numpy()]
  53. sigmas += [F.softplus(sigma).numpy()]
  54. weights += [weight.numpy()]
  55. mu, sigma, weight = self.output_layer.return_parameters()
  56. mus += [mu.numpy()]
  57. sigmas += [F.softplus(sigma).numpy()]
  58. weights += [weight.numpy()]
  59. return mus, sigmas, weights
  60. def clip_vars(self):
  61. if self.clip_var:
  62. for i,layer in enumerate(self.hidden_layers):
  63. if i == 0:
  64. layer.clip_variances()
  65. ###############################################
  66. ## Gaussian Mean-Field Linear Transformation ##
  67. ###############################################
  68. class StochasticLinear(nn.Module):
  69. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  70. Args:
  71. in_features: size of each input sample
  72. out_features: size of each output sample
  73. bias: If set to False, the layer will not learn an additive bias.
  74. Default: ``True``
  75. Shape:
  76. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  77. additional dimensions
  78. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  79. are the same shape as the input.
  80. Attributes:
  81. weight: the learnable weights of the module of shape
  82. `(out_features x in_features)`
  83. bias: the learnable bias of the module of shape `(out_features)`
  84. Examples::
  85. >>> m = nn.Linear(20, 30)
  86. >>> input = torch.randn(128, 20)
  87. >>> output = m(input)
  88. >>> print(output.size())
  89. """
  90. def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=True, clip_var = None):
  91. super(type(self), self).__init__()
  92. self.in_features = in_features
  93. self.out_features = out_features
  94. print("out ", out_features)
  95. print("in ", in_features)
  96. self.clip_var = clip_var
  97. self.tau = sigma_prior
  98. self.sigma_init = sigma_init
  99. #M_w and Sigma_w
  100. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  101. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  102. #Variational parameters
  103. self.mu_sa = Parameter(torch.Tensor(1))#torch.Tensor(in_features)
  104. self.mu_sb = Parameter(torch.Tensor(1))
  105. self.sigma_sa = Parameter(torch.Tensor(1))
  106. self.sigma_sb = Parameter(torch.Tensor(1))
  107. self.mu_alpha = Parameter(torch.Tensor(in_features))
  108. self.mu_beta = Parameter(torch.Tensor(in_features))
  109. self.sigma_alpha = Parameter(torch.Tensor(in_features))
  110. self.sigma_beta = Parameter(torch.Tensor(in_features))
  111. if bias:
  112. self.bias = True
  113. self.bias_mu = Parameter(torch.Tensor(out_features))
  114. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  115. else:
  116. self.register_parameter('bias', None)
  117. self.reset_parameters()
  118. def clip_variances(self):
  119. if self.clip_var:
  120. # print("clipping vars with %.3f" % (self.clip_var))
  121. self.weight_spsigma.data.clamp_(max=self.clip_var)
  122. self.bias_spsigma.data.clamp_(max=self.clip_var)
  123. def reset_parameters(self):
  124. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  125. self.weight_mu.data.uniform_(-stdv, stdv)
  126. if self.bias is not None:
  127. self.bias_mu.data.uniform_(-stdv, stdv)
  128. self.weight_spsigma.data.fill_(math.exp(-9))#normal_(math.exp(-9),1e-2) #use (math.exp(-8),1e-2)
  129. if self.bias is not None:
  130. self.bias_spsigma.data.fill_(math.exp(-9))#normal_(math.exp(-6),1e-2) #use (math.exp(-8),1e-2)
  131. self.mu_sa.data.uniform_(-7,-7)#normal_(-6,1e-2) #use normal(-2,1e-2)
  132. self.mu_sb.data.uniform_(-7,-7)#normal_(-6, 1e-2) #use normal(-2,1e-2)
  133. self.sigma_sa.data.normal_(math.exp(-9), 1e-2)
  134. self.sigma_sb.data.normal_(math.exp(-9), 1e-2)
  135. self.mu_alpha.data.uniform_(-3,-3)#normal_(-6, 1e-2) #use normal(-2,1e-2)
  136. self.mu_beta.data.uniform_(-3,-3)#normal_(-6, 1e-2) #use normal(-2,1e-2)
  137. self.sigma_alpha.data.normal_(math.exp(-9), 1e-5)
  138. self.sigma_beta.data.normal_(math.exp(-9), 1e-5)
  139. def forward(self, input, training = True):
  140. batch_size = input.size()[0]
  141. if not training:
  142. # mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
  143. # log_s = mu_s
  144. # mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
  145. # Z = torch.exp(mu_z.repeat(batch_size,1))
  146. # H = input*Z
  147. # M_h = F.linear(H, self.weight_mu, self.bias_mu)
  148. # self.weight = mu_z
  149. # return M_h
  150. mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
  151. log_s = mu_s
  152. mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
  153. sigma_z = torch.sqrt(0.25*F.softplus(self.sigma_alpha) + 0.25*F.softplus(self.sigma_beta))
  154. Z = torch.exp(mu_z + 0.5*sigma_z)
  155. # Z = torch.exp(mu_z.repeat(batch_size,1) + 0.5*sigma_z.repeat(batch_size,1))
  156. H = input*Z
  157. M_h = F.linear(H, self.weight_mu, self.bias_mu)
  158. self.weight = Z
  159. return M_h
  160. else:
  161. mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
  162. sigma_s = torch.sqrt(0.25*F.softplus(self.sigma_sa) + 0.25*F.softplus(self.sigma_sb))
  163. # noise
  164. e = torch.normal(mean=torch.zeros_like(self.sigma_sa), std=1.0)
  165. E = torch.normal(mean=torch.zeros_like(input), std=1.0)
  166. log_s = mu_s + sigma_s*e
  167. mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
  168. sigma_z = torch.sqrt(0.25*F.softplus(self.sigma_alpha) + 0.25*F.softplus(self.sigma_beta))
  169. Z = torch.exp(mu_z.repeat(batch_size,1) + sigma_z.repeat(batch_size,1)*E)
  170. H = input*Z
  171. M_h = F.linear(H, self.weight_mu, self.bias_mu)
  172. V_h = F.linear(H**2, F.softplus(self.weight_spsigma), F.softplus(self.bias_spsigma))
  173. E_final = torch.normal(mean=torch.zeros_like(V_h), std=1.0)
  174. # self.weight = mu_z
  175. return M_h + torch.sqrt(V_h)*E_final
  176. def kl_divergence(self):
  177. # KL(q(w|z)||p(w|z))
  178. KLD_element = -0.5*F.softplus(self.weight_spsigma).log() + 0.5 * (F.softplus(self.weight_spsigma) + self.weight_mu**2) - 0.5
  179. KLD = torch.sum(KLD_element)
  180. # print(KLD_element.sum())
  181. if self.bias is not None:
  182. # KL bias
  183. KLD_element = -0.5*F.softplus(self.bias_spsigma).log() + 0.5 * (F.softplus(self.bias_spsigma) + self.bias_mu**2) - 0.5
  184. KLD += torch.sum(KLD_element)
  185. # print(KLD_element.sum())
  186. # We change sign since the KL divergences are negative -D_KL
  187. KL_sa = -torch.sum(math.log(self.tau) - 1.0/self.tau * torch.exp( self.mu_sa + 0.5*F.softplus(self.sigma_sa) ) + 0.5*self.mu_sa + 0.5*F.softplus(self.sigma_sa).log() + 0.5 + 0.5*math.log(2))
  188. KL_sb = -torch.sum(-math.exp(0.5*F.softplus(self.sigma_sb) - self.mu_sb) - 0.5*self.mu_sb + 0.5*F.softplus(self.sigma_sb).log() + 0.5 + 0.5*math.log(2))
  189. KL_alpha = -torch.sum(- torch.exp( self.mu_alpha + 0.5*F.softplus(self.sigma_alpha) ) + 0.5*self.mu_alpha + 0.5*F.softplus(self.sigma_alpha).log() + 0.5 + 0.5*math.log(2))
  190. KL_beta = -torch.sum(-torch.exp( 0.5*F.softplus(self.sigma_beta) - self.mu_beta ) - 0.5*self.mu_beta + 0.5*F.softplus(self.sigma_beta).log() + 0.5 + 0.5*math.log(2))
  191. # print(KL_sa)
  192. # print(KL_sb)
  193. # print(KL_alpha)
  194. # print(KL_beta)
  195. # print("\n")
  196. KLD += KL_sa + KL_sb + KL_alpha + KL_beta
  197. return KLD
  198. def return_parameters(self):
  199. mu = self.weight_mu.data
  200. sigma = self.weight_spsigma.data
  201. weight = self.weight.data
  202. # print(weight)
  203. return mu, sigma, weight
  204. def extra_repr(self):
  205. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  206. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  207. )