/Archive/Testing/Peter/Pytorch04/Models/BBB_t.py

https://bitbucket.org/RamiroCope/thesis_repo · Python · 226 lines · 154 code · 37 blank · 35 comment · 19 complexity · a21e8137787b7f58eb80c2945e94e832 MD5 · raw file

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. import numpy as np
  8. from scipy.special import gammaln, psi
  9. #############################
  10. ## Bayesian Neural Network ##
  11. #############################
  12. class BNN(nn.Module):
  13. def __init__(self, input_size,
  14. hidden_sizes,
  15. output_size,
  16. act_func,
  17. prior_prec=1.0,
  18. prec_init=1.0,
  19. df_prior = 5.0,
  20. df_init = 5.0):
  21. super(type(self), self).__init__()
  22. self.input_size = input_size
  23. sigma_prior = 1.0/math.sqrt(prior_prec)
  24. sigma_init = 1.0/math.sqrt(prec_init)
  25. if output_size:
  26. self.output_size = output_size
  27. self.squeeze_output = False
  28. else :
  29. self.output_size = 1
  30. self.squeeze_output = True
  31. self.act = F.tanh if act_func == "tanh" else F.relu
  32. if len(hidden_sizes) == 0:
  33. self.hidden_layers = []
  34. self.output_layer = StochasticLinear(self.input_size,
  35. self.output_size,
  36. sigma_prior = sigma_prior,
  37. sigma_init = sigma_init,
  38. df_prior = df_prior,
  39. df_init = df_init)
  40. else:
  41. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  42. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init)
  43. def forward(self, x, training = True):
  44. x = x.view(-1,self.input_size)
  45. out = x
  46. for layer in self.hidden_layers:
  47. out = self.act(layer(out, training = training))
  48. logits = self.output_layer(out, training = training)
  49. if self.squeeze_output:
  50. logits = torch.squeeze(logits)
  51. return logits
  52. def kl_divergence(self):
  53. kl = 0
  54. for layer in self.hidden_layers:
  55. kl += layer.kl_divergence()
  56. kl += self.output_layer.kl_divergence()
  57. return(kl)
  58. def return_params(self):
  59. mus = []
  60. sigmas = []
  61. dfs = []
  62. weights = []
  63. for layer in self.hidden_layers:
  64. mu, sigma, df, weight = layer.return_parameters()
  65. mus += [mu.numpy()]
  66. sigmas += [sigma.numpy()]
  67. dfs += [df.numpy()]
  68. weights += [weight.numpy()]
  69. mu, sigma, df, weight = self.output_layer.return_parameters()
  70. mus += [mu.numpy()]
  71. sigmas += [sigma.numpy()]
  72. dfs += [df.numpy()]
  73. weights += [weight.numpy()]
  74. return mus, sigmas, dfs, weights
  75. ###############################################
  76. ## Gaussian Mean-Field Linear Transformation ##
  77. ###############################################
  78. class StochasticLinear(nn.Module):
  79. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  80. Args:
  81. in_features: size of each input sample
  82. out_features: size of each output sample
  83. bias: If set to False, the layer will not learn an additive bias.
  84. Default: ``True``
  85. Shape:
  86. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  87. additional dimensions
  88. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  89. are the same shape as the input.
  90. Attributes:
  91. weight: the learnable weights of the module of shape
  92. `(out_features x in_features)`
  93. bias: the learnable bias of the module of shape `(out_features)`
  94. Examples::
  95. >>> m = nn.Linear(20, 30)
  96. >>> input = torch.randn(128, 20)
  97. >>> output = m(input)
  98. >>> print(output.size())
  99. """
  100. def __init__(self, in_features,
  101. out_features,
  102. sigma_prior=1.0,
  103. df_prior = 5.0,
  104. sigma_init=1.0,
  105. df_init = 5.0,
  106. bias=True):
  107. super(type(self), self).__init__()
  108. self.in_features = in_features
  109. self.out_features = out_features
  110. self.sigma_prior = sigma_prior
  111. self.df_prior = df_prior
  112. self.sigma_init = sigma_init
  113. self.df_init = df_init
  114. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  115. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  116. self.weight_spdf = Parameter(torch.Tensor(out_features, in_features))
  117. if bias:
  118. self.bias = True
  119. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  120. self.bias_mu = Parameter(torch.Tensor(out_features))
  121. self.bias_spdf = Parameter(torch.Tensor(out_features))
  122. else:
  123. self.register_parameter('bias', None)
  124. self.reset_parameters()
  125. def reset_parameters(self):
  126. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  127. self.weight_mu.data.uniform_(-stdv, stdv)
  128. if self.bias is not None:
  129. self.bias_mu.data.uniform_(-stdv, stdv)
  130. self.weight_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  131. self.weight_spdf.data.fill_(self.df_init)
  132. if self.bias is not None:
  133. self.bias_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  134. self.bias_spdf.data.fill_(self.df_init)
  135. def forward(self, input, training = True):
  136. if not training:
  137. self.weight = self.weight_mu.data
  138. if self.bias is not None:
  139. self.bias = self.bias_mu.data
  140. return F.linear(input, self.weight, self.bias)
  141. t_dist = distributions.StudentT(F.softplus(self.weight_spdf), loc = self.weight_mu, scale = F.softplus(self.weight_spsigma))
  142. sample = t_dist.rsample()
  143. self.weight = sample
  144. if self.bias is not None:
  145. t_dist = distributions.StudentT(F.softplus(self.bias_spdf), loc = self.bias_mu, scale = F.softplus(self.bias_spsigma))
  146. sample = t_dist.rsample()
  147. self.bias = sample
  148. return F.linear(input, self.weight, self.bias)
  149. def kl_divergence(self):
  150. # get sample
  151. sample = self.weight
  152. # define posterior parameters
  153. mu = self.weight_mu
  154. sigma = F.softplus(self.weight_spsigma)
  155. df = F.softplus(self.weight_spdf)
  156. # define prior parameters
  157. mu0 = torch.zeros_like(mu)
  158. sigma0 = torch.ones_like(sigma)*self.sigma_prior
  159. df0 = torch.ones_like(df)*self.df_prior
  160. q_w = distributions.StudentT(df, loc = mu, scale = sigma)
  161. p_w = distributions.StudentT(df0, loc = mu0, scale = sigma0)
  162. # Compute the KL divergence
  163. kl = (q_w.log_prob(sample) - p_w.log_prob(sample)).sum()
  164. if self.bias is not None:
  165. # get sample
  166. sample = self.bias
  167. # define posterior parameters
  168. mu = self.bias_mu
  169. sigma = F.softplus(self.bias_spsigma)
  170. df = F.softplus(self.bias_spdf)
  171. # define prior parameters
  172. mu0 = torch.zeros_like(mu)
  173. sigma0 = torch.ones_like(sigma)*self.sigma_prior
  174. df0 = torch.ones_like(df)*self.df_prior
  175. q_w = distributions.StudentT(df, loc = mu, scale = sigma)
  176. p_w = distributions.StudentT(df0, loc = mu0, scale = sigma0)
  177. # Compute the KL divergence
  178. kl += (q_w.log_prob(sample) - p_w.log_prob(sample)).sum()
  179. return kl
  180. def return_parameters(self):
  181. mu = self.weight_mu.data
  182. sigma = F.softplus(self.weight_spsigma.data)
  183. df = self.weight_spdf.data
  184. weight = self.weight.data
  185. return mu, sigma, df, weight
  186. def extra_repr(self):
  187. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  188. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  189. )