/Main/Models/BBB_t_linear.py
Python | 195 lines | 137 code | 30 blank | 28 comment | 15 complexity | 5a9d603567948b19e0ea0e865549c6c1 MD5 | raw file
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter
- import torch.distributions as distributions
- #############################
- ## Bayesian Neural Network ##
- #############################
- class BNN(nn.Module):
- def __init__(self, input_size,
- hidden_sizes,
- output_size = None,
- prior_prec=1.0,
- prec_init=1.0,
- df_prior = 5.0,
- df_init = 5.0):
-
- super(type(self), self).__init__()
-
- self.input_size = input_size
- sigma_prior = 1.0/math.sqrt(prior_prec)
- sigma_init = 1.0/math.sqrt(prec_init)
- if output_size:
- self.output_size = output_size
- self.squeeze_output = False
- else :
- self.output_size = 1
- self.squeeze_output = True
- if len(hidden_sizes) == 0:
- self.hidden_layers = []
- self.output_layer = StochasticLinear(self.input_size,
- self.output_size,
- sigma_prior = sigma_prior,
- sigma_init = sigma_init,
- df_prior = df_prior,
- df_init = df_init)
- else:
- self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
- self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init)
- def forward(self, x):
- out = x
- for layer in self.hidden_layers:
- out = layer(out)
- logits = self.output_layer(out)
- if self.squeeze_output:
- logits = torch.squeeze(logits)
- return logits
- def kl_divergence(self):
- kl = 0
- for layer in self.hidden_layers:
- kl += layer.kl_divergence()
- kl += self.output_layer.kl_divergence()
- return(kl)
-
- def return_params(self):
- mus = []
- sigmas = []
- dfs = []
- weights = []
- for layer in self.hidden_layers:
- mu, sigma, df, weight = layer.return_parameters()
- mus += [mu.numpy()]
- sigmas += [sigma.numpy()]
- dfs += [df.numpy()]
- weights += [weight.numpy()]
- mu, sigma, df, weight = self.output_layer.return_parameters()
- mus += [mu.numpy()]
- sigmas += [sigma.numpy()]
- dfs += [df.numpy()]
- weights += [weight.numpy()]
- return mus, sigmas, dfs, weights
- ###############################################
- ## Gaussian Mean-Field Linear Transformation ##
- ###############################################
- class StochasticLinear(nn.Module):
- """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
- Args:
- in_features: size of each input sample
- out_features: size of each output sample
- bias: If set to False, the layer will not learn an additive bias.
- Default: ``True``
- Shape:
- - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
- additional dimensions
- - Output: :math:`(N, *, out\_features)` where all but the last dimension
- are the same shape as the input.
- Attributes:
- weight: the learnable weights of the module of shape
- `(out_features x in_features)`
- bias: the learnable bias of the module of shape `(out_features)`
- Examples::
- >>> m = nn.Linear(20, 30)
- >>> input = torch.randn(128, 20)
- >>> output = m(input)
- >>> print(output.size())
- """
- def __init__(self, in_features,
- out_features,
- sigma_prior=1.0,
- df_prior = 5.0,
- sigma_init=1.0,
- df_init = 5.0,
- bias=False):
-
- super(type(self), self).__init__()
-
- self.count = 0
-
- self.in_features = in_features
- self.out_features = out_features
-
- self.sigma_prior = sigma_prior
- self.df_prior = df_prior
-
- self.sigma_init = sigma_init
- self.df_init = df_init
-
- self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
- self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
- self.weight_spdf = Parameter(torch.Tensor(out_features, in_features))
- if bias:
- self.bias = True
- self.bias_spsigma = Parameter(torch.Tensor(out_features))
- self.bias_mu = Parameter(torch.Tensor(out_features))
- self.bias_spdf = Parameter(torch.Tensor(out_features))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- def reset_parameters(self):
- stdv = 1. / math.sqrt(self.weight_mu.size(1))
- self.weight_mu.data.uniform_(-stdv, stdv)
- if self.bias is not None:
- self.bias_mu.data.uniform_(-stdv, stdv)
- self.weight_spsigma.data.fill_(self.sigma_init)
- self.weight_spdf.data.fill_(self.df_init)
- if self.bias is not None:
- self.bias_spsigma.data.fill_(self.sigma_init)
- self.bias_spdf.data.fill_(self.df_init)
- def forward(self, input):
- # Construct Gamma distribution for reparameterization sampling
- t_dist = distributions.StudentT(F.softplus(self.weight_spdf), loc = self.weight_mu, scale = F.softplus(self.weight_spsigma))
-
- self.weight = t_dist.rsample()
-
- if self.bias is not None:
- t_dist = distributions.StudentT(F.softplus(self.bias_spdf), loc = self.bias_mu, scale = F.softplus(self.bias_spsigma))
- self.bias = t_dist.rsample()
- return F.linear(input, self.weight, self.bias)
- def kl_divergence(self):
- mu = self.weight_mu
- sigma = F.softplus(self.weight_spsigma) + 1e-5
- mu0 = torch.zeros_like(mu)
- sigma0 = torch.ones_like(sigma) * self.sigma_prior
-
- q = distributions.Normal(mu,sigma)
- p = distributions.Normal(mu0,sigma0)
-
- kl = distributions.kl_divergence(q,p).sum()
-
- if self.bias is not None:
- mu = self.bias_mu
- sigma = F.softplus(self.bias_spsigma) + 1e-5
- mu0 = torch.zeros_like(mu)
- sigma0 = torch.ones_like(sigma) * self.sigma_prior
-
- q = distributions.Normal(mu,sigma)
- p = distributions.Normal(mu0,sigma0)
-
- kl += distributions.kl_divergence(q,p).sum()
-
- return kl
- def return_parameters(self):
- mu = self.weight_mu.data
- sigma = F.softplus(self.weight_spsigma.data)
- df = F.softplus(self.weight_spdf.data)
- weight = self.weight.data
- return mu, sigma, df, weight
- def extra_repr(self):
- return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
- self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
- )