PageRenderTime 6ms CodeModel.GetById 21ms app.highlight 7ms RepoModel.GetById 0ms app.codeStats 0ms

/Main/Models/BBB_t_linear.py

https://bitbucket.org/RamiroCope/thesis_repo
Python | 195 lines | 137 code | 30 blank | 28 comment | 22 complexity | 5a9d603567948b19e0ea0e865549c6c1 MD5 | raw file
  1import math
  2import torch
  3import torch.nn as nn
  4import torch.nn.functional as F
  5from torch.nn.parameter import Parameter
  6import torch.distributions as distributions
  7
  8#############################
  9## Bayesian Neural Network ##
 10#############################
 11
 12class BNN(nn.Module):
 13    def __init__(self, input_size,
 14                         hidden_sizes,
 15                         output_size = None,
 16                         prior_prec=1.0,
 17                         prec_init=1.0,
 18                         df_prior = 5.0,
 19                         df_init = 5.0):
 20        
 21        super(type(self), self).__init__()        
 22        
 23        self.input_size = input_size
 24        sigma_prior = 1.0/math.sqrt(prior_prec)
 25        sigma_init = 1.0/math.sqrt(prec_init)
 26        if output_size:
 27            self.output_size = output_size
 28            self.squeeze_output = False
 29        else :
 30            self.output_size = 1
 31            self.squeeze_output = True
 32        if len(hidden_sizes) == 0:
 33            self.hidden_layers = []
 34            self.output_layer = StochasticLinear(self.input_size,
 35                                                         self.output_size,
 36                                                         sigma_prior = sigma_prior,
 37                                                         sigma_init = sigma_init,
 38                                                         df_prior = df_prior,
 39                                                         df_init = df_init)
 40        else:
 41            self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
 42            self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init)
 43
 44    def forward(self, x):
 45        out = x
 46        for layer in self.hidden_layers:
 47            out = layer(out)
 48        logits = self.output_layer(out)
 49        if self.squeeze_output:
 50            logits = torch.squeeze(logits)
 51        return logits
 52
 53    def kl_divergence(self):
 54        kl = 0
 55        for layer in self.hidden_layers:
 56            kl += layer.kl_divergence()
 57        kl += self.output_layer.kl_divergence()
 58        return(kl)
 59        
 60    def return_params(self):
 61        mus = []
 62        sigmas = []
 63        dfs = []
 64        weights = []
 65        for layer in self.hidden_layers:
 66            mu, sigma, df, weight = layer.return_parameters()
 67            mus += [mu.numpy()]
 68            sigmas += [sigma.numpy()]
 69            dfs += [df.numpy()]
 70            weights += [weight.numpy()]
 71        mu, sigma, df, weight = self.output_layer.return_parameters()
 72        mus += [mu.numpy()]
 73        sigmas += [sigma.numpy()]
 74        dfs += [df.numpy()]
 75        weights += [weight.numpy()]
 76        return mus, sigmas, dfs, weights
 77
 78
 79###############################################
 80## Gaussian Mean-Field Linear Transformation ##
 81###############################################
 82
 83class StochasticLinear(nn.Module):
 84    """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
 85    Args:
 86        in_features: size of each input sample
 87        out_features: size of each output sample
 88        bias: If set to False, the layer will not learn an additive bias.
 89            Default: ``True``
 90    Shape:
 91        - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
 92          additional dimensions
 93        - Output: :math:`(N, *, out\_features)` where all but the last dimension
 94          are the same shape as the input.
 95    Attributes:
 96        weight: the learnable weights of the module of shape
 97            `(out_features x in_features)`
 98        bias:   the learnable bias of the module of shape `(out_features)`
 99    Examples::
100        >>> m = nn.Linear(20, 30)
101        >>> input = torch.randn(128, 20)
102        >>> output = m(input)
103        >>> print(output.size())
104    """
105
106    def __init__(self, in_features,
107                         out_features,
108                         sigma_prior=1.0,
109                         df_prior = 5.0,
110                         sigma_init=1.0,
111                         df_init = 5.0,
112                         bias=False):
113        
114        super(type(self), self).__init__()
115        
116        self.count = 0
117        
118        self.in_features = in_features
119        self.out_features = out_features
120        
121        self.sigma_prior = sigma_prior
122        self.df_prior = df_prior
123        
124        self.sigma_init = sigma_init
125        self.df_init = df_init
126        
127        self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
128        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
129        self.weight_spdf = Parameter(torch.Tensor(out_features, in_features))
130        if bias:
131            self.bias = True
132            self.bias_spsigma = Parameter(torch.Tensor(out_features))
133            self.bias_mu = Parameter(torch.Tensor(out_features))
134            self.bias_spdf = Parameter(torch.Tensor(out_features))
135        else:
136            self.register_parameter('bias', None)
137        self.reset_parameters()
138
139    def reset_parameters(self):
140        stdv = 1. / math.sqrt(self.weight_mu.size(1))
141        self.weight_mu.data.uniform_(-stdv, stdv)
142        if self.bias is not None:
143            self.bias_mu.data.uniform_(-stdv, stdv)
144        self.weight_spsigma.data.fill_(self.sigma_init)
145        self.weight_spdf.data.fill_(self.df_init)
146        if self.bias is not None:
147            self.bias_spsigma.data.fill_(self.sigma_init)
148            self.bias_spdf.data.fill_(self.df_init)
149
150    def forward(self, input):
151        # Construct Gamma distribution for reparameterization sampling
152        t_dist = distributions.StudentT(F.softplus(self.weight_spdf), loc = self.weight_mu, scale = F.softplus(self.weight_spsigma))
153        
154        self.weight = t_dist.rsample()
155        
156        if self.bias is not None:
157            t_dist = distributions.StudentT(F.softplus(self.bias_spdf), loc = self.bias_mu, scale = F.softplus(self.bias_spsigma))
158            self.bias = t_dist.rsample()
159        return F.linear(input, self.weight, self.bias)
160
161    def kl_divergence(self):
162        mu = self.weight_mu
163        sigma = F.softplus(self.weight_spsigma) + 1e-5
164        mu0 = torch.zeros_like(mu)
165        sigma0 = torch.ones_like(sigma) * self.sigma_prior
166        
167        q = distributions.Normal(mu,sigma)
168        p = distributions.Normal(mu0,sigma0)
169        
170        kl = distributions.kl_divergence(q,p).sum()
171        
172        if self.bias is not None:
173            mu = self.bias_mu
174            sigma = F.softplus(self.bias_spsigma) + 1e-5
175            mu0 = torch.zeros_like(mu)
176            sigma0 = torch.ones_like(sigma) * self.sigma_prior
177        
178            q = distributions.Normal(mu,sigma)
179            p = distributions.Normal(mu0,sigma0)
180            
181            kl += distributions.kl_divergence(q,p).sum()
182        
183        return kl
184
185    def return_parameters(self):
186        mu = self.weight_mu.data
187        sigma = F.softplus(self.weight_spsigma.data)
188        df = F.softplus(self.weight_spdf.data)
189        weight = self.weight.data
190        return mu, sigma, df, weight
191
192    def extra_repr(self):
193        return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
194            self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
195        )