/Archive/Testing/Peter/Pytorch04/Models/BBB_scale_mix_v2.py
https://bitbucket.org/RamiroCope/thesis_repo · Python · 254 lines · 162 code · 38 blank · 54 comment · 19 complexity · f13de435f08c591c14678ec984fe6c96 MD5 · raw file
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter
- import torch.distributions as distributions
- import numpy as np
- from scipy.special import gammaln, psi
- #############################
- ## Bayesian Neural Network ##
- #############################
- class BNN(nn.Module):
- def __init__(self, input_size,
- hidden_sizes,
- output_size,
- act_func,
- prior_prec=1.0,
- prec_init=1.0,
- alpha_prior = 5.0,
- beta_prior = 5.0,
- alpha_init = 5.0,
- beta_init = 5.0):
-
- super(type(self), self).__init__()
-
- self.input_size = input_size
- sigma_prior = 1.0/math.sqrt(prior_prec)
- sigma_init = 1.0/math.sqrt(prec_init)
- if output_size:
- self.output_size = output_size
- self.squeeze_output = False
- else :
- self.output_size = 1
- self.squeeze_output = True
- self.act = F.tanh if act_func == "tanh" else F.relu
- if len(hidden_sizes) == 0:
- self.hidden_layers = []
- self.output_layer = StochasticLinear(self.input_size,
- self.output_size,
- sigma_prior = sigma_prior,
- alpha_prior = alpha_prior,
- beta_prior = beta_prior,
- sigma_init = sigma_init,
- alpha_init = alpha_init,
- beta_init = beta_init)
- else:
- self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, alpha_prior = alpha_prior, beta_prior = beta_prior, sigma_init = sigma_init, alpha_init = alpha_init, beta_init = beta_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
- self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, alpha_prior = alpha_prior, beta_prior = beta_prior, sigma_init = sigma_init, alpha_init = alpha_init, beta_init = beta_init)
- def forward(self, x, training = True):
- x = x.view(-1,self.input_size)
- out = x
- for layer in self.hidden_layers:
- out = self.act(layer(out, training = training))
- logits = self.output_layer(out, training = training)
- if self.squeeze_output:
- logits = torch.squeeze(logits)
- return logits
- def kl_divergence(self):
- kl = 0
- for layer in self.hidden_layers:
- kl += layer.kl_divergence()
- kl += self.output_layer.kl_divergence()
- return(kl)
-
- def return_params(self):
- mus = []
- alphas = []
- betas = []
- sigmas = []
- for layer in self.hidden_layers:
- mu, alpha, beta, sigma = layer.return_parameters()
- mus += [mu.numpy()]
- alphas += [alpha.numpy()]
- betas += [beta.numpy()]
- sigmas += [sigma.numpy()]
- mu, alpha, beta, sigma = self.output_layer.return_parameters()
- mus += [mu.numpy()]
- alphas += [alpha.numpy()]
- betas += [beta.numpy()]
- sigmas += [sigma.numpy()]
- return mus, alphas, betas, sigmas
- ###############################################
- ## Gaussian Mean-Field Linear Transformation ##
- ###############################################
- class StochasticLinear(nn.Module):
- """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
- Args:
- in_features: size of each input sample
- out_features: size of each output sample
- bias: If set to False, the layer will not learn an additive bias.
- Default: ``True``
- Shape:
- - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
- additional dimensions
- - Output: :math:`(N, *, out\_features)` where all but the last dimension
- are the same shape as the input.
- Attributes:
- weight: the learnable weights of the module of shape
- `(out_features x in_features)`
- bias: the learnable bias of the module of shape `(out_features)`
- Examples::
- >>> m = nn.Linear(20, 30)
- >>> input = torch.randn(128, 20)
- >>> output = m(input)
- >>> print(output.size())
- """
- def __init__(self, in_features,
- out_features,
- sigma_prior=1.0,
- alpha_prior = 5.0,
- beta_prior = 5.0,
- sigma_init=1.0,
- alpha_init = 5.0,
- beta_init = 5.0,
- bias=True):
-
- super(type(self), self).__init__()
-
- self.count = 0
-
- self.in_features = in_features
- self.out_features = out_features
-
- self.sigma_prior = sigma_prior
- self.alpha_prior = alpha_prior
- self.beta_prior = beta_prior
-
- self.sigma_init = sigma_init
- self.alpha_init = alpha_init
- self.beta_init = beta_init
-
- self.weight_spsigma = None
- self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
- self.weight_spalpha = Parameter(torch.Tensor(out_features, in_features))
- self.weight_spbeta = Parameter(torch.Tensor(out_features, in_features))
- if bias:
- self.bias = True
- self.bias_spsigma = None
- self.bias_mu = Parameter(torch.Tensor(out_features))
- self.bias_spalpha = Parameter(torch.Tensor(out_features))
- self.bias_spbeta = Parameter(torch.Tensor(out_features))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
- def reset_parameters(self):
- stdv = 1. / math.sqrt(self.weight_mu.size(1))
- self.weight_mu.data.uniform_(-stdv, stdv)
- if self.bias is not None:
- self.bias_mu.data.uniform_(-stdv, stdv)
- self.weight_spalpha.data.uniform_(self.alpha_init-0.1,self.alpha_init+0.1)#fill_(self.alpha_init)
- self.weight_spbeta.data.uniform_(self.beta_init-0.1,self.beta_init+0.1)#fill_(self.beta_init)
- if self.bias is not None:
- self.bias_spalpha.data.uniform_(self.alpha_init-0.1,self.alpha_init+0.1)#fill_(self.alpha_init)
- self.bias_spbeta.data.uniform_(self.beta_init-0.1,self.beta_init+0.1)#fill_(self.beta_init)
- def forward(self, input, training = True):
- if not training:
- self.weight = self.weight_mu.data
- if self.bias is not None:
- self.bias = self.bias_mu.data
- return F.linear(input, self.weight, self.bias)
-
- # Construct Gamma distribution for reparameterization sampling
- gamma_dist = distributions.Gamma(F.softplus(self.weight_spalpha), F.softplus(self.weight_spbeta))
-
- # Sample variance parameters - Note here that spsigma is inverse Gamma
- # and that spsigma is considered the standard deviation of a Normal
- self.weight_spsigma = torch.sqrt(1.0/gamma_dist.rsample())
-
- epsilon_W = torch.normal(mean=torch.zeros_like(self.weight_mu), std=1.0)
-
- self.weight = self.weight_mu + self.weight_spsigma * epsilon_W
-
- if self.bias is not None:
- gamma_dist = distributions.Gamma(F.softplus(self.bias_spalpha), F.softplus(self.bias_spbeta))
-
- self.bias_spsigma = torch.sqrt(1.0/gamma_dist.rsample())
-
- epsilon_b = torch.normal(mean=torch.zeros_like(self.bias_mu), std=1.0)
- self.bias = self.bias_mu + self.bias_spsigma * epsilon_b
- return F.linear(input, self.weight, self.bias)
- # def _kl_gaussian(self, p_mu, p_sigma, q_mu, q_sigma):
- # var_ratio = (p_sigma / q_sigma).pow(2)
- # t1 = ((p_mu - q_mu) / q_sigma).pow(2)
- # return 0.5 * torch.sum((var_ratio + t1 - 1 - var_ratio.log()))
- #
- # def kl_divergence(self):
- # mu = self.weight_mu
- # sigma = F.softplus(self.weight_spsigma)
- # mu0 = torch.zeros_like(mu)
- # sigma0 = torch.ones_like(sigma) * self.sigma_prior
- # kl = self._kl_gaussian(p_mu = mu, p_sigma = sigma, q_mu = mu0, q_sigma = sigma0)
- # if self.bias is not None:
- # mu = self.bias_mu
- # sigma = F.softplus(self.bias_spsigma)
- # mu0 = torch.zeros_like(mu)
- # sigma0 = torch.ones_like(sigma) * self.sigma_prior
- # kl += self._kl_gaussian(p_mu = mu, p_sigma = sigma, q_mu = mu0, q_sigma = sigma0)
- # return kl
- def kl_divergence(self):
- # define posterior parameters
- mu = self.weight_mu
- sigma = self.weight_spsigma
- alpha = F.softplus(self.weight_spalpha)
- beta = F.softplus(self.weight_spbeta)
-
- # define prior parameters
- alpha0 = torch.ones_like(alpha)*self.alpha_prior
- beta0 = torch.ones_like(beta)*self.beta_prior
-
- q_w = distributions.Normal(mu, sigma)
- kl = (E_q_w_lambda + E_q_lambda - E_p_w_lambda - E_p_lambda).sum()
-
- if self.bias is not None:
- # define posterior parameters
- mu = self.bias_mu
- sigma = self.bias_spsigma
- alpha = F.softplus(self.bias_spalpha)
- beta = F.softplus(self.bias_spbeta)
-
- # define prior parameters
- alpha0 = torch.ones_like(alpha)*self.alpha_prior
- beta0 = torch.ones_like(beta)*self.beta_prior
-
- # Compute E_q( log p(lambda) )
- # Compute the KL divergence
- kl += (E_q_w_lambda + E_q_lambda - E_p_w_lambda - E_p_lambda).sum()
- return kl
- def return_parameters(self):
- mu = self.weight_mu.data
- alpha = self.weight_spalpha.data
- beta= self.weight_spbeta.data
- sigma = self.weight_spsigma.data
- return mu, alpha, beta, sigma
- def extra_repr(self):
- return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
- self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
- )