PageRenderTime 515ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/Archive/Testing/Peter/Pytorch04/Models/BBB_HS_linear.py

https://bitbucket.org/RamiroCope/thesis_repo
Python | 260 lines | 164 code | 52 blank | 44 comment | 22 complexity | 3bbb8758ecb0bcfdaab80cbe6416fd1f MD5 | raw file
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. import numpy as np
  8. #############################
  9. ## Bayesian Neural Network ##
  10. #############################
  11. class BNN(nn.Module):
  12. def __init__(self, input_size,
  13. hidden_sizes,
  14. output_size=None,
  15. act_func=None,
  16. prior_prec=1.0,
  17. prec_init=1.0,
  18. clip_var = None):
  19. super(type(self), self).__init__()
  20. self.input_size = input_size
  21. self.clip_var = clip_var
  22. sigma_prior = 1.0/math.sqrt(prior_prec)
  23. sigma_init = 1.0/math.sqrt(prec_init)
  24. if output_size:
  25. self.output_size = output_size
  26. self.squeeze_output = False
  27. else :
  28. self.output_size = 1
  29. self.squeeze_output = True
  30. self.act = F.tanh if act_func == "tanh" else F.relu
  31. if len(hidden_sizes) == 0:
  32. self.hidden_layers = []
  33. self.output_layer = StochasticLinear(self.input_size,
  34. self.output_size,
  35. sigma_prior = sigma_prior,
  36. sigma_init = sigma_init,
  37. clip_var = clip_var)
  38. else:
  39. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  40. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var)
  41. def forward(self, x, training = False):
  42. # x = x.view(-1,self.input_size)
  43. out = x
  44. for layer in self.hidden_layers:
  45. out = layer(out, training = training)
  46. logits = self.output_layer(out, training = training)
  47. if self.squeeze_output:
  48. logits = torch.squeeze(logits)
  49. return logits
  50. def kl_divergence(self):
  51. kl = 0
  52. for layer in self.hidden_layers:
  53. kl += layer.kl_divergence()
  54. kl += self.output_layer.kl_divergence()
  55. return(kl)
  56. def return_params(self):
  57. mus = []
  58. sigmas = []
  59. weights = []
  60. for layer in self.hidden_layers:
  61. mu, sigma, weight = layer.return_parameters()
  62. mus += [mu.numpy()]
  63. sigmas += [F.softplus(sigma).numpy()]
  64. weights += [weight.numpy()]
  65. mu, sigma, weight = self.output_layer.return_parameters()
  66. mus += [mu.numpy()]
  67. sigmas += [F.softplus(sigma).numpy()]
  68. weights += [weight.numpy()]
  69. print(weight.shape)
  70. print(mu.shape)
  71. return mus, sigmas, weights
  72. def clip_vars(self):
  73. if self.clip_var:
  74. for i,layer in enumerate(self.hidden_layers):
  75. if i == 0:
  76. layer.clip_variances()
  77. ###############################################
  78. ## Gaussian Mean-Field Linear Transformation ##
  79. ###############################################
  80. class StochasticLinear(nn.Module):
  81. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  82. Args:
  83. in_features: size of each input sample
  84. out_features: size of each output sample
  85. bias: If set to False, the layer will not learn an additive bias.
  86. Default: ``True``
  87. Shape:
  88. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  89. additional dimensions
  90. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  91. are the same shape as the input.
  92. Attributes:
  93. weight: the learnable weights of the module of shape
  94. `(out_features x in_features)`
  95. bias: the learnable bias of the module of shape `(out_features)`
  96. Examples::
  97. >>> m = nn.Linear(20, 30)
  98. >>> input = torch.randn(128, 20)
  99. >>> output = m(input)
  100. >>> print(output.size())
  101. """
  102. def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=False, clip_var = None):
  103. super(type(self), self).__init__()
  104. self.in_features = in_features
  105. self.out_features = out_features
  106. self.clip_var = clip_var
  107. self.tau = sigma_prior
  108. self.sigma_init = sigma_init
  109. #M_w and Sigma_w
  110. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  111. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  112. #Variational parameters
  113. self.mu_sa = Parameter(torch.Tensor(1))#torch.Tensor(in_features)
  114. self.mu_sb = Parameter(torch.Tensor(1))
  115. self.sigma_sa = Parameter(torch.Tensor(1))
  116. self.sigma_sb = Parameter(torch.Tensor(1))
  117. self.mu_alpha = Parameter(torch.Tensor(in_features))
  118. self.mu_beta = Parameter(torch.Tensor(in_features))
  119. self.sigma_alpha = Parameter(torch.Tensor(in_features))
  120. self.sigma_beta = Parameter(torch.Tensor(in_features))
  121. if bias:
  122. self.bias = True
  123. self.bias_mu = Parameter(torch.Tensor(out_features))
  124. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  125. else:
  126. self.register_parameter('bias', None)
  127. self.reset_parameters()
  128. def clip_variances(self):
  129. if self.clip_var:
  130. # print("clipping vars with %.3f" % (self.clip_var))
  131. self.weight_spsigma.data.clamp_(max=self.clip_var)
  132. self.bias_spsigma.data.clamp_(max=self.clip_var)
  133. def reset_parameters(self):
  134. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  135. self.weight_mu.data.uniform_(-stdv, stdv)
  136. if self.bias is not None:
  137. self.bias_mu.data.uniform_(-stdv, stdv)
  138. self.weight_spsigma.data.normal_(math.exp(-3),1e-2)
  139. if self.bias is not None:
  140. self.bias_spsigma.data.normal_(math.exp(-3),1e-2)
  141. self.mu_sa.data.normal_(-1,1e-2)
  142. self.mu_sb.data.normal_(-1, 1e-2)
  143. self.sigma_sa.data.normal_(math.exp(-3), 1e-2)
  144. self.sigma_sb.data.normal_(math.exp(-3), 1e-2)
  145. self.mu_alpha.data.normal_(0, 1e-2)
  146. self.mu_beta.data.normal_(0, 1e-2)
  147. self.sigma_alpha.data.normal_(math.exp(-3), 1e-2)
  148. self.sigma_beta.data.normal_(math.exp(-3), 1e-2)
  149. def forward(self, input, training = True):
  150. batch_size = input.size()[0]
  151. if not training:
  152. mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
  153. log_s = mu_s
  154. mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
  155. Z = torch.exp(mu_z.repeat(batch_size,1))
  156. H = input*Z
  157. M_h = F.linear(H, self.weight_mu)
  158. self.weight = mu_z
  159. return M_h
  160. else:
  161. mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
  162. sigma_s = torch.sqrt(0.25*F.softplus(self.sigma_sa) + 0.25*F.softplus(self.sigma_sb))
  163. # noise
  164. e = torch.normal(mean=torch.zeros_like(self.sigma_sa), std=1.0)
  165. E = torch.normal(mean=torch.zeros_like(input), std=1.0)
  166. log_s = mu_s + sigma_s*e
  167. mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
  168. sigma_z = torch.sqrt(0.25*F.softplus(self.sigma_alpha) + 0.25*F.softplus(self.sigma_beta))
  169. Z = torch.exp(mu_z.repeat(batch_size,1) + sigma_z.repeat(batch_size,1)*E)
  170. H = input*Z
  171. M_h = F.linear(H, self.weight_mu)
  172. V_h = F.linear(H**2, F.softplus(self.weight_spsigma))
  173. E_final = torch.normal(mean=torch.zeros_like(V_h), std=1.0)
  174. self.weight = mu_z
  175. # print(self.weight.shape)
  176. return M_h + torch.sqrt(V_h)*E_final
  177. def kl_divergence(self):
  178. # KL(q(w|z)||p(w|z))
  179. KLD_element = -0.5*F.softplus(self.weight_spsigma).log() + 0.5 * (F.softplus(self.weight_spsigma) + self.weight_mu**2) - 0.5
  180. KLD = torch.sum(KLD_element)
  181. # print(KLD_element.sum())
  182. if self.bias is not None:
  183. # KL bias
  184. KLD_element = -0.5*F.softplus(self.bias_spsigma).log() + 0.5 * (F.softplus(self.bias_spsigma) + self.bias_mu**2) - 0.5
  185. KLD += torch.sum(KLD_element)
  186. # print(KLD_element.sum())
  187. # We change sign since the KL divergences are negative -D_KL
  188. KL_sa = -torch.sum(math.log(self.tau) - 1.0/self.tau * torch.exp( self.mu_sa + 0.5*F.softplus(self.sigma_sa) ) + 0.5*self.mu_sa + 0.5*F.softplus(self.sigma_sa).log() + 0.5 + 0.5*math.log(2))
  189. KL_sb = -torch.sum(-math.exp(0.5*F.softplus(self.sigma_sb) - self.mu_sb) - 0.5*self.mu_sb + 0.5*F.softplus(self.sigma_sb).log() + 0.5 + 0.5*math.log(2))
  190. KL_alpha = -torch.sum(- torch.exp( self.mu_alpha + 0.5*F.softplus(self.sigma_alpha) ) + 0.5*self.mu_alpha + 0.5*F.softplus(self.sigma_alpha).log() + 0.5 + 0.5*math.log(2))
  191. KL_beta = -torch.sum(-torch.exp( 0.5*F.softplus(self.sigma_beta) - self.mu_beta ) - 0.5*self.mu_beta + 0.5*F.softplus(self.sigma_beta).log() + 0.5 + 0.5*math.log(2))
  192. # print(KL_sa)
  193. # print(KL_sb)
  194. # print(KL_alpha)
  195. # print(KL_beta)
  196. # print("\n")
  197. KLD += KL_sa + KL_sb + KL_alpha + KL_beta
  198. return KLD
  199. def return_parameters(self):
  200. mu = self.weight_mu.data
  201. sigma = self.weight_spsigma.data
  202. weight = self.weight.data
  203. # print(weight)
  204. return mu, sigma, weight
  205. def extra_repr(self):
  206. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  207. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  208. )