brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,388 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import numbers
21
+ from typing import Callable, Union, Sequence, Optional, Any
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+ from brainstate import environ, init
27
+ from brainstate._state import LongTermState, ParamState
28
+ from brainstate.nn._module import Module
29
+ from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
30
+
31
+ __all__ = [
32
+ 'BatchNorm0d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
33
+ ]
34
+
35
+
36
+ def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
37
+ axes = []
38
+ for axis in feature_axes:
39
+ if axis < 0:
40
+ axis += ndim
41
+ if axis < 0 or axis >= ndim:
42
+ raise ValueError(f'Invalid axis {axis} for {ndim}D input')
43
+ axes.append(axis)
44
+ return tuple(axes)
45
+
46
+
47
+ def _abs_sq(x):
48
+ """Computes the elementwise square of the absolute value |x|^2."""
49
+ if jnp.iscomplexobj(x):
50
+ return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
51
+ else:
52
+ return jax.lax.square(x)
53
+
54
+
55
+ def _compute_stats(
56
+ x: ArrayLike,
57
+ axes: Sequence[int],
58
+ dtype: DTypeLike,
59
+ axis_name: Optional[str] = None,
60
+ axis_index_groups: Optional[Sequence[int]] = None,
61
+ use_mean: bool = True,
62
+ ):
63
+ """Computes mean and variance statistics.
64
+
65
+ This implementation takes care of a few important details:
66
+ - Computes in float32 precision for stability in half precision training.
67
+ - mean and variance are computable in a single XLA fusion,
68
+ by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
69
+ - Clips negative variances to zero which can happen due to
70
+ roundoff errors. This avoids downstream NaNs.
71
+ - Supports averaging across a parallel axis and subgroups of a parallel axis
72
+ with a single `lax.pmean` call to avoid latency.
73
+
74
+ Arguments:
75
+ x: Input array.
76
+ axes: The axes in ``x`` to compute mean and variance statistics for.
77
+ dtype: tp.Optional dtype specifying the minimal precision. Statistics
78
+ are always at least float32 for stability (default: dtype of x).
79
+ axis_name: tp.Optional name for the pmapped axis to compute mean over.
80
+ axis_index_groups: tp.Optional axis indices.
81
+ use_mean: If true, calculate the mean from the input and use it when
82
+ computing the variance. If false, set the mean to zero and compute
83
+ the variance without subtracting the mean.
84
+
85
+ Returns:
86
+ A pair ``(mean, val)``.
87
+ """
88
+ if dtype is None:
89
+ dtype = jax.numpy.result_type(x)
90
+ # promote x to at least float32, this avoids half precision computation
91
+ # but preserves double or complex floating points
92
+ dtype = jax.numpy.promote_types(dtype, environ.dftype())
93
+ x = jnp.asarray(x, dtype)
94
+
95
+ # Compute mean and mean of squared values.
96
+ mean2 = jnp.mean(_abs_sq(x), axes)
97
+ if use_mean:
98
+ mean = jnp.mean(x, axes)
99
+ else:
100
+ mean = jnp.zeros(mean2.shape, dtype=dtype)
101
+
102
+ # If axis_name is provided, we need to average the mean and mean2 across
103
+ if axis_name is not None:
104
+ concatenated_mean = jnp.concatenate([mean, mean2])
105
+ mean, mean2 = jnp.split(
106
+ jax.lax.pmean(
107
+ concatenated_mean,
108
+ axis_name=axis_name,
109
+ axis_index_groups=axis_index_groups,
110
+ ),
111
+ 2,
112
+ )
113
+
114
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
115
+ # to floating point round-off errors.
116
+ var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
117
+ return mean, var
118
+
119
+
120
+ def _normalize(
121
+ x: ArrayLike,
122
+ mean: Optional[ArrayLike],
123
+ var: Optional[ArrayLike],
124
+ weights: Optional[ParamState],
125
+ reduction_axes: Sequence[int],
126
+ dtype: DTypeLike,
127
+ epsilon: Union[numbers.Number, jax.Array],
128
+ ):
129
+ """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
130
+
131
+ Arguments:
132
+ x: The input.
133
+ mean: Mean to use for normalization.
134
+ var: Variance to use for normalization.
135
+ weights: The scale and bias parameters.
136
+ reduction_axes: The axes in ``x`` to reduce.
137
+ dtype: The dtype of the result (default: infer from input and params).
138
+ epsilon: Normalization epsilon.
139
+
140
+ Returns:
141
+ The normalized input.
142
+ """
143
+ if mean is not None:
144
+ assert var is not None, 'mean and val must be both None or not None.'
145
+ stats_shape = list(x.shape)
146
+ for axis in reduction_axes:
147
+ stats_shape[axis] = 1
148
+ mean = mean.reshape(stats_shape)
149
+ var = var.reshape(stats_shape)
150
+ y = x - mean
151
+ mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
152
+ y = y * mul
153
+ if weights is not None:
154
+ y = _scale_operation(y, weights.value)
155
+ else:
156
+ assert var is None, 'mean and val must be both None or not None.'
157
+ assert weights is None, 'scale and bias are not supported without mean and val'
158
+ y = x
159
+ return jnp.asarray(y, dtype)
160
+
161
+
162
+ def _scale_operation(x, param):
163
+ if 'scale' in param:
164
+ x = x * param['scale']
165
+ if 'bias' in param:
166
+ x = x + param['bias']
167
+ return x
168
+
169
+
170
+ class _BatchNorm(Module):
171
+ __module__ = 'brainstate.nn'
172
+ num_spatial_dims: int
173
+
174
+ def __init__(
175
+ self,
176
+ in_size: Size,
177
+ feature_axis: Axes = -1,
178
+ track_running_stats: bool = True,
179
+ epsilon: float = 1e-5,
180
+ momentum: float = 0.99,
181
+ affine: bool = True,
182
+ bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
183
+ scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
184
+ axis_name: Optional[Union[str, Sequence[str]]] = None,
185
+ axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
186
+ name: Optional[str] = None,
187
+ dtype: Any = None,
188
+ ):
189
+ super().__init__(name=name)
190
+
191
+ # parameters
192
+ self.in_size = tuple(in_size)
193
+ self.out_size = tuple(in_size)
194
+ self.affine = affine
195
+ self.bias_initializer = bias_initializer
196
+ self.scale_initializer = scale_initializer
197
+ self.dtype = dtype or environ.dftype()
198
+ self.track_running_stats = track_running_stats
199
+ self.momentum = jnp.asarray(momentum, dtype=self.dtype)
200
+ self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
201
+
202
+ # parameters about axis
203
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
204
+ self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
205
+ self.axis_name = axis_name
206
+ self.axis_index_groups = axis_index_groups
207
+
208
+ # variables
209
+ feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
210
+ if self.track_running_stats:
211
+ self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
212
+ self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
213
+ else:
214
+ self.running_mean = None
215
+ self.running_var = None
216
+
217
+ # parameters
218
+ if self.affine:
219
+ assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
220
+ bias = init.param(self.bias_initializer, feature_shape)
221
+ scale = init.param(self.scale_initializer, feature_shape)
222
+ self.weight = ParamState(dict(bias=bias, scale=scale))
223
+ else:
224
+ self.weight = None
225
+
226
+ def update(self, x):
227
+ # input shape and batch mode or not
228
+ if x.ndim == self.num_spatial_dims + 2:
229
+ x_shape = x.shape[1:]
230
+ batch = True
231
+ elif x.ndim == self.num_spatial_dims + 1:
232
+ x_shape = x.shape
233
+ batch = False
234
+ else:
235
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
236
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
237
+ if self.in_size != x_shape:
238
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
239
+
240
+ # reduce the feature axis
241
+ if batch:
242
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
243
+ else:
244
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
245
+
246
+ # fitting phase
247
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
248
+
249
+ # compute the running mean and variance
250
+ if self.track_running_stats:
251
+ if fit_phase:
252
+ mean, var = _compute_stats(
253
+ x,
254
+ reduction_axes,
255
+ dtype=self.dtype,
256
+ axis_name=self.axis_name,
257
+ axis_index_groups=self.axis_index_groups,
258
+ )
259
+ self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
260
+ self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
261
+ else:
262
+ mean = self.running_mean.value
263
+ var = self.running_var.value
264
+ else:
265
+ mean, var = None, None
266
+
267
+ # normalize
268
+ return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
269
+
270
+
271
+ class BatchNorm0d(_BatchNorm):
272
+ r"""1-D batch normalization [1]_.
273
+
274
+ The data should be of `(b, l, c)`, where `b` is the batch dimension,
275
+ `l` is the layer dimension, and `c` is the channel dimension.
276
+
277
+ %s
278
+ """
279
+ __module__ = 'brainstate.nn'
280
+ num_spatial_dims: int = 0
281
+
282
+
283
+ class BatchNorm1d(_BatchNorm):
284
+ r"""1-D batch normalization [1]_.
285
+
286
+ The data should be of `(b, l, c)`, where `b` is the batch dimension,
287
+ `l` is the layer dimension, and `c` is the channel dimension.
288
+
289
+ %s
290
+ """
291
+ __module__ = 'brainstate.nn'
292
+ num_spatial_dims: int = 1
293
+
294
+
295
+ class BatchNorm2d(_BatchNorm):
296
+ r"""2-D batch normalization [1]_.
297
+
298
+ The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
299
+ `h` is the height dimension, `w` is the width dimension, and `c` is the
300
+ channel dimension.
301
+
302
+ %s
303
+ """
304
+ __module__ = 'brainstate.nn'
305
+ num_spatial_dims: int = 2
306
+
307
+
308
+ class BatchNorm3d(_BatchNorm):
309
+ r"""3-D batch normalization [1]_.
310
+
311
+ The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
312
+ `h` is the height dimension, `w` is the width dimension, `d` is the depth
313
+ dimension, and `c` is the channel dimension.
314
+
315
+ %s
316
+ """
317
+ __module__ = 'brainstate.nn'
318
+ num_spatial_dims: int = 3
319
+
320
+
321
+ _bn_doc = r'''
322
+
323
+ This layer aims to reduce the internal covariant shift of data. It
324
+ normalizes a batch of data by fixing the mean and variance of inputs
325
+ on each feature (channel). Most commonly, the first axis of the data
326
+ is the batch, and the last is the channel. However, users can specify
327
+ the axes to be normalized.
328
+
329
+ .. math::
330
+ y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
331
+
332
+ .. note::
333
+ This :attr:`momentum` argument is different from one used in optimizer
334
+ classes and the conventional notion of momentum. Mathematically, the
335
+ update rule for running statistics here is
336
+ :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
337
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
338
+ new observed value.
339
+
340
+ Parameters
341
+ ----------
342
+ in_size: sequence of int
343
+ The input shape, without batch size.
344
+ feature_axis: int, tuple, list
345
+ The feature or non-batch axis of the input.
346
+ track_running_stats: bool
347
+ A boolean value that when set to ``True``, this module tracks the running mean and variance,
348
+ and when set to ``False``, this module does not track such statistics, and initializes
349
+ statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
350
+ this module always uses batch statistics. in both training and eval modes. Default: ``True``.
351
+ momentum: float
352
+ The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
353
+ epsilon: float
354
+ A value added to the denominator for numerical stability. Default: 1e-5
355
+ affine: bool
356
+ A boolean value that when set to ``True``, this module has
357
+ learnable affine parameters. Default: ``True``
358
+ bias_initializer: ArrayLike, Callable
359
+ An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
360
+ Default: ``init.Constant(0.)``
361
+ scale_initializer: ArrayLike, Callable
362
+ An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
363
+ Default: ``init.Constant(1.)``
364
+ axis_name: optional, str, sequence of str
365
+ If not ``None``, it should be a string (or sequence of
366
+ strings) representing the axis name(s) over which this module is being
367
+ run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
368
+ argument means that batch statistics are calculated across all replicas
369
+ on the named axes.
370
+ axis_index_groups: optional, sequence
371
+ Specifies how devices are grouped. Valid
372
+ only within ``jax.pmap`` collectives.
373
+ Groups of axis indices within that named axis
374
+ representing subsets of devices to reduce over (default: None). For
375
+ example, `[[0, 1], [2, 3]]` would independently batch-normalize over
376
+ the examples on the first two and last two devices. See `jax.lax.psum`
377
+ for more details.
378
+
379
+ References
380
+ ----------
381
+ .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
382
+ by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
383
+
384
+ '''
385
+
386
+ BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
387
+ BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
388
+ BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+
21
+ import brainstate as bst
22
+
23
+
24
+ class Test_Normalization(parameterized.TestCase):
25
+ @parameterized.product(
26
+ fit=[True, False],
27
+ )
28
+ def test_BatchNorm1d(self, fit):
29
+ net = bst.nn.BatchNorm1d((3, 10))
30
+ bst.environ.set(fit=fit)
31
+ input = bst.random.randn(1, 3, 10)
32
+ output = net(input)
33
+
34
+ @parameterized.product(
35
+ fit=[True, False]
36
+ )
37
+ def test_BatchNorm2d(self, fit):
38
+ net = bst.nn.BatchNorm2d([3, 4, 10])
39
+ bst.environ.set(fit=fit)
40
+ input = bst.random.randn(1, 3, 4, 10)
41
+ output = net(input)
42
+
43
+ @parameterized.product(
44
+ fit=[True, False]
45
+ )
46
+ def test_BatchNorm3d(self, fit):
47
+ net = bst.nn.BatchNorm3d([3, 4, 5, 10])
48
+ bst.environ.set(fit=fit)
49
+ input = bst.random.randn(1, 3, 4, 5, 10)
50
+ output = net(input)
51
+
52
+ # @parameterized.product(
53
+ # normalized_shape=(10, [5, 10])
54
+ # )
55
+ # def test_LayerNorm(self, normalized_shape):
56
+ # net = bst.nn.LayerNorm(normalized_shape, )
57
+ # input = bst.random.randn(20, 5, 10)
58
+ # output = net(input)
59
+ #
60
+ # @parameterized.product(
61
+ # num_groups=[1, 2, 3, 6]
62
+ # )
63
+ # def test_GroupNorm(self, num_groups):
64
+ # input = bst.random.randn(20, 10, 10, 6)
65
+ # net = bst.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
66
+ # output = net(input)
67
+ #
68
+ # def test_InstanceNorm(self):
69
+ # input = bst.random.randn(20, 10, 10, 6)
70
+ # net = bst.nn.InstanceNorm(num_channels=6, )
71
+ # output = net(input)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ absltest.main()