brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,726 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import collections.abc
21
+ import numbers
22
+ from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
23
+
24
+ import brainunit as u
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ from brainstate import init, functional
29
+ from brainstate._state import ParamState
30
+ from brainstate.nn._module import Module
31
+ from brainstate.typing import ArrayLike
32
+
33
+ T = TypeVar('T')
34
+
35
+ __all__ = [
36
+ 'Linear', 'ScaledWSLinear', 'SignedWLinear', 'CSRLinear',
37
+ 'Conv1d', 'Conv2d', 'Conv3d',
38
+ 'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
39
+ 'AllToAll',
40
+ ]
41
+
42
+
43
+ def to_dimension_numbers(
44
+ num_spatial_dims: int,
45
+ channels_last: bool,
46
+ transpose: bool
47
+ ) -> jax.lax.ConvDimensionNumbers:
48
+ """Create a `lax.ConvDimensionNumbers` for the given inputs."""
49
+ num_dims = num_spatial_dims + 2
50
+ if channels_last:
51
+ spatial_dims = tuple(range(1, num_dims - 1))
52
+ image_dn = (0, num_dims - 1) + spatial_dims
53
+ else:
54
+ spatial_dims = tuple(range(2, num_dims))
55
+ image_dn = (0, 1) + spatial_dims
56
+ if transpose:
57
+ kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
58
+ else:
59
+ kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
60
+ return jax.lax.ConvDimensionNumbers(lhs_spec=image_dn,
61
+ rhs_spec=kernel_dn,
62
+ out_spec=image_dn)
63
+
64
+
65
+ def replicate(
66
+ element: Union[T, Sequence[T]],
67
+ num_replicate: int,
68
+ name: str,
69
+ ) -> Tuple[T, ...]:
70
+ """Replicates entry in `element` `num_replicate` if needed."""
71
+ if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
72
+ return (element,) * num_replicate
73
+ elif len(element) == 1:
74
+ return tuple(list(element) * num_replicate)
75
+ elif len(element) == num_replicate:
76
+ return tuple(element)
77
+ else:
78
+ raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
79
+ f"sequence of length {num_replicate}.")
80
+
81
+
82
+ class Linear(Module):
83
+ """
84
+ Linear layer.
85
+ """
86
+ __module__ = 'brainstate.nn'
87
+
88
+ def __init__(
89
+ self,
90
+ in_size: Union[int, Sequence[int]],
91
+ out_size: Union[int, Sequence[int]],
92
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
93
+ b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
94
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
95
+ name: Optional[str] = None,
96
+ ):
97
+ super().__init__(name=name)
98
+
99
+ # input and output shape
100
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
101
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
102
+
103
+ # w_mask
104
+ self.w_mask = init.param(w_mask, self.in_size + self.out_size)
105
+
106
+ # weights
107
+ params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
108
+ if b_init is not None:
109
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
110
+
111
+ # weight + op
112
+ self.weight = ParamState(params)
113
+
114
+ def update(self, x):
115
+ params = self.weight.value
116
+ weight = params['weight']
117
+ if self.w_mask is not None:
118
+ weight = weight * self.w_mask
119
+ y = u.math.dot(x, weight)
120
+ if 'bias' in params:
121
+ y = y + params['bias']
122
+ return y
123
+
124
+
125
+ class SignedWLinear(Module):
126
+ """
127
+ Linear layer with signed weights.
128
+ """
129
+ __module__ = 'brainstate.nn'
130
+
131
+ def __init__(
132
+ self,
133
+ in_size: Union[int, Sequence[int]],
134
+ out_size: Union[int, Sequence[int]],
135
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
136
+ w_sign: Optional[ArrayLike] = None,
137
+ name: Optional[str] = None,
138
+
139
+ ):
140
+ super().__init__(name=name)
141
+
142
+ # input and output shape
143
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
144
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
145
+
146
+ # w_mask
147
+ self.w_sign = w_sign
148
+
149
+ # weights
150
+ weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
151
+ self.weight = ParamState(weight)
152
+
153
+ def _operation(self, x, w):
154
+ if self.w_sign is None:
155
+ return jnp.matmul(x, jnp.abs(w))
156
+ else:
157
+ return jnp.matmul(x, jnp.abs(w) * self.w_sign)
158
+
159
+ def update(self, x):
160
+ return self._operation(x, self.weight.value)
161
+
162
+
163
+ class ScaledWSLinear(Module):
164
+ """
165
+ Linear Layer with Weight Standardization.
166
+
167
+ Applies weight standardization to the weights of the linear layer.
168
+
169
+ Parameters
170
+ ----------
171
+ in_size: int, sequence of int
172
+ The input size.
173
+ out_size: int, sequence of int
174
+ The output size.
175
+ w_init: Callable, ArrayLike
176
+ The initializer for the weights.
177
+ b_init: Callable, ArrayLike
178
+ The initializer for the bias.
179
+ w_mask: ArrayLike, Callable
180
+ The optional mask of the weights.
181
+ ws_gain: bool
182
+ Whether to use gain for the weights. The default is True.
183
+ eps: float
184
+ The epsilon value for the weight standardization.
185
+ name: str
186
+ The name of the object.
187
+
188
+ """
189
+ __module__ = 'brainstate.nn'
190
+
191
+ def __init__(
192
+ self,
193
+ in_size: Union[int, Sequence[int]],
194
+ out_size: Union[int, Sequence[int]],
195
+ w_init: Callable = init.KaimingNormal(),
196
+ b_init: Callable = init.ZeroInit(),
197
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
198
+ ws_gain: bool = True,
199
+ eps: float = 1e-4,
200
+ name: str = None,
201
+ ):
202
+ super().__init__(name=name)
203
+
204
+ # input and output shape
205
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
206
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
207
+
208
+ # w_mask
209
+ self.w_mask = init.param(w_mask, (self.in_size[0], 1))
210
+
211
+ # parameters
212
+ self.eps = eps
213
+
214
+ # weights
215
+ params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
216
+ if b_init is not None:
217
+ params['bias'] = init.param(b_init, self.out_size, allow_none=False)
218
+ # gain
219
+ if ws_gain:
220
+ s = params['weight'].shape
221
+ params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
222
+
223
+ # weight operation
224
+ self.weight = ParamState(params)
225
+
226
+ def update(self, x):
227
+ return self._operation(x, self.weight.value)
228
+
229
+ def _operation(self, x, params):
230
+ w = params['weight']
231
+ w = functional.weight_standardization(w, self.eps, params.get('gain', None))
232
+ if self.w_mask is not None:
233
+ w = w * self.w_mask
234
+ y = jnp.dot(x, w)
235
+ if 'bias' in params:
236
+ y = y + params['bias']
237
+ return y
238
+
239
+
240
+ class CSRLinear(Module):
241
+ __module__ = 'brainstate.nn'
242
+
243
+
244
+ class _BaseConv(Module):
245
+ # the number of spatial dimensions
246
+ num_spatial_dims: int
247
+
248
+ # the weight and its operations
249
+ weight: ParamState
250
+
251
+ def __init__(
252
+ self,
253
+ in_size: Sequence[int],
254
+ out_channels: int,
255
+ kernel_size: Union[int, Tuple[int, ...]],
256
+ stride: Union[int, Tuple[int, ...]] = 1,
257
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
258
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
259
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
260
+ groups: int = 1,
261
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
262
+ name: str = None,
263
+ ):
264
+ super().__init__(name=name)
265
+
266
+ # general parameters
267
+ assert self.num_spatial_dims + 1 == len(in_size)
268
+ self.in_size = tuple(in_size)
269
+ self.in_channels = in_size[-1]
270
+ self.out_channels = out_channels
271
+ self.stride = replicate(stride, self.num_spatial_dims, 'stride')
272
+ self.kernel_size = replicate(kernel_size, self.num_spatial_dims, 'kernel_size')
273
+ self.lhs_dilation = replicate(lhs_dilation, self.num_spatial_dims, 'lhs_dilation')
274
+ self.rhs_dilation = replicate(rhs_dilation, self.num_spatial_dims, 'rhs_dilation')
275
+ self.groups = groups
276
+ self.dimension_numbers = to_dimension_numbers(self.num_spatial_dims, channels_last=True, transpose=False)
277
+
278
+ # the padding parameter
279
+ if isinstance(padding, str):
280
+ assert padding in ['SAME', 'VALID']
281
+ elif isinstance(padding, int):
282
+ padding = tuple((padding, padding) for _ in range(self.num_spatial_dims))
283
+ elif isinstance(padding, (tuple, list)):
284
+ if isinstance(padding[0], int):
285
+ padding = (padding,) * self.num_spatial_dims
286
+ elif isinstance(padding[0], (tuple, list)):
287
+ if len(padding) == 1:
288
+ padding = tuple(padding) * self.num_spatial_dims
289
+ else:
290
+ if len(padding) != self.num_spatial_dims:
291
+ raise ValueError(
292
+ f"Padding {padding} must be a Tuple[int, int], "
293
+ f"or sequence of Tuple[int, int] with length 1, "
294
+ f"or sequence of Tuple[int, int] with length {self.num_spatial_dims}."
295
+ )
296
+ padding = tuple(padding)
297
+ else:
298
+ raise ValueError
299
+ self.padding = padding
300
+
301
+ # the number of in-/out-channels
302
+ assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
303
+ assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'
304
+
305
+ # kernel shape and w_mask
306
+ kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
307
+ self.kernel_shape = kernel_shape
308
+ self.w_mask = init.param(w_mask, kernel_shape, allow_none=True)
309
+
310
+ def _check_input_dim(self, x):
311
+ if x.ndim == self.num_spatial_dims + 2:
312
+ x_shape = x.shape[1:]
313
+ elif x.ndim == self.num_spatial_dims + 1:
314
+ x_shape = x.shape
315
+ else:
316
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
317
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
318
+ if self.in_size != x_shape:
319
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
320
+
321
+ def update(self, x):
322
+ self._check_input_dim(x)
323
+ non_batching = False
324
+ if x.ndim == self.num_spatial_dims + 1:
325
+ x = jnp.expand_dims(x, 0)
326
+ non_batching = True
327
+ y = self._conv_op(x, self.weight.value)
328
+ return y[0] if non_batching else y
329
+
330
+ def _conv_op(self, x, params):
331
+ raise NotImplementedError
332
+
333
+ def __repr__(self):
334
+ return (f'{self.__class__.__name__}('
335
+ f'in_channels={self.in_channels}, '
336
+ f'out_channels={self.out_channels}, '
337
+ f'kernel_size={self.kernel_size}, '
338
+ f'stride={self.stride}, '
339
+ f'padding={self.padding}, '
340
+ f'groups={self.groups})')
341
+
342
+
343
+ class _Conv(_BaseConv):
344
+ num_spatial_dims: int = None
345
+
346
+ def __init__(
347
+ self,
348
+ in_size: Sequence[int],
349
+ out_channels: int,
350
+ kernel_size: Union[int, Tuple[int, ...]],
351
+ stride: Union[int, Tuple[int, ...]] = 1,
352
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
353
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
354
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
355
+ groups: int = 1,
356
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
357
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
358
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
359
+ name: str = None,
360
+ ):
361
+ super().__init__(in_size=in_size,
362
+ out_channels=out_channels,
363
+ kernel_size=kernel_size,
364
+ stride=stride,
365
+ padding=padding,
366
+ lhs_dilation=lhs_dilation,
367
+ rhs_dilation=rhs_dilation,
368
+ groups=groups,
369
+ w_mask=w_mask,
370
+ name=name)
371
+
372
+ self.w_initializer = w_init
373
+ self.b_initializer = b_init
374
+
375
+ # --- weights --- #
376
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
377
+ params = dict(weight=weight)
378
+ if self.b_initializer is not None:
379
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
380
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
381
+ params['bias'] = bias
382
+
383
+ # The weight operation
384
+ self.weight = ParamState(params)
385
+
386
+ # Evaluate the output shape
387
+ abstract_y = jax.eval_shape(
388
+ self._conv_op,
389
+ jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
390
+ params
391
+ )
392
+ y_shape = abstract_y.shape[1:]
393
+ self.out_size = y_shape
394
+
395
+ def _conv_op(self, x, params):
396
+ w = params['weight']
397
+ if self.w_mask is not None:
398
+ w = w * self.w_mask
399
+ y = jax.lax.conv_general_dilated(
400
+ lhs=x,
401
+ rhs=w,
402
+ window_strides=self.stride,
403
+ padding=self.padding,
404
+ lhs_dilation=self.lhs_dilation,
405
+ rhs_dilation=self.rhs_dilation,
406
+ feature_group_count=self.groups,
407
+ dimension_numbers=self.dimension_numbers
408
+ )
409
+ if 'bias' in params:
410
+ y = y + params['bias']
411
+ return y
412
+
413
+
414
+ class Conv1d(_Conv):
415
+ """One-dimensional convolution.
416
+
417
+ The input should be a 3d array with the shape of ``[B, H, C]``.
418
+
419
+ Parameters
420
+ ----------
421
+ %s
422
+ """
423
+ __module__ = 'brainstate.nn'
424
+ num_spatial_dims: int = 1
425
+
426
+
427
+ class Conv2d(_Conv):
428
+ """Two-dimensional convolution.
429
+
430
+ The input should be a 4d array with the shape of ``[B, H, W, C]``.
431
+
432
+ Parameters
433
+ ----------
434
+ %s
435
+ """
436
+ __module__ = 'brainstate.nn'
437
+ num_spatial_dims: int = 2
438
+
439
+
440
+ class Conv3d(_Conv):
441
+ """Three-dimensional convolution.
442
+
443
+ The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
444
+
445
+ Parameters
446
+ ----------
447
+ %s
448
+ """
449
+ __module__ = 'brainstate.nn'
450
+ num_spatial_dims: int = 3
451
+
452
+
453
+ _conv_doc = '''
454
+ in_size: tuple of int
455
+ The input shape, without the batch size. This argument is important, since it is
456
+ used to evaluate the shape of the output.
457
+ out_channels: int
458
+ The number of output channels.
459
+ kernel_size: int, sequence of int
460
+ The shape of the convolutional kernel.
461
+ For 1D convolution, the kernel size can be passed as an integer.
462
+ For all other cases, it must be a sequence of integers.
463
+ stride: int, sequence of int
464
+ An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
465
+ padding: str, int, sequence of int, sequence of tuple
466
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
467
+ high)` integer pairs that give the padding to apply before and after each
468
+ spatial dimension.
469
+ lhs_dilation: int, sequence of int
470
+ An integer or a sequence of `n` integers, giving the
471
+ dilation factor to apply in each spatial dimension of `inputs`
472
+ (default: 1). Convolution with input dilation `d` is equivalent to
473
+ transposed convolution with stride `d`.
474
+ rhs_dilation: int, sequence of int
475
+ An integer or a sequence of `n` integers, giving the
476
+ dilation factor to apply in each spatial dimension of the convolution
477
+ kernel (default: 1). Convolution with kernel dilation
478
+ is also known as 'atrous convolution'.
479
+ groups: int
480
+ If specified, divides the input features into groups. default 1.
481
+ w_init: Callable, ArrayLike, Initializer
482
+ The initializer for the convolutional kernel.
483
+ b_init: Optional, Callable, ArrayLike, Initializer
484
+ The initializer for the bias.
485
+ w_mask: ArrayLike, Callable, Optional
486
+ The optional mask of the weights.
487
+ mode: Mode
488
+ The computation mode of the current object. Default it is `training`.
489
+ name: str, Optional
490
+ The name of the object.
491
+ '''
492
+
493
+ Conv1d.__doc__ = Conv1d.__doc__ % _conv_doc
494
+ Conv2d.__doc__ = Conv2d.__doc__ % _conv_doc
495
+ Conv3d.__doc__ = Conv3d.__doc__ % _conv_doc
496
+
497
+
498
+ class _ScaledWSConv(_BaseConv):
499
+ def __init__(
500
+ self,
501
+ in_size: Sequence[int],
502
+ out_channels: int,
503
+ kernel_size: Union[int, Tuple[int, ...]],
504
+ stride: Union[int, Tuple[int, ...]] = 1,
505
+ padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
506
+ lhs_dilation: Union[int, Tuple[int, ...]] = 1,
507
+ rhs_dilation: Union[int, Tuple[int, ...]] = 1,
508
+ groups: int = 1,
509
+ ws_gain: bool = True,
510
+ eps: float = 1e-4,
511
+ w_init: Union[Callable, ArrayLike] = init.XavierNormal(),
512
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
513
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
514
+ name: str = None,
515
+ ):
516
+ super().__init__(in_size=in_size,
517
+ out_channels=out_channels,
518
+ kernel_size=kernel_size,
519
+ stride=stride,
520
+ padding=padding,
521
+ lhs_dilation=lhs_dilation,
522
+ rhs_dilation=rhs_dilation,
523
+ groups=groups,
524
+ w_mask=w_mask,
525
+ name=name, )
526
+
527
+ self.w_initializer = w_init
528
+ self.b_initializer = b_init
529
+
530
+ # --- weights --- #
531
+ weight = init.param(self.w_initializer, self.kernel_shape, allow_none=False)
532
+ params = dict(weight=weight)
533
+ if self.b_initializer is not None:
534
+ bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
535
+ bias = init.param(self.b_initializer, bias_shape, allow_none=True)
536
+ params['bias'] = bias
537
+
538
+ # gain
539
+ if ws_gain:
540
+ gain_size = (1,) * len(self.kernel_size) + (1, self.out_channels)
541
+ ws_gain = jnp.ones(gain_size, dtype=params['weight'].dtype)
542
+ params['gain'] = ws_gain
543
+
544
+ # Epsilon, a small constant to avoid dividing by zero.
545
+ self.eps = eps
546
+
547
+ # The weight operation
548
+ self.weight = ParamState(params)
549
+
550
+ # Evaluate the output shape
551
+ abstract_y = jax.eval_shape(
552
+ self._conv_op,
553
+ jax.ShapeDtypeStruct((128,) + self.in_size, weight.dtype),
554
+ params
555
+ )
556
+ y_shape = abstract_y.shape[1:]
557
+ self.out_size = y_shape
558
+
559
+ def _conv_op(self, x, params):
560
+ w = params['weight']
561
+ w = functional.weight_standardization(w, self.eps, params.get('gain', None))
562
+ if self.w_mask is not None:
563
+ w = w * self.w_mask
564
+ y = jax.lax.conv_general_dilated(
565
+ lhs=x,
566
+ rhs=w,
567
+ window_strides=self.stride,
568
+ padding=self.padding,
569
+ lhs_dilation=self.lhs_dilation,
570
+ rhs_dilation=self.rhs_dilation,
571
+ feature_group_count=self.groups,
572
+ dimension_numbers=self.dimension_numbers
573
+ )
574
+ if 'bias' in params:
575
+ y = y + params['bias']
576
+ return y
577
+
578
+
579
+ class ScaledWSConv1d(_ScaledWSConv):
580
+ """One-dimensional convolution with weight standardization.
581
+
582
+ The input should be a 3d array with the shape of ``[B, H, C]``.
583
+
584
+ Parameters
585
+ ----------
586
+ %s
587
+ """
588
+ __module__ = 'brainstate.nn'
589
+ num_spatial_dims: int = 1
590
+
591
+
592
+ class ScaledWSConv2d(_ScaledWSConv):
593
+ """Two-dimensional convolution with weight standardization.
594
+
595
+ The input should be a 4d array with the shape of ``[B, H, W, C]``.
596
+
597
+ Parameters
598
+ ----------
599
+ %s
600
+ """
601
+ __module__ = 'brainstate.nn'
602
+ num_spatial_dims: int = 2
603
+
604
+
605
+ class ScaledWSConv3d(_ScaledWSConv):
606
+ """Three-dimensional convolution with weight standardization.
607
+
608
+ The input should be a 5d array with the shape of ``[B, H, W, D, C]``.
609
+
610
+ Parameters
611
+ ----------
612
+ %s
613
+ """
614
+ __module__ = 'brainstate.nn'
615
+ num_spatial_dims: int = 3
616
+
617
+
618
+ _ws_conv_doc = '''
619
+ in_size: tuple of int
620
+ The input shape, without the batch size. This argument is important, since it is
621
+ used to evaluate the shape of the output.
622
+ out_channels: int
623
+ The number of output channels.
624
+ kernel_size: int, sequence of int
625
+ The shape of the convolutional kernel.
626
+ For 1D convolution, the kernel size can be passed as an integer.
627
+ For all other cases, it must be a sequence of integers.
628
+ stride: int, sequence of int
629
+ An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
630
+ padding: str, int, sequence of int, sequence of tuple
631
+ Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
632
+ high)` integer pairs that give the padding to apply before and after each
633
+ spatial dimension.
634
+ lhs_dilation: int, sequence of int
635
+ An integer or a sequence of `n` integers, giving the
636
+ dilation factor to apply in each spatial dimension of `inputs`
637
+ (default: 1). Convolution with input dilation `d` is equivalent to
638
+ transposed convolution with stride `d`.
639
+ rhs_dilation: int, sequence of int
640
+ An integer or a sequence of `n` integers, giving the
641
+ dilation factor to apply in each spatial dimension of the convolution
642
+ kernel (default: 1). Convolution with kernel dilation
643
+ is also known as 'atrous convolution'.
644
+ groups: int
645
+ If specified, divides the input features into groups. default 1.
646
+ w_init: Callable, ArrayLike, Initializer
647
+ The initializer for the convolutional kernel.
648
+ b_init: Optional, Callable, ArrayLike, Initializer
649
+ The initializer for the bias.
650
+ ws_gain: bool
651
+ Whether to add a gain term for the weight standarization. The default is `True`.
652
+ eps: float
653
+ The epsilon value for numerical stability.
654
+ w_mask: ArrayLike, Callable, Optional
655
+ The optional mask of the weights.
656
+ mode: Mode
657
+ The computation mode of the current object. Default it is `training`.
658
+ name: str, Optional
659
+ The name of the object.
660
+
661
+ '''
662
+
663
+ ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
664
+ ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
665
+ ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
666
+
667
+
668
+ class AllToAll(Module):
669
+ """Synaptic matrix multiplication with All2All connections.
670
+
671
+ Args:
672
+ in_size: int. The number of neurons in the presynaptic neuron group.
673
+ out_size: int. The number of neurons in the postsynaptic neuron group.
674
+ w_init: The synaptic weights.
675
+ include_self: bool. Whether connect the neuron with at the same position.
676
+ name: str. The object name.
677
+ """
678
+
679
+ def __init__(
680
+ self,
681
+ in_size: Union[int, Sequence[int]],
682
+ out_size: Union[int, Sequence[int]],
683
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
684
+ include_self: bool = True,
685
+ name: Optional[str] = None,
686
+ ):
687
+ super().__init__(name=name)
688
+
689
+ # input and output shape
690
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
691
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
692
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
693
+ 'and "out_size" must be the same.')
694
+
695
+ # weights
696
+ self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
697
+
698
+ # others
699
+ self.include_self = include_self
700
+
701
+ def update(self, pre_val):
702
+ if u.math.ndim(self.weight.value) == 0: # weight is a scalar
703
+ if pre_val.ndim == 1:
704
+ post_val = u.math.sum(pre_val)
705
+ else:
706
+ post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
707
+ if not self.include_self:
708
+ if self.in_size == self.out_size:
709
+ post_val = post_val - pre_val
710
+ elif self.in_size[-1] > self.out_size[-1]:
711
+ val = pre_val[..., :self.out_size[-1]]
712
+ post_val = post_val - val
713
+ else:
714
+ size = list(self.out_size)
715
+ size[-1] = self.out_size[-1] - self.in_size[-1]
716
+ val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
717
+ post_val = post_val - val
718
+ post_val = self.weight.value * post_val
719
+
720
+ else: # weight is a matrix
721
+ assert u.math.ndim(self.weight.value) == 2, '"weight" must be a 2D matrix.'
722
+ if not self.include_self:
723
+ post_val = pre_val @ u.math.fill_diagonal(self.weight.value, 0.)
724
+ else:
725
+ post_val = pre_val @ self.weight.value
726
+ return post_val