PageRenderTime 44ms CodeModel.GetById 21ms RepoModel.GetById 1ms app.codeStats 0ms

/Main/Models/BBB_linear.py

https://bitbucket.org/RamiroCope/thesis_repo
Python | 164 lines | 115 code | 22 blank | 27 comment | 17 complexity | 00f026a39c332b7d0343b38a60a88cf3 MD5 | raw file
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. import torch.distributions as distributions
  7. #############################
  8. ## Bayesian Neural Network ##
  9. #############################
  10. class BNN(nn.Module):
  11. def __init__(self, input_size, hidden_sizes, output_size=None, prior_prec=1.0, prec_init=1.0):
  12. super(type(self), self).__init__()
  13. self.input_size = input_size
  14. sigma_prior = 1.0/math.sqrt(prior_prec)
  15. sigma_init = 1.0/math.sqrt(prec_init)
  16. if output_size:
  17. self.output_size = output_size
  18. self.squeeze_output = False
  19. else :
  20. self.output_size = 1
  21. self.squeeze_output = True
  22. if len(hidden_sizes) == 0:
  23. self.hidden_layers = []
  24. self.output_layer = StochasticLinear(self.input_size, self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
  25. else:
  26. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  27. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
  28. def forward(self, x, training = True):
  29. x = x.view(-1,self.input_size)
  30. out = x
  31. for layer in self.hidden_layers:
  32. out = layer(out, training = training)
  33. logits = self.output_layer(out, training = training)
  34. if self.squeeze_output:
  35. logits = torch.squeeze(logits)
  36. return logits
  37. def kl_divergence(self):
  38. kl = 0
  39. for layer in self.hidden_layers:
  40. kl += layer.kl_divergence()
  41. kl += self.output_layer.kl_divergence()
  42. return(kl)
  43. def return_params(self):
  44. mus = []
  45. sigmas = []
  46. weights = []
  47. for layer in self.hidden_layers:
  48. mu, sigma, weight = layer.return_parameters()
  49. mus += [mu.numpy()]
  50. sigmas += [sigma.numpy()]
  51. weights += [weight.numpy()]
  52. mu, sigma, weight = self.output_layer.return_parameters()
  53. mus += [mu.numpy()]
  54. sigmas += [sigma.numpy()]
  55. weights += [weight.numpy()]
  56. return mus, sigmas, weights
  57. ###############################################
  58. ## Gaussian Mean-Field Linear Transformation ##
  59. ###############################################
  60. class StochasticLinear(nn.Module):
  61. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
  62. Args:
  63. in_features: size of each input sample
  64. out_features: size of each output sample
  65. bias: If set to False, the layer will not learn an additive bias.
  66. Default: ``True``
  67. Shape:
  68. - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
  69. additional dimensions
  70. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  71. are the same shape as the input.
  72. Attributes:
  73. weight: the learnable weights of the module of shape
  74. `(out_features x in_features)`
  75. bias: the learnable bias of the module of shape `(out_features)`
  76. Examples::
  77. >>> m = nn.Linear(20, 30)
  78. >>> input = torch.randn(128, 20)
  79. >>> output = m(input)
  80. >>> print(output.size())
  81. """
  82. def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=False):
  83. super(type(self), self).__init__()
  84. self.in_features = in_features
  85. self.out_features = out_features
  86. self.sigma_prior = sigma_prior
  87. self.sigma_init = sigma_init
  88. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  89. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  90. if bias:
  91. self.bias = True
  92. self.bias_mu = Parameter(torch.Tensor(out_features))
  93. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  94. else:
  95. self.register_parameter('bias', None)
  96. self.reset_parameters()
  97. def reset_parameters(self):
  98. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  99. self.weight_mu.data.uniform_(-stdv, stdv)
  100. if self.bias is not None:
  101. self.bias_mu.data.uniform_(-stdv, stdv)
  102. self.weight_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  103. if self.bias is not None:
  104. self.bias_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  105. def forward(self, input, training = True):
  106. if not training:
  107. self.weight = self.weight_mu.data
  108. if self.bias is not None:
  109. self.bias = self.bias_mu.data
  110. return F.linear(input, self.weight, self.bias)
  111. epsilon_W = torch.normal(mean=torch.zeros_like(self.weight_mu), std=1.0)
  112. self.weight = self.weight_mu + F.softplus(self.weight_spsigma) * epsilon_W
  113. if self.bias is not None:
  114. epsilon_b = torch.normal(mean=torch.zeros_like(self.bias_mu), std=1.0)
  115. self.bias = self.bias_mu + F.softplus(self.bias_spsigma) * epsilon_b
  116. return F.linear(input, self.weight, self.bias)
  117. def kl_divergence(self):
  118. mu = self.weight_mu
  119. sigma = F.softplus(self.weight_spsigma) + 1e-5
  120. mu0 = torch.zeros_like(mu)
  121. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  122. q = distributions.Normal(mu,sigma)
  123. p = distributions.Normal(mu0,sigma0)
  124. kl = distributions.kl_divergence(q,p).sum()
  125. if self.bias is not None:
  126. mu = self.bias_mu
  127. sigma = F.softplus(self.bias_spsigma) + 1e-5
  128. mu0 = torch.zeros_like(mu)
  129. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  130. q = distributions.Normal(mu,sigma)
  131. p = distributions.Normal(mu0,sigma0)
  132. kl += distributions.kl_divergence(q,p).sum()
  133. return kl
  134. def return_parameters(self):
  135. mu = self.weight_mu.data
  136. sigma = self.weight_spsigma.data
  137. weight = self.weight.data
  138. return mu, sigma, weight
  139. def extra_repr(self):
  140. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  141. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  142. )