PageRenderTime 454ms CodeModel.GetById 0ms RepoModel.GetById 1ms app.codeStats 0ms

/Archive/Testing/Peter/Pytorch04/Models/BBB.py

https://bitbucket.org/RamiroCope/thesis_repo
Python | 167 lines | 117 code | 21 blank | 29 comment | 19 complexity | 6bf594b9cae0af274f2e52f6783838a9 MD5 | raw file
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. import numpy as np
  8. #############################
  9. ## Bayesian Neural Network ##
  10. #############################
  11. class BNN(nn.Module):
  12. def __init__(self, input_size, hidden_sizes, output_size, act_func, prior_prec=1.0, prec_init=1.0):
  13. super(type(self), self).__init__()
  14. self.input_size = input_size
  15. sigma_prior = 1.0/math.sqrt(prior_prec)
  16. sigma_init = 1.0/math.sqrt(prec_init)
  17. if output_size:
  18. self.output_size = output_size
  19. self.squeeze_output = False
  20. else :
  21. self.output_size = 1
  22. self.squeeze_output = True
  23. self.act = F.tanh if act_func == "tanh" else F.relu
  24. if len(hidden_sizes) == 0:
  25. self.hidden_layers = []
  26. self.output_layer = StochasticLinear(self.input_size, self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
  27. else:
  28. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  29. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
  30. def forward(self, x, training = True):
  31. x = x.view(-1,self.input_size)
  32. out = x
  33. for layer in self.hidden_layers:
  34. out = self.act(layer(out, training = training))
  35. logits = self.output_layer(out, training = training)
  36. if self.squeeze_output:
  37. logits = torch.squeeze(logits)
  38. return logits
  39. def kl_divergence(self):
  40. kl = 0
  41. for layer in self.hidden_layers:
  42. kl += layer.kl_divergence()
  43. kl += self.output_layer.kl_divergence()
  44. return(kl)
  45. def return_params(self):
  46. mus = []
  47. sigmas = []
  48. weights = []
  49. for layer in self.hidden_layers:
  50. mu, sigma, weight = layer.return_parameters()
  51. mus += [mu.numpy()]
  52. sigmas += [F.softplus(sigma).numpy()]
  53. weights += [weight.numpy()]
  54. mu, sigma, weight = self.output_layer.return_parameters()
  55. mus += [mu.numpy()]
  56. sigmas += [F.softplus(sigma).numpy()]
  57. weights += [weight.numpy()]
  58. return mus, sigmas, weights
  59. ###############################################
  60. ## Gaussian Mean-Field Linear Transformation ##
  61. ###############################################
  62. class StochasticLinear(nn.Module):
  63. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  64. Args:
  65. in_features: size of each input sample
  66. out_features: size of each output sample
  67. bias: If set to False, the layer will not learn an additive bias.
  68. Default: ``True``
  69. Shape:
  70. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  71. additional dimensions
  72. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  73. are the same shape as the input.
  74. Attributes:
  75. weight: the learnable weights of the module of shape
  76. `(out_features x in_features)`
  77. bias: the learnable bias of the module of shape `(out_features)`
  78. Examples::
  79. >>> m = nn.Linear(20, 30)
  80. >>> input = torch.randn(128, 20)
  81. >>> output = m(input)
  82. >>> print(output.size())
  83. """
  84. def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=True):
  85. super(type(self), self).__init__()
  86. self.in_features = in_features
  87. self.out_features = out_features
  88. self.sigma_prior = sigma_prior
  89. self.sigma_init = sigma_init
  90. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  91. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  92. if bias:
  93. self.bias = True
  94. self.bias_mu = Parameter(torch.Tensor(out_features))
  95. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  96. else:
  97. self.register_parameter('bias', None)
  98. self.reset_parameters()
  99. def reset_parameters(self):
  100. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  101. self.weight_mu.data.uniform_(-0.2,0.2)#(-stdv, stdv)
  102. if self.bias is not None:
  103. self.bias_mu.data.uniform_(-0.2,0.2)#(-stdv, stdv)
  104. # self.weight_spsigma.data.uniform_(-5,-4)#(math.log(math.exp(self.sigma_init/4.0) -1),math.log(math.exp(self.sigma_init/2.0) -1))
  105. self.weight_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  106. if self.bias is not None:
  107. # self.bias_spsigma.data.uniform_(-5,-4)#(math.log(math.exp(self.sigma_init/4.0) -1),math.log(math.exp(self.sigma_init/2.0) -1))
  108. self.bias_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  109. def forward(self, input, training = True):
  110. if not training:
  111. self.weight = self.weight_mu.data
  112. if self.bias is not None:
  113. self.bias = self.bias_mu.data
  114. return F.linear(input, self.weight, self.bias)
  115. epsilon_W = torch.normal(mean=torch.zeros_like(self.weight_mu), std=1.0)
  116. self.weight = self.weight_mu + (F.softplus(self.weight_spsigma)+1e-5) * epsilon_W
  117. if self.bias is not None:
  118. epsilon_b = torch.normal(mean=torch.zeros_like(self.bias_mu), std=1.0)
  119. self.bias = self.bias_mu + (F.softplus(self.bias_spsigma)+1e-5) * epsilon_b
  120. return F.linear(input, self.weight, self.bias)
  121. def kl_divergence(self):
  122. mu = self.weight_mu
  123. sigma = F.softplus(self.weight_spsigma) + 1e-5
  124. mu0 = torch.zeros_like(mu)
  125. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  126. q = distributions.Normal(mu,sigma)
  127. p = distributions.Normal(mu0,sigma0)
  128. kl = distributions.kl_divergence(q,p).sum()
  129. if self.bias is not None:
  130. mu = self.bias_mu
  131. sigma = F.softplus(self.bias_spsigma) + 1e-5
  132. mu0 = torch.zeros_like(mu)
  133. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  134. q = distributions.Normal(mu,sigma)
  135. p = distributions.Normal(mu0,sigma0)
  136. kl += distributions.kl_divergence(q,p).sum()
  137. return kl
  138. def return_parameters(self):
  139. mu = self.weight_mu.data
  140. sigma = self.weight_spsigma.data
  141. weight = self.weight.data
  142. return mu, sigma, weight
  143. def extra_repr(self):
  144. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  145. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  146. )