brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_conv.py CHANGED
@@ -1,2010 +1,2010 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import collections.abc
19
- from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
20
-
21
- import brainunit as u
22
- import jax
23
- import jax.numpy as jnp
24
-
25
- from brainstate._state import ParamState
26
- from brainstate.typing import ArrayLike
27
- from . import init as init
28
- from ._module import Module
29
- from ._normalizations import weight_standardization
30
-
31
- T = TypeVar('T')
32
-
33
- __all__ = [
34
- 'Conv1d', 'Conv2d', 'Conv3d',
35
- 'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
36
- 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
37
- ]
38
-
39
-
40
- def to_dimension_numbers(
41
- num_spatial_dims: int,
42
- channels_last: bool,
43
- transpose: bool
44
- ) -> jax.lax.ConvDimensionNumbers:
45
- """
46
- Create a `lax.ConvDimensionNumbers` for the given inputs.
47
-
48
- This function generates the dimension specification needed for JAX's convolution
49
- operations based on the number of spatial dimensions and data format.
50
-
51
- Parameters
52
- ----------
53
- num_spatial_dims : int
54
- The number of spatial dimensions (e.g., 1 for Conv1d, 2 for Conv2d, 3 for Conv3d).
55
- channels_last : bool
56
-
57
- - If True, the input format is channels-last (e.g., [B, H, W, C] for 2D).
58
- - If False, the input format is channels-first (e.g., [B, C, H, W] for 2D).
59
- transpose : bool
60
-
61
- - If True, creates dimension numbers for transposed convolution.
62
- - If False, creates dimension numbers for standard convolution.
63
-
64
- Returns
65
- -------
66
- jax.lax.ConvDimensionNumbers
67
- A named tuple specifying the dimension layout for lhs (input), rhs (kernel),
68
- and output of the convolution operation.
69
-
70
- Examples
71
- --------
72
- .. code-block:: python
73
-
74
- >>> # For 2D convolution with channels-last format
75
- >>> dim_nums = to_dimension_numbers(num_spatial_dims=2, channels_last=True, transpose=False)
76
- >>> print(dim_nums.lhs_spec) # Input layout: (batch, spatial_1, spatial_2, channel)
77
- (0, 3, 1, 2)
78
- """
79
- num_dims = num_spatial_dims + 2
80
- if channels_last:
81
- spatial_dims = tuple(range(1, num_dims - 1))
82
- image_dn = (0, num_dims - 1) + spatial_dims
83
- else:
84
- spatial_dims = tuple(range(2, num_dims))
85
- image_dn = (0, 1) + spatial_dims
86
- if transpose:
87
- kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
88
- else:
89
- kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
90
- return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
91
- rhs_spec=kernel_dn,
92
- out_spec=image_dn)
93
-
94
-
95
- def replicate(
96
- element: Union[T, Sequence[T]],
97
- num_replicate: int,
98
- name: str,
99
- ) -> Tuple[T, ...]:
100
- """
101
- Replicates entry in `element` `num_replicate` times if needed.
102
-
103
- This utility function ensures that parameters like kernel_size, stride, etc.
104
- are properly formatted as tuples with the correct length for multi-dimensional
105
- convolutions.
106
-
107
- Parameters
108
- ----------
109
- element : T or Sequence[T]
110
- The element to replicate. Can be a scalar, string, or sequence.
111
- num_replicate : int
112
- The number of times to replicate the element.
113
- name : str
114
- The name of the parameter (used for error messages).
115
-
116
- Returns
117
- -------
118
- tuple of T
119
- A tuple containing the replicated elements.
120
-
121
- Raises
122
- ------
123
- TypeError
124
- If the element is a sequence with length not equal to 1 or `num_replicate`.
125
-
126
- Examples
127
- --------
128
- .. code-block:: python
129
-
130
- >>> # Replicate a scalar value
131
- >>> replicate(3, 2, 'kernel_size')
132
- (3, 3)
133
- >>>
134
- >>> # Keep a sequence as is if already correct length
135
- >>> replicate((3, 5), 2, 'kernel_size')
136
- (3, 5)
137
- >>>
138
- >>> # Replicate a single-element sequence
139
- >>> replicate([3], 2, 'kernel_size')
140
- (3, 3)
141
- """
142
- if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
143
- return (element,) * num_replicate
144
- elif len(element) == 1:
145
- return tuple(list(element) * num_replicate)
146
- elif len(element) == num_replicate:
147
- return tuple(element)
148
- else:
149
- raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
150
- f"sequence of length {num_replicate}.")
151
-
152
-
153
- class _BaseConv(Module):
154
- # the number of spatial dimensions
155
- num_spatial_dims: int
156
-
157
- # the weight and its operations
158
- weight: ParamState
159
-
160
- def __init__(
161
- self,
162
- in_size: Sequence[int],
163
- out_channels: int,
164
- kernel_size: Union[int, Tuple[int, ...]],
165
- stride: Union[int, Tuple[int, ...]] = 1,
166
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
167
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
168
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
169
- groups: int = 1,
170
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
171
- channel_first: bool = False,
172
- name: str = None,
173
- ):
174
- super().__init__(name=name)
175
-
176
- # general parameters
177
- assert self.num_spatial_dims + 1 == len(in_size)
178
- self.in_size = tuple(in_size)
179
- self.channel_first = channel_first
180
- self.channels_last = not channel_first
181
-
182
- # Determine in_channels based on channel_first
183
- if self.channel_first:
184
- self.in_channels = in_size[0]
185
- else:
186
- self.in_channels = in_size[-1]
187
-
188
- self.out_channels = out_channels
189
- self.stride = replicate(stride, self.num_spatial_dims, 'stride')
190
- self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
191
- self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
192
- self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
193
- self.groups = groups
194
- self.dimension_numbers = to_dimension_numbers(
195
- self.num_spatial_dims,
196
- channels_last=self.channels_last,
197
- transpose=False
198
- )
199
-
200
- # the padding parameter
201
- if isinstance(padding, str):
202
- assert padding in ['SAME', 'VALID']
203
- elif isinstance(padding, int):
204
- padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
205
- elif isinstance(padding, (tuple, list)):
206
- if isinstance(padding[0], int):
207
- padding = (padding,) * self.num_spatial_dims
208
- elif isinstance(padding[0], (tuple, list)):
209
- if len(padding) == 1:
210
- padding = tuple(padding) * self.num_spatial_dims
211
- else:
212
- if len(padding) != self.num_spatial_dims:
213
- raise ValueError(
214
- f"Padding {padding} must be a Tuple[int, int], "
215
- f"or sequence of Tuple[int, int] with length 1, "
216
- f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
217
- )
218
- padding = tuple(padding)
219
- else:
220
- raise ValueError
221
- self.padding = padding
222
-
223
- # the number of in-/out-channels
224
- assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
225
- assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
226
-
227
- # kernel shape and w_mask
228
- kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
229
- self.kernel_shape = kernel_shape
230
- self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
231
-
232
- def _check_input_dim(self, x):
233
- if x.ndim == self.num_spatial_dims + 2:
234
- x_shape = x.shape[1:]
235
- elif x.ndim == self.num_spatial_dims + 1:
236
- x_shape = x.shape
237
- else:
238
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
239
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
240
-
241
- # Check shape matches expected in_size
242
- if self.channel_first:
243
- # For channels-first, expected shape is already (C, spatial...)
244
- expected_shape = self.in_size
245
- else:
246
- # For channels-last, expected shape is (spatial..., C)
247
- expected_shape = self.in_size
248
-
249
- if expected_shape != x_shape:
250
- raise ValueError(f"The expected input shape is {expected_shape}, while we got {x_shape}.")
251
-
252
- def update(self, x):
253
- self._check_input_dim(x)
254
- non_batching = False
255
- if x.ndim == self.num_spatial_dims + 1:
256
- x = u.math.expand_dims(x, 0)
257
- non_batching = True
258
- y = self._conv_op(x, self.weight.value)
259
- return u.math.squeeze(y, axis=0) if non_batching else y
260
-
261
- def _conv_op(self, x, params):
262
- raise NotImplementedError
263
-
264
-
265
- class _Conv(_BaseConv):
266
- num_spatial_dims: int = None
267
-
268
- def __init__(
269
- self,
270
- in_size: Sequence[int],
271
- out_channels: int,
272
- kernel_size: Union[int, Tuple[int, ...]],
273
- stride: Union[int, Tuple[int, ...]] = 1,
274
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
275
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
276
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
277
- groups: int = 1,
278
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
279
- b_init: Optional[Union[Callable, ArrayLike]] = None,
280
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
281
- channel_first: bool = False,
282
- name: str = None,
283
- param_type: type = ParamState,
284
- ):
285
- super().__init__(
286
- in_size=in_size,
287
- out_channels=out_channels,
288
- kernel_size=kernel_size,
289
- stride=stride,
290
- padding=padding,
291
- lhs_dilation=lhs_dilation,
292
- rhs_dilation=rhs_dilation,
293
- groups=groups,
294
- w_mask=w_mask,
295
- channel_first=channel_first,
296
- name=name
297
- )
298
-
299
- self.w_initializer = w_init
300
- self.b_initializer = b_init
301
-
302
- # --- weights --- #
303
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
304
- params = dict(weight=weight)
305
- if self.b_initializer is not None:
306
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
307
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
308
- params['bias'] = bias
309
-
310
- # The weight operation
311
- self.weight = param_type(params)
312
-
313
- # Evaluate the output shape
314
- test_input_shape = (128,) + self.in_size
315
- abstract_y = jax.eval_shape(
316
- self._conv_op,
317
- jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
318
- params
319
- )
320
- y_shape = abstract_y.shape[1:]
321
- self.out_size = y_shape
322
-
323
- def _conv_op(self, x, params):
324
- w = params['weight']
325
- if self.w_mask is not None:
326
- w = w * self.w_mask
327
- y = jax.lax.conv_general_dilated(
328
- lhs=x,
329
- rhs=w,
330
- window_strides=self.stride,
331
- padding=self.padding,
332
- lhs_dilation=self.lhs_dilation,
333
- rhs_dilation=self.rhs_dilation,
334
- feature_group_count=self.groups,
335
- dimension_numbers=self.dimension_numbers
336
- )
337
- if 'bias' in params:
338
- y = y + params['bias']
339
- return y
340
-
341
-
342
- class Conv1d(_Conv):
343
- """
344
- One-dimensional convolution layer.
345
-
346
- Applies a 1D convolution over an input signal composed of several input planes.
347
- The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
348
- L is the sequence length, and C is the number of input channels.
349
-
350
- This layer creates a convolution kernel that is convolved with the layer input
351
- over a single spatial dimension to produce a tensor of outputs.
352
-
353
- Parameters
354
- ----------
355
- in_size : tuple of int
356
- The input shape without the batch dimension. This argument is important as it is
357
- used to evaluate the output shape. For Conv1d: (L, C), Conv2d: (H, W, C), Conv3d: (H, W, D, C).
358
- out_channels : int
359
- The number of output channels (also called filters or feature maps).
360
- kernel_size : int or tuple of int
361
- The shape of the convolutional kernel. For 1D convolution, the kernel size can be
362
- passed as an integer. For 2D and 3D convolutions, it should be a tuple of integers
363
- or a single integer (which will be replicated for all spatial dimensions).
364
- stride : int or tuple of int, optional
365
- The stride of the convolution. An integer or a sequence of `n` integers, representing
366
- the inter-window strides along each spatial dimension. Default: 1.
367
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
368
- The padding strategy. Can be:
369
-
370
- - 'SAME': pads the input so the output has the same shape as input when stride=1
371
- - 'VALID': no padding
372
- - int: symmetric padding applied to all spatial dimensions
373
- - tuple of (low, high): padding for each dimension
374
- - sequence of tuples: explicit padding for each spatial dimension
375
-
376
- Default: 'SAME'.
377
- lhs_dilation : int or tuple of int, optional
378
- The dilation factor for the input. An integer or a sequence of `n` integers, giving
379
- the dilation factor to apply in each spatial dimension of inputs. Convolution with
380
- input dilation `d` is equivalent to transposed convolution with stride `d`.
381
- Default: 1.
382
- rhs_dilation : int or tuple of int, optional
383
- The dilation factor for the kernel. An integer or a sequence of `n` integers, giving
384
- the dilation factor to apply in each spatial dimension of the convolution kernel.
385
- Convolution with kernel dilation is also known as 'atrous convolution', which increases
386
- the receptive field without increasing the number of parameters. Default: 1.
387
- groups : int, optional
388
- Number of groups for grouped convolution. Controls the connections between inputs and
389
- outputs. Both `in_channels` and `out_channels` must be divisible by `groups`. When
390
- groups=1 (default), all inputs are convolved to all outputs. When groups>1, the input
391
- and output channels are divided into groups, and each group is convolved independently.
392
- When groups=in_channels, this becomes a depthwise convolution. Default: 1.
393
- w_init : Callable or ArrayLike, optional
394
- The initializer for the convolutional kernel weights. Can be an initializer instance
395
- or a direct array. Default: XavierNormal().
396
- b_init : Callable or ArrayLike or None, optional
397
- The initializer for the bias. If None, no bias is added. Default: None.
398
- w_mask : ArrayLike or Callable or None, optional
399
- An optional mask applied to the weights during forward pass. Useful for implementing
400
- structured sparsity or custom connectivity patterns. Default: None.
401
- name : str, optional
402
- The name of the module. Default: None.
403
- param_type : type, optional
404
- The type of parameter state to use. Default: ParamState.
405
-
406
- Attributes
407
- ----------
408
- in_size : tuple of int
409
- The input shape (L, C) without batch dimension.
410
- out_size : tuple of int
411
- The output shape (L_out, out_channels) without batch dimension.
412
- in_channels : int
413
- Number of input channels.
414
- out_channels : int
415
- Number of output channels.
416
- kernel_size : tuple of int
417
- Size of the convolving kernel.
418
- weight : ParamState
419
- The learnable weights (and bias if specified) of the module.
420
-
421
- Examples
422
- --------
423
- .. code-block:: python
424
-
425
- >>> import brainstate as brainstate
426
- >>> import jax.numpy as jnp
427
- >>>
428
- >>> # Create a 1D convolution layer
429
- >>> conv = brainstate.nn.Conv1d(in_size=(28, 3), out_channels=16, kernel_size=5)
430
- >>>
431
- >>> # Apply to input: batch_size=2, length=28, channels=3
432
- >>> x = jnp.ones((2, 28, 3))
433
- >>> y = conv(x)
434
- >>> print(y.shape) # (2, 28, 16) with 'SAME' padding
435
- >>>
436
- >>> # Without batch dimension
437
- >>> x_single = jnp.ones((28, 3))
438
- >>> y_single = conv(x_single)
439
- >>> print(y_single.shape) # (28, 16)
440
- >>>
441
- >>> # With custom parameters
442
- >>> conv = brainstate.nn.Conv1d(
443
- ... in_size=(100, 8),
444
- ... out_channels=32,
445
- ... kernel_size=3,
446
- ... stride=2,
447
- ... padding='VALID',
448
- ... b_init=brainstate.init.ZeroInit()
449
- ... )
450
-
451
- Notes
452
- -----
453
- **Output dimensions:**
454
-
455
- The output shape depends on the padding mode:
456
-
457
- - 'SAME': output length = ceil(input_length / stride)
458
- - 'VALID': output length = ceil((input_length - kernel_size + 1) / stride)
459
-
460
- **Grouped convolution:**
461
-
462
- When groups > 1, the convolution becomes a grouped convolution where input and
463
- output channels are divided into groups, reducing computational cost.
464
- """
465
- __module__ = 'brainstate.nn'
466
- num_spatial_dims: int = 1
467
-
468
-
469
- class Conv2d(_Conv):
470
- """
471
- Two-dimensional convolution layer.
472
-
473
- Applies a 2D convolution over an input signal composed of several input planes.
474
- The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
475
- H is height, W is width, and C is the number of input channels (channels-last format).
476
-
477
- This layer creates a convolution kernel that is convolved with the layer input
478
- to produce a tensor of outputs. It is commonly used in computer vision tasks.
479
-
480
- Parameters
481
- ----------
482
- in_size : tuple of int
483
- The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
484
- W is width, and C is the number of input channels. This argument is important as it is
485
- used to evaluate the output shape.
486
- out_channels : int
487
- The number of output channels (also called filters or feature maps). These determine
488
- the depth of the output feature map.
489
- kernel_size : int or tuple of int
490
- The shape of the convolutional kernel. Can be:
491
-
492
- - An integer (e.g., 3): creates a square kernel (3, 3)
493
- - A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
494
- stride : int or tuple of int, optional
495
- The stride of the convolution. Controls how much the kernel moves at each step.
496
- Can be:
497
-
498
- - An integer: same stride for both dimensions
499
- - A tuple of two integers: (stride_height, stride_width)
500
-
501
- Default: 1.
502
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
503
- The padding strategy. Options:
504
-
505
- - 'SAME': output spatial size equals input size when stride=1
506
- - 'VALID': no padding, output size reduced by kernel size
507
- - int: same symmetric padding for all dimensions
508
- - (pad_h, pad_w): different padding for each dimension
509
- - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
510
-
511
- Default: 'SAME'.
512
- lhs_dilation : int or tuple of int, optional
513
- The dilation factor for the input (left-hand side). Controls spacing between input elements.
514
- A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
515
- Default: 1.
516
- rhs_dilation : int or tuple of int, optional
517
- The dilation factor for the kernel (right-hand side). Also known as atrous convolution
518
- or dilated convolution. Increases the receptive field without increasing parameters by
519
- inserting zeros between kernel elements. Useful for capturing multi-scale context.
520
- Default: 1.
521
- groups : int, optional
522
- Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
523
-
524
- - groups=1: standard convolution (all-to-all connections)
525
- - groups>1: grouped convolution (reduces parameters by factor of groups)
526
- - groups=in_channels: depthwise convolution (each input channel convolved separately)
527
-
528
- Default: 1.
529
- w_init : Callable or ArrayLike, optional
530
- Weight initializer for the convolutional kernel. Can be:
531
-
532
- - An initializer instance (e.g., brainstate.init.XavierNormal())
533
- - A callable that returns an array given a shape
534
- - A direct array matching the kernel shape
535
-
536
- Default: XavierNormal().
537
- b_init : Callable or ArrayLike or None, optional
538
- Bias initializer. If None, no bias term is added to the output.
539
- Default: None.
540
- w_mask : ArrayLike or Callable or None, optional
541
- Optional weight mask for structured sparsity or custom connectivity. The mask is
542
- element-wise multiplied with the kernel weights during the forward pass.
543
- Default: None.
544
- name : str, optional
545
- Name identifier for this module instance.
546
- Default: None.
547
- param_type : type, optional
548
- The parameter state class to use for managing learnable parameters.
549
- Default: ParamState.
550
-
551
- Attributes
552
- ----------
553
- in_size : tuple of int
554
- The input shape (H, W, C) without batch dimension.
555
- out_size : tuple of int
556
- The output shape (H_out, W_out, out_channels) without batch dimension.
557
- in_channels : int
558
- Number of input channels.
559
- out_channels : int
560
- Number of output channels.
561
- kernel_size : tuple of int
562
- Size of the convolving kernel (height, width).
563
- weight : ParamState
564
- The learnable weights (and bias if specified) of the module.
565
-
566
- Examples
567
- --------
568
- .. code-block:: python
569
-
570
- >>> import brainstate as brainstate
571
- >>> import jax.numpy as jnp
572
- >>>
573
- >>> # Create a 2D convolution layer
574
- >>> conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
575
- >>>
576
- >>> # Apply to input: batch_size=8, height=32, width=32, channels=3
577
- >>> x = jnp.ones((8, 32, 32, 3))
578
- >>> y = conv(x)
579
- >>> print(y.shape) # (8, 32, 32, 64) with 'SAME' padding
580
- >>>
581
- >>> # Without batch dimension
582
- >>> x_single = jnp.ones((32, 32, 3))
583
- >>> y_single = conv(x_single)
584
- >>> print(y_single.shape) # (32, 32, 64)
585
- >>>
586
- >>> # With custom kernel size and stride
587
- >>> conv = brainstate.nn.Conv2d(
588
- ... in_size=(224, 224, 3),
589
- ... out_channels=128,
590
- ... kernel_size=(5, 5),
591
- ... stride=2,
592
- ... padding='VALID'
593
- ... )
594
- >>>
595
- >>> # Depthwise convolution (groups = in_channels)
596
- >>> conv = brainstate.nn.Conv2d(
597
- ... in_size=(64, 64, 32),
598
- ... out_channels=32,
599
- ... kernel_size=3,
600
- ... groups=32
601
- ... )
602
-
603
- Notes
604
- -----
605
- **Output dimensions:**
606
-
607
- The output spatial dimensions depend on the padding mode:
608
-
609
- - 'SAME': output_size = ceil(input_size / stride)
610
- - 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
611
-
612
- **Grouped convolution:**
613
-
614
- When groups > 1, the input and output channels are divided into groups.
615
- Each group is convolved independently, which can significantly reduce
616
- computational cost while maintaining representational power.
617
- """
618
- __module__ = 'brainstate.nn'
619
- num_spatial_dims: int = 2
620
-
621
-
622
- class Conv3d(_Conv):
623
- """
624
- Three-dimensional convolution layer.
625
-
626
- Applies a 3D convolution over an input signal composed of several input planes.
627
- The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
628
- H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
629
-
630
- This layer is commonly used for processing 3D data such as video sequences or
631
- volumetric medical imaging data.
632
-
633
- Parameters
634
- ----------
635
- in_size : tuple of int
636
- The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
637
- W is width, D is depth, and C is the number of input channels. This argument is important
638
- as it is used to evaluate the output shape.
639
- out_channels : int
640
- The number of output channels (also called filters or feature maps). These determine
641
- the depth of the output feature map.
642
- kernel_size : int or tuple of int
643
- The shape of the convolutional kernel. Can be:
644
-
645
- - An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
646
- - A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
647
- stride : int or tuple of int, optional
648
- The stride of the convolution. Controls how much the kernel moves at each step.
649
- Can be:
650
-
651
- - An integer: same stride for all dimensions
652
- - A tuple of three integers: (stride_h, stride_w, stride_d)
653
- Default: 1.
654
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
655
- The padding strategy. Options:
656
-
657
- - 'SAME': output spatial size equals input size when stride=1
658
- - 'VALID': no padding, output size reduced by kernel size
659
- - int: same symmetric padding for all dimensions
660
- - (pad_h, pad_w, pad_d): different padding for each dimension
661
- - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
662
-
663
- Default: 'SAME'.
664
- lhs_dilation : int or tuple of int, optional
665
- The dilation factor for the input (left-hand side). Controls spacing between input elements.
666
- A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
667
- Default: 1.
668
- rhs_dilation : int or tuple of int, optional
669
- The dilation factor for the kernel (right-hand side). Also known as atrous convolution
670
- or dilated convolution. Increases the receptive field without increasing parameters by
671
- inserting zeros between kernel elements. Particularly useful for 3D data to capture
672
- larger temporal/spatial context.
673
- Default: 1.
674
- groups : int, optional
675
- Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
676
-
677
- - groups=1: standard convolution (all-to-all connections)
678
- - groups>1: grouped convolution (significantly reduces parameters and computation for 3D)
679
- - groups=in_channels: depthwise convolution (each input channel convolved separately)
680
-
681
- Default: 1.
682
- w_init : Callable or ArrayLike, optional
683
- Weight initializer for the convolutional kernel. Can be:
684
-
685
- - An initializer instance (e.g., brainstate.init.XavierNormal())
686
- - A callable that returns an array given a shape
687
- - A direct array matching the kernel shape
688
-
689
- Default: XavierNormal().
690
- b_init : Callable or ArrayLike or None, optional
691
- Bias initializer. If None, no bias term is added to the output.
692
- Default: None.
693
- w_mask : ArrayLike or Callable or None, optional
694
- Optional weight mask for structured sparsity or custom connectivity. The mask is
695
- element-wise multiplied with the kernel weights during the forward pass.
696
- Default: None.
697
- name : str, optional
698
- Name identifier for this module instance.
699
- Default: None.
700
- param_type : type, optional
701
- The parameter state class to use for managing learnable parameters.
702
- Default: ParamState.
703
-
704
- Attributes
705
- ----------
706
- in_size : tuple of int
707
- The input shape (H, W, D, C) without batch dimension.
708
- out_size : tuple of int
709
- The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
710
- in_channels : int
711
- Number of input channels.
712
- out_channels : int
713
- Number of output channels.
714
- kernel_size : tuple of int
715
- Size of the convolving kernel (height, width, depth).
716
- weight : ParamState
717
- The learnable weights (and bias if specified) of the module.
718
-
719
- Examples
720
- --------
721
- .. code-block:: python
722
-
723
- >>> import brainstate as brainstate
724
- >>> import jax.numpy as jnp
725
- >>>
726
- >>> # Create a 3D convolution layer for video data
727
- >>> conv = brainstate.nn.Conv3d(in_size=(16, 64, 64, 3), out_channels=32, kernel_size=3)
728
- >>>
729
- >>> # Apply to input: batch_size=4, frames=16, height=64, width=64, channels=3
730
- >>> x = jnp.ones((4, 16, 64, 64, 3))
731
- >>> y = conv(x)
732
- >>> print(y.shape) # (4, 16, 64, 64, 32) with 'SAME' padding
733
- >>>
734
- >>> # Without batch dimension
735
- >>> x_single = jnp.ones((16, 64, 64, 3))
736
- >>> y_single = conv(x_single)
737
- >>> print(y_single.shape) # (16, 64, 64, 32)
738
- >>>
739
- >>> # For medical imaging with custom parameters
740
- >>> conv = brainstate.nn.Conv3d(
741
- ... in_size=(32, 32, 32, 1),
742
- ... out_channels=64,
743
- ... kernel_size=(3, 3, 3),
744
- ... stride=2,
745
- ... padding='VALID',
746
- ... b_init=brainstate.init.Constant(0.1)
747
- ... )
748
-
749
- Notes
750
- -----
751
- **Output dimensions:**
752
-
753
- The output spatial dimensions depend on the padding mode:
754
-
755
- - 'SAME': output_size = ceil(input_size / stride)
756
- - 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
757
-
758
- **Performance considerations:**
759
-
760
- 3D convolutions are computationally expensive. Consider using:
761
-
762
- - Smaller kernel sizes
763
- - Grouped convolutions (groups > 1)
764
- - Separable convolutions for large-scale applications
765
- """
766
- __module__ = 'brainstate.nn'
767
- num_spatial_dims: int = 3
768
-
769
-
770
- class _ScaledWSConv(_BaseConv):
771
- def __init__(
772
- self,
773
- in_size: Sequence[int],
774
- out_channels: int,
775
- kernel_size: Union[int, Tuple[int, ...]],
776
- stride: Union[int, Tuple[int, ...]] = 1,
777
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
778
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
779
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
780
- groups: int = 1,
781
- ws_gain: bool = True,
782
- eps: float = 1e-4,
783
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
784
- b_init: Optional[Union[Callable, ArrayLike]] = None,
785
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
786
- channel_first: bool = False,
787
- name: str = None,
788
- param_type: type = ParamState,
789
- ):
790
- super().__init__(
791
- in_size=in_size,
792
- out_channels=out_channels,
793
- kernel_size=kernel_size,
794
- stride=stride,
795
- padding=padding,
796
- lhs_dilation=lhs_dilation,
797
- rhs_dilation=rhs_dilation,
798
- groups=groups,
799
- w_mask=w_mask,
800
- channel_first=channel_first,
801
- name=name,
802
- )
803
-
804
- self.w_initializer = w_init
805
- self.b_initializer = b_init
806
-
807
- # --- weights --- #
808
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
809
- params = dict(weight=weight)
810
- if self.b_initializer is not None:
811
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
812
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
813
- params['bias'] = bias
814
-
815
- # gain
816
- if ws_gain:
817
- gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
818
- ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
819
- params['gain'] = ws_gain
820
-
821
- # Epsilon, a small constant to avoid dividing by zero.
822
- self.eps = eps
823
-
824
- # The weight operation
825
- self.weight = param_type(params)
826
-
827
- # Evaluate the output shape
828
- if self.channel_first:
829
- test_input_shape = (128,) + self.in_size
830
- else:
831
- test_input_shape = (128,) + self.in_size
832
-
833
- abstract_y = jax.eval_shape(
834
- self._conv_op,
835
- jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
836
- params
837
- )
838
- y_shape = abstract_y.shape[1:]
839
- self.out_size = y_shape
840
-
841
- def _conv_op(self, x, params):
842
- w = params['weight']
843
- w = weight_standardization(w, self.eps, params.get('gain', None))
844
- if self.w_mask is not None:
845
- w = w * self.w_mask
846
- y = jax.lax.conv_general_dilated(
847
- lhs=x,
848
- rhs=w,
849
- window_strides=self.stride,
850
- padding=self.padding,
851
- lhs_dilation=self.lhs_dilation,
852
- rhs_dilation=self.rhs_dilation,
853
- feature_group_count=self.groups,
854
- dimension_numbers=self.dimension_numbers
855
- )
856
- if 'bias' in params:
857
- y = y + params['bias']
858
- return y
859
-
860
-
861
- class ScaledWSConv1d(_ScaledWSConv):
862
- """
863
- One-dimensional convolution with weight standardization.
864
-
865
- This layer applies weight standardization to the convolutional kernel before
866
- performing the convolution operation. Weight standardization normalizes the
867
- weights to have zero mean and unit variance, which can accelerate training
868
- and improve model performance, especially when combined with group normalization.
869
-
870
- The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
871
- L is the sequence length, and C is the number of input channels.
872
-
873
- Parameters
874
- ----------
875
- in_size : tuple of int
876
- The input shape without the batch dimension. For Conv1d: (L, C) where L is the sequence
877
- length and C is the number of input channels. This argument is important as it is used
878
- to evaluate the output shape.
879
- out_channels : int
880
- The number of output channels (also called filters or feature maps). These determine
881
- the depth of the output feature map.
882
- kernel_size : int or tuple of int
883
- The shape of the convolutional kernel. For 1D convolution, can be:
884
-
885
- - An integer (e.g., 5): creates a kernel of size 5
886
- - A tuple with one integer (e.g., (5,)): equivalent to the above
887
- stride : int or tuple of int, optional
888
- The stride of the convolution. Controls how much the kernel moves at each step.
889
- Default: 1.
890
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
891
- The padding strategy. Options:
892
-
893
- - 'SAME': output length equals input length when stride=1
894
- - 'VALID': no padding, output length reduced by kernel size
895
- - int: symmetric padding
896
- - (pad_before, pad_after): explicit padding for the sequence dimension
897
-
898
- Default: 'SAME'.
899
- lhs_dilation : int or tuple of int, optional
900
- The dilation factor for the input (left-hand side). Controls spacing between input elements.
901
- A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
902
- Default: 1.
903
- rhs_dilation : int or tuple of int, optional
904
- The dilation factor for the kernel (right-hand side). Also known as atrous convolution
905
- or dilated convolution. Increases the receptive field without increasing parameters by
906
- inserting zeros between kernel elements. Useful for capturing long-range dependencies.
907
- Default: 1.
908
- groups : int, optional
909
- Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
910
-
911
- - groups=1: standard convolution (all-to-all connections)
912
- - groups>1: grouped convolution (reduces parameters by factor of groups)
913
- - groups=in_channels: depthwise convolution (each input channel convolved separately)
914
-
915
- Default: 1.
916
- w_init : Callable or ArrayLike, optional
917
- Weight initializer for the convolutional kernel. Can be:
918
-
919
- - An initializer instance (e.g., brainstate.init.XavierNormal())
920
- - A callable that returns an array given a shape
921
- - A direct array matching the kernel shape
922
-
923
- Default: XavierNormal().
924
- b_init : Callable or ArrayLike or None, optional
925
- Bias initializer. If None, no bias term is added to the output.
926
- Default: None.
927
- ws_gain : bool, optional
928
- Whether to include a learnable per-channel gain parameter in weight standardization.
929
- When True, adds a scaling factor that can be learned during training, improving
930
- model expressiveness. Recommended for most applications.
931
- Default: True.
932
- eps : float, optional
933
- Small constant for numerical stability in weight standardization. Prevents division
934
- by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
935
- Default: 1e-4.
936
- w_mask : ArrayLike or Callable or None, optional
937
- Optional weight mask for structured sparsity or custom connectivity. The mask is
938
- element-wise multiplied with the standardized kernel weights during the forward pass.
939
- Default: None.
940
- name : str, optional
941
- Name identifier for this module instance.
942
- Default: None.
943
- param_type : type, optional
944
- The parameter state class to use for managing learnable parameters.
945
- Default: ParamState.
946
-
947
- Attributes
948
- ----------
949
- in_size : tuple of int
950
- The input shape (L, C) without batch dimension.
951
- out_size : tuple of int
952
- The output shape (L_out, out_channels) without batch dimension.
953
- in_channels : int
954
- Number of input channels.
955
- out_channels : int
956
- Number of output channels.
957
- kernel_size : tuple of int
958
- Size of the convolving kernel.
959
- weight : ParamState
960
- The learnable weights (and bias if specified) of the module.
961
- eps : float
962
- Small constant for numerical stability in weight standardization.
963
-
964
- Examples
965
- --------
966
- .. code-block:: python
967
-
968
- >>> import brainstate as brainstate
969
- >>> import jax.numpy as jnp
970
- >>>
971
- >>> # Create a 1D convolution with weight standardization
972
- >>> conv = brainstate.nn.ScaledWSConv1d(
973
- ... in_size=(100, 16),
974
- ... out_channels=32,
975
- ... kernel_size=5
976
- ... )
977
- >>>
978
- >>> # Apply to input
979
- >>> x = jnp.ones((4, 100, 16))
980
- >>> y = conv(x)
981
- >>> print(y.shape) # (4, 100, 32)
982
- >>>
983
- >>> # With custom epsilon and no gain
984
- >>> conv = brainstate.nn.ScaledWSConv1d(
985
- ... in_size=(50, 8),
986
- ... out_channels=16,
987
- ... kernel_size=3,
988
- ... ws_gain=False,
989
- ... eps=1e-5
990
- ... )
991
-
992
- Notes
993
- -----
994
- **Weight standardization formula:**
995
-
996
- Weight standardization reparameterizes the convolutional weights as:
997
-
998
- .. math::
999
- \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1000
-
1001
- where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1002
- of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
1003
- and :math:`\\epsilon` is a small constant for numerical stability.
1004
-
1005
- **When to use:**
1006
-
1007
- This technique is particularly effective when used with Group Normalization
1008
- instead of Batch Normalization, as it reduces the dependence on batch statistics.
1009
-
1010
- References
1011
- ----------
1012
- .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1013
- Weight Standardization. arXiv preprint arXiv:1903.10520.
1014
- """
1015
- __module__ = 'brainstate.nn'
1016
- num_spatial_dims: int = 1
1017
-
1018
-
1019
- class ScaledWSConv2d(_ScaledWSConv):
1020
- """
1021
- Two-dimensional convolution with weight standardization.
1022
-
1023
- This layer applies weight standardization to the convolutional kernel before
1024
- performing the convolution operation. Weight standardization normalizes the
1025
- weights to have zero mean and unit variance, improving training dynamics and
1026
- model generalization, particularly in combination with group normalization.
1027
-
1028
- The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
1029
- H is height, W is width, and C is the number of input channels (channels-last format).
1030
-
1031
- Parameters
1032
- ----------
1033
- in_size : tuple of int
1034
- The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
1035
- W is width, and C is the number of input channels. This argument is important as it is
1036
- used to evaluate the output shape.
1037
- out_channels : int
1038
- The number of output channels (also called filters or feature maps). These determine
1039
- the depth of the output feature map.
1040
- kernel_size : int or tuple of int
1041
- The shape of the convolutional kernel. Can be:
1042
-
1043
- - An integer (e.g., 3): creates a square kernel (3, 3)
1044
- - A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
1045
- stride : int or tuple of int, optional
1046
- The stride of the convolution. Controls how much the kernel moves at each step.
1047
- Can be:
1048
-
1049
- - An integer: same stride for both dimensions
1050
- - A tuple of two integers: (stride_height, stride_width)
1051
-
1052
- Default: 1.
1053
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1054
- The padding strategy. Options:
1055
-
1056
- - 'SAME': output spatial size equals input size when stride=1
1057
- - 'VALID': no padding, output size reduced by kernel size
1058
- - int: same symmetric padding for all dimensions
1059
- - (pad_h, pad_w): different padding for each dimension
1060
- - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
1061
-
1062
- Default: 'SAME'.
1063
- lhs_dilation : int or tuple of int, optional
1064
- The dilation factor for the input (left-hand side). Controls spacing between input elements.
1065
- A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
1066
- Default: 1.
1067
- rhs_dilation : int or tuple of int, optional
1068
- The dilation factor for the kernel (right-hand side). Also known as atrous convolution
1069
- or dilated convolution. Increases the receptive field without increasing parameters by
1070
- inserting zeros between kernel elements. Useful for semantic segmentation and dense
1071
- prediction tasks.
1072
- Default: 1.
1073
- groups : int, optional
1074
- Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
1075
-
1076
- - groups=1: standard convolution (all-to-all connections)
1077
- - groups>1: grouped convolution (reduces parameters by factor of groups)
1078
- - groups=in_channels: depthwise convolution (each input channel convolved separately)
1079
-
1080
- Default: 1.
1081
- w_init : Callable or ArrayLike, optional
1082
- Weight initializer for the convolutional kernel. Can be:
1083
-
1084
- - An initializer instance (e.g., brainstate.init.XavierNormal())
1085
- - A callable that returns an array given a shape
1086
- - A direct array matching the kernel shape
1087
-
1088
- Default: XavierNormal().
1089
- b_init : Callable or ArrayLike or None, optional
1090
- Bias initializer. If None, no bias term is added to the output.
1091
- Default: None.
1092
- ws_gain : bool, optional
1093
- Whether to include a learnable per-channel gain parameter in weight standardization.
1094
- When True, adds a scaling factor that can be learned during training, improving
1095
- model expressiveness. Highly recommended when using with Group Normalization.
1096
- Default: True.
1097
- eps : float, optional
1098
- Small constant for numerical stability in weight standardization. Prevents division
1099
- by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
1100
- Default: 1e-4.
1101
- w_mask : ArrayLike or Callable or None, optional
1102
- Optional weight mask for structured sparsity or custom connectivity. The mask is
1103
- element-wise multiplied with the standardized kernel weights during the forward pass.
1104
- Default: None.
1105
- name : str, optional
1106
- Name identifier for this module instance.
1107
- Default: None.
1108
- param_type : type, optional
1109
- The parameter state class to use for managing learnable parameters.
1110
- Default: ParamState.
1111
-
1112
- Attributes
1113
- ----------
1114
- in_size : tuple of int
1115
- The input shape (H, W, C) without batch dimension.
1116
- out_size : tuple of int
1117
- The output shape (H_out, W_out, out_channels) without batch dimension.
1118
- in_channels : int
1119
- Number of input channels.
1120
- out_channels : int
1121
- Number of output channels.
1122
- kernel_size : tuple of int
1123
- Size of the convolving kernel (height, width).
1124
- weight : ParamState
1125
- The learnable weights (and bias if specified) of the module.
1126
- eps : float
1127
- Small constant for numerical stability in weight standardization.
1128
-
1129
- Examples
1130
- --------
1131
- .. code-block:: python
1132
-
1133
- >>> import brainstate as brainstate
1134
- >>> import jax.numpy as jnp
1135
- >>>
1136
- >>> # Create a 2D convolution with weight standardization
1137
- >>> conv = brainstate.nn.ScaledWSConv2d(
1138
- ... in_size=(64, 64, 3),
1139
- ... out_channels=32,
1140
- ... kernel_size=3
1141
- ... )
1142
- >>>
1143
- >>> # Apply to input
1144
- >>> x = jnp.ones((8, 64, 64, 3))
1145
- >>> y = conv(x)
1146
- >>> print(y.shape) # (8, 64, 64, 32)
1147
- >>>
1148
- >>> # Combine with custom settings for ResNet-style architecture
1149
- >>> conv = brainstate.nn.ScaledWSConv2d(
1150
- ... in_size=(224, 224, 3),
1151
- ... out_channels=64,
1152
- ... kernel_size=7,
1153
- ... stride=2,
1154
- ... padding='SAME',
1155
- ... ws_gain=True,
1156
- ... b_init=brainstate.init.ZeroInit()
1157
- ... )
1158
- >>>
1159
- >>> # Depthwise separable convolution with weight standardization
1160
- >>> conv = brainstate.nn.ScaledWSConv2d(
1161
- ... in_size=(32, 32, 128),
1162
- ... out_channels=128,
1163
- ... kernel_size=3,
1164
- ... groups=128,
1165
- ... ws_gain=False
1166
- ... )
1167
-
1168
- Notes
1169
- -----
1170
- **Weight standardization formula:**
1171
-
1172
- Weight standardization reparameterizes the convolutional weights as:
1173
-
1174
- .. math::
1175
- \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1176
-
1177
- where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1178
- of the weights computed per output channel, :math:`g` is a learnable gain
1179
- parameter (if ws_gain=True), and :math:`\\epsilon` is a small constant.
1180
-
1181
- **Benefits:**
1182
-
1183
- - Reduces internal covariate shift
1184
- - Smooths the loss landscape
1185
- - Works well with Group Normalization
1186
- - Improves training stability with small batch sizes
1187
- - Enables training deeper networks more easily
1188
-
1189
- References
1190
- ----------
1191
- .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1192
- Weight Standardization. arXiv preprint arXiv:1903.10520.
1193
- """
1194
- __module__ = 'brainstate.nn'
1195
- num_spatial_dims: int = 2
1196
-
1197
-
1198
- class ScaledWSConv3d(_ScaledWSConv):
1199
- """
1200
- Three-dimensional convolution with weight standardization.
1201
-
1202
- This layer applies weight standardization to the convolutional kernel before
1203
- performing the 3D convolution operation. Weight standardization normalizes the
1204
- weights to have zero mean and unit variance, which improves training dynamics
1205
- especially for 3D networks that are typically deeper and more parameter-heavy.
1206
-
1207
- The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
1208
- H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
1209
-
1210
- Parameters
1211
- ----------
1212
- in_size : tuple of int
1213
- The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
1214
- W is width, D is depth, and C is the number of input channels. This argument is important
1215
- as it is used to evaluate the output shape.
1216
- out_channels : int
1217
- The number of output channels (also called filters or feature maps). These determine
1218
- the depth of the output feature map.
1219
- kernel_size : int or tuple of int
1220
- The shape of the convolutional kernel. Can be:
1221
-
1222
- - An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
1223
- - A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
1224
- stride : int or tuple of int, optional
1225
- The stride of the convolution. Controls how much the kernel moves at each step.
1226
- Can be:
1227
-
1228
- - An integer: same stride for all dimensions
1229
- - A tuple of three integers: (stride_h, stride_w, stride_d)
1230
-
1231
- Default: 1.
1232
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1233
- The padding strategy. Options:
1234
-
1235
- - 'SAME': output spatial size equals input size when stride=1
1236
- - 'VALID': no padding, output size reduced by kernel size
1237
- - int: same symmetric padding for all dimensions
1238
- - (pad_h, pad_w, pad_d): different padding for each dimension
1239
- - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
1240
-
1241
- Default: 'SAME'.
1242
- lhs_dilation : int or tuple of int, optional
1243
- The dilation factor for the input (left-hand side). Controls spacing between input elements.
1244
- A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
1245
- Default: 1.
1246
- rhs_dilation : int or tuple of int, optional
1247
- The dilation factor for the kernel (right-hand side). Also known as atrous convolution
1248
- or dilated convolution. Increases the receptive field without increasing parameters by
1249
- inserting zeros between kernel elements. Particularly valuable for 3D to capture
1250
- multi-scale temporal/spatial context efficiently.
1251
- Default: 1.
1252
- groups : int, optional
1253
- Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
1254
-
1255
- - groups=1: standard convolution (all-to-all connections)
1256
- - groups>1: grouped convolution (critical for reducing 3D conv computational cost)
1257
- - groups=in_channels: depthwise convolution (each input channel convolved separately)
1258
-
1259
- Default: 1.
1260
- w_init : Callable or ArrayLike, optional
1261
- Weight initializer for the convolutional kernel. Can be:
1262
-
1263
- - An initializer instance (e.g., brainstate.init.XavierNormal())
1264
- - A callable that returns an array given a shape
1265
- - A direct array matching the kernel shape
1266
-
1267
- Default: XavierNormal().
1268
- b_init : Callable or ArrayLike or None, optional
1269
- Bias initializer. If None, no bias term is added to the output.
1270
- Default: None.
1271
- ws_gain : bool, optional
1272
- Whether to include a learnable per-channel gain parameter in weight standardization.
1273
- When True, adds a scaling factor that can be learned during training, improving
1274
- model expressiveness. Particularly beneficial for deep 3D networks.
1275
- Default: True.
1276
- eps : float, optional
1277
- Small constant for numerical stability in weight standardization. Prevents division
1278
- by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
1279
- Default: 1e-4.
1280
- w_mask : ArrayLike or Callable or None, optional
1281
- Optional weight mask for structured sparsity or custom connectivity. The mask is
1282
- element-wise multiplied with the standardized kernel weights during the forward pass.
1283
- Default: None.
1284
- name : str, optional
1285
- Name identifier for this module instance.
1286
- Default: None.
1287
- param_type : type, optional
1288
- The parameter state class to use for managing learnable parameters.
1289
- Default: ParamState.
1290
-
1291
- Attributes
1292
- ----------
1293
- in_size : tuple of int
1294
- The input shape (H, W, D, C) without batch dimension.
1295
- out_size : tuple of int
1296
- The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
1297
- in_channels : int
1298
- Number of input channels.
1299
- out_channels : int
1300
- Number of output channels.
1301
- kernel_size : tuple of int
1302
- Size of the convolving kernel (height, width, depth).
1303
- weight : ParamState
1304
- The learnable weights (and bias if specified) of the module.
1305
- eps : float
1306
- Small constant for numerical stability in weight standardization.
1307
-
1308
- Examples
1309
- --------
1310
- .. code-block:: python
1311
-
1312
- >>> import brainstate as brainstate
1313
- >>> import jax.numpy as jnp
1314
- >>>
1315
- >>> # Create a 3D convolution with weight standardization for video
1316
- >>> conv = brainstate.nn.ScaledWSConv3d(
1317
- ... in_size=(16, 64, 64, 3),
1318
- ... out_channels=32,
1319
- ... kernel_size=3
1320
- ... )
1321
- >>>
1322
- >>> # Apply to input
1323
- >>> x = jnp.ones((4, 16, 64, 64, 3))
1324
- >>> y = conv(x)
1325
- >>> print(y.shape) # (4, 16, 64, 64, 32)
1326
- >>>
1327
- >>> # For medical imaging with custom parameters
1328
- >>> conv = brainstate.nn.ScaledWSConv3d(
1329
- ... in_size=(32, 32, 32, 1),
1330
- ... out_channels=64,
1331
- ... kernel_size=(3, 3, 3),
1332
- ... stride=2,
1333
- ... ws_gain=True,
1334
- ... eps=1e-5,
1335
- ... b_init=brainstate.init.Constant(0.01)
1336
- ... )
1337
- >>>
1338
- >>> # 3D grouped convolution with weight standardization
1339
- >>> conv = brainstate.nn.ScaledWSConv3d(
1340
- ... in_size=(8, 16, 16, 64),
1341
- ... out_channels=64,
1342
- ... kernel_size=3,
1343
- ... groups=8,
1344
- ... ws_gain=False
1345
- ... )
1346
-
1347
- Notes
1348
- -----
1349
- **Weight standardization formula:**
1350
-
1351
- Weight standardization reparameterizes the convolutional weights as:
1352
-
1353
- .. math::
1354
- \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1355
-
1356
- where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1357
- of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
1358
- and :math:`\\epsilon` is a small constant for numerical stability.
1359
-
1360
- **Why weight standardization for 3D:**
1361
-
1362
- For 3D convolutions, weight standardization is particularly beneficial because:
1363
-
1364
- - 3D networks are typically much deeper and harder to train
1365
- - Reduces sensitivity to weight initialization
1366
- - Improves gradient flow through very deep networks
1367
- - Works well with limited computational resources (small batches)
1368
- - Compatible with Group Normalization for batch-independent normalization
1369
-
1370
- **Applications:**
1371
-
1372
- Video understanding, medical imaging (CT, MRI scans), 3D object recognition,
1373
- and temporal sequence modeling.
1374
-
1375
- References
1376
- ----------
1377
- .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1378
- Weight Standardization. arXiv preprint arXiv:1903.10520.
1379
- """
1380
- __module__ = 'brainstate.nn'
1381
- num_spatial_dims: int = 3
1382
-
1383
-
1384
- class _ConvTranspose(_BaseConv):
1385
- """Base class for transposed convolution layers."""
1386
- num_spatial_dims: int = None
1387
-
1388
- def __init__(
1389
- self,
1390
- in_size: Sequence[int],
1391
- out_channels: int,
1392
- kernel_size: Union[int, Tuple[int, ...]],
1393
- stride: Union[int, Tuple[int, ...]] = 1,
1394
- padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
1395
- lhs_dilation: Union[int, Tuple[int, ...]] = 1,
1396
- rhs_dilation: Union[int, Tuple[int, ...]] = 1,
1397
- groups: int = 1,
1398
- w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
1399
- b_init: Optional[Union[Callable, ArrayLike]] = None,
1400
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
1401
- channel_first: bool = False,
1402
- name: str = None,
1403
- param_type: type = ParamState,
1404
- ):
1405
- # Initialize with transpose=True for dimension numbers
1406
- Module.__init__(self, name=name)
1407
-
1408
- # general parameters
1409
- assert self.num_spatial_dims + 1 == len(in_size)
1410
- self.in_size = tuple(in_size)
1411
- self.channel_first = channel_first
1412
- self.channels_last = not channel_first
1413
-
1414
- # Determine in_channels based on channel_first
1415
- if self.channel_first:
1416
- self.in_channels = in_size[0]
1417
- else:
1418
- self.in_channels = in_size[-1]
1419
-
1420
- self.out_channels = out_channels
1421
- self.stride = replicate(stride, self.num_spatial_dims, 'stride')
1422
- self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
1423
- self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
1424
- self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
1425
- self.groups = groups
1426
- self.dimension_numbers = to_dimension_numbers(
1427
- self.num_spatial_dims,
1428
- channels_last=self.channels_last,
1429
- transpose=True # Key difference from regular Conv
1430
- )
1431
-
1432
- # the padding parameter
1433
- # For transposed convolution, string padding needs to be converted to explicit padding
1434
- # when using lhs_dilation (stride) > 1
1435
- if isinstance(padding, str):
1436
- assert padding in ['SAME', 'VALID']
1437
- self.padding_mode = padding
1438
- # Compute explicit padding for transposed convolution
1439
- if max(self.stride) > 1:
1440
- # For transposed conv with stride, compute padding to achieve desired output size
1441
- spatial_in_size = self.in_size[:-1] if not self.channel_first else self.in_size[1:]
1442
- if padding == 'SAME':
1443
- # For SAME padding with transposed conv: output_size = input_size * stride
1444
- # Compute required padding to achieve this
1445
- explicit_padding = []
1446
- for i, (k, s, in_dim) in enumerate(zip(self.kernel_size, self.stride, spatial_in_size)):
1447
- # Desired output size
1448
- out_dim = in_dim * s
1449
- # Calculate total padding needed
1450
- # For transposed conv: out = (in - 1) * stride + kernel - 2 * pad
1451
- # Solving for pad: pad = (kernel + (in-1) * stride - out) // 2
1452
- total_pad = max(k + (in_dim - 1) * s - out_dim, 0)
1453
- pad_left = total_pad // 2
1454
- pad_right = total_pad - pad_left
1455
- explicit_padding.append((pad_left, pad_right))
1456
- padding = tuple(explicit_padding)
1457
- else: # 'VALID'
1458
- # For VALID padding: no padding
1459
- padding = tuple((0, 0) for _ in range(self.num_spatial_dims))
1460
- # If stride is 1, keep string padding
1461
- elif isinstance(padding, int):
1462
- self.padding_mode = 'explicit'
1463
- padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
1464
- elif isinstance(padding, (tuple, list)):
1465
- self.padding_mode = 'explicit'
1466
- if isinstance(padding[0], int):
1467
- padding = (padding,) * self.num_spatial_dims
1468
- elif isinstance(padding[0], (tuple, list)):
1469
- if len(padding) == 1:
1470
- padding = tuple(padding) * self.num_spatial_dims
1471
- else:
1472
- if len(padding) != self.num_spatial_dims:
1473
- raise ValueError(
1474
- f"Padding {padding} must be a Tuple[int, int], "
1475
- f"or sequence of Tuple[int, int] with length 1, "
1476
- f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
1477
- )
1478
- padding = tuple(padding)
1479
- else:
1480
- raise ValueError
1481
- self.padding = padding
1482
-
1483
- # the number of in-/out-channels
1484
- assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
1485
- assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
1486
-
1487
- # kernel shape for transpose conv
1488
- # When transpose=True in dimension_numbers, kernel is (spatial..., out_channels, in_channels // groups)
1489
- # This matches JAX's expectation for transposed convolution
1490
- kernel_shape = tuple(self.kernel_size) + (self.out_channels, self.in_channels // self.groups)
1491
- self.kernel_shape = kernel_shape
1492
- self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
1493
-
1494
- self.w_initializer = w_init
1495
- self.b_initializer = b_init
1496
-
1497
- # --- weights --- #
1498
- weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
1499
- params = dict(weight=weight)
1500
- if self.b_initializer is not None:
1501
- bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
1502
- bias = init.param(self.b_initializer, bias_shape, allow_none=True)
1503
- params['bias'] = bias
1504
-
1505
- # The weight operation
1506
- self.weight = param_type(params)
1507
-
1508
- # Evaluate the output shape
1509
- test_input_shape = (128,) + self.in_size
1510
- abstract_y = jax.eval_shape(
1511
- self._conv_op,
1512
- jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
1513
- params
1514
- )
1515
- y_shape = abstract_y.shape[1:]
1516
- self.out_size = y_shape
1517
-
1518
- def _conv_op(self, x, params):
1519
- w = params['weight']
1520
- if self.w_mask is not None:
1521
- w = w * self.w_mask
1522
- # For transposed convolution:
1523
- # - window_strides should be (1,1,...) - no striding in the conv operation
1524
- # - lhs_dilation should be the stride - this creates the upsampling effect
1525
- window_strides = (1,) * self.num_spatial_dims
1526
- y = jax.lax.conv_general_dilated(
1527
- lhs=x,
1528
- rhs=w,
1529
- window_strides=window_strides,
1530
- padding=self.padding,
1531
- lhs_dilation=self.stride, # For transpose conv, use stride as lhs_dilation
1532
- rhs_dilation=self.rhs_dilation,
1533
- feature_group_count=self.groups,
1534
- dimension_numbers=self.dimension_numbers
1535
- )
1536
- if 'bias' in params:
1537
- y = y + params['bias']
1538
- return y
1539
-
1540
-
1541
- class ConvTranspose1d(_ConvTranspose):
1542
- """
1543
- One-dimensional transposed convolution layer (also known as deconvolution).
1544
-
1545
- Applies a 1D transposed convolution over an input signal. Transposed convolution
1546
- is used for upsampling, reversing the spatial transformation of a regular convolution.
1547
- It's commonly used in autoencoders, GANs, and semantic segmentation networks.
1548
-
1549
- The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
1550
- L is the sequence length, and C is the number of input channels (channels-last format).
1551
-
1552
- Parameters
1553
- ----------
1554
- in_size : tuple of int
1555
- The input shape without the batch dimension. For ConvTranspose1d: (L, C) where L
1556
- is the sequence length and C is the number of input channels.
1557
- out_channels : int
1558
- The number of output channels (feature maps) produced by the transposed convolution.
1559
- kernel_size : int or tuple of int
1560
- The shape of the convolutional kernel. For 1D, can be an integer or a single-element tuple.
1561
- stride : int or tuple of int, optional
1562
- The stride of the transposed convolution. Larger strides produce larger output sizes,
1563
- which is the opposite behavior of regular convolution. Default: 1.
1564
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1565
- The padding strategy. Options:
1566
-
1567
- - 'SAME': output length approximately equals input_length * stride
1568
- - 'VALID': no padding, maximum output size
1569
- - int: symmetric padding
1570
- - (pad_before, pad_after): explicit padding for the sequence dimension
1571
-
1572
- Default: 'SAME'.
1573
- lhs_dilation : int or tuple of int, optional
1574
- The dilation factor for the input. For transposed convolution, this is typically
1575
- set equal to stride internally. Default: 1.
1576
- rhs_dilation : int or tuple of int, optional
1577
- The dilation factor for the kernel. Increases the receptive field without increasing
1578
- parameters by inserting zeros between kernel elements. Default: 1.
1579
- groups : int, optional
1580
- Number of groups for grouped transposed convolution. Both `in_channels` and
1581
- `out_channels` must be divisible by `groups`. Default: 1.
1582
- w_init : Callable or ArrayLike, optional
1583
- The initializer for the convolutional kernel weights. Default: XavierNormal().
1584
- b_init : Callable or ArrayLike or None, optional
1585
- The initializer for the bias. If None, no bias is added. Default: None.
1586
- w_mask : ArrayLike or Callable or None, optional
1587
- An optional mask applied to the weights during forward pass. Default: None.
1588
- channel_first : bool, optional
1589
- If True, uses channels-first format (e.g., [B, C, L]). If False, uses channels-last
1590
- format (e.g., [B, L, C]). Default: False (channels-last, JAX convention).
1591
- name : str, optional
1592
- The name of the module. Default: None.
1593
- param_type : type, optional
1594
- The type of parameter state to use. Default: ParamState.
1595
-
1596
- Attributes
1597
- ----------
1598
- in_size : tuple of int
1599
- The input shape (L, C) without batch dimension.
1600
- out_size : tuple of int
1601
- The output shape (L_out, out_channels) without batch dimension.
1602
- in_channels : int
1603
- Number of input channels.
1604
- out_channels : int
1605
- Number of output channels.
1606
- kernel_size : tuple of int
1607
- Size of the convolving kernel.
1608
- weight : ParamState
1609
- The learnable weights (and bias if specified) of the module.
1610
-
1611
- Examples
1612
- --------
1613
- .. code-block:: python
1614
-
1615
- >>> import brainstate as brainstate
1616
- >>> import jax.numpy as jnp
1617
- >>>
1618
- >>> # Create a 1D transposed convolution layer for upsampling
1619
- >>> conv_transpose = brainstate.nn.ConvTranspose1d(
1620
- ... in_size=(28, 16),
1621
- ... out_channels=8,
1622
- ... kernel_size=4,
1623
- ... stride=2
1624
- ... )
1625
- >>>
1626
- >>> # Apply to input: batch_size=2, length=28, channels=16
1627
- >>> x = jnp.ones((2, 28, 16))
1628
- >>> y = conv_transpose(x)
1629
- >>> print(y.shape) # Output will be upsampled
1630
- >>>
1631
- >>> # Without batch dimension
1632
- >>> x_single = jnp.ones((28, 16))
1633
- >>> y_single = conv_transpose(x_single)
1634
- >>>
1635
- >>> # Channels-first format (PyTorch style)
1636
- >>> conv_transpose = brainstate.nn.ConvTranspose1d(
1637
- ... in_size=(16, 28),
1638
- ... out_channels=8,
1639
- ... kernel_size=4,
1640
- ... stride=2,
1641
- ... channel_first=True
1642
- ... )
1643
- >>> x = jnp.ones((2, 16, 28))
1644
- >>> y = conv_transpose(x)
1645
-
1646
- Notes
1647
- -----
1648
- **Output dimensions:**
1649
-
1650
- Unlike regular convolution, transposed convolution increases spatial dimensions.
1651
- With stride > 1, the output is larger than the input:
1652
-
1653
- - output_length ≈ input_length * stride (depends on padding and kernel size)
1654
-
1655
- **Relationship to regular convolution:**
1656
-
1657
- Transposed convolution performs the gradient computation of a regular convolution
1658
- with respect to its input. It's sometimes called "deconvolution" but this term
1659
- is mathematically imprecise.
1660
-
1661
- **Common use cases:**
1662
-
1663
- - Upsampling in encoder-decoder architectures
1664
- - Generative models (GANs, VAEs)
1665
- - Semantic segmentation (U-Net, FCN)
1666
- - Super-resolution networks
1667
-
1668
- **Comparison with PyTorch:**
1669
-
1670
- - PyTorch uses channels-first by default; BrainState uses channels-last
1671
- - Set `channel_first=True` for PyTorch-compatible format
1672
- - PyTorch's `output_padding` is handled through padding parameter
1673
- """
1674
- __module__ = 'brainstate.nn'
1675
- num_spatial_dims: int = 1
1676
-
1677
-
1678
- class ConvTranspose2d(_ConvTranspose):
1679
- """
1680
- Two-dimensional transposed convolution layer (also known as deconvolution).
1681
-
1682
- Applies a 2D transposed convolution over an input signal. Transposed convolution
1683
- is the gradient of a regular convolution with respect to its input, commonly used
1684
- for upsampling feature maps in encoder-decoder architectures, GANs, and segmentation.
1685
-
1686
- The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
1687
- H is height, W is width, and C is the number of input channels (channels-last format).
1688
-
1689
- Parameters
1690
- ----------
1691
- in_size : tuple of int
1692
- The input shape without the batch dimension. For ConvTranspose2d: (H, W, C) where
1693
- H is height, W is width, and C is the number of input channels.
1694
- out_channels : int
1695
- The number of output channels (feature maps) produced by the transposed convolution.
1696
- kernel_size : int or tuple of int
1697
- The shape of the convolutional kernel. Can be:
1698
-
1699
- - An integer (e.g., 4): creates a square kernel (4, 4)
1700
- - A tuple of two integers (e.g., (4, 4)): creates a (height, width) kernel
1701
- stride : int or tuple of int, optional
1702
- The stride of the transposed convolution. Controls the upsampling factor.
1703
- Can be:
1704
-
1705
- - An integer: same stride for both dimensions
1706
- - A tuple of two integers: (stride_height, stride_width)
1707
-
1708
- Larger strides produce larger outputs. Default: 1.
1709
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1710
- The padding strategy. Options:
1711
-
1712
- - 'SAME': output size approximately equals input_size * stride
1713
- - 'VALID': no padding, maximum output size
1714
- - int: same symmetric padding for all dimensions
1715
- - (pad_h, pad_w): different padding for each dimension
1716
- - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
1717
-
1718
- Default: 'SAME'.
1719
- lhs_dilation : int or tuple of int, optional
1720
- The dilation factor for the input. For transposed convolution, this is typically
1721
- set equal to stride internally. Default: 1.
1722
- rhs_dilation : int or tuple of int, optional
1723
- The dilation factor for the kernel. Increases the receptive field without increasing
1724
- parameters by inserting zeros between kernel elements. Default: 1.
1725
- groups : int, optional
1726
- Number of groups for grouped transposed convolution. Must divide both `in_channels`
1727
- and `out_channels`. Default: 1.
1728
- w_init : Callable or ArrayLike, optional
1729
- Weight initializer for the convolutional kernel. Default: XavierNormal().
1730
- b_init : Callable or ArrayLike or None, optional
1731
- Bias initializer. If None, no bias term is added. Default: None.
1732
- w_mask : ArrayLike or Callable or None, optional
1733
- Optional weight mask for structured sparsity. Default: None.
1734
- channel_first : bool, optional
1735
- If True, uses channels-first format (e.g., [B, C, H, W]). If False, uses channels-last
1736
- format (e.g., [B, H, W, C]). Default: False (channels-last, JAX convention).
1737
- name : str, optional
1738
- Name identifier for this module instance. Default: None.
1739
- param_type : type, optional
1740
- The parameter state class to use. Default: ParamState.
1741
-
1742
- Attributes
1743
- ----------
1744
- in_size : tuple of int
1745
- The input shape (H, W, C) without batch dimension.
1746
- out_size : tuple of int
1747
- The output shape (H_out, W_out, out_channels) without batch dimension.
1748
- in_channels : int
1749
- Number of input channels.
1750
- out_channels : int
1751
- Number of output channels.
1752
- kernel_size : tuple of int
1753
- Size of the convolving kernel (height, width).
1754
- weight : ParamState
1755
- The learnable weights (and bias if specified) of the module.
1756
-
1757
- Examples
1758
- --------
1759
- .. code-block:: python
1760
-
1761
- >>> import brainstate as brainstate
1762
- >>> import jax.numpy as jnp
1763
- >>>
1764
- >>> # Create a 2D transposed convolution for upsampling
1765
- >>> conv_transpose = brainstate.nn.ConvTranspose2d(
1766
- ... in_size=(32, 32, 64),
1767
- ... out_channels=32,
1768
- ... kernel_size=4,
1769
- ... stride=2
1770
- ... )
1771
- >>>
1772
- >>> # Apply to input: batch_size=8, height=32, width=32, channels=64
1773
- >>> x = jnp.ones((8, 32, 32, 64))
1774
- >>> y = conv_transpose(x)
1775
- >>> print(y.shape) # Output will be approximately (8, 64, 64, 32)
1776
- >>>
1777
- >>> # Without batch dimension
1778
- >>> x_single = jnp.ones((32, 32, 64))
1779
- >>> y_single = conv_transpose(x_single)
1780
- >>>
1781
- >>> # Decoder in autoencoder (upsampling path)
1782
- >>> decoder = brainstate.nn.ConvTranspose2d(
1783
- ... in_size=(16, 16, 128),
1784
- ... out_channels=64,
1785
- ... kernel_size=4,
1786
- ... stride=2,
1787
- ... padding='SAME',
1788
- ... b_init=brainstate.init.Constant(0.0)
1789
- ... )
1790
- >>>
1791
- >>> # Channels-first format (PyTorch style)
1792
- >>> conv_transpose = brainstate.nn.ConvTranspose2d(
1793
- ... in_size=(64, 32, 32),
1794
- ... out_channels=32,
1795
- ... kernel_size=4,
1796
- ... stride=2,
1797
- ... channel_first=True
1798
- ... )
1799
- >>> x = jnp.ones((8, 64, 32, 32))
1800
- >>> y = conv_transpose(x)
1801
-
1802
- Notes
1803
- -----
1804
- **Output dimensions:**
1805
-
1806
- Transposed convolution increases spatial dimensions, with the upsampling factor
1807
- primarily controlled by stride:
1808
-
1809
- - output_size ≈ input_size * stride (exact size depends on padding and kernel size)
1810
- - 'SAME' padding: output_size = input_size * stride
1811
- - 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
1812
-
1813
- **Relationship to regular convolution:**
1814
-
1815
- Transposed convolution is the backward pass of a regular convolution. If a regular
1816
- convolution reduces spatial dimensions from X to Y, a transposed convolution with
1817
- the same parameters increases dimensions from Y back to approximately X.
1818
-
1819
- **Common use cases:**
1820
-
1821
- - Image segmentation (U-Net, SegNet, FCN)
1822
- - Image-to-image translation (pix2pix, CycleGAN)
1823
- - Generative models (DCGAN, VAE decoders)
1824
- - Super-resolution networks
1825
- - Autoencoders (decoder path)
1826
-
1827
- **Comparison with PyTorch:**
1828
-
1829
- - PyTorch uses channels-first by default; BrainState uses channels-last
1830
- - Set `channel_first=True` for PyTorch-compatible format
1831
- - Kernel shape convention: PyTorch stores (C_in, C_out, H, W), BrainState uses (H, W, C_out, C_in)
1832
- - PyTorch's `output_padding` parameter controls output size; use padding parameter here
1833
-
1834
- **Tips:**
1835
-
1836
- - Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
1837
- - Initialize with bilinear upsampling weights for better convergence in segmentation
1838
- - Combine with batch normalization or group normalization for stable training
1839
- """
1840
- __module__ = 'brainstate.nn'
1841
- num_spatial_dims: int = 2
1842
-
1843
-
1844
- class ConvTranspose3d(_ConvTranspose):
1845
- """
1846
- Three-dimensional transposed convolution layer (also known as deconvolution).
1847
-
1848
- Applies a 3D transposed convolution over an input signal. Used for upsampling
1849
- 3D feature maps in video generation, 3D segmentation, and volumetric reconstruction.
1850
-
1851
- The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
1852
- H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
1853
-
1854
- Parameters
1855
- ----------
1856
- in_size : tuple of int
1857
- The input shape without the batch dimension. For ConvTranspose3d: (H, W, D, C) where
1858
- H is height, W is width, D is depth, and C is the number of input channels.
1859
- out_channels : int
1860
- The number of output channels (feature maps) produced by the transposed convolution.
1861
- kernel_size : int or tuple of int
1862
- The shape of the convolutional kernel. Can be:
1863
-
1864
- - An integer (e.g., 4): creates a cubic kernel (4, 4, 4)
1865
- - A tuple of three integers (e.g., (4, 4, 4)): creates a (height, width, depth) kernel
1866
- stride : int or tuple of int, optional
1867
- The stride of the transposed convolution. Controls the upsampling factor.
1868
- Can be:
1869
-
1870
- - An integer: same stride for all dimensions
1871
- - A tuple of three integers: (stride_h, stride_w, stride_d)
1872
-
1873
- Larger strides produce larger outputs. Default: 1.
1874
- padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1875
- The padding strategy. Options:
1876
-
1877
- - 'SAME': output size approximately equals input_size * stride
1878
- - 'VALID': no padding, maximum output size
1879
- - int: same symmetric padding for all dimensions
1880
- - (pad_h, pad_w, pad_d): different padding for each dimension
1881
- - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit
1882
-
1883
- Default: 'SAME'.
1884
- lhs_dilation : int or tuple of int, optional
1885
- The dilation factor for the input. For transposed convolution, this is typically
1886
- set equal to stride internally. Default: 1.
1887
- rhs_dilation : int or tuple of int, optional
1888
- The dilation factor for the kernel. Increases the receptive field without increasing
1889
- parameters. Default: 1.
1890
- groups : int, optional
1891
- Number of groups for grouped transposed convolution. Must divide both `in_channels`
1892
- and `out_channels`. Useful for reducing computational cost in 3D. Default: 1.
1893
- w_init : Callable or ArrayLike, optional
1894
- Weight initializer for the convolutional kernel. Default: XavierNormal().
1895
- b_init : Callable or ArrayLike or None, optional
1896
- Bias initializer. If None, no bias term is added. Default: None.
1897
- w_mask : ArrayLike or Callable or None, optional
1898
- Optional weight mask for structured sparsity. Default: None.
1899
- channel_first : bool, optional
1900
- If True, uses channels-first format (e.g., [B, C, H, W, D]). If False, uses channels-last
1901
- format (e.g., [B, H, W, D, C]). Default: False (channels-last, JAX convention).
1902
- name : str, optional
1903
- Name identifier for this module instance. Default: None.
1904
- param_type : type, optional
1905
- The parameter state class to use. Default: ParamState.
1906
-
1907
- Attributes
1908
- ----------
1909
- in_size : tuple of int
1910
- The input shape (H, W, D, C) without batch dimension.
1911
- out_size : tuple of int
1912
- The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
1913
- in_channels : int
1914
- Number of input channels.
1915
- out_channels : int
1916
- Number of output channels.
1917
- kernel_size : tuple of int
1918
- Size of the convolving kernel (height, width, depth).
1919
- weight : ParamState
1920
- The learnable weights (and bias if specified) of the module.
1921
-
1922
- Examples
1923
- --------
1924
- .. code-block:: python
1925
-
1926
- >>> import brainstate as brainstate
1927
- >>> import jax.numpy as jnp
1928
- >>>
1929
- >>> # Create a 3D transposed convolution for video upsampling
1930
- >>> conv_transpose = brainstate.nn.ConvTranspose3d(
1931
- ... in_size=(8, 16, 16, 64),
1932
- ... out_channels=32,
1933
- ... kernel_size=4,
1934
- ... stride=2
1935
- ... )
1936
- >>>
1937
- >>> # Apply to input: batch_size=4, frames=8, height=16, width=16, channels=64
1938
- >>> x = jnp.ones((4, 8, 16, 16, 64))
1939
- >>> y = conv_transpose(x)
1940
- >>> print(y.shape) # Output will be approximately (4, 16, 32, 32, 32)
1941
- >>>
1942
- >>> # Without batch dimension
1943
- >>> x_single = jnp.ones((8, 16, 16, 64))
1944
- >>> y_single = conv_transpose(x_single)
1945
- >>>
1946
- >>> # For medical imaging reconstruction
1947
- >>> decoder = brainstate.nn.ConvTranspose3d(
1948
- ... in_size=(16, 16, 16, 128),
1949
- ... out_channels=64,
1950
- ... kernel_size=(4, 4, 4),
1951
- ... stride=2,
1952
- ... padding='SAME',
1953
- ... b_init=brainstate.init.Constant(0.0)
1954
- ... )
1955
- >>>
1956
- >>> # Channels-first format (PyTorch style)
1957
- >>> conv_transpose = brainstate.nn.ConvTranspose3d(
1958
- ... in_size=(64, 8, 16, 16),
1959
- ... out_channels=32,
1960
- ... kernel_size=4,
1961
- ... stride=2,
1962
- ... channel_first=True
1963
- ... )
1964
- >>> x = jnp.ones((4, 64, 8, 16, 16))
1965
- >>> y = conv_transpose(x)
1966
-
1967
- Notes
1968
- -----
1969
- **Output dimensions:**
1970
-
1971
- Transposed convolution increases spatial dimensions:
1972
-
1973
- - output_size ≈ input_size * stride (exact size depends on padding and kernel size)
1974
- - 'SAME' padding: output_size = input_size * stride
1975
- - 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
1976
-
1977
- **Computational considerations:**
1978
-
1979
- 3D transposed convolutions are very computationally expensive. Consider:
1980
-
1981
- - Using grouped convolutions (groups > 1) to reduce parameters
1982
- - Smaller kernel sizes
1983
- - Progressive upsampling (multiple layers with stride=2)
1984
- - Separable convolutions for large-scale applications
1985
-
1986
- **Common use cases:**
1987
-
1988
- - Video generation and prediction
1989
- - 3D medical image segmentation (U-Net 3D)
1990
- - Volumetric reconstruction
1991
- - 3D super-resolution
1992
- - Video frame interpolation
1993
- - 3D VAE decoders
1994
-
1995
- **Comparison with PyTorch:**
1996
-
1997
- - PyTorch uses channels-first by default; BrainState uses channels-last
1998
- - Set `channel_first=True` for PyTorch-compatible format
1999
- - Kernel shape convention differs between frameworks
2000
- - PyTorch's `output_padding` parameter is handled through padding here
2001
-
2002
- **Tips:**
2003
-
2004
- - Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
2005
- - Group normalization often works better than batch normalization for 3D
2006
- - Consider using smaller batch sizes due to memory constraints
2007
- - Progressive upsampling (2x at a time) is more stable than large strides
2008
- """
2009
- __module__ = 'brainstate.nn'
2010
- num_spatial_dims: int = 3
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import collections.abc
19
+ from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
20
+
21
+ import brainunit as u
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+ from brainstate._state import ParamState
26
+ from brainstate.typing import ArrayLike
27
+ from . import init as init
28
+ from ._module import Module
29
+ from ._normalizations import weight_standardization
30
+
31
+ T = TypeVar('T')
32
+
33
+ __all__ = [
34
+ 'Conv1d', 'Conv2d', 'Conv3d',
35
+ 'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
36
+ 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
37
+ ]
38
+
39
+
40
+ def to_dimension_numbers(
41
+ num_spatial_dims: int,
42
+ channels_last: bool,
43
+ transpose: bool
44
+ ) -> jax.lax.ConvDimensionNumbers:
45
+ """
46
+ Create a `lax.ConvDimensionNumbers` for the given inputs.
47
+
48
+ This function generates the dimension specification needed for JAX's convolution
49
+ operations based on the number of spatial dimensions and data format.
50
+
51
+ Parameters
52
+ ----------
53
+ num_spatial_dims : int
54
+ The number of spatial dimensions (e.g., 1 for Conv1d, 2 for Conv2d, 3 for Conv3d).
55
+ channels_last : bool
56
+
57
+ - If True, the input format is channels-last (e.g., [B, H, W, C] for 2D).
58
+ - If False, the input format is channels-first (e.g., [B, C, H, W] for 2D).
59
+ transpose : bool
60
+
61
+ - If True, creates dimension numbers for transposed convolution.
62
+ - If False, creates dimension numbers for standard convolution.
63
+
64
+ Returns
65
+ -------
66
+ jax.lax.ConvDimensionNumbers
67
+ A named tuple specifying the dimension layout for lhs (input), rhs (kernel),
68
+ and output of the convolution operation.
69
+
70
+ Examples
71
+ --------
72
+ .. code-block:: python
73
+
74
+ >>> # For 2D convolution with channels-last format
75
+ >>> dim_nums = to_dimension_numbers(num_spatial_dims=2, channels_last=True, transpose=False)
76
+ >>> print(dim_nums.lhs_spec) # Input layout: (batch, spatial_1, spatial_2, channel)
77
+ (0, 3, 1, 2)
78
+ """
79
+ num_dims = num_spatial_dims + 2
80
+ if channels_last:
81
+ spatial_dims = tuple(range(1, num_dims - 1))
82
+ image_dn = (0, num_dims - 1) + spatial_dims
83
+ else:
84
+ spatial_dims = tuple(range(2, num_dims))
85
+ image_dn = (0, 1) + spatial_dims
86
+ if transpose:
87
+ kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
88
+ else:
89
+ kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
90
+ return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
91
+ rhs_spec=kernel_dn,
92
+ out_spec=image_dn)
93
+
94
+
95
+ def replicate(
96
+ element: Union[T, Sequence[T]],
97
+ num_replicate: int,
98
+ name: str,
99
+ ) -> Tuple[T, ...]:
100
+ """
101
+ Replicates entry in `element` `num_replicate` times if needed.
102
+
103
+ This utility function ensures that parameters like kernel_size, stride, etc.
104
+ are properly formatted as tuples with the correct length for multi-dimensional
105
+ convolutions.
106
+
107
+ Parameters
108
+ ----------
109
+ element : T or Sequence[T]
110
+ The element to replicate. Can be a scalar, string, or sequence.
111
+ num_replicate : int
112
+ The number of times to replicate the element.
113
+ name : str
114
+ The name of the parameter (used for error messages).
115
+
116
+ Returns
117
+ -------
118
+ tuple of T
119
+ A tuple containing the replicated elements.
120
+
121
+ Raises
122
+ ------
123
+ TypeError
124
+ If the element is a sequence with length not equal to 1 or `num_replicate`.
125
+
126
+ Examples
127
+ --------
128
+ .. code-block:: python
129
+
130
+ >>> # Replicate a scalar value
131
+ >>> replicate(3, 2, 'kernel_size')
132
+ (3, 3)
133
+ >>>
134
+ >>> # Keep a sequence as is if already correct length
135
+ >>> replicate((3, 5), 2, 'kernel_size')
136
+ (3, 5)
137
+ >>>
138
+ >>> # Replicate a single-element sequence
139
+ >>> replicate([3], 2, 'kernel_size')
140
+ (3, 3)
141
+ """
142
+ if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
143
+ return (element,) * num_replicate
144
+ elif len(element) == 1:
145
+ return tuple(list(element) * num_replicate)
146
+ elif len(element) == num_replicate:
147
+ return tuple(element)
148
+ else:
149
+ raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
150
+ f"sequence of length {num_replicate}.")
151
+
152
+
153
+ class _BaseConv(Module):
154
+ # the number of spatial dimensions
155
+ num_spatial_dims: int
156
+
157
+ # the weight and its operations
158
+ weight: ParamState
159
+
160
+ def __init__(
161
+ self,
162
+ in_size: Sequence[int],
163
+ out_channels: int,
164
+ kernel_size: Union[int, Tuple[int, ...]],
165
+ stride: Union[int, Tuple[int, ...]] = 1,
166
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
167
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
168
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
169
+ groups: int = 1,
170
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
171
+ channel_first: bool = False,
172
+ name: str = None,
173
+ ):
174
+ super().__init__(name=name)
175
+
176
+ # general parameters
177
+ assert self.num_spatial_dims + 1 == len(in_size)
178
+ self.in_size = tuple(in_size)
179
+ self.channel_first = channel_first
180
+ self.channels_last = not channel_first
181
+
182
+ # Determine in_channels based on channel_first
183
+ if self.channel_first:
184
+ self.in_channels = in_size[0]
185
+ else:
186
+ self.in_channels = in_size[-1]
187
+
188
+ self.out_channels = out_channels
189
+ self.stride = replicate(stride, self.num_spatial_dims, 'stride')
190
+ self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
191
+ self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
192
+ self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
193
+ self.groups = groups
194
+ self.dimension_numbers = to_dimension_numbers(
195
+ self.num_spatial_dims,
196
+ channels_last=self.channels_last,
197
+ transpose=False
198
+ )
199
+
200
+ # the padding parameter
201
+ if isinstance(padding, str):
202
+ assert padding in ['SAME', 'VALID']
203
+ elif isinstance(padding, int):
204
+ padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
205
+ elif isinstance(padding, (tuple, list)):
206
+ if isinstance(padding[0], int):
207
+ padding = (padding,) * self.num_spatial_dims
208
+ elif isinstance(padding[0], (tuple, list)):
209
+ if len(padding) == 1:
210
+ padding = tuple(padding) * self.num_spatial_dims
211
+ else:
212
+ if len(padding) != self.num_spatial_dims:
213
+ raise ValueError(
214
+ f"Padding {padding} must be a Tuple[int, int], "
215
+ f"or sequence of Tuple[int, int] with length 1, "
216
+ f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
217
+ )
218
+ padding = tuple(padding)
219
+ else:
220
+ raise ValueError
221
+ self.padding = padding
222
+
223
+ # the number of in-/out-channels
224
+ assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
225
+ assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
226
+
227
+ # kernel shape and w_mask
228
+ kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
229
+ self.kernel_shape = kernel_shape
230
+ self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
231
+
232
+ def _check_input_dim(self, x):
233
+ if x.ndim == self.num_spatial_dims + 2:
234
+ x_shape = x.shape[1:]
235
+ elif x.ndim == self.num_spatial_dims + 1:
236
+ x_shape = x.shape
237
+ else:
238
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
239
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
240
+
241
+ # Check shape matches expected in_size
242
+ if self.channel_first:
243
+ # For channels-first, expected shape is already (C, spatial...)
244
+ expected_shape = self.in_size
245
+ else:
246
+ # For channels-last, expected shape is (spatial..., C)
247
+ expected_shape = self.in_size
248
+
249
+ if expected_shape != x_shape:
250
+ raise ValueError(f"The expected input shape is {expected_shape}, while we got {x_shape}.")
251
+
252
+ def update(self, x):
253
+ self._check_input_dim(x)
254
+ non_batching = False
255
+ if x.ndim == self.num_spatial_dims + 1:
256
+ x = u.math.expand_dims(x, 0)
257
+ non_batching = True
258
+ y = self._conv_op(x, self.weight.value)
259
+ return u.math.squeeze(y, axis=0) if non_batching else y
260
+
261
+ def _conv_op(self, x, params):
262
+ raise NotImplementedError
263
+
264
+
265
+ class _Conv(_BaseConv):
266
+ num_spatial_dims: int = None
267
+
268
+ def __init__(
269
+ self,
270
+ in_size: Sequence[int],
271
+ out_channels: int,
272
+ kernel_size: Union[int, Tuple[int, ...]],
273
+ stride: Union[int, Tuple[int, ...]] = 1,
274
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
275
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
276
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
277
+ groups: int = 1,
278
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
279
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
280
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
281
+ channel_first: bool = False,
282
+ name: str = None,
283
+ param_type: type = ParamState,
284
+ ):
285
+ super().__init__(
286
+ in_size=in_size,
287
+ out_channels=out_channels,
288
+ kernel_size=kernel_size,
289
+ stride=stride,
290
+ padding=padding,
291
+ lhs_dilation=lhs_dilation,
292
+ rhs_dilation=rhs_dilation,
293
+ groups=groups,
294
+ w_mask=w_mask,
295
+ channel_first=channel_first,
296
+ name=name
297
+ )
298
+
299
+ self.w_initializer = w_init
300
+ self.b_initializer = b_init
301
+
302
+ # --- weights --- #
303
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
304
+ params = dict(weight=weight)
305
+ if self.b_initializer is not None:
306
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
307
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
308
+ params['bias'] = bias
309
+
310
+ # The weight operation
311
+ self.weight = param_type(params)
312
+
313
+ # Evaluate the output shape
314
+ test_input_shape = (128,) + self.in_size
315
+ abstract_y = jax.eval_shape(
316
+ self._conv_op,
317
+ jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
318
+ params
319
+ )
320
+ y_shape = abstract_y.shape[1:]
321
+ self.out_size = y_shape
322
+
323
+ def _conv_op(self, x, params):
324
+ w = params['weight']
325
+ if self.w_mask is not None:
326
+ w = w * self.w_mask
327
+ y = jax.lax.conv_general_dilated(
328
+ lhs=x,
329
+ rhs=w,
330
+ window_strides=self.stride,
331
+ padding=self.padding,
332
+ lhs_dilation=self.lhs_dilation,
333
+ rhs_dilation=self.rhs_dilation,
334
+ feature_group_count=self.groups,
335
+ dimension_numbers=self.dimension_numbers
336
+ )
337
+ if 'bias' in params:
338
+ y = y + params['bias']
339
+ return y
340
+
341
+
342
+ class Conv1d(_Conv):
343
+ """
344
+ One-dimensional convolution layer.
345
+
346
+ Applies a 1D convolution over an input signal composed of several input planes.
347
+ The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
348
+ L is the sequence length, and C is the number of input channels.
349
+
350
+ This layer creates a convolution kernel that is convolved with the layer input
351
+ over a single spatial dimension to produce a tensor of outputs.
352
+
353
+ Parameters
354
+ ----------
355
+ in_size : tuple of int
356
+ The input shape without the batch dimension. This argument is important as it is
357
+ used to evaluate the output shape. For Conv1d: (L, C), Conv2d: (H, W, C), Conv3d: (H, W, D, C).
358
+ out_channels : int
359
+ The number of output channels (also called filters or feature maps).
360
+ kernel_size : int or tuple of int
361
+ The shape of the convolutional kernel. For 1D convolution, the kernel size can be
362
+ passed as an integer. For 2D and 3D convolutions, it should be a tuple of integers
363
+ or a single integer (which will be replicated for all spatial dimensions).
364
+ stride : int or tuple of int, optional
365
+ The stride of the convolution. An integer or a sequence of `n` integers, representing
366
+ the inter-window strides along each spatial dimension. Default: 1.
367
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
368
+ The padding strategy. Can be:
369
+
370
+ - 'SAME': pads the input so the output has the same shape as input when stride=1
371
+ - 'VALID': no padding
372
+ - int: symmetric padding applied to all spatial dimensions
373
+ - tuple of (low, high): padding for each dimension
374
+ - sequence of tuples: explicit padding for each spatial dimension
375
+
376
+ Default: 'SAME'.
377
+ lhs_dilation : int or tuple of int, optional
378
+ The dilation factor for the input. An integer or a sequence of `n` integers, giving
379
+ the dilation factor to apply in each spatial dimension of inputs. Convolution with
380
+ input dilation `d` is equivalent to transposed convolution with stride `d`.
381
+ Default: 1.
382
+ rhs_dilation : int or tuple of int, optional
383
+ The dilation factor for the kernel. An integer or a sequence of `n` integers, giving
384
+ the dilation factor to apply in each spatial dimension of the convolution kernel.
385
+ Convolution with kernel dilation is also known as 'atrous convolution', which increases
386
+ the receptive field without increasing the number of parameters. Default: 1.
387
+ groups : int, optional
388
+ Number of groups for grouped convolution. Controls the connections between inputs and
389
+ outputs. Both `in_channels` and `out_channels` must be divisible by `groups`. When
390
+ groups=1 (default), all inputs are convolved to all outputs. When groups>1, the input
391
+ and output channels are divided into groups, and each group is convolved independently.
392
+ When groups=in_channels, this becomes a depthwise convolution. Default: 1.
393
+ w_init : Callable or ArrayLike, optional
394
+ The initializer for the convolutional kernel weights. Can be an initializer instance
395
+ or a direct array. Default: XavierNormal().
396
+ b_init : Callable or ArrayLike or None, optional
397
+ The initializer for the bias. If None, no bias is added. Default: None.
398
+ w_mask : ArrayLike or Callable or None, optional
399
+ An optional mask applied to the weights during forward pass. Useful for implementing
400
+ structured sparsity or custom connectivity patterns. Default: None.
401
+ name : str, optional
402
+ The name of the module. Default: None.
403
+ param_type : type, optional
404
+ The type of parameter state to use. Default: ParamState.
405
+
406
+ Attributes
407
+ ----------
408
+ in_size : tuple of int
409
+ The input shape (L, C) without batch dimension.
410
+ out_size : tuple of int
411
+ The output shape (L_out, out_channels) without batch dimension.
412
+ in_channels : int
413
+ Number of input channels.
414
+ out_channels : int
415
+ Number of output channels.
416
+ kernel_size : tuple of int
417
+ Size of the convolving kernel.
418
+ weight : ParamState
419
+ The learnable weights (and bias if specified) of the module.
420
+
421
+ Examples
422
+ --------
423
+ .. code-block:: python
424
+
425
+ >>> import brainstate as brainstate
426
+ >>> import jax.numpy as jnp
427
+ >>>
428
+ >>> # Create a 1D convolution layer
429
+ >>> conv = brainstate.nn.Conv1d(in_size=(28, 3), out_channels=16, kernel_size=5)
430
+ >>>
431
+ >>> # Apply to input: batch_size=2, length=28, channels=3
432
+ >>> x = jnp.ones((2, 28, 3))
433
+ >>> y = conv(x)
434
+ >>> print(y.shape) # (2, 28, 16) with 'SAME' padding
435
+ >>>
436
+ >>> # Without batch dimension
437
+ >>> x_single = jnp.ones((28, 3))
438
+ >>> y_single = conv(x_single)
439
+ >>> print(y_single.shape) # (28, 16)
440
+ >>>
441
+ >>> # With custom parameters
442
+ >>> conv = brainstate.nn.Conv1d(
443
+ ... in_size=(100, 8),
444
+ ... out_channels=32,
445
+ ... kernel_size=3,
446
+ ... stride=2,
447
+ ... padding='VALID',
448
+ ... b_init=brainstate.init.ZeroInit()
449
+ ... )
450
+
451
+ Notes
452
+ -----
453
+ **Output dimensions:**
454
+
455
+ The output shape depends on the padding mode:
456
+
457
+ - 'SAME': output length = ceil(input_length / stride)
458
+ - 'VALID': output length = ceil((input_length - kernel_size + 1) / stride)
459
+
460
+ **Grouped convolution:**
461
+
462
+ When groups > 1, the convolution becomes a grouped convolution where input and
463
+ output channels are divided into groups, reducing computational cost.
464
+ """
465
+ __module__ = 'brainstate.nn'
466
+ num_spatial_dims: int = 1
467
+
468
+
469
+ class Conv2d(_Conv):
470
+ """
471
+ Two-dimensional convolution layer.
472
+
473
+ Applies a 2D convolution over an input signal composed of several input planes.
474
+ The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
475
+ H is height, W is width, and C is the number of input channels (channels-last format).
476
+
477
+ This layer creates a convolution kernel that is convolved with the layer input
478
+ to produce a tensor of outputs. It is commonly used in computer vision tasks.
479
+
480
+ Parameters
481
+ ----------
482
+ in_size : tuple of int
483
+ The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
484
+ W is width, and C is the number of input channels. This argument is important as it is
485
+ used to evaluate the output shape.
486
+ out_channels : int
487
+ The number of output channels (also called filters or feature maps). These determine
488
+ the depth of the output feature map.
489
+ kernel_size : int or tuple of int
490
+ The shape of the convolutional kernel. Can be:
491
+
492
+ - An integer (e.g., 3): creates a square kernel (3, 3)
493
+ - A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
494
+ stride : int or tuple of int, optional
495
+ The stride of the convolution. Controls how much the kernel moves at each step.
496
+ Can be:
497
+
498
+ - An integer: same stride for both dimensions
499
+ - A tuple of two integers: (stride_height, stride_width)
500
+
501
+ Default: 1.
502
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
503
+ The padding strategy. Options:
504
+
505
+ - 'SAME': output spatial size equals input size when stride=1
506
+ - 'VALID': no padding, output size reduced by kernel size
507
+ - int: same symmetric padding for all dimensions
508
+ - (pad_h, pad_w): different padding for each dimension
509
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
510
+
511
+ Default: 'SAME'.
512
+ lhs_dilation : int or tuple of int, optional
513
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
514
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
515
+ Default: 1.
516
+ rhs_dilation : int or tuple of int, optional
517
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
518
+ or dilated convolution. Increases the receptive field without increasing parameters by
519
+ inserting zeros between kernel elements. Useful for capturing multi-scale context.
520
+ Default: 1.
521
+ groups : int, optional
522
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
523
+
524
+ - groups=1: standard convolution (all-to-all connections)
525
+ - groups>1: grouped convolution (reduces parameters by factor of groups)
526
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
527
+
528
+ Default: 1.
529
+ w_init : Callable or ArrayLike, optional
530
+ Weight initializer for the convolutional kernel. Can be:
531
+
532
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
533
+ - A callable that returns an array given a shape
534
+ - A direct array matching the kernel shape
535
+
536
+ Default: XavierNormal().
537
+ b_init : Callable or ArrayLike or None, optional
538
+ Bias initializer. If None, no bias term is added to the output.
539
+ Default: None.
540
+ w_mask : ArrayLike or Callable or None, optional
541
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
542
+ element-wise multiplied with the kernel weights during the forward pass.
543
+ Default: None.
544
+ name : str, optional
545
+ Name identifier for this module instance.
546
+ Default: None.
547
+ param_type : type, optional
548
+ The parameter state class to use for managing learnable parameters.
549
+ Default: ParamState.
550
+
551
+ Attributes
552
+ ----------
553
+ in_size : tuple of int
554
+ The input shape (H, W, C) without batch dimension.
555
+ out_size : tuple of int
556
+ The output shape (H_out, W_out, out_channels) without batch dimension.
557
+ in_channels : int
558
+ Number of input channels.
559
+ out_channels : int
560
+ Number of output channels.
561
+ kernel_size : tuple of int
562
+ Size of the convolving kernel (height, width).
563
+ weight : ParamState
564
+ The learnable weights (and bias if specified) of the module.
565
+
566
+ Examples
567
+ --------
568
+ .. code-block:: python
569
+
570
+ >>> import brainstate as brainstate
571
+ >>> import jax.numpy as jnp
572
+ >>>
573
+ >>> # Create a 2D convolution layer
574
+ >>> conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
575
+ >>>
576
+ >>> # Apply to input: batch_size=8, height=32, width=32, channels=3
577
+ >>> x = jnp.ones((8, 32, 32, 3))
578
+ >>> y = conv(x)
579
+ >>> print(y.shape) # (8, 32, 32, 64) with 'SAME' padding
580
+ >>>
581
+ >>> # Without batch dimension
582
+ >>> x_single = jnp.ones((32, 32, 3))
583
+ >>> y_single = conv(x_single)
584
+ >>> print(y_single.shape) # (32, 32, 64)
585
+ >>>
586
+ >>> # With custom kernel size and stride
587
+ >>> conv = brainstate.nn.Conv2d(
588
+ ... in_size=(224, 224, 3),
589
+ ... out_channels=128,
590
+ ... kernel_size=(5, 5),
591
+ ... stride=2,
592
+ ... padding='VALID'
593
+ ... )
594
+ >>>
595
+ >>> # Depthwise convolution (groups = in_channels)
596
+ >>> conv = brainstate.nn.Conv2d(
597
+ ... in_size=(64, 64, 32),
598
+ ... out_channels=32,
599
+ ... kernel_size=3,
600
+ ... groups=32
601
+ ... )
602
+
603
+ Notes
604
+ -----
605
+ **Output dimensions:**
606
+
607
+ The output spatial dimensions depend on the padding mode:
608
+
609
+ - 'SAME': output_size = ceil(input_size / stride)
610
+ - 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
611
+
612
+ **Grouped convolution:**
613
+
614
+ When groups > 1, the input and output channels are divided into groups.
615
+ Each group is convolved independently, which can significantly reduce
616
+ computational cost while maintaining representational power.
617
+ """
618
+ __module__ = 'brainstate.nn'
619
+ num_spatial_dims: int = 2
620
+
621
+
622
+ class Conv3d(_Conv):
623
+ """
624
+ Three-dimensional convolution layer.
625
+
626
+ Applies a 3D convolution over an input signal composed of several input planes.
627
+ The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
628
+ H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
629
+
630
+ This layer is commonly used for processing 3D data such as video sequences or
631
+ volumetric medical imaging data.
632
+
633
+ Parameters
634
+ ----------
635
+ in_size : tuple of int
636
+ The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
637
+ W is width, D is depth, and C is the number of input channels. This argument is important
638
+ as it is used to evaluate the output shape.
639
+ out_channels : int
640
+ The number of output channels (also called filters or feature maps). These determine
641
+ the depth of the output feature map.
642
+ kernel_size : int or tuple of int
643
+ The shape of the convolutional kernel. Can be:
644
+
645
+ - An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
646
+ - A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
647
+ stride : int or tuple of int, optional
648
+ The stride of the convolution. Controls how much the kernel moves at each step.
649
+ Can be:
650
+
651
+ - An integer: same stride for all dimensions
652
+ - A tuple of three integers: (stride_h, stride_w, stride_d)
653
+ Default: 1.
654
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
655
+ The padding strategy. Options:
656
+
657
+ - 'SAME': output spatial size equals input size when stride=1
658
+ - 'VALID': no padding, output size reduced by kernel size
659
+ - int: same symmetric padding for all dimensions
660
+ - (pad_h, pad_w, pad_d): different padding for each dimension
661
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
662
+
663
+ Default: 'SAME'.
664
+ lhs_dilation : int or tuple of int, optional
665
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
666
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
667
+ Default: 1.
668
+ rhs_dilation : int or tuple of int, optional
669
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
670
+ or dilated convolution. Increases the receptive field without increasing parameters by
671
+ inserting zeros between kernel elements. Particularly useful for 3D data to capture
672
+ larger temporal/spatial context.
673
+ Default: 1.
674
+ groups : int, optional
675
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
676
+
677
+ - groups=1: standard convolution (all-to-all connections)
678
+ - groups>1: grouped convolution (significantly reduces parameters and computation for 3D)
679
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
680
+
681
+ Default: 1.
682
+ w_init : Callable or ArrayLike, optional
683
+ Weight initializer for the convolutional kernel. Can be:
684
+
685
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
686
+ - A callable that returns an array given a shape
687
+ - A direct array matching the kernel shape
688
+
689
+ Default: XavierNormal().
690
+ b_init : Callable or ArrayLike or None, optional
691
+ Bias initializer. If None, no bias term is added to the output.
692
+ Default: None.
693
+ w_mask : ArrayLike or Callable or None, optional
694
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
695
+ element-wise multiplied with the kernel weights during the forward pass.
696
+ Default: None.
697
+ name : str, optional
698
+ Name identifier for this module instance.
699
+ Default: None.
700
+ param_type : type, optional
701
+ The parameter state class to use for managing learnable parameters.
702
+ Default: ParamState.
703
+
704
+ Attributes
705
+ ----------
706
+ in_size : tuple of int
707
+ The input shape (H, W, D, C) without batch dimension.
708
+ out_size : tuple of int
709
+ The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
710
+ in_channels : int
711
+ Number of input channels.
712
+ out_channels : int
713
+ Number of output channels.
714
+ kernel_size : tuple of int
715
+ Size of the convolving kernel (height, width, depth).
716
+ weight : ParamState
717
+ The learnable weights (and bias if specified) of the module.
718
+
719
+ Examples
720
+ --------
721
+ .. code-block:: python
722
+
723
+ >>> import brainstate as brainstate
724
+ >>> import jax.numpy as jnp
725
+ >>>
726
+ >>> # Create a 3D convolution layer for video data
727
+ >>> conv = brainstate.nn.Conv3d(in_size=(16, 64, 64, 3), out_channels=32, kernel_size=3)
728
+ >>>
729
+ >>> # Apply to input: batch_size=4, frames=16, height=64, width=64, channels=3
730
+ >>> x = jnp.ones((4, 16, 64, 64, 3))
731
+ >>> y = conv(x)
732
+ >>> print(y.shape) # (4, 16, 64, 64, 32) with 'SAME' padding
733
+ >>>
734
+ >>> # Without batch dimension
735
+ >>> x_single = jnp.ones((16, 64, 64, 3))
736
+ >>> y_single = conv(x_single)
737
+ >>> print(y_single.shape) # (16, 64, 64, 32)
738
+ >>>
739
+ >>> # For medical imaging with custom parameters
740
+ >>> conv = brainstate.nn.Conv3d(
741
+ ... in_size=(32, 32, 32, 1),
742
+ ... out_channels=64,
743
+ ... kernel_size=(3, 3, 3),
744
+ ... stride=2,
745
+ ... padding='VALID',
746
+ ... b_init=brainstate.init.Constant(0.1)
747
+ ... )
748
+
749
+ Notes
750
+ -----
751
+ **Output dimensions:**
752
+
753
+ The output spatial dimensions depend on the padding mode:
754
+
755
+ - 'SAME': output_size = ceil(input_size / stride)
756
+ - 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
757
+
758
+ **Performance considerations:**
759
+
760
+ 3D convolutions are computationally expensive. Consider using:
761
+
762
+ - Smaller kernel sizes
763
+ - Grouped convolutions (groups > 1)
764
+ - Separable convolutions for large-scale applications
765
+ """
766
+ __module__ = 'brainstate.nn'
767
+ num_spatial_dims: int = 3
768
+
769
+
770
+ class _ScaledWSConv(_BaseConv):
771
+ def __init__(
772
+ self,
773
+ in_size: Sequence[int],
774
+ out_channels: int,
775
+ kernel_size: Union[int, Tuple[int, ...]],
776
+ stride: Union[int, Tuple[int, ...]] = 1,
777
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
778
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
779
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
780
+ groups: int = 1,
781
+ ws_gain: bool = True,
782
+ eps: float = 1e-4,
783
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
784
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
785
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
786
+ channel_first: bool = False,
787
+ name: str = None,
788
+ param_type: type = ParamState,
789
+ ):
790
+ super().__init__(
791
+ in_size=in_size,
792
+ out_channels=out_channels,
793
+ kernel_size=kernel_size,
794
+ stride=stride,
795
+ padding=padding,
796
+ lhs_dilation=lhs_dilation,
797
+ rhs_dilation=rhs_dilation,
798
+ groups=groups,
799
+ w_mask=w_mask,
800
+ channel_first=channel_first,
801
+ name=name,
802
+ )
803
+
804
+ self.w_initializer = w_init
805
+ self.b_initializer = b_init
806
+
807
+ # --- weights --- #
808
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
809
+ params = dict(weight=weight)
810
+ if self.b_initializer is not None:
811
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
812
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
813
+ params['bias'] = bias
814
+
815
+ # gain
816
+ if ws_gain:
817
+ gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
818
+ ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
819
+ params['gain'] = ws_gain
820
+
821
+ # Epsilon, a small constant to avoid dividing by zero.
822
+ self.eps = eps
823
+
824
+ # The weight operation
825
+ self.weight = param_type(params)
826
+
827
+ # Evaluate the output shape
828
+ if self.channel_first:
829
+ test_input_shape = (128,) + self.in_size
830
+ else:
831
+ test_input_shape = (128,) + self.in_size
832
+
833
+ abstract_y = jax.eval_shape(
834
+ self._conv_op,
835
+ jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
836
+ params
837
+ )
838
+ y_shape = abstract_y.shape[1:]
839
+ self.out_size = y_shape
840
+
841
+ def _conv_op(self, x, params):
842
+ w = params['weight']
843
+ w = weight_standardization(w, self.eps, params.get('gain', None))
844
+ if self.w_mask is not None:
845
+ w = w * self.w_mask
846
+ y = jax.lax.conv_general_dilated(
847
+ lhs=x,
848
+ rhs=w,
849
+ window_strides=self.stride,
850
+ padding=self.padding,
851
+ lhs_dilation=self.lhs_dilation,
852
+ rhs_dilation=self.rhs_dilation,
853
+ feature_group_count=self.groups,
854
+ dimension_numbers=self.dimension_numbers
855
+ )
856
+ if 'bias' in params:
857
+ y = y + params['bias']
858
+ return y
859
+
860
+
861
+ class ScaledWSConv1d(_ScaledWSConv):
862
+ """
863
+ One-dimensional convolution with weight standardization.
864
+
865
+ This layer applies weight standardization to the convolutional kernel before
866
+ performing the convolution operation. Weight standardization normalizes the
867
+ weights to have zero mean and unit variance, which can accelerate training
868
+ and improve model performance, especially when combined with group normalization.
869
+
870
+ The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
871
+ L is the sequence length, and C is the number of input channels.
872
+
873
+ Parameters
874
+ ----------
875
+ in_size : tuple of int
876
+ The input shape without the batch dimension. For Conv1d: (L, C) where L is the sequence
877
+ length and C is the number of input channels. This argument is important as it is used
878
+ to evaluate the output shape.
879
+ out_channels : int
880
+ The number of output channels (also called filters or feature maps). These determine
881
+ the depth of the output feature map.
882
+ kernel_size : int or tuple of int
883
+ The shape of the convolutional kernel. For 1D convolution, can be:
884
+
885
+ - An integer (e.g., 5): creates a kernel of size 5
886
+ - A tuple with one integer (e.g., (5,)): equivalent to the above
887
+ stride : int or tuple of int, optional
888
+ The stride of the convolution. Controls how much the kernel moves at each step.
889
+ Default: 1.
890
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
891
+ The padding strategy. Options:
892
+
893
+ - 'SAME': output length equals input length when stride=1
894
+ - 'VALID': no padding, output length reduced by kernel size
895
+ - int: symmetric padding
896
+ - (pad_before, pad_after): explicit padding for the sequence dimension
897
+
898
+ Default: 'SAME'.
899
+ lhs_dilation : int or tuple of int, optional
900
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
901
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
902
+ Default: 1.
903
+ rhs_dilation : int or tuple of int, optional
904
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
905
+ or dilated convolution. Increases the receptive field without increasing parameters by
906
+ inserting zeros between kernel elements. Useful for capturing long-range dependencies.
907
+ Default: 1.
908
+ groups : int, optional
909
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
910
+
911
+ - groups=1: standard convolution (all-to-all connections)
912
+ - groups>1: grouped convolution (reduces parameters by factor of groups)
913
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
914
+
915
+ Default: 1.
916
+ w_init : Callable or ArrayLike, optional
917
+ Weight initializer for the convolutional kernel. Can be:
918
+
919
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
920
+ - A callable that returns an array given a shape
921
+ - A direct array matching the kernel shape
922
+
923
+ Default: XavierNormal().
924
+ b_init : Callable or ArrayLike or None, optional
925
+ Bias initializer. If None, no bias term is added to the output.
926
+ Default: None.
927
+ ws_gain : bool, optional
928
+ Whether to include a learnable per-channel gain parameter in weight standardization.
929
+ When True, adds a scaling factor that can be learned during training, improving
930
+ model expressiveness. Recommended for most applications.
931
+ Default: True.
932
+ eps : float, optional
933
+ Small constant for numerical stability in weight standardization. Prevents division
934
+ by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
935
+ Default: 1e-4.
936
+ w_mask : ArrayLike or Callable or None, optional
937
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
938
+ element-wise multiplied with the standardized kernel weights during the forward pass.
939
+ Default: None.
940
+ name : str, optional
941
+ Name identifier for this module instance.
942
+ Default: None.
943
+ param_type : type, optional
944
+ The parameter state class to use for managing learnable parameters.
945
+ Default: ParamState.
946
+
947
+ Attributes
948
+ ----------
949
+ in_size : tuple of int
950
+ The input shape (L, C) without batch dimension.
951
+ out_size : tuple of int
952
+ The output shape (L_out, out_channels) without batch dimension.
953
+ in_channels : int
954
+ Number of input channels.
955
+ out_channels : int
956
+ Number of output channels.
957
+ kernel_size : tuple of int
958
+ Size of the convolving kernel.
959
+ weight : ParamState
960
+ The learnable weights (and bias if specified) of the module.
961
+ eps : float
962
+ Small constant for numerical stability in weight standardization.
963
+
964
+ Examples
965
+ --------
966
+ .. code-block:: python
967
+
968
+ >>> import brainstate as brainstate
969
+ >>> import jax.numpy as jnp
970
+ >>>
971
+ >>> # Create a 1D convolution with weight standardization
972
+ >>> conv = brainstate.nn.ScaledWSConv1d(
973
+ ... in_size=(100, 16),
974
+ ... out_channels=32,
975
+ ... kernel_size=5
976
+ ... )
977
+ >>>
978
+ >>> # Apply to input
979
+ >>> x = jnp.ones((4, 100, 16))
980
+ >>> y = conv(x)
981
+ >>> print(y.shape) # (4, 100, 32)
982
+ >>>
983
+ >>> # With custom epsilon and no gain
984
+ >>> conv = brainstate.nn.ScaledWSConv1d(
985
+ ... in_size=(50, 8),
986
+ ... out_channels=16,
987
+ ... kernel_size=3,
988
+ ... ws_gain=False,
989
+ ... eps=1e-5
990
+ ... )
991
+
992
+ Notes
993
+ -----
994
+ **Weight standardization formula:**
995
+
996
+ Weight standardization reparameterizes the convolutional weights as:
997
+
998
+ .. math::
999
+ \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1000
+
1001
+ where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1002
+ of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
1003
+ and :math:`\\epsilon` is a small constant for numerical stability.
1004
+
1005
+ **When to use:**
1006
+
1007
+ This technique is particularly effective when used with Group Normalization
1008
+ instead of Batch Normalization, as it reduces the dependence on batch statistics.
1009
+
1010
+ References
1011
+ ----------
1012
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1013
+ Weight Standardization. arXiv preprint arXiv:1903.10520.
1014
+ """
1015
+ __module__ = 'brainstate.nn'
1016
+ num_spatial_dims: int = 1
1017
+
1018
+
1019
+ class ScaledWSConv2d(_ScaledWSConv):
1020
+ """
1021
+ Two-dimensional convolution with weight standardization.
1022
+
1023
+ This layer applies weight standardization to the convolutional kernel before
1024
+ performing the convolution operation. Weight standardization normalizes the
1025
+ weights to have zero mean and unit variance, improving training dynamics and
1026
+ model generalization, particularly in combination with group normalization.
1027
+
1028
+ The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
1029
+ H is height, W is width, and C is the number of input channels (channels-last format).
1030
+
1031
+ Parameters
1032
+ ----------
1033
+ in_size : tuple of int
1034
+ The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
1035
+ W is width, and C is the number of input channels. This argument is important as it is
1036
+ used to evaluate the output shape.
1037
+ out_channels : int
1038
+ The number of output channels (also called filters or feature maps). These determine
1039
+ the depth of the output feature map.
1040
+ kernel_size : int or tuple of int
1041
+ The shape of the convolutional kernel. Can be:
1042
+
1043
+ - An integer (e.g., 3): creates a square kernel (3, 3)
1044
+ - A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
1045
+ stride : int or tuple of int, optional
1046
+ The stride of the convolution. Controls how much the kernel moves at each step.
1047
+ Can be:
1048
+
1049
+ - An integer: same stride for both dimensions
1050
+ - A tuple of two integers: (stride_height, stride_width)
1051
+
1052
+ Default: 1.
1053
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1054
+ The padding strategy. Options:
1055
+
1056
+ - 'SAME': output spatial size equals input size when stride=1
1057
+ - 'VALID': no padding, output size reduced by kernel size
1058
+ - int: same symmetric padding for all dimensions
1059
+ - (pad_h, pad_w): different padding for each dimension
1060
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
1061
+
1062
+ Default: 'SAME'.
1063
+ lhs_dilation : int or tuple of int, optional
1064
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
1065
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
1066
+ Default: 1.
1067
+ rhs_dilation : int or tuple of int, optional
1068
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
1069
+ or dilated convolution. Increases the receptive field without increasing parameters by
1070
+ inserting zeros between kernel elements. Useful for semantic segmentation and dense
1071
+ prediction tasks.
1072
+ Default: 1.
1073
+ groups : int, optional
1074
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
1075
+
1076
+ - groups=1: standard convolution (all-to-all connections)
1077
+ - groups>1: grouped convolution (reduces parameters by factor of groups)
1078
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
1079
+
1080
+ Default: 1.
1081
+ w_init : Callable or ArrayLike, optional
1082
+ Weight initializer for the convolutional kernel. Can be:
1083
+
1084
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
1085
+ - A callable that returns an array given a shape
1086
+ - A direct array matching the kernel shape
1087
+
1088
+ Default: XavierNormal().
1089
+ b_init : Callable or ArrayLike or None, optional
1090
+ Bias initializer. If None, no bias term is added to the output.
1091
+ Default: None.
1092
+ ws_gain : bool, optional
1093
+ Whether to include a learnable per-channel gain parameter in weight standardization.
1094
+ When True, adds a scaling factor that can be learned during training, improving
1095
+ model expressiveness. Highly recommended when using with Group Normalization.
1096
+ Default: True.
1097
+ eps : float, optional
1098
+ Small constant for numerical stability in weight standardization. Prevents division
1099
+ by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
1100
+ Default: 1e-4.
1101
+ w_mask : ArrayLike or Callable or None, optional
1102
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
1103
+ element-wise multiplied with the standardized kernel weights during the forward pass.
1104
+ Default: None.
1105
+ name : str, optional
1106
+ Name identifier for this module instance.
1107
+ Default: None.
1108
+ param_type : type, optional
1109
+ The parameter state class to use for managing learnable parameters.
1110
+ Default: ParamState.
1111
+
1112
+ Attributes
1113
+ ----------
1114
+ in_size : tuple of int
1115
+ The input shape (H, W, C) without batch dimension.
1116
+ out_size : tuple of int
1117
+ The output shape (H_out, W_out, out_channels) without batch dimension.
1118
+ in_channels : int
1119
+ Number of input channels.
1120
+ out_channels : int
1121
+ Number of output channels.
1122
+ kernel_size : tuple of int
1123
+ Size of the convolving kernel (height, width).
1124
+ weight : ParamState
1125
+ The learnable weights (and bias if specified) of the module.
1126
+ eps : float
1127
+ Small constant for numerical stability in weight standardization.
1128
+
1129
+ Examples
1130
+ --------
1131
+ .. code-block:: python
1132
+
1133
+ >>> import brainstate as brainstate
1134
+ >>> import jax.numpy as jnp
1135
+ >>>
1136
+ >>> # Create a 2D convolution with weight standardization
1137
+ >>> conv = brainstate.nn.ScaledWSConv2d(
1138
+ ... in_size=(64, 64, 3),
1139
+ ... out_channels=32,
1140
+ ... kernel_size=3
1141
+ ... )
1142
+ >>>
1143
+ >>> # Apply to input
1144
+ >>> x = jnp.ones((8, 64, 64, 3))
1145
+ >>> y = conv(x)
1146
+ >>> print(y.shape) # (8, 64, 64, 32)
1147
+ >>>
1148
+ >>> # Combine with custom settings for ResNet-style architecture
1149
+ >>> conv = brainstate.nn.ScaledWSConv2d(
1150
+ ... in_size=(224, 224, 3),
1151
+ ... out_channels=64,
1152
+ ... kernel_size=7,
1153
+ ... stride=2,
1154
+ ... padding='SAME',
1155
+ ... ws_gain=True,
1156
+ ... b_init=brainstate.init.ZeroInit()
1157
+ ... )
1158
+ >>>
1159
+ >>> # Depthwise separable convolution with weight standardization
1160
+ >>> conv = brainstate.nn.ScaledWSConv2d(
1161
+ ... in_size=(32, 32, 128),
1162
+ ... out_channels=128,
1163
+ ... kernel_size=3,
1164
+ ... groups=128,
1165
+ ... ws_gain=False
1166
+ ... )
1167
+
1168
+ Notes
1169
+ -----
1170
+ **Weight standardization formula:**
1171
+
1172
+ Weight standardization reparameterizes the convolutional weights as:
1173
+
1174
+ .. math::
1175
+ \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1176
+
1177
+ where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1178
+ of the weights computed per output channel, :math:`g` is a learnable gain
1179
+ parameter (if ws_gain=True), and :math:`\\epsilon` is a small constant.
1180
+
1181
+ **Benefits:**
1182
+
1183
+ - Reduces internal covariate shift
1184
+ - Smooths the loss landscape
1185
+ - Works well with Group Normalization
1186
+ - Improves training stability with small batch sizes
1187
+ - Enables training deeper networks more easily
1188
+
1189
+ References
1190
+ ----------
1191
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1192
+ Weight Standardization. arXiv preprint arXiv:1903.10520.
1193
+ """
1194
+ __module__ = 'brainstate.nn'
1195
+ num_spatial_dims: int = 2
1196
+
1197
+
1198
+ class ScaledWSConv3d(_ScaledWSConv):
1199
+ """
1200
+ Three-dimensional convolution with weight standardization.
1201
+
1202
+ This layer applies weight standardization to the convolutional kernel before
1203
+ performing the 3D convolution operation. Weight standardization normalizes the
1204
+ weights to have zero mean and unit variance, which improves training dynamics
1205
+ especially for 3D networks that are typically deeper and more parameter-heavy.
1206
+
1207
+ The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
1208
+ H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
1209
+
1210
+ Parameters
1211
+ ----------
1212
+ in_size : tuple of int
1213
+ The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
1214
+ W is width, D is depth, and C is the number of input channels. This argument is important
1215
+ as it is used to evaluate the output shape.
1216
+ out_channels : int
1217
+ The number of output channels (also called filters or feature maps). These determine
1218
+ the depth of the output feature map.
1219
+ kernel_size : int or tuple of int
1220
+ The shape of the convolutional kernel. Can be:
1221
+
1222
+ - An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
1223
+ - A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
1224
+ stride : int or tuple of int, optional
1225
+ The stride of the convolution. Controls how much the kernel moves at each step.
1226
+ Can be:
1227
+
1228
+ - An integer: same stride for all dimensions
1229
+ - A tuple of three integers: (stride_h, stride_w, stride_d)
1230
+
1231
+ Default: 1.
1232
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1233
+ The padding strategy. Options:
1234
+
1235
+ - 'SAME': output spatial size equals input size when stride=1
1236
+ - 'VALID': no padding, output size reduced by kernel size
1237
+ - int: same symmetric padding for all dimensions
1238
+ - (pad_h, pad_w, pad_d): different padding for each dimension
1239
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
1240
+
1241
+ Default: 'SAME'.
1242
+ lhs_dilation : int or tuple of int, optional
1243
+ The dilation factor for the input (left-hand side). Controls spacing between input elements.
1244
+ A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
1245
+ Default: 1.
1246
+ rhs_dilation : int or tuple of int, optional
1247
+ The dilation factor for the kernel (right-hand side). Also known as atrous convolution
1248
+ or dilated convolution. Increases the receptive field without increasing parameters by
1249
+ inserting zeros between kernel elements. Particularly valuable for 3D to capture
1250
+ multi-scale temporal/spatial context efficiently.
1251
+ Default: 1.
1252
+ groups : int, optional
1253
+ Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
1254
+
1255
+ - groups=1: standard convolution (all-to-all connections)
1256
+ - groups>1: grouped convolution (critical for reducing 3D conv computational cost)
1257
+ - groups=in_channels: depthwise convolution (each input channel convolved separately)
1258
+
1259
+ Default: 1.
1260
+ w_init : Callable or ArrayLike, optional
1261
+ Weight initializer for the convolutional kernel. Can be:
1262
+
1263
+ - An initializer instance (e.g., brainstate.init.XavierNormal())
1264
+ - A callable that returns an array given a shape
1265
+ - A direct array matching the kernel shape
1266
+
1267
+ Default: XavierNormal().
1268
+ b_init : Callable or ArrayLike or None, optional
1269
+ Bias initializer. If None, no bias term is added to the output.
1270
+ Default: None.
1271
+ ws_gain : bool, optional
1272
+ Whether to include a learnable per-channel gain parameter in weight standardization.
1273
+ When True, adds a scaling factor that can be learned during training, improving
1274
+ model expressiveness. Particularly beneficial for deep 3D networks.
1275
+ Default: True.
1276
+ eps : float, optional
1277
+ Small constant for numerical stability in weight standardization. Prevents division
1278
+ by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
1279
+ Default: 1e-4.
1280
+ w_mask : ArrayLike or Callable or None, optional
1281
+ Optional weight mask for structured sparsity or custom connectivity. The mask is
1282
+ element-wise multiplied with the standardized kernel weights during the forward pass.
1283
+ Default: None.
1284
+ name : str, optional
1285
+ Name identifier for this module instance.
1286
+ Default: None.
1287
+ param_type : type, optional
1288
+ The parameter state class to use for managing learnable parameters.
1289
+ Default: ParamState.
1290
+
1291
+ Attributes
1292
+ ----------
1293
+ in_size : tuple of int
1294
+ The input shape (H, W, D, C) without batch dimension.
1295
+ out_size : tuple of int
1296
+ The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
1297
+ in_channels : int
1298
+ Number of input channels.
1299
+ out_channels : int
1300
+ Number of output channels.
1301
+ kernel_size : tuple of int
1302
+ Size of the convolving kernel (height, width, depth).
1303
+ weight : ParamState
1304
+ The learnable weights (and bias if specified) of the module.
1305
+ eps : float
1306
+ Small constant for numerical stability in weight standardization.
1307
+
1308
+ Examples
1309
+ --------
1310
+ .. code-block:: python
1311
+
1312
+ >>> import brainstate as brainstate
1313
+ >>> import jax.numpy as jnp
1314
+ >>>
1315
+ >>> # Create a 3D convolution with weight standardization for video
1316
+ >>> conv = brainstate.nn.ScaledWSConv3d(
1317
+ ... in_size=(16, 64, 64, 3),
1318
+ ... out_channels=32,
1319
+ ... kernel_size=3
1320
+ ... )
1321
+ >>>
1322
+ >>> # Apply to input
1323
+ >>> x = jnp.ones((4, 16, 64, 64, 3))
1324
+ >>> y = conv(x)
1325
+ >>> print(y.shape) # (4, 16, 64, 64, 32)
1326
+ >>>
1327
+ >>> # For medical imaging with custom parameters
1328
+ >>> conv = brainstate.nn.ScaledWSConv3d(
1329
+ ... in_size=(32, 32, 32, 1),
1330
+ ... out_channels=64,
1331
+ ... kernel_size=(3, 3, 3),
1332
+ ... stride=2,
1333
+ ... ws_gain=True,
1334
+ ... eps=1e-5,
1335
+ ... b_init=brainstate.init.Constant(0.01)
1336
+ ... )
1337
+ >>>
1338
+ >>> # 3D grouped convolution with weight standardization
1339
+ >>> conv = brainstate.nn.ScaledWSConv3d(
1340
+ ... in_size=(8, 16, 16, 64),
1341
+ ... out_channels=64,
1342
+ ... kernel_size=3,
1343
+ ... groups=8,
1344
+ ... ws_gain=False
1345
+ ... )
1346
+
1347
+ Notes
1348
+ -----
1349
+ **Weight standardization formula:**
1350
+
1351
+ Weight standardization reparameterizes the convolutional weights as:
1352
+
1353
+ .. math::
1354
+ \\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
1355
+
1356
+ where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
1357
+ of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
1358
+ and :math:`\\epsilon` is a small constant for numerical stability.
1359
+
1360
+ **Why weight standardization for 3D:**
1361
+
1362
+ For 3D convolutions, weight standardization is particularly beneficial because:
1363
+
1364
+ - 3D networks are typically much deeper and harder to train
1365
+ - Reduces sensitivity to weight initialization
1366
+ - Improves gradient flow through very deep networks
1367
+ - Works well with limited computational resources (small batches)
1368
+ - Compatible with Group Normalization for batch-independent normalization
1369
+
1370
+ **Applications:**
1371
+
1372
+ Video understanding, medical imaging (CT, MRI scans), 3D object recognition,
1373
+ and temporal sequence modeling.
1374
+
1375
+ References
1376
+ ----------
1377
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
1378
+ Weight Standardization. arXiv preprint arXiv:1903.10520.
1379
+ """
1380
+ __module__ = 'brainstate.nn'
1381
+ num_spatial_dims: int = 3
1382
+
1383
+
1384
+ class _ConvTranspose(_BaseConv):
1385
+ """Base class for transposed convolution layers."""
1386
+ num_spatial_dims: int = None
1387
+
1388
+ def __init__(
1389
+ self,
1390
+ in_size: Sequence[int],
1391
+ out_channels: int,
1392
+ kernel_size: Union[int, Tuple[int, ...]],
1393
+ stride: Union[int, Tuple[int, ...]] = 1,
1394
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
1395
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
1396
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
1397
+ groups: int = 1,
1398
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
1399
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
1400
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
1401
+ channel_first: bool = False,
1402
+ name: str = None,
1403
+ param_type: type = ParamState,
1404
+ ):
1405
+ # Initialize with transpose=True for dimension numbers
1406
+ Module.__init__(self, name=name)
1407
+
1408
+ # general parameters
1409
+ assert self.num_spatial_dims + 1 == len(in_size)
1410
+ self.in_size = tuple(in_size)
1411
+ self.channel_first = channel_first
1412
+ self.channels_last = not channel_first
1413
+
1414
+ # Determine in_channels based on channel_first
1415
+ if self.channel_first:
1416
+ self.in_channels = in_size[0]
1417
+ else:
1418
+ self.in_channels = in_size[-1]
1419
+
1420
+ self.out_channels = out_channels
1421
+ self.stride = replicate(stride, self.num_spatial_dims, 'stride')
1422
+ self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
1423
+ self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
1424
+ self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
1425
+ self.groups = groups
1426
+ self.dimension_numbers = to_dimension_numbers(
1427
+ self.num_spatial_dims,
1428
+ channels_last=self.channels_last,
1429
+ transpose=True # Key difference from regular Conv
1430
+ )
1431
+
1432
+ # the padding parameter
1433
+ # For transposed convolution, string padding needs to be converted to explicit padding
1434
+ # when using lhs_dilation (stride) > 1
1435
+ if isinstance(padding, str):
1436
+ assert padding in ['SAME', 'VALID']
1437
+ self.padding_mode = padding
1438
+ # Compute explicit padding for transposed convolution
1439
+ if max(self.stride) > 1:
1440
+ # For transposed conv with stride, compute padding to achieve desired output size
1441
+ spatial_in_size = self.in_size[:-1] if not self.channel_first else self.in_size[1:]
1442
+ if padding == 'SAME':
1443
+ # For SAME padding with transposed conv: output_size = input_size * stride
1444
+ # Compute required padding to achieve this
1445
+ explicit_padding = []
1446
+ for i, (k, s, in_dim) in enumerate(zip(self.kernel_size, self.stride, spatial_in_size)):
1447
+ # Desired output size
1448
+ out_dim = in_dim * s
1449
+ # Calculate total padding needed
1450
+ # For transposed conv: out = (in - 1) * stride + kernel - 2 * pad
1451
+ # Solving for pad: pad = (kernel + (in-1) * stride - out) // 2
1452
+ total_pad = max(k + (in_dim - 1) * s - out_dim, 0)
1453
+ pad_left = total_pad // 2
1454
+ pad_right = total_pad - pad_left
1455
+ explicit_padding.append((pad_left, pad_right))
1456
+ padding = tuple(explicit_padding)
1457
+ else: # 'VALID'
1458
+ # For VALID padding: no padding
1459
+ padding = tuple((0, 0) for _ in range(self.num_spatial_dims))
1460
+ # If stride is 1, keep string padding
1461
+ elif isinstance(padding, int):
1462
+ self.padding_mode = 'explicit'
1463
+ padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
1464
+ elif isinstance(padding, (tuple, list)):
1465
+ self.padding_mode = 'explicit'
1466
+ if isinstance(padding[0], int):
1467
+ padding = (padding,) * self.num_spatial_dims
1468
+ elif isinstance(padding[0], (tuple, list)):
1469
+ if len(padding) == 1:
1470
+ padding = tuple(padding) * self.num_spatial_dims
1471
+ else:
1472
+ if len(padding) != self.num_spatial_dims:
1473
+ raise ValueError(
1474
+ f"Padding {padding} must be a Tuple[int, int], "
1475
+ f"or sequence of Tuple[int, int] with length 1, "
1476
+ f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
1477
+ )
1478
+ padding = tuple(padding)
1479
+ else:
1480
+ raise ValueError
1481
+ self.padding = padding
1482
+
1483
+ # the number of in-/out-channels
1484
+ assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
1485
+ assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
1486
+
1487
+ # kernel shape for transpose conv
1488
+ # When transpose=True in dimension_numbers, kernel is (spatial..., out_channels, in_channels // groups)
1489
+ # This matches JAX's expectation for transposed convolution
1490
+ kernel_shape = tuple(self.kernel_size) + (self.out_channels, self.in_channels // self.groups)
1491
+ self.kernel_shape = kernel_shape
1492
+ self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
1493
+
1494
+ self.w_initializer = w_init
1495
+ self.b_initializer = b_init
1496
+
1497
+ # --- weights --- #
1498
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
1499
+ params = dict(weight=weight)
1500
+ if self.b_initializer is not None:
1501
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
1502
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
1503
+ params['bias'] = bias
1504
+
1505
+ # The weight operation
1506
+ self.weight = param_type(params)
1507
+
1508
+ # Evaluate the output shape
1509
+ test_input_shape = (128,) + self.in_size
1510
+ abstract_y = jax.eval_shape(
1511
+ self._conv_op,
1512
+ jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
1513
+ params
1514
+ )
1515
+ y_shape = abstract_y.shape[1:]
1516
+ self.out_size = y_shape
1517
+
1518
+ def _conv_op(self, x, params):
1519
+ w = params['weight']
1520
+ if self.w_mask is not None:
1521
+ w = w * self.w_mask
1522
+ # For transposed convolution:
1523
+ # - window_strides should be (1,1,...) - no striding in the conv operation
1524
+ # - lhs_dilation should be the stride - this creates the upsampling effect
1525
+ window_strides = (1,) * self.num_spatial_dims
1526
+ y = jax.lax.conv_general_dilated(
1527
+ lhs=x,
1528
+ rhs=w,
1529
+ window_strides=window_strides,
1530
+ padding=self.padding,
1531
+ lhs_dilation=self.stride, # For transpose conv, use stride as lhs_dilation
1532
+ rhs_dilation=self.rhs_dilation,
1533
+ feature_group_count=self.groups,
1534
+ dimension_numbers=self.dimension_numbers
1535
+ )
1536
+ if 'bias' in params:
1537
+ y = y + params['bias']
1538
+ return y
1539
+
1540
+
1541
+ class ConvTranspose1d(_ConvTranspose):
1542
+ """
1543
+ One-dimensional transposed convolution layer (also known as deconvolution).
1544
+
1545
+ Applies a 1D transposed convolution over an input signal. Transposed convolution
1546
+ is used for upsampling, reversing the spatial transformation of a regular convolution.
1547
+ It's commonly used in autoencoders, GANs, and semantic segmentation networks.
1548
+
1549
+ The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
1550
+ L is the sequence length, and C is the number of input channels (channels-last format).
1551
+
1552
+ Parameters
1553
+ ----------
1554
+ in_size : tuple of int
1555
+ The input shape without the batch dimension. For ConvTranspose1d: (L, C) where L
1556
+ is the sequence length and C is the number of input channels.
1557
+ out_channels : int
1558
+ The number of output channels (feature maps) produced by the transposed convolution.
1559
+ kernel_size : int or tuple of int
1560
+ The shape of the convolutional kernel. For 1D, can be an integer or a single-element tuple.
1561
+ stride : int or tuple of int, optional
1562
+ The stride of the transposed convolution. Larger strides produce larger output sizes,
1563
+ which is the opposite behavior of regular convolution. Default: 1.
1564
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1565
+ The padding strategy. Options:
1566
+
1567
+ - 'SAME': output length approximately equals input_length * stride
1568
+ - 'VALID': no padding, maximum output size
1569
+ - int: symmetric padding
1570
+ - (pad_before, pad_after): explicit padding for the sequence dimension
1571
+
1572
+ Default: 'SAME'.
1573
+ lhs_dilation : int or tuple of int, optional
1574
+ The dilation factor for the input. For transposed convolution, this is typically
1575
+ set equal to stride internally. Default: 1.
1576
+ rhs_dilation : int or tuple of int, optional
1577
+ The dilation factor for the kernel. Increases the receptive field without increasing
1578
+ parameters by inserting zeros between kernel elements. Default: 1.
1579
+ groups : int, optional
1580
+ Number of groups for grouped transposed convolution. Both `in_channels` and
1581
+ `out_channels` must be divisible by `groups`. Default: 1.
1582
+ w_init : Callable or ArrayLike, optional
1583
+ The initializer for the convolutional kernel weights. Default: XavierNormal().
1584
+ b_init : Callable or ArrayLike or None, optional
1585
+ The initializer for the bias. If None, no bias is added. Default: None.
1586
+ w_mask : ArrayLike or Callable or None, optional
1587
+ An optional mask applied to the weights during forward pass. Default: None.
1588
+ channel_first : bool, optional
1589
+ If True, uses channels-first format (e.g., [B, C, L]). If False, uses channels-last
1590
+ format (e.g., [B, L, C]). Default: False (channels-last, JAX convention).
1591
+ name : str, optional
1592
+ The name of the module. Default: None.
1593
+ param_type : type, optional
1594
+ The type of parameter state to use. Default: ParamState.
1595
+
1596
+ Attributes
1597
+ ----------
1598
+ in_size : tuple of int
1599
+ The input shape (L, C) without batch dimension.
1600
+ out_size : tuple of int
1601
+ The output shape (L_out, out_channels) without batch dimension.
1602
+ in_channels : int
1603
+ Number of input channels.
1604
+ out_channels : int
1605
+ Number of output channels.
1606
+ kernel_size : tuple of int
1607
+ Size of the convolving kernel.
1608
+ weight : ParamState
1609
+ The learnable weights (and bias if specified) of the module.
1610
+
1611
+ Examples
1612
+ --------
1613
+ .. code-block:: python
1614
+
1615
+ >>> import brainstate as brainstate
1616
+ >>> import jax.numpy as jnp
1617
+ >>>
1618
+ >>> # Create a 1D transposed convolution layer for upsampling
1619
+ >>> conv_transpose = brainstate.nn.ConvTranspose1d(
1620
+ ... in_size=(28, 16),
1621
+ ... out_channels=8,
1622
+ ... kernel_size=4,
1623
+ ... stride=2
1624
+ ... )
1625
+ >>>
1626
+ >>> # Apply to input: batch_size=2, length=28, channels=16
1627
+ >>> x = jnp.ones((2, 28, 16))
1628
+ >>> y = conv_transpose(x)
1629
+ >>> print(y.shape) # Output will be upsampled
1630
+ >>>
1631
+ >>> # Without batch dimension
1632
+ >>> x_single = jnp.ones((28, 16))
1633
+ >>> y_single = conv_transpose(x_single)
1634
+ >>>
1635
+ >>> # Channels-first format (PyTorch style)
1636
+ >>> conv_transpose = brainstate.nn.ConvTranspose1d(
1637
+ ... in_size=(16, 28),
1638
+ ... out_channels=8,
1639
+ ... kernel_size=4,
1640
+ ... stride=2,
1641
+ ... channel_first=True
1642
+ ... )
1643
+ >>> x = jnp.ones((2, 16, 28))
1644
+ >>> y = conv_transpose(x)
1645
+
1646
+ Notes
1647
+ -----
1648
+ **Output dimensions:**
1649
+
1650
+ Unlike regular convolution, transposed convolution increases spatial dimensions.
1651
+ With stride > 1, the output is larger than the input:
1652
+
1653
+ - output_length ≈ input_length * stride (depends on padding and kernel size)
1654
+
1655
+ **Relationship to regular convolution:**
1656
+
1657
+ Transposed convolution performs the gradient computation of a regular convolution
1658
+ with respect to its input. It's sometimes called "deconvolution" but this term
1659
+ is mathematically imprecise.
1660
+
1661
+ **Common use cases:**
1662
+
1663
+ - Upsampling in encoder-decoder architectures
1664
+ - Generative models (GANs, VAEs)
1665
+ - Semantic segmentation (U-Net, FCN)
1666
+ - Super-resolution networks
1667
+
1668
+ **Comparison with PyTorch:**
1669
+
1670
+ - PyTorch uses channels-first by default; BrainState uses channels-last
1671
+ - Set `channel_first=True` for PyTorch-compatible format
1672
+ - PyTorch's `output_padding` is handled through padding parameter
1673
+ """
1674
+ __module__ = 'brainstate.nn'
1675
+ num_spatial_dims: int = 1
1676
+
1677
+
1678
+ class ConvTranspose2d(_ConvTranspose):
1679
+ """
1680
+ Two-dimensional transposed convolution layer (also known as deconvolution).
1681
+
1682
+ Applies a 2D transposed convolution over an input signal. Transposed convolution
1683
+ is the gradient of a regular convolution with respect to its input, commonly used
1684
+ for upsampling feature maps in encoder-decoder architectures, GANs, and segmentation.
1685
+
1686
+ The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
1687
+ H is height, W is width, and C is the number of input channels (channels-last format).
1688
+
1689
+ Parameters
1690
+ ----------
1691
+ in_size : tuple of int
1692
+ The input shape without the batch dimension. For ConvTranspose2d: (H, W, C) where
1693
+ H is height, W is width, and C is the number of input channels.
1694
+ out_channels : int
1695
+ The number of output channels (feature maps) produced by the transposed convolution.
1696
+ kernel_size : int or tuple of int
1697
+ The shape of the convolutional kernel. Can be:
1698
+
1699
+ - An integer (e.g., 4): creates a square kernel (4, 4)
1700
+ - A tuple of two integers (e.g., (4, 4)): creates a (height, width) kernel
1701
+ stride : int or tuple of int, optional
1702
+ The stride of the transposed convolution. Controls the upsampling factor.
1703
+ Can be:
1704
+
1705
+ - An integer: same stride for both dimensions
1706
+ - A tuple of two integers: (stride_height, stride_width)
1707
+
1708
+ Larger strides produce larger outputs. Default: 1.
1709
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1710
+ The padding strategy. Options:
1711
+
1712
+ - 'SAME': output size approximately equals input_size * stride
1713
+ - 'VALID': no padding, maximum output size
1714
+ - int: same symmetric padding for all dimensions
1715
+ - (pad_h, pad_w): different padding for each dimension
1716
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
1717
+
1718
+ Default: 'SAME'.
1719
+ lhs_dilation : int or tuple of int, optional
1720
+ The dilation factor for the input. For transposed convolution, this is typically
1721
+ set equal to stride internally. Default: 1.
1722
+ rhs_dilation : int or tuple of int, optional
1723
+ The dilation factor for the kernel. Increases the receptive field without increasing
1724
+ parameters by inserting zeros between kernel elements. Default: 1.
1725
+ groups : int, optional
1726
+ Number of groups for grouped transposed convolution. Must divide both `in_channels`
1727
+ and `out_channels`. Default: 1.
1728
+ w_init : Callable or ArrayLike, optional
1729
+ Weight initializer for the convolutional kernel. Default: XavierNormal().
1730
+ b_init : Callable or ArrayLike or None, optional
1731
+ Bias initializer. If None, no bias term is added. Default: None.
1732
+ w_mask : ArrayLike or Callable or None, optional
1733
+ Optional weight mask for structured sparsity. Default: None.
1734
+ channel_first : bool, optional
1735
+ If True, uses channels-first format (e.g., [B, C, H, W]). If False, uses channels-last
1736
+ format (e.g., [B, H, W, C]). Default: False (channels-last, JAX convention).
1737
+ name : str, optional
1738
+ Name identifier for this module instance. Default: None.
1739
+ param_type : type, optional
1740
+ The parameter state class to use. Default: ParamState.
1741
+
1742
+ Attributes
1743
+ ----------
1744
+ in_size : tuple of int
1745
+ The input shape (H, W, C) without batch dimension.
1746
+ out_size : tuple of int
1747
+ The output shape (H_out, W_out, out_channels) without batch dimension.
1748
+ in_channels : int
1749
+ Number of input channels.
1750
+ out_channels : int
1751
+ Number of output channels.
1752
+ kernel_size : tuple of int
1753
+ Size of the convolving kernel (height, width).
1754
+ weight : ParamState
1755
+ The learnable weights (and bias if specified) of the module.
1756
+
1757
+ Examples
1758
+ --------
1759
+ .. code-block:: python
1760
+
1761
+ >>> import brainstate as brainstate
1762
+ >>> import jax.numpy as jnp
1763
+ >>>
1764
+ >>> # Create a 2D transposed convolution for upsampling
1765
+ >>> conv_transpose = brainstate.nn.ConvTranspose2d(
1766
+ ... in_size=(32, 32, 64),
1767
+ ... out_channels=32,
1768
+ ... kernel_size=4,
1769
+ ... stride=2
1770
+ ... )
1771
+ >>>
1772
+ >>> # Apply to input: batch_size=8, height=32, width=32, channels=64
1773
+ >>> x = jnp.ones((8, 32, 32, 64))
1774
+ >>> y = conv_transpose(x)
1775
+ >>> print(y.shape) # Output will be approximately (8, 64, 64, 32)
1776
+ >>>
1777
+ >>> # Without batch dimension
1778
+ >>> x_single = jnp.ones((32, 32, 64))
1779
+ >>> y_single = conv_transpose(x_single)
1780
+ >>>
1781
+ >>> # Decoder in autoencoder (upsampling path)
1782
+ >>> decoder = brainstate.nn.ConvTranspose2d(
1783
+ ... in_size=(16, 16, 128),
1784
+ ... out_channels=64,
1785
+ ... kernel_size=4,
1786
+ ... stride=2,
1787
+ ... padding='SAME',
1788
+ ... b_init=brainstate.init.Constant(0.0)
1789
+ ... )
1790
+ >>>
1791
+ >>> # Channels-first format (PyTorch style)
1792
+ >>> conv_transpose = brainstate.nn.ConvTranspose2d(
1793
+ ... in_size=(64, 32, 32),
1794
+ ... out_channels=32,
1795
+ ... kernel_size=4,
1796
+ ... stride=2,
1797
+ ... channel_first=True
1798
+ ... )
1799
+ >>> x = jnp.ones((8, 64, 32, 32))
1800
+ >>> y = conv_transpose(x)
1801
+
1802
+ Notes
1803
+ -----
1804
+ **Output dimensions:**
1805
+
1806
+ Transposed convolution increases spatial dimensions, with the upsampling factor
1807
+ primarily controlled by stride:
1808
+
1809
+ - output_size ≈ input_size * stride (exact size depends on padding and kernel size)
1810
+ - 'SAME' padding: output_size = input_size * stride
1811
+ - 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
1812
+
1813
+ **Relationship to regular convolution:**
1814
+
1815
+ Transposed convolution is the backward pass of a regular convolution. If a regular
1816
+ convolution reduces spatial dimensions from X to Y, a transposed convolution with
1817
+ the same parameters increases dimensions from Y back to approximately X.
1818
+
1819
+ **Common use cases:**
1820
+
1821
+ - Image segmentation (U-Net, SegNet, FCN)
1822
+ - Image-to-image translation (pix2pix, CycleGAN)
1823
+ - Generative models (DCGAN, VAE decoders)
1824
+ - Super-resolution networks
1825
+ - Autoencoders (decoder path)
1826
+
1827
+ **Comparison with PyTorch:**
1828
+
1829
+ - PyTorch uses channels-first by default; BrainState uses channels-last
1830
+ - Set `channel_first=True` for PyTorch-compatible format
1831
+ - Kernel shape convention: PyTorch stores (C_in, C_out, H, W), BrainState uses (H, W, C_out, C_in)
1832
+ - PyTorch's `output_padding` parameter controls output size; use padding parameter here
1833
+
1834
+ **Tips:**
1835
+
1836
+ - Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
1837
+ - Initialize with bilinear upsampling weights for better convergence in segmentation
1838
+ - Combine with batch normalization or group normalization for stable training
1839
+ """
1840
+ __module__ = 'brainstate.nn'
1841
+ num_spatial_dims: int = 2
1842
+
1843
+
1844
+ class ConvTranspose3d(_ConvTranspose):
1845
+ """
1846
+ Three-dimensional transposed convolution layer (also known as deconvolution).
1847
+
1848
+ Applies a 3D transposed convolution over an input signal. Used for upsampling
1849
+ 3D feature maps in video generation, 3D segmentation, and volumetric reconstruction.
1850
+
1851
+ The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
1852
+ H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
1853
+
1854
+ Parameters
1855
+ ----------
1856
+ in_size : tuple of int
1857
+ The input shape without the batch dimension. For ConvTranspose3d: (H, W, D, C) where
1858
+ H is height, W is width, D is depth, and C is the number of input channels.
1859
+ out_channels : int
1860
+ The number of output channels (feature maps) produced by the transposed convolution.
1861
+ kernel_size : int or tuple of int
1862
+ The shape of the convolutional kernel. Can be:
1863
+
1864
+ - An integer (e.g., 4): creates a cubic kernel (4, 4, 4)
1865
+ - A tuple of three integers (e.g., (4, 4, 4)): creates a (height, width, depth) kernel
1866
+ stride : int or tuple of int, optional
1867
+ The stride of the transposed convolution. Controls the upsampling factor.
1868
+ Can be:
1869
+
1870
+ - An integer: same stride for all dimensions
1871
+ - A tuple of three integers: (stride_h, stride_w, stride_d)
1872
+
1873
+ Larger strides produce larger outputs. Default: 1.
1874
+ padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
1875
+ The padding strategy. Options:
1876
+
1877
+ - 'SAME': output size approximately equals input_size * stride
1878
+ - 'VALID': no padding, maximum output size
1879
+ - int: same symmetric padding for all dimensions
1880
+ - (pad_h, pad_w, pad_d): different padding for each dimension
1881
+ - [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit
1882
+
1883
+ Default: 'SAME'.
1884
+ lhs_dilation : int or tuple of int, optional
1885
+ The dilation factor for the input. For transposed convolution, this is typically
1886
+ set equal to stride internally. Default: 1.
1887
+ rhs_dilation : int or tuple of int, optional
1888
+ The dilation factor for the kernel. Increases the receptive field without increasing
1889
+ parameters. Default: 1.
1890
+ groups : int, optional
1891
+ Number of groups for grouped transposed convolution. Must divide both `in_channels`
1892
+ and `out_channels`. Useful for reducing computational cost in 3D. Default: 1.
1893
+ w_init : Callable or ArrayLike, optional
1894
+ Weight initializer for the convolutional kernel. Default: XavierNormal().
1895
+ b_init : Callable or ArrayLike or None, optional
1896
+ Bias initializer. If None, no bias term is added. Default: None.
1897
+ w_mask : ArrayLike or Callable or None, optional
1898
+ Optional weight mask for structured sparsity. Default: None.
1899
+ channel_first : bool, optional
1900
+ If True, uses channels-first format (e.g., [B, C, H, W, D]). If False, uses channels-last
1901
+ format (e.g., [B, H, W, D, C]). Default: False (channels-last, JAX convention).
1902
+ name : str, optional
1903
+ Name identifier for this module instance. Default: None.
1904
+ param_type : type, optional
1905
+ The parameter state class to use. Default: ParamState.
1906
+
1907
+ Attributes
1908
+ ----------
1909
+ in_size : tuple of int
1910
+ The input shape (H, W, D, C) without batch dimension.
1911
+ out_size : tuple of int
1912
+ The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
1913
+ in_channels : int
1914
+ Number of input channels.
1915
+ out_channels : int
1916
+ Number of output channels.
1917
+ kernel_size : tuple of int
1918
+ Size of the convolving kernel (height, width, depth).
1919
+ weight : ParamState
1920
+ The learnable weights (and bias if specified) of the module.
1921
+
1922
+ Examples
1923
+ --------
1924
+ .. code-block:: python
1925
+
1926
+ >>> import brainstate as brainstate
1927
+ >>> import jax.numpy as jnp
1928
+ >>>
1929
+ >>> # Create a 3D transposed convolution for video upsampling
1930
+ >>> conv_transpose = brainstate.nn.ConvTranspose3d(
1931
+ ... in_size=(8, 16, 16, 64),
1932
+ ... out_channels=32,
1933
+ ... kernel_size=4,
1934
+ ... stride=2
1935
+ ... )
1936
+ >>>
1937
+ >>> # Apply to input: batch_size=4, frames=8, height=16, width=16, channels=64
1938
+ >>> x = jnp.ones((4, 8, 16, 16, 64))
1939
+ >>> y = conv_transpose(x)
1940
+ >>> print(y.shape) # Output will be approximately (4, 16, 32, 32, 32)
1941
+ >>>
1942
+ >>> # Without batch dimension
1943
+ >>> x_single = jnp.ones((8, 16, 16, 64))
1944
+ >>> y_single = conv_transpose(x_single)
1945
+ >>>
1946
+ >>> # For medical imaging reconstruction
1947
+ >>> decoder = brainstate.nn.ConvTranspose3d(
1948
+ ... in_size=(16, 16, 16, 128),
1949
+ ... out_channels=64,
1950
+ ... kernel_size=(4, 4, 4),
1951
+ ... stride=2,
1952
+ ... padding='SAME',
1953
+ ... b_init=brainstate.init.Constant(0.0)
1954
+ ... )
1955
+ >>>
1956
+ >>> # Channels-first format (PyTorch style)
1957
+ >>> conv_transpose = brainstate.nn.ConvTranspose3d(
1958
+ ... in_size=(64, 8, 16, 16),
1959
+ ... out_channels=32,
1960
+ ... kernel_size=4,
1961
+ ... stride=2,
1962
+ ... channel_first=True
1963
+ ... )
1964
+ >>> x = jnp.ones((4, 64, 8, 16, 16))
1965
+ >>> y = conv_transpose(x)
1966
+
1967
+ Notes
1968
+ -----
1969
+ **Output dimensions:**
1970
+
1971
+ Transposed convolution increases spatial dimensions:
1972
+
1973
+ - output_size ≈ input_size * stride (exact size depends on padding and kernel size)
1974
+ - 'SAME' padding: output_size = input_size * stride
1975
+ - 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
1976
+
1977
+ **Computational considerations:**
1978
+
1979
+ 3D transposed convolutions are very computationally expensive. Consider:
1980
+
1981
+ - Using grouped convolutions (groups > 1) to reduce parameters
1982
+ - Smaller kernel sizes
1983
+ - Progressive upsampling (multiple layers with stride=2)
1984
+ - Separable convolutions for large-scale applications
1985
+
1986
+ **Common use cases:**
1987
+
1988
+ - Video generation and prediction
1989
+ - 3D medical image segmentation (U-Net 3D)
1990
+ - Volumetric reconstruction
1991
+ - 3D super-resolution
1992
+ - Video frame interpolation
1993
+ - 3D VAE decoders
1994
+
1995
+ **Comparison with PyTorch:**
1996
+
1997
+ - PyTorch uses channels-first by default; BrainState uses channels-last
1998
+ - Set `channel_first=True` for PyTorch-compatible format
1999
+ - Kernel shape convention differs between frameworks
2000
+ - PyTorch's `output_padding` parameter is handled through padding here
2001
+
2002
+ **Tips:**
2003
+
2004
+ - Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
2005
+ - Group normalization often works better than batch normalization for 3D
2006
+ - Consider using smaller batch sizes due to memory constraints
2007
+ - Progressive upsampling (2x at a time) is more stable than large strides
2008
+ """
2009
+ __module__ = 'brainstate.nn'
2010
+ num_spatial_dims: int = 3