nmn 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/torch/conv.py ADDED
@@ -0,0 +1,2105 @@
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ from typing import Optional, Union
4
+ from typing_extensions import deprecated
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch._torch_docs import reproducibility_notes
9
+ from torch.nn import functional as F, init
10
+ from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
11
+ from torch.nn.parameter import Parameter, UninitializedParameter
12
+
13
+ from .lazy import LazyModuleMixin
14
+ from .module import Module
15
+ from .utils import _pair, _reverse_repeat_tuple, _single, _triple
16
+
17
+
18
+ __all__ = [
19
+ "Conv1d",
20
+ "Conv2d",
21
+ "Conv3d",
22
+ "ConvTranspose1d",
23
+ "ConvTranspose2d",
24
+ "ConvTranspose3d",
25
+ "LazyConv1d",
26
+ "LazyConv2d",
27
+ "LazyConv3d",
28
+ "LazyConvTranspose1d",
29
+ "LazyConvTranspose2d",
30
+ "LazyConvTranspose3d",
31
+ "YatConv1d",
32
+ "YatConv2d",
33
+ "YatConv3d",
34
+ ]
35
+
36
+ convolution_notes = {
37
+ "groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs.
38
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
39
+ :attr:`groups`. For example,
40
+
41
+ * At groups=1, all inputs are convolved to all outputs.
42
+ * At groups=2, the operation becomes equivalent to having two conv
43
+ layers side by side, each seeing half the input channels
44
+ and producing half the output channels, and both subsequently
45
+ concatenated.
46
+ * At groups= :attr:`in_channels`, each input channel is convolved with
47
+ its own set of filters (of size
48
+ :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""",
49
+ "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`,
50
+ where `K` is a positive integer, this operation is also known as a "depthwise convolution".
51
+
52
+ In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
53
+ a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments
54
+ :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""",
55
+ } # noqa: B950
56
+
57
+
58
+ class _ConvNd(Module):
59
+ __constants__ = [
60
+ "stride",
61
+ "padding",
62
+ "dilation",
63
+ "groups",
64
+ "padding_mode",
65
+ "output_padding",
66
+ "in_channels",
67
+ "out_channels",
68
+ "kernel_size",
69
+ ]
70
+ __annotations__ = {"bias": Optional[torch.Tensor]}
71
+
72
+ def _conv_forward( # type: ignore[empty-body]
73
+ self, input: Tensor, weight: Tensor, bias: Optional[Tensor]
74
+ ) -> Tensor: ...
75
+
76
+ in_channels: int
77
+ _reversed_padding_repeated_twice: list[int]
78
+ out_channels: int
79
+ kernel_size: tuple[int, ...]
80
+ stride: tuple[int, ...]
81
+ padding: Union[str, tuple[int, ...]]
82
+ dilation: tuple[int, ...]
83
+ transposed: bool
84
+ output_padding: tuple[int, ...]
85
+ groups: int
86
+ padding_mode: str
87
+ weight: Tensor
88
+ bias: Optional[Tensor]
89
+
90
+ def __init__(
91
+ self,
92
+ in_channels: int,
93
+ out_channels: int,
94
+ kernel_size: tuple[int, ...],
95
+ stride: tuple[int, ...],
96
+ padding: Union[str, tuple[int, ...]],
97
+ dilation: tuple[int, ...],
98
+ transposed: bool,
99
+ output_padding: tuple[int, ...],
100
+ groups: int,
101
+ bias: bool,
102
+ padding_mode: str,
103
+ device=None,
104
+ dtype=None,
105
+ ) -> None:
106
+ factory_kwargs = {"device": device, "dtype": dtype}
107
+ super().__init__()
108
+ if groups <= 0:
109
+ raise ValueError("groups must be a positive integer")
110
+ if in_channels % groups != 0:
111
+ raise ValueError("in_channels must be divisible by groups")
112
+ if out_channels % groups != 0:
113
+ raise ValueError("out_channels must be divisible by groups")
114
+ valid_padding_strings = {"same", "valid"}
115
+ if isinstance(padding, str):
116
+ if padding not in valid_padding_strings:
117
+ raise ValueError(
118
+ f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
119
+ )
120
+ if padding == "same" and any(s != 1 for s in stride):
121
+ raise ValueError(
122
+ "padding='same' is not supported for strided convolutions"
123
+ )
124
+
125
+ valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
126
+ if padding_mode not in valid_padding_modes:
127
+ raise ValueError(
128
+ f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
129
+ )
130
+ self.in_channels = in_channels
131
+ self.out_channels = out_channels
132
+ self.kernel_size = kernel_size
133
+ self.stride = stride
134
+ self.padding = padding
135
+ self.dilation = dilation
136
+ self.transposed = transposed
137
+ self.output_padding = output_padding
138
+ self.groups = groups
139
+ self.padding_mode = padding_mode
140
+ # `_reversed_padding_repeated_twice` is the padding to be passed to
141
+ # `F.pad` if needed (e.g., for non-zero padding types that are
142
+ # implemented as two ops: padding + conv). `F.pad` accepts paddings in
143
+ # reverse order than the dimension.
144
+ if isinstance(self.padding, str):
145
+ self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
146
+ if padding == "same":
147
+ for d, k, i in zip(
148
+ dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)
149
+ ):
150
+ total_padding = d * (k - 1)
151
+ left_pad = total_padding // 2
152
+ self._reversed_padding_repeated_twice[2 * i] = left_pad
153
+ self._reversed_padding_repeated_twice[2 * i + 1] = (
154
+ total_padding - left_pad
155
+ )
156
+ else:
157
+ self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
158
+ self.padding, 2
159
+ )
160
+
161
+ if transposed:
162
+ self.weight = Parameter(
163
+ torch.empty(
164
+ (in_channels, out_channels // groups, *kernel_size),
165
+ **factory_kwargs,
166
+ )
167
+ )
168
+ else:
169
+ self.weight = Parameter(
170
+ torch.empty(
171
+ (out_channels, in_channels // groups, *kernel_size),
172
+ **factory_kwargs,
173
+ )
174
+ )
175
+ if bias:
176
+ self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
177
+ else:
178
+ self.register_parameter("bias", None)
179
+
180
+ self.reset_parameters()
181
+
182
+ def reset_parameters(self) -> None:
183
+ # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
184
+ # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
185
+ # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
186
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
187
+ if self.bias is not None:
188
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
189
+ if fan_in != 0:
190
+ bound = 1 / math.sqrt(fan_in)
191
+ init.uniform_(self.bias, -bound, bound)
192
+
193
+ def extra_repr(self):
194
+ s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
195
+ if self.padding != (0,) * len(self.padding):
196
+ s += ", padding={padding}"
197
+ if self.dilation != (1,) * len(self.dilation):
198
+ s += ", dilation={dilation}"
199
+ if self.output_padding != (0,) * len(self.output_padding):
200
+ s += ", output_padding={output_padding}"
201
+ if self.groups != 1:
202
+ s += ", groups={groups}"
203
+ if self.bias is None:
204
+ s += ", bias=False"
205
+ if self.padding_mode != "zeros":
206
+ s += ", padding_mode={padding_mode}"
207
+ return s.format(**self.__dict__)
208
+
209
+ def __setstate__(self, state):
210
+ super().__setstate__(state)
211
+ if not hasattr(self, "padding_mode"):
212
+ self.padding_mode = "zeros"
213
+
214
+
215
+ class Conv1d(_ConvNd):
216
+ __doc__ = (
217
+ r"""Applies a 1D convolution over an input signal composed of several input
218
+ planes.
219
+
220
+ In the simplest case, the output value of the layer with input size
221
+ :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
222
+ precisely described as:
223
+
224
+ .. math::
225
+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
226
+ \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
227
+ \star \text{input}(N_i, k)
228
+
229
+ where :math:`\star` is the valid `cross-correlation`_ operator,
230
+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
231
+ :math:`L` is a length of signal sequence.
232
+ """
233
+ + r"""
234
+
235
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
236
+
237
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
238
+
239
+ * :attr:`stride` controls the stride for the cross-correlation, a single
240
+ number or a one-element tuple.
241
+
242
+ * :attr:`padding` controls the amount of padding applied to the input. It
243
+ can be either a string {{'valid', 'same'}} or a tuple of ints giving the
244
+ amount of implicit padding applied on both sides.
245
+ """
246
+ """
247
+ * :attr:`dilation` controls the spacing between the kernel points; also
248
+ known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
249
+ has a nice visualization of what :attr:`dilation` does.
250
+ """
251
+ r"""
252
+ {groups_note}
253
+
254
+ Note:
255
+ {depthwise_separable_note}
256
+ Note:
257
+ {cudnn_reproducibility_note}
258
+
259
+ Note:
260
+ ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
261
+ the input so the output has the shape as the input. However, this mode
262
+ doesn't support any stride values other than 1.
263
+
264
+ Note:
265
+ This module supports complex data types i.e. ``complex32, complex64, complex128``.
266
+
267
+ Args:
268
+ in_channels (int): Number of channels in the input image
269
+ out_channels (int): Number of channels produced by the convolution
270
+ kernel_size (int or tuple): Size of the convolving kernel
271
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
272
+ padding (int, tuple or str, optional): Padding added to both sides of
273
+ the input. Default: 0
274
+ dilation (int or tuple, optional): Spacing between kernel
275
+ elements. Default: 1
276
+ groups (int, optional): Number of blocked connections from input
277
+ channels to output channels. Default: 1
278
+ bias (bool, optional): If ``True``, adds a learnable bias to the
279
+ output. Default: ``True``
280
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
281
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
282
+
283
+ """.format(**reproducibility_notes, **convolution_notes)
284
+ + r"""
285
+
286
+ Shape:
287
+ - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
288
+ - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
289
+
290
+ .. math::
291
+ L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
292
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
293
+
294
+ Attributes:
295
+ weight (Tensor): the learnable weights of the module of shape
296
+ :math:`(\text{out\_channels},
297
+ \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
298
+ The values of these weights are sampled from
299
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
300
+ :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
301
+ bias (Tensor): the learnable bias of the module of shape
302
+ (out_channels). If :attr:`bias` is ``True``, then the values of these weights are
303
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
304
+ :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
305
+
306
+ Examples::
307
+
308
+ >>> m = nn.Conv1d(16, 33, 3, stride=2)
309
+ >>> input = torch.randn(20, 16, 50)
310
+ >>> output = m(input)
311
+
312
+ .. _cross-correlation:
313
+ https://en.wikipedia.org/wiki/Cross-correlation
314
+
315
+ .. _link:
316
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
317
+ """
318
+ )
319
+
320
+ def __init__(
321
+ self,
322
+ in_channels: int,
323
+ out_channels: int,
324
+ kernel_size: _size_1_t,
325
+ stride: _size_1_t = 1,
326
+ padding: Union[str, _size_1_t] = 0,
327
+ dilation: _size_1_t = 1,
328
+ groups: int = 1,
329
+ bias: bool = True,
330
+ padding_mode: str = "zeros", # TODO: refine this type
331
+ device=None,
332
+ dtype=None,
333
+ ) -> None:
334
+ factory_kwargs = {"device": device, "dtype": dtype}
335
+ # we create new variables below to make mypy happy since kernel_size has
336
+ # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
337
+ kernel_size_ = _single(kernel_size)
338
+ stride_ = _single(stride)
339
+ padding_ = padding if isinstance(padding, str) else _single(padding)
340
+ dilation_ = _single(dilation)
341
+ super().__init__(
342
+ in_channels,
343
+ out_channels,
344
+ kernel_size_,
345
+ stride_,
346
+ padding_,
347
+ dilation_,
348
+ False,
349
+ _single(0),
350
+ groups,
351
+ bias,
352
+ padding_mode,
353
+ **factory_kwargs,
354
+ )
355
+
356
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
357
+ if self.padding_mode != "zeros":
358
+ return F.conv1d(
359
+ F.pad(
360
+ input, self._reversed_padding_repeated_twice, mode=self.padding_mode
361
+ ),
362
+ weight,
363
+ bias,
364
+ self.stride,
365
+ _single(0),
366
+ self.dilation,
367
+ self.groups,
368
+ )
369
+ return F.conv1d(
370
+ input, weight, bias, self.stride, self.padding, self.dilation, self.groups
371
+ )
372
+
373
+ def forward(self, input: Tensor) -> Tensor:
374
+ return self._conv_forward(input, self.weight, self.bias)
375
+
376
+
377
+ class Conv2d(_ConvNd):
378
+ __doc__ = (
379
+ r"""Applies a 2D convolution over an input signal composed of several input
380
+ planes.
381
+
382
+ In the simplest case, the output value of the layer with input size
383
+ :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
384
+ can be precisely described as:
385
+
386
+ .. math::
387
+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
388
+ \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
389
+
390
+
391
+ where :math:`\star` is the valid 2D `cross-correlation`_ operator,
392
+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
393
+ :math:`H` is a height of input planes in pixels, and :math:`W` is
394
+ width in pixels.
395
+ """
396
+ + r"""
397
+
398
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
399
+
400
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
401
+
402
+ * :attr:`stride` controls the stride for the cross-correlation, a single
403
+ number or a tuple.
404
+
405
+ * :attr:`padding` controls the amount of padding applied to the input. It
406
+ can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the
407
+ amount of implicit padding applied on both sides.
408
+ """
409
+ """
410
+ * :attr:`dilation` controls the spacing between the kernel points; also
411
+ known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
412
+ has a nice visualization of what :attr:`dilation` does.
413
+ """
414
+ r"""
415
+
416
+ {groups_note}
417
+
418
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
419
+
420
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
421
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
422
+ and the second `int` for the width dimension
423
+
424
+ Note:
425
+ {depthwise_separable_note}
426
+
427
+ Note:
428
+ {cudnn_reproducibility_note}
429
+
430
+ Note:
431
+ ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
432
+ the input so the output has the shape as the input. However, this mode
433
+ doesn't support any stride values other than 1.
434
+
435
+ Note:
436
+ This module supports complex data types i.e. ``complex32, complex64, complex128``.
437
+
438
+ Args:
439
+ in_channels (int): Number of channels in the input image
440
+ out_channels (int): Number of channels produced by the convolution
441
+ kernel_size (int or tuple): Size of the convolving kernel
442
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
443
+ padding (int, tuple or str, optional): Padding added to all four sides of
444
+ the input. Default: 0
445
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
446
+ groups (int, optional): Number of blocked connections from input
447
+ channels to output channels. Default: 1
448
+ bias (bool, optional): If ``True``, adds a learnable bias to the
449
+ output. Default: ``True``
450
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
451
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
452
+ """.format(**reproducibility_notes, **convolution_notes)
453
+ + r"""
454
+
455
+ Shape:
456
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
457
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
458
+
459
+ .. math::
460
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
461
+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
462
+
463
+ .. math::
464
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
465
+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
466
+
467
+ Attributes:
468
+ weight (Tensor): the learnable weights of the module of shape
469
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
470
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
471
+ The values of these weights are sampled from
472
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
473
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
474
+ bias (Tensor): the learnable bias of the module of shape
475
+ (out_channels). If :attr:`bias` is ``True``,
476
+ then the values of these weights are
477
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
478
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
479
+
480
+ Examples:
481
+
482
+ >>> # With square kernels and equal stride
483
+ >>> m = nn.Conv2d(16, 33, 3, stride=2)
484
+ >>> # non-square kernels and unequal stride and with padding
485
+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
486
+ >>> # non-square kernels and unequal stride and with padding and dilation
487
+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
488
+ >>> input = torch.randn(20, 16, 50, 100)
489
+ >>> output = m(input)
490
+
491
+ .. _cross-correlation:
492
+ https://en.wikipedia.org/wiki/Cross-correlation
493
+
494
+ .. _link:
495
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
496
+ """
497
+ )
498
+
499
+ def __init__(
500
+ self,
501
+ in_channels: int,
502
+ out_channels: int,
503
+ kernel_size: _size_2_t,
504
+ stride: _size_2_t = 1,
505
+ padding: Union[str, _size_2_t] = 0,
506
+ dilation: _size_2_t = 1,
507
+ groups: int = 1,
508
+ bias: bool = True,
509
+ padding_mode: str = "zeros", # TODO: refine this type
510
+ device=None,
511
+ dtype=None,
512
+ ) -> None:
513
+ factory_kwargs = {"device": device, "dtype": dtype}
514
+ kernel_size_ = _pair(kernel_size)
515
+ stride_ = _pair(stride)
516
+ padding_ = padding if isinstance(padding, str) else _pair(padding)
517
+ dilation_ = _pair(dilation)
518
+ super().__init__(
519
+ in_channels,
520
+ out_channels,
521
+ kernel_size_,
522
+ stride_,
523
+ padding_,
524
+ dilation_,
525
+ False,
526
+ _pair(0),
527
+ groups,
528
+ bias,
529
+ padding_mode,
530
+ **factory_kwargs,
531
+ )
532
+
533
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
534
+ if self.padding_mode != "zeros":
535
+ return F.conv2d(
536
+ F.pad(
537
+ input, self._reversed_padding_repeated_twice, mode=self.padding_mode
538
+ ),
539
+ weight,
540
+ bias,
541
+ self.stride,
542
+ _pair(0),
543
+ self.dilation,
544
+ self.groups,
545
+ )
546
+ return F.conv2d(
547
+ input, weight, bias, self.stride, self.padding, self.dilation, self.groups
548
+ )
549
+
550
+ def forward(self, input: Tensor) -> Tensor:
551
+ return self._conv_forward(input, self.weight, self.bias)
552
+
553
+
554
+ class Conv3d(_ConvNd):
555
+ __doc__ = (
556
+ r"""Applies a 3D convolution over an input signal composed of several input
557
+ planes.
558
+
559
+ In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
560
+ and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
561
+
562
+ .. math::
563
+ out(N_i, C_{out_j}) = bias(C_{out_j}) +
564
+ \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
565
+
566
+ where :math:`\star` is the valid 3D `cross-correlation`_ operator
567
+ """
568
+ + r"""
569
+
570
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
571
+
572
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
573
+
574
+ * :attr:`stride` controls the stride for the cross-correlation.
575
+
576
+ * :attr:`padding` controls the amount of padding applied to the input. It
577
+ can be either a string {{'valid', 'same'}} or a tuple of ints giving the
578
+ amount of implicit padding applied on both sides.
579
+ """
580
+ """
581
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
582
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
583
+ """
584
+ r"""
585
+
586
+ {groups_note}
587
+
588
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
589
+
590
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
591
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
592
+ the second `int` for the height dimension and the third `int` for the width dimension
593
+
594
+ Note:
595
+ {depthwise_separable_note}
596
+
597
+ Note:
598
+ {cudnn_reproducibility_note}
599
+
600
+ Note:
601
+ ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
602
+ the input so the output has the shape as the input. However, this mode
603
+ doesn't support any stride values other than 1.
604
+
605
+ Note:
606
+ This module supports complex data types i.e. ``complex32, complex64, complex128``.
607
+
608
+ Args:
609
+ in_channels (int): Number of channels in the input image
610
+ out_channels (int): Number of channels produced by the convolution
611
+ kernel_size (int or tuple): Size of the convolving kernel
612
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
613
+ padding (int, tuple or str, optional): Padding added to all six sides of
614
+ the input. Default: 0
615
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
616
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
617
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
618
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
619
+ """.format(**reproducibility_notes, **convolution_notes)
620
+ + r"""
621
+
622
+ Shape:
623
+ - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
624
+ - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
625
+ where
626
+
627
+ .. math::
628
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
629
+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
630
+
631
+ .. math::
632
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
633
+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
634
+
635
+ .. math::
636
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
637
+ \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
638
+
639
+ Attributes:
640
+ weight (Tensor): the learnable weights of the module of shape
641
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
642
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
643
+ The values of these weights are sampled from
644
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
645
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
646
+ bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
647
+ then the values of these weights are
648
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
649
+ :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
650
+
651
+ Examples::
652
+
653
+ >>> # With square kernels and equal stride
654
+ >>> m = nn.Conv3d(16, 33, 3, stride=2)
655
+ >>> # non-square kernels and unequal stride and with padding
656
+ >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
657
+ >>> input = torch.randn(20, 16, 10, 50, 100)
658
+ >>> output = m(input)
659
+
660
+ .. _cross-correlation:
661
+ https://en.wikipedia.org/wiki/Cross-correlation
662
+
663
+ .. _link:
664
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
665
+ """
666
+ )
667
+
668
+ def __init__(
669
+ self,
670
+ in_channels: int,
671
+ out_channels: int,
672
+ kernel_size: _size_3_t,
673
+ stride: _size_3_t = 1,
674
+ padding: Union[str, _size_3_t] = 0,
675
+ dilation: _size_3_t = 1,
676
+ groups: int = 1,
677
+ bias: bool = True,
678
+ padding_mode: str = "zeros",
679
+ device=None,
680
+ dtype=None,
681
+ ) -> None:
682
+ factory_kwargs = {"device": device, "dtype": dtype}
683
+ kernel_size_ = _triple(kernel_size)
684
+ stride_ = _triple(stride)
685
+ padding_ = padding if isinstance(padding, str) else _triple(padding)
686
+ dilation_ = _triple(dilation)
687
+ super().__init__(
688
+ in_channels,
689
+ out_channels,
690
+ kernel_size_,
691
+ stride_,
692
+ padding_,
693
+ dilation_,
694
+ False,
695
+ _triple(0),
696
+ groups,
697
+ bias,
698
+ padding_mode,
699
+ **factory_kwargs,
700
+ )
701
+
702
+ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
703
+ if self.padding_mode != "zeros":
704
+ return F.conv3d(
705
+ F.pad(
706
+ input, self._reversed_padding_repeated_twice, mode=self.padding_mode
707
+ ),
708
+ weight,
709
+ bias,
710
+ self.stride,
711
+ _triple(0),
712
+ self.dilation,
713
+ self.groups,
714
+ )
715
+ return F.conv3d(
716
+ input, weight, bias, self.stride, self.padding, self.dilation, self.groups
717
+ )
718
+
719
+ def forward(self, input: Tensor) -> Tensor:
720
+ return self._conv_forward(input, self.weight, self.bias)
721
+
722
+
723
+ class _ConvTransposeNd(_ConvNd):
724
+ def __init__(
725
+ self,
726
+ in_channels,
727
+ out_channels,
728
+ kernel_size,
729
+ stride,
730
+ padding,
731
+ dilation,
732
+ transposed,
733
+ output_padding,
734
+ groups,
735
+ bias,
736
+ padding_mode,
737
+ device=None,
738
+ dtype=None,
739
+ ) -> None:
740
+ if padding_mode != "zeros":
741
+ raise ValueError(
742
+ f'Only "zeros" padding mode is supported for {self.__class__.__name__}'
743
+ )
744
+
745
+ factory_kwargs = {"device": device, "dtype": dtype}
746
+ super().__init__(
747
+ in_channels,
748
+ out_channels,
749
+ kernel_size,
750
+ stride,
751
+ padding,
752
+ dilation,
753
+ transposed,
754
+ output_padding,
755
+ groups,
756
+ bias,
757
+ padding_mode,
758
+ **factory_kwargs,
759
+ )
760
+
761
+ # dilation being an optional parameter is for backwards
762
+ # compatibility
763
+ def _output_padding(
764
+ self,
765
+ input: Tensor,
766
+ output_size: Optional[list[int]],
767
+ stride: list[int],
768
+ padding: list[int],
769
+ kernel_size: list[int],
770
+ num_spatial_dims: int,
771
+ dilation: Optional[list[int]] = None,
772
+ ) -> list[int]:
773
+ if output_size is None:
774
+ ret = _single(self.output_padding) # converting to list if was not already
775
+ else:
776
+ has_batch_dim = input.dim() == num_spatial_dims + 2
777
+ num_non_spatial_dims = 2 if has_batch_dim else 1
778
+ if len(output_size) == num_non_spatial_dims + num_spatial_dims:
779
+ output_size = output_size[num_non_spatial_dims:]
780
+ if len(output_size) != num_spatial_dims:
781
+ raise ValueError(
782
+ f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} "
783
+ f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
784
+ )
785
+
786
+ min_sizes = torch.jit.annotate(list[int], [])
787
+ max_sizes = torch.jit.annotate(list[int], [])
788
+ for d in range(num_spatial_dims):
789
+ dim_size = (
790
+ (input.size(d + num_non_spatial_dims) - 1) * stride[d]
791
+ - 2 * padding[d]
792
+ + (dilation[d] if dilation is not None else 1)
793
+ * (kernel_size[d] - 1)
794
+ + 1
795
+ )
796
+ min_sizes.append(dim_size)
797
+ max_sizes.append(min_sizes[d] + stride[d] - 1)
798
+
799
+ for i in range(len(output_size)):
800
+ size = output_size[i]
801
+ min_size = min_sizes[i]
802
+ max_size = max_sizes[i]
803
+ if size < min_size or size > max_size:
804
+ raise ValueError(
805
+ f"requested an output size of {output_size}, but valid sizes range "
806
+ f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
807
+ )
808
+
809
+ res = torch.jit.annotate(list[int], [])
810
+ for d in range(num_spatial_dims):
811
+ res.append(output_size[d] - min_sizes[d])
812
+
813
+ ret = res
814
+ return ret
815
+
816
+
817
+ class ConvTranspose1d(_ConvTransposeNd):
818
+ __doc__ = (
819
+ r"""Applies a 1D transposed convolution operator over an input image
820
+ composed of several input planes.
821
+
822
+ This module can be seen as the gradient of Conv1d with respect to its input.
823
+ It is also known as a fractionally-strided convolution or
824
+ a deconvolution (although it is not an actual deconvolution operation as it does
825
+ not compute a true inverse of convolution). For more information, see the visualizations
826
+ `here`_ and the `Deconvolutional Networks`_ paper.
827
+
828
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
829
+
830
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
831
+
832
+ * :attr:`stride` controls the stride for the cross-correlation.
833
+
834
+ * :attr:`padding` controls the amount of implicit zero padding on both
835
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
836
+ below for details.
837
+
838
+ * :attr:`output_padding` controls the additional size added to one side
839
+ of the output shape. See note below for details.
840
+ """
841
+ """
842
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
843
+ It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
844
+ """
845
+ r"""
846
+ {groups_note}
847
+
848
+ Note:
849
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
850
+ amount of zero padding to both sizes of the input. This is set so that
851
+ when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`
852
+ are initialized with same parameters, they are inverses of each other in
853
+ regard to the input and output shapes. However, when ``stride > 1``,
854
+ :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output
855
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
856
+ effectively increasing the calculated output shape on one side. Note
857
+ that :attr:`output_padding` is only used to find output shape, but does
858
+ not actually add zero-padding to output.
859
+
860
+ Note:
861
+ In some circumstances when using the CUDA backend with CuDNN, this operator
862
+ may select a nondeterministic algorithm to increase performance. If this is
863
+ undesirable, you can try to make the operation deterministic (potentially at
864
+ a performance cost) by setting ``torch.backends.cudnn.deterministic =
865
+ True``.
866
+ Please see the notes on :doc:`/notes/randomness` for background.
867
+
868
+
869
+ Args:
870
+ in_channels (int): Number of channels in the input image
871
+ out_channels (int): Number of channels produced by the convolution
872
+ kernel_size (int or tuple): Size of the convolving kernel
873
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
874
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
875
+ will be added to both sides of the input. Default: 0
876
+ output_padding (int or tuple, optional): Additional size added to one side
877
+ of the output shape. Default: 0
878
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
879
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
880
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
881
+ """.format(**reproducibility_notes, **convolution_notes)
882
+ + r"""
883
+
884
+ Shape:
885
+ - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
886
+ - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
887
+
888
+ .. math::
889
+ L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
890
+ \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1
891
+
892
+ Attributes:
893
+ weight (Tensor): the learnable weights of the module of shape
894
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
895
+ :math:`\text{kernel\_size})`.
896
+ The values of these weights are sampled from
897
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
898
+ :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
899
+ bias (Tensor): the learnable bias of the module of shape (out_channels).
900
+ If :attr:`bias` is ``True``, then the values of these weights are
901
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
902
+ :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
903
+
904
+ .. _`here`:
905
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
906
+
907
+ .. _`Deconvolutional Networks`:
908
+ https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
909
+ """
910
+ )
911
+
912
+ def __init__(
913
+ self,
914
+ in_channels: int,
915
+ out_channels: int,
916
+ kernel_size: _size_1_t,
917
+ stride: _size_1_t = 1,
918
+ padding: _size_1_t = 0,
919
+ output_padding: _size_1_t = 0,
920
+ groups: int = 1,
921
+ bias: bool = True,
922
+ dilation: _size_1_t = 1,
923
+ padding_mode: str = "zeros",
924
+ device=None,
925
+ dtype=None,
926
+ ) -> None:
927
+ factory_kwargs = {"device": device, "dtype": dtype}
928
+ kernel_size = _single(kernel_size)
929
+ stride = _single(stride)
930
+ padding = _single(padding)
931
+ dilation = _single(dilation)
932
+ output_padding = _single(output_padding)
933
+ super().__init__(
934
+ in_channels,
935
+ out_channels,
936
+ kernel_size,
937
+ stride,
938
+ padding,
939
+ dilation,
940
+ True,
941
+ output_padding,
942
+ groups,
943
+ bias,
944
+ padding_mode,
945
+ **factory_kwargs,
946
+ )
947
+
948
+ def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
949
+ if self.padding_mode != "zeros":
950
+ raise ValueError(
951
+ "Only `zeros` padding mode is supported for ConvTranspose1d"
952
+ )
953
+
954
+ assert isinstance(self.padding, tuple)
955
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
956
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
957
+ num_spatial_dims = 1
958
+ output_padding = self._output_padding(
959
+ input,
960
+ output_size,
961
+ self.stride, # type: ignore[arg-type]
962
+ self.padding, # type: ignore[arg-type]
963
+ self.kernel_size, # type: ignore[arg-type]
964
+ num_spatial_dims,
965
+ self.dilation, # type: ignore[arg-type]
966
+ )
967
+ return F.conv_transpose1d(
968
+ input,
969
+ self.weight,
970
+ self.bias,
971
+ self.stride,
972
+ self.padding,
973
+ output_padding,
974
+ self.groups,
975
+ self.dilation,
976
+ )
977
+
978
+
979
+ class ConvTranspose2d(_ConvTransposeNd):
980
+ __doc__ = (
981
+ r"""Applies a 2D transposed convolution operator over an input image
982
+ composed of several input planes.
983
+
984
+ This module can be seen as the gradient of Conv2d with respect to its input.
985
+ It is also known as a fractionally-strided convolution or
986
+ a deconvolution (although it is not an actual deconvolution operation as it does
987
+ not compute a true inverse of convolution). For more information, see the visualizations
988
+ `here`_ and the `Deconvolutional Networks`_ paper.
989
+
990
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
991
+
992
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
993
+
994
+ * :attr:`stride` controls the stride for the cross-correlation. When stride > 1, ConvTranspose2d inserts zeros between input
995
+ elements along the spatial dimensions before applying the convolution kernel. This zero-insertion operation is the standard
996
+ behavior of transposed convolutions, which can increase the spatial resolution and is equivalent to a learnable
997
+ upsampling operation.
998
+
999
+ * :attr:`padding` controls the amount of implicit zero padding on both
1000
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
1001
+ below for details.
1002
+
1003
+ * :attr:`output_padding` controls the additional size added to one side
1004
+ of the output shape. See note below for details.
1005
+ """
1006
+ """
1007
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
1008
+ It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
1009
+ """
1010
+ r"""
1011
+ {groups_note}
1012
+
1013
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
1014
+ can either be:
1015
+
1016
+ - a single ``int`` -- in which case the same value is used for the height and width dimensions
1017
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
1018
+ and the second `int` for the width dimension
1019
+
1020
+ Note:
1021
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
1022
+ amount of zero padding to both sizes of the input. This is set so that
1023
+ when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
1024
+ are initialized with same parameters, they are inverses of each other in
1025
+ regard to the input and output shapes. However, when ``stride > 1``,
1026
+ :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
1027
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
1028
+ effectively increasing the calculated output shape on one side. Note
1029
+ that :attr:`output_padding` is only used to find output shape, but does
1030
+ not actually add zero-padding to output.
1031
+
1032
+ Note:
1033
+ {cudnn_reproducibility_note}
1034
+
1035
+ Args:
1036
+ in_channels (int): Number of channels in the input image
1037
+ out_channels (int): Number of channels produced by the convolution
1038
+ kernel_size (int or tuple): Size of the convolving kernel
1039
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1040
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1041
+ will be added to both sides of each dimension in the input. Default: 0
1042
+ output_padding (int or tuple, optional): Additional size added to one side
1043
+ of each dimension in the output shape. Default: 0
1044
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1045
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1046
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1047
+ """.format(**reproducibility_notes, **convolution_notes)
1048
+ + r"""
1049
+
1050
+ Shape:
1051
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
1052
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
1053
+
1054
+ .. math::
1055
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
1056
+ \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
1057
+ .. math::
1058
+ W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
1059
+ \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
1060
+
1061
+ Attributes:
1062
+ weight (Tensor): the learnable weights of the module of shape
1063
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
1064
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
1065
+ The values of these weights are sampled from
1066
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1067
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1068
+ bias (Tensor): the learnable bias of the module of shape (out_channels)
1069
+ If :attr:`bias` is ``True``, then the values of these weights are
1070
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1071
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
1072
+
1073
+ Examples::
1074
+
1075
+ >>> # With square kernels and equal stride
1076
+ >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
1077
+ >>> # non-square kernels and unequal stride and with padding
1078
+ >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
1079
+ >>> input = torch.randn(20, 16, 50, 100)
1080
+ >>> output = m(input)
1081
+ >>> # exact output size can be also specified as an argument
1082
+ >>> input = torch.randn(1, 16, 12, 12)
1083
+ >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
1084
+ >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
1085
+ >>> h = downsample(input)
1086
+ >>> h.size()
1087
+ torch.Size([1, 16, 6, 6])
1088
+ >>> output = upsample(h, output_size=input.size())
1089
+ >>> output.size()
1090
+ torch.Size([1, 16, 12, 12])
1091
+
1092
+ .. _`here`:
1093
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1094
+
1095
+ .. _`Deconvolutional Networks`:
1096
+ https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
1097
+ """
1098
+ )
1099
+
1100
+ def __init__(
1101
+ self,
1102
+ in_channels: int,
1103
+ out_channels: int,
1104
+ kernel_size: _size_2_t,
1105
+ stride: _size_2_t = 1,
1106
+ padding: _size_2_t = 0,
1107
+ output_padding: _size_2_t = 0,
1108
+ groups: int = 1,
1109
+ bias: bool = True,
1110
+ dilation: _size_2_t = 1,
1111
+ padding_mode: str = "zeros",
1112
+ device=None,
1113
+ dtype=None,
1114
+ ) -> None:
1115
+ factory_kwargs = {"device": device, "dtype": dtype}
1116
+ kernel_size = _pair(kernel_size)
1117
+ stride = _pair(stride)
1118
+ padding = _pair(padding)
1119
+ dilation = _pair(dilation)
1120
+ output_padding = _pair(output_padding)
1121
+ super().__init__(
1122
+ in_channels,
1123
+ out_channels,
1124
+ kernel_size,
1125
+ stride,
1126
+ padding,
1127
+ dilation,
1128
+ True,
1129
+ output_padding,
1130
+ groups,
1131
+ bias,
1132
+ padding_mode,
1133
+ **factory_kwargs,
1134
+ )
1135
+
1136
+ def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
1137
+ """
1138
+ Performs the forward pass.
1139
+
1140
+ Attributes:
1141
+ input (Tensor): The input tensor.
1142
+ output_size (list[int], optional): A list of integers representing
1143
+ the size of the output tensor. Default is None.
1144
+ """
1145
+ if self.padding_mode != "zeros":
1146
+ raise ValueError(
1147
+ "Only `zeros` padding mode is supported for ConvTranspose2d"
1148
+ )
1149
+
1150
+ assert isinstance(self.padding, tuple)
1151
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
1152
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
1153
+ num_spatial_dims = 2
1154
+ output_padding = self._output_padding(
1155
+ input,
1156
+ output_size,
1157
+ self.stride, # type: ignore[arg-type]
1158
+ self.padding, # type: ignore[arg-type]
1159
+ self.kernel_size, # type: ignore[arg-type]
1160
+ num_spatial_dims,
1161
+ self.dilation, # type: ignore[arg-type]
1162
+ )
1163
+
1164
+ return F.conv_transpose2d(
1165
+ input,
1166
+ self.weight,
1167
+ self.bias,
1168
+ self.stride,
1169
+ self.padding,
1170
+ output_padding,
1171
+ self.groups,
1172
+ self.dilation,
1173
+ )
1174
+
1175
+
1176
+ class ConvTranspose3d(_ConvTransposeNd):
1177
+ __doc__ = (
1178
+ r"""Applies a 3D transposed convolution operator over an input image composed of several input
1179
+ planes.
1180
+ The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
1181
+ and sums over the outputs from all input feature planes.
1182
+
1183
+ This module can be seen as the gradient of Conv3d with respect to its input.
1184
+ It is also known as a fractionally-strided convolution or
1185
+ a deconvolution (although it is not an actual deconvolution operation as it does
1186
+ not compute a true inverse of convolution). For more information, see the visualizations
1187
+ `here`_ and the `Deconvolutional Networks`_ paper.
1188
+
1189
+ This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
1190
+
1191
+ On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
1192
+
1193
+ * :attr:`stride` controls the stride for the cross-correlation.
1194
+
1195
+ * :attr:`padding` controls the amount of implicit zero padding on both
1196
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
1197
+ below for details.
1198
+
1199
+ * :attr:`output_padding` controls the additional size added to one side
1200
+ of the output shape. See note below for details.
1201
+ """
1202
+ """
1203
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
1204
+ It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
1205
+ """
1206
+ r"""
1207
+ {groups_note}
1208
+
1209
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
1210
+ can either be:
1211
+
1212
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
1213
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
1214
+ the second `int` for the height dimension and the third `int` for the width dimension
1215
+
1216
+ Note:
1217
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
1218
+ amount of zero padding to both sizes of the input. This is set so that
1219
+ when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`
1220
+ are initialized with same parameters, they are inverses of each other in
1221
+ regard to the input and output shapes. However, when ``stride > 1``,
1222
+ :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output
1223
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
1224
+ effectively increasing the calculated output shape on one side. Note
1225
+ that :attr:`output_padding` is only used to find output shape, but does
1226
+ not actually add zero-padding to output.
1227
+
1228
+ Note:
1229
+ {cudnn_reproducibility_note}
1230
+
1231
+ Args:
1232
+ in_channels (int): Number of channels in the input image
1233
+ out_channels (int): Number of channels produced by the convolution
1234
+ kernel_size (int or tuple): Size of the convolving kernel
1235
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1236
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1237
+ will be added to both sides of each dimension in the input. Default: 0
1238
+ output_padding (int or tuple, optional): Additional size added to one side
1239
+ of each dimension in the output shape. Default: 0
1240
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1241
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1242
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1243
+ """.format(**reproducibility_notes, **convolution_notes)
1244
+ + r"""
1245
+
1246
+ Shape:
1247
+ - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
1248
+ - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
1249
+ :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
1250
+
1251
+ .. math::
1252
+ D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
1253
+ \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
1254
+ .. math::
1255
+ H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
1256
+ \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
1257
+ .. math::
1258
+ W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2]
1259
+ \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
1260
+
1261
+
1262
+ Attributes:
1263
+ weight (Tensor): the learnable weights of the module of shape
1264
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
1265
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
1266
+ The values of these weights are sampled from
1267
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1268
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
1269
+ bias (Tensor): the learnable bias of the module of shape (out_channels)
1270
+ If :attr:`bias` is ``True``, then the values of these weights are
1271
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
1272
+ :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
1273
+
1274
+ Examples::
1275
+
1276
+ >>> # With square kernels and equal stride
1277
+ >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
1278
+ >>> # non-square kernels and unequal stride and with padding
1279
+ >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
1280
+ >>> input = torch.randn(20, 16, 10, 50, 100)
1281
+ >>> output = m(input)
1282
+
1283
+ .. _`here`:
1284
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
1285
+
1286
+ .. _`Deconvolutional Networks`:
1287
+ https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
1288
+ """
1289
+ )
1290
+
1291
+ def __init__(
1292
+ self,
1293
+ in_channels: int,
1294
+ out_channels: int,
1295
+ kernel_size: _size_3_t,
1296
+ stride: _size_3_t = 1,
1297
+ padding: _size_3_t = 0,
1298
+ output_padding: _size_3_t = 0,
1299
+ groups: int = 1,
1300
+ bias: bool = True,
1301
+ dilation: _size_3_t = 1,
1302
+ padding_mode: str = "zeros",
1303
+ device=None,
1304
+ dtype=None,
1305
+ ) -> None:
1306
+ factory_kwargs = {"device": device, "dtype": dtype}
1307
+ kernel_size = _triple(kernel_size)
1308
+ stride = _triple(stride)
1309
+ padding = _triple(padding)
1310
+ dilation = _triple(dilation)
1311
+ output_padding = _triple(output_padding)
1312
+ super().__init__(
1313
+ in_channels,
1314
+ out_channels,
1315
+ kernel_size,
1316
+ stride,
1317
+ padding,
1318
+ dilation,
1319
+ True,
1320
+ output_padding,
1321
+ groups,
1322
+ bias,
1323
+ padding_mode,
1324
+ **factory_kwargs,
1325
+ )
1326
+
1327
+ def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
1328
+ if self.padding_mode != "zeros":
1329
+ raise ValueError(
1330
+ "Only `zeros` padding mode is supported for ConvTranspose3d"
1331
+ )
1332
+
1333
+ assert isinstance(self.padding, tuple)
1334
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
1335
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
1336
+ num_spatial_dims = 3
1337
+ output_padding = self._output_padding(
1338
+ input,
1339
+ output_size,
1340
+ self.stride, # type: ignore[arg-type]
1341
+ self.padding, # type: ignore[arg-type]
1342
+ self.kernel_size, # type: ignore[arg-type]
1343
+ num_spatial_dims,
1344
+ self.dilation, # type: ignore[arg-type]
1345
+ )
1346
+
1347
+ return F.conv_transpose3d(
1348
+ input,
1349
+ self.weight,
1350
+ self.bias,
1351
+ self.stride,
1352
+ self.padding,
1353
+ output_padding,
1354
+ self.groups,
1355
+ self.dilation,
1356
+ )
1357
+
1358
+
1359
+ # TODO: Deprecate and remove the following alias `_ConvTransposeMixin`.
1360
+ #
1361
+ # `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used
1362
+ # with `_ConvNd` to construct actual module classes that implements conv
1363
+ # transpose ops:
1364
+ #
1365
+ # class MyConvTranspose(_ConvNd, _ConvTransposeMixin):
1366
+ # ...
1367
+ #
1368
+ # In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper
1369
+ # subclass of `_ConvNd`. However, some user code in the wild still (incorrectly)
1370
+ # use the internal class `_ConvTransposeMixin`. Hence, we provide this alias
1371
+ # for BC, because it is cheap and easy for us to do so, even though that
1372
+ # `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as
1373
+ # above would still work).
1374
+ class _ConvTransposeMixin(_ConvTransposeNd):
1375
+ @deprecated(
1376
+ "`_ConvTransposeMixin` is a deprecated internal class. "
1377
+ "Please consider using public APIs.",
1378
+ category=FutureWarning,
1379
+ )
1380
+ def __init__(self, *args, **kwargs):
1381
+ super().__init__(*args, **kwargs)
1382
+
1383
+
1384
+ class YatConvNd(_ConvNd):
1385
+ def __init__(
1386
+ self,
1387
+ in_channels: int,
1388
+ out_channels: int,
1389
+ kernel_size: tuple[int, ...],
1390
+ stride: tuple[int, ...],
1391
+ padding: Union[str, tuple[int, ...]],
1392
+ dilation: tuple[int, ...],
1393
+ transposed: bool,
1394
+ output_padding: tuple[int, ...],
1395
+ groups: int,
1396
+ bias: bool,
1397
+ padding_mode: str,
1398
+ use_alpha: bool = True,
1399
+ epsilon: float = 1e-5,
1400
+ device=None,
1401
+ dtype=None,
1402
+ ) -> None:
1403
+ super().__init__(
1404
+ in_channels,
1405
+ out_channels,
1406
+ kernel_size,
1407
+ stride,
1408
+ padding,
1409
+ dilation,
1410
+ transposed,
1411
+ output_padding,
1412
+ groups,
1413
+ bias,
1414
+ padding_mode,
1415
+ device=device,
1416
+ dtype=dtype,
1417
+ )
1418
+ self.use_alpha = use_alpha
1419
+ self.epsilon = epsilon
1420
+ if self.use_alpha:
1421
+ factory_kwargs = {"device": device, "dtype": dtype}
1422
+ self.alpha = Parameter(torch.ones(1, **factory_kwargs))
1423
+ else:
1424
+ self.register_parameter("alpha", None)
1425
+
1426
+ def _yat_forward(self, input: Tensor, conv_fn: callable) -> Tensor:
1427
+ dot_prod_map = self._conv_forward(input, self.weight, None)
1428
+
1429
+ input_sq = input * input
1430
+
1431
+ sum_kernel_size = (self.groups, self.in_channels // self.groups) + self.kernel_size
1432
+ sum_kernel = torch.ones(
1433
+ sum_kernel_size, device=input.device, dtype=input.dtype
1434
+ )
1435
+
1436
+ if self.padding_mode != "zeros":
1437
+ if isinstance(self.padding, str):
1438
+ # This should not happen for padding_mode != 'zeros' based on _ConvNd, but for safety
1439
+ _padding = [0] * len(self.kernel_size)
1440
+ else:
1441
+ _padding = self.padding
1442
+ patch_sq_sum_map_raw = conv_fn(
1443
+ F.pad(
1444
+ input_sq,
1445
+ self._reversed_padding_repeated_twice,
1446
+ mode=self.padding_mode,
1447
+ ),
1448
+ sum_kernel,
1449
+ None,
1450
+ self.stride,
1451
+ [0] * len(self.kernel_size),
1452
+ self.dilation,
1453
+ self.groups,
1454
+ )
1455
+ else:
1456
+ patch_sq_sum_map_raw = conv_fn(
1457
+ input_sq,
1458
+ sum_kernel,
1459
+ None,
1460
+ self.stride,
1461
+ self.padding,
1462
+ self.dilation,
1463
+ self.groups,
1464
+ )
1465
+ if self.groups > 1:
1466
+ if self.out_channels % self.groups != 0:
1467
+ raise ValueError("out_channels must be divisible by groups")
1468
+ num_out_channels_per_group = self.out_channels // self.groups
1469
+ patch_sq_sum_map = patch_sq_sum_map_raw.repeat_interleave(
1470
+ num_out_channels_per_group, dim=1
1471
+ )
1472
+ else:
1473
+ patch_sq_sum_map = patch_sq_sum_map_raw
1474
+
1475
+ reduce_dims = tuple(range(1, self.weight.dim()))
1476
+ kernel_sq_sum_per_filter = torch.sum(self.weight**2, dim=reduce_dims)
1477
+
1478
+ view_shape = (1, -1) + (1,) * (dot_prod_map.dim() - 2)
1479
+ kernel_sq_sum_reshaped = kernel_sq_sum_per_filter.view(*view_shape)
1480
+
1481
+ distance_sq_map = patch_sq_sum_map + kernel_sq_sum_reshaped - 2 * dot_prod_map
1482
+ y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
1483
+
1484
+ if self.bias is not None:
1485
+ y = y + self.bias.view(*view_shape)
1486
+
1487
+ if self.use_alpha and self.alpha is not None:
1488
+ scale = (math.sqrt(self.out_channels) / math.log(1.0 + self.out_channels)) ** self.alpha
1489
+ y = y * scale
1490
+
1491
+ return y
1492
+
1493
+
1494
+ # TODO: Conv2dLocal
1495
+ # TODO: Conv2dMap
1496
+ # TODO: ConvTranspose2dMap
1497
+
1498
+
1499
+ class _LazyConvXdMixin(LazyModuleMixin):
1500
+ groups: int
1501
+ transposed: bool
1502
+ in_channels: int
1503
+ out_channels: int
1504
+ kernel_size: tuple[int, ...]
1505
+ weight: UninitializedParameter
1506
+ bias: UninitializedParameter
1507
+
1508
+ def reset_parameters(self) -> None:
1509
+ # has_uninitialized_params is defined in parent class and it is using a protocol on self
1510
+ if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc]
1511
+ # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined
1512
+ # in super class. Turns out that it is defined in _ConvND which is inherited by any class
1513
+ # that also inherits _LazyConvXdMixin
1514
+ super().reset_parameters() # type: ignore[misc]
1515
+
1516
+ # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin
1517
+ def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type: ignore[override]
1518
+ # defined by parent class but using a protocol
1519
+ if self.has_uninitialized_params(): # type: ignore[misc]
1520
+ self.in_channels = self._get_in_channels(input)
1521
+ if self.in_channels % self.groups != 0:
1522
+ raise ValueError("in_channels must be divisible by groups")
1523
+ assert isinstance(self.weight, UninitializedParameter)
1524
+ if self.transposed:
1525
+ self.weight.materialize(
1526
+ (
1527
+ self.in_channels,
1528
+ self.out_channels // self.groups,
1529
+ *self.kernel_size,
1530
+ )
1531
+ )
1532
+ else:
1533
+ self.weight.materialize(
1534
+ (
1535
+ self.out_channels,
1536
+ self.in_channels // self.groups,
1537
+ *self.kernel_size,
1538
+ )
1539
+ )
1540
+ if self.bias is not None:
1541
+ assert isinstance(self.bias, UninitializedParameter)
1542
+ self.bias.materialize((self.out_channels,))
1543
+ self.reset_parameters()
1544
+
1545
+ # Function to extract in_channels from first input.
1546
+ def _get_in_channels(self, input: Tensor) -> int:
1547
+ num_spatial_dims = self._get_num_spatial_dims()
1548
+ num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim
1549
+ num_dims_batch = num_dims_no_batch + 1
1550
+ if input.dim() not in (num_dims_no_batch, num_dims_batch):
1551
+ raise RuntimeError(
1552
+ f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input "
1553
+ f"to {self.__class__.__name__}, but "
1554
+ f"got input of size: {input.shape}"
1555
+ )
1556
+ return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
1557
+
1558
+ # Function to return the number of spatial dims expected for inputs to the module.
1559
+ # This is expected to be implemented by subclasses.
1560
+ def _get_num_spatial_dims(self) -> int:
1561
+ raise NotImplementedError
1562
+
1563
+
1564
+ # LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
1565
+ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
1566
+ r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
1567
+
1568
+ The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``.
1569
+ The attributes that will be lazily initialized are `weight` and `bias`.
1570
+
1571
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1572
+ on lazy modules and their limitations.
1573
+
1574
+ Args:
1575
+ out_channels (int): Number of channels produced by the convolution
1576
+ kernel_size (int or tuple): Size of the convolving kernel
1577
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1578
+ padding (int or tuple, optional): Zero-padding added to both sides of
1579
+ the input. Default: 0
1580
+ dilation (int or tuple, optional): Spacing between kernel
1581
+ elements. Default: 1
1582
+ groups (int, optional): Number of blocked connections from input
1583
+ channels to output channels. Default: 1
1584
+ bias (bool, optional): If ``True``, adds a learnable bias to the
1585
+ output. Default: ``True``
1586
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1587
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1588
+
1589
+ .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1590
+ """
1591
+
1592
+ # super class define this variable as None. "type: ignore[..] is required
1593
+ # since we are redefining the variable.
1594
+ cls_to_become = Conv1d # type: ignore[assignment]
1595
+
1596
+ def __init__(
1597
+ self,
1598
+ out_channels: int,
1599
+ kernel_size: _size_1_t,
1600
+ stride: _size_1_t = 1,
1601
+ padding: _size_1_t = 0,
1602
+ dilation: _size_1_t = 1,
1603
+ groups: int = 1,
1604
+ bias: bool = True,
1605
+ padding_mode: str = "zeros",
1606
+ device=None,
1607
+ dtype=None,
1608
+ ) -> None:
1609
+ factory_kwargs = {"device": device, "dtype": dtype}
1610
+ super().__init__(
1611
+ 0,
1612
+ 0,
1613
+ kernel_size,
1614
+ stride,
1615
+ padding,
1616
+ dilation,
1617
+ groups,
1618
+ # bias is hardcoded to False to avoid creating tensor
1619
+ # that will soon be overwritten.
1620
+ False,
1621
+ padding_mode,
1622
+ **factory_kwargs,
1623
+ )
1624
+ self.weight = UninitializedParameter(**factory_kwargs)
1625
+ self.out_channels = out_channels
1626
+ if bias:
1627
+ self.bias = UninitializedParameter(**factory_kwargs)
1628
+
1629
+ def _get_num_spatial_dims(self) -> int:
1630
+ return 1
1631
+
1632
+
1633
+ # LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter
1634
+ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
1635
+ r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
1636
+
1637
+ The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``.
1638
+ The attributes that will be lazily initialized are `weight` and `bias`.
1639
+
1640
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1641
+ on lazy modules and their limitations.
1642
+
1643
+ Args:
1644
+ out_channels (int): Number of channels produced by the convolution
1645
+ kernel_size (int or tuple): Size of the convolving kernel
1646
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1647
+ padding (int or tuple, optional): Zero-padding added to both sides of
1648
+ the input. Default: 0
1649
+ dilation (int or tuple, optional): Spacing between kernel
1650
+ elements. Default: 1
1651
+ groups (int, optional): Number of blocked connections from input
1652
+ channels to output channels. Default: 1
1653
+ bias (bool, optional): If ``True``, adds a learnable bias to the
1654
+ output. Default: ``True``
1655
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1656
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1657
+
1658
+ .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1659
+ """
1660
+
1661
+ # super class define this variable as None. "type: ignore[..] is required
1662
+ # since we are redefining the variable.
1663
+ cls_to_become = Conv2d # type: ignore[assignment]
1664
+
1665
+ def __init__(
1666
+ self,
1667
+ out_channels: int,
1668
+ kernel_size: _size_2_t,
1669
+ stride: _size_2_t = 1,
1670
+ padding: _size_2_t = 0,
1671
+ dilation: _size_2_t = 1,
1672
+ groups: int = 1,
1673
+ bias: bool = True,
1674
+ padding_mode: str = "zeros", # TODO: refine this type
1675
+ device=None,
1676
+ dtype=None,
1677
+ ) -> None:
1678
+ factory_kwargs = {"device": device, "dtype": dtype}
1679
+ super().__init__(
1680
+ 0,
1681
+ 0,
1682
+ kernel_size,
1683
+ stride,
1684
+ padding,
1685
+ dilation,
1686
+ groups,
1687
+ # bias is hardcoded to False to avoid creating tensor
1688
+ # that will soon be overwritten.
1689
+ False,
1690
+ padding_mode,
1691
+ **factory_kwargs,
1692
+ )
1693
+ self.weight = UninitializedParameter(**factory_kwargs)
1694
+ self.out_channels = out_channels
1695
+ if bias:
1696
+ self.bias = UninitializedParameter(**factory_kwargs)
1697
+
1698
+ def _get_num_spatial_dims(self) -> int:
1699
+ return 2
1700
+
1701
+
1702
+ # LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter
1703
+ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
1704
+ r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
1705
+
1706
+ The ``in_channels`` argument of the :class:`Conv3d` that is inferred from
1707
+ the ``input.size(1)``.
1708
+ The attributes that will be lazily initialized are `weight` and `bias`.
1709
+
1710
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1711
+ on lazy modules and their limitations.
1712
+
1713
+ Args:
1714
+ out_channels (int): Number of channels produced by the convolution
1715
+ kernel_size (int or tuple): Size of the convolving kernel
1716
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1717
+ padding (int or tuple, optional): Zero-padding added to both sides of
1718
+ the input. Default: 0
1719
+ dilation (int or tuple, optional): Spacing between kernel
1720
+ elements. Default: 1
1721
+ groups (int, optional): Number of blocked connections from input
1722
+ channels to output channels. Default: 1
1723
+ bias (bool, optional): If ``True``, adds a learnable bias to the
1724
+ output. Default: ``True``
1725
+ padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
1726
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
1727
+
1728
+ .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1729
+ """
1730
+
1731
+ # super class define this variable as None. "type: ignore[..] is required
1732
+ # since we are redefining the variable.
1733
+ cls_to_become = Conv3d # type: ignore[assignment]
1734
+
1735
+ def __init__(
1736
+ self,
1737
+ out_channels: int,
1738
+ kernel_size: _size_3_t,
1739
+ stride: _size_3_t = 1,
1740
+ padding: _size_3_t = 0,
1741
+ dilation: _size_3_t = 1,
1742
+ groups: int = 1,
1743
+ bias: bool = True,
1744
+ padding_mode: str = "zeros",
1745
+ device=None,
1746
+ dtype=None,
1747
+ ) -> None:
1748
+ factory_kwargs = {"device": device, "dtype": dtype}
1749
+ super().__init__(
1750
+ 0,
1751
+ 0,
1752
+ kernel_size,
1753
+ stride,
1754
+ padding,
1755
+ dilation,
1756
+ groups,
1757
+ # bias is hardcoded to False to avoid creating tensor
1758
+ # that will soon be overwritten.
1759
+ False,
1760
+ padding_mode,
1761
+ **factory_kwargs,
1762
+ )
1763
+ self.weight = UninitializedParameter(**factory_kwargs)
1764
+ self.out_channels = out_channels
1765
+ if bias:
1766
+ self.bias = UninitializedParameter(**factory_kwargs)
1767
+
1768
+ def _get_num_spatial_dims(self) -> int:
1769
+ return 3
1770
+
1771
+
1772
+ # LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter
1773
+ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
1774
+ r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
1775
+
1776
+ The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from
1777
+ the ``input.size(1)``.
1778
+ The attributes that will be lazily initialized are `weight` and `bias`.
1779
+
1780
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1781
+ on lazy modules and their limitations.
1782
+
1783
+ Args:
1784
+ out_channels (int): Number of channels produced by the convolution
1785
+ kernel_size (int or tuple): Size of the convolving kernel
1786
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1787
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1788
+ will be added to both sides of the input. Default: 0
1789
+ output_padding (int or tuple, optional): Additional size added to one side
1790
+ of the output shape. Default: 0
1791
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1792
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1793
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1794
+
1795
+ .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1796
+ """
1797
+
1798
+ # super class define this variable as None. "type: ignore[..] is required
1799
+ # since we are redefining the variable.
1800
+ cls_to_become = ConvTranspose1d # type: ignore[assignment]
1801
+
1802
+ def __init__(
1803
+ self,
1804
+ out_channels: int,
1805
+ kernel_size: _size_1_t,
1806
+ stride: _size_1_t = 1,
1807
+ padding: _size_1_t = 0,
1808
+ output_padding: _size_1_t = 0,
1809
+ groups: int = 1,
1810
+ bias: bool = True,
1811
+ dilation: _size_1_t = 1,
1812
+ padding_mode: str = "zeros",
1813
+ device=None,
1814
+ dtype=None,
1815
+ ) -> None:
1816
+ factory_kwargs = {"device": device, "dtype": dtype}
1817
+ super().__init__(
1818
+ 0,
1819
+ 0,
1820
+ kernel_size,
1821
+ stride,
1822
+ padding,
1823
+ output_padding,
1824
+ groups,
1825
+ # bias is hardcoded to False to avoid creating tensor
1826
+ # that will soon be overwritten.
1827
+ False,
1828
+ dilation,
1829
+ padding_mode,
1830
+ **factory_kwargs,
1831
+ )
1832
+ self.weight = UninitializedParameter(**factory_kwargs)
1833
+ self.out_channels = out_channels
1834
+ if bias:
1835
+ self.bias = UninitializedParameter(**factory_kwargs)
1836
+
1837
+ def _get_num_spatial_dims(self) -> int:
1838
+ return 1
1839
+
1840
+
1841
+ # LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter
1842
+ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
1843
+ r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
1844
+
1845
+ The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from
1846
+ the ``input.size(1)``.
1847
+ The attributes that will be lazily initialized are `weight` and `bias`.
1848
+
1849
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1850
+ on lazy modules and their limitations.
1851
+
1852
+ Args:
1853
+ out_channels (int): Number of channels produced by the convolution
1854
+ kernel_size (int or tuple): Size of the convolving kernel
1855
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1856
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1857
+ will be added to both sides of each dimension in the input. Default: 0
1858
+ output_padding (int or tuple, optional): Additional size added to one side
1859
+ of each dimension in the output shape. Default: 0
1860
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1861
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1862
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1863
+
1864
+ .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1865
+ """
1866
+
1867
+ # super class define this variable as None. "type: ignore[..] is required
1868
+ # since we are redefining the variable.
1869
+ cls_to_become = ConvTranspose2d # type: ignore[assignment]
1870
+
1871
+ def __init__(
1872
+ self,
1873
+ out_channels: int,
1874
+ kernel_size: _size_2_t,
1875
+ stride: _size_2_t = 1,
1876
+ padding: _size_2_t = 0,
1877
+ output_padding: _size_2_t = 0,
1878
+ groups: int = 1,
1879
+ bias: bool = True,
1880
+ dilation: int = 1,
1881
+ padding_mode: str = "zeros",
1882
+ device=None,
1883
+ dtype=None,
1884
+ ) -> None:
1885
+ factory_kwargs = {"device": device, "dtype": dtype}
1886
+ super().__init__(
1887
+ 0,
1888
+ 0,
1889
+ kernel_size,
1890
+ stride,
1891
+ padding,
1892
+ output_padding,
1893
+ groups,
1894
+ # bias is hardcoded to False to avoid creating tensor
1895
+ # that will soon be overwritten.
1896
+ False,
1897
+ dilation,
1898
+ padding_mode,
1899
+ **factory_kwargs,
1900
+ )
1901
+ self.weight = UninitializedParameter(**factory_kwargs)
1902
+ self.out_channels = out_channels
1903
+ if bias:
1904
+ self.bias = UninitializedParameter(**factory_kwargs)
1905
+
1906
+ def _get_num_spatial_dims(self) -> int:
1907
+ return 2
1908
+
1909
+
1910
+ # LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter
1911
+ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
1912
+ r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
1913
+
1914
+ The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from
1915
+ the ``input.size(1)``.
1916
+ The attributes that will be lazily initialized are `weight` and `bias`.
1917
+
1918
+ Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
1919
+ on lazy modules and their limitations.
1920
+
1921
+ Args:
1922
+ out_channels (int): Number of channels produced by the convolution
1923
+ kernel_size (int or tuple): Size of the convolving kernel
1924
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
1925
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
1926
+ will be added to both sides of each dimension in the input. Default: 0
1927
+ output_padding (int or tuple, optional): Additional size added to one side
1928
+ of each dimension in the output shape. Default: 0
1929
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
1930
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
1931
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
1932
+
1933
+ .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
1934
+ """
1935
+
1936
+ # super class define this variable as None. "type: ignore[..] is required
1937
+ # since we are redefining the variable.
1938
+ cls_to_become = ConvTranspose3d # type: ignore[assignment]
1939
+
1940
+ def __init__(
1941
+ self,
1942
+ out_channels: int,
1943
+ kernel_size: _size_3_t,
1944
+ stride: _size_3_t = 1,
1945
+ padding: _size_3_t = 0,
1946
+ output_padding: _size_3_t = 0,
1947
+ groups: int = 1,
1948
+ bias: bool = True,
1949
+ dilation: _size_3_t = 1,
1950
+ padding_mode: str = "zeros",
1951
+ device=None,
1952
+ dtype=None,
1953
+ ) -> None:
1954
+ factory_kwargs = {"device": device, "dtype": dtype}
1955
+ super().__init__(
1956
+ 0,
1957
+ 0,
1958
+ kernel_size,
1959
+ stride,
1960
+ padding,
1961
+ output_padding,
1962
+ groups,
1963
+ # bias is hardcoded to False to avoid creating tensor
1964
+ # that will soon be overwritten.
1965
+ False,
1966
+ dilation,
1967
+ padding_mode,
1968
+ **factory_kwargs,
1969
+ )
1970
+ self.weight = UninitializedParameter(**factory_kwargs)
1971
+ self.out_channels = out_channels
1972
+ if bias:
1973
+ self.bias = UninitializedParameter(**factory_kwargs)
1974
+
1975
+ def _get_num_spatial_dims(self) -> int:
1976
+ return 3
1977
+
1978
+
1979
+ class YatConv1d(YatConvNd, Conv1d):
1980
+ def __init__(
1981
+ self,
1982
+ in_channels: int,
1983
+ out_channels: int,
1984
+ kernel_size: _size_1_t,
1985
+ stride: _size_1_t = 1,
1986
+ padding: Union[str, _size_1_t] = 0,
1987
+ dilation: _size_1_t = 1,
1988
+ groups: int = 1,
1989
+ bias: bool = True,
1990
+ padding_mode: str = "zeros",
1991
+ use_alpha: bool = True,
1992
+ epsilon: float = 1e-5,
1993
+ device=None,
1994
+ dtype=None,
1995
+ ) -> None:
1996
+ factory_kwargs = {"device": device, "dtype": dtype}
1997
+ kernel_size_ = _single(kernel_size)
1998
+ stride_ = _single(stride)
1999
+ padding_ = padding if isinstance(padding, str) else _single(padding)
2000
+ dilation_ = _single(dilation)
2001
+ super().__init__(
2002
+ in_channels,
2003
+ out_channels,
2004
+ kernel_size_,
2005
+ stride_,
2006
+ padding_,
2007
+ dilation_,
2008
+ False,
2009
+ _single(0),
2010
+ groups,
2011
+ bias,
2012
+ padding_mode,
2013
+ use_alpha,
2014
+ epsilon,
2015
+ **factory_kwargs,
2016
+ )
2017
+
2018
+ def forward(self, input: Tensor) -> Tensor:
2019
+ return self._yat_forward(input, F.conv1d)
2020
+
2021
+
2022
+ class YatConv2d(YatConvNd, Conv2d):
2023
+ def __init__(
2024
+ self,
2025
+ in_channels: int,
2026
+ out_channels: int,
2027
+ kernel_size: _size_2_t,
2028
+ stride: _size_2_t = 1,
2029
+ padding: Union[str, _size_2_t] = 0,
2030
+ dilation: _size_2_t = 1,
2031
+ groups: int = 1,
2032
+ bias: bool = True,
2033
+ padding_mode: str = "zeros",
2034
+ use_alpha: bool = True,
2035
+ epsilon: float = 1e-5,
2036
+ device=None,
2037
+ dtype=None,
2038
+ ) -> None:
2039
+ factory_kwargs = {"device": device, "dtype": dtype}
2040
+ kernel_size_ = _pair(kernel_size)
2041
+ stride_ = _pair(stride)
2042
+ padding_ = padding if isinstance(padding, str) else _pair(padding)
2043
+ dilation_ = _pair(dilation)
2044
+ super().__init__(
2045
+ in_channels,
2046
+ out_channels,
2047
+ kernel_size_,
2048
+ stride_,
2049
+ padding_,
2050
+ dilation_,
2051
+ False,
2052
+ _pair(0),
2053
+ groups,
2054
+ bias,
2055
+ padding_mode,
2056
+ use_alpha,
2057
+ epsilon,
2058
+ **factory_kwargs,
2059
+ )
2060
+
2061
+ def forward(self, input: Tensor) -> Tensor:
2062
+ return self._yat_forward(input, F.conv2d)
2063
+
2064
+
2065
+ class YatConv3d(YatConvNd, Conv3d):
2066
+ def __init__(
2067
+ self,
2068
+ in_channels: int,
2069
+ out_channels: int,
2070
+ kernel_size: _size_3_t,
2071
+ stride: _size_3_t = 1,
2072
+ padding: Union[str, _size_3_t] = 0,
2073
+ dilation: _size_3_t = 1,
2074
+ groups: int = 1,
2075
+ bias: bool = True,
2076
+ padding_mode: str = "zeros",
2077
+ use_alpha: bool = True,
2078
+ epsilon: float = 1e-5,
2079
+ device=None,
2080
+ dtype=None,
2081
+ ) -> None:
2082
+ factory_kwargs = {"device": device, "dtype": dtype}
2083
+ kernel_size_ = _triple(kernel_size)
2084
+ stride_ = _triple(stride)
2085
+ padding_ = padding if isinstance(padding, str) else _triple(padding)
2086
+ dilation_ = _triple(dilation)
2087
+ super().__init__(
2088
+ in_channels,
2089
+ out_channels,
2090
+ kernel_size_,
2091
+ stride_,
2092
+ padding_,
2093
+ dilation_,
2094
+ False,
2095
+ _triple(0),
2096
+ groups,
2097
+ bias,
2098
+ padding_mode,
2099
+ use_alpha,
2100
+ epsilon,
2101
+ **factory_kwargs,
2102
+ )
2103
+
2104
+ def forward(self, input: Tensor) -> Tensor:
2105
+ return self._yat_forward(input, F.conv3d)