brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_conv.py
CHANGED
@@ -1,501 +1,501 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import collections.abc
|
19
|
-
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
20
|
-
|
21
|
-
import jax
|
22
|
-
import jax.numpy as jnp
|
23
|
-
|
24
|
-
from brainstate import init, functional
|
25
|
-
from brainstate._state import ParamState
|
26
|
-
from brainstate.typing import ArrayLike
|
27
|
-
from ._module import Module
|
28
|
-
|
29
|
-
T = TypeVar('T')
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'Conv1d', 'Conv2d', 'Conv3d',
|
33
|
-
'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
|
34
|
-
]
|
35
|
-
|
36
|
-
|
37
|
-
def to_dimension_numbers(
|
38
|
-
num_spatial_dims: int,
|
39
|
-
channels_last: bool,
|
40
|
-
transpose: bool
|
41
|
-
) -> jax.lax.ConvDimensionNumbers:
|
42
|
-
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
|
43
|
-
num_dims = num_spatial_dims + 2
|
44
|
-
if channels_last:
|
45
|
-
spatial_dims = tuple(range(1, num_dims - 1))
|
46
|
-
image_dn = (0, num_dims - 1) + spatial_dims
|
47
|
-
else:
|
48
|
-
spatial_dims = tuple(range(2, num_dims))
|
49
|
-
image_dn = (0, 1) + spatial_dims
|
50
|
-
if transpose:
|
51
|
-
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
|
52
|
-
else:
|
53
|
-
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
|
54
|
-
return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
|
55
|
-
rhs_spec=kernel_dn,
|
56
|
-
out_spec=image_dn)
|
57
|
-
|
58
|
-
|
59
|
-
def replicate(
|
60
|
-
element: Union[T, Sequence[T]],
|
61
|
-
num_replicate: int,
|
62
|
-
name: str,
|
63
|
-
) -> Tuple[T, ...]:
|
64
|
-
"""Replicates entry in `element` `num_replicate` if needed."""
|
65
|
-
if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
|
66
|
-
return (element,) * num_replicate
|
67
|
-
elif len(element) == 1:
|
68
|
-
return tuple(list(element) * num_replicate)
|
69
|
-
elif len(element) == num_replicate:
|
70
|
-
return tuple(element)
|
71
|
-
else:
|
72
|
-
raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
|
73
|
-
f"sequence of length {num_replicate}.")
|
74
|
-
|
75
|
-
|
76
|
-
class _BaseConv(Module):
|
77
|
-
# the number of spatial dimensions
|
78
|
-
num_spatial_dims: int
|
79
|
-
|
80
|
-
# the weight and its operations
|
81
|
-
weight: ParamState
|
82
|
-
|
83
|
-
def __init__(
|
84
|
-
self,
|
85
|
-
in_size: Sequence[int],
|
86
|
-
out_channels: int,
|
87
|
-
kernel_size: Union[int, Tuple[int, ...]],
|
88
|
-
stride: Union[int, Tuple[int, ...]] = 1,
|
89
|
-
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
90
|
-
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
91
|
-
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
92
|
-
groups: int = 1,
|
93
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
94
|
-
name: str = None,
|
95
|
-
):
|
96
|
-
super().__init__(name=name)
|
97
|
-
|
98
|
-
# general parameters
|
99
|
-
assert self.num_spatial_dims + 1 == len(in_size)
|
100
|
-
self.in_size = tuple(in_size)
|
101
|
-
self.in_channels = in_size[-1]
|
102
|
-
self.out_channels = out_channels
|
103
|
-
self.stride = replicate(stride, self.num_spatial_dims, 'stride')
|
104
|
-
self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
|
105
|
-
self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
|
106
|
-
self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
|
107
|
-
self.groups = groups
|
108
|
-
self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
|
109
|
-
|
110
|
-
# the padding parameter
|
111
|
-
if isinstance(padding, str):
|
112
|
-
assert padding in ['SAME', 'VALID']
|
113
|
-
elif isinstance(padding, int):
|
114
|
-
padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
|
115
|
-
elif isinstance(padding, (tuple, list)):
|
116
|
-
if isinstance(padding[0], int):
|
117
|
-
padding = (padding,) * self.num_spatial_dims
|
118
|
-
elif isinstance(padding[0], (tuple, list)):
|
119
|
-
if len(padding) == 1:
|
120
|
-
padding = tuple(padding) * self.num_spatial_dims
|
121
|
-
else:
|
122
|
-
if len(padding) != self.num_spatial_dims:
|
123
|
-
raise ValueError(
|
124
|
-
f"Padding {padding} must be a Tuple[int, int], "
|
125
|
-
f"or sequence of Tuple[int, int] with length 1, "
|
126
|
-
f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
|
127
|
-
)
|
128
|
-
padding = tuple(padding)
|
129
|
-
else:
|
130
|
-
raise ValueError
|
131
|
-
self.padding = padding
|
132
|
-
|
133
|
-
# the number of in-/out-channels
|
134
|
-
assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
|
135
|
-
assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
|
136
|
-
|
137
|
-
# kernel shape and w_mask
|
138
|
-
kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
|
139
|
-
self.kernel_shape = kernel_shape
|
140
|
-
self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
|
141
|
-
|
142
|
-
def _check_input_dim(self, x):
|
143
|
-
if x.ndim == self.num_spatial_dims + 2:
|
144
|
-
x_shape = x.shape[1:]
|
145
|
-
elif x.ndim == self.num_spatial_dims + 1:
|
146
|
-
x_shape = x.shape
|
147
|
-
else:
|
148
|
-
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
149
|
-
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
150
|
-
if self.in_size != x_shape:
|
151
|
-
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
152
|
-
|
153
|
-
def update(self, x):
|
154
|
-
self._check_input_dim(x)
|
155
|
-
non_batching = False
|
156
|
-
if x.ndim == self.num_spatial_dims + 1:
|
157
|
-
x = jnp.expand_dims(x, 0)
|
158
|
-
non_batching = True
|
159
|
-
y = self._conv_op(x, self.weight.value)
|
160
|
-
return y[0] if non_batching else y
|
161
|
-
|
162
|
-
def _conv_op(self, x, params):
|
163
|
-
raise NotImplementedError
|
164
|
-
|
165
|
-
def __repr__(self):
|
166
|
-
return (f'{self.__class__.__name__}('
|
167
|
-
f'in_channels={self.in_channels}, '
|
168
|
-
f'out_channels={self.out_channels}, '
|
169
|
-
f'kernel_size={self.kernel_size}, '
|
170
|
-
f'stride={self.stride}, '
|
171
|
-
f'padding={self.padding}, '
|
172
|
-
f'groups={self.groups})')
|
173
|
-
|
174
|
-
|
175
|
-
class _Conv(_BaseConv):
|
176
|
-
num_spatial_dims: int = None
|
177
|
-
|
178
|
-
def __init__(
|
179
|
-
self,
|
180
|
-
in_size: Sequence[int],
|
181
|
-
out_channels: int,
|
182
|
-
kernel_size: Union[int, Tuple[int, ...]],
|
183
|
-
stride: Union[int, Tuple[int, ...]] = 1,
|
184
|
-
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
185
|
-
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
186
|
-
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
187
|
-
groups: int = 1,
|
188
|
-
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
189
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
190
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
191
|
-
name: str = None,
|
192
|
-
param_type: type = ParamState,
|
193
|
-
):
|
194
|
-
super().__init__(
|
195
|
-
in_size=in_size,
|
196
|
-
out_channels=out_channels,
|
197
|
-
kernel_size=kernel_size,
|
198
|
-
stride=stride,
|
199
|
-
padding=padding,
|
200
|
-
lhs_dilation=lhs_dilation,
|
201
|
-
rhs_dilation=rhs_dilation,
|
202
|
-
groups=groups,
|
203
|
-
w_mask=w_mask,
|
204
|
-
name=name
|
205
|
-
)
|
206
|
-
|
207
|
-
self.w_initializer = w_init
|
208
|
-
self.b_initializer = b_init
|
209
|
-
|
210
|
-
# --- weights --- #
|
211
|
-
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
212
|
-
params = dict(weight=weight)
|
213
|
-
if self.b_initializer is not None:
|
214
|
-
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
215
|
-
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
216
|
-
params['bias'] = bias
|
217
|
-
|
218
|
-
# The weight operation
|
219
|
-
self.weight = param_type(params)
|
220
|
-
|
221
|
-
# Evaluate the output shape
|
222
|
-
abstract_y = jax.eval_shape(
|
223
|
-
self._conv_op,
|
224
|
-
jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
|
225
|
-
params
|
226
|
-
)
|
227
|
-
y_shape = abstract_y.shape[1:]
|
228
|
-
self.out_size = y_shape
|
229
|
-
|
230
|
-
def _conv_op(self, x, params):
|
231
|
-
w = params['weight']
|
232
|
-
if self.w_mask is not None:
|
233
|
-
w = w * self.w_mask
|
234
|
-
y = jax.lax.conv_general_dilated(
|
235
|
-
lhs=x,
|
236
|
-
rhs=w,
|
237
|
-
window_strides=self.stride,
|
238
|
-
padding=self.padding,
|
239
|
-
lhs_dilation=self.lhs_dilation,
|
240
|
-
rhs_dilation=self.rhs_dilation,
|
241
|
-
feature_group_count=self.groups,
|
242
|
-
dimension_numbers=self.dimension_numbers
|
243
|
-
)
|
244
|
-
if 'bias' in params:
|
245
|
-
y = y + params['bias']
|
246
|
-
return y
|
247
|
-
|
248
|
-
|
249
|
-
class Conv1d(_Conv):
|
250
|
-
"""One-dimensional convolution.
|
251
|
-
|
252
|
-
The input should be a 3d array with the shape of ``[B, H, C]``.
|
253
|
-
|
254
|
-
Parameters
|
255
|
-
----------
|
256
|
-
%s
|
257
|
-
"""
|
258
|
-
__module__ = 'brainstate.nn'
|
259
|
-
num_spatial_dims: int = 1
|
260
|
-
|
261
|
-
|
262
|
-
class Conv2d(_Conv):
|
263
|
-
"""Two-dimensional convolution.
|
264
|
-
|
265
|
-
The input should be a 4d array with the shape of ``[B, H, W, C]``.
|
266
|
-
|
267
|
-
Parameters
|
268
|
-
----------
|
269
|
-
%s
|
270
|
-
"""
|
271
|
-
__module__ = 'brainstate.nn'
|
272
|
-
num_spatial_dims: int = 2
|
273
|
-
|
274
|
-
|
275
|
-
class Conv3d(_Conv):
|
276
|
-
"""Three-dimensional convolution.
|
277
|
-
|
278
|
-
The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
|
279
|
-
|
280
|
-
Parameters
|
281
|
-
----------
|
282
|
-
%s
|
283
|
-
"""
|
284
|
-
__module__ = 'brainstate.nn'
|
285
|
-
num_spatial_dims: int = 3
|
286
|
-
|
287
|
-
|
288
|
-
_conv_doc = '''
|
289
|
-
in_size: tuple of int
|
290
|
-
The input shape, without the batch size. This argument is important, since it is
|
291
|
-
used to evaluate the shape of the output.
|
292
|
-
out_channels: int
|
293
|
-
The number of output channels.
|
294
|
-
kernel_size: int, sequence of int
|
295
|
-
The shape of the convolutional kernel.
|
296
|
-
For 1D convolution, the kernel size can be passed as an integer.
|
297
|
-
For all other cases, it must be a sequence of integers.
|
298
|
-
stride: int, sequence of int
|
299
|
-
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
300
|
-
padding: str, int, sequence of int, sequence of tuple
|
301
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
302
|
-
high)` integer pairs that give the padding to apply before and after each
|
303
|
-
spatial dimension.
|
304
|
-
lhs_dilation: int, sequence of int
|
305
|
-
An integer or a sequence of `n` integers, giving the
|
306
|
-
dilation factor to apply in each spatial dimension of `inputs`
|
307
|
-
(default: 1). Convolution with input dilation `d` is equivalent to
|
308
|
-
transposed convolution with stride `d`.
|
309
|
-
rhs_dilation: int, sequence of int
|
310
|
-
An integer or a sequence of `n` integers, giving the
|
311
|
-
dilation factor to apply in each spatial dimension of the convolution
|
312
|
-
kernel (default: 1). Convolution with kernel dilation
|
313
|
-
is also known as 'atrous convolution'.
|
314
|
-
groups: int
|
315
|
-
If specified, divides the input features into groups. default 1.
|
316
|
-
w_init: Callable, ArrayLike, Initializer
|
317
|
-
The initializer for the convolutional kernel.
|
318
|
-
b_init: Optional, Callable, ArrayLike, Initializer
|
319
|
-
The initializer for the bias.
|
320
|
-
w_mask: ArrayLike, Callable, Optional
|
321
|
-
The optional mask of the weights.
|
322
|
-
mode: Mode
|
323
|
-
The computation mode of the current object. Default it is `training`.
|
324
|
-
name: str, Optional
|
325
|
-
The name of the object.
|
326
|
-
'''
|
327
|
-
|
328
|
-
Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
|
329
|
-
Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
|
330
|
-
Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
|
331
|
-
|
332
|
-
|
333
|
-
class _ScaledWSConv(_BaseConv):
|
334
|
-
def __init__(
|
335
|
-
self,
|
336
|
-
in_size: Sequence[int],
|
337
|
-
out_channels: int,
|
338
|
-
kernel_size: Union[int, Tuple[int, ...]],
|
339
|
-
stride: Union[int, Tuple[int, ...]] = 1,
|
340
|
-
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
341
|
-
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
342
|
-
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
343
|
-
groups: int = 1,
|
344
|
-
ws_gain: bool = True,
|
345
|
-
eps: float = 1e-4,
|
346
|
-
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
347
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
348
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
349
|
-
name: str = None,
|
350
|
-
param_type: type = ParamState,
|
351
|
-
):
|
352
|
-
super().__init__(in_size=in_size,
|
353
|
-
out_channels=out_channels,
|
354
|
-
kernel_size=kernel_size,
|
355
|
-
stride=stride,
|
356
|
-
padding=padding,
|
357
|
-
lhs_dilation=lhs_dilation,
|
358
|
-
rhs_dilation=rhs_dilation,
|
359
|
-
groups=groups,
|
360
|
-
w_mask=w_mask,
|
361
|
-
name=name, )
|
362
|
-
|
363
|
-
self.w_initializer = w_init
|
364
|
-
self.b_initializer = b_init
|
365
|
-
|
366
|
-
# --- weights --- #
|
367
|
-
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
368
|
-
params = dict(weight=weight)
|
369
|
-
if self.b_initializer is not None:
|
370
|
-
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
371
|
-
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
372
|
-
params['bias'] = bias
|
373
|
-
|
374
|
-
# gain
|
375
|
-
if ws_gain:
|
376
|
-
gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
|
377
|
-
ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
|
378
|
-
params['gain'] = ws_gain
|
379
|
-
|
380
|
-
# Epsilon, a small constant to avoid dividing by zero.
|
381
|
-
self.eps = eps
|
382
|
-
|
383
|
-
# The weight operation
|
384
|
-
self.weight = param_type(params)
|
385
|
-
|
386
|
-
# Evaluate the output shape
|
387
|
-
abstract_y = jax.eval_shape(
|
388
|
-
self._conv_op,
|
389
|
-
jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
|
390
|
-
params
|
391
|
-
)
|
392
|
-
y_shape = abstract_y.shape[1:]
|
393
|
-
self.out_size = y_shape
|
394
|
-
|
395
|
-
def _conv_op(self, x, params):
|
396
|
-
w = params['weight']
|
397
|
-
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
398
|
-
if self.w_mask is not None:
|
399
|
-
w = w * self.w_mask
|
400
|
-
y = jax.lax.conv_general_dilated(
|
401
|
-
lhs=x,
|
402
|
-
rhs=w,
|
403
|
-
window_strides=self.stride,
|
404
|
-
padding=self.padding,
|
405
|
-
lhs_dilation=self.lhs_dilation,
|
406
|
-
rhs_dilation=self.rhs_dilation,
|
407
|
-
feature_group_count=self.groups,
|
408
|
-
dimension_numbers=self.dimension_numbers
|
409
|
-
)
|
410
|
-
if 'bias' in params:
|
411
|
-
y = y + params['bias']
|
412
|
-
return y
|
413
|
-
|
414
|
-
|
415
|
-
class ScaledWSConv1d(_ScaledWSConv):
|
416
|
-
"""One-dimensional convolution with weight standardization.
|
417
|
-
|
418
|
-
The input should be a 3d array with the shape of ``[B, H, C]``.
|
419
|
-
|
420
|
-
Parameters
|
421
|
-
----------
|
422
|
-
%s
|
423
|
-
"""
|
424
|
-
__module__ = 'brainstate.nn'
|
425
|
-
num_spatial_dims: int = 1
|
426
|
-
|
427
|
-
|
428
|
-
class ScaledWSConv2d(_ScaledWSConv):
|
429
|
-
"""Two-dimensional convolution with weight standardization.
|
430
|
-
|
431
|
-
The input should be a 4d array with the shape of ``[B, H, W, C]``.
|
432
|
-
|
433
|
-
Parameters
|
434
|
-
----------
|
435
|
-
%s
|
436
|
-
"""
|
437
|
-
__module__ = 'brainstate.nn'
|
438
|
-
num_spatial_dims: int = 2
|
439
|
-
|
440
|
-
|
441
|
-
class ScaledWSConv3d(_ScaledWSConv):
|
442
|
-
"""Three-dimensional convolution with weight standardization.
|
443
|
-
|
444
|
-
The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
|
445
|
-
|
446
|
-
Parameters
|
447
|
-
----------
|
448
|
-
%s
|
449
|
-
"""
|
450
|
-
__module__ = 'brainstate.nn'
|
451
|
-
num_spatial_dims: int = 3
|
452
|
-
|
453
|
-
|
454
|
-
_ws_conv_doc = '''
|
455
|
-
in_size: tuple of int
|
456
|
-
The input shape, without the batch size. This argument is important, since it is
|
457
|
-
used to evaluate the shape of the output.
|
458
|
-
out_channels: int
|
459
|
-
The number of output channels.
|
460
|
-
kernel_size: int, sequence of int
|
461
|
-
The shape of the convolutional kernel.
|
462
|
-
For 1D convolution, the kernel size can be passed as an integer.
|
463
|
-
For all other cases, it must be a sequence of integers.
|
464
|
-
stride: int, sequence of int
|
465
|
-
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
466
|
-
padding: str, int, sequence of int, sequence of tuple
|
467
|
-
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
468
|
-
high)` integer pairs that give the padding to apply before and after each
|
469
|
-
spatial dimension.
|
470
|
-
lhs_dilation: int, sequence of int
|
471
|
-
An integer or a sequence of `n` integers, giving the
|
472
|
-
dilation factor to apply in each spatial dimension of `inputs`
|
473
|
-
(default: 1). Convolution with input dilation `d` is equivalent to
|
474
|
-
transposed convolution with stride `d`.
|
475
|
-
rhs_dilation: int, sequence of int
|
476
|
-
An integer or a sequence of `n` integers, giving the
|
477
|
-
dilation factor to apply in each spatial dimension of the convolution
|
478
|
-
kernel (default: 1). Convolution with kernel dilation
|
479
|
-
is also known as 'atrous convolution'.
|
480
|
-
groups: int
|
481
|
-
If specified, divides the input features into groups. default 1.
|
482
|
-
w_init: Callable, ArrayLike, Initializer
|
483
|
-
The initializer for the convolutional kernel.
|
484
|
-
b_init: Optional, Callable, ArrayLike, Initializer
|
485
|
-
The initializer for the bias.
|
486
|
-
ws_gain: bool
|
487
|
-
Whether to add a gain term for the weight standarization. The default is `True`.
|
488
|
-
eps: float
|
489
|
-
The epsilon value for numerical stability.
|
490
|
-
w_mask: ArrayLike, Callable, Optional
|
491
|
-
The optional mask of the weights.
|
492
|
-
mode: Mode
|
493
|
-
The computation mode of the current object. Default it is `training`.
|
494
|
-
name: str, Optional
|
495
|
-
The name of the object.
|
496
|
-
|
497
|
-
'''
|
498
|
-
|
499
|
-
ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
|
500
|
-
ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
|
501
|
-
ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import collections.abc
|
19
|
+
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
|
24
|
+
from brainstate import init, functional
|
25
|
+
from brainstate._state import ParamState
|
26
|
+
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import Module
|
28
|
+
|
29
|
+
T = TypeVar('T')
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'Conv1d', 'Conv2d', 'Conv3d',
|
33
|
+
'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
|
34
|
+
]
|
35
|
+
|
36
|
+
|
37
|
+
def to_dimension_numbers(
|
38
|
+
num_spatial_dims: int,
|
39
|
+
channels_last: bool,
|
40
|
+
transpose: bool
|
41
|
+
) -> jax.lax.ConvDimensionNumbers:
|
42
|
+
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
|
43
|
+
num_dims = num_spatial_dims + 2
|
44
|
+
if channels_last:
|
45
|
+
spatial_dims = tuple(range(1, num_dims - 1))
|
46
|
+
image_dn = (0, num_dims - 1) + spatial_dims
|
47
|
+
else:
|
48
|
+
spatial_dims = tuple(range(2, num_dims))
|
49
|
+
image_dn = (0, 1) + spatial_dims
|
50
|
+
if transpose:
|
51
|
+
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
|
52
|
+
else:
|
53
|
+
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
|
54
|
+
return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
|
55
|
+
rhs_spec=kernel_dn,
|
56
|
+
out_spec=image_dn)
|
57
|
+
|
58
|
+
|
59
|
+
def replicate(
|
60
|
+
element: Union[T, Sequence[T]],
|
61
|
+
num_replicate: int,
|
62
|
+
name: str,
|
63
|
+
) -> Tuple[T, ...]:
|
64
|
+
"""Replicates entry in `element` `num_replicate` if needed."""
|
65
|
+
if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
|
66
|
+
return (element,) * num_replicate
|
67
|
+
elif len(element) == 1:
|
68
|
+
return tuple(list(element) * num_replicate)
|
69
|
+
elif len(element) == num_replicate:
|
70
|
+
return tuple(element)
|
71
|
+
else:
|
72
|
+
raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
|
73
|
+
f"sequence of length {num_replicate}.")
|
74
|
+
|
75
|
+
|
76
|
+
class _BaseConv(Module):
|
77
|
+
# the number of spatial dimensions
|
78
|
+
num_spatial_dims: int
|
79
|
+
|
80
|
+
# the weight and its operations
|
81
|
+
weight: ParamState
|
82
|
+
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
in_size: Sequence[int],
|
86
|
+
out_channels: int,
|
87
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
88
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
89
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
90
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
91
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
92
|
+
groups: int = 1,
|
93
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
94
|
+
name: str = None,
|
95
|
+
):
|
96
|
+
super().__init__(name=name)
|
97
|
+
|
98
|
+
# general parameters
|
99
|
+
assert self.num_spatial_dims + 1 == len(in_size)
|
100
|
+
self.in_size = tuple(in_size)
|
101
|
+
self.in_channels = in_size[-1]
|
102
|
+
self.out_channels = out_channels
|
103
|
+
self.stride = replicate(stride, self.num_spatial_dims, 'stride')
|
104
|
+
self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
|
105
|
+
self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
|
106
|
+
self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
|
107
|
+
self.groups = groups
|
108
|
+
self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
|
109
|
+
|
110
|
+
# the padding parameter
|
111
|
+
if isinstance(padding, str):
|
112
|
+
assert padding in ['SAME', 'VALID']
|
113
|
+
elif isinstance(padding, int):
|
114
|
+
padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
|
115
|
+
elif isinstance(padding, (tuple, list)):
|
116
|
+
if isinstance(padding[0], int):
|
117
|
+
padding = (padding,) * self.num_spatial_dims
|
118
|
+
elif isinstance(padding[0], (tuple, list)):
|
119
|
+
if len(padding) == 1:
|
120
|
+
padding = tuple(padding) * self.num_spatial_dims
|
121
|
+
else:
|
122
|
+
if len(padding) != self.num_spatial_dims:
|
123
|
+
raise ValueError(
|
124
|
+
f"Padding {padding} must be a Tuple[int, int], "
|
125
|
+
f"or sequence of Tuple[int, int] with length 1, "
|
126
|
+
f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
|
127
|
+
)
|
128
|
+
padding = tuple(padding)
|
129
|
+
else:
|
130
|
+
raise ValueError
|
131
|
+
self.padding = padding
|
132
|
+
|
133
|
+
# the number of in-/out-channels
|
134
|
+
assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
|
135
|
+
assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
|
136
|
+
|
137
|
+
# kernel shape and w_mask
|
138
|
+
kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
|
139
|
+
self.kernel_shape = kernel_shape
|
140
|
+
self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
|
141
|
+
|
142
|
+
def _check_input_dim(self, x):
|
143
|
+
if x.ndim == self.num_spatial_dims + 2:
|
144
|
+
x_shape = x.shape[1:]
|
145
|
+
elif x.ndim == self.num_spatial_dims + 1:
|
146
|
+
x_shape = x.shape
|
147
|
+
else:
|
148
|
+
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
149
|
+
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
150
|
+
if self.in_size != x_shape:
|
151
|
+
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
152
|
+
|
153
|
+
def update(self, x):
|
154
|
+
self._check_input_dim(x)
|
155
|
+
non_batching = False
|
156
|
+
if x.ndim == self.num_spatial_dims + 1:
|
157
|
+
x = jnp.expand_dims(x, 0)
|
158
|
+
non_batching = True
|
159
|
+
y = self._conv_op(x, self.weight.value)
|
160
|
+
return y[0] if non_batching else y
|
161
|
+
|
162
|
+
def _conv_op(self, x, params):
|
163
|
+
raise NotImplementedError
|
164
|
+
|
165
|
+
def __repr__(self):
|
166
|
+
return (f'{self.__class__.__name__}('
|
167
|
+
f'in_channels={self.in_channels}, '
|
168
|
+
f'out_channels={self.out_channels}, '
|
169
|
+
f'kernel_size={self.kernel_size}, '
|
170
|
+
f'stride={self.stride}, '
|
171
|
+
f'padding={self.padding}, '
|
172
|
+
f'groups={self.groups})')
|
173
|
+
|
174
|
+
|
175
|
+
class _Conv(_BaseConv):
|
176
|
+
num_spatial_dims: int = None
|
177
|
+
|
178
|
+
def __init__(
|
179
|
+
self,
|
180
|
+
in_size: Sequence[int],
|
181
|
+
out_channels: int,
|
182
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
183
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
184
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
185
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
186
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
187
|
+
groups: int = 1,
|
188
|
+
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
189
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
190
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
191
|
+
name: str = None,
|
192
|
+
param_type: type = ParamState,
|
193
|
+
):
|
194
|
+
super().__init__(
|
195
|
+
in_size=in_size,
|
196
|
+
out_channels=out_channels,
|
197
|
+
kernel_size=kernel_size,
|
198
|
+
stride=stride,
|
199
|
+
padding=padding,
|
200
|
+
lhs_dilation=lhs_dilation,
|
201
|
+
rhs_dilation=rhs_dilation,
|
202
|
+
groups=groups,
|
203
|
+
w_mask=w_mask,
|
204
|
+
name=name
|
205
|
+
)
|
206
|
+
|
207
|
+
self.w_initializer = w_init
|
208
|
+
self.b_initializer = b_init
|
209
|
+
|
210
|
+
# --- weights --- #
|
211
|
+
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
212
|
+
params = dict(weight=weight)
|
213
|
+
if self.b_initializer is not None:
|
214
|
+
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
215
|
+
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
216
|
+
params['bias'] = bias
|
217
|
+
|
218
|
+
# The weight operation
|
219
|
+
self.weight = param_type(params)
|
220
|
+
|
221
|
+
# Evaluate the output shape
|
222
|
+
abstract_y = jax.eval_shape(
|
223
|
+
self._conv_op,
|
224
|
+
jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
|
225
|
+
params
|
226
|
+
)
|
227
|
+
y_shape = abstract_y.shape[1:]
|
228
|
+
self.out_size = y_shape
|
229
|
+
|
230
|
+
def _conv_op(self, x, params):
|
231
|
+
w = params['weight']
|
232
|
+
if self.w_mask is not None:
|
233
|
+
w = w * self.w_mask
|
234
|
+
y = jax.lax.conv_general_dilated(
|
235
|
+
lhs=x,
|
236
|
+
rhs=w,
|
237
|
+
window_strides=self.stride,
|
238
|
+
padding=self.padding,
|
239
|
+
lhs_dilation=self.lhs_dilation,
|
240
|
+
rhs_dilation=self.rhs_dilation,
|
241
|
+
feature_group_count=self.groups,
|
242
|
+
dimension_numbers=self.dimension_numbers
|
243
|
+
)
|
244
|
+
if 'bias' in params:
|
245
|
+
y = y + params['bias']
|
246
|
+
return y
|
247
|
+
|
248
|
+
|
249
|
+
class Conv1d(_Conv):
|
250
|
+
"""One-dimensional convolution.
|
251
|
+
|
252
|
+
The input should be a 3d array with the shape of ``[B, H, C]``.
|
253
|
+
|
254
|
+
Parameters
|
255
|
+
----------
|
256
|
+
%s
|
257
|
+
"""
|
258
|
+
__module__ = 'brainstate.nn'
|
259
|
+
num_spatial_dims: int = 1
|
260
|
+
|
261
|
+
|
262
|
+
class Conv2d(_Conv):
|
263
|
+
"""Two-dimensional convolution.
|
264
|
+
|
265
|
+
The input should be a 4d array with the shape of ``[B, H, W, C]``.
|
266
|
+
|
267
|
+
Parameters
|
268
|
+
----------
|
269
|
+
%s
|
270
|
+
"""
|
271
|
+
__module__ = 'brainstate.nn'
|
272
|
+
num_spatial_dims: int = 2
|
273
|
+
|
274
|
+
|
275
|
+
class Conv3d(_Conv):
|
276
|
+
"""Three-dimensional convolution.
|
277
|
+
|
278
|
+
The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
|
279
|
+
|
280
|
+
Parameters
|
281
|
+
----------
|
282
|
+
%s
|
283
|
+
"""
|
284
|
+
__module__ = 'brainstate.nn'
|
285
|
+
num_spatial_dims: int = 3
|
286
|
+
|
287
|
+
|
288
|
+
_conv_doc = '''
|
289
|
+
in_size: tuple of int
|
290
|
+
The input shape, without the batch size. This argument is important, since it is
|
291
|
+
used to evaluate the shape of the output.
|
292
|
+
out_channels: int
|
293
|
+
The number of output channels.
|
294
|
+
kernel_size: int, sequence of int
|
295
|
+
The shape of the convolutional kernel.
|
296
|
+
For 1D convolution, the kernel size can be passed as an integer.
|
297
|
+
For all other cases, it must be a sequence of integers.
|
298
|
+
stride: int, sequence of int
|
299
|
+
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
300
|
+
padding: str, int, sequence of int, sequence of tuple
|
301
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
302
|
+
high)` integer pairs that give the padding to apply before and after each
|
303
|
+
spatial dimension.
|
304
|
+
lhs_dilation: int, sequence of int
|
305
|
+
An integer or a sequence of `n` integers, giving the
|
306
|
+
dilation factor to apply in each spatial dimension of `inputs`
|
307
|
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
308
|
+
transposed convolution with stride `d`.
|
309
|
+
rhs_dilation: int, sequence of int
|
310
|
+
An integer or a sequence of `n` integers, giving the
|
311
|
+
dilation factor to apply in each spatial dimension of the convolution
|
312
|
+
kernel (default: 1). Convolution with kernel dilation
|
313
|
+
is also known as 'atrous convolution'.
|
314
|
+
groups: int
|
315
|
+
If specified, divides the input features into groups. default 1.
|
316
|
+
w_init: Callable, ArrayLike, Initializer
|
317
|
+
The initializer for the convolutional kernel.
|
318
|
+
b_init: Optional, Callable, ArrayLike, Initializer
|
319
|
+
The initializer for the bias.
|
320
|
+
w_mask: ArrayLike, Callable, Optional
|
321
|
+
The optional mask of the weights.
|
322
|
+
mode: Mode
|
323
|
+
The computation mode of the current object. Default it is `training`.
|
324
|
+
name: str, Optional
|
325
|
+
The name of the object.
|
326
|
+
'''
|
327
|
+
|
328
|
+
Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
|
329
|
+
Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
|
330
|
+
Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
|
331
|
+
|
332
|
+
|
333
|
+
class _ScaledWSConv(_BaseConv):
|
334
|
+
def __init__(
|
335
|
+
self,
|
336
|
+
in_size: Sequence[int],
|
337
|
+
out_channels: int,
|
338
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
339
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
340
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
341
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
342
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
343
|
+
groups: int = 1,
|
344
|
+
ws_gain: bool = True,
|
345
|
+
eps: float = 1e-4,
|
346
|
+
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
347
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
348
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
349
|
+
name: str = None,
|
350
|
+
param_type: type = ParamState,
|
351
|
+
):
|
352
|
+
super().__init__(in_size=in_size,
|
353
|
+
out_channels=out_channels,
|
354
|
+
kernel_size=kernel_size,
|
355
|
+
stride=stride,
|
356
|
+
padding=padding,
|
357
|
+
lhs_dilation=lhs_dilation,
|
358
|
+
rhs_dilation=rhs_dilation,
|
359
|
+
groups=groups,
|
360
|
+
w_mask=w_mask,
|
361
|
+
name=name, )
|
362
|
+
|
363
|
+
self.w_initializer = w_init
|
364
|
+
self.b_initializer = b_init
|
365
|
+
|
366
|
+
# --- weights --- #
|
367
|
+
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
368
|
+
params = dict(weight=weight)
|
369
|
+
if self.b_initializer is not None:
|
370
|
+
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
371
|
+
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
372
|
+
params['bias'] = bias
|
373
|
+
|
374
|
+
# gain
|
375
|
+
if ws_gain:
|
376
|
+
gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
|
377
|
+
ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
|
378
|
+
params['gain'] = ws_gain
|
379
|
+
|
380
|
+
# Epsilon, a small constant to avoid dividing by zero.
|
381
|
+
self.eps = eps
|
382
|
+
|
383
|
+
# The weight operation
|
384
|
+
self.weight = param_type(params)
|
385
|
+
|
386
|
+
# Evaluate the output shape
|
387
|
+
abstract_y = jax.eval_shape(
|
388
|
+
self._conv_op,
|
389
|
+
jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
|
390
|
+
params
|
391
|
+
)
|
392
|
+
y_shape = abstract_y.shape[1:]
|
393
|
+
self.out_size = y_shape
|
394
|
+
|
395
|
+
def _conv_op(self, x, params):
|
396
|
+
w = params['weight']
|
397
|
+
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
398
|
+
if self.w_mask is not None:
|
399
|
+
w = w * self.w_mask
|
400
|
+
y = jax.lax.conv_general_dilated(
|
401
|
+
lhs=x,
|
402
|
+
rhs=w,
|
403
|
+
window_strides=self.stride,
|
404
|
+
padding=self.padding,
|
405
|
+
lhs_dilation=self.lhs_dilation,
|
406
|
+
rhs_dilation=self.rhs_dilation,
|
407
|
+
feature_group_count=self.groups,
|
408
|
+
dimension_numbers=self.dimension_numbers
|
409
|
+
)
|
410
|
+
if 'bias' in params:
|
411
|
+
y = y + params['bias']
|
412
|
+
return y
|
413
|
+
|
414
|
+
|
415
|
+
class ScaledWSConv1d(_ScaledWSConv):
|
416
|
+
"""One-dimensional convolution with weight standardization.
|
417
|
+
|
418
|
+
The input should be a 3d array with the shape of ``[B, H, C]``.
|
419
|
+
|
420
|
+
Parameters
|
421
|
+
----------
|
422
|
+
%s
|
423
|
+
"""
|
424
|
+
__module__ = 'brainstate.nn'
|
425
|
+
num_spatial_dims: int = 1
|
426
|
+
|
427
|
+
|
428
|
+
class ScaledWSConv2d(_ScaledWSConv):
|
429
|
+
"""Two-dimensional convolution with weight standardization.
|
430
|
+
|
431
|
+
The input should be a 4d array with the shape of ``[B, H, W, C]``.
|
432
|
+
|
433
|
+
Parameters
|
434
|
+
----------
|
435
|
+
%s
|
436
|
+
"""
|
437
|
+
__module__ = 'brainstate.nn'
|
438
|
+
num_spatial_dims: int = 2
|
439
|
+
|
440
|
+
|
441
|
+
class ScaledWSConv3d(_ScaledWSConv):
|
442
|
+
"""Three-dimensional convolution with weight standardization.
|
443
|
+
|
444
|
+
The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
|
445
|
+
|
446
|
+
Parameters
|
447
|
+
----------
|
448
|
+
%s
|
449
|
+
"""
|
450
|
+
__module__ = 'brainstate.nn'
|
451
|
+
num_spatial_dims: int = 3
|
452
|
+
|
453
|
+
|
454
|
+
_ws_conv_doc = '''
|
455
|
+
in_size: tuple of int
|
456
|
+
The input shape, without the batch size. This argument is important, since it is
|
457
|
+
used to evaluate the shape of the output.
|
458
|
+
out_channels: int
|
459
|
+
The number of output channels.
|
460
|
+
kernel_size: int, sequence of int
|
461
|
+
The shape of the convolutional kernel.
|
462
|
+
For 1D convolution, the kernel size can be passed as an integer.
|
463
|
+
For all other cases, it must be a sequence of integers.
|
464
|
+
stride: int, sequence of int
|
465
|
+
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
466
|
+
padding: str, int, sequence of int, sequence of tuple
|
467
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
468
|
+
high)` integer pairs that give the padding to apply before and after each
|
469
|
+
spatial dimension.
|
470
|
+
lhs_dilation: int, sequence of int
|
471
|
+
An integer or a sequence of `n` integers, giving the
|
472
|
+
dilation factor to apply in each spatial dimension of `inputs`
|
473
|
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
474
|
+
transposed convolution with stride `d`.
|
475
|
+
rhs_dilation: int, sequence of int
|
476
|
+
An integer or a sequence of `n` integers, giving the
|
477
|
+
dilation factor to apply in each spatial dimension of the convolution
|
478
|
+
kernel (default: 1). Convolution with kernel dilation
|
479
|
+
is also known as 'atrous convolution'.
|
480
|
+
groups: int
|
481
|
+
If specified, divides the input features into groups. default 1.
|
482
|
+
w_init: Callable, ArrayLike, Initializer
|
483
|
+
The initializer for the convolutional kernel.
|
484
|
+
b_init: Optional, Callable, ArrayLike, Initializer
|
485
|
+
The initializer for the bias.
|
486
|
+
ws_gain: bool
|
487
|
+
Whether to add a gain term for the weight standarization. The default is `True`.
|
488
|
+
eps: float
|
489
|
+
The epsilon value for numerical stability.
|
490
|
+
w_mask: ArrayLike, Callable, Optional
|
491
|
+
The optional mask of the weights.
|
492
|
+
mode: Mode
|
493
|
+
The computation mode of the current object. Default it is `training`.
|
494
|
+
name: str, Optional
|
495
|
+
The name of the object.
|
496
|
+
|
497
|
+
'''
|
498
|
+
|
499
|
+
ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
|
500
|
+
ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
|
501
|
+
ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
|