/Main/Models/BBB_t_linear.py
Python | 195 lines | 137 code | 30 blank | 28 comment | 22 complexity | 5a9d603567948b19e0ea0e865549c6c1 MD5 | raw file
1import math 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5from torch.nn.parameter import Parameter 6import torch.distributions as distributions 7 8############################# 9## Bayesian Neural Network ## 10############################# 11 12class BNN(nn.Module): 13 def __init__(self, input_size, 14 hidden_sizes, 15 output_size = None, 16 prior_prec=1.0, 17 prec_init=1.0, 18 df_prior = 5.0, 19 df_init = 5.0): 20 21 super(type(self), self).__init__() 22 23 self.input_size = input_size 24 sigma_prior = 1.0/math.sqrt(prior_prec) 25 sigma_init = 1.0/math.sqrt(prec_init) 26 if output_size: 27 self.output_size = output_size 28 self.squeeze_output = False 29 else : 30 self.output_size = 1 31 self.squeeze_output = True 32 if len(hidden_sizes) == 0: 33 self.hidden_layers = [] 34 self.output_layer = StochasticLinear(self.input_size, 35 self.output_size, 36 sigma_prior = sigma_prior, 37 sigma_init = sigma_init, 38 df_prior = df_prior, 39 df_init = df_init) 40 else: 41 self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)]) 42 self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, df_prior = df_prior, sigma_init = sigma_init, df_init = df_init) 43 44 def forward(self, x): 45 out = x 46 for layer in self.hidden_layers: 47 out = layer(out) 48 logits = self.output_layer(out) 49 if self.squeeze_output: 50 logits = torch.squeeze(logits) 51 return logits 52 53 def kl_divergence(self): 54 kl = 0 55 for layer in self.hidden_layers: 56 kl += layer.kl_divergence() 57 kl += self.output_layer.kl_divergence() 58 return(kl) 59 60 def return_params(self): 61 mus = [] 62 sigmas = [] 63 dfs = [] 64 weights = [] 65 for layer in self.hidden_layers: 66 mu, sigma, df, weight = layer.return_parameters() 67 mus += [mu.numpy()] 68 sigmas += [sigma.numpy()] 69 dfs += [df.numpy()] 70 weights += [weight.numpy()] 71 mu, sigma, df, weight = self.output_layer.return_parameters() 72 mus += [mu.numpy()] 73 sigmas += [sigma.numpy()] 74 dfs += [df.numpy()] 75 weights += [weight.numpy()] 76 return mus, sigmas, dfs, weights 77 78 79############################################### 80## Gaussian Mean-Field Linear Transformation ## 81############################################### 82 83class StochasticLinear(nn.Module): 84 """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b` 85 Args: 86 in_features: size of each input sample 87 out_features: size of each output sample 88 bias: If set to False, the layer will not learn an additive bias. 89 Default: ``True`` 90 Shape: 91 - Input: :math:`(N, *, in\_features)` where :math:`*` means any number of 92 additional dimensions 93 - Output: :math:`(N, *, out\_features)` where all but the last dimension 94 are the same shape as the input. 95 Attributes: 96 weight: the learnable weights of the module of shape 97 `(out_features x in_features)` 98 bias: the learnable bias of the module of shape `(out_features)` 99 Examples:: 100 >>> m = nn.Linear(20, 30) 101 >>> input = torch.randn(128, 20) 102 >>> output = m(input) 103 >>> print(output.size()) 104 """ 105 106 def __init__(self, in_features, 107 out_features, 108 sigma_prior=1.0, 109 df_prior = 5.0, 110 sigma_init=1.0, 111 df_init = 5.0, 112 bias=False): 113 114 super(type(self), self).__init__() 115 116 self.count = 0 117 118 self.in_features = in_features 119 self.out_features = out_features 120 121 self.sigma_prior = sigma_prior 122 self.df_prior = df_prior 123 124 self.sigma_init = sigma_init 125 self.df_init = df_init 126 127 self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features)) 128 self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) 129 self.weight_spdf = Parameter(torch.Tensor(out_features, in_features)) 130 if bias: 131 self.bias = True 132 self.bias_spsigma = Parameter(torch.Tensor(out_features)) 133 self.bias_mu = Parameter(torch.Tensor(out_features)) 134 self.bias_spdf = Parameter(torch.Tensor(out_features)) 135 else: 136 self.register_parameter('bias', None) 137 self.reset_parameters() 138 139 def reset_parameters(self): 140 stdv = 1. / math.sqrt(self.weight_mu.size(1)) 141 self.weight_mu.data.uniform_(-stdv, stdv) 142 if self.bias is not None: 143 self.bias_mu.data.uniform_(-stdv, stdv) 144 self.weight_spsigma.data.fill_(self.sigma_init) 145 self.weight_spdf.data.fill_(self.df_init) 146 if self.bias is not None: 147 self.bias_spsigma.data.fill_(self.sigma_init) 148 self.bias_spdf.data.fill_(self.df_init) 149 150 def forward(self, input): 151 # Construct Gamma distribution for reparameterization sampling 152 t_dist = distributions.StudentT(F.softplus(self.weight_spdf), loc = self.weight_mu, scale = F.softplus(self.weight_spsigma)) 153 154 self.weight = t_dist.rsample() 155 156 if self.bias is not None: 157 t_dist = distributions.StudentT(F.softplus(self.bias_spdf), loc = self.bias_mu, scale = F.softplus(self.bias_spsigma)) 158 self.bias = t_dist.rsample() 159 return F.linear(input, self.weight, self.bias) 160 161 def kl_divergence(self): 162 mu = self.weight_mu 163 sigma = F.softplus(self.weight_spsigma) + 1e-5 164 mu0 = torch.zeros_like(mu) 165 sigma0 = torch.ones_like(sigma) * self.sigma_prior 166 167 q = distributions.Normal(mu,sigma) 168 p = distributions.Normal(mu0,sigma0) 169 170 kl = distributions.kl_divergence(q,p).sum() 171 172 if self.bias is not None: 173 mu = self.bias_mu 174 sigma = F.softplus(self.bias_spsigma) + 1e-5 175 mu0 = torch.zeros_like(mu) 176 sigma0 = torch.ones_like(sigma) * self.sigma_prior 177 178 q = distributions.Normal(mu,sigma) 179 p = distributions.Normal(mu0,sigma0) 180 181 kl += distributions.kl_divergence(q,p).sum() 182 183 return kl 184 185 def return_parameters(self): 186 mu = self.weight_mu.data 187 sigma = F.softplus(self.weight_spsigma.data) 188 df = F.softplus(self.weight_spdf.data) 189 weight = self.weight.data 190 return mu, sigma, df, weight 191 192 def extra_repr(self): 193 return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format( 194 self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None 195 )