/Archive/Testing/Peter/Pytorch04/Models/BBB_HS_linear.py
Python | 260 lines | 164 code | 52 blank | 44 comment | 22 complexity | 3bbb8758ecb0bcfdaab80cbe6416fd1f MD5 | raw file
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter
- import torch.distributions as distributions
- import numpy as np
- #############################
- ## Bayesian Neural Network ##
- #############################
- class BNN(nn.Module):
- def __init__(self, input_size,
- hidden_sizes,
- output_size=None,
- act_func=None,
- prior_prec=1.0,
- prec_init=1.0,
- clip_var = None):
- super(type(self), self).__init__()
- self.input_size = input_size
- self.clip_var = clip_var
- sigma_prior = 1.0/math.sqrt(prior_prec)
- sigma_init = 1.0/math.sqrt(prec_init)
- if output_size:
- self.output_size = output_size
- self.squeeze_output = False
- else :
- self.output_size = 1
- self.squeeze_output = True
- self.act = F.tanh if act_func == "tanh" else F.relu
- if len(hidden_sizes) == 0:
- self.hidden_layers = []
- self.output_layer = StochasticLinear(self.input_size,
- self.output_size,
- sigma_prior = sigma_prior,
- sigma_init = sigma_init,
- clip_var = clip_var)
- else:
- self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
- self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init, clip_var = clip_var)
- def forward(self, x, training = False):
- # x = x.view(-1,self.input_size)
- out = x
- for layer in self.hidden_layers:
- out = layer(out, training = training)
- logits = self.output_layer(out, training = training)
- if self.squeeze_output:
- logits = torch.squeeze(logits)
- return logits
- def kl_divergence(self):
- kl = 0
- for layer in self.hidden_layers:
- kl += layer.kl_divergence()
- kl += self.output_layer.kl_divergence()
- return(kl)
-
- def return_params(self):
- mus = []
- sigmas = []
- weights = []
- for layer in self.hidden_layers:
- mu, sigma, weight = layer.return_parameters()
- mus += [mu.numpy()]
- sigmas += [F.softplus(sigma).numpy()]
- weights += [weight.numpy()]
- mu, sigma, weight = self.output_layer.return_parameters()
- mus += [mu.numpy()]
- sigmas += [F.softplus(sigma).numpy()]
- weights += [weight.numpy()]
- print(weight.shape)
- print(mu.shape)
- return mus, sigmas, weights
-
- def clip_vars(self):
- if self.clip_var:
- for i,layer in enumerate(self.hidden_layers):
- if i == 0:
- layer.clip_variances()
- ###############################################
- ## Gaussian Mean-Field Linear Transformation ##
- ###############################################
- class StochasticLinear(nn.Module):
- """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
- Args:
- in_features: size of each input sample
- out_features: size of each output sample
- bias: If set to False, the layer will not learn an additive bias.
- Default: ``True``
- Shape:
- - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
- additional dimensions
- - Output: :math:`(N, *, out\_features)` where all but the last dimension
- are the same shape as the input.
- Attributes:
- weight: the learnable weights of the module of shape
- `(out_features x in_features)`
- bias: the learnable bias of the module of shape `(out_features)`
- Examples::
- >>> m = nn.Linear(20, 30)
- >>> input = torch.randn(128, 20)
- >>> output = m(input)
- >>> print(output.size())
- """
- def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=False, clip_var = None):
- super(type(self), self).__init__()
- self.in_features = in_features
- self.out_features = out_features
-
-
- self.clip_var = clip_var
-
- self.tau = sigma_prior
- self.sigma_init = sigma_init
-
- #M_w and Sigma_w
- self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
- self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
-
- #Variational parameters
- self.mu_sa = Parameter(torch.Tensor(1))#torch.Tensor(in_features)
- self.mu_sb = Parameter(torch.Tensor(1))
-
- self.sigma_sa = Parameter(torch.Tensor(1))
- self.sigma_sb = Parameter(torch.Tensor(1))
-
- self.mu_alpha = Parameter(torch.Tensor(in_features))
- self.mu_beta = Parameter(torch.Tensor(in_features))
-
- self.sigma_alpha = Parameter(torch.Tensor(in_features))
- self.sigma_beta = Parameter(torch.Tensor(in_features))
-
- if bias:
- self.bias = True
- self.bias_mu = Parameter(torch.Tensor(out_features))
- self.bias_spsigma = Parameter(torch.Tensor(out_features))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- def clip_variances(self):
- if self.clip_var:
- # print("clipping vars with %.3f" % (self.clip_var))
- self.weight_spsigma.data.clamp_(max=self.clip_var)
- self.bias_spsigma.data.clamp_(max=self.clip_var)
-
- def reset_parameters(self):
- stdv = 1. / math.sqrt(self.weight_mu.size(1))
-
- self.weight_mu.data.uniform_(-stdv, stdv)
-
- if self.bias is not None:
- self.bias_mu.data.uniform_(-stdv, stdv)
- self.weight_spsigma.data.normal_(math.exp(-3),1e-2)
-
- if self.bias is not None:
- self.bias_spsigma.data.normal_(math.exp(-3),1e-2)
-
- self.mu_sa.data.normal_(-1,1e-2)
- self.mu_sb.data.normal_(-1, 1e-2)
-
- self.sigma_sa.data.normal_(math.exp(-3), 1e-2)
- self.sigma_sb.data.normal_(math.exp(-3), 1e-2)
-
- self.mu_alpha.data.normal_(0, 1e-2)
- self.mu_beta.data.normal_(0, 1e-2)
-
- self.sigma_alpha.data.normal_(math.exp(-3), 1e-2)
- self.sigma_beta.data.normal_(math.exp(-3), 1e-2)
- def forward(self, input, training = True):
-
- batch_size = input.size()[0]
- if not training:
- mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
- log_s = mu_s
- mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
- Z = torch.exp(mu_z.repeat(batch_size,1))
- H = input*Z
- M_h = F.linear(H, self.weight_mu)
- self.weight = mu_z
- return M_h
- else:
-
- mu_s = 0.5*self.mu_sa + 0.5*self.mu_sb
- sigma_s = torch.sqrt(0.25*F.softplus(self.sigma_sa) + 0.25*F.softplus(self.sigma_sb))
-
- # noise
- e = torch.normal(mean=torch.zeros_like(self.sigma_sa), std=1.0)
- E = torch.normal(mean=torch.zeros_like(input), std=1.0)
-
- log_s = mu_s + sigma_s*e
-
- mu_z = 0.5*self.mu_alpha + 0.5*self.mu_beta + log_s
- sigma_z = torch.sqrt(0.25*F.softplus(self.sigma_alpha) + 0.25*F.softplus(self.sigma_beta))
-
- Z = torch.exp(mu_z.repeat(batch_size,1) + sigma_z.repeat(batch_size,1)*E)
-
- H = input*Z
- M_h = F.linear(H, self.weight_mu)
- V_h = F.linear(H**2, F.softplus(self.weight_spsigma))
-
-
- E_final = torch.normal(mean=torch.zeros_like(V_h), std=1.0)
-
- self.weight = mu_z
- # print(self.weight.shape)
-
- return M_h + torch.sqrt(V_h)*E_final
- def kl_divergence(self):
-
-
- # KL(q(w|z)||p(w|z))
- KLD_element = -0.5*F.softplus(self.weight_spsigma).log() + 0.5 * (F.softplus(self.weight_spsigma) + self.weight_mu**2) - 0.5
- KLD = torch.sum(KLD_element)
- # print(KLD_element.sum())
- if self.bias is not None:
- # KL bias
- KLD_element = -0.5*F.softplus(self.bias_spsigma).log() + 0.5 * (F.softplus(self.bias_spsigma) + self.bias_mu**2) - 0.5
- KLD += torch.sum(KLD_element)
- # print(KLD_element.sum())
-
- # We change sign since the KL divergences are negative -D_KL
- KL_sa = -torch.sum(math.log(self.tau) - 1.0/self.tau * torch.exp( self.mu_sa + 0.5*F.softplus(self.sigma_sa) ) + 0.5*self.mu_sa + 0.5*F.softplus(self.sigma_sa).log() + 0.5 + 0.5*math.log(2))
-
- KL_sb = -torch.sum(-math.exp(0.5*F.softplus(self.sigma_sb) - self.mu_sb) - 0.5*self.mu_sb + 0.5*F.softplus(self.sigma_sb).log() + 0.5 + 0.5*math.log(2))
-
- KL_alpha = -torch.sum(- torch.exp( self.mu_alpha + 0.5*F.softplus(self.sigma_alpha) ) + 0.5*self.mu_alpha + 0.5*F.softplus(self.sigma_alpha).log() + 0.5 + 0.5*math.log(2))
-
- KL_beta = -torch.sum(-torch.exp( 0.5*F.softplus(self.sigma_beta) - self.mu_beta ) - 0.5*self.mu_beta + 0.5*F.softplus(self.sigma_beta).log() + 0.5 + 0.5*math.log(2))
-
- # print(KL_sa)
- # print(KL_sb)
- # print(KL_alpha)
- # print(KL_beta)
- # print("\n")
-
- KLD += KL_sa + KL_sb + KL_alpha + KL_beta
-
- return KLD
- def return_parameters(self):
- mu = self.weight_mu.data
- sigma = self.weight_spsigma.data
- weight = self.weight.data
- # print(weight)
- return mu, sigma, weight
- def extra_repr(self):
- return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
- self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
- )