/layers.py

https://github.com/iwtw/pytorch-TP-GAN
Python | 105 lines | 86 code | 11 blank | 8 comment | 23 complexity | 9358dd374e323bf2630ddd2a07b317f2 MD5 | raw file
  1. #wrappers for convenience
  2. import torch.nn as nn
  3. from torch.nn.init import xavier_normal , kaiming_normal
  4. def sequential(*kargs ):
  5. seq = nn.Sequential(*kargs)
  6. for layer in reversed(kargs):
  7. if hasattr( layer , 'out_channels'):
  8. seq.out_channels = layer.out_channels
  9. break
  10. if hasattr( layer , 'out_features'):
  11. seq.out_channels = layer.out_features
  12. break
  13. return seq
  14. def weight_initialization( weight , init , activation):
  15. if init is None:
  16. return
  17. if init == "kaiming":
  18. assert not activation is None
  19. if hasattr(activation,"negative_slope"):
  20. kaiming_normal( weight , a = activation.negative_slope )
  21. else:
  22. kaiming_normal( weight , a = 0 )
  23. elif init == "xavier":
  24. xavier_normal( weight )
  25. return
  26. def conv( in_channels , out_channels , kernel_size , stride = 1 , padding = 0 , init = "kaiming" , activation = nn.ReLU() , use_batchnorm = False ):
  27. convs = []
  28. if type(padding) == type(list()) :
  29. assert len(padding) != 3
  30. if len(padding)==4:
  31. convs.append( nn.ReflectionPad2d( padding ) )
  32. padding = 0
  33. #print(padding)
  34. convs.append( nn.Conv2d( in_channels , out_channels , kernel_size , stride , padding ) )
  35. #weight init
  36. weight_initialization( convs[-1].weight , init , activation )
  37. #activation
  38. if not activation is None:
  39. convs.append( activation )
  40. #bn
  41. if use_batchnorm:
  42. convs.append( nn.BatchNorm2d( out_channels ) )
  43. seq = nn.Sequential( *convs )
  44. seq.out_channels = out_channels
  45. return seq
  46. def deconv( in_channels , out_channels , kernel_size , stride = 1 , padding = 0 , output_padding = 0 , init = "kaiming" , activation = nn.ReLU() , use_batchnorm = False):
  47. convs = []
  48. convs.append( nn.ConvTranspose2d( in_channels , out_channels , kernel_size , stride , padding , output_padding ) )
  49. #weight init
  50. weight_initialization( convs[0].weight , init , activation )
  51. #activation
  52. if not activation is None:
  53. convs.append( activation )
  54. #bn
  55. if use_batchnorm:
  56. convs.append( nn.BatchNorm2d( out_channels ) )
  57. seq = nn.Sequential( *convs )
  58. seq.out_channels = out_channels
  59. return seq
  60. class ResidualBlock(nn.Module):
  61. def __init__(self, in_channels ,
  62. out_channels = None,
  63. kernel_size = 3,
  64. stride = 1,
  65. padding = None ,
  66. weight_init = "kaiming" ,
  67. activation = nn.ReLU() ,
  68. is_bottleneck = False ,
  69. use_projection = False,
  70. scaling_factor = 1.0
  71. ):
  72. super(type(self),self).__init__()
  73. if out_channels is None:
  74. out_channels = in_channels // stride
  75. self.out_channels = out_channels
  76. self.use_projection = use_projection
  77. self.scaling_factor = scaling_factor
  78. self.activation = activation
  79. convs = []
  80. assert stride in [1,2]
  81. if stride == 1 :
  82. self.shorcut = nn.Sequential()
  83. else:
  84. self.shorcut = conv( in_channels , out_channels , 1 , stride , 0 , None , None , False )
  85. if is_bottleneck:
  86. convs.append( conv( in_channels , in_channels//2 , 1 , 1 , 0 , weight_init , activation , False))
  87. convs.append( conv( in_channels//2 , out_channels//2 , kernel_size , stride , (kernel_size - 1)//2 , weight_init , activation , False))
  88. convs.append( conv( out_channels//2 , out_channels , 1 , 1 , 0 , None , None , False))
  89. else:
  90. convs.append( conv( in_channels , in_channels , kernel_size , 1 , padding if padding is not None else (kernel_size - 1)//2 , weight_init , activation , False))
  91. convs.append( conv( in_channels , out_channels , kernel_size , 1 , padding if padding is not None else (kernel_size - 1)//2 , None , None , False))
  92. self.layers = nn.Sequential( *convs )
  93. def forward(self,x):
  94. return self.activation( self.layers(x) + self.scaling_factor * self.shorcut(x) )