PageRenderTime 61ms CodeModel.GetById 32ms RepoModel.GetById 0ms app.codeStats 1ms

/Main/Models/BBB_t_linear.py

https://bitbucket.org/RamiroCope/thesis_repo
Python | 195 lines | 137 code | 30 blank | 28 comment | 15 complexity | 5a9d603567948b19e0ea0e865549c6c1 MD5 | raw file
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. #############################
  8. ## Bayesian Neural Network ##
  9. #############################
  10. class BNN(nn.Module):
  11. def __init__(self, input_size,
  12. hidden_sizes,
  13. output_size = None,
  14. prior_prec=1.0,
  15. prec_init=1.0,
  16. df_prior = 5.0,
  17. df_init = 5.0):
  18. super(type(self), self).__init__()
  19. self.input_size = input_size
  20. sigma_prior = 1.0/math.sqrt(prior_prec)
  21. sigma_init = 1.0/math.sqrt(prec_init)
  22. if output_size:
  23. self.output_size = output_size
  24. self.squeeze_output = False
  25. else :
  26. self.output_size = 1
  27. self.squeeze_output = True
  28. if len(hidden_sizes) == 0:
  29. self.hidden_layers = []
  30. self.output_layer = StochasticLinear(self.input_size,
  31. self.output_size,
  32. sigma_prior = sigma_prior,
  33. sigma_init = sigma_init,
  34. df_prior = df_prior,
  35. df_init = df_init)
  36. else:
  37. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  38. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init)
  39. def forward(self, x):
  40. out = x
  41. for layer in self.hidden_layers:
  42. out = layer(out)
  43. logits = self.output_layer(out)
  44. if self.squeeze_output:
  45. logits = torch.squeeze(logits)
  46. return logits
  47. def kl_divergence(self):
  48. kl = 0
  49. for layer in self.hidden_layers:
  50. kl += layer.kl_divergence()
  51. kl += self.output_layer.kl_divergence()
  52. return(kl)
  53. def return_params(self):
  54. mus = []
  55. sigmas = []
  56. dfs = []
  57. weights = []
  58. for layer in self.hidden_layers:
  59. mu, sigma, df, weight = layer.return_parameters()
  60. mus += [mu.numpy()]
  61. sigmas += [sigma.numpy()]
  62. dfs += [df.numpy()]
  63. weights += [weight.numpy()]
  64. mu, sigma, df, weight = self.output_layer.return_parameters()
  65. mus += [mu.numpy()]
  66. sigmas += [sigma.numpy()]
  67. dfs += [df.numpy()]
  68. weights += [weight.numpy()]
  69. return mus, sigmas, dfs, weights
  70. ###############################################
  71. ## Gaussian Mean-Field Linear Transformation ##
  72. ###############################################
  73. class StochasticLinear(nn.Module):
  74. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  75. Args:
  76. in_features: size of each input sample
  77. out_features: size of each output sample
  78. bias: If set to False, the layer will not learn an additive bias.
  79. Default: ``True``
  80. Shape:
  81. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  82. additional dimensions
  83. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  84. are the same shape as the input.
  85. Attributes:
  86. weight: the learnable weights of the module of shape
  87. `(out_features x in_features)`
  88. bias: the learnable bias of the module of shape `(out_features)`
  89. Examples::
  90. >>> m = nn.Linear(20, 30)
  91. >>> input = torch.randn(128, 20)
  92. >>> output = m(input)
  93. >>> print(output.size())
  94. """
  95. def __init__(self, in_features,
  96. out_features,
  97. sigma_prior=1.0,
  98. df_prior = 5.0,
  99. sigma_init=1.0,
  100. df_init = 5.0,
  101. bias=False):
  102. super(type(self), self).__init__()
  103. self.count = 0
  104. self.in_features = in_features
  105. self.out_features = out_features
  106. self.sigma_prior = sigma_prior
  107. self.df_prior = df_prior
  108. self.sigma_init = sigma_init
  109. self.df_init = df_init
  110. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  111. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  112. self.weight_spdf = Parameter(torch.Tensor(out_features, in_features))
  113. if bias:
  114. self.bias = True
  115. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  116. self.bias_mu = Parameter(torch.Tensor(out_features))
  117. self.bias_spdf = Parameter(torch.Tensor(out_features))
  118. else:
  119. self.register_parameter('bias', None)
  120. self.reset_parameters()
  121. def reset_parameters(self):
  122. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  123. self.weight_mu.data.uniform_(-stdv, stdv)
  124. if self.bias is not None:
  125. self.bias_mu.data.uniform_(-stdv, stdv)
  126. self.weight_spsigma.data.fill_(self.sigma_init)
  127. self.weight_spdf.data.fill_(self.df_init)
  128. if self.bias is not None:
  129. self.bias_spsigma.data.fill_(self.sigma_init)
  130. self.bias_spdf.data.fill_(self.df_init)
  131. def forward(self, input):
  132. # Construct Gamma distribution for reparameterization sampling
  133. t_dist = distributions.StudentT(F.softplus(self.weight_spdf), loc = self.weight_mu, scale = F.softplus(self.weight_spsigma))
  134. self.weight = t_dist.rsample()
  135. if self.bias is not None:
  136. t_dist = distributions.StudentT(F.softplus(self.bias_spdf), loc = self.bias_mu, scale = F.softplus(self.bias_spsigma))
  137. self.bias = t_dist.rsample()
  138. return F.linear(input, self.weight, self.bias)
  139. def kl_divergence(self):
  140. mu = self.weight_mu
  141. sigma = F.softplus(self.weight_spsigma) + 1e-5
  142. mu0 = torch.zeros_like(mu)
  143. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  144. q = distributions.Normal(mu,sigma)
  145. p = distributions.Normal(mu0,sigma0)
  146. kl = distributions.kl_divergence(q,p).sum()
  147. if self.bias is not None:
  148. mu = self.bias_mu
  149. sigma = F.softplus(self.bias_spsigma) + 1e-5
  150. mu0 = torch.zeros_like(mu)
  151. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  152. q = distributions.Normal(mu,sigma)
  153. p = distributions.Normal(mu0,sigma0)
  154. kl += distributions.kl_divergence(q,p).sum()
  155. return kl
  156. def return_parameters(self):
  157. mu = self.weight_mu.data
  158. sigma = F.softplus(self.weight_spsigma.data)
  159. df = F.softplus(self.weight_spdf.data)
  160. weight = self.weight.data
  161. return mu, sigma, df, weight
  162. def extra_repr(self):
  163. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  164. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  165. )