brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_conv.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,19 +18,22 @@
18
18
  import collections.abc
19
19
  from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
20
20
 
21
+ import brainunit as u
21
22
  import jax
22
23
  import jax.numpy as jnp
23
24
 
24
- from brainstate import init, functional
25
25
  from brainstate._state import ParamState
26
26
  from brainstate.typing import ArrayLike
27
+ from . import init as init
27
28
  from ._module import Module
29
+ from ._normalizations import weight_standardization
28
30
 
29
31
  T = TypeVar('T')
30
32
 
31
33
  __all__ = [
32
34
  'Conv1d', 'Conv2d', 'Conv3d',
33
35
  'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
36
+ 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
34
37
  ]
35
38
 
36
39
 
@@ -39,7 +42,40 @@ def to_dimension_numbers(
39
42
  channels_last: bool,
40
43
  transpose: bool
41
44
  ) -> jax.lax.ConvDimensionNumbers:
42
- """Create a `lax.ConvDimensionNumbers` for the given inputs."""
45
+ """
46
+ Create a `lax.ConvDimensionNumbers` for the given inputs.
47
+
48
+ This function generates the dimension specification needed for JAX's convolution
49
+ operations based on the number of spatial dimensions and data format.
50
+
51
+ Parameters
52
+ ----------
53
+ num_spatial_dims : int
54
+ The number of spatial dimensions (e.g., 1 for Conv1d, 2 for Conv2d, 3 for Conv3d).
55
+ channels_last : bool
56
+
57
+ - If True, the input format is channels-last (e.g., [B, H, W, C] for 2D).
58
+ - If False, the input format is channels-first (e.g., [B, C, H, W] for 2D).
59
+ transpose : bool
60
+
61
+ - If True, creates dimension numbers for transposed convolution.
62
+ - If False, creates dimension numbers for standard convolution.
63
+
64
+ Returns
65
+ -------
66
+ jax.lax.ConvDimensionNumbers
67
+ A named tuple specifying the dimension layout for lhs (input), rhs (kernel),
68
+ and output of the convolution operation.
69
+
70
+ Examples
71
+ --------
72
+ .. code-block:: python
73
+
74
+ >>> # For 2D convolution with channels-last format
75
+ >>> dim_nums = to_dimension_numbers(num_spatial_dims=2, channels_last=True, transpose=False)
76
+ >>> print(dim_nums.lhs_spec) # Input layout: (batch, spatial_1, spatial_2, channel)
77
+ (0, 3, 1, 2)
78
+ """
43
79
  num_dims = num_spatial_dims + 2
44
80
  if channels_last:
45
81
  spatial_dims = tuple(range(1, num_dims - 1))
@@ -61,7 +97,48 @@ def replicate(
61
97
  num_replicate: int,
62
98
  name: str,
63
99
  ) -> Tuple[T, ...]:
64
- """Replicates entry in `element` `num_replicate` if needed."""
100
+ """
101
+ Replicates entry in `element` `num_replicate` times if needed.
102
+
103
+ This utility function ensures that parameters like kernel_size, stride, etc.
104
+ are properly formatted as tuples with the correct length for multi-dimensional
105
+ convolutions.
106
+
107
+ Parameters
108
+ ----------
109
+ element : T or Sequence[T]
110
+ The element to replicate. Can be a scalar, string, or sequence.
111
+ num_replicate : int
112
+ The number of times to replicate the element.
113
+ name : str
114
+ The name of the parameter (used for error messages).
115
+
116
+ Returns
117
+ -------
118
+ tuple of T
119
+ A tuple containing the replicated elements.
120
+
121
+ Raises
122
+ ------
123
+ TypeError
124
+ If the element is a sequence with length not equal to 1 or `num_replicate`.
125
+
126
+ Examples
127
+ --------
128
+ .. code-block:: python
129
+
130
+ >>> # Replicate a scalar value
131
+ >>> replicate(3, 2, 'kernel_size')
132
+ (3, 3)
133
+ >>>
134
+ >>> # Keep a sequence as is if already correct length
135
+ >>> replicate((3, 5), 2, 'kernel_size')
136
+ (3, 5)
137
+ >>>
138
+ >>> # Replicate a single-element sequence
139
+ >>> replicate([3], 2, 'kernel_size')
140
+ (3, 3)
141
+ """
65
142
  if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
66
143
  return (element,) * num_replicate
67
144
  elif len(element) == 1:
@@ -91,6 +168,7 @@ class _BaseConv(Module):
91
168
  rhs_dilation: Union[int, Tuple[int, ...]] = 1,
92
169
  groups: int = 1,
93
170
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
171
+ channel_first: bool = False,
94
172
  name: str = None,
95
173
  ):
96
174
  super().__init__(name=name)
@@ -98,14 +176,26 @@ class _BaseConv(Module):
98
176
  # general parameters
99
177
  assert self.num_spatial_dims + 1 == len(in_size)
100
178
  self.in_size = tuple(in_size)
101
- self.in_channels = in_size[-1]
179
+ self.channel_first = channel_first
180
+ self.channels_last = not channel_first
181
+
182
+ # Determine in_channels based on channel_first
183
+ if self.channel_first:
184
+ self.in_channels = in_size[0]
185
+ else:
186
+ self.in_channels = in_size[-1]
187
+
102
188
  self.out_channels = out_channels
103
189
  self.stride = replicate(stride, self.num_spatial_dims, 'stride')
104
190
  self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
105
191
  self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
106
192
  self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
107
193
  self.groups = groups
108
- self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
194
+ self.dimension_numbers = to_dimension_numbers(
195
+ self.num_spatial_dims,
196
+ channels_last=self.channels_last,
197
+ transpose=False
198
+ )
109
199
 
110
200
  # the padding parameter
111
201
  if isinstance(padding, str):
@@ -147,30 +237,30 @@ class _BaseConv(Module):
147
237
  else:
148
238
  raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
149
239
  f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
150
- if self.in_size != x_shape:
151
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
240
+
241
+ # Check shape matches expected in_size
242
+ if self.channel_first:
243
+ # For channels-first, expected shape is already (C, spatial...)
244
+ expected_shape = self.in_size
245
+ else:
246
+ # For channels-last, expected shape is (spatial..., C)
247
+ expected_shape = self.in_size
248
+
249
+ if expected_shape != x_shape:
250
+ raise ValueError(f"The expected input shape is {expected_shape}, while we got {x_shape}.")
152
251
 
153
252
  def update(self, x):
154
253
  self._check_input_dim(x)
155
254
  non_batching = False
156
255
  if x.ndim == self.num_spatial_dims + 1:
157
- x = jnp.expand_dims(x, 0)
256
+ x = u.math.expand_dims(x, 0)
158
257
  non_batching = True
159
258
  y = self._conv_op(x, self.weight.value)
160
- return y[0] if non_batching else y
259
+ return u.math.squeeze(y, axis=0) if non_batching else y
161
260
 
162
261
  def _conv_op(self, x, params):
163
262
  raise NotImplementedError
164
263
 
165
- def __repr__(self):
166
- return (f'{self.__class__.__name__}('
167
- f'in_channels={self.in_channels}, '
168
- f'out_channels={self.out_channels}, '
169
- f'kernel_size={self.kernel_size}, '
170
- f'stride={self.stride}, '
171
- f'padding={self.padding}, '
172
- f'groups={self.groups})')
173
-
174
264
 
175
265
  class _Conv(_BaseConv):
176
266
  num_spatial_dims: int = None
@@ -188,6 +278,7 @@ class _Conv(_BaseConv):
188
278
  w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
189
279
  b_init: Optional[Union[Callable, ArrayLike]] = None,
190
280
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
281
+ channel_first: bool = False,
191
282
  name: str = None,
192
283
  param_type: type = ParamState,
193
284
  ):
@@ -201,6 +292,7 @@ class _Conv(_BaseConv):
201
292
  rhs_dilation=rhs_dilation,
202
293
  groups=groups,
203
294
  w_mask=w_mask,
295
+ channel_first=channel_first,
204
296
  name=name
205
297
  )
206
298
 
@@ -219,9 +311,10 @@ class _Conv(_BaseConv):
219
311
  self.weight = param_type(params)
220
312
 
221
313
  # Evaluate the output shape
314
+ test_input_shape = (128,) + self.in_size
222
315
  abstract_y = jax.eval_shape(
223
316
  self._conv_op,
224
- jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
317
+ jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
225
318
  params
226
319
  )
227
320
  y_shape = abstract_y.shape[1:]
@@ -247,89 +340,433 @@ class _Conv(_BaseConv):
247
340
 
248
341
 
249
342
  class Conv1d(_Conv):
250
- """One-dimensional convolution.
343
+ """
344
+ One-dimensional convolution layer.
251
345
 
252
- The input should be a 3d array with the shape of ``[B, H, C]``.
346
+ Applies a 1D convolution over an input signal composed of several input planes.
347
+ The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
348
+ L is the sequence length, and C is the number of input channels.
349
+
350
+ This layer creates a convolution kernel that is convolved with the layer input
351
+ over a single spatial dimension to produce a tensor of outputs.
253
352
 
254
353
  Parameters
255
354
  ----------
256
- %s
355
+ in_size : tuple of int
356
+ The input shape without the batch dimension. This argument is important as it is
357
+ used to evaluate the output shape. For Conv1d: (L, C), Conv2d: (H, W, C), Conv3d: (H, W, D, C).
358
+ out_channels : int
359
+ The number of output channels (also called filters or feature maps).
360
+ kernel_size : int or tuple of int
361
+ The shape of the convolutional kernel. For 1D convolution, the kernel size can be
362
+ passed as an integer. For 2D and 3D convolutions, it should be a tuple of integers
363
+ or a single integer (which will be replicated for all spatial dimensions).
364
+ stride : int or tuple of int, optional
365
+ The stride of the convolution. An integer or a sequence of `n` integers, representing
366
+ the inter-window strides along each spatial dimension. Default: 1.
367
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
368
+ The padding strategy. Can be:
369
+
370
+ - 'SAME': pads the input so the output has the same shape as input when stride=1
371
+ - 'VALID': no padding
372
+ - int: symmetric padding applied to all spatial dimensions
373
+ - tuple of (low, high): padding for each dimension
374
+ - sequence of tuples: explicit padding for each spatial dimension
375
+
376
+ Default: 'SAME'.
377
+ lhs_dilation : int or tuple of int, optional
378
+ The dilation factor for the input. An integer or a sequence of `n` integers, giving
379
+ the dilation factor to apply in each spatial dimension of inputs. Convolution with
380
+ input dilation `d` is equivalent to transposed convolution with stride `d`.
381
+ Default: 1.
382
+ rhs_dilation : int or tuple of int, optional
383
+ The dilation factor for the kernel. An integer or a sequence of `n` integers, giving
384
+ the dilation factor to apply in each spatial dimension of the convolution kernel.
385
+ Convolution with kernel dilation is also known as 'atrous convolution', which increases
386
+ the receptive field without increasing the number of parameters. Default: 1.
387
+ groups : int, optional
388
+ Number of groups for grouped convolution. Controls the connections between inputs and
389
+ outputs. Both `in_channels` and `out_channels` must be divisible by `groups`. When
390
+ groups=1 (default), all inputs are convolved to all outputs. When groups>1, the input
391
+ and output channels are divided into groups, and each group is convolved independently.
392
+ When groups=in_channels, this becomes a depthwise convolution. Default: 1.
393
+ w_init : Callable or ArrayLike, optional
394
+ The initializer for the convolutional kernel weights. Can be an initializer instance
395
+ or a direct array. Default: XavierNormal().
396
+ b_init : Callable or ArrayLike or None, optional
397
+ The initializer for the bias. If None, no bias is added. Default: None.
398
+ w_mask : ArrayLike or Callable or None, optional
399
+ An optional mask applied to the weights during forward pass. Useful for implementing
400
+ structured sparsity or custom connectivity patterns. Default: None.
401
+ name : str, optional
402
+ The name of the module. Default: None.
403
+ param_type : type, optional
404
+ The type of parameter state to use. Default: ParamState.
405
+
406
+ Attributes
407
+ ----------
408
+ in_size : tuple of int
409
+ The input shape (L, C) without batch dimension.
410
+ out_size : tuple of int
411
+ The output shape (L_out, out_channels) without batch dimension.
412
+ in_channels : int
413
+ Number of input channels.
414
+ out_channels : int
415
+ Number of output channels.
416
+ kernel_size : tuple of int
417
+ Size of the convolving kernel.
418
+ weight : ParamState
419
+ The learnable weights (and bias if specified) of the module.
420
+
421
+ Examples
422
+ --------
423
+ .. code-block:: python
424
+
425
+ >>> import brainstate as brainstate
426
+ >>> import jax.numpy as jnp
427
+ >>>
428
+ >>> # Create a 1D convolution layer
429
+ >>> conv = brainstate.nn.Conv1d(in_size=(28, 3), out_channels=16, kernel_size=5)
430
+ >>>
431
+ >>> # Apply to input: batch_size=2, length=28, channels=3
432
+ >>> x = jnp.ones((2, 28, 3))
433
+ >>> y = conv(x)
434
+ >>> print(y.shape) # (2, 28, 16) with 'SAME' padding
435
+ >>>
436
+ >>> # Without batch dimension
437
+ >>> x_single = jnp.ones((28, 3))
438
+ >>> y_single = conv(x_single)
439
+ >>> print(y_single.shape) # (28, 16)
440
+ >>>
441
+ >>> # With custom parameters
442
+ >>> conv = brainstate.nn.Conv1d(
443
+ ... in_size=(100, 8),
444
+ ... out_channels=32,
445
+ ... kernel_size=3,
446
+ ... stride=2,
447
+ ... padding='VALID',
448
+ ... b_init=brainstate.init.ZeroInit()
449
+ ... )
450
+
451
+ Notes
452
+ -----
453
+ **Output dimensions:**
454
+
455
+ The output shape depends on the padding mode:
456
+
457
+ - 'SAME': output length = ceil(input_length / stride)
458
+ - 'VALID': output length = ceil((input_length - kernel_size + 1) / stride)
459
+
460
+ **Grouped convolution:**
461
+
462
+ When groups > 1, the convolution becomes a grouped convolution where input and
463
+ output channels are divided into groups, reducing computational cost.
257
464
  """
258
465
  __module__ = 'brainstate.nn'
259
466
  num_spatial_dims: int = 1
260
467
 
261
468
 
262
469
  class Conv2d(_Conv):
263
- """Two-dimensional convolution.
470
+ """
471
+ Two-dimensional convolution layer.
264
472
 
265
- The input should be a 4d array with the shape of ``[B, H, W, C]``.
473
+ Applies a 2D convolution over an input signal composed of several input planes.
474
+ The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
475
+ H is height, W is width, and C is the number of input channels (channels-last format).
476
+
477
+ This layer creates a convolution kernel that is convolved with the layer input
478
+ to produce a tensor of outputs. It is commonly used in computer vision tasks.
266
479
 
267
480
  Parameters
268
481
  ----------
269
- %s
482
+ in_size : tuple of int
483
+ The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
484
+ W is width, and C is the number of input channels. This argument is important as it is
485
+ used to evaluate the output shape.
486
+ out_channels : int
487
+ The number of output channels (also called filters or feature maps). These determine
488
+ the depth of the output feature map.
489
+ kernel_size : int or tuple of int
490
+ The shape of the convolutional kernel. Can be:
491
+
492
+ - An integer (e.g., 3): creates a square kernel (3, 3)
493
+ - A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
494
+ stride : int or tuple of int, optional
495
+ The stride of the convolution. Controls how much the kernel moves at each step.
496
+ Can be:
497
+
498
+ - An integer: same stride for both dimensions
499
+ - A tuple of two integers: (stride_height, stride_width)
500
+
501
+ Default: 1.
502
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
503
+ The padding strategy. Options:
504
+
505
+ - 'SAME': output spatial size equals input size when stride=1
506
+ - 'VALID': no padding, output size reduced by kernel size
507
+ - int: same symmetric padding for all dimensions
508
+ - (pad_h, pad_w): different padding for each dimension
509
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
510
+
511
+ Default: 'SAME'.
512
+ lhs_dilation : int or tuple of int, optional
513
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
514
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
515
+ Default: 1.
516
+ rhs_dilation : int or tuple of int, optional
517
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
518
+ or dilated convolution. Increases the receptive field without increasing parameters by
519
+ inserting zeros between kernel elements. Useful for capturing multi-scale context.
520
+ Default: 1.
521
+ groups : int, optional
522
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
523
+
524
+ - groups=1: standard convolution (all-to-all connections)
525
+ - groups>1: grouped convolution (reduces parameters by factor of groups)
526
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
527
+
528
+ Default: 1.
529
+ w_init : Callable or ArrayLike, optional
530
+ Weight initializer for the convolutional kernel. Can be:
531
+
532
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
533
+ - A callable that returns an array given a shape
534
+ - A direct array matching the kernel shape
535
+
536
+ Default: XavierNormal().
537
+ b_init : Callable or ArrayLike or None, optional
538
+ Bias initializer. If None, no bias term is added to the output.
539
+ Default: None.
540
+ w_mask : ArrayLike or Callable or None, optional
541
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
542
+ element-wise multiplied with the kernel weights during the forward pass.
543
+ Default: None.
544
+ name : str, optional
545
+ Name identifier for this module instance.
546
+ Default: None.
547
+ param_type : type, optional
548
+ The parameter state class to use for managing learnable parameters.
549
+ Default: ParamState.
550
+
551
+ Attributes
552
+ ----------
553
+ in_size : tuple of int
554
+ The input shape (H, W, C) without batch dimension.
555
+ out_size : tuple of int
556
+ The output shape (H_out, W_out, out_channels) without batch dimension.
557
+ in_channels : int
558
+ Number of input channels.
559
+ out_channels : int
560
+ Number of output channels.
561
+ kernel_size : tuple of int
562
+ Size of the convolving kernel (height, width).
563
+ weight : ParamState
564
+ The learnable weights (and bias if specified) of the module.
565
+
566
+ Examples
567
+ --------
568
+ .. code-block:: python
569
+
570
+ >>> import brainstate as brainstate
571
+ >>> import jax.numpy as jnp
572
+ >>>
573
+ >>> # Create a 2D convolution layer
574
+ >>> conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
575
+ >>>
576
+ >>> # Apply to input: batch_size=8, height=32, width=32, channels=3
577
+ >>> x = jnp.ones((8, 32, 32, 3))
578
+ >>> y = conv(x)
579
+ >>> print(y.shape) # (8, 32, 32, 64) with 'SAME' padding
580
+ >>>
581
+ >>> # Without batch dimension
582
+ >>> x_single = jnp.ones((32, 32, 3))
583
+ >>> y_single = conv(x_single)
584
+ >>> print(y_single.shape) # (32, 32, 64)
585
+ >>>
586
+ >>> # With custom kernel size and stride
587
+ >>> conv = brainstate.nn.Conv2d(
588
+ ... in_size=(224, 224, 3),
589
+ ... out_channels=128,
590
+ ... kernel_size=(5, 5),
591
+ ... stride=2,
592
+ ... padding='VALID'
593
+ ... )
594
+ >>>
595
+ >>> # Depthwise convolution (groups = in_channels)
596
+ >>> conv = brainstate.nn.Conv2d(
597
+ ... in_size=(64, 64, 32),
598
+ ... out_channels=32,
599
+ ... kernel_size=3,
600
+ ... groups=32
601
+ ... )
602
+
603
+ Notes
604
+ -----
605
+ **Output dimensions:**
606
+
607
+ The output spatial dimensions depend on the padding mode:
608
+
609
+ - 'SAME': output_size = ceil(input_size / stride)
610
+ - 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
611
+
612
+ **Grouped convolution:**
613
+
614
+ When groups > 1, the input and output channels are divided into groups.
615
+ Each group is convolved independently, which can significantly reduce
616
+ computational cost while maintaining representational power.
270
617
  """
271
618
  __module__ = 'brainstate.nn'
272
619
  num_spatial_dims: int = 2
273
620
 
274
621
 
275
622
  class Conv3d(_Conv):
276
- """Three-dimensional convolution.
623
+ """
624
+ Three-dimensional convolution layer.
277
625
 
278
- The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
626
+ Applies a 3D convolution over an input signal composed of several input planes.
627
+ The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
628
+ H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
629
+
630
+ This layer is commonly used for processing 3D data such as video sequences or
631
+ volumetric medical imaging data.
279
632
 
280
633
  Parameters
281
634
  ----------
282
- %s
635
+ in_size : tuple of int
636
+ The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
637
+ W is width, D is depth, and C is the number of input channels. This argument is important
638
+ as it is used to evaluate the output shape.
639
+ out_channels : int
640
+ The number of output channels (also called filters or feature maps). These determine
641
+ the depth of the output feature map.
642
+ kernel_size : int or tuple of int
643
+ The shape of the convolutional kernel. Can be:
644
+
645
+ - An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
646
+ - A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
647
+ stride : int or tuple of int, optional
648
+ The stride of the convolution. Controls how much the kernel moves at each step.
649
+ Can be:
650
+
651
+ - An integer: same stride for all dimensions
652
+ - A tuple of three integers: (stride_h, stride_w, stride_d)
653
+ Default: 1.
654
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
655
+ The padding strategy. Options:
656
+
657
+ - 'SAME': output spatial size equals input size when stride=1
658
+ - 'VALID': no padding, output size reduced by kernel size
659
+ - int: same symmetric padding for all dimensions
660
+ - (pad_h, pad_w, pad_d): different padding for each dimension
661
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
662
+
663
+ Default: 'SAME'.
664
+ lhs_dilation : int or tuple of int, optional
665
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
666
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
667
+ Default: 1.
668
+ rhs_dilation : int or tuple of int, optional
669
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
670
+ or dilated convolution. Increases the receptive field without increasing parameters by
671
+ inserting zeros between kernel elements. Particularly useful for 3D data to capture
672
+ larger temporal/spatial context.
673
+ Default: 1.
674
+ groups : int, optional
675
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
676
+
677
+ - groups=1: standard convolution (all-to-all connections)
678
+ - groups>1: grouped convolution (significantly reduces parameters and computation for 3D)
679
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
680
+
681
+ Default: 1.
682
+ w_init : Callable or ArrayLike, optional
683
+ Weight initializer for the convolutional kernel. Can be:
684
+
685
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
686
+ - A callable that returns an array given a shape
687
+ - A direct array matching the kernel shape
688
+
689
+ Default: XavierNormal().
690
+ b_init : Callable or ArrayLike or None, optional
691
+ Bias initializer. If None, no bias term is added to the output.
692
+ Default: None.
693
+ w_mask : ArrayLike or Callable or None, optional
694
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
695
+ element-wise multiplied with the kernel weights during the forward pass.
696
+ Default: None.
697
+ name : str, optional
698
+ Name identifier for this module instance.
699
+ Default: None.
700
+ param_type : type, optional
701
+ The parameter state class to use for managing learnable parameters.
702
+ Default: ParamState.
703
+
704
+ Attributes
705
+ ----------
706
+ in_size : tuple of int
707
+ The input shape (H, W, D, C) without batch dimension.
708
+ out_size : tuple of int
709
+ The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
710
+ in_channels : int
711
+ Number of input channels.
712
+ out_channels : int
713
+ Number of output channels.
714
+ kernel_size : tuple of int
715
+ Size of the convolving kernel (height, width, depth).
716
+ weight : ParamState
717
+ The learnable weights (and bias if specified) of the module.
718
+
719
+ Examples
720
+ --------
721
+ .. code-block:: python
722
+
723
+ >>> import brainstate as brainstate
724
+ >>> import jax.numpy as jnp
725
+ >>>
726
+ >>> # Create a 3D convolution layer for video data
727
+ >>> conv = brainstate.nn.Conv3d(in_size=(16, 64, 64, 3), out_channels=32, kernel_size=3)
728
+ >>>
729
+ >>> # Apply to input: batch_size=4, frames=16, height=64, width=64, channels=3
730
+ >>> x = jnp.ones((4, 16, 64, 64, 3))
731
+ >>> y = conv(x)
732
+ >>> print(y.shape) # (4, 16, 64, 64, 32) with 'SAME' padding
733
+ >>>
734
+ >>> # Without batch dimension
735
+ >>> x_single = jnp.ones((16, 64, 64, 3))
736
+ >>> y_single = conv(x_single)
737
+ >>> print(y_single.shape) # (16, 64, 64, 32)
738
+ >>>
739
+ >>> # For medical imaging with custom parameters
740
+ >>> conv = brainstate.nn.Conv3d(
741
+ ... in_size=(32, 32, 32, 1),
742
+ ... out_channels=64,
743
+ ... kernel_size=(3, 3, 3),
744
+ ... stride=2,
745
+ ... padding='VALID',
746
+ ... b_init=brainstate.init.Constant(0.1)
747
+ ... )
748
+
749
+ Notes
750
+ -----
751
+ **Output dimensions:**
752
+
753
+ The output spatial dimensions depend on the padding mode:
754
+
755
+ - 'SAME': output_size = ceil(input_size / stride)
756
+ - 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
757
+
758
+ **Performance considerations:**
759
+
760
+ 3D convolutions are computationally expensive. Consider using:
761
+
762
+ - Smaller kernel sizes
763
+ - Grouped convolutions (groups > 1)
764
+ - Separable convolutions for large-scale applications
283
765
  """
284
766
  __module__ = 'brainstate.nn'
285
767
  num_spatial_dims: int = 3
286
768
 
287
769
 
288
- _conv_doc = '''
289
- in_size: tuple of int
290
- The input shape, without the batch size. This argument is important, since it is
291
- used to evaluate the shape of the output.
292
- out_channels: int
293
- The number of output channels.
294
- kernel_size: int, sequence of int
295
- The shape of the convolutional kernel.
296
- For 1D convolution, the kernel size can be passed as an integer.
297
- For all other cases, it must be a sequence of integers.
298
- stride: int, sequence of int
299
- An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
300
- padding: str, int, sequence of int, sequence of tuple
301
- Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
302
- high)` integer pairs that give the padding to apply before and after each
303
- spatial dimension.
304
- lhs_dilation: int, sequence of int
305
- An integer or a sequence of `n` integers, giving the
306
- dilation factor to apply in each spatial dimension of `inputs`
307
- (default: 1). Convolution with input dilation `d` is equivalent to
308
- transposed convolution with stride `d`.
309
- rhs_dilation: int, sequence of int
310
- An integer or a sequence of `n` integers, giving the
311
- dilation factor to apply in each spatial dimension of the convolution
312
- kernel (default: 1). Convolution with kernel dilation
313
- is also known as 'atrous convolution'.
314
- groups: int
315
- If specified, divides the input features into groups. default 1.
316
- w_init: Callable, ArrayLike, Initializer
317
- The initializer for the convolutional kernel.
318
- b_init: Optional, Callable, ArrayLike, Initializer
319
- The initializer for the bias.
320
- w_mask: ArrayLike, Callable, Optional
321
- The optional mask of the weights.
322
- mode: Mode
323
- The computation mode of the current object. Default it is `training`.
324
- name: str, Optional
325
- The name of the object.
326
- '''
327
-
328
- Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
329
- Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
330
- Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
331
-
332
-
333
770
  class _ScaledWSConv(_BaseConv):
334
771
  def __init__(
335
772
  self,
@@ -346,19 +783,23 @@ class _ScaledWSConv(_BaseConv):
346
783
  w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
347
784
  b_init: Optional[Union[Callable, ArrayLike]] = None,
348
785
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
786
+ channel_first: bool = False,
349
787
  name: str = None,
350
788
  param_type: type = ParamState,
351
789
  ):
352
- super().__init__(in_size=in_size,
353
- out_channels=out_channels,
354
- kernel_size=kernel_size,
355
- stride=stride,
356
- padding=padding,
357
- lhs_dilation=lhs_dilation,
358
- rhs_dilation=rhs_dilation,
359
- groups=groups,
360
- w_mask=w_mask,
361
- name=name, )
790
+ super().__init__(
791
+ in_size=in_size,
792
+ out_channels=out_channels,
793
+ kernel_size=kernel_size,
794
+ stride=stride,
795
+ padding=padding,
796
+ lhs_dilation=lhs_dilation,
797
+ rhs_dilation=rhs_dilation,
798
+ groups=groups,
799
+ w_mask=w_mask,
800
+ channel_first=channel_first,
801
+ name=name,
802
+ )
362
803
 
363
804
  self.w_initializer = w_init
364
805
  self.b_initializer = b_init
@@ -384,9 +825,14 @@ class _ScaledWSConv(_BaseConv):
384
825
  self.weight = param_type(params)
385
826
 
386
827
  # Evaluate the output shape
828
+ if self.channel_first:
829
+ test_input_shape = (128,) + self.in_size
830
+ else:
831
+ test_input_shape = (128,) + self.in_size
832
+
387
833
  abstract_y = jax.eval_shape(
388
834
  self._conv_op,
389
- jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
835
+ jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
390
836
  params
391
837
  )
392
838
  y_shape = abstract_y.shape[1:]
@@ -394,7 +840,7 @@ class _ScaledWSConv(_BaseConv):
394
840
 
395
841
  def _conv_op(self, x, params):
396
842
  w = params['weight']
397
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
843
+ w = weight_standardization(w, self.eps, params.get('gain', None))
398
844
  if self.w_mask is not None:
399
845
  w = w * self.w_mask
400
846
  y = jax.lax.conv_general_dilated(
@@ -413,89 +859,1152 @@ class _ScaledWSConv(_BaseConv):
413
859
 
414
860
 
415
861
  class ScaledWSConv1d(_ScaledWSConv):
416
- """One-dimensional convolution with weight standardization.
862
+ """
863
+ One-dimensional convolution with weight standardization.
417
864
 
418
- The input should be a 3d array with the shape of ``[B, H, C]``.
865
+ This layer applies weight standardization to the convolutional kernel before
866
+ performing the convolution operation. Weight standardization normalizes the
867
+ weights to have zero mean and unit variance, which can accelerate training
868
+ and improve model performance, especially when combined with group normalization.
869
+
870
+ The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
871
+ L is the sequence length, and C is the number of input channels.
419
872
 
420
873
  Parameters
421
874
  ----------
422
- %s
875
+ in_size : tuple of int
876
+ The input shape without the batch dimension. For Conv1d: (L, C) where L is the sequence
877
+ length and C is the number of input channels. This argument is important as it is used
878
+ to evaluate the output shape.
879
+ out_channels : int
880
+ The number of output channels (also called filters or feature maps). These determine
881
+ the depth of the output feature map.
882
+ kernel_size : int or tuple of int
883
+ The shape of the convolutional kernel. For 1D convolution, can be:
884
+
885
+ - An integer (e.g., 5): creates a kernel of size 5
886
+ - A tuple with one integer (e.g., (5,)): equivalent to the above
887
+ stride : int or tuple of int, optional
888
+ The stride of the convolution. Controls how much the kernel moves at each step.
889
+ Default: 1.
890
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
891
+ The padding strategy. Options:
892
+
893
+ - 'SAME': output length equals input length when stride=1
894
+ - 'VALID': no padding, output length reduced by kernel size
895
+ - int: symmetric padding
896
+ - (pad_before, pad_after): explicit padding for the sequence dimension
897
+
898
+ Default: 'SAME'.
899
+ lhs_dilation : int or tuple of int, optional
900
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
901
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
902
+ Default: 1.
903
+ rhs_dilation : int or tuple of int, optional
904
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
905
+ or dilated convolution. Increases the receptive field without increasing parameters by
906
+ inserting zeros between kernel elements. Useful for capturing long-range dependencies.
907
+ Default: 1.
908
+ groups : int, optional
909
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
910
+
911
+ - groups=1: standard convolution (all-to-all connections)
912
+ - groups>1: grouped convolution (reduces parameters by factor of groups)
913
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
914
+
915
+ Default: 1.
916
+ w_init : Callable or ArrayLike, optional
917
+ Weight initializer for the convolutional kernel. Can be:
918
+
919
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
920
+ - A callable that returns an array given a shape
921
+ - A direct array matching the kernel shape
922
+
923
+ Default: XavierNormal().
924
+ b_init : Callable or ArrayLike or None, optional
925
+ Bias initializer. If None, no bias term is added to the output.
926
+ Default: None.
927
+ ws_gain : bool, optional
928
+ Whether to include a learnable per-channel gain parameter in weight standardization.
929
+ When True, adds a scaling factor that can be learned during training, improving
930
+ model expressiveness. Recommended for most applications.
931
+ Default: True.
932
+ eps : float, optional
933
+ Small constant for numerical stability in weight standardization. Prevents division
934
+ by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
935
+ Default: 1e-4.
936
+ w_mask : ArrayLike or Callable or None, optional
937
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
938
+ element-wise multiplied with the standardized kernel weights during the forward pass.
939
+ Default: None.
940
+ name : str, optional
941
+ Name identifier for this module instance.
942
+ Default: None.
943
+ param_type : type, optional
944
+ The parameter state class to use for managing learnable parameters.
945
+ Default: ParamState.
946
+
947
+ Attributes
948
+ ----------
949
+ in_size : tuple of int
950
+ The input shape (L, C) without batch dimension.
951
+ out_size : tuple of int
952
+ The output shape (L_out, out_channels) without batch dimension.
953
+ in_channels : int
954
+ Number of input channels.
955
+ out_channels : int
956
+ Number of output channels.
957
+ kernel_size : tuple of int
958
+ Size of the convolving kernel.
959
+ weight : ParamState
960
+ The learnable weights (and bias if specified) of the module.
961
+ eps : float
962
+ Small constant for numerical stability in weight standardization.
963
+
964
+ Examples
965
+ --------
966
+ .. code-block:: python
967
+
968
+ >>> import brainstate as brainstate
969
+ >>> import jax.numpy as jnp
970
+ >>>
971
+ >>> # Create a 1D convolution with weight standardization
972
+ >>> conv = brainstate.nn.ScaledWSConv1d(
973
+ ... in_size=(100, 16),
974
+ ... out_channels=32,
975
+ ... kernel_size=5
976
+ ... )
977
+ >>>
978
+ >>> # Apply to input
979
+ >>> x = jnp.ones((4, 100, 16))
980
+ >>> y = conv(x)
981
+ >>> print(y.shape) # (4, 100, 32)
982
+ >>>
983
+ >>> # With custom epsilon and no gain
984
+ >>> conv = brainstate.nn.ScaledWSConv1d(
985
+ ... in_size=(50, 8),
986
+ ... out_channels=16,
987
+ ... kernel_size=3,
988
+ ... ws_gain=False,
989
+ ... eps=1e-5
990
+ ... )
991
+
992
+ Notes
993
+ -----
994
+ **Weight standardization formula:**
995
+
996
+ Weight standardization reparameterizes the convolutional weights as:
997
+
998
+ .. math::
999
+ \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1000
+
1001
+ where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1002
+ of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
1003
+ and :math:`\\epsilon` is a small constant for numerical stability.
1004
+
1005
+ **When to use:**
1006
+
1007
+ This technique is particularly effective when used with Group Normalization
1008
+ instead of Batch Normalization, as it reduces the dependence on batch statistics.
1009
+
1010
+ References
1011
+ ----------
1012
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1013
+ Weight Standardization. arXiv preprint arXiv:1903.10520.
423
1014
  """
424
1015
  __module__ = 'brainstate.nn'
425
1016
  num_spatial_dims: int = 1
426
1017
 
427
1018
 
428
1019
  class ScaledWSConv2d(_ScaledWSConv):
429
- """Two-dimensional convolution with weight standardization.
1020
+ """
1021
+ Two-dimensional convolution with weight standardization.
1022
+
1023
+ This layer applies weight standardization to the convolutional kernel before
1024
+ performing the convolution operation. Weight standardization normalizes the
1025
+ weights to have zero mean and unit variance, improving training dynamics and
1026
+ model generalization, particularly in combination with group normalization.
430
1027
 
431
- The input should be a 4d array with the shape of ``[B, H, W, C]``.
1028
+ The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
1029
+ H is height, W is width, and C is the number of input channels (channels-last format).
432
1030
 
433
1031
  Parameters
434
1032
  ----------
435
- %s
1033
+ in_size : tuple of int
1034
+ The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
1035
+ W is width, and C is the number of input channels. This argument is important as it is
1036
+ used to evaluate the output shape.
1037
+ out_channels : int
1038
+ The number of output channels (also called filters or feature maps). These determine
1039
+ the depth of the output feature map.
1040
+ kernel_size : int or tuple of int
1041
+ The shape of the convolutional kernel. Can be:
1042
+
1043
+ - An integer (e.g., 3): creates a square kernel (3, 3)
1044
+ - A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
1045
+ stride : int or tuple of int, optional
1046
+ The stride of the convolution. Controls how much the kernel moves at each step.
1047
+ Can be:
1048
+
1049
+ - An integer: same stride for both dimensions
1050
+ - A tuple of two integers: (stride_height, stride_width)
1051
+
1052
+ Default: 1.
1053
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1054
+ The padding strategy. Options:
1055
+
1056
+ - 'SAME': output spatial size equals input size when stride=1
1057
+ - 'VALID': no padding, output size reduced by kernel size
1058
+ - int: same symmetric padding for all dimensions
1059
+ - (pad_h, pad_w): different padding for each dimension
1060
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
1061
+
1062
+ Default: 'SAME'.
1063
+ lhs_dilation : int or tuple of int, optional
1064
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
1065
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
1066
+ Default: 1.
1067
+ rhs_dilation : int or tuple of int, optional
1068
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
1069
+ or dilated convolution. Increases the receptive field without increasing parameters by
1070
+ inserting zeros between kernel elements. Useful for semantic segmentation and dense
1071
+ prediction tasks.
1072
+ Default: 1.
1073
+ groups : int, optional
1074
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
1075
+
1076
+ - groups=1: standard convolution (all-to-all connections)
1077
+ - groups>1: grouped convolution (reduces parameters by factor of groups)
1078
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
1079
+
1080
+ Default: 1.
1081
+ w_init : Callable or ArrayLike, optional
1082
+ Weight initializer for the convolutional kernel. Can be:
1083
+
1084
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
1085
+ - A callable that returns an array given a shape
1086
+ - A direct array matching the kernel shape
1087
+
1088
+ Default: XavierNormal().
1089
+ b_init : Callable or ArrayLike or None, optional
1090
+ Bias initializer. If None, no bias term is added to the output.
1091
+ Default: None.
1092
+ ws_gain : bool, optional
1093
+ Whether to include a learnable per-channel gain parameter in weight standardization.
1094
+ When True, adds a scaling factor that can be learned during training, improving
1095
+ model expressiveness. Highly recommended when using with Group Normalization.
1096
+ Default: True.
1097
+ eps : float, optional
1098
+ Small constant for numerical stability in weight standardization. Prevents division
1099
+ by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
1100
+ Default: 1e-4.
1101
+ w_mask : ArrayLike or Callable or None, optional
1102
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
1103
+ element-wise multiplied with the standardized kernel weights during the forward pass.
1104
+ Default: None.
1105
+ name : str, optional
1106
+ Name identifier for this module instance.
1107
+ Default: None.
1108
+ param_type : type, optional
1109
+ The parameter state class to use for managing learnable parameters.
1110
+ Default: ParamState.
1111
+
1112
+ Attributes
1113
+ ----------
1114
+ in_size : tuple of int
1115
+ The input shape (H, W, C) without batch dimension.
1116
+ out_size : tuple of int
1117
+ The output shape (H_out, W_out, out_channels) without batch dimension.
1118
+ in_channels : int
1119
+ Number of input channels.
1120
+ out_channels : int
1121
+ Number of output channels.
1122
+ kernel_size : tuple of int
1123
+ Size of the convolving kernel (height, width).
1124
+ weight : ParamState
1125
+ The learnable weights (and bias if specified) of the module.
1126
+ eps : float
1127
+ Small constant for numerical stability in weight standardization.
1128
+
1129
+ Examples
1130
+ --------
1131
+ .. code-block:: python
1132
+
1133
+ >>> import brainstate as brainstate
1134
+ >>> import jax.numpy as jnp
1135
+ >>>
1136
+ >>> # Create a 2D convolution with weight standardization
1137
+ >>> conv = brainstate.nn.ScaledWSConv2d(
1138
+ ... in_size=(64, 64, 3),
1139
+ ... out_channels=32,
1140
+ ... kernel_size=3
1141
+ ... )
1142
+ >>>
1143
+ >>> # Apply to input
1144
+ >>> x = jnp.ones((8, 64, 64, 3))
1145
+ >>> y = conv(x)
1146
+ >>> print(y.shape) # (8, 64, 64, 32)
1147
+ >>>
1148
+ >>> # Combine with custom settings for ResNet-style architecture
1149
+ >>> conv = brainstate.nn.ScaledWSConv2d(
1150
+ ... in_size=(224, 224, 3),
1151
+ ... out_channels=64,
1152
+ ... kernel_size=7,
1153
+ ... stride=2,
1154
+ ... padding='SAME',
1155
+ ... ws_gain=True,
1156
+ ... b_init=brainstate.init.ZeroInit()
1157
+ ... )
1158
+ >>>
1159
+ >>> # Depthwise separable convolution with weight standardization
1160
+ >>> conv = brainstate.nn.ScaledWSConv2d(
1161
+ ... in_size=(32, 32, 128),
1162
+ ... out_channels=128,
1163
+ ... kernel_size=3,
1164
+ ... groups=128,
1165
+ ... ws_gain=False
1166
+ ... )
1167
+
1168
+ Notes
1169
+ -----
1170
+ **Weight standardization formula:**
1171
+
1172
+ Weight standardization reparameterizes the convolutional weights as:
1173
+
1174
+ .. math::
1175
+ \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1176
+
1177
+ where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1178
+ of the weights computed per output channel, :math:`g` is a learnable gain
1179
+ parameter (if ws_gain=True), and :math:`\\epsilon` is a small constant.
1180
+
1181
+ **Benefits:**
1182
+
1183
+ - Reduces internal covariate shift
1184
+ - Smooths the loss landscape
1185
+ - Works well with Group Normalization
1186
+ - Improves training stability with small batch sizes
1187
+ - Enables training deeper networks more easily
1188
+
1189
+ References
1190
+ ----------
1191
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1192
+ Weight Standardization. arXiv preprint arXiv:1903.10520.
436
1193
  """
437
1194
  __module__ = 'brainstate.nn'
438
1195
  num_spatial_dims: int = 2
439
1196
 
440
1197
 
441
1198
  class ScaledWSConv3d(_ScaledWSConv):
442
- """Three-dimensional convolution with weight standardization.
1199
+ """
1200
+ Three-dimensional convolution with weight standardization.
1201
+
1202
+ This layer applies weight standardization to the convolutional kernel before
1203
+ performing the 3D convolution operation. Weight standardization normalizes the
1204
+ weights to have zero mean and unit variance, which improves training dynamics
1205
+ especially for 3D networks that are typically deeper and more parameter-heavy.
443
1206
 
444
- The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
1207
+ The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
1208
+ H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
445
1209
 
446
1210
  Parameters
447
1211
  ----------
448
- %s
1212
+ in_size : tuple of int
1213
+ The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
1214
+ W is width, D is depth, and C is the number of input channels. This argument is important
1215
+ as it is used to evaluate the output shape.
1216
+ out_channels : int
1217
+ The number of output channels (also called filters or feature maps). These determine
1218
+ the depth of the output feature map.
1219
+ kernel_size : int or tuple of int
1220
+ The shape of the convolutional kernel. Can be:
1221
+
1222
+ - An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
1223
+ - A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
1224
+ stride : int or tuple of int, optional
1225
+ The stride of the convolution. Controls how much the kernel moves at each step.
1226
+ Can be:
1227
+
1228
+ - An integer: same stride for all dimensions
1229
+ - A tuple of three integers: (stride_h, stride_w, stride_d)
1230
+
1231
+ Default: 1.
1232
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1233
+ The padding strategy. Options:
1234
+
1235
+ - 'SAME': output spatial size equals input size when stride=1
1236
+ - 'VALID': no padding, output size reduced by kernel size
1237
+ - int: same symmetric padding for all dimensions
1238
+ - (pad_h, pad_w, pad_d): different padding for each dimension
1239
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
1240
+
1241
+ Default: 'SAME'.
1242
+ lhs_dilation : int or tuple of int, optional
1243
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
1244
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
1245
+ Default: 1.
1246
+ rhs_dilation : int or tuple of int, optional
1247
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
1248
+ or dilated convolution. Increases the receptive field without increasing parameters by
1249
+ inserting zeros between kernel elements. Particularly valuable for 3D to capture
1250
+ multi-scale temporal/spatial context efficiently.
1251
+ Default: 1.
1252
+ groups : int, optional
1253
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
1254
+
1255
+ - groups=1: standard convolution (all-to-all connections)
1256
+ - groups>1: grouped convolution (critical for reducing 3D conv computational cost)
1257
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
1258
+
1259
+ Default: 1.
1260
+ w_init : Callable or ArrayLike, optional
1261
+ Weight initializer for the convolutional kernel. Can be:
1262
+
1263
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
1264
+ - A callable that returns an array given a shape
1265
+ - A direct array matching the kernel shape
1266
+
1267
+ Default: XavierNormal().
1268
+ b_init : Callable or ArrayLike or None, optional
1269
+ Bias initializer. If None, no bias term is added to the output.
1270
+ Default: None.
1271
+ ws_gain : bool, optional
1272
+ Whether to include a learnable per-channel gain parameter in weight standardization.
1273
+ When True, adds a scaling factor that can be learned during training, improving
1274
+ model expressiveness. Particularly beneficial for deep 3D networks.
1275
+ Default: True.
1276
+ eps : float, optional
1277
+ Small constant for numerical stability in weight standardization. Prevents division
1278
+ by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
1279
+ Default: 1e-4.
1280
+ w_mask : ArrayLike or Callable or None, optional
1281
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
1282
+ element-wise multiplied with the standardized kernel weights during the forward pass.
1283
+ Default: None.
1284
+ name : str, optional
1285
+ Name identifier for this module instance.
1286
+ Default: None.
1287
+ param_type : type, optional
1288
+ The parameter state class to use for managing learnable parameters.
1289
+ Default: ParamState.
1290
+
1291
+ Attributes
1292
+ ----------
1293
+ in_size : tuple of int
1294
+ The input shape (H, W, D, C) without batch dimension.
1295
+ out_size : tuple of int
1296
+ The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
1297
+ in_channels : int
1298
+ Number of input channels.
1299
+ out_channels : int
1300
+ Number of output channels.
1301
+ kernel_size : tuple of int
1302
+ Size of the convolving kernel (height, width, depth).
1303
+ weight : ParamState
1304
+ The learnable weights (and bias if specified) of the module.
1305
+ eps : float
1306
+ Small constant for numerical stability in weight standardization.
1307
+
1308
+ Examples
1309
+ --------
1310
+ .. code-block:: python
1311
+
1312
+ >>> import brainstate as brainstate
1313
+ >>> import jax.numpy as jnp
1314
+ >>>
1315
+ >>> # Create a 3D convolution with weight standardization for video
1316
+ >>> conv = brainstate.nn.ScaledWSConv3d(
1317
+ ... in_size=(16, 64, 64, 3),
1318
+ ... out_channels=32,
1319
+ ... kernel_size=3
1320
+ ... )
1321
+ >>>
1322
+ >>> # Apply to input
1323
+ >>> x = jnp.ones((4, 16, 64, 64, 3))
1324
+ >>> y = conv(x)
1325
+ >>> print(y.shape) # (4, 16, 64, 64, 32)
1326
+ >>>
1327
+ >>> # For medical imaging with custom parameters
1328
+ >>> conv = brainstate.nn.ScaledWSConv3d(
1329
+ ... in_size=(32, 32, 32, 1),
1330
+ ... out_channels=64,
1331
+ ... kernel_size=(3, 3, 3),
1332
+ ... stride=2,
1333
+ ... ws_gain=True,
1334
+ ... eps=1e-5,
1335
+ ... b_init=brainstate.init.Constant(0.01)
1336
+ ... )
1337
+ >>>
1338
+ >>> # 3D grouped convolution with weight standardization
1339
+ >>> conv = brainstate.nn.ScaledWSConv3d(
1340
+ ... in_size=(8, 16, 16, 64),
1341
+ ... out_channels=64,
1342
+ ... kernel_size=3,
1343
+ ... groups=8,
1344
+ ... ws_gain=False
1345
+ ... )
1346
+
1347
+ Notes
1348
+ -----
1349
+ **Weight standardization formula:**
1350
+
1351
+ Weight standardization reparameterizes the convolutional weights as:
1352
+
1353
+ .. math::
1354
+ \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1355
+
1356
+ where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1357
+ of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
1358
+ and :math:`\\epsilon` is a small constant for numerical stability.
1359
+
1360
+ **Why weight standardization for 3D:**
1361
+
1362
+ For 3D convolutions, weight standardization is particularly beneficial because:
1363
+
1364
+ - 3D networks are typically much deeper and harder to train
1365
+ - Reduces sensitivity to weight initialization
1366
+ - Improves gradient flow through very deep networks
1367
+ - Works well with limited computational resources (small batches)
1368
+ - Compatible with Group Normalization for batch-independent normalization
1369
+
1370
+ **Applications:**
1371
+
1372
+ Video understanding, medical imaging (CT, MRI scans), 3D object recognition,
1373
+ and temporal sequence modeling.
1374
+
1375
+ References
1376
+ ----------
1377
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1378
+ Weight Standardization. arXiv preprint arXiv:1903.10520.
449
1379
  """
450
1380
  __module__ = 'brainstate.nn'
451
1381
  num_spatial_dims: int = 3
452
1382
 
453
1383
 
454
- _ws_conv_doc = '''
455
- in_size: tuple of int
456
- The input shape, without the batch size. This argument is important, since it is
457
- used to evaluate the shape of the output.
458
- out_channels: int
459
- The number of output channels.
460
- kernel_size: int, sequence of int
461
- The shape of the convolutional kernel.
462
- For 1D convolution, the kernel size can be passed as an integer.
463
- For all other cases, it must be a sequence of integers.
464
- stride: int, sequence of int
465
- An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
466
- padding: str, int, sequence of int, sequence of tuple
467
- Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
468
- high)` integer pairs that give the padding to apply before and after each
469
- spatial dimension.
470
- lhs_dilation: int, sequence of int
471
- An integer or a sequence of `n` integers, giving the
472
- dilation factor to apply in each spatial dimension of `inputs`
473
- (default: 1). Convolution with input dilation `d` is equivalent to
474
- transposed convolution with stride `d`.
475
- rhs_dilation: int, sequence of int
476
- An integer or a sequence of `n` integers, giving the
477
- dilation factor to apply in each spatial dimension of the convolution
478
- kernel (default: 1). Convolution with kernel dilation
479
- is also known as 'atrous convolution'.
480
- groups: int
481
- If specified, divides the input features into groups. default 1.
482
- w_init: Callable, ArrayLike, Initializer
483
- The initializer for the convolutional kernel.
484
- b_init: Optional, Callable, ArrayLike, Initializer
485
- The initializer for the bias.
486
- ws_gain: bool
487
- Whether to add a gain term for the weight standarization. The default is `True`.
488
- eps: float
489
- The epsilon value for numerical stability.
490
- w_mask: ArrayLike, Callable, Optional
491
- The optional mask of the weights.
492
- mode: Mode
493
- The computation mode of the current object. Default it is `training`.
494
- name: str, Optional
495
- The name of the object.
496
-
497
- '''
498
-
499
- ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
500
- ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
501
- ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
1384
+ class _ConvTranspose(_BaseConv):
1385
+ """Base class for transposed convolution layers."""
1386
+ num_spatial_dims: int = None
1387
+
1388
+ def __init__(
1389
+ self,
1390
+ in_size: Sequence[int],
1391
+ out_channels: int,
1392
+ kernel_size: Union[int, Tuple[int, ...]],
1393
+ stride: Union[int, Tuple[int, ...]] = 1,
1394
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
1395
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
1396
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
1397
+ groups: int = 1,
1398
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
1399
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
1400
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
1401
+ channel_first: bool = False,
1402
+ name: str = None,
1403
+ param_type: type = ParamState,
1404
+ ):
1405
+ # Initialize with transpose=True for dimension numbers
1406
+ Module.__init__(self, name=name)
1407
+
1408
+ # general parameters
1409
+ assert self.num_spatial_dims + 1 == len(in_size)
1410
+ self.in_size = tuple(in_size)
1411
+ self.channel_first = channel_first
1412
+ self.channels_last = not channel_first
1413
+
1414
+ # Determine in_channels based on channel_first
1415
+ if self.channel_first:
1416
+ self.in_channels = in_size[0]
1417
+ else:
1418
+ self.in_channels = in_size[-1]
1419
+
1420
+ self.out_channels = out_channels
1421
+ self.stride = replicate(stride, self.num_spatial_dims, 'stride')
1422
+ self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
1423
+ self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
1424
+ self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
1425
+ self.groups = groups
1426
+ self.dimension_numbers = to_dimension_numbers(
1427
+ self.num_spatial_dims,
1428
+ channels_last=self.channels_last,
1429
+ transpose=True # Key difference from regular Conv
1430
+ )
1431
+
1432
+ # the padding parameter
1433
+ # For transposed convolution, string padding needs to be converted to explicit padding
1434
+ # when using lhs_dilation (stride) > 1
1435
+ if isinstance(padding, str):
1436
+ assert padding in ['SAME', 'VALID']
1437
+ self.padding_mode = padding
1438
+ # Compute explicit padding for transposed convolution
1439
+ if max(self.stride) > 1:
1440
+ # For transposed conv with stride, compute padding to achieve desired output size
1441
+ spatial_in_size = self.in_size[:-1] if not self.channel_first else self.in_size[1:]
1442
+ if padding == 'SAME':
1443
+ # For SAME padding with transposed conv: output_size = input_size * stride
1444
+ # Compute required padding to achieve this
1445
+ explicit_padding = []
1446
+ for i, (k, s, in_dim) in enumerate(zip(self.kernel_size, self.stride, spatial_in_size)):
1447
+ # Desired output size
1448
+ out_dim = in_dim * s
1449
+ # Calculate total padding needed
1450
+ # For transposed conv: out = (in - 1) * stride + kernel - 2 * pad
1451
+ # Solving for pad: pad = (kernel + (in-1) * stride - out) // 2
1452
+ total_pad = max(k + (in_dim - 1) * s - out_dim, 0)
1453
+ pad_left = total_pad // 2
1454
+ pad_right = total_pad - pad_left
1455
+ explicit_padding.append((pad_left, pad_right))
1456
+ padding = tuple(explicit_padding)
1457
+ else: # 'VALID'
1458
+ # For VALID padding: no padding
1459
+ padding = tuple((0, 0) for _ in range(self.num_spatial_dims))
1460
+ # If stride is 1, keep string padding
1461
+ elif isinstance(padding, int):
1462
+ self.padding_mode = 'explicit'
1463
+ padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
1464
+ elif isinstance(padding, (tuple, list)):
1465
+ self.padding_mode = 'explicit'
1466
+ if isinstance(padding[0], int):
1467
+ padding = (padding,) * self.num_spatial_dims
1468
+ elif isinstance(padding[0], (tuple, list)):
1469
+ if len(padding) == 1:
1470
+ padding = tuple(padding) * self.num_spatial_dims
1471
+ else:
1472
+ if len(padding) != self.num_spatial_dims:
1473
+ raise ValueError(
1474
+ f"Padding {padding} must be a Tuple[int, int], "
1475
+ f"or sequence of Tuple[int, int] with length 1, "
1476
+ f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
1477
+ )
1478
+ padding = tuple(padding)
1479
+ else:
1480
+ raise ValueError
1481
+ self.padding = padding
1482
+
1483
+ # the number of in-/out-channels
1484
+ assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
1485
+ assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
1486
+
1487
+ # kernel shape for transpose conv
1488
+ # When transpose=True in dimension_numbers, kernel is (spatial..., out_channels, in_channels // groups)
1489
+ # This matches JAX's expectation for transposed convolution
1490
+ kernel_shape = tuple(self.kernel_size) + (self.out_channels, self.in_channels // self.groups)
1491
+ self.kernel_shape = kernel_shape
1492
+ self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
1493
+
1494
+ self.w_initializer = w_init
1495
+ self.b_initializer = b_init
1496
+
1497
+ # --- weights --- #
1498
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
1499
+ params = dict(weight=weight)
1500
+ if self.b_initializer is not None:
1501
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
1502
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
1503
+ params['bias'] = bias
1504
+
1505
+ # The weight operation
1506
+ self.weight = param_type(params)
1507
+
1508
+ # Evaluate the output shape
1509
+ test_input_shape = (128,) + self.in_size
1510
+ abstract_y = jax.eval_shape(
1511
+ self._conv_op,
1512
+ jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
1513
+ params
1514
+ )
1515
+ y_shape = abstract_y.shape[1:]
1516
+ self.out_size = y_shape
1517
+
1518
+ def _conv_op(self, x, params):
1519
+ w = params['weight']
1520
+ if self.w_mask is not None:
1521
+ w = w * self.w_mask
1522
+ # For transposed convolution:
1523
+ # - window_strides should be (1,1,...) - no striding in the conv operation
1524
+ # - lhs_dilation should be the stride - this creates the upsampling effect
1525
+ window_strides = (1,) * self.num_spatial_dims
1526
+ y = jax.lax.conv_general_dilated(
1527
+ lhs=x,
1528
+ rhs=w,
1529
+ window_strides=window_strides,
1530
+ padding=self.padding,
1531
+ lhs_dilation=self.stride, # For transpose conv, use stride as lhs_dilation
1532
+ rhs_dilation=self.rhs_dilation,
1533
+ feature_group_count=self.groups,
1534
+ dimension_numbers=self.dimension_numbers
1535
+ )
1536
+ if 'bias' in params:
1537
+ y = y + params['bias']
1538
+ return y
1539
+
1540
+
1541
+ class ConvTranspose1d(_ConvTranspose):
1542
+ """
1543
+ One-dimensional transposed convolution layer (also known as deconvolution).
1544
+
1545
+ Applies a 1D transposed convolution over an input signal. Transposed convolution
1546
+ is used for upsampling, reversing the spatial transformation of a regular convolution.
1547
+ It's commonly used in autoencoders, GANs, and semantic segmentation networks.
1548
+
1549
+ The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
1550
+ L is the sequence length, and C is the number of input channels (channels-last format).
1551
+
1552
+ Parameters
1553
+ ----------
1554
+ in_size : tuple of int
1555
+ The input shape without the batch dimension. For ConvTranspose1d: (L, C) where L
1556
+ is the sequence length and C is the number of input channels.
1557
+ out_channels : int
1558
+ The number of output channels (feature maps) produced by the transposed convolution.
1559
+ kernel_size : int or tuple of int
1560
+ The shape of the convolutional kernel. For 1D, can be an integer or a single-element tuple.
1561
+ stride : int or tuple of int, optional
1562
+ The stride of the transposed convolution. Larger strides produce larger output sizes,
1563
+ which is the opposite behavior of regular convolution. Default: 1.
1564
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1565
+ The padding strategy. Options:
1566
+
1567
+ - 'SAME': output length approximately equals input_length * stride
1568
+ - 'VALID': no padding, maximum output size
1569
+ - int: symmetric padding
1570
+ - (pad_before, pad_after): explicit padding for the sequence dimension
1571
+
1572
+ Default: 'SAME'.
1573
+ lhs_dilation : int or tuple of int, optional
1574
+ The dilation factor for the input. For transposed convolution, this is typically
1575
+ set equal to stride internally. Default: 1.
1576
+ rhs_dilation : int or tuple of int, optional
1577
+ The dilation factor for the kernel. Increases the receptive field without increasing
1578
+ parameters by inserting zeros between kernel elements. Default: 1.
1579
+ groups : int, optional
1580
+ Number of groups for grouped transposed convolution. Both `in_channels` and
1581
+ `out_channels` must be divisible by `groups`. Default: 1.
1582
+ w_init : Callable or ArrayLike, optional
1583
+ The initializer for the convolutional kernel weights. Default: XavierNormal().
1584
+ b_init : Callable or ArrayLike or None, optional
1585
+ The initializer for the bias. If None, no bias is added. Default: None.
1586
+ w_mask : ArrayLike or Callable or None, optional
1587
+ An optional mask applied to the weights during forward pass. Default: None.
1588
+ channel_first : bool, optional
1589
+ If True, uses channels-first format (e.g., [B, C, L]). If False, uses channels-last
1590
+ format (e.g., [B, L, C]). Default: False (channels-last, JAX convention).
1591
+ name : str, optional
1592
+ The name of the module. Default: None.
1593
+ param_type : type, optional
1594
+ The type of parameter state to use. Default: ParamState.
1595
+
1596
+ Attributes
1597
+ ----------
1598
+ in_size : tuple of int
1599
+ The input shape (L, C) without batch dimension.
1600
+ out_size : tuple of int
1601
+ The output shape (L_out, out_channels) without batch dimension.
1602
+ in_channels : int
1603
+ Number of input channels.
1604
+ out_channels : int
1605
+ Number of output channels.
1606
+ kernel_size : tuple of int
1607
+ Size of the convolving kernel.
1608
+ weight : ParamState
1609
+ The learnable weights (and bias if specified) of the module.
1610
+
1611
+ Examples
1612
+ --------
1613
+ .. code-block:: python
1614
+
1615
+ >>> import brainstate as brainstate
1616
+ >>> import jax.numpy as jnp
1617
+ >>>
1618
+ >>> # Create a 1D transposed convolution layer for upsampling
1619
+ >>> conv_transpose = brainstate.nn.ConvTranspose1d(
1620
+ ... in_size=(28, 16),
1621
+ ... out_channels=8,
1622
+ ... kernel_size=4,
1623
+ ... stride=2
1624
+ ... )
1625
+ >>>
1626
+ >>> # Apply to input: batch_size=2, length=28, channels=16
1627
+ >>> x = jnp.ones((2, 28, 16))
1628
+ >>> y = conv_transpose(x)
1629
+ >>> print(y.shape) # Output will be upsampled
1630
+ >>>
1631
+ >>> # Without batch dimension
1632
+ >>> x_single = jnp.ones((28, 16))
1633
+ >>> y_single = conv_transpose(x_single)
1634
+ >>>
1635
+ >>> # Channels-first format (PyTorch style)
1636
+ >>> conv_transpose = brainstate.nn.ConvTranspose1d(
1637
+ ... in_size=(16, 28),
1638
+ ... out_channels=8,
1639
+ ... kernel_size=4,
1640
+ ... stride=2,
1641
+ ... channel_first=True
1642
+ ... )
1643
+ >>> x = jnp.ones((2, 16, 28))
1644
+ >>> y = conv_transpose(x)
1645
+
1646
+ Notes
1647
+ -----
1648
+ **Output dimensions:**
1649
+
1650
+ Unlike regular convolution, transposed convolution increases spatial dimensions.
1651
+ With stride > 1, the output is larger than the input:
1652
+
1653
+ - output_length ≈ input_length * stride (depends on padding and kernel size)
1654
+
1655
+ **Relationship to regular convolution:**
1656
+
1657
+ Transposed convolution performs the gradient computation of a regular convolution
1658
+ with respect to its input. It's sometimes called "deconvolution" but this term
1659
+ is mathematically imprecise.
1660
+
1661
+ **Common use cases:**
1662
+
1663
+ - Upsampling in encoder-decoder architectures
1664
+ - Generative models (GANs, VAEs)
1665
+ - Semantic segmentation (U-Net, FCN)
1666
+ - Super-resolution networks
1667
+
1668
+ **Comparison with PyTorch:**
1669
+
1670
+ - PyTorch uses channels-first by default; BrainState uses channels-last
1671
+ - Set `channel_first=True` for PyTorch-compatible format
1672
+ - PyTorch's `output_padding` is handled through padding parameter
1673
+ """
1674
+ __module__ = 'brainstate.nn'
1675
+ num_spatial_dims: int = 1
1676
+
1677
+
1678
+ class ConvTranspose2d(_ConvTranspose):
1679
+ """
1680
+ Two-dimensional transposed convolution layer (also known as deconvolution).
1681
+
1682
+ Applies a 2D transposed convolution over an input signal. Transposed convolution
1683
+ is the gradient of a regular convolution with respect to its input, commonly used
1684
+ for upsampling feature maps in encoder-decoder architectures, GANs, and segmentation.
1685
+
1686
+ The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
1687
+ H is height, W is width, and C is the number of input channels (channels-last format).
1688
+
1689
+ Parameters
1690
+ ----------
1691
+ in_size : tuple of int
1692
+ The input shape without the batch dimension. For ConvTranspose2d: (H, W, C) where
1693
+ H is height, W is width, and C is the number of input channels.
1694
+ out_channels : int
1695
+ The number of output channels (feature maps) produced by the transposed convolution.
1696
+ kernel_size : int or tuple of int
1697
+ The shape of the convolutional kernel. Can be:
1698
+
1699
+ - An integer (e.g., 4): creates a square kernel (4, 4)
1700
+ - A tuple of two integers (e.g., (4, 4)): creates a (height, width) kernel
1701
+ stride : int or tuple of int, optional
1702
+ The stride of the transposed convolution. Controls the upsampling factor.
1703
+ Can be:
1704
+
1705
+ - An integer: same stride for both dimensions
1706
+ - A tuple of two integers: (stride_height, stride_width)
1707
+
1708
+ Larger strides produce larger outputs. Default: 1.
1709
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1710
+ The padding strategy. Options:
1711
+
1712
+ - 'SAME': output size approximately equals input_size * stride
1713
+ - 'VALID': no padding, maximum output size
1714
+ - int: same symmetric padding for all dimensions
1715
+ - (pad_h, pad_w): different padding for each dimension
1716
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
1717
+
1718
+ Default: 'SAME'.
1719
+ lhs_dilation : int or tuple of int, optional
1720
+ The dilation factor for the input. For transposed convolution, this is typically
1721
+ set equal to stride internally. Default: 1.
1722
+ rhs_dilation : int or tuple of int, optional
1723
+ The dilation factor for the kernel. Increases the receptive field without increasing
1724
+ parameters by inserting zeros between kernel elements. Default: 1.
1725
+ groups : int, optional
1726
+ Number of groups for grouped transposed convolution. Must divide both `in_channels`
1727
+ and `out_channels`. Default: 1.
1728
+ w_init : Callable or ArrayLike, optional
1729
+ Weight initializer for the convolutional kernel. Default: XavierNormal().
1730
+ b_init : Callable or ArrayLike or None, optional
1731
+ Bias initializer. If None, no bias term is added. Default: None.
1732
+ w_mask : ArrayLike or Callable or None, optional
1733
+ Optional weight mask for structured sparsity. Default: None.
1734
+ channel_first : bool, optional
1735
+ If True, uses channels-first format (e.g., [B, C, H, W]). If False, uses channels-last
1736
+ format (e.g., [B, H, W, C]). Default: False (channels-last, JAX convention).
1737
+ name : str, optional
1738
+ Name identifier for this module instance. Default: None.
1739
+ param_type : type, optional
1740
+ The parameter state class to use. Default: ParamState.
1741
+
1742
+ Attributes
1743
+ ----------
1744
+ in_size : tuple of int
1745
+ The input shape (H, W, C) without batch dimension.
1746
+ out_size : tuple of int
1747
+ The output shape (H_out, W_out, out_channels) without batch dimension.
1748
+ in_channels : int
1749
+ Number of input channels.
1750
+ out_channels : int
1751
+ Number of output channels.
1752
+ kernel_size : tuple of int
1753
+ Size of the convolving kernel (height, width).
1754
+ weight : ParamState
1755
+ The learnable weights (and bias if specified) of the module.
1756
+
1757
+ Examples
1758
+ --------
1759
+ .. code-block:: python
1760
+
1761
+ >>> import brainstate as brainstate
1762
+ >>> import jax.numpy as jnp
1763
+ >>>
1764
+ >>> # Create a 2D transposed convolution for upsampling
1765
+ >>> conv_transpose = brainstate.nn.ConvTranspose2d(
1766
+ ... in_size=(32, 32, 64),
1767
+ ... out_channels=32,
1768
+ ... kernel_size=4,
1769
+ ... stride=2
1770
+ ... )
1771
+ >>>
1772
+ >>> # Apply to input: batch_size=8, height=32, width=32, channels=64
1773
+ >>> x = jnp.ones((8, 32, 32, 64))
1774
+ >>> y = conv_transpose(x)
1775
+ >>> print(y.shape) # Output will be approximately (8, 64, 64, 32)
1776
+ >>>
1777
+ >>> # Without batch dimension
1778
+ >>> x_single = jnp.ones((32, 32, 64))
1779
+ >>> y_single = conv_transpose(x_single)
1780
+ >>>
1781
+ >>> # Decoder in autoencoder (upsampling path)
1782
+ >>> decoder = brainstate.nn.ConvTranspose2d(
1783
+ ... in_size=(16, 16, 128),
1784
+ ... out_channels=64,
1785
+ ... kernel_size=4,
1786
+ ... stride=2,
1787
+ ... padding='SAME',
1788
+ ... b_init=brainstate.init.Constant(0.0)
1789
+ ... )
1790
+ >>>
1791
+ >>> # Channels-first format (PyTorch style)
1792
+ >>> conv_transpose = brainstate.nn.ConvTranspose2d(
1793
+ ... in_size=(64, 32, 32),
1794
+ ... out_channels=32,
1795
+ ... kernel_size=4,
1796
+ ... stride=2,
1797
+ ... channel_first=True
1798
+ ... )
1799
+ >>> x = jnp.ones((8, 64, 32, 32))
1800
+ >>> y = conv_transpose(x)
1801
+
1802
+ Notes
1803
+ -----
1804
+ **Output dimensions:**
1805
+
1806
+ Transposed convolution increases spatial dimensions, with the upsampling factor
1807
+ primarily controlled by stride:
1808
+
1809
+ - output_size ≈ input_size * stride (exact size depends on padding and kernel size)
1810
+ - 'SAME' padding: output_size = input_size * stride
1811
+ - 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
1812
+
1813
+ **Relationship to regular convolution:**
1814
+
1815
+ Transposed convolution is the backward pass of a regular convolution. If a regular
1816
+ convolution reduces spatial dimensions from X to Y, a transposed convolution with
1817
+ the same parameters increases dimensions from Y back to approximately X.
1818
+
1819
+ **Common use cases:**
1820
+
1821
+ - Image segmentation (U-Net, SegNet, FCN)
1822
+ - Image-to-image translation (pix2pix, CycleGAN)
1823
+ - Generative models (DCGAN, VAE decoders)
1824
+ - Super-resolution networks
1825
+ - Autoencoders (decoder path)
1826
+
1827
+ **Comparison with PyTorch:**
1828
+
1829
+ - PyTorch uses channels-first by default; BrainState uses channels-last
1830
+ - Set `channel_first=True` for PyTorch-compatible format
1831
+ - Kernel shape convention: PyTorch stores (C_in, C_out, H, W), BrainState uses (H, W, C_out, C_in)
1832
+ - PyTorch's `output_padding` parameter controls output size; use padding parameter here
1833
+
1834
+ **Tips:**
1835
+
1836
+ - Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
1837
+ - Initialize with bilinear upsampling weights for better convergence in segmentation
1838
+ - Combine with batch normalization or group normalization for stable training
1839
+ """
1840
+ __module__ = 'brainstate.nn'
1841
+ num_spatial_dims: int = 2
1842
+
1843
+
1844
+ class ConvTranspose3d(_ConvTranspose):
1845
+ """
1846
+ Three-dimensional transposed convolution layer (also known as deconvolution).
1847
+
1848
+ Applies a 3D transposed convolution over an input signal. Used for upsampling
1849
+ 3D feature maps in video generation, 3D segmentation, and volumetric reconstruction.
1850
+
1851
+ The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
1852
+ H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
1853
+
1854
+ Parameters
1855
+ ----------
1856
+ in_size : tuple of int
1857
+ The input shape without the batch dimension. For ConvTranspose3d: (H, W, D, C) where
1858
+ H is height, W is width, D is depth, and C is the number of input channels.
1859
+ out_channels : int
1860
+ The number of output channels (feature maps) produced by the transposed convolution.
1861
+ kernel_size : int or tuple of int
1862
+ The shape of the convolutional kernel. Can be:
1863
+
1864
+ - An integer (e.g., 4): creates a cubic kernel (4, 4, 4)
1865
+ - A tuple of three integers (e.g., (4, 4, 4)): creates a (height, width, depth) kernel
1866
+ stride : int or tuple of int, optional
1867
+ The stride of the transposed convolution. Controls the upsampling factor.
1868
+ Can be:
1869
+
1870
+ - An integer: same stride for all dimensions
1871
+ - A tuple of three integers: (stride_h, stride_w, stride_d)
1872
+
1873
+ Larger strides produce larger outputs. Default: 1.
1874
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1875
+ The padding strategy. Options:
1876
+
1877
+ - 'SAME': output size approximately equals input_size * stride
1878
+ - 'VALID': no padding, maximum output size
1879
+ - int: same symmetric padding for all dimensions
1880
+ - (pad_h, pad_w, pad_d): different padding for each dimension
1881
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit
1882
+
1883
+ Default: 'SAME'.
1884
+ lhs_dilation : int or tuple of int, optional
1885
+ The dilation factor for the input. For transposed convolution, this is typically
1886
+ set equal to stride internally. Default: 1.
1887
+ rhs_dilation : int or tuple of int, optional
1888
+ The dilation factor for the kernel. Increases the receptive field without increasing
1889
+ parameters. Default: 1.
1890
+ groups : int, optional
1891
+ Number of groups for grouped transposed convolution. Must divide both `in_channels`
1892
+ and `out_channels`. Useful for reducing computational cost in 3D. Default: 1.
1893
+ w_init : Callable or ArrayLike, optional
1894
+ Weight initializer for the convolutional kernel. Default: XavierNormal().
1895
+ b_init : Callable or ArrayLike or None, optional
1896
+ Bias initializer. If None, no bias term is added. Default: None.
1897
+ w_mask : ArrayLike or Callable or None, optional
1898
+ Optional weight mask for structured sparsity. Default: None.
1899
+ channel_first : bool, optional
1900
+ If True, uses channels-first format (e.g., [B, C, H, W, D]). If False, uses channels-last
1901
+ format (e.g., [B, H, W, D, C]). Default: False (channels-last, JAX convention).
1902
+ name : str, optional
1903
+ Name identifier for this module instance. Default: None.
1904
+ param_type : type, optional
1905
+ The parameter state class to use. Default: ParamState.
1906
+
1907
+ Attributes
1908
+ ----------
1909
+ in_size : tuple of int
1910
+ The input shape (H, W, D, C) without batch dimension.
1911
+ out_size : tuple of int
1912
+ The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
1913
+ in_channels : int
1914
+ Number of input channels.
1915
+ out_channels : int
1916
+ Number of output channels.
1917
+ kernel_size : tuple of int
1918
+ Size of the convolving kernel (height, width, depth).
1919
+ weight : ParamState
1920
+ The learnable weights (and bias if specified) of the module.
1921
+
1922
+ Examples
1923
+ --------
1924
+ .. code-block:: python
1925
+
1926
+ >>> import brainstate as brainstate
1927
+ >>> import jax.numpy as jnp
1928
+ >>>
1929
+ >>> # Create a 3D transposed convolution for video upsampling
1930
+ >>> conv_transpose = brainstate.nn.ConvTranspose3d(
1931
+ ... in_size=(8, 16, 16, 64),
1932
+ ... out_channels=32,
1933
+ ... kernel_size=4,
1934
+ ... stride=2
1935
+ ... )
1936
+ >>>
1937
+ >>> # Apply to input: batch_size=4, frames=8, height=16, width=16, channels=64
1938
+ >>> x = jnp.ones((4, 8, 16, 16, 64))
1939
+ >>> y = conv_transpose(x)
1940
+ >>> print(y.shape) # Output will be approximately (4, 16, 32, 32, 32)
1941
+ >>>
1942
+ >>> # Without batch dimension
1943
+ >>> x_single = jnp.ones((8, 16, 16, 64))
1944
+ >>> y_single = conv_transpose(x_single)
1945
+ >>>
1946
+ >>> # For medical imaging reconstruction
1947
+ >>> decoder = brainstate.nn.ConvTranspose3d(
1948
+ ... in_size=(16, 16, 16, 128),
1949
+ ... out_channels=64,
1950
+ ... kernel_size=(4, 4, 4),
1951
+ ... stride=2,
1952
+ ... padding='SAME',
1953
+ ... b_init=brainstate.init.Constant(0.0)
1954
+ ... )
1955
+ >>>
1956
+ >>> # Channels-first format (PyTorch style)
1957
+ >>> conv_transpose = brainstate.nn.ConvTranspose3d(
1958
+ ... in_size=(64, 8, 16, 16),
1959
+ ... out_channels=32,
1960
+ ... kernel_size=4,
1961
+ ... stride=2,
1962
+ ... channel_first=True
1963
+ ... )
1964
+ >>> x = jnp.ones((4, 64, 8, 16, 16))
1965
+ >>> y = conv_transpose(x)
1966
+
1967
+ Notes
1968
+ -----
1969
+ **Output dimensions:**
1970
+
1971
+ Transposed convolution increases spatial dimensions:
1972
+
1973
+ - output_size ≈ input_size * stride (exact size depends on padding and kernel size)
1974
+ - 'SAME' padding: output_size = input_size * stride
1975
+ - 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
1976
+
1977
+ **Computational considerations:**
1978
+
1979
+ 3D transposed convolutions are very computationally expensive. Consider:
1980
+
1981
+ - Using grouped convolutions (groups > 1) to reduce parameters
1982
+ - Smaller kernel sizes
1983
+ - Progressive upsampling (multiple layers with stride=2)
1984
+ - Separable convolutions for large-scale applications
1985
+
1986
+ **Common use cases:**
1987
+
1988
+ - Video generation and prediction
1989
+ - 3D medical image segmentation (U-Net 3D)
1990
+ - Volumetric reconstruction
1991
+ - 3D super-resolution
1992
+ - Video frame interpolation
1993
+ - 3D VAE decoders
1994
+
1995
+ **Comparison with PyTorch:**
1996
+
1997
+ - PyTorch uses channels-first by default; BrainState uses channels-last
1998
+ - Set `channel_first=True` for PyTorch-compatible format
1999
+ - Kernel shape convention differs between frameworks
2000
+ - PyTorch's `output_padding` parameter is handled through padding here
2001
+
2002
+ **Tips:**
2003
+
2004
+ - Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
2005
+ - Group normalization often works better than batch normalization for 3D
2006
+ - Consider using smaller batch sizes due to memory constraints
2007
+ - Progressive upsampling (2x at a time) is more stable than large strides
2008
+ """
2009
+ __module__ = 'brainstate.nn'
2010
+ num_spatial_dims: int = 3