/torch/csrc/api/include/torch/nn/options/conv.h

https://github.com/pytorch/pytorch · C Header · 382 lines · 113 code · 72 blank · 197 comment · 0 complexity · 13449af63ba4092a883529a212371d61 MD5 · raw file

  1. #pragma once
  2. #include <torch/arg.h>
  3. #include <torch/enum.h>
  4. #include <torch/csrc/WindowsTorchApiMacro.h>
  5. #include <torch/expanding_array.h>
  6. #include <torch/types.h>
  7. namespace torch {
  8. namespace nn {
  9. namespace detail {
  10. typedef c10::variant<
  11. enumtype::kZeros,
  12. enumtype::kReflect,
  13. enumtype::kReplicate,
  14. enumtype::kCircular
  15. > conv_padding_mode_t;
  16. /// Options for a `D`-dimensional convolution or convolution transpose module.
  17. template <size_t D>
  18. struct ConvNdOptions {
  19. ConvNdOptions(
  20. int64_t in_channels,
  21. int64_t out_channels,
  22. ExpandingArray<D> kernel_size) :
  23. in_channels_(in_channels),
  24. out_channels_(out_channels),
  25. kernel_size_(std::move(kernel_size)) {}
  26. /// The number of channels the input volumes will have.
  27. /// Changing this parameter after construction __has no effect__.
  28. TORCH_ARG(int64_t, in_channels);
  29. /// The number of output channels the convolution should produce.
  30. /// Changing this parameter after construction __has no effect__.
  31. TORCH_ARG(int64_t, out_channels);
  32. /// The kernel size to use.
  33. /// For a `D`-dim convolution, must be a single number or a list of `D`
  34. /// numbers.
  35. /// This parameter __can__ be changed after construction.
  36. TORCH_ARG(ExpandingArray<D>, kernel_size);
  37. /// The stride of the convolution.
  38. /// For a `D`-dim convolution, must be a single number or a list of `D`
  39. /// numbers.
  40. /// This parameter __can__ be changed after construction.
  41. TORCH_ARG(ExpandingArray<D>, stride) = 1;
  42. /// The padding to add to the input volumes.
  43. /// For a `D`-dim convolution, must be a single number or a list of `D`
  44. /// numbers.
  45. /// This parameter __can__ be changed after construction.
  46. TORCH_ARG(ExpandingArray<D>, padding) = 0;
  47. /// The kernel dilation.
  48. /// For a `D`-dim convolution, must be a single number or a list of `D`
  49. /// numbers.
  50. /// This parameter __can__ be changed after construction.
  51. TORCH_ARG(ExpandingArray<D>, dilation) = 1;
  52. /// If true, convolutions will be transpose convolutions (a.k.a.
  53. /// deconvolutions).
  54. /// Changing this parameter after construction __has no effect__.
  55. TORCH_ARG(bool, transposed) = false;
  56. /// For transpose convolutions, the padding to add to output volumes.
  57. /// For a `D`-dim convolution, must be a single number or a list of `D`
  58. /// numbers.
  59. /// This parameter __can__ be changed after construction.
  60. TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
  61. /// The number of convolution groups.
  62. /// This parameter __can__ be changed after construction.
  63. TORCH_ARG(int64_t, groups) = 1;
  64. /// Whether to add a bias after individual applications of the kernel.
  65. /// Changing this parameter after construction __has no effect__.
  66. TORCH_ARG(bool, bias) = true;
  67. /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
  68. TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros;
  69. };
  70. } // namespace detail
  71. // ============================================================================
  72. /// Options for a `D`-dimensional convolution module.
  73. template <size_t D>
  74. struct ConvOptions {
  75. using padding_mode_t = detail::conv_padding_mode_t;
  76. ConvOptions(
  77. int64_t in_channels,
  78. int64_t out_channels,
  79. ExpandingArray<D> kernel_size) :
  80. in_channels_(in_channels),
  81. out_channels_(out_channels),
  82. kernel_size_(std::move(kernel_size)) {}
  83. /// The number of channels the input volumes will have.
  84. /// Changing this parameter after construction __has no effect__.
  85. TORCH_ARG(int64_t, in_channels);
  86. /// The number of output channels the convolution should produce.
  87. /// Changing this parameter after construction __has no effect__.
  88. TORCH_ARG(int64_t, out_channels);
  89. /// The kernel size to use.
  90. /// For a `D`-dim convolution, must be a single number or a list of `D`
  91. /// numbers.
  92. /// This parameter __can__ be changed after construction.
  93. TORCH_ARG(ExpandingArray<D>, kernel_size);
  94. /// The stride of the convolution.
  95. /// For a `D`-dim convolution, must be a single number or a list of `D`
  96. /// numbers.
  97. /// This parameter __can__ be changed after construction.
  98. TORCH_ARG(ExpandingArray<D>, stride) = 1;
  99. /// The padding to add to the input volumes.
  100. /// For a `D`-dim convolution, must be a single number or a list of `D`
  101. /// numbers.
  102. /// This parameter __can__ be changed after construction.
  103. TORCH_ARG(ExpandingArray<D>, padding) = 0;
  104. /// The kernel dilation.
  105. /// For a `D`-dim convolution, must be a single number or a list of `D`
  106. /// numbers.
  107. /// This parameter __can__ be changed after construction.
  108. TORCH_ARG(ExpandingArray<D>, dilation) = 1;
  109. /// The number of convolution groups.
  110. /// This parameter __can__ be changed after construction.
  111. TORCH_ARG(int64_t, groups) = 1;
  112. /// Whether to add a bias after individual applications of the kernel.
  113. /// Changing this parameter after construction __has no effect__.
  114. TORCH_ARG(bool, bias) = true;
  115. /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
  116. TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
  117. };
  118. /// `ConvOptions` specialized for the `Conv1d` module.
  119. ///
  120. /// Example:
  121. /// ```
  122. /// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
  123. /// ```
  124. using Conv1dOptions = ConvOptions<1>;
  125. /// `ConvOptions` specialized for the `Conv2d` module.
  126. ///
  127. /// Example:
  128. /// ```
  129. /// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
  130. /// ```
  131. using Conv2dOptions = ConvOptions<2>;
  132. /// `ConvOptions` specialized for the `Conv3d` module.
  133. ///
  134. /// Example:
  135. /// ```
  136. /// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
  137. /// ```
  138. using Conv3dOptions = ConvOptions<3>;
  139. // ============================================================================
  140. namespace functional {
  141. /// Options for a `D`-dimensional convolution functional.
  142. template <size_t D>
  143. struct ConvFuncOptions {
  144. /// optional bias of shape `(out_channels)`. Default: ``None``
  145. TORCH_ARG(torch::Tensor, bias) = Tensor();
  146. /// The stride of the convolving kernel.
  147. /// For a `D`-dim convolution, must be a single number or a list of `D`
  148. /// numbers.
  149. TORCH_ARG(ExpandingArray<D>, stride) = 1;
  150. /// Implicit paddings on both sides of the input.
  151. /// For a `D`-dim convolution, must be a single number or a list of `D`
  152. /// numbers.
  153. TORCH_ARG(ExpandingArray<D>, padding) = 0;
  154. /// The spacing between kernel elements.
  155. /// For a `D`-dim convolution, must be a single number or a list of `D`
  156. /// numbers.
  157. TORCH_ARG(ExpandingArray<D>, dilation) = 1;
  158. /// Split input into groups, `in_channels` should be divisible by
  159. /// the number of groups.
  160. TORCH_ARG(int64_t, groups) = 1;
  161. };
  162. /// `ConvFuncOptions` specialized for `torch::nn::functional::conv1d`.
  163. ///
  164. /// Example:
  165. /// ```
  166. /// namespace F = torch::nn::functional;
  167. /// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
  168. /// ```
  169. using Conv1dFuncOptions = ConvFuncOptions<1>;
  170. /// `ConvFuncOptions` specialized for `torch::nn::functional::conv2d`.
  171. ///
  172. /// Example:
  173. /// ```
  174. /// namespace F = torch::nn::functional;
  175. /// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
  176. /// ```
  177. using Conv2dFuncOptions = ConvFuncOptions<2>;
  178. /// `ConvFuncOptions` specialized for `torch::nn::functional::conv3d`.
  179. ///
  180. /// Example:
  181. /// ```
  182. /// namespace F = torch::nn::functional;
  183. /// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
  184. /// ```
  185. using Conv3dFuncOptions = ConvFuncOptions<3>;
  186. } // namespace functional
  187. // ============================================================================
  188. template <size_t D>
  189. struct ConvTransposeOptions {
  190. using padding_mode_t = detail::conv_padding_mode_t;
  191. ConvTransposeOptions(
  192. int64_t in_channels,
  193. int64_t out_channels,
  194. ExpandingArray<D> kernel_size) :
  195. in_channels_(in_channels),
  196. out_channels_(out_channels),
  197. kernel_size_(std::move(kernel_size)) {}
  198. /// The number of channels the input volumes will have.
  199. /// Changing this parameter after construction __has no effect__.
  200. TORCH_ARG(int64_t, in_channels);
  201. /// The number of output channels the convolution should produce.
  202. /// Changing this parameter after construction __has no effect__.
  203. TORCH_ARG(int64_t, out_channels);
  204. /// The kernel size to use.
  205. /// For a `D`-dim convolution, must be a single number or a list of `D`
  206. /// numbers.
  207. /// This parameter __can__ be changed after construction.
  208. TORCH_ARG(ExpandingArray<D>, kernel_size);
  209. /// The stride of the convolution.
  210. /// For a `D`-dim convolution, must be a single number or a list of `D`
  211. /// numbers.
  212. /// This parameter __can__ be changed after construction.
  213. TORCH_ARG(ExpandingArray<D>, stride) = 1;
  214. /// The padding to add to the input volumes.
  215. /// For a `D`-dim convolution, must be a single number or a list of `D`
  216. /// numbers.
  217. /// This parameter __can__ be changed after construction.
  218. TORCH_ARG(ExpandingArray<D>, padding) = 0;
  219. /// For transpose convolutions, the padding to add to output volumes.
  220. /// For a `D`-dim convolution, must be a single number or a list of `D`
  221. /// numbers.
  222. /// This parameter __can__ be changed after construction.
  223. TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
  224. /// The number of convolution groups.
  225. /// This parameter __can__ be changed after construction.
  226. TORCH_ARG(int64_t, groups) = 1;
  227. /// Whether to add a bias after individual applications of the kernel.
  228. /// Changing this parameter after construction __has no effect__.
  229. TORCH_ARG(bool, bias) = true;
  230. /// The kernel dilation.
  231. /// For a `D`-dim convolution, must be a single number or a list of `D`
  232. /// numbers.
  233. /// This parameter __can__ be changed after construction.
  234. TORCH_ARG(ExpandingArray<D>, dilation) = 1;
  235. /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or `torch::kCircular`. Default: `torch::kZeros`
  236. TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
  237. };
  238. /// `ConvTransposeOptions` specialized for the `ConvTranspose1d` module.
  239. ///
  240. /// Example:
  241. /// ```
  242. /// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false));
  243. /// ```
  244. using ConvTranspose1dOptions = ConvTransposeOptions<1>;
  245. /// `ConvTransposeOptions` specialized for the `ConvTranspose2d` module.
  246. ///
  247. /// Example:
  248. /// ```
  249. /// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false));
  250. /// ```
  251. using ConvTranspose2dOptions = ConvTransposeOptions<2>;
  252. /// `ConvTransposeOptions` specialized for the `ConvTranspose3d` module.
  253. ///
  254. /// Example:
  255. /// ```
  256. /// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false));
  257. /// ```
  258. using ConvTranspose3dOptions = ConvTransposeOptions<3>;
  259. // ============================================================================
  260. namespace functional {
  261. /// Options for a `D`-dimensional convolution functional.
  262. template <size_t D>
  263. struct ConvTransposeFuncOptions {
  264. /// optional bias of shape `(out_channels)`. Default: ``None``
  265. TORCH_ARG(torch::Tensor, bias) = Tensor();
  266. /// The stride of the convolving kernel.
  267. /// For a `D`-dim convolution, must be a single number or a list of `D`
  268. /// numbers.
  269. TORCH_ARG(ExpandingArray<D>, stride) = 1;
  270. /// Implicit paddings on both sides of the input.
  271. /// For a `D`-dim convolution, must be a single number or a list of `D`
  272. /// numbers.
  273. TORCH_ARG(ExpandingArray<D>, padding) = 0;
  274. /// Additional size added to one side of each dimension in the output shape. Default: 0
  275. TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
  276. /// Split input into groups, `in_channels` should be divisible by
  277. /// the number of groups.
  278. TORCH_ARG(int64_t, groups) = 1;
  279. /// The spacing between kernel elements.
  280. /// For a `D`-dim convolution, must be a single number or a list of `D`
  281. /// numbers.
  282. TORCH_ARG(ExpandingArray<D>, dilation) = 1;
  283. };
  284. /// `ConvTransposeFuncOptions` specialized for `torch::nn::functional::conv_transpose1d`.
  285. ///
  286. /// Example:
  287. /// ```
  288. /// namespace F = torch::nn::functional;
  289. /// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
  290. /// ```
  291. using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>;
  292. /// `ConvTransposeFuncOptions` specialized for `torch::nn::functional::conv_transpose2d`.
  293. ///
  294. /// Example:
  295. /// ```
  296. /// namespace F = torch::nn::functional;
  297. /// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
  298. /// ```
  299. using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>;
  300. /// `ConvTransposeFuncOptions` specialized for `torch::nn::functional::conv_transpose3d`.
  301. ///
  302. /// Example:
  303. /// ```
  304. /// namespace F = torch::nn::functional;
  305. /// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
  306. /// ```
  307. using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>;
  308. } // namespace functional
  309. } // namespace nn
  310. } // namespace torch