brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,726 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import collections.abc
|
21
|
+
import numbers
|
22
|
+
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
23
|
+
|
24
|
+
import brainunit as u
|
25
|
+
import jax
|
26
|
+
import jax.numpy as jnp
|
27
|
+
|
28
|
+
from brainstate import init, functional
|
29
|
+
from brainstate._state import ParamState
|
30
|
+
from brainstate.nn._module import Module
|
31
|
+
from brainstate.typing import ArrayLike
|
32
|
+
|
33
|
+
T = TypeVar('T')
|
34
|
+
|
35
|
+
__all__ = [
|
36
|
+
'Linear', 'ScaledWSLinear', 'SignedWLinear', 'CSRLinear',
|
37
|
+
'Conv1d', 'Conv2d', 'Conv3d',
|
38
|
+
'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
|
39
|
+
'AllToAll',
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
def to_dimension_numbers(
|
44
|
+
num_spatial_dims: int,
|
45
|
+
channels_last: bool,
|
46
|
+
transpose: bool
|
47
|
+
) -> jax.lax.ConvDimensionNumbers:
|
48
|
+
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
|
49
|
+
num_dims = num_spatial_dims + 2
|
50
|
+
if channels_last:
|
51
|
+
spatial_dims = tuple(range(1, num_dims - 1))
|
52
|
+
image_dn = (0, num_dims - 1) + spatial_dims
|
53
|
+
else:
|
54
|
+
spatial_dims = tuple(range(2, num_dims))
|
55
|
+
image_dn = (0, 1) + spatial_dims
|
56
|
+
if transpose:
|
57
|
+
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
|
58
|
+
else:
|
59
|
+
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
|
60
|
+
return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
|
61
|
+
rhs_spec=kernel_dn,
|
62
|
+
out_spec=image_dn)
|
63
|
+
|
64
|
+
|
65
|
+
def replicate(
|
66
|
+
element: Union[T, Sequence[T]],
|
67
|
+
num_replicate: int,
|
68
|
+
name: str,
|
69
|
+
) -> Tuple[T, ...]:
|
70
|
+
"""Replicates entry in `element` `num_replicate` if needed."""
|
71
|
+
if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
|
72
|
+
return (element,) * num_replicate
|
73
|
+
elif len(element) == 1:
|
74
|
+
return tuple(list(element) * num_replicate)
|
75
|
+
elif len(element) == num_replicate:
|
76
|
+
return tuple(element)
|
77
|
+
else:
|
78
|
+
raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
|
79
|
+
f"sequence of length {num_replicate}.")
|
80
|
+
|
81
|
+
|
82
|
+
class Linear(Module):
|
83
|
+
"""
|
84
|
+
Linear layer.
|
85
|
+
"""
|
86
|
+
__module__ = 'brainstate.nn'
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
in_size: Union[int, Sequence[int]],
|
91
|
+
out_size: Union[int, Sequence[int]],
|
92
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
93
|
+
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
94
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
95
|
+
name: Optional[str] = None,
|
96
|
+
):
|
97
|
+
super().__init__(name=name)
|
98
|
+
|
99
|
+
# input and output shape
|
100
|
+
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
101
|
+
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
102
|
+
|
103
|
+
# w_mask
|
104
|
+
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
105
|
+
|
106
|
+
# weights
|
107
|
+
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
108
|
+
if b_init is not None:
|
109
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
110
|
+
|
111
|
+
# weight + op
|
112
|
+
self.weight = ParamState(params)
|
113
|
+
|
114
|
+
def update(self, x):
|
115
|
+
params = self.weight.value
|
116
|
+
weight = params['weight']
|
117
|
+
if self.w_mask is not None:
|
118
|
+
weight = weight * self.w_mask
|
119
|
+
y = u.math.dot(x, weight)
|
120
|
+
if 'bias' in params:
|
121
|
+
y = y + params['bias']
|
122
|
+
return y
|
123
|
+
|
124
|
+
|
125
|
+
class SignedWLinear(Module):
|
126
|
+
"""
|
127
|
+
Linear layer with signed weights.
|
128
|
+
"""
|
129
|
+
__module__ = 'brainstate.nn'
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
in_size: Union[int, Sequence[int]],
|
134
|
+
out_size: Union[int, Sequence[int]],
|
135
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
136
|
+
w_sign: Optional[ArrayLike] = None,
|
137
|
+
name: Optional[str] = None,
|
138
|
+
|
139
|
+
):
|
140
|
+
super().__init__(name=name)
|
141
|
+
|
142
|
+
# input and output shape
|
143
|
+
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
144
|
+
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
145
|
+
|
146
|
+
# w_mask
|
147
|
+
self.w_sign = w_sign
|
148
|
+
|
149
|
+
# weights
|
150
|
+
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
151
|
+
self.weight = ParamState(weight)
|
152
|
+
|
153
|
+
def _operation(self, x, w):
|
154
|
+
if self.w_sign is None:
|
155
|
+
return jnp.matmul(x, jnp.abs(w))
|
156
|
+
else:
|
157
|
+
return jnp.matmul(x, jnp.abs(w) * self.w_sign)
|
158
|
+
|
159
|
+
def update(self, x):
|
160
|
+
return self._operation(x, self.weight.value)
|
161
|
+
|
162
|
+
|
163
|
+
class ScaledWSLinear(Module):
|
164
|
+
"""
|
165
|
+
Linear Layer with Weight Standardization.
|
166
|
+
|
167
|
+
Applies weight standardization to the weights of the linear layer.
|
168
|
+
|
169
|
+
Parameters
|
170
|
+
----------
|
171
|
+
in_size: int, sequence of int
|
172
|
+
The input size.
|
173
|
+
out_size: int, sequence of int
|
174
|
+
The output size.
|
175
|
+
w_init: Callable, ArrayLike
|
176
|
+
The initializer for the weights.
|
177
|
+
b_init: Callable, ArrayLike
|
178
|
+
The initializer for the bias.
|
179
|
+
w_mask: ArrayLike, Callable
|
180
|
+
The optional mask of the weights.
|
181
|
+
ws_gain: bool
|
182
|
+
Whether to use gain for the weights. The default is True.
|
183
|
+
eps: float
|
184
|
+
The epsilon value for the weight standardization.
|
185
|
+
name: str
|
186
|
+
The name of the object.
|
187
|
+
|
188
|
+
"""
|
189
|
+
__module__ = 'brainstate.nn'
|
190
|
+
|
191
|
+
def __init__(
|
192
|
+
self,
|
193
|
+
in_size: Union[int, Sequence[int]],
|
194
|
+
out_size: Union[int, Sequence[int]],
|
195
|
+
w_init: Callable = init.KaimingNormal(),
|
196
|
+
b_init: Callable = init.ZeroInit(),
|
197
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
198
|
+
ws_gain: bool = True,
|
199
|
+
eps: float = 1e-4,
|
200
|
+
name: str = None,
|
201
|
+
):
|
202
|
+
super().__init__(name=name)
|
203
|
+
|
204
|
+
# input and output shape
|
205
|
+
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
206
|
+
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
207
|
+
|
208
|
+
# w_mask
|
209
|
+
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
210
|
+
|
211
|
+
# parameters
|
212
|
+
self.eps = eps
|
213
|
+
|
214
|
+
# weights
|
215
|
+
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
216
|
+
if b_init is not None:
|
217
|
+
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
218
|
+
# gain
|
219
|
+
if ws_gain:
|
220
|
+
s = params['weight'].shape
|
221
|
+
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
222
|
+
|
223
|
+
# weight operation
|
224
|
+
self.weight = ParamState(params)
|
225
|
+
|
226
|
+
def update(self, x):
|
227
|
+
return self._operation(x, self.weight.value)
|
228
|
+
|
229
|
+
def _operation(self, x, params):
|
230
|
+
w = params['weight']
|
231
|
+
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
232
|
+
if self.w_mask is not None:
|
233
|
+
w = w * self.w_mask
|
234
|
+
y = jnp.dot(x, w)
|
235
|
+
if 'bias' in params:
|
236
|
+
y = y + params['bias']
|
237
|
+
return y
|
238
|
+
|
239
|
+
|
240
|
+
class CSRLinear(Module):
|
241
|
+
__module__ = 'brainstate.nn'
|
242
|
+
|
243
|
+
|
244
|
+
class _BaseConv(Module):
|
245
|
+
# the number of spatial dimensions
|
246
|
+
num_spatial_dims: int
|
247
|
+
|
248
|
+
# the weight and its operations
|
249
|
+
weight: ParamState
|
250
|
+
|
251
|
+
def __init__(
|
252
|
+
self,
|
253
|
+
in_size: Sequence[int],
|
254
|
+
out_channels: int,
|
255
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
256
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
257
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
258
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
259
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
260
|
+
groups: int = 1,
|
261
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
262
|
+
name: str = None,
|
263
|
+
):
|
264
|
+
super().__init__(name=name)
|
265
|
+
|
266
|
+
# general parameters
|
267
|
+
assert self.num_spatial_dims + 1 == len(in_size)
|
268
|
+
self.in_size = tuple(in_size)
|
269
|
+
self.in_channels = in_size[-1]
|
270
|
+
self.out_channels = out_channels
|
271
|
+
self.stride = replicate(stride, self.num_spatial_dims, 'stride')
|
272
|
+
self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
|
273
|
+
self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
|
274
|
+
self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
|
275
|
+
self.groups = groups
|
276
|
+
self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
|
277
|
+
|
278
|
+
# the padding parameter
|
279
|
+
if isinstance(padding, str):
|
280
|
+
assert padding in ['SAME', 'VALID']
|
281
|
+
elif isinstance(padding, int):
|
282
|
+
padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
|
283
|
+
elif isinstance(padding, (tuple, list)):
|
284
|
+
if isinstance(padding[0], int):
|
285
|
+
padding = (padding,) * self.num_spatial_dims
|
286
|
+
elif isinstance(padding[0], (tuple, list)):
|
287
|
+
if len(padding) == 1:
|
288
|
+
padding = tuple(padding) * self.num_spatial_dims
|
289
|
+
else:
|
290
|
+
if len(padding) != self.num_spatial_dims:
|
291
|
+
raise ValueError(
|
292
|
+
f"Padding {padding} must be a Tuple[int, int], "
|
293
|
+
f"or sequence of Tuple[int, int] with length 1, "
|
294
|
+
f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
|
295
|
+
)
|
296
|
+
padding = tuple(padding)
|
297
|
+
else:
|
298
|
+
raise ValueError
|
299
|
+
self.padding = padding
|
300
|
+
|
301
|
+
# the number of in-/out-channels
|
302
|
+
assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
|
303
|
+
assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
|
304
|
+
|
305
|
+
# kernel shape and w_mask
|
306
|
+
kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
|
307
|
+
self.kernel_shape = kernel_shape
|
308
|
+
self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
|
309
|
+
|
310
|
+
def _check_input_dim(self, x):
|
311
|
+
if x.ndim == self.num_spatial_dims + 2:
|
312
|
+
x_shape = x.shape[1:]
|
313
|
+
elif x.ndim == self.num_spatial_dims + 1:
|
314
|
+
x_shape = x.shape
|
315
|
+
else:
|
316
|
+
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
317
|
+
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
318
|
+
if self.in_size != x_shape:
|
319
|
+
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
320
|
+
|
321
|
+
def update(self, x):
|
322
|
+
self._check_input_dim(x)
|
323
|
+
non_batching = False
|
324
|
+
if x.ndim == self.num_spatial_dims + 1:
|
325
|
+
x = jnp.expand_dims(x, 0)
|
326
|
+
non_batching = True
|
327
|
+
y = self._conv_op(x, self.weight.value)
|
328
|
+
return y[0] if non_batching else y
|
329
|
+
|
330
|
+
def _conv_op(self, x, params):
|
331
|
+
raise NotImplementedError
|
332
|
+
|
333
|
+
def __repr__(self):
|
334
|
+
return (f'{self.__class__.__name__}('
|
335
|
+
f'in_channels={self.in_channels}, '
|
336
|
+
f'out_channels={self.out_channels}, '
|
337
|
+
f'kernel_size={self.kernel_size}, '
|
338
|
+
f'stride={self.stride}, '
|
339
|
+
f'padding={self.padding}, '
|
340
|
+
f'groups={self.groups})')
|
341
|
+
|
342
|
+
|
343
|
+
class _Conv(_BaseConv):
|
344
|
+
num_spatial_dims: int = None
|
345
|
+
|
346
|
+
def __init__(
|
347
|
+
self,
|
348
|
+
in_size: Sequence[int],
|
349
|
+
out_channels: int,
|
350
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
351
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
352
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
353
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
354
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
355
|
+
groups: int = 1,
|
356
|
+
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
357
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
358
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
359
|
+
name: str = None,
|
360
|
+
):
|
361
|
+
super().__init__(in_size=in_size,
|
362
|
+
out_channels=out_channels,
|
363
|
+
kernel_size=kernel_size,
|
364
|
+
stride=stride,
|
365
|
+
padding=padding,
|
366
|
+
lhs_dilation=lhs_dilation,
|
367
|
+
rhs_dilation=rhs_dilation,
|
368
|
+
groups=groups,
|
369
|
+
w_mask=w_mask,
|
370
|
+
name=name)
|
371
|
+
|
372
|
+
self.w_initializer = w_init
|
373
|
+
self.b_initializer = b_init
|
374
|
+
|
375
|
+
# --- weights --- #
|
376
|
+
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
377
|
+
params = dict(weight=weight)
|
378
|
+
if self.b_initializer is not None:
|
379
|
+
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
380
|
+
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
381
|
+
params['bias'] = bias
|
382
|
+
|
383
|
+
# The weight operation
|
384
|
+
self.weight = ParamState(params)
|
385
|
+
|
386
|
+
# Evaluate the output shape
|
387
|
+
abstract_y = jax.eval_shape(
|
388
|
+
self._conv_op,
|
389
|
+
jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
|
390
|
+
params
|
391
|
+
)
|
392
|
+
y_shape = abstract_y.shape[1:]
|
393
|
+
self.out_size = y_shape
|
394
|
+
|
395
|
+
def _conv_op(self, x, params):
|
396
|
+
w = params['weight']
|
397
|
+
if self.w_mask is not None:
|
398
|
+
w = w * self.w_mask
|
399
|
+
y = jax.lax.conv_general_dilated(
|
400
|
+
lhs=x,
|
401
|
+
rhs=w,
|
402
|
+
window_strides=self.stride,
|
403
|
+
padding=self.padding,
|
404
|
+
lhs_dilation=self.lhs_dilation,
|
405
|
+
rhs_dilation=self.rhs_dilation,
|
406
|
+
feature_group_count=self.groups,
|
407
|
+
dimension_numbers=self.dimension_numbers
|
408
|
+
)
|
409
|
+
if 'bias' in params:
|
410
|
+
y = y + params['bias']
|
411
|
+
return y
|
412
|
+
|
413
|
+
|
414
|
+
class Conv1d(_Conv):
|
415
|
+
"""One-dimensional convolution.
|
416
|
+
|
417
|
+
The input should be a 3d array with the shape of ``[B, H, C]``.
|
418
|
+
|
419
|
+
Parameters
|
420
|
+
----------
|
421
|
+
%s
|
422
|
+
"""
|
423
|
+
__module__ = 'brainstate.nn'
|
424
|
+
num_spatial_dims: int = 1
|
425
|
+
|
426
|
+
|
427
|
+
class Conv2d(_Conv):
|
428
|
+
"""Two-dimensional convolution.
|
429
|
+
|
430
|
+
The input should be a 4d array with the shape of ``[B, H, W, C]``.
|
431
|
+
|
432
|
+
Parameters
|
433
|
+
----------
|
434
|
+
%s
|
435
|
+
"""
|
436
|
+
__module__ = 'brainstate.nn'
|
437
|
+
num_spatial_dims: int = 2
|
438
|
+
|
439
|
+
|
440
|
+
class Conv3d(_Conv):
|
441
|
+
"""Three-dimensional convolution.
|
442
|
+
|
443
|
+
The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
|
444
|
+
|
445
|
+
Parameters
|
446
|
+
----------
|
447
|
+
%s
|
448
|
+
"""
|
449
|
+
__module__ = 'brainstate.nn'
|
450
|
+
num_spatial_dims: int = 3
|
451
|
+
|
452
|
+
|
453
|
+
_conv_doc = '''
|
454
|
+
in_size: tuple of int
|
455
|
+
The input shape, without the batch size. This argument is important, since it is
|
456
|
+
used to evaluate the shape of the output.
|
457
|
+
out_channels: int
|
458
|
+
The number of output channels.
|
459
|
+
kernel_size: int, sequence of int
|
460
|
+
The shape of the convolutional kernel.
|
461
|
+
For 1D convolution, the kernel size can be passed as an integer.
|
462
|
+
For all other cases, it must be a sequence of integers.
|
463
|
+
stride: int, sequence of int
|
464
|
+
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
465
|
+
padding: str, int, sequence of int, sequence of tuple
|
466
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
467
|
+
high)` integer pairs that give the padding to apply before and after each
|
468
|
+
spatial dimension.
|
469
|
+
lhs_dilation: int, sequence of int
|
470
|
+
An integer or a sequence of `n` integers, giving the
|
471
|
+
dilation factor to apply in each spatial dimension of `inputs`
|
472
|
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
473
|
+
transposed convolution with stride `d`.
|
474
|
+
rhs_dilation: int, sequence of int
|
475
|
+
An integer or a sequence of `n` integers, giving the
|
476
|
+
dilation factor to apply in each spatial dimension of the convolution
|
477
|
+
kernel (default: 1). Convolution with kernel dilation
|
478
|
+
is also known as 'atrous convolution'.
|
479
|
+
groups: int
|
480
|
+
If specified, divides the input features into groups. default 1.
|
481
|
+
w_init: Callable, ArrayLike, Initializer
|
482
|
+
The initializer for the convolutional kernel.
|
483
|
+
b_init: Optional, Callable, ArrayLike, Initializer
|
484
|
+
The initializer for the bias.
|
485
|
+
w_mask: ArrayLike, Callable, Optional
|
486
|
+
The optional mask of the weights.
|
487
|
+
mode: Mode
|
488
|
+
The computation mode of the current object. Default it is `training`.
|
489
|
+
name: str, Optional
|
490
|
+
The name of the object.
|
491
|
+
'''
|
492
|
+
|
493
|
+
Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
|
494
|
+
Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
|
495
|
+
Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
|
496
|
+
|
497
|
+
|
498
|
+
class _ScaledWSConv(_BaseConv):
|
499
|
+
def __init__(
|
500
|
+
self,
|
501
|
+
in_size: Sequence[int],
|
502
|
+
out_channels: int,
|
503
|
+
kernel_size: Union[int, Tuple[int, ...]],
|
504
|
+
stride: Union[int, Tuple[int, ...]] = 1,
|
505
|
+
padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
|
506
|
+
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
507
|
+
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
|
508
|
+
groups: int = 1,
|
509
|
+
ws_gain: bool = True,
|
510
|
+
eps: float = 1e-4,
|
511
|
+
w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
|
512
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
513
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
514
|
+
name: str = None,
|
515
|
+
):
|
516
|
+
super().__init__(in_size=in_size,
|
517
|
+
out_channels=out_channels,
|
518
|
+
kernel_size=kernel_size,
|
519
|
+
stride=stride,
|
520
|
+
padding=padding,
|
521
|
+
lhs_dilation=lhs_dilation,
|
522
|
+
rhs_dilation=rhs_dilation,
|
523
|
+
groups=groups,
|
524
|
+
w_mask=w_mask,
|
525
|
+
name=name, )
|
526
|
+
|
527
|
+
self.w_initializer = w_init
|
528
|
+
self.b_initializer = b_init
|
529
|
+
|
530
|
+
# --- weights --- #
|
531
|
+
weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
|
532
|
+
params = dict(weight=weight)
|
533
|
+
if self.b_initializer is not None:
|
534
|
+
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
|
535
|
+
bias = init.param(self.b_initializer, bias_shape, allow_none=True)
|
536
|
+
params['bias'] = bias
|
537
|
+
|
538
|
+
# gain
|
539
|
+
if ws_gain:
|
540
|
+
gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
|
541
|
+
ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
|
542
|
+
params['gain'] = ws_gain
|
543
|
+
|
544
|
+
# Epsilon, a small constant to avoid dividing by zero.
|
545
|
+
self.eps = eps
|
546
|
+
|
547
|
+
# The weight operation
|
548
|
+
self.weight = ParamState(params)
|
549
|
+
|
550
|
+
# Evaluate the output shape
|
551
|
+
abstract_y = jax.eval_shape(
|
552
|
+
self._conv_op,
|
553
|
+
jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
|
554
|
+
params
|
555
|
+
)
|
556
|
+
y_shape = abstract_y.shape[1:]
|
557
|
+
self.out_size = y_shape
|
558
|
+
|
559
|
+
def _conv_op(self, x, params):
|
560
|
+
w = params['weight']
|
561
|
+
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
562
|
+
if self.w_mask is not None:
|
563
|
+
w = w * self.w_mask
|
564
|
+
y = jax.lax.conv_general_dilated(
|
565
|
+
lhs=x,
|
566
|
+
rhs=w,
|
567
|
+
window_strides=self.stride,
|
568
|
+
padding=self.padding,
|
569
|
+
lhs_dilation=self.lhs_dilation,
|
570
|
+
rhs_dilation=self.rhs_dilation,
|
571
|
+
feature_group_count=self.groups,
|
572
|
+
dimension_numbers=self.dimension_numbers
|
573
|
+
)
|
574
|
+
if 'bias' in params:
|
575
|
+
y = y + params['bias']
|
576
|
+
return y
|
577
|
+
|
578
|
+
|
579
|
+
class ScaledWSConv1d(_ScaledWSConv):
|
580
|
+
"""One-dimensional convolution with weight standardization.
|
581
|
+
|
582
|
+
The input should be a 3d array with the shape of ``[B, H, C]``.
|
583
|
+
|
584
|
+
Parameters
|
585
|
+
----------
|
586
|
+
%s
|
587
|
+
"""
|
588
|
+
__module__ = 'brainstate.nn'
|
589
|
+
num_spatial_dims: int = 1
|
590
|
+
|
591
|
+
|
592
|
+
class ScaledWSConv2d(_ScaledWSConv):
|
593
|
+
"""Two-dimensional convolution with weight standardization.
|
594
|
+
|
595
|
+
The input should be a 4d array with the shape of ``[B, H, W, C]``.
|
596
|
+
|
597
|
+
Parameters
|
598
|
+
----------
|
599
|
+
%s
|
600
|
+
"""
|
601
|
+
__module__ = 'brainstate.nn'
|
602
|
+
num_spatial_dims: int = 2
|
603
|
+
|
604
|
+
|
605
|
+
class ScaledWSConv3d(_ScaledWSConv):
|
606
|
+
"""Three-dimensional convolution with weight standardization.
|
607
|
+
|
608
|
+
The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
|
609
|
+
|
610
|
+
Parameters
|
611
|
+
----------
|
612
|
+
%s
|
613
|
+
"""
|
614
|
+
__module__ = 'brainstate.nn'
|
615
|
+
num_spatial_dims: int = 3
|
616
|
+
|
617
|
+
|
618
|
+
_ws_conv_doc = '''
|
619
|
+
in_size: tuple of int
|
620
|
+
The input shape, without the batch size. This argument is important, since it is
|
621
|
+
used to evaluate the shape of the output.
|
622
|
+
out_channels: int
|
623
|
+
The number of output channels.
|
624
|
+
kernel_size: int, sequence of int
|
625
|
+
The shape of the convolutional kernel.
|
626
|
+
For 1D convolution, the kernel size can be passed as an integer.
|
627
|
+
For all other cases, it must be a sequence of integers.
|
628
|
+
stride: int, sequence of int
|
629
|
+
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
|
630
|
+
padding: str, int, sequence of int, sequence of tuple
|
631
|
+
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
|
632
|
+
high)` integer pairs that give the padding to apply before and after each
|
633
|
+
spatial dimension.
|
634
|
+
lhs_dilation: int, sequence of int
|
635
|
+
An integer or a sequence of `n` integers, giving the
|
636
|
+
dilation factor to apply in each spatial dimension of `inputs`
|
637
|
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
638
|
+
transposed convolution with stride `d`.
|
639
|
+
rhs_dilation: int, sequence of int
|
640
|
+
An integer or a sequence of `n` integers, giving the
|
641
|
+
dilation factor to apply in each spatial dimension of the convolution
|
642
|
+
kernel (default: 1). Convolution with kernel dilation
|
643
|
+
is also known as 'atrous convolution'.
|
644
|
+
groups: int
|
645
|
+
If specified, divides the input features into groups. default 1.
|
646
|
+
w_init: Callable, ArrayLike, Initializer
|
647
|
+
The initializer for the convolutional kernel.
|
648
|
+
b_init: Optional, Callable, ArrayLike, Initializer
|
649
|
+
The initializer for the bias.
|
650
|
+
ws_gain: bool
|
651
|
+
Whether to add a gain term for the weight standarization. The default is `True`.
|
652
|
+
eps: float
|
653
|
+
The epsilon value for numerical stability.
|
654
|
+
w_mask: ArrayLike, Callable, Optional
|
655
|
+
The optional mask of the weights.
|
656
|
+
mode: Mode
|
657
|
+
The computation mode of the current object. Default it is `training`.
|
658
|
+
name: str, Optional
|
659
|
+
The name of the object.
|
660
|
+
|
661
|
+
'''
|
662
|
+
|
663
|
+
ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
|
664
|
+
ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
|
665
|
+
ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
|
666
|
+
|
667
|
+
|
668
|
+
class AllToAll(Module):
|
669
|
+
"""Synaptic matrix multiplication with All2All connections.
|
670
|
+
|
671
|
+
Args:
|
672
|
+
in_size: int. The number of neurons in the presynaptic neuron group.
|
673
|
+
out_size: int. The number of neurons in the postsynaptic neuron group.
|
674
|
+
w_init: The synaptic weights.
|
675
|
+
include_self: bool. Whether connect the neuron with at the same position.
|
676
|
+
name: str. The object name.
|
677
|
+
"""
|
678
|
+
|
679
|
+
def __init__(
|
680
|
+
self,
|
681
|
+
in_size: Union[int, Sequence[int]],
|
682
|
+
out_size: Union[int, Sequence[int]],
|
683
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
684
|
+
include_self: bool = True,
|
685
|
+
name: Optional[str] = None,
|
686
|
+
):
|
687
|
+
super().__init__(name=name)
|
688
|
+
|
689
|
+
# input and output shape
|
690
|
+
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
691
|
+
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
692
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
693
|
+
'and "out_size" must be the same.')
|
694
|
+
|
695
|
+
# weights
|
696
|
+
self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
697
|
+
|
698
|
+
# others
|
699
|
+
self.include_self = include_self
|
700
|
+
|
701
|
+
def update(self, pre_val):
|
702
|
+
if u.math.ndim(self.weight.value) == 0: # weight is a scalar
|
703
|
+
if pre_val.ndim == 1:
|
704
|
+
post_val = u.math.sum(pre_val)
|
705
|
+
else:
|
706
|
+
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
707
|
+
if not self.include_self:
|
708
|
+
if self.in_size == self.out_size:
|
709
|
+
post_val = post_val - pre_val
|
710
|
+
elif self.in_size[-1] > self.out_size[-1]:
|
711
|
+
val = pre_val[..., :self.out_size[-1]]
|
712
|
+
post_val = post_val - val
|
713
|
+
else:
|
714
|
+
size = list(self.out_size)
|
715
|
+
size[-1] = self.out_size[-1] - self.in_size[-1]
|
716
|
+
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
717
|
+
post_val = post_val - val
|
718
|
+
post_val = self.weight.value * post_val
|
719
|
+
|
720
|
+
else: # weight is a matrix
|
721
|
+
assert u.math.ndim(self.weight.value) == 2, '"weight" must be a 2D matrix.'
|
722
|
+
if not self.include_self:
|
723
|
+
post_val = pre_val @ u.math.fill_diagonal(self.weight.value, 0.)
|
724
|
+
else:
|
725
|
+
post_val = pre_val @ self.weight.value
|
726
|
+
return post_val
|