brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_conv.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -18,19 +18,22 @@
|
|
18
18
|
import collections.abc
|
19
19
|
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
20
20
|
|
21
|
+
import brainunit as u
|
21
22
|
import jax
|
22
23
|
import jax.numpy as jnp
|
23
24
|
|
24
|
-
from brainstate import init, functional
|
25
25
|
from brainstate._state import ParamState
|
26
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from . import init as init
|
27
28
|
from ._module import Module
|
29
|
+
from ._normalizations import weight_standardization
|
28
30
|
|
29
31
|
T = TypeVar('T')
|
30
32
|
|
31
33
|
__all__ = [
|
32
34
|
'Conv1d', 'Conv2d', 'Conv3d',
|
33
35
|
'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
|
36
|
+
'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
|
34
37
|
]
|
35
38
|
|
36
39
|
|
@@ -39,7 +42,40 @@ def to_dimension_numbers(
|
|
39
42
|
channels_last: bool,
|
40
43
|
transpose: bool
|
41
44
|
) -> jax.lax.ConvDimensionNumbers:
|
42
|
-
"""
|
45
|
+
"""
|
46
|
+
Create a `lax.ConvDimensionNumbers` for the given inputs.
|
47
|
+
|
48
|
+
This function generates the dimension specification needed for JAX's convolution
|
49
|
+
operations based on the number of spatial dimensions and data format.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
num_spatial_dims : int
|
54
|
+
The number of spatial dimensions (e.g., 1 for Conv1d, 2 for Conv2d, 3 for Conv3d).
|
55
|
+
channels_last : bool
|
56
|
+
|
57
|
+
- If True, the input format is channels-last (e.g., [B, H, W, C] for 2D).
|
58
|
+
- If False, the input format is channels-first (e.g., [B, C, H, W] for 2D).
|
59
|
+
transpose : bool
|
60
|
+
|
61
|
+
- If True, creates dimension numbers for transposed convolution.
|
62
|
+
- If False, creates dimension numbers for standard convolution.
|
63
|
+
|
64
|
+
Returns
|
65
|
+
-------
|
66
|
+
jax.lax.ConvDimensionNumbers
|
67
|
+
A named tuple specifying the dimension layout for lhs (input), rhs (kernel),
|
68
|
+
and output of the convolution operation.
|
69
|
+
|
70
|
+
Examples
|
71
|
+
--------
|
72
|
+
.. code-block:: python
|
73
|
+
|
74
|
+
>>> # For 2D convolution with channels-last format
|
75
|
+
>>> dim_nums = to_dimension_numbers(num_spatial_dims=2, channels_last=True, transpose=False)
|
76
|
+
>>> print(dim_nums.lhs_spec) # Input layout: (batch, spatial_1, spatial_2, channel)
|
77
|
+
(0, 3, 1, 2)
|
78
|
+
"""
|
43
79
|
num_dims = num_spatial_dims + 2
|
44
80
|
if channels_last:
|
45
81
|
spatial_dims = tuple(range(1, num_dims - 1))
|
@@ -61,7 +97,48 @@ def replicate(
|
|
61
97
|
num_replicate: int,
|
62
98
|
name: str,
|
63
99
|
) -> Tuple[T, ...]:
|
64
|
-
"""
|
100
|
+
"""
|
101
|
+
Replicates entry in `element` `num_replicate` times if needed.
|
102
|
+
|
103
|
+
This utility function ensures that parameters like kernel_size, stride, etc.
|
104
|
+
are properly formatted as tuples with the correct length for multi-dimensional
|
105
|
+
convolutions.
|
106
|
+
|
107
|
+
Parameters
|
108
|
+
----------
|
109
|
+
element : T or Sequence[T]
|
110
|
+
The element to replicate. Can be a scalar, string, or sequence.
|
111
|
+
num_replicate : int
|
112
|
+
The number of times to replicate the element.
|
113
|
+
name : str
|
114
|
+
The name of the parameter (used for error messages).
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
tuple of T
|
119
|
+
A tuple containing the replicated elements.
|
120
|
+
|
121
|
+
Raises
|
122
|
+
------
|
123
|
+
TypeError
|
124
|
+
If the element is a sequence with length not equal to 1 or `num_replicate`.
|
125
|
+
|
126
|
+
Examples
|
127
|
+
--------
|
128
|
+
.. code-block:: python
|
129
|
+
|
130
|
+
>>> # Replicate a scalar value
|
131
|
+
>>> replicate(3, 2, 'kernel_size')
|
132
|
+
(3, 3)
|
133
|
+
>>>
|
134
|
+
>>> # Keep a sequence as is if already correct length
|
135
|
+
>>> replicate((3, 5), 2, 'kernel_size')
|
136
|
+
(3, 5)
|
137
|
+
>>>
|
138
|
+
>>> # Replicate a single-element sequence
|
139
|
+
>>> replicate([3], 2, 'kernel_size')
|
140
|
+
(3, 3)
|
141
|
+
"""
|
65
142
|
if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
|
66
143
|
return (element,) * num_replicate
|
67
144
|
elif len(element) == 1:
|
@@ -91,6 +168,7 @@ class _BaseConv(Module):
|
|
91
168
|
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
92
169
|
groups: int = 1,
|
93
170
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
171
|
+
channel_first: bool = False,
|
94
172
|
name: str = None,
|
95
173
|
):
|
96
174
|
super().__init__(name=name)
|
@@ -98,14 +176,26 @@ class _BaseConv(Module):
|
|
98
176
|
# general parameters
|
99
177
|
assert self.num_spatial_dims + 1 == len(in_size)
|
100
178
|
self.in_size = tuple(in_size)
|
101
|
-
self.
|
179
|
+
self.channel_first = channel_first
|
180
|
+
self.channels_last = not channel_first
|
181
|
+
|
182
|
+
# Determine in_channels based on channel_first
|
183
|
+
if self.channel_first:
|
184
|
+
self.in_channels = in_size[0]
|
185
|
+
else:
|
186
|
+
self.in_channels = in_size[-1]
|
187
|
+
|
102
188
|
self.out_channels = out_channels
|
103
189
|
self.stride = replicate(stride, self.num_spatial_dims, 'stride')
|
104
190
|
self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
|
105
191
|
self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
|
106
192
|
self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
|
107
193
|
self.groups = groups
|
108
|
-
self.dimension_numbers = to_dimension_numbers(
|
194
|
+
self.dimension_numbers = to_dimension_numbers(
|
195
|
+
self.num_spatial_dims,
|
196
|
+
channels_last=self.channels_last,
|
197
|
+
transpose=False
|
198
|
+
)
|
109
199
|
|
110
200
|
# the padding parameter
|
111
201
|
if isinstance(padding, str):
|
@@ -147,30 +237,30 @@ class _BaseConv(Module):
|
|
147
237
|
else:
|
148
238
|
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
149
239
|
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
150
|
-
|
151
|
-
|
240
|
+
|
241
|
+
# Check shape matches expected in_size
|
242
|
+
if self.channel_first:
|
243
|
+
# For channels-first, expected shape is already (C, spatial...)
|
244
|
+
expected_shape = self.in_size
|
245
|
+
else:
|
246
|
+
# For channels-last, expected shape is (spatial..., C)
|
247
|
+
expected_shape = self.in_size
|
248
|
+
|
249
|
+
if expected_shape != x_shape:
|
250
|
+
raise ValueError(f"The expected input shape is {expected_shape}, while we got {x_shape}.")
|
152
251
|
|
153
252
|
def update(self, x):
|
154
253
|
self._check_input_dim(x)
|
155
254
|
non_batching = False
|
156
255
|
if x.ndim == self.num_spatial_dims + 1:
|
157
|
-
x =
|
256
|
+
x = u.math.expand_dims(x, 0)
|
158
257
|
non_batching = True
|
159
258
|
y = self._conv_op(x, self.weight.value)
|
160
|
-
return y
|
259
|
+
return u.math.squeeze(y, axis=0) if non_batching else y
|
161
260
|
|
162
261
|
def _conv_op(self, x, params):
|
163
262
|
raise NotImplementedError
|
164
263
|
|
165
|
-
def __repr__(self):
|
166
|
-
return (f'{self.__class__.__name__}('
|
167
|
-
f'in_channels={self.in_channels}, '
|
168
|
-
f'out_channels={self.out_channels}, '
|
169
|
-
f'kernel_size={self.kernel_size}, '
|
170
|
-
f'stride={self.stride}, '
|
171
|
-
f'padding={self.padding}, '
|
172
|
-
f'groups={self.groups})')
|
173
|
-
|
174
264
|
|
175
265
|
class _Conv(_BaseConv):
|
176
266
|
num_spatial_dims: int = None
|
@@ -188,6 +278,7 @@ class _Conv(_BaseConv):
|
|
188
278
|
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
189
279
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
190
280
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
281
|
+
channel_first: bool = False,
|
191
282
|
name: str = None,
|
192
283
|
param_type: type = ParamState,
|
193
284
|
):
|
@@ -201,6 +292,7 @@ class _Conv(_BaseConv):
|
|
201
292
|
rhs_dilation=rhs_dilation,
|
202
293
|
groups=groups,
|
203
294
|
w_mask=w_mask,
|
295
|
+
channel_first=channel_first,
|
204
296
|
name=name
|
205
297
|
)
|
206
298
|
|
@@ -219,9 +311,10 @@ class _Conv(_BaseConv):
|
|
219
311
|
self.weight = param_type(params)
|
220
312
|
|
221
313
|
# Evaluate the output shape
|
314
|
+
test_input_shape = (128,) + self.in_size
|
222
315
|
abstract_y = jax.eval_shape(
|
223
316
|
self._conv_op,
|
224
|
-
jax.ShapeDtypeStruct(
|
317
|
+
jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
|
225
318
|
params
|
226
319
|
)
|
227
320
|
y_shape = abstract_y.shape[1:]
|
@@ -247,89 +340,433 @@ class _Conv(_BaseConv):
|
|
247
340
|
|
248
341
|
|
249
342
|
class Conv1d(_Conv):
|
250
|
-
"""
|
343
|
+
"""
|
344
|
+
One-dimensional convolution layer.
|
251
345
|
|
252
|
-
|
346
|
+
Applies a 1D convolution over an input signal composed of several input planes.
|
347
|
+
The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
|
348
|
+
L is the sequence length, and C is the number of input channels.
|
349
|
+
|
350
|
+
This layer creates a convolution kernel that is convolved with the layer input
|
351
|
+
over a single spatial dimension to produce a tensor of outputs.
|
253
352
|
|
254
353
|
Parameters
|
255
354
|
----------
|
256
|
-
|
355
|
+
in_size : tuple of int
|
356
|
+
The input shape without the batch dimension. This argument is important as it is
|
357
|
+
used to evaluate the output shape. For Conv1d: (L, C), Conv2d: (H, W, C), Conv3d: (H, W, D, C).
|
358
|
+
out_channels : int
|
359
|
+
The number of output channels (also called filters or feature maps).
|
360
|
+
kernel_size : int or tuple of int
|
361
|
+
The shape of the convolutional kernel. For 1D convolution, the kernel size can be
|
362
|
+
passed as an integer. For 2D and 3D convolutions, it should be a tuple of integers
|
363
|
+
or a single integer (which will be replicated for all spatial dimensions).
|
364
|
+
stride : int or tuple of int, optional
|
365
|
+
The stride of the convolution. An integer or a sequence of `n` integers, representing
|
366
|
+
the inter-window strides along each spatial dimension. Default: 1.
|
367
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
368
|
+
The padding strategy. Can be:
|
369
|
+
|
370
|
+
- 'SAME': pads the input so the output has the same shape as input when stride=1
|
371
|
+
- 'VALID': no padding
|
372
|
+
- int: symmetric padding applied to all spatial dimensions
|
373
|
+
- tuple of (low, high): padding for each dimension
|
374
|
+
- sequence of tuples: explicit padding for each spatial dimension
|
375
|
+
|
376
|
+
Default: 'SAME'.
|
377
|
+
lhs_dilation : int or tuple of int, optional
|
378
|
+
The dilation factor for the input. An integer or a sequence of `n` integers, giving
|
379
|
+
the dilation factor to apply in each spatial dimension of inputs. Convolution with
|
380
|
+
input dilation `d` is equivalent to transposed convolution with stride `d`.
|
381
|
+
Default: 1.
|
382
|
+
rhs_dilation : int or tuple of int, optional
|
383
|
+
The dilation factor for the kernel. An integer or a sequence of `n` integers, giving
|
384
|
+
the dilation factor to apply in each spatial dimension of the convolution kernel.
|
385
|
+
Convolution with kernel dilation is also known as 'atrous convolution', which increases
|
386
|
+
the receptive field without increasing the number of parameters. Default: 1.
|
387
|
+
groups : int, optional
|
388
|
+
Number of groups for grouped convolution. Controls the connections between inputs and
|
389
|
+
outputs. Both `in_channels` and `out_channels` must be divisible by `groups`. When
|
390
|
+
groups=1 (default), all inputs are convolved to all outputs. When groups>1, the input
|
391
|
+
and output channels are divided into groups, and each group is convolved independently.
|
392
|
+
When groups=in_channels, this becomes a depthwise convolution. Default: 1.
|
393
|
+
w_init : Callable or ArrayLike, optional
|
394
|
+
The initializer for the convolutional kernel weights. Can be an initializer instance
|
395
|
+
or a direct array. Default: XavierNormal().
|
396
|
+
b_init : Callable or ArrayLike or None, optional
|
397
|
+
The initializer for the bias. If None, no bias is added. Default: None.
|
398
|
+
w_mask : ArrayLike or Callable or None, optional
|
399
|
+
An optional mask applied to the weights during forward pass. Useful for implementing
|
400
|
+
structured sparsity or custom connectivity patterns. Default: None.
|
401
|
+
name : str, optional
|
402
|
+
The name of the module. Default: None.
|
403
|
+
param_type : type, optional
|
404
|
+
The type of parameter state to use. Default: ParamState.
|
405
|
+
|
406
|
+
Attributes
|
407
|
+
----------
|
408
|
+
in_size : tuple of int
|
409
|
+
The input shape (L, C) without batch dimension.
|
410
|
+
out_size : tuple of int
|
411
|
+
The output shape (L_out, out_channels) without batch dimension.
|
412
|
+
in_channels : int
|
413
|
+
Number of input channels.
|
414
|
+
out_channels : int
|
415
|
+
Number of output channels.
|
416
|
+
kernel_size : tuple of int
|
417
|
+
Size of the convolving kernel.
|
418
|
+
weight : ParamState
|
419
|
+
The learnable weights (and bias if specified) of the module.
|
420
|
+
|
421
|
+
Examples
|
422
|
+
--------
|
423
|
+
.. code-block:: python
|
424
|
+
|
425
|
+
>>> import brainstate as brainstate
|
426
|
+
>>> import jax.numpy as jnp
|
427
|
+
>>>
|
428
|
+
>>> # Create a 1D convolution layer
|
429
|
+
>>> conv = brainstate.nn.Conv1d(in_size=(28, 3), out_channels=16, kernel_size=5)
|
430
|
+
>>>
|
431
|
+
>>> # Apply to input: batch_size=2, length=28, channels=3
|
432
|
+
>>> x = jnp.ones((2, 28, 3))
|
433
|
+
>>> y = conv(x)
|
434
|
+
>>> print(y.shape) # (2, 28, 16) with 'SAME' padding
|
435
|
+
>>>
|
436
|
+
>>> # Without batch dimension
|
437
|
+
>>> x_single = jnp.ones((28, 3))
|
438
|
+
>>> y_single = conv(x_single)
|
439
|
+
>>> print(y_single.shape) # (28, 16)
|
440
|
+
>>>
|
441
|
+
>>> # With custom parameters
|
442
|
+
>>> conv = brainstate.nn.Conv1d(
|
443
|
+
... in_size=(100, 8),
|
444
|
+
... out_channels=32,
|
445
|
+
... kernel_size=3,
|
446
|
+
... stride=2,
|
447
|
+
... padding='VALID',
|
448
|
+
... b_init=brainstate.init.ZeroInit()
|
449
|
+
... )
|
450
|
+
|
451
|
+
Notes
|
452
|
+
-----
|
453
|
+
**Output dimensions:**
|
454
|
+
|
455
|
+
The output shape depends on the padding mode:
|
456
|
+
|
457
|
+
- 'SAME': output length = ceil(input_length / stride)
|
458
|
+
- 'VALID': output length = ceil((input_length - kernel_size + 1) / stride)
|
459
|
+
|
460
|
+
**Grouped convolution:**
|
461
|
+
|
462
|
+
When groups > 1, the convolution becomes a grouped convolution where input and
|
463
|
+
output channels are divided into groups, reducing computational cost.
|
257
464
|
"""
|
258
465
|
__module__ = 'brainstate.nn'
|
259
466
|
num_spatial_dims: int = 1
|
260
467
|
|
261
468
|
|
262
469
|
class Conv2d(_Conv):
|
263
|
-
"""
|
470
|
+
"""
|
471
|
+
Two-dimensional convolution layer.
|
264
472
|
|
265
|
-
|
473
|
+
Applies a 2D convolution over an input signal composed of several input planes.
|
474
|
+
The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
|
475
|
+
H is height, W is width, and C is the number of input channels (channels-last format).
|
476
|
+
|
477
|
+
This layer creates a convolution kernel that is convolved with the layer input
|
478
|
+
to produce a tensor of outputs. It is commonly used in computer vision tasks.
|
266
479
|
|
267
480
|
Parameters
|
268
481
|
----------
|
269
|
-
|
482
|
+
in_size : tuple of int
|
483
|
+
The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
|
484
|
+
W is width, and C is the number of input channels. This argument is important as it is
|
485
|
+
used to evaluate the output shape.
|
486
|
+
out_channels : int
|
487
|
+
The number of output channels (also called filters or feature maps). These determine
|
488
|
+
the depth of the output feature map.
|
489
|
+
kernel_size : int or tuple of int
|
490
|
+
The shape of the convolutional kernel. Can be:
|
491
|
+
|
492
|
+
- An integer (e.g., 3): creates a square kernel (3, 3)
|
493
|
+
- A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
|
494
|
+
stride : int or tuple of int, optional
|
495
|
+
The stride of the convolution. Controls how much the kernel moves at each step.
|
496
|
+
Can be:
|
497
|
+
|
498
|
+
- An integer: same stride for both dimensions
|
499
|
+
- A tuple of two integers: (stride_height, stride_width)
|
500
|
+
|
501
|
+
Default: 1.
|
502
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
503
|
+
The padding strategy. Options:
|
504
|
+
|
505
|
+
- 'SAME': output spatial size equals input size when stride=1
|
506
|
+
- 'VALID': no padding, output size reduced by kernel size
|
507
|
+
- int: same symmetric padding for all dimensions
|
508
|
+
- (pad_h, pad_w): different padding for each dimension
|
509
|
+
- [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
|
510
|
+
|
511
|
+
Default: 'SAME'.
|
512
|
+
lhs_dilation : int or tuple of int, optional
|
513
|
+
The dilation factor for the input (left-hand side). Controls spacing between input elements.
|
514
|
+
A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
|
515
|
+
Default: 1.
|
516
|
+
rhs_dilation : int or tuple of int, optional
|
517
|
+
The dilation factor for the kernel (right-hand side). Also known as atrous convolution
|
518
|
+
or dilated convolution. Increases the receptive field without increasing parameters by
|
519
|
+
inserting zeros between kernel elements. Useful for capturing multi-scale context.
|
520
|
+
Default: 1.
|
521
|
+
groups : int, optional
|
522
|
+
Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
|
523
|
+
|
524
|
+
- groups=1: standard convolution (all-to-all connections)
|
525
|
+
- groups>1: grouped convolution (reduces parameters by factor of groups)
|
526
|
+
- groups=in_channels: depthwise convolution (each input channel convolved separately)
|
527
|
+
|
528
|
+
Default: 1.
|
529
|
+
w_init : Callable or ArrayLike, optional
|
530
|
+
Weight initializer for the convolutional kernel. Can be:
|
531
|
+
|
532
|
+
- An initializer instance (e.g., brainstate.init.XavierNormal())
|
533
|
+
- A callable that returns an array given a shape
|
534
|
+
- A direct array matching the kernel shape
|
535
|
+
|
536
|
+
Default: XavierNormal().
|
537
|
+
b_init : Callable or ArrayLike or None, optional
|
538
|
+
Bias initializer. If None, no bias term is added to the output.
|
539
|
+
Default: None.
|
540
|
+
w_mask : ArrayLike or Callable or None, optional
|
541
|
+
Optional weight mask for structured sparsity or custom connectivity. The mask is
|
542
|
+
element-wise multiplied with the kernel weights during the forward pass.
|
543
|
+
Default: None.
|
544
|
+
name : str, optional
|
545
|
+
Name identifier for this module instance.
|
546
|
+
Default: None.
|
547
|
+
param_type : type, optional
|
548
|
+
The parameter state class to use for managing learnable parameters.
|
549
|
+
Default: ParamState.
|
550
|
+
|
551
|
+
Attributes
|
552
|
+
----------
|
553
|
+
in_size : tuple of int
|
554
|
+
The input shape (H, W, C) without batch dimension.
|
555
|
+
out_size : tuple of int
|
556
|
+
The output shape (H_out, W_out, out_channels) without batch dimension.
|
557
|
+
in_channels : int
|
558
|
+
Number of input channels.
|
559
|
+
out_channels : int
|
560
|
+
Number of output channels.
|
561
|
+
kernel_size : tuple of int
|
562
|
+
Size of the convolving kernel (height, width).
|
563
|
+
weight : ParamState
|
564
|
+
The learnable weights (and bias if specified) of the module.
|
565
|
+
|
566
|
+
Examples
|
567
|
+
--------
|
568
|
+
.. code-block:: python
|
569
|
+
|
570
|
+
>>> import brainstate as brainstate
|
571
|
+
>>> import jax.numpy as jnp
|
572
|
+
>>>
|
573
|
+
>>> # Create a 2D convolution layer
|
574
|
+
>>> conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
|
575
|
+
>>>
|
576
|
+
>>> # Apply to input: batch_size=8, height=32, width=32, channels=3
|
577
|
+
>>> x = jnp.ones((8, 32, 32, 3))
|
578
|
+
>>> y = conv(x)
|
579
|
+
>>> print(y.shape) # (8, 32, 32, 64) with 'SAME' padding
|
580
|
+
>>>
|
581
|
+
>>> # Without batch dimension
|
582
|
+
>>> x_single = jnp.ones((32, 32, 3))
|
583
|
+
>>> y_single = conv(x_single)
|
584
|
+
>>> print(y_single.shape) # (32, 32, 64)
|
585
|
+
>>>
|
586
|
+
>>> # With custom kernel size and stride
|
587
|
+
>>> conv = brainstate.nn.Conv2d(
|
588
|
+
... in_size=(224, 224, 3),
|
589
|
+
... out_channels=128,
|
590
|
+
... kernel_size=(5, 5),
|
591
|
+
... stride=2,
|
592
|
+
... padding='VALID'
|
593
|
+
... )
|
594
|
+
>>>
|
595
|
+
>>> # Depthwise convolution (groups = in_channels)
|
596
|
+
>>> conv = brainstate.nn.Conv2d(
|
597
|
+
... in_size=(64, 64, 32),
|
598
|
+
... out_channels=32,
|
599
|
+
... kernel_size=3,
|
600
|
+
... groups=32
|
601
|
+
... )
|
602
|
+
|
603
|
+
Notes
|
604
|
+
-----
|
605
|
+
**Output dimensions:**
|
606
|
+
|
607
|
+
The output spatial dimensions depend on the padding mode:
|
608
|
+
|
609
|
+
- 'SAME': output_size = ceil(input_size / stride)
|
610
|
+
- 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
|
611
|
+
|
612
|
+
**Grouped convolution:**
|
613
|
+
|
614
|
+
When groups > 1, the input and output channels are divided into groups.
|
615
|
+
Each group is convolved independently, which can significantly reduce
|
616
|
+
computational cost while maintaining representational power.
|
270
617
|
"""
|
271
618
|
__module__ = 'brainstate.nn'
|
272
619
|
num_spatial_dims: int = 2
|
273
620
|
|
274
621
|
|
275
622
|
class Conv3d(_Conv):
|
276
|
-
"""
|
623
|
+
"""
|
624
|
+
Three-dimensional convolution layer.
|
277
625
|
|
278
|
-
|
626
|
+
Applies a 3D convolution over an input signal composed of several input planes.
|
627
|
+
The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
|
628
|
+
H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
|
629
|
+
|
630
|
+
This layer is commonly used for processing 3D data such as video sequences or
|
631
|
+
volumetric medical imaging data.
|
279
632
|
|
280
633
|
Parameters
|
281
634
|
----------
|
282
|
-
|
635
|
+
in_size : tuple of int
|
636
|
+
The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
|
637
|
+
W is width, D is depth, and C is the number of input channels. This argument is important
|
638
|
+
as it is used to evaluate the output shape.
|
639
|
+
out_channels : int
|
640
|
+
The number of output channels (also called filters or feature maps). These determine
|
641
|
+
the depth of the output feature map.
|
642
|
+
kernel_size : int or tuple of int
|
643
|
+
The shape of the convolutional kernel. Can be:
|
644
|
+
|
645
|
+
- An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
|
646
|
+
- A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
|
647
|
+
stride : int or tuple of int, optional
|
648
|
+
The stride of the convolution. Controls how much the kernel moves at each step.
|
649
|
+
Can be:
|
650
|
+
|
651
|
+
- An integer: same stride for all dimensions
|
652
|
+
- A tuple of three integers: (stride_h, stride_w, stride_d)
|
653
|
+
Default: 1.
|
654
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
655
|
+
The padding strategy. Options:
|
656
|
+
|
657
|
+
- 'SAME': output spatial size equals input size when stride=1
|
658
|
+
- 'VALID': no padding, output size reduced by kernel size
|
659
|
+
- int: same symmetric padding for all dimensions
|
660
|
+
- (pad_h, pad_w, pad_d): different padding for each dimension
|
661
|
+
- [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
|
662
|
+
|
663
|
+
Default: 'SAME'.
|
664
|
+
lhs_dilation : int or tuple of int, optional
|
665
|
+
The dilation factor for the input (left-hand side). Controls spacing between input elements.
|
666
|
+
A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
|
667
|
+
Default: 1.
|
668
|
+
rhs_dilation : int or tuple of int, optional
|
669
|
+
The dilation factor for the kernel (right-hand side). Also known as atrous convolution
|
670
|
+
or dilated convolution. Increases the receptive field without increasing parameters by
|
671
|
+
inserting zeros between kernel elements. Particularly useful for 3D data to capture
|
672
|
+
larger temporal/spatial context.
|
673
|
+
Default: 1.
|
674
|
+
groups : int, optional
|
675
|
+
Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
|
676
|
+
|
677
|
+
- groups=1: standard convolution (all-to-all connections)
|
678
|
+
- groups>1: grouped convolution (significantly reduces parameters and computation for 3D)
|
679
|
+
- groups=in_channels: depthwise convolution (each input channel convolved separately)
|
680
|
+
|
681
|
+
Default: 1.
|
682
|
+
w_init : Callable or ArrayLike, optional
|
683
|
+
Weight initializer for the convolutional kernel. Can be:
|
684
|
+
|
685
|
+
- An initializer instance (e.g., brainstate.init.XavierNormal())
|
686
|
+
- A callable that returns an array given a shape
|
687
|
+
- A direct array matching the kernel shape
|
688
|
+
|
689
|
+
Default: XavierNormal().
|
690
|
+
b_init : Callable or ArrayLike or None, optional
|
691
|
+
Bias initializer. If None, no bias term is added to the output.
|
692
|
+
Default: None.
|
693
|
+
w_mask : ArrayLike or Callable or None, optional
|
694
|
+
Optional weight mask for structured sparsity or custom connectivity. The mask is
|
695
|
+
element-wise multiplied with the kernel weights during the forward pass.
|
696
|
+
Default: None.
|
697
|
+
name : str, optional
|
698
|
+
Name identifier for this module instance.
|
699
|
+
Default: None.
|
700
|
+
param_type : type, optional
|
701
|
+
The parameter state class to use for managing learnable parameters.
|
702
|
+
Default: ParamState.
|
703
|
+
|
704
|
+
Attributes
|
705
|
+
----------
|
706
|
+
in_size : tuple of int
|
707
|
+
The input shape (H, W, D, C) without batch dimension.
|
708
|
+
out_size : tuple of int
|
709
|
+
The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
|
710
|
+
in_channels : int
|
711
|
+
Number of input channels.
|
712
|
+
out_channels : int
|
713
|
+
Number of output channels.
|
714
|
+
kernel_size : tuple of int
|
715
|
+
Size of the convolving kernel (height, width, depth).
|
716
|
+
weight : ParamState
|
717
|
+
The learnable weights (and bias if specified) of the module.
|
718
|
+
|
719
|
+
Examples
|
720
|
+
--------
|
721
|
+
.. code-block:: python
|
722
|
+
|
723
|
+
>>> import brainstate as brainstate
|
724
|
+
>>> import jax.numpy as jnp
|
725
|
+
>>>
|
726
|
+
>>> # Create a 3D convolution layer for video data
|
727
|
+
>>> conv = brainstate.nn.Conv3d(in_size=(16, 64, 64, 3), out_channels=32, kernel_size=3)
|
728
|
+
>>>
|
729
|
+
>>> # Apply to input: batch_size=4, frames=16, height=64, width=64, channels=3
|
730
|
+
>>> x = jnp.ones((4, 16, 64, 64, 3))
|
731
|
+
>>> y = conv(x)
|
732
|
+
>>> print(y.shape) # (4, 16, 64, 64, 32) with 'SAME' padding
|
733
|
+
>>>
|
734
|
+
>>> # Without batch dimension
|
735
|
+
>>> x_single = jnp.ones((16, 64, 64, 3))
|
736
|
+
>>> y_single = conv(x_single)
|
737
|
+
>>> print(y_single.shape) # (16, 64, 64, 32)
|
738
|
+
>>>
|
739
|
+
>>> # For medical imaging with custom parameters
|
740
|
+
>>> conv = brainstate.nn.Conv3d(
|
741
|
+
... in_size=(32, 32, 32, 1),
|
742
|
+
... out_channels=64,
|
743
|
+
... kernel_size=(3, 3, 3),
|
744
|
+
... stride=2,
|
745
|
+
... padding='VALID',
|
746
|
+
... b_init=brainstate.init.Constant(0.1)
|
747
|
+
... )
|
748
|
+
|
749
|
+
Notes
|
750
|
+
-----
|
751
|
+
**Output dimensions:**
|
752
|
+
|
753
|
+
The output spatial dimensions depend on the padding mode:
|
754
|
+
|
755
|
+
- 'SAME': output_size = ceil(input_size / stride)
|
756
|
+
- 'VALID': output_size = ceil((input_size - kernel_size + 1) / stride)
|
757
|
+
|
758
|
+
**Performance considerations:**
|
759
|
+
|
760
|
+
3D convolutions are computationally expensive. Consider using:
|
761
|
+
|
762
|
+
- Smaller kernel sizes
|
763
|
+
- Grouped convolutions (groups > 1)
|
764
|
+
- Separable convolutions for large-scale applications
|
283
765
|
"""
|
284
766
|
__module__ = 'brainstate.nn'
|
285
767
|
num_spatial_dims: int = 3
|
286
768
|
|
287
769
|
|
288
|
-
_conv_doc = '''
|
289
|
-
in_size: tuple of int
|
290
|
-
The input shape, without the batch size. This argument is important, since it is
|
291
|
-
used to evaluate the shape of the output.
|
292
|
-
out_channels: int
|
293
|
-
The number of output channels.
|
294
|
-
kernel_size: int, sequence of int
|
295
|
-
The shape of the convolutional kernel.
|
296
|
-
For 1D convolution, the kernel size can be passed as an integer.
|
297
|
-
For all other cases, it must be a sequence of integers.
|
298
|
-
stride: int, sequence of int
|
299
|
-
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
300
|
-
padding: str, int, sequence of int, sequence of tuple
|
301
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
302
|
-
high)` integer pairs that give the padding to apply before and after each
|
303
|
-
spatial dimension.
|
304
|
-
lhs_dilation: int, sequence of int
|
305
|
-
An integer or a sequence of `n` integers, giving the
|
306
|
-
dilation factor to apply in each spatial dimension of `inputs`
|
307
|
-
(default: 1). Convolution with input dilation `d` is equivalent to
|
308
|
-
transposed convolution with stride `d`.
|
309
|
-
rhs_dilation: int, sequence of int
|
310
|
-
An integer or a sequence of `n` integers, giving the
|
311
|
-
dilation factor to apply in each spatial dimension of the convolution
|
312
|
-
kernel (default: 1). Convolution with kernel dilation
|
313
|
-
is also known as 'atrous convolution'.
|
314
|
-
groups: int
|
315
|
-
If specified, divides the input features into groups. default 1.
|
316
|
-
w_init: Callable, ArrayLike, Initializer
|
317
|
-
The initializer for the convolutional kernel.
|
318
|
-
b_init: Optional, Callable, ArrayLike, Initializer
|
319
|
-
The initializer for the bias.
|
320
|
-
w_mask: ArrayLike, Callable, Optional
|
321
|
-
The optional mask of the weights.
|
322
|
-
mode: Mode
|
323
|
-
The computation mode of the current object. Default it is `training`.
|
324
|
-
name: str, Optional
|
325
|
-
The name of the object.
|
326
|
-
'''
|
327
|
-
|
328
|
-
Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
|
329
|
-
Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
|
330
|
-
Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
|
331
|
-
|
332
|
-
|
333
770
|
class _ScaledWSConv(_BaseConv):
|
334
771
|
def __init__(
|
335
772
|
self,
|
@@ -346,19 +783,23 @@ class _ScaledWSConv(_BaseConv):
|
|
346
783
|
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
347
784
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
348
785
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
786
|
+
channel_first: bool = False,
|
349
787
|
name: str = None,
|
350
788
|
param_type: type = ParamState,
|
351
789
|
):
|
352
|
-
super().__init__(
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
790
|
+
super().__init__(
|
791
|
+
in_size=in_size,
|
792
|
+
out_channels=out_channels,
|
793
|
+
kernel_size=kernel_size,
|
794
|
+
stride=stride,
|
795
|
+
padding=padding,
|
796
|
+
lhs_dilation=lhs_dilation,
|
797
|
+
rhs_dilation=rhs_dilation,
|
798
|
+
groups=groups,
|
799
|
+
w_mask=w_mask,
|
800
|
+
channel_first=channel_first,
|
801
|
+
name=name,
|
802
|
+
)
|
362
803
|
|
363
804
|
self.w_initializer = w_init
|
364
805
|
self.b_initializer = b_init
|
@@ -384,9 +825,14 @@ class _ScaledWSConv(_BaseConv):
|
|
384
825
|
self.weight = param_type(params)
|
385
826
|
|
386
827
|
# Evaluate the output shape
|
828
|
+
if self.channel_first:
|
829
|
+
test_input_shape = (128,) + self.in_size
|
830
|
+
else:
|
831
|
+
test_input_shape = (128,) + self.in_size
|
832
|
+
|
387
833
|
abstract_y = jax.eval_shape(
|
388
834
|
self._conv_op,
|
389
|
-
jax.ShapeDtypeStruct(
|
835
|
+
jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
|
390
836
|
params
|
391
837
|
)
|
392
838
|
y_shape = abstract_y.shape[1:]
|
@@ -394,7 +840,7 @@ class _ScaledWSConv(_BaseConv):
|
|
394
840
|
|
395
841
|
def _conv_op(self, x, params):
|
396
842
|
w = params['weight']
|
397
|
-
w =
|
843
|
+
w = weight_standardization(w, self.eps, params.get('gain', None))
|
398
844
|
if self.w_mask is not None:
|
399
845
|
w = w * self.w_mask
|
400
846
|
y = jax.lax.conv_general_dilated(
|
@@ -413,89 +859,1152 @@ class _ScaledWSConv(_BaseConv):
|
|
413
859
|
|
414
860
|
|
415
861
|
class ScaledWSConv1d(_ScaledWSConv):
|
416
|
-
"""
|
862
|
+
"""
|
863
|
+
One-dimensional convolution with weight standardization.
|
417
864
|
|
418
|
-
|
865
|
+
This layer applies weight standardization to the convolutional kernel before
|
866
|
+
performing the convolution operation. Weight standardization normalizes the
|
867
|
+
weights to have zero mean and unit variance, which can accelerate training
|
868
|
+
and improve model performance, especially when combined with group normalization.
|
869
|
+
|
870
|
+
The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
|
871
|
+
L is the sequence length, and C is the number of input channels.
|
419
872
|
|
420
873
|
Parameters
|
421
874
|
----------
|
422
|
-
|
875
|
+
in_size : tuple of int
|
876
|
+
The input shape without the batch dimension. For Conv1d: (L, C) where L is the sequence
|
877
|
+
length and C is the number of input channels. This argument is important as it is used
|
878
|
+
to evaluate the output shape.
|
879
|
+
out_channels : int
|
880
|
+
The number of output channels (also called filters or feature maps). These determine
|
881
|
+
the depth of the output feature map.
|
882
|
+
kernel_size : int or tuple of int
|
883
|
+
The shape of the convolutional kernel. For 1D convolution, can be:
|
884
|
+
|
885
|
+
- An integer (e.g., 5): creates a kernel of size 5
|
886
|
+
- A tuple with one integer (e.g., (5,)): equivalent to the above
|
887
|
+
stride : int or tuple of int, optional
|
888
|
+
The stride of the convolution. Controls how much the kernel moves at each step.
|
889
|
+
Default: 1.
|
890
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
891
|
+
The padding strategy. Options:
|
892
|
+
|
893
|
+
- 'SAME': output length equals input length when stride=1
|
894
|
+
- 'VALID': no padding, output length reduced by kernel size
|
895
|
+
- int: symmetric padding
|
896
|
+
- (pad_before, pad_after): explicit padding for the sequence dimension
|
897
|
+
|
898
|
+
Default: 'SAME'.
|
899
|
+
lhs_dilation : int or tuple of int, optional
|
900
|
+
The dilation factor for the input (left-hand side). Controls spacing between input elements.
|
901
|
+
A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
|
902
|
+
Default: 1.
|
903
|
+
rhs_dilation : int or tuple of int, optional
|
904
|
+
The dilation factor for the kernel (right-hand side). Also known as atrous convolution
|
905
|
+
or dilated convolution. Increases the receptive field without increasing parameters by
|
906
|
+
inserting zeros between kernel elements. Useful for capturing long-range dependencies.
|
907
|
+
Default: 1.
|
908
|
+
groups : int, optional
|
909
|
+
Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
|
910
|
+
|
911
|
+
- groups=1: standard convolution (all-to-all connections)
|
912
|
+
- groups>1: grouped convolution (reduces parameters by factor of groups)
|
913
|
+
- groups=in_channels: depthwise convolution (each input channel convolved separately)
|
914
|
+
|
915
|
+
Default: 1.
|
916
|
+
w_init : Callable or ArrayLike, optional
|
917
|
+
Weight initializer for the convolutional kernel. Can be:
|
918
|
+
|
919
|
+
- An initializer instance (e.g., brainstate.init.XavierNormal())
|
920
|
+
- A callable that returns an array given a shape
|
921
|
+
- A direct array matching the kernel shape
|
922
|
+
|
923
|
+
Default: XavierNormal().
|
924
|
+
b_init : Callable or ArrayLike or None, optional
|
925
|
+
Bias initializer. If None, no bias term is added to the output.
|
926
|
+
Default: None.
|
927
|
+
ws_gain : bool, optional
|
928
|
+
Whether to include a learnable per-channel gain parameter in weight standardization.
|
929
|
+
When True, adds a scaling factor that can be learned during training, improving
|
930
|
+
model expressiveness. Recommended for most applications.
|
931
|
+
Default: True.
|
932
|
+
eps : float, optional
|
933
|
+
Small constant for numerical stability in weight standardization. Prevents division
|
934
|
+
by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
|
935
|
+
Default: 1e-4.
|
936
|
+
w_mask : ArrayLike or Callable or None, optional
|
937
|
+
Optional weight mask for structured sparsity or custom connectivity. The mask is
|
938
|
+
element-wise multiplied with the standardized kernel weights during the forward pass.
|
939
|
+
Default: None.
|
940
|
+
name : str, optional
|
941
|
+
Name identifier for this module instance.
|
942
|
+
Default: None.
|
943
|
+
param_type : type, optional
|
944
|
+
The parameter state class to use for managing learnable parameters.
|
945
|
+
Default: ParamState.
|
946
|
+
|
947
|
+
Attributes
|
948
|
+
----------
|
949
|
+
in_size : tuple of int
|
950
|
+
The input shape (L, C) without batch dimension.
|
951
|
+
out_size : tuple of int
|
952
|
+
The output shape (L_out, out_channels) without batch dimension.
|
953
|
+
in_channels : int
|
954
|
+
Number of input channels.
|
955
|
+
out_channels : int
|
956
|
+
Number of output channels.
|
957
|
+
kernel_size : tuple of int
|
958
|
+
Size of the convolving kernel.
|
959
|
+
weight : ParamState
|
960
|
+
The learnable weights (and bias if specified) of the module.
|
961
|
+
eps : float
|
962
|
+
Small constant for numerical stability in weight standardization.
|
963
|
+
|
964
|
+
Examples
|
965
|
+
--------
|
966
|
+
.. code-block:: python
|
967
|
+
|
968
|
+
>>> import brainstate as brainstate
|
969
|
+
>>> import jax.numpy as jnp
|
970
|
+
>>>
|
971
|
+
>>> # Create a 1D convolution with weight standardization
|
972
|
+
>>> conv = brainstate.nn.ScaledWSConv1d(
|
973
|
+
... in_size=(100, 16),
|
974
|
+
... out_channels=32,
|
975
|
+
... kernel_size=5
|
976
|
+
... )
|
977
|
+
>>>
|
978
|
+
>>> # Apply to input
|
979
|
+
>>> x = jnp.ones((4, 100, 16))
|
980
|
+
>>> y = conv(x)
|
981
|
+
>>> print(y.shape) # (4, 100, 32)
|
982
|
+
>>>
|
983
|
+
>>> # With custom epsilon and no gain
|
984
|
+
>>> conv = brainstate.nn.ScaledWSConv1d(
|
985
|
+
... in_size=(50, 8),
|
986
|
+
... out_channels=16,
|
987
|
+
... kernel_size=3,
|
988
|
+
... ws_gain=False,
|
989
|
+
... eps=1e-5
|
990
|
+
... )
|
991
|
+
|
992
|
+
Notes
|
993
|
+
-----
|
994
|
+
**Weight standardization formula:**
|
995
|
+
|
996
|
+
Weight standardization reparameterizes the convolutional weights as:
|
997
|
+
|
998
|
+
.. math::
|
999
|
+
\\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
|
1000
|
+
|
1001
|
+
where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
|
1002
|
+
of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
|
1003
|
+
and :math:`\\epsilon` is a small constant for numerical stability.
|
1004
|
+
|
1005
|
+
**When to use:**
|
1006
|
+
|
1007
|
+
This technique is particularly effective when used with Group Normalization
|
1008
|
+
instead of Batch Normalization, as it reduces the dependence on batch statistics.
|
1009
|
+
|
1010
|
+
References
|
1011
|
+
----------
|
1012
|
+
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
1013
|
+
Weight Standardization. arXiv preprint arXiv:1903.10520.
|
423
1014
|
"""
|
424
1015
|
__module__ = 'brainstate.nn'
|
425
1016
|
num_spatial_dims: int = 1
|
426
1017
|
|
427
1018
|
|
428
1019
|
class ScaledWSConv2d(_ScaledWSConv):
|
429
|
-
"""
|
1020
|
+
"""
|
1021
|
+
Two-dimensional convolution with weight standardization.
|
1022
|
+
|
1023
|
+
This layer applies weight standardization to the convolutional kernel before
|
1024
|
+
performing the convolution operation. Weight standardization normalizes the
|
1025
|
+
weights to have zero mean and unit variance, improving training dynamics and
|
1026
|
+
model generalization, particularly in combination with group normalization.
|
430
1027
|
|
431
|
-
The input should be a
|
1028
|
+
The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
|
1029
|
+
H is height, W is width, and C is the number of input channels (channels-last format).
|
432
1030
|
|
433
1031
|
Parameters
|
434
1032
|
----------
|
435
|
-
|
1033
|
+
in_size : tuple of int
|
1034
|
+
The input shape without the batch dimension. For Conv2d: (H, W, C) where H is height,
|
1035
|
+
W is width, and C is the number of input channels. This argument is important as it is
|
1036
|
+
used to evaluate the output shape.
|
1037
|
+
out_channels : int
|
1038
|
+
The number of output channels (also called filters or feature maps). These determine
|
1039
|
+
the depth of the output feature map.
|
1040
|
+
kernel_size : int or tuple of int
|
1041
|
+
The shape of the convolutional kernel. Can be:
|
1042
|
+
|
1043
|
+
- An integer (e.g., 3): creates a square kernel (3, 3)
|
1044
|
+
- A tuple of two integers (e.g., (3, 5)): creates a (height, width) kernel
|
1045
|
+
stride : int or tuple of int, optional
|
1046
|
+
The stride of the convolution. Controls how much the kernel moves at each step.
|
1047
|
+
Can be:
|
1048
|
+
|
1049
|
+
- An integer: same stride for both dimensions
|
1050
|
+
- A tuple of two integers: (stride_height, stride_width)
|
1051
|
+
|
1052
|
+
Default: 1.
|
1053
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
1054
|
+
The padding strategy. Options:
|
1055
|
+
|
1056
|
+
- 'SAME': output spatial size equals input size when stride=1
|
1057
|
+
- 'VALID': no padding, output size reduced by kernel size
|
1058
|
+
- int: same symmetric padding for all dimensions
|
1059
|
+
- (pad_h, pad_w): different padding for each dimension
|
1060
|
+
- [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
|
1061
|
+
|
1062
|
+
Default: 'SAME'.
|
1063
|
+
lhs_dilation : int or tuple of int, optional
|
1064
|
+
The dilation factor for the input (left-hand side). Controls spacing between input elements.
|
1065
|
+
A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
|
1066
|
+
Default: 1.
|
1067
|
+
rhs_dilation : int or tuple of int, optional
|
1068
|
+
The dilation factor for the kernel (right-hand side). Also known as atrous convolution
|
1069
|
+
or dilated convolution. Increases the receptive field without increasing parameters by
|
1070
|
+
inserting zeros between kernel elements. Useful for semantic segmentation and dense
|
1071
|
+
prediction tasks.
|
1072
|
+
Default: 1.
|
1073
|
+
groups : int, optional
|
1074
|
+
Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
|
1075
|
+
|
1076
|
+
- groups=1: standard convolution (all-to-all connections)
|
1077
|
+
- groups>1: grouped convolution (reduces parameters by factor of groups)
|
1078
|
+
- groups=in_channels: depthwise convolution (each input channel convolved separately)
|
1079
|
+
|
1080
|
+
Default: 1.
|
1081
|
+
w_init : Callable or ArrayLike, optional
|
1082
|
+
Weight initializer for the convolutional kernel. Can be:
|
1083
|
+
|
1084
|
+
- An initializer instance (e.g., brainstate.init.XavierNormal())
|
1085
|
+
- A callable that returns an array given a shape
|
1086
|
+
- A direct array matching the kernel shape
|
1087
|
+
|
1088
|
+
Default: XavierNormal().
|
1089
|
+
b_init : Callable or ArrayLike or None, optional
|
1090
|
+
Bias initializer. If None, no bias term is added to the output.
|
1091
|
+
Default: None.
|
1092
|
+
ws_gain : bool, optional
|
1093
|
+
Whether to include a learnable per-channel gain parameter in weight standardization.
|
1094
|
+
When True, adds a scaling factor that can be learned during training, improving
|
1095
|
+
model expressiveness. Highly recommended when using with Group Normalization.
|
1096
|
+
Default: True.
|
1097
|
+
eps : float, optional
|
1098
|
+
Small constant for numerical stability in weight standardization. Prevents division
|
1099
|
+
by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
|
1100
|
+
Default: 1e-4.
|
1101
|
+
w_mask : ArrayLike or Callable or None, optional
|
1102
|
+
Optional weight mask for structured sparsity or custom connectivity. The mask is
|
1103
|
+
element-wise multiplied with the standardized kernel weights during the forward pass.
|
1104
|
+
Default: None.
|
1105
|
+
name : str, optional
|
1106
|
+
Name identifier for this module instance.
|
1107
|
+
Default: None.
|
1108
|
+
param_type : type, optional
|
1109
|
+
The parameter state class to use for managing learnable parameters.
|
1110
|
+
Default: ParamState.
|
1111
|
+
|
1112
|
+
Attributes
|
1113
|
+
----------
|
1114
|
+
in_size : tuple of int
|
1115
|
+
The input shape (H, W, C) without batch dimension.
|
1116
|
+
out_size : tuple of int
|
1117
|
+
The output shape (H_out, W_out, out_channels) without batch dimension.
|
1118
|
+
in_channels : int
|
1119
|
+
Number of input channels.
|
1120
|
+
out_channels : int
|
1121
|
+
Number of output channels.
|
1122
|
+
kernel_size : tuple of int
|
1123
|
+
Size of the convolving kernel (height, width).
|
1124
|
+
weight : ParamState
|
1125
|
+
The learnable weights (and bias if specified) of the module.
|
1126
|
+
eps : float
|
1127
|
+
Small constant for numerical stability in weight standardization.
|
1128
|
+
|
1129
|
+
Examples
|
1130
|
+
--------
|
1131
|
+
.. code-block:: python
|
1132
|
+
|
1133
|
+
>>> import brainstate as brainstate
|
1134
|
+
>>> import jax.numpy as jnp
|
1135
|
+
>>>
|
1136
|
+
>>> # Create a 2D convolution with weight standardization
|
1137
|
+
>>> conv = brainstate.nn.ScaledWSConv2d(
|
1138
|
+
... in_size=(64, 64, 3),
|
1139
|
+
... out_channels=32,
|
1140
|
+
... kernel_size=3
|
1141
|
+
... )
|
1142
|
+
>>>
|
1143
|
+
>>> # Apply to input
|
1144
|
+
>>> x = jnp.ones((8, 64, 64, 3))
|
1145
|
+
>>> y = conv(x)
|
1146
|
+
>>> print(y.shape) # (8, 64, 64, 32)
|
1147
|
+
>>>
|
1148
|
+
>>> # Combine with custom settings for ResNet-style architecture
|
1149
|
+
>>> conv = brainstate.nn.ScaledWSConv2d(
|
1150
|
+
... in_size=(224, 224, 3),
|
1151
|
+
... out_channels=64,
|
1152
|
+
... kernel_size=7,
|
1153
|
+
... stride=2,
|
1154
|
+
... padding='SAME',
|
1155
|
+
... ws_gain=True,
|
1156
|
+
... b_init=brainstate.init.ZeroInit()
|
1157
|
+
... )
|
1158
|
+
>>>
|
1159
|
+
>>> # Depthwise separable convolution with weight standardization
|
1160
|
+
>>> conv = brainstate.nn.ScaledWSConv2d(
|
1161
|
+
... in_size=(32, 32, 128),
|
1162
|
+
... out_channels=128,
|
1163
|
+
... kernel_size=3,
|
1164
|
+
... groups=128,
|
1165
|
+
... ws_gain=False
|
1166
|
+
... )
|
1167
|
+
|
1168
|
+
Notes
|
1169
|
+
-----
|
1170
|
+
**Weight standardization formula:**
|
1171
|
+
|
1172
|
+
Weight standardization reparameterizes the convolutional weights as:
|
1173
|
+
|
1174
|
+
.. math::
|
1175
|
+
\\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
|
1176
|
+
|
1177
|
+
where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
|
1178
|
+
of the weights computed per output channel, :math:`g` is a learnable gain
|
1179
|
+
parameter (if ws_gain=True), and :math:`\\epsilon` is a small constant.
|
1180
|
+
|
1181
|
+
**Benefits:**
|
1182
|
+
|
1183
|
+
- Reduces internal covariate shift
|
1184
|
+
- Smooths the loss landscape
|
1185
|
+
- Works well with Group Normalization
|
1186
|
+
- Improves training stability with small batch sizes
|
1187
|
+
- Enables training deeper networks more easily
|
1188
|
+
|
1189
|
+
References
|
1190
|
+
----------
|
1191
|
+
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
1192
|
+
Weight Standardization. arXiv preprint arXiv:1903.10520.
|
436
1193
|
"""
|
437
1194
|
__module__ = 'brainstate.nn'
|
438
1195
|
num_spatial_dims: int = 2
|
439
1196
|
|
440
1197
|
|
441
1198
|
class ScaledWSConv3d(_ScaledWSConv):
|
442
|
-
"""
|
1199
|
+
"""
|
1200
|
+
Three-dimensional convolution with weight standardization.
|
1201
|
+
|
1202
|
+
This layer applies weight standardization to the convolutional kernel before
|
1203
|
+
performing the 3D convolution operation. Weight standardization normalizes the
|
1204
|
+
weights to have zero mean and unit variance, which improves training dynamics
|
1205
|
+
especially for 3D networks that are typically deeper and more parameter-heavy.
|
443
1206
|
|
444
|
-
The input should be a
|
1207
|
+
The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
|
1208
|
+
H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
|
445
1209
|
|
446
1210
|
Parameters
|
447
1211
|
----------
|
448
|
-
|
1212
|
+
in_size : tuple of int
|
1213
|
+
The input shape without the batch dimension. For Conv3d: (H, W, D, C) where H is height,
|
1214
|
+
W is width, D is depth, and C is the number of input channels. This argument is important
|
1215
|
+
as it is used to evaluate the output shape.
|
1216
|
+
out_channels : int
|
1217
|
+
The number of output channels (also called filters or feature maps). These determine
|
1218
|
+
the depth of the output feature map.
|
1219
|
+
kernel_size : int or tuple of int
|
1220
|
+
The shape of the convolutional kernel. Can be:
|
1221
|
+
|
1222
|
+
- An integer (e.g., 3): creates a cubic kernel (3, 3, 3)
|
1223
|
+
- A tuple of three integers (e.g., (3, 5, 5)): creates a (height, width, depth) kernel
|
1224
|
+
stride : int or tuple of int, optional
|
1225
|
+
The stride of the convolution. Controls how much the kernel moves at each step.
|
1226
|
+
Can be:
|
1227
|
+
|
1228
|
+
- An integer: same stride for all dimensions
|
1229
|
+
- A tuple of three integers: (stride_h, stride_w, stride_d)
|
1230
|
+
|
1231
|
+
Default: 1.
|
1232
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
1233
|
+
The padding strategy. Options:
|
1234
|
+
|
1235
|
+
- 'SAME': output spatial size equals input size when stride=1
|
1236
|
+
- 'VALID': no padding, output size reduced by kernel size
|
1237
|
+
- int: same symmetric padding for all dimensions
|
1238
|
+
- (pad_h, pad_w, pad_d): different padding for each dimension
|
1239
|
+
- [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit padding
|
1240
|
+
|
1241
|
+
Default: 'SAME'.
|
1242
|
+
lhs_dilation : int or tuple of int, optional
|
1243
|
+
The dilation factor for the input (left-hand side). Controls spacing between input elements.
|
1244
|
+
A value > 1 inserts zeros between input elements, equivalent to transposed convolution.
|
1245
|
+
Default: 1.
|
1246
|
+
rhs_dilation : int or tuple of int, optional
|
1247
|
+
The dilation factor for the kernel (right-hand side). Also known as atrous convolution
|
1248
|
+
or dilated convolution. Increases the receptive field without increasing parameters by
|
1249
|
+
inserting zeros between kernel elements. Particularly valuable for 3D to capture
|
1250
|
+
multi-scale temporal/spatial context efficiently.
|
1251
|
+
Default: 1.
|
1252
|
+
groups : int, optional
|
1253
|
+
Number of groups for grouped convolution. Must divide both `in_channels` and `out_channels`.
|
1254
|
+
|
1255
|
+
- groups=1: standard convolution (all-to-all connections)
|
1256
|
+
- groups>1: grouped convolution (critical for reducing 3D conv computational cost)
|
1257
|
+
- groups=in_channels: depthwise convolution (each input channel convolved separately)
|
1258
|
+
|
1259
|
+
Default: 1.
|
1260
|
+
w_init : Callable or ArrayLike, optional
|
1261
|
+
Weight initializer for the convolutional kernel. Can be:
|
1262
|
+
|
1263
|
+
- An initializer instance (e.g., brainstate.init.XavierNormal())
|
1264
|
+
- A callable that returns an array given a shape
|
1265
|
+
- A direct array matching the kernel shape
|
1266
|
+
|
1267
|
+
Default: XavierNormal().
|
1268
|
+
b_init : Callable or ArrayLike or None, optional
|
1269
|
+
Bias initializer. If None, no bias term is added to the output.
|
1270
|
+
Default: None.
|
1271
|
+
ws_gain : bool, optional
|
1272
|
+
Whether to include a learnable per-channel gain parameter in weight standardization.
|
1273
|
+
When True, adds a scaling factor that can be learned during training, improving
|
1274
|
+
model expressiveness. Particularly beneficial for deep 3D networks.
|
1275
|
+
Default: True.
|
1276
|
+
eps : float, optional
|
1277
|
+
Small constant for numerical stability in weight standardization. Prevents division
|
1278
|
+
by zero when computing weight standard deviation. Typical values: 1e-4 to 1e-5.
|
1279
|
+
Default: 1e-4.
|
1280
|
+
w_mask : ArrayLike or Callable or None, optional
|
1281
|
+
Optional weight mask for structured sparsity or custom connectivity. The mask is
|
1282
|
+
element-wise multiplied with the standardized kernel weights during the forward pass.
|
1283
|
+
Default: None.
|
1284
|
+
name : str, optional
|
1285
|
+
Name identifier for this module instance.
|
1286
|
+
Default: None.
|
1287
|
+
param_type : type, optional
|
1288
|
+
The parameter state class to use for managing learnable parameters.
|
1289
|
+
Default: ParamState.
|
1290
|
+
|
1291
|
+
Attributes
|
1292
|
+
----------
|
1293
|
+
in_size : tuple of int
|
1294
|
+
The input shape (H, W, D, C) without batch dimension.
|
1295
|
+
out_size : tuple of int
|
1296
|
+
The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
|
1297
|
+
in_channels : int
|
1298
|
+
Number of input channels.
|
1299
|
+
out_channels : int
|
1300
|
+
Number of output channels.
|
1301
|
+
kernel_size : tuple of int
|
1302
|
+
Size of the convolving kernel (height, width, depth).
|
1303
|
+
weight : ParamState
|
1304
|
+
The learnable weights (and bias if specified) of the module.
|
1305
|
+
eps : float
|
1306
|
+
Small constant for numerical stability in weight standardization.
|
1307
|
+
|
1308
|
+
Examples
|
1309
|
+
--------
|
1310
|
+
.. code-block:: python
|
1311
|
+
|
1312
|
+
>>> import brainstate as brainstate
|
1313
|
+
>>> import jax.numpy as jnp
|
1314
|
+
>>>
|
1315
|
+
>>> # Create a 3D convolution with weight standardization for video
|
1316
|
+
>>> conv = brainstate.nn.ScaledWSConv3d(
|
1317
|
+
... in_size=(16, 64, 64, 3),
|
1318
|
+
... out_channels=32,
|
1319
|
+
... kernel_size=3
|
1320
|
+
... )
|
1321
|
+
>>>
|
1322
|
+
>>> # Apply to input
|
1323
|
+
>>> x = jnp.ones((4, 16, 64, 64, 3))
|
1324
|
+
>>> y = conv(x)
|
1325
|
+
>>> print(y.shape) # (4, 16, 64, 64, 32)
|
1326
|
+
>>>
|
1327
|
+
>>> # For medical imaging with custom parameters
|
1328
|
+
>>> conv = brainstate.nn.ScaledWSConv3d(
|
1329
|
+
... in_size=(32, 32, 32, 1),
|
1330
|
+
... out_channels=64,
|
1331
|
+
... kernel_size=(3, 3, 3),
|
1332
|
+
... stride=2,
|
1333
|
+
... ws_gain=True,
|
1334
|
+
... eps=1e-5,
|
1335
|
+
... b_init=brainstate.init.Constant(0.01)
|
1336
|
+
... )
|
1337
|
+
>>>
|
1338
|
+
>>> # 3D grouped convolution with weight standardization
|
1339
|
+
>>> conv = brainstate.nn.ScaledWSConv3d(
|
1340
|
+
... in_size=(8, 16, 16, 64),
|
1341
|
+
... out_channels=64,
|
1342
|
+
... kernel_size=3,
|
1343
|
+
... groups=8,
|
1344
|
+
... ws_gain=False
|
1345
|
+
... )
|
1346
|
+
|
1347
|
+
Notes
|
1348
|
+
-----
|
1349
|
+
**Weight standardization formula:**
|
1350
|
+
|
1351
|
+
Weight standardization reparameterizes the convolutional weights as:
|
1352
|
+
|
1353
|
+
.. math::
|
1354
|
+
\\hat{W} = g \\cdot \\frac{W - \\mu_W}{\\sigma_W + \\epsilon}
|
1355
|
+
|
1356
|
+
where :math:`\\mu_W` and :math:`\\sigma_W` are the mean and standard deviation
|
1357
|
+
of the weights, :math:`g` is a learnable gain parameter (if ws_gain=True),
|
1358
|
+
and :math:`\\epsilon` is a small constant for numerical stability.
|
1359
|
+
|
1360
|
+
**Why weight standardization for 3D:**
|
1361
|
+
|
1362
|
+
For 3D convolutions, weight standardization is particularly beneficial because:
|
1363
|
+
|
1364
|
+
- 3D networks are typically much deeper and harder to train
|
1365
|
+
- Reduces sensitivity to weight initialization
|
1366
|
+
- Improves gradient flow through very deep networks
|
1367
|
+
- Works well with limited computational resources (small batches)
|
1368
|
+
- Compatible with Group Normalization for batch-independent normalization
|
1369
|
+
|
1370
|
+
**Applications:**
|
1371
|
+
|
1372
|
+
Video understanding, medical imaging (CT, MRI scans), 3D object recognition,
|
1373
|
+
and temporal sequence modeling.
|
1374
|
+
|
1375
|
+
References
|
1376
|
+
----------
|
1377
|
+
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
1378
|
+
Weight Standardization. arXiv preprint arXiv:1903.10520.
|
449
1379
|
"""
|
450
1380
|
__module__ = 'brainstate.nn'
|
451
1381
|
num_spatial_dims: int = 3
|
452
1382
|
|
453
1383
|
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
1384
|
+
class _ConvTranspose(_BaseConv):
|
1385
|
+
"""Base class for transposed convolution layers."""
|
1386
|
+
num_spatial_dims: int = None
|
1387
|
+
|
1388
|
+
def __init__(
|
1389
|
+
self,
|
1390
|
+
in_size: Sequence[int],
|
1391
|
+
out_channels: int,
|
1392
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
1393
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
1394
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
1395
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
1396
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
1397
|
+
groups: int = 1,
|
1398
|
+
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
1399
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
1400
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
1401
|
+
channel_first: bool = False,
|
1402
|
+
name: str = None,
|
1403
|
+
param_type: type = ParamState,
|
1404
|
+
):
|
1405
|
+
# Initialize with transpose=True for dimension numbers
|
1406
|
+
Module.__init__(self, name=name)
|
1407
|
+
|
1408
|
+
# general parameters
|
1409
|
+
assert self.num_spatial_dims + 1 == len(in_size)
|
1410
|
+
self.in_size = tuple(in_size)
|
1411
|
+
self.channel_first = channel_first
|
1412
|
+
self.channels_last = not channel_first
|
1413
|
+
|
1414
|
+
# Determine in_channels based on channel_first
|
1415
|
+
if self.channel_first:
|
1416
|
+
self.in_channels = in_size[0]
|
1417
|
+
else:
|
1418
|
+
self.in_channels = in_size[-1]
|
1419
|
+
|
1420
|
+
self.out_channels = out_channels
|
1421
|
+
self.stride = replicate(stride, self.num_spatial_dims, 'stride')
|
1422
|
+
self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
|
1423
|
+
self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
|
1424
|
+
self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
|
1425
|
+
self.groups = groups
|
1426
|
+
self.dimension_numbers = to_dimension_numbers(
|
1427
|
+
self.num_spatial_dims,
|
1428
|
+
channels_last=self.channels_last,
|
1429
|
+
transpose=True # Key difference from regular Conv
|
1430
|
+
)
|
1431
|
+
|
1432
|
+
# the padding parameter
|
1433
|
+
# For transposed convolution, string padding needs to be converted to explicit padding
|
1434
|
+
# when using lhs_dilation (stride) > 1
|
1435
|
+
if isinstance(padding, str):
|
1436
|
+
assert padding in ['SAME', 'VALID']
|
1437
|
+
self.padding_mode = padding
|
1438
|
+
# Compute explicit padding for transposed convolution
|
1439
|
+
if max(self.stride) > 1:
|
1440
|
+
# For transposed conv with stride, compute padding to achieve desired output size
|
1441
|
+
spatial_in_size = self.in_size[:-1] if not self.channel_first else self.in_size[1:]
|
1442
|
+
if padding == 'SAME':
|
1443
|
+
# For SAME padding with transposed conv: output_size = input_size * stride
|
1444
|
+
# Compute required padding to achieve this
|
1445
|
+
explicit_padding = []
|
1446
|
+
for i, (k, s, in_dim) in enumerate(zip(self.kernel_size, self.stride, spatial_in_size)):
|
1447
|
+
# Desired output size
|
1448
|
+
out_dim = in_dim * s
|
1449
|
+
# Calculate total padding needed
|
1450
|
+
# For transposed conv: out = (in - 1) * stride + kernel - 2 * pad
|
1451
|
+
# Solving for pad: pad = (kernel + (in-1) * stride - out) // 2
|
1452
|
+
total_pad = max(k + (in_dim - 1) * s - out_dim, 0)
|
1453
|
+
pad_left = total_pad // 2
|
1454
|
+
pad_right = total_pad - pad_left
|
1455
|
+
explicit_padding.append((pad_left, pad_right))
|
1456
|
+
padding = tuple(explicit_padding)
|
1457
|
+
else: # 'VALID'
|
1458
|
+
# For VALID padding: no padding
|
1459
|
+
padding = tuple((0, 0) for _ in range(self.num_spatial_dims))
|
1460
|
+
# If stride is 1, keep string padding
|
1461
|
+
elif isinstance(padding, int):
|
1462
|
+
self.padding_mode = 'explicit'
|
1463
|
+
padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
|
1464
|
+
elif isinstance(padding, (tuple, list)):
|
1465
|
+
self.padding_mode = 'explicit'
|
1466
|
+
if isinstance(padding[0], int):
|
1467
|
+
padding = (padding,) * self.num_spatial_dims
|
1468
|
+
elif isinstance(padding[0], (tuple, list)):
|
1469
|
+
if len(padding) == 1:
|
1470
|
+
padding = tuple(padding) * self.num_spatial_dims
|
1471
|
+
else:
|
1472
|
+
if len(padding) != self.num_spatial_dims:
|
1473
|
+
raise ValueError(
|
1474
|
+
f"Padding {padding} must be a Tuple[int, int], "
|
1475
|
+
f"or sequence of Tuple[int, int] with length 1, "
|
1476
|
+
f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
|
1477
|
+
)
|
1478
|
+
padding = tuple(padding)
|
1479
|
+
else:
|
1480
|
+
raise ValueError
|
1481
|
+
self.padding = padding
|
1482
|
+
|
1483
|
+
# the number of in-/out-channels
|
1484
|
+
assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
|
1485
|
+
assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
|
1486
|
+
|
1487
|
+
# kernel shape for transpose conv
|
1488
|
+
# When transpose=True in dimension_numbers, kernel is (spatial..., out_channels, in_channels // groups)
|
1489
|
+
# This matches JAX's expectation for transposed convolution
|
1490
|
+
kernel_shape = tuple(self.kernel_size) + (self.out_channels, self.in_channels // self.groups)
|
1491
|
+
self.kernel_shape = kernel_shape
|
1492
|
+
self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
|
1493
|
+
|
1494
|
+
self.w_initializer = w_init
|
1495
|
+
self.b_initializer = b_init
|
1496
|
+
|
1497
|
+
# --- weights --- #
|
1498
|
+
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
1499
|
+
params = dict(weight=weight)
|
1500
|
+
if self.b_initializer is not None:
|
1501
|
+
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
1502
|
+
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
1503
|
+
params['bias'] = bias
|
1504
|
+
|
1505
|
+
# The weight operation
|
1506
|
+
self.weight = param_type(params)
|
1507
|
+
|
1508
|
+
# Evaluate the output shape
|
1509
|
+
test_input_shape = (128,) + self.in_size
|
1510
|
+
abstract_y = jax.eval_shape(
|
1511
|
+
self._conv_op,
|
1512
|
+
jax.ShapeDtypeStruct(test_input_shape, weight.dtype),
|
1513
|
+
params
|
1514
|
+
)
|
1515
|
+
y_shape = abstract_y.shape[1:]
|
1516
|
+
self.out_size = y_shape
|
1517
|
+
|
1518
|
+
def _conv_op(self, x, params):
|
1519
|
+
w = params['weight']
|
1520
|
+
if self.w_mask is not None:
|
1521
|
+
w = w * self.w_mask
|
1522
|
+
# For transposed convolution:
|
1523
|
+
# - window_strides should be (1,1,...) - no striding in the conv operation
|
1524
|
+
# - lhs_dilation should be the stride - this creates the upsampling effect
|
1525
|
+
window_strides = (1,) * self.num_spatial_dims
|
1526
|
+
y = jax.lax.conv_general_dilated(
|
1527
|
+
lhs=x,
|
1528
|
+
rhs=w,
|
1529
|
+
window_strides=window_strides,
|
1530
|
+
padding=self.padding,
|
1531
|
+
lhs_dilation=self.stride, # For transpose conv, use stride as lhs_dilation
|
1532
|
+
rhs_dilation=self.rhs_dilation,
|
1533
|
+
feature_group_count=self.groups,
|
1534
|
+
dimension_numbers=self.dimension_numbers
|
1535
|
+
)
|
1536
|
+
if 'bias' in params:
|
1537
|
+
y = y + params['bias']
|
1538
|
+
return y
|
1539
|
+
|
1540
|
+
|
1541
|
+
class ConvTranspose1d(_ConvTranspose):
|
1542
|
+
"""
|
1543
|
+
One-dimensional transposed convolution layer (also known as deconvolution).
|
1544
|
+
|
1545
|
+
Applies a 1D transposed convolution over an input signal. Transposed convolution
|
1546
|
+
is used for upsampling, reversing the spatial transformation of a regular convolution.
|
1547
|
+
It's commonly used in autoencoders, GANs, and semantic segmentation networks.
|
1548
|
+
|
1549
|
+
The input should be a 3D array with the shape of ``[B, L, C]`` where B is batch size,
|
1550
|
+
L is the sequence length, and C is the number of input channels (channels-last format).
|
1551
|
+
|
1552
|
+
Parameters
|
1553
|
+
----------
|
1554
|
+
in_size : tuple of int
|
1555
|
+
The input shape without the batch dimension. For ConvTranspose1d: (L, C) where L
|
1556
|
+
is the sequence length and C is the number of input channels.
|
1557
|
+
out_channels : int
|
1558
|
+
The number of output channels (feature maps) produced by the transposed convolution.
|
1559
|
+
kernel_size : int or tuple of int
|
1560
|
+
The shape of the convolutional kernel. For 1D, can be an integer or a single-element tuple.
|
1561
|
+
stride : int or tuple of int, optional
|
1562
|
+
The stride of the transposed convolution. Larger strides produce larger output sizes,
|
1563
|
+
which is the opposite behavior of regular convolution. Default: 1.
|
1564
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
1565
|
+
The padding strategy. Options:
|
1566
|
+
|
1567
|
+
- 'SAME': output length approximately equals input_length * stride
|
1568
|
+
- 'VALID': no padding, maximum output size
|
1569
|
+
- int: symmetric padding
|
1570
|
+
- (pad_before, pad_after): explicit padding for the sequence dimension
|
1571
|
+
|
1572
|
+
Default: 'SAME'.
|
1573
|
+
lhs_dilation : int or tuple of int, optional
|
1574
|
+
The dilation factor for the input. For transposed convolution, this is typically
|
1575
|
+
set equal to stride internally. Default: 1.
|
1576
|
+
rhs_dilation : int or tuple of int, optional
|
1577
|
+
The dilation factor for the kernel. Increases the receptive field without increasing
|
1578
|
+
parameters by inserting zeros between kernel elements. Default: 1.
|
1579
|
+
groups : int, optional
|
1580
|
+
Number of groups for grouped transposed convolution. Both `in_channels` and
|
1581
|
+
`out_channels` must be divisible by `groups`. Default: 1.
|
1582
|
+
w_init : Callable or ArrayLike, optional
|
1583
|
+
The initializer for the convolutional kernel weights. Default: XavierNormal().
|
1584
|
+
b_init : Callable or ArrayLike or None, optional
|
1585
|
+
The initializer for the bias. If None, no bias is added. Default: None.
|
1586
|
+
w_mask : ArrayLike or Callable or None, optional
|
1587
|
+
An optional mask applied to the weights during forward pass. Default: None.
|
1588
|
+
channel_first : bool, optional
|
1589
|
+
If True, uses channels-first format (e.g., [B, C, L]). If False, uses channels-last
|
1590
|
+
format (e.g., [B, L, C]). Default: False (channels-last, JAX convention).
|
1591
|
+
name : str, optional
|
1592
|
+
The name of the module. Default: None.
|
1593
|
+
param_type : type, optional
|
1594
|
+
The type of parameter state to use. Default: ParamState.
|
1595
|
+
|
1596
|
+
Attributes
|
1597
|
+
----------
|
1598
|
+
in_size : tuple of int
|
1599
|
+
The input shape (L, C) without batch dimension.
|
1600
|
+
out_size : tuple of int
|
1601
|
+
The output shape (L_out, out_channels) without batch dimension.
|
1602
|
+
in_channels : int
|
1603
|
+
Number of input channels.
|
1604
|
+
out_channels : int
|
1605
|
+
Number of output channels.
|
1606
|
+
kernel_size : tuple of int
|
1607
|
+
Size of the convolving kernel.
|
1608
|
+
weight : ParamState
|
1609
|
+
The learnable weights (and bias if specified) of the module.
|
1610
|
+
|
1611
|
+
Examples
|
1612
|
+
--------
|
1613
|
+
.. code-block:: python
|
1614
|
+
|
1615
|
+
>>> import brainstate as brainstate
|
1616
|
+
>>> import jax.numpy as jnp
|
1617
|
+
>>>
|
1618
|
+
>>> # Create a 1D transposed convolution layer for upsampling
|
1619
|
+
>>> conv_transpose = brainstate.nn.ConvTranspose1d(
|
1620
|
+
... in_size=(28, 16),
|
1621
|
+
... out_channels=8,
|
1622
|
+
... kernel_size=4,
|
1623
|
+
... stride=2
|
1624
|
+
... )
|
1625
|
+
>>>
|
1626
|
+
>>> # Apply to input: batch_size=2, length=28, channels=16
|
1627
|
+
>>> x = jnp.ones((2, 28, 16))
|
1628
|
+
>>> y = conv_transpose(x)
|
1629
|
+
>>> print(y.shape) # Output will be upsampled
|
1630
|
+
>>>
|
1631
|
+
>>> # Without batch dimension
|
1632
|
+
>>> x_single = jnp.ones((28, 16))
|
1633
|
+
>>> y_single = conv_transpose(x_single)
|
1634
|
+
>>>
|
1635
|
+
>>> # Channels-first format (PyTorch style)
|
1636
|
+
>>> conv_transpose = brainstate.nn.ConvTranspose1d(
|
1637
|
+
... in_size=(16, 28),
|
1638
|
+
... out_channels=8,
|
1639
|
+
... kernel_size=4,
|
1640
|
+
... stride=2,
|
1641
|
+
... channel_first=True
|
1642
|
+
... )
|
1643
|
+
>>> x = jnp.ones((2, 16, 28))
|
1644
|
+
>>> y = conv_transpose(x)
|
1645
|
+
|
1646
|
+
Notes
|
1647
|
+
-----
|
1648
|
+
**Output dimensions:**
|
1649
|
+
|
1650
|
+
Unlike regular convolution, transposed convolution increases spatial dimensions.
|
1651
|
+
With stride > 1, the output is larger than the input:
|
1652
|
+
|
1653
|
+
- output_length ≈ input_length * stride (depends on padding and kernel size)
|
1654
|
+
|
1655
|
+
**Relationship to regular convolution:**
|
1656
|
+
|
1657
|
+
Transposed convolution performs the gradient computation of a regular convolution
|
1658
|
+
with respect to its input. It's sometimes called "deconvolution" but this term
|
1659
|
+
is mathematically imprecise.
|
1660
|
+
|
1661
|
+
**Common use cases:**
|
1662
|
+
|
1663
|
+
- Upsampling in encoder-decoder architectures
|
1664
|
+
- Generative models (GANs, VAEs)
|
1665
|
+
- Semantic segmentation (U-Net, FCN)
|
1666
|
+
- Super-resolution networks
|
1667
|
+
|
1668
|
+
**Comparison with PyTorch:**
|
1669
|
+
|
1670
|
+
- PyTorch uses channels-first by default; BrainState uses channels-last
|
1671
|
+
- Set `channel_first=True` for PyTorch-compatible format
|
1672
|
+
- PyTorch's `output_padding` is handled through padding parameter
|
1673
|
+
"""
|
1674
|
+
__module__ = 'brainstate.nn'
|
1675
|
+
num_spatial_dims: int = 1
|
1676
|
+
|
1677
|
+
|
1678
|
+
class ConvTranspose2d(_ConvTranspose):
|
1679
|
+
"""
|
1680
|
+
Two-dimensional transposed convolution layer (also known as deconvolution).
|
1681
|
+
|
1682
|
+
Applies a 2D transposed convolution over an input signal. Transposed convolution
|
1683
|
+
is the gradient of a regular convolution with respect to its input, commonly used
|
1684
|
+
for upsampling feature maps in encoder-decoder architectures, GANs, and segmentation.
|
1685
|
+
|
1686
|
+
The input should be a 4D array with the shape of ``[B, H, W, C]`` where B is batch size,
|
1687
|
+
H is height, W is width, and C is the number of input channels (channels-last format).
|
1688
|
+
|
1689
|
+
Parameters
|
1690
|
+
----------
|
1691
|
+
in_size : tuple of int
|
1692
|
+
The input shape without the batch dimension. For ConvTranspose2d: (H, W, C) where
|
1693
|
+
H is height, W is width, and C is the number of input channels.
|
1694
|
+
out_channels : int
|
1695
|
+
The number of output channels (feature maps) produced by the transposed convolution.
|
1696
|
+
kernel_size : int or tuple of int
|
1697
|
+
The shape of the convolutional kernel. Can be:
|
1698
|
+
|
1699
|
+
- An integer (e.g., 4): creates a square kernel (4, 4)
|
1700
|
+
- A tuple of two integers (e.g., (4, 4)): creates a (height, width) kernel
|
1701
|
+
stride : int or tuple of int, optional
|
1702
|
+
The stride of the transposed convolution. Controls the upsampling factor.
|
1703
|
+
Can be:
|
1704
|
+
|
1705
|
+
- An integer: same stride for both dimensions
|
1706
|
+
- A tuple of two integers: (stride_height, stride_width)
|
1707
|
+
|
1708
|
+
Larger strides produce larger outputs. Default: 1.
|
1709
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
1710
|
+
The padding strategy. Options:
|
1711
|
+
|
1712
|
+
- 'SAME': output size approximately equals input_size * stride
|
1713
|
+
- 'VALID': no padding, maximum output size
|
1714
|
+
- int: same symmetric padding for all dimensions
|
1715
|
+
- (pad_h, pad_w): different padding for each dimension
|
1716
|
+
- [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after)]: explicit padding
|
1717
|
+
|
1718
|
+
Default: 'SAME'.
|
1719
|
+
lhs_dilation : int or tuple of int, optional
|
1720
|
+
The dilation factor for the input. For transposed convolution, this is typically
|
1721
|
+
set equal to stride internally. Default: 1.
|
1722
|
+
rhs_dilation : int or tuple of int, optional
|
1723
|
+
The dilation factor for the kernel. Increases the receptive field without increasing
|
1724
|
+
parameters by inserting zeros between kernel elements. Default: 1.
|
1725
|
+
groups : int, optional
|
1726
|
+
Number of groups for grouped transposed convolution. Must divide both `in_channels`
|
1727
|
+
and `out_channels`. Default: 1.
|
1728
|
+
w_init : Callable or ArrayLike, optional
|
1729
|
+
Weight initializer for the convolutional kernel. Default: XavierNormal().
|
1730
|
+
b_init : Callable or ArrayLike or None, optional
|
1731
|
+
Bias initializer. If None, no bias term is added. Default: None.
|
1732
|
+
w_mask : ArrayLike or Callable or None, optional
|
1733
|
+
Optional weight mask for structured sparsity. Default: None.
|
1734
|
+
channel_first : bool, optional
|
1735
|
+
If True, uses channels-first format (e.g., [B, C, H, W]). If False, uses channels-last
|
1736
|
+
format (e.g., [B, H, W, C]). Default: False (channels-last, JAX convention).
|
1737
|
+
name : str, optional
|
1738
|
+
Name identifier for this module instance. Default: None.
|
1739
|
+
param_type : type, optional
|
1740
|
+
The parameter state class to use. Default: ParamState.
|
1741
|
+
|
1742
|
+
Attributes
|
1743
|
+
----------
|
1744
|
+
in_size : tuple of int
|
1745
|
+
The input shape (H, W, C) without batch dimension.
|
1746
|
+
out_size : tuple of int
|
1747
|
+
The output shape (H_out, W_out, out_channels) without batch dimension.
|
1748
|
+
in_channels : int
|
1749
|
+
Number of input channels.
|
1750
|
+
out_channels : int
|
1751
|
+
Number of output channels.
|
1752
|
+
kernel_size : tuple of int
|
1753
|
+
Size of the convolving kernel (height, width).
|
1754
|
+
weight : ParamState
|
1755
|
+
The learnable weights (and bias if specified) of the module.
|
1756
|
+
|
1757
|
+
Examples
|
1758
|
+
--------
|
1759
|
+
.. code-block:: python
|
1760
|
+
|
1761
|
+
>>> import brainstate as brainstate
|
1762
|
+
>>> import jax.numpy as jnp
|
1763
|
+
>>>
|
1764
|
+
>>> # Create a 2D transposed convolution for upsampling
|
1765
|
+
>>> conv_transpose = brainstate.nn.ConvTranspose2d(
|
1766
|
+
... in_size=(32, 32, 64),
|
1767
|
+
... out_channels=32,
|
1768
|
+
... kernel_size=4,
|
1769
|
+
... stride=2
|
1770
|
+
... )
|
1771
|
+
>>>
|
1772
|
+
>>> # Apply to input: batch_size=8, height=32, width=32, channels=64
|
1773
|
+
>>> x = jnp.ones((8, 32, 32, 64))
|
1774
|
+
>>> y = conv_transpose(x)
|
1775
|
+
>>> print(y.shape) # Output will be approximately (8, 64, 64, 32)
|
1776
|
+
>>>
|
1777
|
+
>>> # Without batch dimension
|
1778
|
+
>>> x_single = jnp.ones((32, 32, 64))
|
1779
|
+
>>> y_single = conv_transpose(x_single)
|
1780
|
+
>>>
|
1781
|
+
>>> # Decoder in autoencoder (upsampling path)
|
1782
|
+
>>> decoder = brainstate.nn.ConvTranspose2d(
|
1783
|
+
... in_size=(16, 16, 128),
|
1784
|
+
... out_channels=64,
|
1785
|
+
... kernel_size=4,
|
1786
|
+
... stride=2,
|
1787
|
+
... padding='SAME',
|
1788
|
+
... b_init=brainstate.init.Constant(0.0)
|
1789
|
+
... )
|
1790
|
+
>>>
|
1791
|
+
>>> # Channels-first format (PyTorch style)
|
1792
|
+
>>> conv_transpose = brainstate.nn.ConvTranspose2d(
|
1793
|
+
... in_size=(64, 32, 32),
|
1794
|
+
... out_channels=32,
|
1795
|
+
... kernel_size=4,
|
1796
|
+
... stride=2,
|
1797
|
+
... channel_first=True
|
1798
|
+
... )
|
1799
|
+
>>> x = jnp.ones((8, 64, 32, 32))
|
1800
|
+
>>> y = conv_transpose(x)
|
1801
|
+
|
1802
|
+
Notes
|
1803
|
+
-----
|
1804
|
+
**Output dimensions:**
|
1805
|
+
|
1806
|
+
Transposed convolution increases spatial dimensions, with the upsampling factor
|
1807
|
+
primarily controlled by stride:
|
1808
|
+
|
1809
|
+
- output_size ≈ input_size * stride (exact size depends on padding and kernel size)
|
1810
|
+
- 'SAME' padding: output_size = input_size * stride
|
1811
|
+
- 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
|
1812
|
+
|
1813
|
+
**Relationship to regular convolution:**
|
1814
|
+
|
1815
|
+
Transposed convolution is the backward pass of a regular convolution. If a regular
|
1816
|
+
convolution reduces spatial dimensions from X to Y, a transposed convolution with
|
1817
|
+
the same parameters increases dimensions from Y back to approximately X.
|
1818
|
+
|
1819
|
+
**Common use cases:**
|
1820
|
+
|
1821
|
+
- Image segmentation (U-Net, SegNet, FCN)
|
1822
|
+
- Image-to-image translation (pix2pix, CycleGAN)
|
1823
|
+
- Generative models (DCGAN, VAE decoders)
|
1824
|
+
- Super-resolution networks
|
1825
|
+
- Autoencoders (decoder path)
|
1826
|
+
|
1827
|
+
**Comparison with PyTorch:**
|
1828
|
+
|
1829
|
+
- PyTorch uses channels-first by default; BrainState uses channels-last
|
1830
|
+
- Set `channel_first=True` for PyTorch-compatible format
|
1831
|
+
- Kernel shape convention: PyTorch stores (C_in, C_out, H, W), BrainState uses (H, W, C_out, C_in)
|
1832
|
+
- PyTorch's `output_padding` parameter controls output size; use padding parameter here
|
1833
|
+
|
1834
|
+
**Tips:**
|
1835
|
+
|
1836
|
+
- Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
|
1837
|
+
- Initialize with bilinear upsampling weights for better convergence in segmentation
|
1838
|
+
- Combine with batch normalization or group normalization for stable training
|
1839
|
+
"""
|
1840
|
+
__module__ = 'brainstate.nn'
|
1841
|
+
num_spatial_dims: int = 2
|
1842
|
+
|
1843
|
+
|
1844
|
+
class ConvTranspose3d(_ConvTranspose):
|
1845
|
+
"""
|
1846
|
+
Three-dimensional transposed convolution layer (also known as deconvolution).
|
1847
|
+
|
1848
|
+
Applies a 3D transposed convolution over an input signal. Used for upsampling
|
1849
|
+
3D feature maps in video generation, 3D segmentation, and volumetric reconstruction.
|
1850
|
+
|
1851
|
+
The input should be a 5D array with the shape of ``[B, H, W, D, C]`` where B is batch size,
|
1852
|
+
H is height, W is width, D is depth, and C is the number of input channels (channels-last format).
|
1853
|
+
|
1854
|
+
Parameters
|
1855
|
+
----------
|
1856
|
+
in_size : tuple of int
|
1857
|
+
The input shape without the batch dimension. For ConvTranspose3d: (H, W, D, C) where
|
1858
|
+
H is height, W is width, D is depth, and C is the number of input channels.
|
1859
|
+
out_channels : int
|
1860
|
+
The number of output channels (feature maps) produced by the transposed convolution.
|
1861
|
+
kernel_size : int or tuple of int
|
1862
|
+
The shape of the convolutional kernel. Can be:
|
1863
|
+
|
1864
|
+
- An integer (e.g., 4): creates a cubic kernel (4, 4, 4)
|
1865
|
+
- A tuple of three integers (e.g., (4, 4, 4)): creates a (height, width, depth) kernel
|
1866
|
+
stride : int or tuple of int, optional
|
1867
|
+
The stride of the transposed convolution. Controls the upsampling factor.
|
1868
|
+
Can be:
|
1869
|
+
|
1870
|
+
- An integer: same stride for all dimensions
|
1871
|
+
- A tuple of three integers: (stride_h, stride_w, stride_d)
|
1872
|
+
|
1873
|
+
Larger strides produce larger outputs. Default: 1.
|
1874
|
+
padding : {'SAME', 'VALID'} or int or tuple of int or sequence of tuple, optional
|
1875
|
+
The padding strategy. Options:
|
1876
|
+
|
1877
|
+
- 'SAME': output size approximately equals input_size * stride
|
1878
|
+
- 'VALID': no padding, maximum output size
|
1879
|
+
- int: same symmetric padding for all dimensions
|
1880
|
+
- (pad_h, pad_w, pad_d): different padding for each dimension
|
1881
|
+
- [(pad_h_before, pad_h_after), (pad_w_before, pad_w_after), (pad_d_before, pad_d_after)]: explicit
|
1882
|
+
|
1883
|
+
Default: 'SAME'.
|
1884
|
+
lhs_dilation : int or tuple of int, optional
|
1885
|
+
The dilation factor for the input. For transposed convolution, this is typically
|
1886
|
+
set equal to stride internally. Default: 1.
|
1887
|
+
rhs_dilation : int or tuple of int, optional
|
1888
|
+
The dilation factor for the kernel. Increases the receptive field without increasing
|
1889
|
+
parameters. Default: 1.
|
1890
|
+
groups : int, optional
|
1891
|
+
Number of groups for grouped transposed convolution. Must divide both `in_channels`
|
1892
|
+
and `out_channels`. Useful for reducing computational cost in 3D. Default: 1.
|
1893
|
+
w_init : Callable or ArrayLike, optional
|
1894
|
+
Weight initializer for the convolutional kernel. Default: XavierNormal().
|
1895
|
+
b_init : Callable or ArrayLike or None, optional
|
1896
|
+
Bias initializer. If None, no bias term is added. Default: None.
|
1897
|
+
w_mask : ArrayLike or Callable or None, optional
|
1898
|
+
Optional weight mask for structured sparsity. Default: None.
|
1899
|
+
channel_first : bool, optional
|
1900
|
+
If True, uses channels-first format (e.g., [B, C, H, W, D]). If False, uses channels-last
|
1901
|
+
format (e.g., [B, H, W, D, C]). Default: False (channels-last, JAX convention).
|
1902
|
+
name : str, optional
|
1903
|
+
Name identifier for this module instance. Default: None.
|
1904
|
+
param_type : type, optional
|
1905
|
+
The parameter state class to use. Default: ParamState.
|
1906
|
+
|
1907
|
+
Attributes
|
1908
|
+
----------
|
1909
|
+
in_size : tuple of int
|
1910
|
+
The input shape (H, W, D, C) without batch dimension.
|
1911
|
+
out_size : tuple of int
|
1912
|
+
The output shape (H_out, W_out, D_out, out_channels) without batch dimension.
|
1913
|
+
in_channels : int
|
1914
|
+
Number of input channels.
|
1915
|
+
out_channels : int
|
1916
|
+
Number of output channels.
|
1917
|
+
kernel_size : tuple of int
|
1918
|
+
Size of the convolving kernel (height, width, depth).
|
1919
|
+
weight : ParamState
|
1920
|
+
The learnable weights (and bias if specified) of the module.
|
1921
|
+
|
1922
|
+
Examples
|
1923
|
+
--------
|
1924
|
+
.. code-block:: python
|
1925
|
+
|
1926
|
+
>>> import brainstate as brainstate
|
1927
|
+
>>> import jax.numpy as jnp
|
1928
|
+
>>>
|
1929
|
+
>>> # Create a 3D transposed convolution for video upsampling
|
1930
|
+
>>> conv_transpose = brainstate.nn.ConvTranspose3d(
|
1931
|
+
... in_size=(8, 16, 16, 64),
|
1932
|
+
... out_channels=32,
|
1933
|
+
... kernel_size=4,
|
1934
|
+
... stride=2
|
1935
|
+
... )
|
1936
|
+
>>>
|
1937
|
+
>>> # Apply to input: batch_size=4, frames=8, height=16, width=16, channels=64
|
1938
|
+
>>> x = jnp.ones((4, 8, 16, 16, 64))
|
1939
|
+
>>> y = conv_transpose(x)
|
1940
|
+
>>> print(y.shape) # Output will be approximately (4, 16, 32, 32, 32)
|
1941
|
+
>>>
|
1942
|
+
>>> # Without batch dimension
|
1943
|
+
>>> x_single = jnp.ones((8, 16, 16, 64))
|
1944
|
+
>>> y_single = conv_transpose(x_single)
|
1945
|
+
>>>
|
1946
|
+
>>> # For medical imaging reconstruction
|
1947
|
+
>>> decoder = brainstate.nn.ConvTranspose3d(
|
1948
|
+
... in_size=(16, 16, 16, 128),
|
1949
|
+
... out_channels=64,
|
1950
|
+
... kernel_size=(4, 4, 4),
|
1951
|
+
... stride=2,
|
1952
|
+
... padding='SAME',
|
1953
|
+
... b_init=brainstate.init.Constant(0.0)
|
1954
|
+
... )
|
1955
|
+
>>>
|
1956
|
+
>>> # Channels-first format (PyTorch style)
|
1957
|
+
>>> conv_transpose = brainstate.nn.ConvTranspose3d(
|
1958
|
+
... in_size=(64, 8, 16, 16),
|
1959
|
+
... out_channels=32,
|
1960
|
+
... kernel_size=4,
|
1961
|
+
... stride=2,
|
1962
|
+
... channel_first=True
|
1963
|
+
... )
|
1964
|
+
>>> x = jnp.ones((4, 64, 8, 16, 16))
|
1965
|
+
>>> y = conv_transpose(x)
|
1966
|
+
|
1967
|
+
Notes
|
1968
|
+
-----
|
1969
|
+
**Output dimensions:**
|
1970
|
+
|
1971
|
+
Transposed convolution increases spatial dimensions:
|
1972
|
+
|
1973
|
+
- output_size ≈ input_size * stride (exact size depends on padding and kernel size)
|
1974
|
+
- 'SAME' padding: output_size = input_size * stride
|
1975
|
+
- 'VALID' padding: output_size = input_size * stride + max(kernel_size - stride, 0)
|
1976
|
+
|
1977
|
+
**Computational considerations:**
|
1978
|
+
|
1979
|
+
3D transposed convolutions are very computationally expensive. Consider:
|
1980
|
+
|
1981
|
+
- Using grouped convolutions (groups > 1) to reduce parameters
|
1982
|
+
- Smaller kernel sizes
|
1983
|
+
- Progressive upsampling (multiple layers with stride=2)
|
1984
|
+
- Separable convolutions for large-scale applications
|
1985
|
+
|
1986
|
+
**Common use cases:**
|
1987
|
+
|
1988
|
+
- Video generation and prediction
|
1989
|
+
- 3D medical image segmentation (U-Net 3D)
|
1990
|
+
- Volumetric reconstruction
|
1991
|
+
- 3D super-resolution
|
1992
|
+
- Video frame interpolation
|
1993
|
+
- 3D VAE decoders
|
1994
|
+
|
1995
|
+
**Comparison with PyTorch:**
|
1996
|
+
|
1997
|
+
- PyTorch uses channels-first by default; BrainState uses channels-last
|
1998
|
+
- Set `channel_first=True` for PyTorch-compatible format
|
1999
|
+
- Kernel shape convention differs between frameworks
|
2000
|
+
- PyTorch's `output_padding` parameter is handled through padding here
|
2001
|
+
|
2002
|
+
**Tips:**
|
2003
|
+
|
2004
|
+
- Use kernel_size=stride*2 for smooth upsampling (e.g., kernel_size=4, stride=2)
|
2005
|
+
- Group normalization often works better than batch normalization for 3D
|
2006
|
+
- Consider using smaller batch sizes due to memory constraints
|
2007
|
+
- Progressive upsampling (2x at a time) is more stable than large strides
|
2008
|
+
"""
|
2009
|
+
__module__ = 'brainstate.nn'
|
2010
|
+
num_spatial_dims: int = 3
|