brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1334 +1,1334 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Callable, Union, Sequence, Optional, Any
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.numpy as jnp
23
-
24
- from brainstate import environ
25
- from brainstate._state import ParamState, BatchState
26
- from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
27
- from . import init as init
28
- from ._module import Module
29
-
30
- __all__ = [
31
- 'weight_standardization',
32
- 'BatchNorm0d',
33
- 'BatchNorm1d',
34
- 'BatchNorm2d',
35
- 'BatchNorm3d',
36
- 'LayerNorm',
37
- 'RMSNorm',
38
- 'GroupNorm',
39
- ]
40
-
41
-
42
- def weight_standardization(
43
- w: ArrayLike,
44
- eps: float = 1e-4,
45
- gain: Optional[jax.Array] = None,
46
- out_axis: int = -1,
47
- ) -> Union[jax.Array, u.Quantity]:
48
- """
49
- Scaled Weight Standardization.
50
-
51
- Applies weight standardization to improve training stability, as described in
52
- "Micro-Batch Training with Batch-Channel Normalization and Weight Standardization" [1]_.
53
-
54
- Parameters
55
- ----------
56
- w : ArrayLike
57
- The weight tensor to be standardized.
58
- eps : float, optional
59
- A small value added to variance to avoid division by zero. Default is 1e-4.
60
- gain : jax.Array, optional
61
- Optional gain parameter to scale the standardized weights. Default is None.
62
- out_axis : int, optional
63
- The output axis of the weight tensor. Default is -1.
64
-
65
- Returns
66
- -------
67
- jax.Array or u.Quantity
68
- The standardized weight tensor with the same shape as input.
69
-
70
- References
71
- ----------
72
- .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
73
- Micro-Batch Training with Batch-Channel Normalization and Weight Standardization.
74
- arXiv preprint arXiv:1903.10520.
75
-
76
- Examples
77
- --------
78
- .. code-block:: python
79
-
80
- >>> import brainstate as bst
81
- >>> import jax.numpy as jnp
82
- >>>
83
- >>> # Standardize a weight matrix
84
- >>> w = jnp.ones((3, 4))
85
- >>> w_std = bst.nn.weight_standardization(w)
86
- >>>
87
- >>> # With custom gain
88
- >>> gain = jnp.ones((4,))
89
- >>> w_std = bst.nn.weight_standardization(w, gain=gain)
90
- """
91
- w = u.maybe_custom_array(w)
92
- if out_axis < 0:
93
- out_axis = w.ndim + out_axis
94
- fan_in = 1 # get the fan-in of the weight tensor
95
- axes = [] # get the axes of the weight tensor
96
- for i in range(w.ndim):
97
- if i != out_axis:
98
- fan_in *= w.shape[i]
99
- axes.append(i)
100
- # normalize the weight
101
- mean = u.math.mean(w, axis=axes, keepdims=True)
102
- var = u.math.var(w, axis=axes, keepdims=True)
103
-
104
- temp = u.math.maximum(var * fan_in, eps)
105
- if isinstance(temp, u.Quantity):
106
- unit = temp.unit
107
- temp = temp.mantissa
108
- if unit.is_unitless:
109
- scale = jax.lax.rsqrt(temp)
110
- else:
111
- scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
112
- else:
113
- scale = jax.lax.rsqrt(temp)
114
- if gain is not None:
115
- scale = gain * scale
116
- shift = mean * scale
117
- return w * scale - shift
118
-
119
-
120
- def canonicalize_dtype(
121
- *args,
122
- dtype: jax.typing.DTypeLike | None = None,
123
- inexact: bool = True
124
- ) -> jax.typing.DTypeLike:
125
- """
126
- Canonicalize an optional dtype to the definitive dtype.
127
-
128
- If the ``dtype`` is None, this function will infer the dtype from the input
129
- arguments using ``jnp.result_type``. If it is not None, it will be returned
130
- unmodified or an exception is raised if the dtype is invalid.
131
-
132
- Parameters
133
- ----------
134
- *args : ArrayLike
135
- JAX array compatible values. None values are ignored.
136
- dtype : jax.typing.DTypeLike, optional
137
- Optional dtype override. If specified, the arguments are cast to the
138
- specified dtype and dtype inference is disabled. Default is None.
139
- inexact : bool, optional
140
- When True, the output dtype must be a subtype of ``jnp.inexact``.
141
- Inexact dtypes are real or complex floating points. This is useful
142
- when applying operations that don't work directly on integers like
143
- taking a mean. Default is True.
144
-
145
- Returns
146
- -------
147
- jax.typing.DTypeLike
148
- The dtype that ``*args`` should be cast to.
149
-
150
- Raises
151
- ------
152
- ValueError
153
- If ``inexact=True`` and the resulting dtype is not an inexact type.
154
-
155
- Examples
156
- --------
157
- .. code-block:: python
158
-
159
- >>> import jax.numpy as jnp
160
- >>>
161
- >>> # Infer dtype from arguments
162
- >>> x = jnp.array([1, 2, 3])
163
- >>> dtype = canonicalize_dtype(x)
164
- >>>
165
- >>> # Specify explicit dtype
166
- >>> dtype = canonicalize_dtype(x, dtype=jnp.float64)
167
- """
168
- if dtype is None:
169
- args_filtered = [jnp.asarray(x) for x in args if x is not None]
170
- dtype = jnp.result_type(*args_filtered)
171
- if inexact and not jnp.issubdtype(dtype, jnp.inexact):
172
- dtype = jnp.promote_types(jnp.float32, dtype)
173
- if inexact and not jnp.issubdtype(dtype, jnp.inexact):
174
- raise ValueError(f'Dtype must be inexact: {dtype}')
175
- return dtype
176
-
177
-
178
- def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
179
- axes = []
180
- for axis in feature_axes:
181
- if axis < 0:
182
- axis += ndim
183
- if axis < 0 or axis >= ndim:
184
- raise ValueError(f'Invalid axis {axis} for {ndim}D input')
185
- axes.append(axis)
186
- return tuple(axes)
187
-
188
-
189
- def _abs_sq(x):
190
- """Computes the elementwise square of the absolute value |x|^2."""
191
- if jnp.iscomplexobj(x):
192
- return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
193
- else:
194
- return jax.lax.square(x)
195
-
196
-
197
- class NormalizationParamState(ParamState):
198
- # This is a dummy class to be used as a compatibility
199
- # usage of `ETraceParam` for the layers in "brainetrace"
200
- def execute(self, x):
201
- param = self.value
202
- if 'scale' in param:
203
- x = x * param['scale']
204
- if 'bias' in param:
205
- x = x + param['bias']
206
- return x
207
-
208
-
209
- def _compute_stats(
210
- x: ArrayLike,
211
- axes: Sequence[int],
212
- dtype: DTypeLike,
213
- axis_name: Optional[str] = None,
214
- axis_index_groups: Optional[Sequence[int]] = None,
215
- use_mean: bool = True,
216
- use_fast_variance: bool = True,
217
- mask: Optional[jax.Array] = None,
218
- ):
219
- """
220
- Compute mean and variance statistics for normalization.
221
-
222
- This implementation includes several optimizations:
223
-
224
- - Computes in float32 precision for stability in half precision training.
225
- - If ``use_fast_variance`` is True, uses the formula Var = E[|x|^2] - |E[x]|^2
226
- instead of Var = E[|x - E[x]|^2] in a single XLA fusion.
227
- - Clips negative variances to zero to avoid downstream NaNs from roundoff errors.
228
- - Supports averaging across parallel axes and subgroups with a single
229
- ``lax.pmean`` call to reduce latency.
230
-
231
- Parameters
232
- ----------
233
- x : ArrayLike
234
- Input array.
235
- axes : Sequence[int]
236
- The axes in ``x`` to compute mean and variance statistics for.
237
- dtype : DTypeLike
238
- Optional dtype specifying the minimal precision. Statistics are always
239
- at least float32 for stability. If None, uses the dtype of x.
240
- axis_name : str, optional
241
- Optional name for the pmapped axis to compute mean over. Only used for
242
- pmap and shard map. For SPMD jit, axes should be correctly annotated
243
- and XLA:SPMD will insert necessary collectives. Default is None.
244
- axis_index_groups : Sequence[int], optional
245
- Optional axis indices for grouped reductions. Default is None.
246
- use_mean : bool, optional
247
- If True, calculate the mean from the input and use it when computing
248
- the variance. If False, set the mean to zero and compute the variance
249
- without subtracting the mean. Default is True.
250
- use_fast_variance : bool, optional
251
- If True, use a faster but less numerically stable calculation for the
252
- variance. Default is True.
253
- mask : jax.Array, optional
254
- Binary array of shape broadcastable to ``x``, indicating the positions
255
- for which the mean and variance should be computed. Default is None.
256
-
257
- Returns
258
- -------
259
- tuple of jax.Array
260
- A pair ``(mean, var)`` containing the computed mean and variance.
261
- """
262
- if dtype is None:
263
- dtype = jax.numpy.result_type(x)
264
- # promote x to at least float32, this avoids half precision computation
265
- # but preserves double or complex floating points
266
- dtype = jax.numpy.promote_types(dtype, jnp.float32)
267
- x = jnp.asarray(x, dtype)
268
- axes = _canonicalize_axes(x.ndim, axes)
269
-
270
- def maybe_distributed_mean(*xs, mask=None):
271
- mus = tuple(x.mean(axes, where=mask) for x in xs)
272
- if axis_name is None:
273
- return mus if len(xs) > 1 else mus[0]
274
- else:
275
- # In the distributed case we stack multiple arrays to speed comms.
276
- if len(xs) > 1:
277
- reduced_mus = jax.lax.pmean(
278
- jnp.stack(mus, axis=0),
279
- axis_name,
280
- axis_index_groups=axis_index_groups,
281
- )
282
- return tuple(reduced_mus[i] for i in range(len(xs)))
283
- else:
284
- return jax.lax.pmean(
285
- mus[0],
286
- axis_name,
287
- axis_index_groups=axis_index_groups
288
- )
289
-
290
- if use_mean:
291
- if use_fast_variance:
292
- mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
293
- # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
294
- # to floating point round-off errors.
295
- var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
296
- else:
297
- mu = maybe_distributed_mean(x, mask=mask)
298
- var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
299
- else:
300
- var = maybe_distributed_mean(_abs_sq(x), mask=mask)
301
- mu = jnp.zeros_like(var)
302
- return mu, var
303
-
304
-
305
- def _normalize(
306
- x: ArrayLike,
307
- mean: Optional[ArrayLike],
308
- var: Optional[ArrayLike],
309
- weights: Optional[NormalizationParamState],
310
- reduction_axes: Axes,
311
- feature_axes: Axes,
312
- dtype: DTypeLike,
313
- epsilon: jax.typing.ArrayLike,
314
- ):
315
- """
316
- Normalize the input and optionally apply learned scale and bias.
317
-
318
- Parameters
319
- ----------
320
- x : ArrayLike
321
- The input array.
322
- mean : ArrayLike, optional
323
- Mean to use for normalization. If None, normalization is skipped.
324
- var : ArrayLike, optional
325
- Variance to use for normalization. If None, normalization is skipped.
326
- weights : NormalizationParamState, optional
327
- The scale and bias parameters. If None, no affine transformation is applied.
328
- reduction_axes : Axes
329
- The axes in ``x`` to reduce.
330
- feature_axes : Axes
331
- The feature axes to apply the scale and bias.
332
- dtype : DTypeLike
333
- The dtype of the result. If None, inferred from input and parameters.
334
- epsilon : jax.typing.ArrayLike
335
- A small value added to variance to avoid division by zero.
336
-
337
- Returns
338
- -------
339
- jax.Array
340
- The normalized input array.
341
- """
342
- if mean is not None:
343
- assert var is not None, 'mean and val must be both None or not None.'
344
- reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
345
- feature_axes = _canonicalize_axes(x.ndim, feature_axes)
346
- stats_shape = list(x.shape)
347
- for axis in reduction_axes:
348
- stats_shape[axis] = 1
349
- mean = mean.reshape(stats_shape)
350
- var = var.reshape(stats_shape)
351
- feature_shape = [1] * x.ndim
352
- for ax in feature_axes:
353
- feature_shape[ax] = x.shape[ax]
354
- y = x - mean
355
- mul = jax.lax.rsqrt(var + epsilon)
356
- y = y * mul
357
- if weights is not None:
358
- y = weights.execute(y)
359
- dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
360
- else:
361
- assert var is None, 'mean and val must be both None or not None.'
362
- assert weights is None, 'scale and bias are not supported without mean and val'
363
- y = x
364
- return jnp.asarray(y, dtype)
365
-
366
-
367
- class _BatchNorm(Module):
368
- __module__ = 'brainstate.nn'
369
- num_spatial_dims: int
370
-
371
- def __init__(
372
- self,
373
- in_size: Size,
374
- feature_axis: Axes = -1,
375
- *,
376
- track_running_stats: bool = True,
377
- epsilon: float = 1e-5,
378
- momentum: float = 0.99,
379
- affine: bool = True,
380
- bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
381
- scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
382
- axis_name: Optional[Union[str, Sequence[str]]] = None,
383
- axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
384
- use_fast_variance: bool = True,
385
- name: Optional[str] = None,
386
- dtype: Any = None,
387
- param_type: type = NormalizationParamState,
388
- mean_type: type = BatchState,
389
- ):
390
- super().__init__(name=name)
391
-
392
- # parameters
393
- self.in_size = in_size
394
- self.out_size = in_size
395
- self.affine = affine
396
- self.bias_initializer = bias_initializer
397
- self.scale_initializer = scale_initializer
398
- self.dtype = dtype or environ.dftype()
399
- self.track_running_stats = track_running_stats
400
- self.momentum = jnp.asarray(momentum, dtype=self.dtype)
401
- self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
402
- self.use_fast_variance = use_fast_variance
403
-
404
- # parameters about axis
405
- feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
406
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
407
- self.axis_name = axis_name
408
- self.axis_index_groups = axis_index_groups
409
-
410
- # variables
411
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
412
- for i, ax in enumerate(self.in_size)])
413
- if self.track_running_stats:
414
- self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
415
- self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
416
- else:
417
- self.running_mean = None
418
- self.running_var = None
419
-
420
- # parameters
421
- if self.affine:
422
- assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
423
- bias = init.param(self.bias_initializer, feature_shape)
424
- scale = init.param(self.scale_initializer, feature_shape)
425
- self.weight = param_type(dict(bias=bias, scale=scale))
426
- else:
427
- self.weight = None
428
-
429
- def update(self, x, mask: Optional[jax.Array] = None):
430
- # input shape and batch mode or not
431
- if x.ndim == self.num_spatial_dims + 2:
432
- x_shape = x.shape[1:]
433
- batch = True
434
- elif x.ndim == self.num_spatial_dims + 1:
435
- x_shape = x.shape
436
- batch = False
437
- else:
438
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
439
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
440
- if self.in_size != x_shape:
441
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
442
-
443
- # reduce the feature axis
444
- if batch:
445
- reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
446
- else:
447
- reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
448
-
449
- # fitting phase
450
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
451
-
452
- # compute the running mean and variance
453
- if self.track_running_stats:
454
- if fit_phase:
455
- mean, var = _compute_stats(
456
- x,
457
- reduction_axes,
458
- dtype=self.dtype,
459
- axis_name=self.axis_name,
460
- axis_index_groups=self.axis_index_groups,
461
- use_fast_variance=self.use_fast_variance,
462
- mask=mask,
463
- )
464
- self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
465
- self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
466
- else:
467
- mean = self.running_mean.value
468
- var = self.running_var.value
469
- else:
470
- mean, var = None, None
471
-
472
- # normalize
473
- return _normalize(
474
- x,
475
- mean=mean,
476
- var=var,
477
- weights=self.weight,
478
- reduction_axes=reduction_axes,
479
- feature_axes=self.feature_axes,
480
- dtype=self.dtype,
481
- epsilon=self.epsilon
482
- )
483
-
484
-
485
- class BatchNorm0d(_BatchNorm):
486
- """
487
- 0-D batch normalization.
488
-
489
- Normalizes a batch of 0-D data (vectors) by fixing the mean and variance
490
- of inputs on each feature (channel). This layer aims to reduce the internal
491
- covariate shift of data.
492
-
493
- The input data should have shape ``(b, c)``, where ``b`` is the batch dimension
494
- and ``c`` is the channel dimension.
495
-
496
- The normalization is performed as:
497
-
498
- .. math::
499
- y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\operatorname{Var}[x] + \\epsilon}} \\cdot \\gamma + \\beta
500
-
501
- where :math:`\\gamma` and :math:`\\beta` are learnable affine parameters (if ``affine=True``).
502
-
503
- Parameters
504
- ----------
505
- in_size : tuple of int
506
- The input shape, without batch dimension.
507
- feature_axis : int or tuple of int, optional
508
- The feature or non-batch axis of the input. Default is -1.
509
- track_running_stats : bool, optional
510
- If True, tracks the running mean and variance. If False, uses batch
511
- statistics in both training and eval modes. Default is True.
512
- epsilon : float, optional
513
- A value added to the denominator for numerical stability. Default is 1e-5.
514
- momentum : float, optional
515
- The momentum value used for the ``running_mean`` and ``running_var``
516
- computation. The update rule is:
517
- :math:`\\hat{x}_{\\text{new}} = \\text{momentum} \\times \\hat{x} + (1 - \\text{momentum}) \\times x_t`.
518
- Default is 0.99.
519
- affine : bool, optional
520
- If True, this module has learnable affine parameters (scale and bias).
521
- Default is True.
522
- bias_initializer : ArrayLike or Callable, optional
523
- Initializer for the bias (beta) parameter. Default is ``init.Constant(0.)``.
524
- scale_initializer : ArrayLike or Callable, optional
525
- Initializer for the scale (gamma) parameter. Default is ``init.Constant(1.)``.
526
- axis_name : str or sequence of str, optional
527
- The axis name(s) for parallel reduction using ``jax.pmap`` or ``jax.vmap``.
528
- If specified, batch statistics are calculated across all replicas on the
529
- named axes. Default is None.
530
- axis_index_groups : sequence of sequence of int, optional
531
- Groups of axis indices within the named axis representing subsets of
532
- devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
533
- independently batch-normalize over the first two and last two devices.
534
- See ``jax.lax.psum`` for more details. Default is None.
535
- use_fast_variance : bool, optional
536
- If True, use a faster but less numerically stable calculation for
537
- the variance. Default is True.
538
-
539
- Notes
540
- -----
541
- The ``momentum`` parameter is different from the conventional notion of
542
- momentum used in optimizers.
543
-
544
- References
545
- ----------
546
- .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
547
- Deep Network Training by Reducing Internal Covariate Shift.
548
- In International Conference on Machine Learning (pp. 448-456).
549
-
550
- Examples
551
- --------
552
- .. code-block:: python
553
-
554
- >>> import brainstate as bst
555
- >>> import jax.numpy as jnp
556
- >>>
557
- >>> # Create a BatchNorm0d layer
558
- >>> layer = bst.nn.BatchNorm0d(in_size=(10,))
559
- >>>
560
- >>> # Apply normalization to a batch of data
561
- >>> x = jnp.ones((32, 10)) # batch_size=32, features=10
562
- >>> y = layer(x)
563
- >>>
564
- >>> # Check output shape
565
- >>> print(y.shape)
566
- (32, 10)
567
- """
568
- __module__ = 'brainstate.nn'
569
- num_spatial_dims: int = 0
570
-
571
-
572
- class BatchNorm1d(_BatchNorm):
573
- """
574
- 1-D batch normalization.
575
-
576
- Normalizes a batch of 1-D data by fixing the mean and variance of inputs
577
- on each feature (channel). This layer aims to reduce the internal covariate
578
- shift of data.
579
-
580
- The input data should have shape ``(b, l, c)``, where ``b`` is the batch
581
- dimension, ``l`` is the spatial/sequence dimension, and ``c`` is the channel
582
- dimension.
583
-
584
- Parameters
585
- ----------
586
- in_size : tuple of int
587
- The input shape, without batch dimension. For 1-D data, typically ``(l, c)``.
588
- feature_axis : int or tuple of int, optional
589
- The feature or non-batch axis of the input. Default is -1.
590
- track_running_stats : bool, optional
591
- If True, tracks the running mean and variance. If False, uses batch
592
- statistics in both training and eval modes. Default is True.
593
- epsilon : float, optional
594
- A value added to the denominator for numerical stability. Default is 1e-5.
595
- momentum : float, optional
596
- The momentum value for running statistics computation. Default is 0.99.
597
- affine : bool, optional
598
- If True, has learnable affine parameters (scale and bias). Default is True.
599
- bias_initializer : ArrayLike or Callable, optional
600
- Initializer for the bias parameter. Default is ``init.Constant(0.)``.
601
- scale_initializer : ArrayLike or Callable, optional
602
- Initializer for the scale parameter. Default is ``init.Constant(1.)``.
603
- axis_name : str or sequence of str, optional
604
- Axis name(s) for parallel reduction. Default is None.
605
- axis_index_groups : sequence of sequence of int, optional
606
- Groups of axis indices for device-grouped reduction. Default is None.
607
- use_fast_variance : bool, optional
608
- If True, use faster but less stable variance calculation. Default is True.
609
-
610
- References
611
- ----------
612
- .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
613
- Deep Network Training by Reducing Internal Covariate Shift.
614
- In International Conference on Machine Learning (pp. 448-456).
615
-
616
- See Also
617
- --------
618
- BatchNorm0d : 0-D batch normalization
619
- BatchNorm2d : 2-D batch normalization
620
- BatchNorm3d : 3-D batch normalization
621
-
622
- Examples
623
- --------
624
- .. code-block:: python
625
-
626
- >>> import brainstate as bst
627
- >>> import jax.numpy as jnp
628
- >>>
629
- >>> # Create a BatchNorm1d layer for sequence data
630
- >>> layer = bst.nn.BatchNorm1d(in_size=(100, 64)) # length=100, channels=64
631
- >>>
632
- >>> # Apply normalization
633
- >>> x = jnp.ones((8, 100, 64)) # batch_size=8
634
- >>> y = layer(x)
635
- >>> print(y.shape)
636
- (8, 100, 64)
637
- """
638
- __module__ = 'brainstate.nn'
639
- num_spatial_dims: int = 1
640
-
641
-
642
- class BatchNorm2d(_BatchNorm):
643
- """
644
- 2-D batch normalization.
645
-
646
- Normalizes a batch of 2-D data (e.g., images) by fixing the mean and variance
647
- of inputs on each feature (channel). This layer aims to reduce the internal
648
- covariate shift of data.
649
-
650
- The input data should have shape ``(b, h, w, c)``, where ``b`` is the batch
651
- dimension, ``h`` is the height dimension, ``w`` is the width dimension, and
652
- ``c`` is the channel dimension.
653
-
654
- Parameters
655
- ----------
656
- in_size : tuple of int
657
- The input shape, without batch dimension. For 2-D data, typically ``(h, w, c)``.
658
- feature_axis : int or tuple of int, optional
659
- The feature or non-batch axis of the input. Default is -1.
660
- track_running_stats : bool, optional
661
- If True, tracks the running mean and variance. If False, uses batch
662
- statistics in both training and eval modes. Default is True.
663
- epsilon : float, optional
664
- A value added to the denominator for numerical stability. Default is 1e-5.
665
- momentum : float, optional
666
- The momentum value for running statistics computation. Default is 0.99.
667
- affine : bool, optional
668
- If True, has learnable affine parameters (scale and bias). Default is True.
669
- bias_initializer : ArrayLike or Callable, optional
670
- Initializer for the bias parameter. Default is ``init.Constant(0.)``.
671
- scale_initializer : ArrayLike or Callable, optional
672
- Initializer for the scale parameter. Default is ``init.Constant(1.)``.
673
- axis_name : str or sequence of str, optional
674
- Axis name(s) for parallel reduction. Default is None.
675
- axis_index_groups : sequence of sequence of int, optional
676
- Groups of axis indices for device-grouped reduction. Default is None.
677
- use_fast_variance : bool, optional
678
- If True, use faster but less stable variance calculation. Default is True.
679
-
680
- References
681
- ----------
682
- .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
683
- Deep Network Training by Reducing Internal Covariate Shift.
684
- In International Conference on Machine Learning (pp. 448-456).
685
-
686
- See Also
687
- --------
688
- BatchNorm0d : 0-D batch normalization
689
- BatchNorm1d : 1-D batch normalization
690
- BatchNorm3d : 3-D batch normalization
691
-
692
- Examples
693
- --------
694
- .. code-block:: python
695
-
696
- >>> import brainstate as bst
697
- >>> import jax.numpy as jnp
698
- >>>
699
- >>> # Create a BatchNorm2d layer for image data
700
- >>> layer = bst.nn.BatchNorm2d(in_size=(28, 28, 3)) # 28x28 RGB images
701
- >>>
702
- >>> # Apply normalization
703
- >>> x = jnp.ones((16, 28, 28, 3)) # batch_size=16
704
- >>> y = layer(x)
705
- >>> print(y.shape)
706
- (16, 28, 28, 3)
707
- """
708
- __module__ = 'brainstate.nn'
709
- num_spatial_dims: int = 2
710
-
711
-
712
- class BatchNorm3d(_BatchNorm):
713
- """
714
- 3-D batch normalization.
715
-
716
- Normalizes a batch of 3-D data (e.g., video or volumetric data) by fixing
717
- the mean and variance of inputs on each feature (channel). This layer aims
718
- to reduce the internal covariate shift of data.
719
-
720
- The input data should have shape ``(b, h, w, d, c)``, where ``b`` is the
721
- batch dimension, ``h`` is the height dimension, ``w`` is the width dimension,
722
- ``d`` is the depth dimension, and ``c`` is the channel dimension.
723
-
724
- Parameters
725
- ----------
726
- in_size : tuple of int
727
- The input shape, without batch dimension. For 3-D data, typically ``(h, w, d, c)``.
728
- feature_axis : int or tuple of int, optional
729
- The feature or non-batch axis of the input. Default is -1.
730
- track_running_stats : bool, optional
731
- If True, tracks the running mean and variance. If False, uses batch
732
- statistics in both training and eval modes. Default is True.
733
- epsilon : float, optional
734
- A value added to the denominator for numerical stability. Default is 1e-5.
735
- momentum : float, optional
736
- The momentum value for running statistics computation. Default is 0.99.
737
- affine : bool, optional
738
- If True, has learnable affine parameters (scale and bias). Default is True.
739
- bias_initializer : ArrayLike or Callable, optional
740
- Initializer for the bias parameter. Default is ``init.Constant(0.)``.
741
- scale_initializer : ArrayLike or Callable, optional
742
- Initializer for the scale parameter. Default is ``init.Constant(1.)``.
743
- axis_name : str or sequence of str, optional
744
- Axis name(s) for parallel reduction. Default is None.
745
- axis_index_groups : sequence of sequence of int, optional
746
- Groups of axis indices for device-grouped reduction. Default is None.
747
- use_fast_variance : bool, optional
748
- If True, use faster but less stable variance calculation. Default is True.
749
-
750
- References
751
- ----------
752
- .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
753
- Deep Network Training by Reducing Internal Covariate Shift.
754
- In International Conference on Machine Learning (pp. 448-456).
755
-
756
- See Also
757
- --------
758
- BatchNorm0d : 0-D batch normalization
759
- BatchNorm1d : 1-D batch normalization
760
- BatchNorm2d : 2-D batch normalization
761
-
762
- Examples
763
- --------
764
- .. code-block:: python
765
-
766
- >>> import brainstate as bst
767
- >>> import jax.numpy as jnp
768
- >>>
769
- >>> # Create a BatchNorm3d layer for volumetric data
770
- >>> layer = bst.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes
771
- >>>
772
- >>> # Apply normalization
773
- >>> x = jnp.ones((4, 32, 32, 32, 1)) # batch_size=4
774
- >>> y = layer(x)
775
- >>> print(y.shape)
776
- (4, 32, 32, 32, 1)
777
- """
778
- __module__ = 'brainstate.nn'
779
- num_spatial_dims: int = 3
780
-
781
-
782
- class LayerNorm(Module):
783
- """
784
- Layer normalization layer [1]_.
785
-
786
- LayerNorm normalizes the activations of the layer for each given example in
787
- a batch independently, rather than across a batch like Batch Normalization.
788
- It applies a transformation that maintains the mean activation within each
789
- example close to 0 and the activation standard deviation close to 1.
790
-
791
- Parameters
792
- ----------
793
- in_size : tuple of int
794
- The input shape, without batch dimension.
795
- reduction_axes : int or tuple of int, optional
796
- Axes for computing normalization statistics. It is recommended to use
797
- negative integers, as positive integers may cause issues when batch
798
- dimensions are present. Default is -1.
799
- feature_axes : int or tuple of int, optional
800
- Feature axes for learned bias and scaling. Default is -1.
801
- epsilon : float, optional
802
- A small value added to variance to avoid division by zero. Default is 1e-6.
803
- use_bias : bool, optional
804
- If True, bias (beta) is added. Default is True.
805
- use_scale : bool, optional
806
- If True, multiply by scale (gamma). When the next layer is linear
807
- (e.g., nn.relu), this can be disabled since scaling will be done by
808
- the next layer. Default is True.
809
- bias_init : Callable, optional
810
- Initializer for bias parameter. Default is ``init.ZeroInit()``.
811
- scale_init : Callable, optional
812
- Initializer for scale parameter. Default is ``init.Constant(1.0)``.
813
- axis_name : str, optional
814
- The axis name used to combine batch statistics from multiple devices.
815
- See ``jax.pmap`` for axis name description. Only needed if the model
816
- is subdivided across devices. Default is None.
817
- axis_index_groups : sequence, optional
818
- Groups of axis indices within the named axis representing subsets of
819
- devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
820
- independently normalize over the first two and last two devices.
821
- See ``jax.lax.psum`` for details. Default is None.
822
- use_fast_variance : bool, optional
823
- If True, use a faster but less numerically stable calculation for
824
- the variance. Default is True.
825
- dtype : jax.typing.DTypeLike, optional
826
- The dtype of the result. If None, inferred from input and parameters.
827
- Default is None.
828
-
829
- References
830
- ----------
831
- .. [1] Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization.
832
- arXiv preprint arXiv:1607.06450.
833
-
834
- See Also
835
- --------
836
- RMSNorm : Root Mean Square Layer Normalization
837
- GroupNorm : Group Normalization
838
- BatchNorm1d : 1-D Batch Normalization
839
-
840
- Examples
841
- --------
842
- .. code-block:: python
843
-
844
- >>> import brainstate as bst
845
- >>>
846
- >>> # Create a LayerNorm layer
847
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
848
- >>> layer = bst.nn.LayerNorm(x.shape)
849
- >>>
850
- >>> # Apply normalization
851
- >>> y = layer(x)
852
- >>> print(y.shape)
853
- (3, 4, 5, 6)
854
- >>>
855
- >>> # Normalize only the last dimension
856
- >>> layer = bst.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
857
- >>> x = bst.random.normal((5, 10, 20))
858
- >>> y = layer(x)
859
- """
860
-
861
- def __init__(
862
- self,
863
- in_size: Size,
864
- reduction_axes: Axes = -1,
865
- feature_axes: Axes = -1,
866
- *,
867
- epsilon: float = 1e-6,
868
- use_bias: bool = True,
869
- use_scale: bool = True,
870
- bias_init: Callable = init.ZeroInit(),
871
- scale_init: Callable = init.Constant(1.0),
872
- axis_name: Optional[str] = None,
873
- axis_index_groups: Any = None,
874
- use_fast_variance: bool = True,
875
- dtype: Optional[jax.typing.DTypeLike] = None,
876
- param_type: type = NormalizationParamState,
877
- ):
878
- super().__init__()
879
-
880
- self.in_size = in_size
881
- self.out_size = in_size
882
-
883
- # parameters about axis
884
- feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
885
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
886
- self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
887
- self.axis_name = axis_name
888
- self.axis_index_groups = axis_index_groups
889
-
890
- # variables
891
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
892
- for i, ax in enumerate(self.in_size)])
893
-
894
- weights = dict()
895
- if use_scale:
896
- weights['scale'] = init.param(scale_init, feature_shape)
897
- if use_bias:
898
- weights['bias'] = init.param(bias_init, feature_shape)
899
- if len(weights):
900
- self.weight = param_type(weights)
901
- else:
902
- self.weight = None
903
-
904
- # parameters
905
- self.epsilon = epsilon
906
- self.dtype = dtype or environ.dftype()
907
- self.use_bias = use_bias
908
- self.use_scale = use_scale
909
- self.bias_init = bias_init
910
- self.scale_init = scale_init
911
- self.use_fast_variance = use_fast_variance
912
-
913
- def update(self, x, *, mask: Optional[jax.Array] = None):
914
- """
915
- Apply layer normalization on the input.
916
-
917
- Parameters
918
- ----------
919
- x : jax.Array
920
- The input array.
921
- mask : jax.Array, optional
922
- Binary array of shape broadcastable to ``x``, indicating the
923
- positions for which normalization should be computed. Default is None.
924
-
925
- Returns
926
- -------
927
- jax.Array
928
- Normalized inputs with the same shape as the input.
929
- """
930
- mean, var = _compute_stats(
931
- x,
932
- self.reduction_axes,
933
- dtype=self.dtype,
934
- axis_name=self.axis_name,
935
- axis_index_groups=self.axis_index_groups,
936
- use_fast_variance=self.use_fast_variance,
937
- mask=mask,
938
- )
939
-
940
- return _normalize(
941
- x,
942
- mean=mean,
943
- var=var,
944
- weights=self.weight,
945
- reduction_axes=self.reduction_axes,
946
- feature_axes=self.feature_axes,
947
- dtype=self.dtype,
948
- epsilon=self.epsilon,
949
- )
950
-
951
-
952
- class RMSNorm(Module):
953
- """
954
- Root Mean Square Layer Normalization [1]_.
955
-
956
- RMSNorm normalizes the activations of the layer for each given example in a
957
- batch independently, rather than across a batch like Batch Normalization.
958
- Unlike LayerNorm which re-centers the mean to 0 and normalizes by the standard
959
- deviation, RMSNorm does not re-center at all and instead normalizes by the
960
- root mean square of the activations.
961
-
962
- Parameters
963
- ----------
964
- in_size : tuple of int
965
- The input shape, without batch dimension.
966
- epsilon : float, optional
967
- A small value added to variance to avoid division by zero. Default is 1e-6.
968
- dtype : jax.typing.DTypeLike, optional
969
- The dtype of the result. If None, inferred from input and parameters.
970
- Default is None.
971
- use_scale : bool, optional
972
- If True, multiply by scale (gamma). When the next layer is linear
973
- (e.g., nn.relu), this can be disabled since scaling will be done by
974
- the next layer. Default is True.
975
- scale_init : Callable, optional
976
- Initializer for scale parameter. Default is ``init.Constant(1.0)``.
977
- reduction_axes : int or tuple of int, optional
978
- Axes for computing normalization statistics. It is recommended to use
979
- negative integers. Default is -1.
980
- feature_axes : int or tuple of int, optional
981
- Feature axes for learned scaling. Default is -1.
982
- axis_name : str, optional
983
- The axis name used to combine batch statistics from multiple devices.
984
- See ``jax.pmap`` for details. Default is None.
985
- axis_index_groups : sequence, optional
986
- Groups of axis indices within the named axis representing subsets of
987
- devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
988
- independently normalize over the first two and last two devices.
989
- Default is None.
990
- use_fast_variance : bool, optional
991
- If True, use a faster but less numerically stable calculation for
992
- the variance. Default is True.
993
-
994
- References
995
- ----------
996
- .. [1] Zhang, B., & Sennrich, R. (2019). Root Mean Square Layer Normalization.
997
- Advances in Neural Information Processing Systems, 32.
998
-
999
- See Also
1000
- --------
1001
- LayerNorm : Layer Normalization
1002
- GroupNorm : Group Normalization
1003
-
1004
- Examples
1005
- --------
1006
- .. code-block:: python
1007
-
1008
- >>> import brainstate as bst
1009
- >>>
1010
- >>> # Create an RMSNorm layer
1011
- >>> x = bst.random.normal(size=(5, 6))
1012
- >>> layer = bst.nn.RMSNorm(in_size=(6,))
1013
- >>>
1014
- >>> # Apply normalization
1015
- >>> y = layer(x)
1016
- >>> print(y.shape)
1017
- (5, 6)
1018
- >>>
1019
- >>> # Without scaling
1020
- >>> layer = bst.nn.RMSNorm(in_size=(10,), use_scale=False)
1021
- >>> x = bst.random.normal((3, 10))
1022
- >>> y = layer(x)
1023
- """
1024
-
1025
- def __init__(
1026
- self,
1027
- in_size: Size,
1028
- *,
1029
- epsilon: float = 1e-6,
1030
- dtype: Optional[jax.typing.DTypeLike] = None,
1031
- use_scale: bool = True,
1032
- scale_init: Callable = init.Constant(1.0),
1033
- reduction_axes: Axes = -1,
1034
- feature_axes: Axes = -1,
1035
- axis_name: Optional[str] = None,
1036
- axis_index_groups: Any = None,
1037
- use_fast_variance: bool = True,
1038
- param_type: type = NormalizationParamState,
1039
- ):
1040
- super().__init__()
1041
-
1042
- self.in_size = in_size
1043
- self.out_size = in_size
1044
-
1045
- # parameters about axis
1046
- feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
1047
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
1048
- self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
1049
- self.axis_name = axis_name
1050
- self.axis_index_groups = axis_index_groups
1051
-
1052
- # variables
1053
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
1054
- for i, ax in enumerate(self.in_size)])
1055
- if use_scale:
1056
- self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
1057
- else:
1058
- self.scale = None
1059
-
1060
- # parameters
1061
- self.epsilon = epsilon
1062
- self.dtype = dtype or environ.dftype()
1063
- self.use_scale = use_scale
1064
- self.scale_init = scale_init
1065
- self.use_fast_variance = use_fast_variance
1066
-
1067
- def update(self, x, *, mask: Optional[jax.Array] = None):
1068
- """
1069
- Apply RMS normalization on the input.
1070
-
1071
- Parameters
1072
- ----------
1073
- x : jax.Array
1074
- The input array.
1075
- mask : jax.Array, optional
1076
- Binary array of shape broadcastable to ``x``, indicating the
1077
- positions for which normalization should be computed. Default is None.
1078
-
1079
- Returns
1080
- -------
1081
- jax.Array
1082
- Normalized inputs with the same shape as the input.
1083
- """
1084
- mean, var = _compute_stats(
1085
- x,
1086
- self.reduction_axes,
1087
- dtype=self.dtype,
1088
- axis_name=self.axis_name,
1089
- axis_index_groups=self.axis_index_groups,
1090
- use_mean=False,
1091
- use_fast_variance=self.use_fast_variance,
1092
- mask=mask,
1093
- )
1094
-
1095
- return _normalize(
1096
- x,
1097
- mean=mean,
1098
- var=var,
1099
- weights=self.scale,
1100
- reduction_axes=self.reduction_axes,
1101
- feature_axes=self.feature_axes,
1102
- dtype=self.dtype,
1103
- epsilon=self.epsilon,
1104
- )
1105
-
1106
-
1107
- class GroupNorm(Module):
1108
- """
1109
- Group Normalization layer [1]_.
1110
-
1111
- Group normalization is similar to batch normalization, but statistics are
1112
- shared across equally-sized groups of channels and not shared across the
1113
- batch dimension. Thus, group normalization does not depend on the batch
1114
- composition and does not require maintaining internal state for storing statistics.
1115
-
1116
- The user should specify either the total number of channel groups (``num_groups``)
1117
- or the number of channels per group (``group_size``).
1118
-
1119
- Parameters
1120
- ----------
1121
- in_size : tuple of int
1122
- The input shape, without batch dimension.
1123
- feature_axis : int or tuple of int, optional
1124
- The feature axis of the input. Default is -1.
1125
- num_groups : int, optional
1126
- The total number of channel groups. The default value of 32 is proposed
1127
- by the original group normalization paper. Either ``num_groups`` or
1128
- ``group_size`` must be specified, but not both. Default is 32.
1129
- group_size : int, optional
1130
- The number of channels in each group. Either ``num_groups`` or
1131
- ``group_size`` must be specified, but not both. Default is None.
1132
- epsilon : float, optional
1133
- A small value added to variance to avoid division by zero. Default is 1e-6.
1134
- dtype : jax.typing.DTypeLike, optional
1135
- The dtype of the result. If None, inferred from input and parameters.
1136
- Default is None.
1137
- use_bias : bool, optional
1138
- If True, bias (beta) is added. Default is True.
1139
- use_scale : bool, optional
1140
- If True, multiply by scale (gamma). When the next layer is linear
1141
- (e.g., nn.relu), this can be disabled. Default is True.
1142
- bias_init : Callable, optional
1143
- Initializer for bias parameter. Default is ``init.ZeroInit()``.
1144
- scale_init : Callable, optional
1145
- Initializer for scale parameter. Default is ``init.Constant(1.)``.
1146
- reduction_axes : int or tuple of int, optional
1147
- List of axes used for computing normalization statistics. Must include
1148
- the final dimension (feature axis). It is recommended to use negative
1149
- integers. Default is None.
1150
- axis_name : str, optional
1151
- The axis name used to combine batch statistics from multiple devices.
1152
- See ``jax.pmap`` for details. Default is None.
1153
- axis_index_groups : sequence, optional
1154
- Groups of axis indices within the named axis representing subsets of
1155
- devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
1156
- independently normalize over the first two and last two devices.
1157
- Default is None.
1158
- use_fast_variance : bool, optional
1159
- If True, use a faster but less numerically stable calculation for
1160
- the variance. Default is True.
1161
-
1162
- Notes
1163
- -----
1164
- LayerNorm is a special case of GroupNorm where ``num_groups=1``.
1165
-
1166
- References
1167
- ----------
1168
- .. [1] Wu, Y., & He, K. (2018). Group Normalization.
1169
- In Proceedings of the European Conference on Computer Vision (ECCV)
1170
- (pp. 3-19).
1171
-
1172
- See Also
1173
- --------
1174
- LayerNorm : Layer Normalization
1175
- BatchNorm2d : 2-D Batch Normalization
1176
-
1177
- Examples
1178
- --------
1179
- .. code-block:: python
1180
-
1181
- >>> import numpy as np
1182
- >>> import brainstate as bst
1183
- >>>
1184
- >>> # Create a GroupNorm layer with 3 groups
1185
- >>> x = bst.random.normal(size=(3, 4, 5, 6))
1186
- >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
1187
- >>> y = layer(x)
1188
- >>>
1189
- >>> # GroupNorm with num_groups=1 is equivalent to LayerNorm
1190
- >>> y1 = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
1191
- >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
1192
- >>> np.testing.assert_allclose(y1, y2, rtol=1e-5)
1193
- >>>
1194
- >>> # Specify group_size instead of num_groups
1195
- >>> layer = bst.nn.GroupNorm((12,), num_groups=None, group_size=4)
1196
- """
1197
-
1198
- def __init__(
1199
- self,
1200
- in_size: Size,
1201
- feature_axis: Axes = -1,
1202
- num_groups: Optional[int] = 32,
1203
- group_size: Optional[int] = None,
1204
- *,
1205
- epsilon: float = 1e-6,
1206
- dtype: Optional[jax.typing.DTypeLike] = None,
1207
- use_bias: bool = True,
1208
- use_scale: bool = True,
1209
- bias_init: Callable = init.ZeroInit(),
1210
- scale_init: Callable = init.Constant(1.),
1211
- reduction_axes: Optional[Axes] = None,
1212
- axis_name: Optional[str] = None,
1213
- axis_index_groups: Any = None,
1214
- use_fast_variance: bool = True,
1215
- param_type: type = NormalizationParamState,
1216
- ):
1217
- super().__init__()
1218
-
1219
- self.in_size = in_size
1220
- self.out_size = in_size
1221
-
1222
- # parameters about axis
1223
- feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
1224
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
1225
- self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
1226
- self.axis_name = axis_name
1227
- self.axis_index_groups = axis_index_groups
1228
-
1229
- if (num_groups is None and group_size is None) or (
1230
- num_groups is not None and group_size is not None
1231
- ):
1232
- raise ValueError(
1233
- 'Either `num_groups` or `group_size` should be '
1234
- 'specified. If `group_size` is to be specified, '
1235
- 'pass `num_groups=None` as argument to override '
1236
- 'the default `num_groups` value of 32.'
1237
- )
1238
-
1239
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
1240
- for i, ax in enumerate(self.in_size)])
1241
- assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
1242
- num_features = feature_shape[0]
1243
- if group_size is not None:
1244
- if num_features % group_size != 0:
1245
- raise ValueError(
1246
- 'Number of features ({}) is not multiple of the '
1247
- 'group size ({}).'.format(num_features, group_size)
1248
- )
1249
- self.num_groups = num_features // group_size
1250
- self.group_size = group_size
1251
- else:
1252
- if not isinstance(num_groups, int) or num_groups <= 0 or (
1253
- num_features % num_groups != 0
1254
- ):
1255
- raise ValueError(
1256
- 'Number of groups ({}) does not divide the number'
1257
- ' of channels ({}).'.format(num_groups, num_features)
1258
- )
1259
- self.num_groups = num_groups
1260
- self.group_size = num_features // num_groups
1261
-
1262
- # variables
1263
- weights = dict()
1264
- if use_scale:
1265
- weights['scale'] = init.param(scale_init, feature_shape)
1266
- if use_bias:
1267
- weights['bias'] = init.param(bias_init, feature_shape)
1268
- if len(weights):
1269
- self.weight = param_type(weights)
1270
- else:
1271
- self.weight = None
1272
-
1273
- # parameters
1274
- self.epsilon = epsilon
1275
- self.dtype = dtype
1276
- self.use_bias = use_bias
1277
- self.use_scale = use_scale
1278
- self.bias_init = bias_init
1279
- self.scale_init = scale_init
1280
- self.use_fast_variance = use_fast_variance
1281
-
1282
- def update(self, x, *, mask: Optional[jax.Array] = None):
1283
- """
1284
- Apply group normalization to the input.
1285
-
1286
- Parameters
1287
- ----------
1288
- x : jax.Array
1289
- The input of shape ``...C`` where ``C`` is the channels dimension
1290
- and ``...`` represents an arbitrary number of extra dimensions. If no
1291
- reduction axes have been specified, all additional dimensions will be
1292
- used to accumulate statistics apart from the leading dimension which
1293
- is assumed to represent the batch.
1294
- mask : jax.Array, optional
1295
- Binary array of shape broadcastable to ``x``, indicating the
1296
- positions for which the mean and variance should be computed.
1297
- Default is None.
1298
-
1299
- Returns
1300
- -------
1301
- jax.Array
1302
- Normalized inputs with the same shape as the input.
1303
- """
1304
- if self.reduction_axes is not None:
1305
- reduction_axes = self.reduction_axes
1306
- else:
1307
- reduction_axes = list(range(1, x.ndim - 1)) + [-1]
1308
- reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
1309
-
1310
- group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
1311
- if mask is not None:
1312
- mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
1313
-
1314
- mean, var = _compute_stats(
1315
- x.reshape(group_shape),
1316
- list(reduction_axes[:-1]) + [-1],
1317
- dtype=self.dtype,
1318
- axis_name=self.axis_name,
1319
- axis_index_groups=self.axis_index_groups,
1320
- use_fast_variance=self.use_fast_variance,
1321
- mask=mask,
1322
- )
1323
- mean = jnp.repeat(mean, self.group_size, axis=1)
1324
- var = jnp.repeat(var, self.group_size, axis=1)
1325
- return _normalize(
1326
- x,
1327
- mean=mean,
1328
- var=var,
1329
- weights=self.weight,
1330
- reduction_axes=reduction_axes[:-1],
1331
- feature_axes=self.feature_axes,
1332
- dtype=self.dtype,
1333
- epsilon=self.epsilon,
1334
- )
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Callable, Union, Sequence, Optional, Any
19
+
20
+ import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ from brainstate import environ
25
+ from brainstate._state import ParamState, BatchState
26
+ from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
27
+ from . import init as init
28
+ from ._module import Module
29
+
30
+ __all__ = [
31
+ 'weight_standardization',
32
+ 'BatchNorm0d',
33
+ 'BatchNorm1d',
34
+ 'BatchNorm2d',
35
+ 'BatchNorm3d',
36
+ 'LayerNorm',
37
+ 'RMSNorm',
38
+ 'GroupNorm',
39
+ ]
40
+
41
+
42
+ def weight_standardization(
43
+ w: ArrayLike,
44
+ eps: float = 1e-4,
45
+ gain: Optional[jax.Array] = None,
46
+ out_axis: int = -1,
47
+ ) -> Union[jax.Array, u.Quantity]:
48
+ """
49
+ Scaled Weight Standardization.
50
+
51
+ Applies weight standardization to improve training stability, as described in
52
+ "Micro-Batch Training with Batch-Channel Normalization and Weight Standardization" [1]_.
53
+
54
+ Parameters
55
+ ----------
56
+ w : ArrayLike
57
+ The weight tensor to be standardized.
58
+ eps : float, optional
59
+ A small value added to variance to avoid division by zero. Default is 1e-4.
60
+ gain : jax.Array, optional
61
+ Optional gain parameter to scale the standardized weights. Default is None.
62
+ out_axis : int, optional
63
+ The output axis of the weight tensor. Default is -1.
64
+
65
+ Returns
66
+ -------
67
+ jax.Array or u.Quantity
68
+ The standardized weight tensor with the same shape as input.
69
+
70
+ References
71
+ ----------
72
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
73
+ Micro-Batch Training with Batch-Channel Normalization and Weight Standardization.
74
+ arXiv preprint arXiv:1903.10520.
75
+
76
+ Examples
77
+ --------
78
+ .. code-block:: python
79
+
80
+ >>> import brainstate as bst
81
+ >>> import jax.numpy as jnp
82
+ >>>
83
+ >>> # Standardize a weight matrix
84
+ >>> w = jnp.ones((3, 4))
85
+ >>> w_std = bst.nn.weight_standardization(w)
86
+ >>>
87
+ >>> # With custom gain
88
+ >>> gain = jnp.ones((4,))
89
+ >>> w_std = bst.nn.weight_standardization(w, gain=gain)
90
+ """
91
+ w = u.maybe_custom_array(w)
92
+ if out_axis < 0:
93
+ out_axis = w.ndim + out_axis
94
+ fan_in = 1 # get the fan-in of the weight tensor
95
+ axes = [] # get the axes of the weight tensor
96
+ for i in range(w.ndim):
97
+ if i != out_axis:
98
+ fan_in *= w.shape[i]
99
+ axes.append(i)
100
+ # normalize the weight
101
+ mean = u.math.mean(w, axis=axes, keepdims=True)
102
+ var = u.math.var(w, axis=axes, keepdims=True)
103
+
104
+ temp = u.math.maximum(var * fan_in, eps)
105
+ if isinstance(temp, u.Quantity):
106
+ unit = temp.unit
107
+ temp = temp.mantissa
108
+ if unit.is_unitless:
109
+ scale = jax.lax.rsqrt(temp)
110
+ else:
111
+ scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
112
+ else:
113
+ scale = jax.lax.rsqrt(temp)
114
+ if gain is not None:
115
+ scale = gain * scale
116
+ shift = mean * scale
117
+ return w * scale - shift
118
+
119
+
120
+ def canonicalize_dtype(
121
+ *args,
122
+ dtype: jax.typing.DTypeLike | None = None,
123
+ inexact: bool = True
124
+ ) -> jax.typing.DTypeLike:
125
+ """
126
+ Canonicalize an optional dtype to the definitive dtype.
127
+
128
+ If the ``dtype`` is None, this function will infer the dtype from the input
129
+ arguments using ``jnp.result_type``. If it is not None, it will be returned
130
+ unmodified or an exception is raised if the dtype is invalid.
131
+
132
+ Parameters
133
+ ----------
134
+ *args : ArrayLike
135
+ JAX array compatible values. None values are ignored.
136
+ dtype : jax.typing.DTypeLike, optional
137
+ Optional dtype override. If specified, the arguments are cast to the
138
+ specified dtype and dtype inference is disabled. Default is None.
139
+ inexact : bool, optional
140
+ When True, the output dtype must be a subtype of ``jnp.inexact``.
141
+ Inexact dtypes are real or complex floating points. This is useful
142
+ when applying operations that don't work directly on integers like
143
+ taking a mean. Default is True.
144
+
145
+ Returns
146
+ -------
147
+ jax.typing.DTypeLike
148
+ The dtype that ``*args`` should be cast to.
149
+
150
+ Raises
151
+ ------
152
+ ValueError
153
+ If ``inexact=True`` and the resulting dtype is not an inexact type.
154
+
155
+ Examples
156
+ --------
157
+ .. code-block:: python
158
+
159
+ >>> import jax.numpy as jnp
160
+ >>>
161
+ >>> # Infer dtype from arguments
162
+ >>> x = jnp.array([1, 2, 3])
163
+ >>> dtype = canonicalize_dtype(x)
164
+ >>>
165
+ >>> # Specify explicit dtype
166
+ >>> dtype = canonicalize_dtype(x, dtype=jnp.float64)
167
+ """
168
+ if dtype is None:
169
+ args_filtered = [jnp.asarray(x) for x in args if x is not None]
170
+ dtype = jnp.result_type(*args_filtered)
171
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
172
+ dtype = jnp.promote_types(jnp.float32, dtype)
173
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
174
+ raise ValueError(f'Dtype must be inexact: {dtype}')
175
+ return dtype
176
+
177
+
178
+ def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
179
+ axes = []
180
+ for axis in feature_axes:
181
+ if axis < 0:
182
+ axis += ndim
183
+ if axis < 0 or axis >= ndim:
184
+ raise ValueError(f'Invalid axis {axis} for {ndim}D input')
185
+ axes.append(axis)
186
+ return tuple(axes)
187
+
188
+
189
+ def _abs_sq(x):
190
+ """Computes the elementwise square of the absolute value |x|^2."""
191
+ if jnp.iscomplexobj(x):
192
+ return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
193
+ else:
194
+ return jax.lax.square(x)
195
+
196
+
197
+ class NormalizationParamState(ParamState):
198
+ # This is a dummy class to be used as a compatibility
199
+ # usage of `ETraceParam` for the layers in "brainetrace"
200
+ def execute(self, x):
201
+ param = self.value
202
+ if 'scale' in param:
203
+ x = x * param['scale']
204
+ if 'bias' in param:
205
+ x = x + param['bias']
206
+ return x
207
+
208
+
209
+ def _compute_stats(
210
+ x: ArrayLike,
211
+ axes: Sequence[int],
212
+ dtype: DTypeLike,
213
+ axis_name: Optional[str] = None,
214
+ axis_index_groups: Optional[Sequence[int]] = None,
215
+ use_mean: bool = True,
216
+ use_fast_variance: bool = True,
217
+ mask: Optional[jax.Array] = None,
218
+ ):
219
+ """
220
+ Compute mean and variance statistics for normalization.
221
+
222
+ This implementation includes several optimizations:
223
+
224
+ - Computes in float32 precision for stability in half precision training.
225
+ - If ``use_fast_variance`` is True, uses the formula Var = E[|x|^2] - |E[x]|^2
226
+ instead of Var = E[|x - E[x]|^2] in a single XLA fusion.
227
+ - Clips negative variances to zero to avoid downstream NaNs from roundoff errors.
228
+ - Supports averaging across parallel axes and subgroups with a single
229
+ ``lax.pmean`` call to reduce latency.
230
+
231
+ Parameters
232
+ ----------
233
+ x : ArrayLike
234
+ Input array.
235
+ axes : Sequence[int]
236
+ The axes in ``x`` to compute mean and variance statistics for.
237
+ dtype : DTypeLike
238
+ Optional dtype specifying the minimal precision. Statistics are always
239
+ at least float32 for stability. If None, uses the dtype of x.
240
+ axis_name : str, optional
241
+ Optional name for the pmapped axis to compute mean over. Only used for
242
+ pmap and shard map. For SPMD jit, axes should be correctly annotated
243
+ and XLA:SPMD will insert necessary collectives. Default is None.
244
+ axis_index_groups : Sequence[int], optional
245
+ Optional axis indices for grouped reductions. Default is None.
246
+ use_mean : bool, optional
247
+ If True, calculate the mean from the input and use it when computing
248
+ the variance. If False, set the mean to zero and compute the variance
249
+ without subtracting the mean. Default is True.
250
+ use_fast_variance : bool, optional
251
+ If True, use a faster but less numerically stable calculation for the
252
+ variance. Default is True.
253
+ mask : jax.Array, optional
254
+ Binary array of shape broadcastable to ``x``, indicating the positions
255
+ for which the mean and variance should be computed. Default is None.
256
+
257
+ Returns
258
+ -------
259
+ tuple of jax.Array
260
+ A pair ``(mean, var)`` containing the computed mean and variance.
261
+ """
262
+ if dtype is None:
263
+ dtype = jax.numpy.result_type(x)
264
+ # promote x to at least float32, this avoids half precision computation
265
+ # but preserves double or complex floating points
266
+ dtype = jax.numpy.promote_types(dtype, jnp.float32)
267
+ x = jnp.asarray(x, dtype)
268
+ axes = _canonicalize_axes(x.ndim, axes)
269
+
270
+ def maybe_distributed_mean(*xs, mask=None):
271
+ mus = tuple(x.mean(axes, where=mask) for x in xs)
272
+ if axis_name is None:
273
+ return mus if len(xs) > 1 else mus[0]
274
+ else:
275
+ # In the distributed case we stack multiple arrays to speed comms.
276
+ if len(xs) > 1:
277
+ reduced_mus = jax.lax.pmean(
278
+ jnp.stack(mus, axis=0),
279
+ axis_name,
280
+ axis_index_groups=axis_index_groups,
281
+ )
282
+ return tuple(reduced_mus[i] for i in range(len(xs)))
283
+ else:
284
+ return jax.lax.pmean(
285
+ mus[0],
286
+ axis_name,
287
+ axis_index_groups=axis_index_groups
288
+ )
289
+
290
+ if use_mean:
291
+ if use_fast_variance:
292
+ mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
293
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
294
+ # to floating point round-off errors.
295
+ var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
296
+ else:
297
+ mu = maybe_distributed_mean(x, mask=mask)
298
+ var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
299
+ else:
300
+ var = maybe_distributed_mean(_abs_sq(x), mask=mask)
301
+ mu = jnp.zeros_like(var)
302
+ return mu, var
303
+
304
+
305
+ def _normalize(
306
+ x: ArrayLike,
307
+ mean: Optional[ArrayLike],
308
+ var: Optional[ArrayLike],
309
+ weights: Optional[NormalizationParamState],
310
+ reduction_axes: Axes,
311
+ feature_axes: Axes,
312
+ dtype: DTypeLike,
313
+ epsilon: jax.typing.ArrayLike,
314
+ ):
315
+ """
316
+ Normalize the input and optionally apply learned scale and bias.
317
+
318
+ Parameters
319
+ ----------
320
+ x : ArrayLike
321
+ The input array.
322
+ mean : ArrayLike, optional
323
+ Mean to use for normalization. If None, normalization is skipped.
324
+ var : ArrayLike, optional
325
+ Variance to use for normalization. If None, normalization is skipped.
326
+ weights : NormalizationParamState, optional
327
+ The scale and bias parameters. If None, no affine transformation is applied.
328
+ reduction_axes : Axes
329
+ The axes in ``x`` to reduce.
330
+ feature_axes : Axes
331
+ The feature axes to apply the scale and bias.
332
+ dtype : DTypeLike
333
+ The dtype of the result. If None, inferred from input and parameters.
334
+ epsilon : jax.typing.ArrayLike
335
+ A small value added to variance to avoid division by zero.
336
+
337
+ Returns
338
+ -------
339
+ jax.Array
340
+ The normalized input array.
341
+ """
342
+ if mean is not None:
343
+ assert var is not None, 'mean and val must be both None or not None.'
344
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
345
+ feature_axes = _canonicalize_axes(x.ndim, feature_axes)
346
+ stats_shape = list(x.shape)
347
+ for axis in reduction_axes:
348
+ stats_shape[axis] = 1
349
+ mean = mean.reshape(stats_shape)
350
+ var = var.reshape(stats_shape)
351
+ feature_shape = [1] * x.ndim
352
+ for ax in feature_axes:
353
+ feature_shape[ax] = x.shape[ax]
354
+ y = x - mean
355
+ mul = jax.lax.rsqrt(var + epsilon)
356
+ y = y * mul
357
+ if weights is not None:
358
+ y = weights.execute(y)
359
+ dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
360
+ else:
361
+ assert var is None, 'mean and val must be both None or not None.'
362
+ assert weights is None, 'scale and bias are not supported without mean and val'
363
+ y = x
364
+ return jnp.asarray(y, dtype)
365
+
366
+
367
+ class _BatchNorm(Module):
368
+ __module__ = 'brainstate.nn'
369
+ num_spatial_dims: int
370
+
371
+ def __init__(
372
+ self,
373
+ in_size: Size,
374
+ feature_axis: Axes = -1,
375
+ *,
376
+ track_running_stats: bool = True,
377
+ epsilon: float = 1e-5,
378
+ momentum: float = 0.99,
379
+ affine: bool = True,
380
+ bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
381
+ scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
382
+ axis_name: Optional[Union[str, Sequence[str]]] = None,
383
+ axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
384
+ use_fast_variance: bool = True,
385
+ name: Optional[str] = None,
386
+ dtype: Any = None,
387
+ param_type: type = NormalizationParamState,
388
+ mean_type: type = BatchState,
389
+ ):
390
+ super().__init__(name=name)
391
+
392
+ # parameters
393
+ self.in_size = in_size
394
+ self.out_size = in_size
395
+ self.affine = affine
396
+ self.bias_initializer = bias_initializer
397
+ self.scale_initializer = scale_initializer
398
+ self.dtype = dtype or environ.dftype()
399
+ self.track_running_stats = track_running_stats
400
+ self.momentum = jnp.asarray(momentum, dtype=self.dtype)
401
+ self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
402
+ self.use_fast_variance = use_fast_variance
403
+
404
+ # parameters about axis
405
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
406
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
407
+ self.axis_name = axis_name
408
+ self.axis_index_groups = axis_index_groups
409
+
410
+ # variables
411
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
412
+ for i, ax in enumerate(self.in_size)])
413
+ if self.track_running_stats:
414
+ self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
415
+ self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
416
+ else:
417
+ self.running_mean = None
418
+ self.running_var = None
419
+
420
+ # parameters
421
+ if self.affine:
422
+ assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
423
+ bias = init.param(self.bias_initializer, feature_shape)
424
+ scale = init.param(self.scale_initializer, feature_shape)
425
+ self.weight = param_type(dict(bias=bias, scale=scale))
426
+ else:
427
+ self.weight = None
428
+
429
+ def update(self, x, mask: Optional[jax.Array] = None):
430
+ # input shape and batch mode or not
431
+ if x.ndim == self.num_spatial_dims + 2:
432
+ x_shape = x.shape[1:]
433
+ batch = True
434
+ elif x.ndim == self.num_spatial_dims + 1:
435
+ x_shape = x.shape
436
+ batch = False
437
+ else:
438
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
439
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
440
+ if self.in_size != x_shape:
441
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
442
+
443
+ # reduce the feature axis
444
+ if batch:
445
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
446
+ else:
447
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
448
+
449
+ # fitting phase
450
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
451
+
452
+ # compute the running mean and variance
453
+ if self.track_running_stats:
454
+ if fit_phase:
455
+ mean, var = _compute_stats(
456
+ x,
457
+ reduction_axes,
458
+ dtype=self.dtype,
459
+ axis_name=self.axis_name,
460
+ axis_index_groups=self.axis_index_groups,
461
+ use_fast_variance=self.use_fast_variance,
462
+ mask=mask,
463
+ )
464
+ self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
465
+ self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
466
+ else:
467
+ mean = self.running_mean.value
468
+ var = self.running_var.value
469
+ else:
470
+ mean, var = None, None
471
+
472
+ # normalize
473
+ return _normalize(
474
+ x,
475
+ mean=mean,
476
+ var=var,
477
+ weights=self.weight,
478
+ reduction_axes=reduction_axes,
479
+ feature_axes=self.feature_axes,
480
+ dtype=self.dtype,
481
+ epsilon=self.epsilon
482
+ )
483
+
484
+
485
+ class BatchNorm0d(_BatchNorm):
486
+ """
487
+ 0-D batch normalization.
488
+
489
+ Normalizes a batch of 0-D data (vectors) by fixing the mean and variance
490
+ of inputs on each feature (channel). This layer aims to reduce the internal
491
+ covariate shift of data.
492
+
493
+ The input data should have shape ``(b, c)``, where ``b`` is the batch dimension
494
+ and ``c`` is the channel dimension.
495
+
496
+ The normalization is performed as:
497
+
498
+ .. math::
499
+ y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\operatorname{Var}[x] + \\epsilon}} \\cdot \\gamma + \\beta
500
+
501
+ where :math:`\\gamma` and :math:`\\beta` are learnable affine parameters (if ``affine=True``).
502
+
503
+ Parameters
504
+ ----------
505
+ in_size : tuple of int
506
+ The input shape, without batch dimension.
507
+ feature_axis : int or tuple of int, optional
508
+ The feature or non-batch axis of the input. Default is -1.
509
+ track_running_stats : bool, optional
510
+ If True, tracks the running mean and variance. If False, uses batch
511
+ statistics in both training and eval modes. Default is True.
512
+ epsilon : float, optional
513
+ A value added to the denominator for numerical stability. Default is 1e-5.
514
+ momentum : float, optional
515
+ The momentum value used for the ``running_mean`` and ``running_var``
516
+ computation. The update rule is:
517
+ :math:`\\hat{x}_{\\text{new}} = \\text{momentum} \\times \\hat{x} + (1 - \\text{momentum}) \\times x_t`.
518
+ Default is 0.99.
519
+ affine : bool, optional
520
+ If True, this module has learnable affine parameters (scale and bias).
521
+ Default is True.
522
+ bias_initializer : ArrayLike or Callable, optional
523
+ Initializer for the bias (beta) parameter. Default is ``init.Constant(0.)``.
524
+ scale_initializer : ArrayLike or Callable, optional
525
+ Initializer for the scale (gamma) parameter. Default is ``init.Constant(1.)``.
526
+ axis_name : str or sequence of str, optional
527
+ The axis name(s) for parallel reduction using ``jax.pmap`` or ``jax.vmap``.
528
+ If specified, batch statistics are calculated across all replicas on the
529
+ named axes. Default is None.
530
+ axis_index_groups : sequence of sequence of int, optional
531
+ Groups of axis indices within the named axis representing subsets of
532
+ devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
533
+ independently batch-normalize over the first two and last two devices.
534
+ See ``jax.lax.psum`` for more details. Default is None.
535
+ use_fast_variance : bool, optional
536
+ If True, use a faster but less numerically stable calculation for
537
+ the variance. Default is True.
538
+
539
+ Notes
540
+ -----
541
+ The ``momentum`` parameter is different from the conventional notion of
542
+ momentum used in optimizers.
543
+
544
+ References
545
+ ----------
546
+ .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
547
+ Deep Network Training by Reducing Internal Covariate Shift.
548
+ In International Conference on Machine Learning (pp. 448-456).
549
+
550
+ Examples
551
+ --------
552
+ .. code-block:: python
553
+
554
+ >>> import brainstate as bst
555
+ >>> import jax.numpy as jnp
556
+ >>>
557
+ >>> # Create a BatchNorm0d layer
558
+ >>> layer = bst.nn.BatchNorm0d(in_size=(10,))
559
+ >>>
560
+ >>> # Apply normalization to a batch of data
561
+ >>> x = jnp.ones((32, 10)) # batch_size=32, features=10
562
+ >>> y = layer(x)
563
+ >>>
564
+ >>> # Check output shape
565
+ >>> print(y.shape)
566
+ (32, 10)
567
+ """
568
+ __module__ = 'brainstate.nn'
569
+ num_spatial_dims: int = 0
570
+
571
+
572
+ class BatchNorm1d(_BatchNorm):
573
+ """
574
+ 1-D batch normalization.
575
+
576
+ Normalizes a batch of 1-D data by fixing the mean and variance of inputs
577
+ on each feature (channel). This layer aims to reduce the internal covariate
578
+ shift of data.
579
+
580
+ The input data should have shape ``(b, l, c)``, where ``b`` is the batch
581
+ dimension, ``l`` is the spatial/sequence dimension, and ``c`` is the channel
582
+ dimension.
583
+
584
+ Parameters
585
+ ----------
586
+ in_size : tuple of int
587
+ The input shape, without batch dimension. For 1-D data, typically ``(l, c)``.
588
+ feature_axis : int or tuple of int, optional
589
+ The feature or non-batch axis of the input. Default is -1.
590
+ track_running_stats : bool, optional
591
+ If True, tracks the running mean and variance. If False, uses batch
592
+ statistics in both training and eval modes. Default is True.
593
+ epsilon : float, optional
594
+ A value added to the denominator for numerical stability. Default is 1e-5.
595
+ momentum : float, optional
596
+ The momentum value for running statistics computation. Default is 0.99.
597
+ affine : bool, optional
598
+ If True, has learnable affine parameters (scale and bias). Default is True.
599
+ bias_initializer : ArrayLike or Callable, optional
600
+ Initializer for the bias parameter. Default is ``init.Constant(0.)``.
601
+ scale_initializer : ArrayLike or Callable, optional
602
+ Initializer for the scale parameter. Default is ``init.Constant(1.)``.
603
+ axis_name : str or sequence of str, optional
604
+ Axis name(s) for parallel reduction. Default is None.
605
+ axis_index_groups : sequence of sequence of int, optional
606
+ Groups of axis indices for device-grouped reduction. Default is None.
607
+ use_fast_variance : bool, optional
608
+ If True, use faster but less stable variance calculation. Default is True.
609
+
610
+ References
611
+ ----------
612
+ .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
613
+ Deep Network Training by Reducing Internal Covariate Shift.
614
+ In International Conference on Machine Learning (pp. 448-456).
615
+
616
+ See Also
617
+ --------
618
+ BatchNorm0d : 0-D batch normalization
619
+ BatchNorm2d : 2-D batch normalization
620
+ BatchNorm3d : 3-D batch normalization
621
+
622
+ Examples
623
+ --------
624
+ .. code-block:: python
625
+
626
+ >>> import brainstate as bst
627
+ >>> import jax.numpy as jnp
628
+ >>>
629
+ >>> # Create a BatchNorm1d layer for sequence data
630
+ >>> layer = bst.nn.BatchNorm1d(in_size=(100, 64)) # length=100, channels=64
631
+ >>>
632
+ >>> # Apply normalization
633
+ >>> x = jnp.ones((8, 100, 64)) # batch_size=8
634
+ >>> y = layer(x)
635
+ >>> print(y.shape)
636
+ (8, 100, 64)
637
+ """
638
+ __module__ = 'brainstate.nn'
639
+ num_spatial_dims: int = 1
640
+
641
+
642
+ class BatchNorm2d(_BatchNorm):
643
+ """
644
+ 2-D batch normalization.
645
+
646
+ Normalizes a batch of 2-D data (e.g., images) by fixing the mean and variance
647
+ of inputs on each feature (channel). This layer aims to reduce the internal
648
+ covariate shift of data.
649
+
650
+ The input data should have shape ``(b, h, w, c)``, where ``b`` is the batch
651
+ dimension, ``h`` is the height dimension, ``w`` is the width dimension, and
652
+ ``c`` is the channel dimension.
653
+
654
+ Parameters
655
+ ----------
656
+ in_size : tuple of int
657
+ The input shape, without batch dimension. For 2-D data, typically ``(h, w, c)``.
658
+ feature_axis : int or tuple of int, optional
659
+ The feature or non-batch axis of the input. Default is -1.
660
+ track_running_stats : bool, optional
661
+ If True, tracks the running mean and variance. If False, uses batch
662
+ statistics in both training and eval modes. Default is True.
663
+ epsilon : float, optional
664
+ A value added to the denominator for numerical stability. Default is 1e-5.
665
+ momentum : float, optional
666
+ The momentum value for running statistics computation. Default is 0.99.
667
+ affine : bool, optional
668
+ If True, has learnable affine parameters (scale and bias). Default is True.
669
+ bias_initializer : ArrayLike or Callable, optional
670
+ Initializer for the bias parameter. Default is ``init.Constant(0.)``.
671
+ scale_initializer : ArrayLike or Callable, optional
672
+ Initializer for the scale parameter. Default is ``init.Constant(1.)``.
673
+ axis_name : str or sequence of str, optional
674
+ Axis name(s) for parallel reduction. Default is None.
675
+ axis_index_groups : sequence of sequence of int, optional
676
+ Groups of axis indices for device-grouped reduction. Default is None.
677
+ use_fast_variance : bool, optional
678
+ If True, use faster but less stable variance calculation. Default is True.
679
+
680
+ References
681
+ ----------
682
+ .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
683
+ Deep Network Training by Reducing Internal Covariate Shift.
684
+ In International Conference on Machine Learning (pp. 448-456).
685
+
686
+ See Also
687
+ --------
688
+ BatchNorm0d : 0-D batch normalization
689
+ BatchNorm1d : 1-D batch normalization
690
+ BatchNorm3d : 3-D batch normalization
691
+
692
+ Examples
693
+ --------
694
+ .. code-block:: python
695
+
696
+ >>> import brainstate as bst
697
+ >>> import jax.numpy as jnp
698
+ >>>
699
+ >>> # Create a BatchNorm2d layer for image data
700
+ >>> layer = bst.nn.BatchNorm2d(in_size=(28, 28, 3)) # 28x28 RGB images
701
+ >>>
702
+ >>> # Apply normalization
703
+ >>> x = jnp.ones((16, 28, 28, 3)) # batch_size=16
704
+ >>> y = layer(x)
705
+ >>> print(y.shape)
706
+ (16, 28, 28, 3)
707
+ """
708
+ __module__ = 'brainstate.nn'
709
+ num_spatial_dims: int = 2
710
+
711
+
712
+ class BatchNorm3d(_BatchNorm):
713
+ """
714
+ 3-D batch normalization.
715
+
716
+ Normalizes a batch of 3-D data (e.g., video or volumetric data) by fixing
717
+ the mean and variance of inputs on each feature (channel). This layer aims
718
+ to reduce the internal covariate shift of data.
719
+
720
+ The input data should have shape ``(b, h, w, d, c)``, where ``b`` is the
721
+ batch dimension, ``h`` is the height dimension, ``w`` is the width dimension,
722
+ ``d`` is the depth dimension, and ``c`` is the channel dimension.
723
+
724
+ Parameters
725
+ ----------
726
+ in_size : tuple of int
727
+ The input shape, without batch dimension. For 3-D data, typically ``(h, w, d, c)``.
728
+ feature_axis : int or tuple of int, optional
729
+ The feature or non-batch axis of the input. Default is -1.
730
+ track_running_stats : bool, optional
731
+ If True, tracks the running mean and variance. If False, uses batch
732
+ statistics in both training and eval modes. Default is True.
733
+ epsilon : float, optional
734
+ A value added to the denominator for numerical stability. Default is 1e-5.
735
+ momentum : float, optional
736
+ The momentum value for running statistics computation. Default is 0.99.
737
+ affine : bool, optional
738
+ If True, has learnable affine parameters (scale and bias). Default is True.
739
+ bias_initializer : ArrayLike or Callable, optional
740
+ Initializer for the bias parameter. Default is ``init.Constant(0.)``.
741
+ scale_initializer : ArrayLike or Callable, optional
742
+ Initializer for the scale parameter. Default is ``init.Constant(1.)``.
743
+ axis_name : str or sequence of str, optional
744
+ Axis name(s) for parallel reduction. Default is None.
745
+ axis_index_groups : sequence of sequence of int, optional
746
+ Groups of axis indices for device-grouped reduction. Default is None.
747
+ use_fast_variance : bool, optional
748
+ If True, use faster but less stable variance calculation. Default is True.
749
+
750
+ References
751
+ ----------
752
+ .. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
753
+ Deep Network Training by Reducing Internal Covariate Shift.
754
+ In International Conference on Machine Learning (pp. 448-456).
755
+
756
+ See Also
757
+ --------
758
+ BatchNorm0d : 0-D batch normalization
759
+ BatchNorm1d : 1-D batch normalization
760
+ BatchNorm2d : 2-D batch normalization
761
+
762
+ Examples
763
+ --------
764
+ .. code-block:: python
765
+
766
+ >>> import brainstate as bst
767
+ >>> import jax.numpy as jnp
768
+ >>>
769
+ >>> # Create a BatchNorm3d layer for volumetric data
770
+ >>> layer = bst.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes
771
+ >>>
772
+ >>> # Apply normalization
773
+ >>> x = jnp.ones((4, 32, 32, 32, 1)) # batch_size=4
774
+ >>> y = layer(x)
775
+ >>> print(y.shape)
776
+ (4, 32, 32, 32, 1)
777
+ """
778
+ __module__ = 'brainstate.nn'
779
+ num_spatial_dims: int = 3
780
+
781
+
782
+ class LayerNorm(Module):
783
+ """
784
+ Layer normalization layer [1]_.
785
+
786
+ LayerNorm normalizes the activations of the layer for each given example in
787
+ a batch independently, rather than across a batch like Batch Normalization.
788
+ It applies a transformation that maintains the mean activation within each
789
+ example close to 0 and the activation standard deviation close to 1.
790
+
791
+ Parameters
792
+ ----------
793
+ in_size : tuple of int
794
+ The input shape, without batch dimension.
795
+ reduction_axes : int or tuple of int, optional
796
+ Axes for computing normalization statistics. It is recommended to use
797
+ negative integers, as positive integers may cause issues when batch
798
+ dimensions are present. Default is -1.
799
+ feature_axes : int or tuple of int, optional
800
+ Feature axes for learned bias and scaling. Default is -1.
801
+ epsilon : float, optional
802
+ A small value added to variance to avoid division by zero. Default is 1e-6.
803
+ use_bias : bool, optional
804
+ If True, bias (beta) is added. Default is True.
805
+ use_scale : bool, optional
806
+ If True, multiply by scale (gamma). When the next layer is linear
807
+ (e.g., nn.relu), this can be disabled since scaling will be done by
808
+ the next layer. Default is True.
809
+ bias_init : Callable, optional
810
+ Initializer for bias parameter. Default is ``init.ZeroInit()``.
811
+ scale_init : Callable, optional
812
+ Initializer for scale parameter. Default is ``init.Constant(1.0)``.
813
+ axis_name : str, optional
814
+ The axis name used to combine batch statistics from multiple devices.
815
+ See ``jax.pmap`` for axis name description. Only needed if the model
816
+ is subdivided across devices. Default is None.
817
+ axis_index_groups : sequence, optional
818
+ Groups of axis indices within the named axis representing subsets of
819
+ devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
820
+ independently normalize over the first two and last two devices.
821
+ See ``jax.lax.psum`` for details. Default is None.
822
+ use_fast_variance : bool, optional
823
+ If True, use a faster but less numerically stable calculation for
824
+ the variance. Default is True.
825
+ dtype : jax.typing.DTypeLike, optional
826
+ The dtype of the result. If None, inferred from input and parameters.
827
+ Default is None.
828
+
829
+ References
830
+ ----------
831
+ .. [1] Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization.
832
+ arXiv preprint arXiv:1607.06450.
833
+
834
+ See Also
835
+ --------
836
+ RMSNorm : Root Mean Square Layer Normalization
837
+ GroupNorm : Group Normalization
838
+ BatchNorm1d : 1-D Batch Normalization
839
+
840
+ Examples
841
+ --------
842
+ .. code-block:: python
843
+
844
+ >>> import brainstate as bst
845
+ >>>
846
+ >>> # Create a LayerNorm layer
847
+ >>> x = bst.random.normal(size=(3, 4, 5, 6))
848
+ >>> layer = bst.nn.LayerNorm(x.shape)
849
+ >>>
850
+ >>> # Apply normalization
851
+ >>> y = layer(x)
852
+ >>> print(y.shape)
853
+ (3, 4, 5, 6)
854
+ >>>
855
+ >>> # Normalize only the last dimension
856
+ >>> layer = bst.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
857
+ >>> x = bst.random.normal((5, 10, 20))
858
+ >>> y = layer(x)
859
+ """
860
+
861
+ def __init__(
862
+ self,
863
+ in_size: Size,
864
+ reduction_axes: Axes = -1,
865
+ feature_axes: Axes = -1,
866
+ *,
867
+ epsilon: float = 1e-6,
868
+ use_bias: bool = True,
869
+ use_scale: bool = True,
870
+ bias_init: Callable = init.ZeroInit(),
871
+ scale_init: Callable = init.Constant(1.0),
872
+ axis_name: Optional[str] = None,
873
+ axis_index_groups: Any = None,
874
+ use_fast_variance: bool = True,
875
+ dtype: Optional[jax.typing.DTypeLike] = None,
876
+ param_type: type = NormalizationParamState,
877
+ ):
878
+ super().__init__()
879
+
880
+ self.in_size = in_size
881
+ self.out_size = in_size
882
+
883
+ # parameters about axis
884
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
885
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
886
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
887
+ self.axis_name = axis_name
888
+ self.axis_index_groups = axis_index_groups
889
+
890
+ # variables
891
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
892
+ for i, ax in enumerate(self.in_size)])
893
+
894
+ weights = dict()
895
+ if use_scale:
896
+ weights['scale'] = init.param(scale_init, feature_shape)
897
+ if use_bias:
898
+ weights['bias'] = init.param(bias_init, feature_shape)
899
+ if len(weights):
900
+ self.weight = param_type(weights)
901
+ else:
902
+ self.weight = None
903
+
904
+ # parameters
905
+ self.epsilon = epsilon
906
+ self.dtype = dtype or environ.dftype()
907
+ self.use_bias = use_bias
908
+ self.use_scale = use_scale
909
+ self.bias_init = bias_init
910
+ self.scale_init = scale_init
911
+ self.use_fast_variance = use_fast_variance
912
+
913
+ def update(self, x, *, mask: Optional[jax.Array] = None):
914
+ """
915
+ Apply layer normalization on the input.
916
+
917
+ Parameters
918
+ ----------
919
+ x : jax.Array
920
+ The input array.
921
+ mask : jax.Array, optional
922
+ Binary array of shape broadcastable to ``x``, indicating the
923
+ positions for which normalization should be computed. Default is None.
924
+
925
+ Returns
926
+ -------
927
+ jax.Array
928
+ Normalized inputs with the same shape as the input.
929
+ """
930
+ mean, var = _compute_stats(
931
+ x,
932
+ self.reduction_axes,
933
+ dtype=self.dtype,
934
+ axis_name=self.axis_name,
935
+ axis_index_groups=self.axis_index_groups,
936
+ use_fast_variance=self.use_fast_variance,
937
+ mask=mask,
938
+ )
939
+
940
+ return _normalize(
941
+ x,
942
+ mean=mean,
943
+ var=var,
944
+ weights=self.weight,
945
+ reduction_axes=self.reduction_axes,
946
+ feature_axes=self.feature_axes,
947
+ dtype=self.dtype,
948
+ epsilon=self.epsilon,
949
+ )
950
+
951
+
952
+ class RMSNorm(Module):
953
+ """
954
+ Root Mean Square Layer Normalization [1]_.
955
+
956
+ RMSNorm normalizes the activations of the layer for each given example in a
957
+ batch independently, rather than across a batch like Batch Normalization.
958
+ Unlike LayerNorm which re-centers the mean to 0 and normalizes by the standard
959
+ deviation, RMSNorm does not re-center at all and instead normalizes by the
960
+ root mean square of the activations.
961
+
962
+ Parameters
963
+ ----------
964
+ in_size : tuple of int
965
+ The input shape, without batch dimension.
966
+ epsilon : float, optional
967
+ A small value added to variance to avoid division by zero. Default is 1e-6.
968
+ dtype : jax.typing.DTypeLike, optional
969
+ The dtype of the result. If None, inferred from input and parameters.
970
+ Default is None.
971
+ use_scale : bool, optional
972
+ If True, multiply by scale (gamma). When the next layer is linear
973
+ (e.g., nn.relu), this can be disabled since scaling will be done by
974
+ the next layer. Default is True.
975
+ scale_init : Callable, optional
976
+ Initializer for scale parameter. Default is ``init.Constant(1.0)``.
977
+ reduction_axes : int or tuple of int, optional
978
+ Axes for computing normalization statistics. It is recommended to use
979
+ negative integers. Default is -1.
980
+ feature_axes : int or tuple of int, optional
981
+ Feature axes for learned scaling. Default is -1.
982
+ axis_name : str, optional
983
+ The axis name used to combine batch statistics from multiple devices.
984
+ See ``jax.pmap`` for details. Default is None.
985
+ axis_index_groups : sequence, optional
986
+ Groups of axis indices within the named axis representing subsets of
987
+ devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
988
+ independently normalize over the first two and last two devices.
989
+ Default is None.
990
+ use_fast_variance : bool, optional
991
+ If True, use a faster but less numerically stable calculation for
992
+ the variance. Default is True.
993
+
994
+ References
995
+ ----------
996
+ .. [1] Zhang, B., & Sennrich, R. (2019). Root Mean Square Layer Normalization.
997
+ Advances in Neural Information Processing Systems, 32.
998
+
999
+ See Also
1000
+ --------
1001
+ LayerNorm : Layer Normalization
1002
+ GroupNorm : Group Normalization
1003
+
1004
+ Examples
1005
+ --------
1006
+ .. code-block:: python
1007
+
1008
+ >>> import brainstate as bst
1009
+ >>>
1010
+ >>> # Create an RMSNorm layer
1011
+ >>> x = bst.random.normal(size=(5, 6))
1012
+ >>> layer = bst.nn.RMSNorm(in_size=(6,))
1013
+ >>>
1014
+ >>> # Apply normalization
1015
+ >>> y = layer(x)
1016
+ >>> print(y.shape)
1017
+ (5, 6)
1018
+ >>>
1019
+ >>> # Without scaling
1020
+ >>> layer = bst.nn.RMSNorm(in_size=(10,), use_scale=False)
1021
+ >>> x = bst.random.normal((3, 10))
1022
+ >>> y = layer(x)
1023
+ """
1024
+
1025
+ def __init__(
1026
+ self,
1027
+ in_size: Size,
1028
+ *,
1029
+ epsilon: float = 1e-6,
1030
+ dtype: Optional[jax.typing.DTypeLike] = None,
1031
+ use_scale: bool = True,
1032
+ scale_init: Callable = init.Constant(1.0),
1033
+ reduction_axes: Axes = -1,
1034
+ feature_axes: Axes = -1,
1035
+ axis_name: Optional[str] = None,
1036
+ axis_index_groups: Any = None,
1037
+ use_fast_variance: bool = True,
1038
+ param_type: type = NormalizationParamState,
1039
+ ):
1040
+ super().__init__()
1041
+
1042
+ self.in_size = in_size
1043
+ self.out_size = in_size
1044
+
1045
+ # parameters about axis
1046
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
1047
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
1048
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
1049
+ self.axis_name = axis_name
1050
+ self.axis_index_groups = axis_index_groups
1051
+
1052
+ # variables
1053
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
1054
+ for i, ax in enumerate(self.in_size)])
1055
+ if use_scale:
1056
+ self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
1057
+ else:
1058
+ self.scale = None
1059
+
1060
+ # parameters
1061
+ self.epsilon = epsilon
1062
+ self.dtype = dtype or environ.dftype()
1063
+ self.use_scale = use_scale
1064
+ self.scale_init = scale_init
1065
+ self.use_fast_variance = use_fast_variance
1066
+
1067
+ def update(self, x, *, mask: Optional[jax.Array] = None):
1068
+ """
1069
+ Apply RMS normalization on the input.
1070
+
1071
+ Parameters
1072
+ ----------
1073
+ x : jax.Array
1074
+ The input array.
1075
+ mask : jax.Array, optional
1076
+ Binary array of shape broadcastable to ``x``, indicating the
1077
+ positions for which normalization should be computed. Default is None.
1078
+
1079
+ Returns
1080
+ -------
1081
+ jax.Array
1082
+ Normalized inputs with the same shape as the input.
1083
+ """
1084
+ mean, var = _compute_stats(
1085
+ x,
1086
+ self.reduction_axes,
1087
+ dtype=self.dtype,
1088
+ axis_name=self.axis_name,
1089
+ axis_index_groups=self.axis_index_groups,
1090
+ use_mean=False,
1091
+ use_fast_variance=self.use_fast_variance,
1092
+ mask=mask,
1093
+ )
1094
+
1095
+ return _normalize(
1096
+ x,
1097
+ mean=mean,
1098
+ var=var,
1099
+ weights=self.scale,
1100
+ reduction_axes=self.reduction_axes,
1101
+ feature_axes=self.feature_axes,
1102
+ dtype=self.dtype,
1103
+ epsilon=self.epsilon,
1104
+ )
1105
+
1106
+
1107
+ class GroupNorm(Module):
1108
+ """
1109
+ Group Normalization layer [1]_.
1110
+
1111
+ Group normalization is similar to batch normalization, but statistics are
1112
+ shared across equally-sized groups of channels and not shared across the
1113
+ batch dimension. Thus, group normalization does not depend on the batch
1114
+ composition and does not require maintaining internal state for storing statistics.
1115
+
1116
+ The user should specify either the total number of channel groups (``num_groups``)
1117
+ or the number of channels per group (``group_size``).
1118
+
1119
+ Parameters
1120
+ ----------
1121
+ in_size : tuple of int
1122
+ The input shape, without batch dimension.
1123
+ feature_axis : int or tuple of int, optional
1124
+ The feature axis of the input. Default is -1.
1125
+ num_groups : int, optional
1126
+ The total number of channel groups. The default value of 32 is proposed
1127
+ by the original group normalization paper. Either ``num_groups`` or
1128
+ ``group_size`` must be specified, but not both. Default is 32.
1129
+ group_size : int, optional
1130
+ The number of channels in each group. Either ``num_groups`` or
1131
+ ``group_size`` must be specified, but not both. Default is None.
1132
+ epsilon : float, optional
1133
+ A small value added to variance to avoid division by zero. Default is 1e-6.
1134
+ dtype : jax.typing.DTypeLike, optional
1135
+ The dtype of the result. If None, inferred from input and parameters.
1136
+ Default is None.
1137
+ use_bias : bool, optional
1138
+ If True, bias (beta) is added. Default is True.
1139
+ use_scale : bool, optional
1140
+ If True, multiply by scale (gamma). When the next layer is linear
1141
+ (e.g., nn.relu), this can be disabled. Default is True.
1142
+ bias_init : Callable, optional
1143
+ Initializer for bias parameter. Default is ``init.ZeroInit()``.
1144
+ scale_init : Callable, optional
1145
+ Initializer for scale parameter. Default is ``init.Constant(1.)``.
1146
+ reduction_axes : int or tuple of int, optional
1147
+ List of axes used for computing normalization statistics. Must include
1148
+ the final dimension (feature axis). It is recommended to use negative
1149
+ integers. Default is None.
1150
+ axis_name : str, optional
1151
+ The axis name used to combine batch statistics from multiple devices.
1152
+ See ``jax.pmap`` for details. Default is None.
1153
+ axis_index_groups : sequence, optional
1154
+ Groups of axis indices within the named axis representing subsets of
1155
+ devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
1156
+ independently normalize over the first two and last two devices.
1157
+ Default is None.
1158
+ use_fast_variance : bool, optional
1159
+ If True, use a faster but less numerically stable calculation for
1160
+ the variance. Default is True.
1161
+
1162
+ Notes
1163
+ -----
1164
+ LayerNorm is a special case of GroupNorm where ``num_groups=1``.
1165
+
1166
+ References
1167
+ ----------
1168
+ .. [1] Wu, Y., & He, K. (2018). Group Normalization.
1169
+ In Proceedings of the European Conference on Computer Vision (ECCV)
1170
+ (pp. 3-19).
1171
+
1172
+ See Also
1173
+ --------
1174
+ LayerNorm : Layer Normalization
1175
+ BatchNorm2d : 2-D Batch Normalization
1176
+
1177
+ Examples
1178
+ --------
1179
+ .. code-block:: python
1180
+
1181
+ >>> import numpy as np
1182
+ >>> import brainstate as bst
1183
+ >>>
1184
+ >>> # Create a GroupNorm layer with 3 groups
1185
+ >>> x = bst.random.normal(size=(3, 4, 5, 6))
1186
+ >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
1187
+ >>> y = layer(x)
1188
+ >>>
1189
+ >>> # GroupNorm with num_groups=1 is equivalent to LayerNorm
1190
+ >>> y1 = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
1191
+ >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
1192
+ >>> np.testing.assert_allclose(y1, y2, rtol=1e-5)
1193
+ >>>
1194
+ >>> # Specify group_size instead of num_groups
1195
+ >>> layer = bst.nn.GroupNorm((12,), num_groups=None, group_size=4)
1196
+ """
1197
+
1198
+ def __init__(
1199
+ self,
1200
+ in_size: Size,
1201
+ feature_axis: Axes = -1,
1202
+ num_groups: Optional[int] = 32,
1203
+ group_size: Optional[int] = None,
1204
+ *,
1205
+ epsilon: float = 1e-6,
1206
+ dtype: Optional[jax.typing.DTypeLike] = None,
1207
+ use_bias: bool = True,
1208
+ use_scale: bool = True,
1209
+ bias_init: Callable = init.ZeroInit(),
1210
+ scale_init: Callable = init.Constant(1.),
1211
+ reduction_axes: Optional[Axes] = None,
1212
+ axis_name: Optional[str] = None,
1213
+ axis_index_groups: Any = None,
1214
+ use_fast_variance: bool = True,
1215
+ param_type: type = NormalizationParamState,
1216
+ ):
1217
+ super().__init__()
1218
+
1219
+ self.in_size = in_size
1220
+ self.out_size = in_size
1221
+
1222
+ # parameters about axis
1223
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
1224
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
1225
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
1226
+ self.axis_name = axis_name
1227
+ self.axis_index_groups = axis_index_groups
1228
+
1229
+ if (num_groups is None and group_size is None) or (
1230
+ num_groups is not None and group_size is not None
1231
+ ):
1232
+ raise ValueError(
1233
+ 'Either `num_groups` or `group_size` should be '
1234
+ 'specified. If `group_size` is to be specified, '
1235
+ 'pass `num_groups=None` as argument to override '
1236
+ 'the default `num_groups` value of 32.'
1237
+ )
1238
+
1239
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
1240
+ for i, ax in enumerate(self.in_size)])
1241
+ assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
1242
+ num_features = feature_shape[0]
1243
+ if group_size is not None:
1244
+ if num_features % group_size != 0:
1245
+ raise ValueError(
1246
+ 'Number of features ({}) is not multiple of the '
1247
+ 'group size ({}).'.format(num_features, group_size)
1248
+ )
1249
+ self.num_groups = num_features // group_size
1250
+ self.group_size = group_size
1251
+ else:
1252
+ if not isinstance(num_groups, int) or num_groups <= 0 or (
1253
+ num_features % num_groups != 0
1254
+ ):
1255
+ raise ValueError(
1256
+ 'Number of groups ({}) does not divide the number'
1257
+ ' of channels ({}).'.format(num_groups, num_features)
1258
+ )
1259
+ self.num_groups = num_groups
1260
+ self.group_size = num_features // num_groups
1261
+
1262
+ # variables
1263
+ weights = dict()
1264
+ if use_scale:
1265
+ weights['scale'] = init.param(scale_init, feature_shape)
1266
+ if use_bias:
1267
+ weights['bias'] = init.param(bias_init, feature_shape)
1268
+ if len(weights):
1269
+ self.weight = param_type(weights)
1270
+ else:
1271
+ self.weight = None
1272
+
1273
+ # parameters
1274
+ self.epsilon = epsilon
1275
+ self.dtype = dtype
1276
+ self.use_bias = use_bias
1277
+ self.use_scale = use_scale
1278
+ self.bias_init = bias_init
1279
+ self.scale_init = scale_init
1280
+ self.use_fast_variance = use_fast_variance
1281
+
1282
+ def update(self, x, *, mask: Optional[jax.Array] = None):
1283
+ """
1284
+ Apply group normalization to the input.
1285
+
1286
+ Parameters
1287
+ ----------
1288
+ x : jax.Array
1289
+ The input of shape ``...C`` where ``C`` is the channels dimension
1290
+ and ``...`` represents an arbitrary number of extra dimensions. If no
1291
+ reduction axes have been specified, all additional dimensions will be
1292
+ used to accumulate statistics apart from the leading dimension which
1293
+ is assumed to represent the batch.
1294
+ mask : jax.Array, optional
1295
+ Binary array of shape broadcastable to ``x``, indicating the
1296
+ positions for which the mean and variance should be computed.
1297
+ Default is None.
1298
+
1299
+ Returns
1300
+ -------
1301
+ jax.Array
1302
+ Normalized inputs with the same shape as the input.
1303
+ """
1304
+ if self.reduction_axes is not None:
1305
+ reduction_axes = self.reduction_axes
1306
+ else:
1307
+ reduction_axes = list(range(1, x.ndim - 1)) + [-1]
1308
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
1309
+
1310
+ group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
1311
+ if mask is not None:
1312
+ mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
1313
+
1314
+ mean, var = _compute_stats(
1315
+ x.reshape(group_shape),
1316
+ list(reduction_axes[:-1]) + [-1],
1317
+ dtype=self.dtype,
1318
+ axis_name=self.axis_name,
1319
+ axis_index_groups=self.axis_index_groups,
1320
+ use_fast_variance=self.use_fast_variance,
1321
+ mask=mask,
1322
+ )
1323
+ mean = jnp.repeat(mean, self.group_size, axis=1)
1324
+ var = jnp.repeat(var, self.group_size, axis=1)
1325
+ return _normalize(
1326
+ x,
1327
+ mean=mean,
1328
+ var=var,
1329
+ weights=self.weight,
1330
+ reduction_axes=reduction_axes[:-1],
1331
+ feature_axes=self.feature_axes,
1332
+ dtype=self.dtype,
1333
+ epsilon=self.epsilon,
1334
+ )