/pytorch/vadam/models.py

https://github.com/emtiyaz/vadam · Python · 249 lines · 175 code · 35 blank · 39 comment · 41 complexity · bb66a319f4c17c4a1d2e4dd66691b35c MD5 · raw file

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. ############################
  7. ## Multi-Layer Perceptron ##
  8. ############################
  9. class MLP(nn.Module):
  10. def __init__(self, input_size, hidden_sizes, output_size, act_func="relu"):
  11. super(MLP, self).__init__()
  12. self.input_size = input_size
  13. self.hidden_sizes = hidden_sizes
  14. if output_size is not None:
  15. self.output_size = output_size
  16. self.squeeze_output = False
  17. else :
  18. self.output_size = 1
  19. self.squeeze_output = True
  20. # Set activation function
  21. if act_func == "relu":
  22. self.act = F.relu
  23. elif act_func == "tanh":
  24. self.act = F.tanh
  25. elif act_func == "sigmoid":
  26. self.act = F.sigmoid
  27. # Define layers
  28. if len(hidden_sizes) == 0:
  29. # Linear model
  30. self.hidden_layers = []
  31. self.output_layer = nn.Linear(self.input_size, self.output_size)
  32. else:
  33. # Neural network
  34. self.hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  35. self.output_layer = nn.Linear(hidden_sizes[-1], self.output_size)
  36. def forward(self, x):
  37. x = x.view(-1,self.input_size)
  38. out = x
  39. for layer in self.hidden_layers:
  40. out = self.act(layer(out))
  41. z = self.output_layer(out)
  42. if self.squeeze_output:
  43. z = torch.squeeze(z).view([-1])
  44. return z
  45. #############################
  46. ## Bayesian Neural Network ##
  47. #############################
  48. class BNN(nn.Module):
  49. def __init__(self, input_size, hidden_sizes, output_size, act_func="relu", prior_prec=1.0, prec_init=1.0):
  50. super(type(self), self).__init__()
  51. self.input_size = input_size
  52. sigma_prior = 1.0/math.sqrt(prior_prec)
  53. sigma_init = 1.0/math.sqrt(prec_init)
  54. if output_size:
  55. self.output_size = output_size
  56. self.squeeze_output = False
  57. else :
  58. self.output_size = 1
  59. self.squeeze_output = True
  60. # Set activation function
  61. if act_func == "relu":
  62. self.act = F.relu
  63. elif act_func == "tanh":
  64. self.act = F.tanh
  65. elif act_func == "sigmoid":
  66. self.act = F.sigmoid
  67. # Define layers
  68. if len(hidden_sizes) == 0:
  69. # Linear model
  70. self.hidden_layers = []
  71. self.output_layer = StochasticLinear(self.input_size, self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
  72. else:
  73. # Neural network
  74. self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  75. self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
  76. def forward(self, x):
  77. x = x.view(-1,self.input_size)
  78. out = x
  79. for layer in self.hidden_layers:
  80. out = self.act(layer(out))
  81. z = self.output_layer(out)
  82. if self.squeeze_output:
  83. z = torch.squeeze(z).view([-1])
  84. return z
  85. def kl_divergence(self):
  86. kl = 0
  87. for layer in self.hidden_layers:
  88. kl += layer.kl_divergence()
  89. kl += self.output_layer.kl_divergence()
  90. return(kl)
  91. ###############################################
  92. ## Gaussian Mean-Field Linear Transformation ##
  93. ###############################################
  94. class StochasticLinear(nn.Module):
  95. """Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`.
  96. This is a stochastic variant of the in-built torch.nn.Linear().
  97. """
  98. def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=True):
  99. super(type(self), self).__init__()
  100. self.in_features = in_features
  101. self.out_features = out_features
  102. self.sigma_prior = sigma_prior
  103. self.sigma_init = sigma_init
  104. self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
  105. self.weight_spsigma = Parameter(torch.Tensor(out_features, in_features))
  106. if bias:
  107. self.bias = True
  108. self.bias_mu = Parameter(torch.Tensor(out_features))
  109. self.bias_spsigma = Parameter(torch.Tensor(out_features))
  110. else:
  111. self.register_parameter('bias', None)
  112. self.reset_parameters()
  113. def reset_parameters(self):
  114. stdv = 1. / math.sqrt(self.weight_mu.size(1))
  115. self.weight_mu.data.uniform_(-stdv, stdv)
  116. if self.bias is not None:
  117. self.bias_mu.data.uniform_(-stdv, stdv)
  118. self.weight_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  119. if self.bias is not None:
  120. self.bias_spsigma.data.fill_(math.log(math.exp(self.sigma_init)-1))
  121. def forward(self, input):
  122. epsilon_W = torch.normal(mean=torch.zeros_like(self.weight_mu), std=1.0)
  123. weight = self.weight_mu + F.softplus(self.weight_spsigma) * epsilon_W
  124. if self.bias is not None:
  125. epsilon_b = torch.normal(mean=torch.zeros_like(self.bias_mu), std=1.0)
  126. bias = self.bias_mu + F.softplus(self.bias_spsigma) * epsilon_b
  127. return F.linear(input, weight, bias)
  128. def _kl_gaussian(self, p_mu, p_sigma, q_mu, q_sigma):
  129. var_ratio = (p_sigma / q_sigma).pow(2)
  130. t1 = ((p_mu - q_mu) / q_sigma).pow(2)
  131. return 0.5 * torch.sum((var_ratio + t1 - 1 - var_ratio.log()))
  132. def kl_divergence(self):
  133. # Compute KL divergence between current distribution and the prior.
  134. mu = self.weight_mu
  135. sigma = F.softplus(self.weight_spsigma)
  136. mu0 = torch.zeros_like(mu)
  137. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  138. kl = self._kl_gaussian(p_mu = mu, p_sigma = sigma, q_mu = mu0, q_sigma = sigma0)
  139. if self.bias is not None:
  140. mu = self.bias_mu
  141. sigma = F.softplus(self.bias_spsigma)
  142. mu0 = torch.zeros_like(mu)
  143. sigma0 = torch.ones_like(sigma) * self.sigma_prior
  144. kl += self._kl_gaussian(p_mu = mu, p_sigma = sigma, q_mu = mu0, q_sigma = sigma0)
  145. return kl
  146. def extra_repr(self):
  147. return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
  148. self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
  149. )
  150. #################################################################
  151. ## MultiLayer Perceptron with support for individual gradients ##
  152. #################################################################
  153. class IndividualGradientMLP(nn.Module):
  154. def __init__(self, input_size, hidden_sizes, output_size, act_func="relu"):
  155. super(type(self), self).__init__()
  156. self.input_size = input_size
  157. self.hidden_sizes = hidden_sizes
  158. if output_size is not None:
  159. self.output_size = output_size
  160. self.squeeze_output = False
  161. else :
  162. self.output_size = 1
  163. self.squeeze_output = True
  164. # Set activation function
  165. if act_func == "relu":
  166. self.act = F.relu
  167. elif act_func == "tanh":
  168. self.act = F.tanh
  169. elif act_func == "sigmoid":
  170. self.act = F.sigmoid
  171. # Define layers
  172. if len(hidden_sizes) == 0:
  173. # Linear model
  174. self.hidden_layers = []
  175. self.output_layer = nn.Linear(self.input_size, self.output_size)
  176. else:
  177. # Neural network
  178. self.hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
  179. self.output_layer = nn.Linear(hidden_sizes[-1], self.output_size)
  180. def forward(self, x, individual_grads=False):
  181. '''
  182. x: The input patterns/features.
  183. individual_grads: Whether or not the activations tensors and linear
  184. combination tensors from each layer are returned. These tensors
  185. are necessary for computing the GGN using goodfellow_backprop_ggn
  186. '''
  187. x = x.view(-1, self.input_size)
  188. out = x
  189. # Save the model inputs, which are considered the activations of the
  190. # 0'th layer.
  191. if individual_grads:
  192. H_list = [out]
  193. Z_list = []
  194. for layer in self.hidden_layers:
  195. Z = layer(out)
  196. out = self.act(Z)
  197. # Save the activations and linear combinations from this layer.
  198. if individual_grads:
  199. H_list.append(out)
  200. Z.retain_grad()
  201. Z.requires_grad_(True)
  202. Z_list.append(Z)
  203. z = self.output_layer(out)
  204. if self.squeeze_output:
  205. z = torch.squeeze(z).view([-1])
  206. # Save the final model ouputs, which are the linear combinations
  207. # from the final layer.
  208. if individual_grads:
  209. z.retain_grad()
  210. z.requires_grad_(True)
  211. Z_list.append(z)
  212. if individual_grads:
  213. return (z, H_list, Z_list)
  214. return z