brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +588 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
  127. brainstate-0.1.10.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,918 +1,975 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Callable, Union, Sequence, Optional, Any
19
-
20
- import jax
21
- import jax.numpy as jnp
22
-
23
- from brainstate import environ, init
24
- from brainstate._state import ParamState, BatchState
25
- from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
26
- from ._module import Module
27
-
28
- __all__ = [
29
- 'BatchNorm0d',
30
- 'BatchNorm1d',
31
- 'BatchNorm2d',
32
- 'BatchNorm3d',
33
- 'LayerNorm',
34
- 'RMSNorm',
35
- 'GroupNorm',
36
- ]
37
-
38
-
39
- def canonicalize_dtype(
40
- *args,
41
- dtype: jax.typing.DTypeLike | None = None,
42
- inexact: bool = True
43
- ) -> jax.typing.DTypeLike:
44
- """Canonicalize an optional dtype to the definitive dtype.
45
-
46
- If the ``dtype`` is None this function will infer the dtype. If it is not
47
- None it will be returned unmodified or an exceptions is raised if the dtype
48
- is invalid.
49
- from the input arguments using ``jnp.result_type``.
50
-
51
- Args:
52
- *args: JAX array compatible values. None values
53
- are ignored.
54
- dtype: Optional dtype override. If specified the arguments are cast to
55
- the specified dtype instead and dtype inference is disabled.
56
- inexact: When True, the output dtype must be a subdtype
57
- of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
58
- is useful when you want to apply operations that don't work directly on
59
- integers like taking a mean for example.
60
- Returns:
61
- The dtype that *args should be cast to.
62
- """
63
- if dtype is None:
64
- args_filtered = [jnp.asarray(x) for x in args if x is not None]
65
- dtype = jnp.result_type(*args_filtered)
66
- if inexact and not jnp.issubdtype(dtype, jnp.inexact):
67
- dtype = jnp.promote_types(jnp.float32, dtype)
68
- if inexact and not jnp.issubdtype(dtype, jnp.inexact):
69
- raise ValueError(f'Dtype must be inexact: {dtype}')
70
- return dtype
71
-
72
-
73
- def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
74
- axes = []
75
- for axis in feature_axes:
76
- if axis < 0:
77
- axis += ndim
78
- if axis < 0 or axis >= ndim:
79
- raise ValueError(f'Invalid axis {axis} for {ndim}D input')
80
- axes.append(axis)
81
- return tuple(axes)
82
-
83
-
84
- def _abs_sq(x):
85
- """Computes the elementwise square of the absolute value |x|^2."""
86
- if jnp.iscomplexobj(x):
87
- return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
88
- else:
89
- return jax.lax.square(x)
90
-
91
-
92
- class NormalizationParamState(ParamState):
93
- # This is a dummy class to be used as a compatibility
94
- # usage of `ETraceParam` for the layers in "brainetrace"
95
- def execute(self, x):
96
- param = self.value
97
- if 'scale' in param:
98
- x = x * param['scale']
99
- if 'bias' in param:
100
- x = x + param['bias']
101
- return x
102
-
103
-
104
- def _compute_stats(
105
- x: ArrayLike,
106
- axes: Sequence[int],
107
- dtype: DTypeLike,
108
- axis_name: Optional[str] = None,
109
- axis_index_groups: Optional[Sequence[int]] = None,
110
- use_mean: bool = True,
111
- use_fast_variance: bool = True,
112
- mask: Optional[jax.Array] = None,
113
- ):
114
- """
115
- Computes mean and variance statistics.
116
-
117
- This implementation takes care of a few important details:
118
- - Computes in float32 precision for stability in half precision training.
119
- - If ``use_fast_variance`` is ``True``, mean and variance are computed using
120
- Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
121
- - Clips negative variances to zero which can happen due to
122
- roundoff errors. This avoids downstream NaNs.
123
- - Supports averaging across a parallel axis and subgroups of a parallel axis
124
- with a single `lax.pmean` call to avoid latency.
125
-
126
- Arguments:
127
- x: Input array.
128
- axes: The axes in ``x`` to compute mean and variance statistics for.
129
- dtype: tp.Optional dtype specifying the minimal precision. Statistics
130
- are always at least float32 for stability (default: dtype of x).
131
- axis_name: Optional name for the pmapped axis to compute mean over. Note,
132
- this is only used for pmap and shard map. For SPMD jit, you do not need to
133
- manually synchronize. Just make sure that the axes are correctly annotated
134
- and XLA:SPMD will insert the necessary collectives.
135
- axis_index_groups: Optional axis indices.
136
- use_mean: If true, calculate the mean from the input and use it when
137
- computing the variance. If false, set the mean to zero and compute
138
- the variance without subtracting the mean.
139
- use_fast_variance: If true, use a faster, but less numerically stable,
140
- calculation for the variance.
141
- mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
142
- the positions for which the mean and variance should be computed.
143
-
144
- Returns:
145
- A pair ``(mean, val)``.
146
- """
147
- if dtype is None:
148
- dtype = jax.numpy.result_type(x)
149
- # promote x to at least float32, this avoids half precision computation
150
- # but preserves double or complex floating points
151
- dtype = jax.numpy.promote_types(dtype, jnp.float32)
152
- x = jnp.asarray(x, dtype)
153
- axes = _canonicalize_axes(x.ndim, axes)
154
-
155
- def maybe_distributed_mean(*xs, mask=None):
156
- mus = tuple(x.mean(axes, where=mask) for x in xs)
157
- if axis_name is None:
158
- return mus if len(xs) > 1 else mus[0]
159
- else:
160
- # In the distributed case we stack multiple arrays to speed comms.
161
- if len(xs) > 1:
162
- reduced_mus = jax.lax.pmean(
163
- jnp.stack(mus, axis=0),
164
- axis_name,
165
- axis_index_groups=axis_index_groups,
166
- )
167
- return tuple(reduced_mus[i] for i in range(len(xs)))
168
- else:
169
- return jax.lax.pmean(
170
- mus[0],
171
- axis_name,
172
- axis_index_groups=axis_index_groups
173
- )
174
-
175
- if use_mean:
176
- if use_fast_variance:
177
- mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
178
- # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
179
- # to floating point round-off errors.
180
- var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
181
- else:
182
- mu = maybe_distributed_mean(x, mask=mask)
183
- var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
184
- else:
185
- var = maybe_distributed_mean(_abs_sq(x), mask=mask)
186
- mu = jnp.zeros_like(var)
187
- return mu, var
188
-
189
-
190
- def _normalize(
191
- x: ArrayLike,
192
- mean: Optional[ArrayLike],
193
- var: Optional[ArrayLike],
194
- weights: Optional[NormalizationParamState],
195
- reduction_axes: Axes,
196
- feature_axes: Axes,
197
- dtype: DTypeLike,
198
- epsilon: jax.typing.ArrayLike,
199
- ):
200
- """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
201
-
202
- Arguments:
203
- x: The input.
204
- mean: Mean to use for normalization.
205
- var: Variance to use for normalization.
206
- weights: The scale and bias parameters.
207
- reduction_axes: The axes in ``x`` to reduce.
208
- feature_axes: The feature axes to apply the scale and bias.
209
- dtype: The dtype of the result (default: infer from input and params).
210
- epsilon: Normalization epsilon.
211
-
212
- Returns:
213
- The normalized input.
214
- """
215
- if mean is not None:
216
- assert var is not None, 'mean and val must be both None or not None.'
217
- reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
218
- feature_axes = _canonicalize_axes(x.ndim, feature_axes)
219
- stats_shape = list(x.shape)
220
- for axis in reduction_axes:
221
- stats_shape[axis] = 1
222
- mean = mean.reshape(stats_shape)
223
- var = var.reshape(stats_shape)
224
- feature_shape = [1] * x.ndim
225
- for ax in feature_axes:
226
- feature_shape[ax] = x.shape[ax]
227
- y = x - mean
228
- mul = jax.lax.rsqrt(var + epsilon)
229
- y = y * mul
230
- if weights is not None:
231
- y = weights.execute(y)
232
- dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
233
- else:
234
- assert var is None, 'mean and val must be both None or not None.'
235
- assert weights is None, 'scale and bias are not supported without mean and val'
236
- y = x
237
- return jnp.asarray(y, dtype)
238
-
239
-
240
- class _BatchNorm(Module):
241
- __module__ = 'brainstate.nn'
242
- num_spatial_dims: int
243
-
244
- def __init__(
245
- self,
246
- in_size: Size,
247
- feature_axis: Axes = -1,
248
- *,
249
- track_running_stats: bool = True,
250
- epsilon: float = 1e-5,
251
- momentum: float = 0.99,
252
- affine: bool = True,
253
- bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
254
- scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
255
- axis_name: Optional[Union[str, Sequence[str]]] = None,
256
- axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
257
- use_fast_variance: bool = True,
258
- name: Optional[str] = None,
259
- dtype: Any = None,
260
- param_type: type = NormalizationParamState,
261
- mean_type: type = BatchState,
262
- ):
263
- super().__init__(name=name)
264
-
265
- # parameters
266
- self.in_size = in_size
267
- self.out_size = in_size
268
- self.affine = affine
269
- self.bias_initializer = bias_initializer
270
- self.scale_initializer = scale_initializer
271
- self.dtype = dtype or environ.dftype()
272
- self.track_running_stats = track_running_stats
273
- self.momentum = jnp.asarray(momentum, dtype=self.dtype)
274
- self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
275
- self.use_fast_variance = use_fast_variance
276
-
277
- # parameters about axis
278
- feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
279
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
280
- self.axis_name = axis_name
281
- self.axis_index_groups = axis_index_groups
282
-
283
- # variables
284
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
285
- for i, ax in enumerate(self.in_size)])
286
- if self.track_running_stats:
287
- self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
288
- self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
289
- else:
290
- self.running_mean = None
291
- self.running_var = None
292
-
293
- # parameters
294
- if self.affine:
295
- assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
296
- bias = init.param(self.bias_initializer, feature_shape)
297
- scale = init.param(self.scale_initializer, feature_shape)
298
- self.weight = param_type(dict(bias=bias, scale=scale))
299
- else:
300
- self.weight = None
301
-
302
- def update(self, x, mask: Optional[jax.Array] = None):
303
- # input shape and batch mode or not
304
- if x.ndim == self.num_spatial_dims + 2:
305
- x_shape = x.shape[1:]
306
- batch = True
307
- elif x.ndim == self.num_spatial_dims + 1:
308
- x_shape = x.shape
309
- batch = False
310
- else:
311
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
312
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
313
- if self.in_size != x_shape:
314
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
315
-
316
- # reduce the feature axis
317
- if batch:
318
- reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
319
- else:
320
- reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
321
-
322
- # fitting phase
323
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
324
-
325
- # compute the running mean and variance
326
- if self.track_running_stats:
327
- if fit_phase:
328
- mean, var = _compute_stats(
329
- x,
330
- reduction_axes,
331
- dtype=self.dtype,
332
- axis_name=self.axis_name,
333
- axis_index_groups=self.axis_index_groups,
334
- use_fast_variance=self.use_fast_variance,
335
- mask=mask,
336
- )
337
- self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
338
- self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
339
- else:
340
- mean = self.running_mean.value
341
- var = self.running_var.value
342
- else:
343
- mean, var = None, None
344
-
345
- # normalize
346
- return _normalize(
347
- x,
348
- mean=mean,
349
- var=var,
350
- weights=self.weight,
351
- reduction_axes=reduction_axes,
352
- feature_axes=self.feature_axes,
353
- dtype=self.dtype,
354
- epsilon=self.epsilon
355
- )
356
-
357
-
358
- class BatchNorm0d(_BatchNorm):
359
- r"""0-D batch normalization [1]_.
360
-
361
- The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
362
-
363
- %s
364
- """
365
- __module__ = 'brainstate.nn'
366
- num_spatial_dims: int = 0
367
-
368
-
369
- class BatchNorm1d(_BatchNorm):
370
- r"""1-D batch normalization [1]_.
371
-
372
- The data should be of `(b, l, c)`, where `b` is the batch dimension,
373
- `l` is the layer dimension, and `c` is the channel dimension.
374
-
375
- %s
376
- """
377
- __module__ = 'brainstate.nn'
378
- num_spatial_dims: int = 1
379
-
380
-
381
- class BatchNorm2d(_BatchNorm):
382
- r"""2-D batch normalization [1]_.
383
-
384
- The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
385
- `h` is the height dimension, `w` is the width dimension, and `c` is the
386
- channel dimension.
387
-
388
- %s
389
- """
390
- __module__ = 'brainstate.nn'
391
- num_spatial_dims: int = 2
392
-
393
-
394
- class BatchNorm3d(_BatchNorm):
395
- r"""3-D batch normalization [1]_.
396
-
397
- The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
398
- `h` is the height dimension, `w` is the width dimension, `d` is the depth
399
- dimension, and `c` is the channel dimension.
400
-
401
- %s
402
- """
403
- __module__ = 'brainstate.nn'
404
- num_spatial_dims: int = 3
405
-
406
-
407
- _bn_doc = r'''
408
-
409
- This layer aims to reduce the internal covariant shift of data. It
410
- normalizes a batch of data by fixing the mean and variance of inputs
411
- on each feature (channel). Most commonly, the first axis of the data
412
- is the batch, and the last is the channel. However, users can specify
413
- the axes to be normalized.
414
-
415
- .. math::
416
- y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
417
-
418
- .. note::
419
- This :attr:`momentum` argument is different from one used in optimizer
420
- classes and the conventional notion of momentum. Mathematically, the
421
- update rule for running statistics here is
422
- :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
423
- where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
424
- new observed value.
425
-
426
- Parameters
427
- ----------
428
- in_size: sequence of int
429
- The input shape, without batch size.
430
- feature_axis: int, tuple, list
431
- The feature or non-batch axis of the input.
432
- track_running_stats: bool
433
- A boolean value that when set to ``True``, this module tracks the running mean and variance,
434
- and when set to ``False``, this module does not track such statistics, and initializes
435
- statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
436
- this module always uses batch statistics. in both training and eval modes. Default: ``True``.
437
- momentum: float
438
- The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
439
- epsilon: float
440
- A value added to the denominator for numerical stability. Default: 1e-5
441
- affine: bool
442
- A boolean value that when set to ``True``, this module has
443
- learnable affine parameters. Default: ``True``
444
- bias_initializer: ArrayLike, Callable
445
- An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
446
- Default: ``init.Constant(0.)``
447
- scale_initializer: ArrayLike, Callable
448
- An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
449
- Default: ``init.Constant(1.)``
450
- axis_name: optional, str, sequence of str
451
- If not ``None``, it should be a string (or sequence of
452
- strings) representing the axis name(s) over which this module is being
453
- run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
454
- argument means that batch statistics are calculated across all replicas
455
- on the named axes.
456
- axis_index_groups: optional, sequence
457
- Specifies how devices are grouped. Valid
458
- only within ``jax.pmap`` collectives.
459
- Groups of axis indices within that named axis
460
- representing subsets of devices to reduce over (default: None). For
461
- example, `[[0, 1], [2, 3]]` would independently batch-normalize over
462
- the examples on the first two and last two devices. See `jax.lax.psum`
463
- for more details.
464
- use_fast_variance: If true, use a faster, but less numerically stable,
465
- calculation for the variance.
466
-
467
-
468
- References
469
- ----------
470
- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
471
- by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
472
-
473
- '''
474
-
475
- BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
476
- BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
477
- BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
478
-
479
-
480
- class LayerNorm(Module):
481
- """
482
- Layer normalization (https://arxiv.org/abs/1607.06450).
483
-
484
- LayerNorm normalizes the activations of the layer for each given example in a
485
- batch independently, rather than across a batch like Batch Normalization.
486
- i.e. applies a transformation that maintains the mean activation within
487
- each example close to 0 and the activation standard deviation close to 1.
488
-
489
- Example usage::
490
-
491
- >>> import brainstate as brainstate
492
- >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
493
- >>> layer = brainstate.nn.LayerNorm(x.shape)
494
- >>> layer.states()
495
- >>> y = layer(x)
496
-
497
- Attributes:
498
- in_size: The input shape, without batch size.
499
- epsilon: A small float added to variance to avoid dividing by zero.
500
- dtype: the dtype of the result (default: infer from input and params).
501
- use_bias: If True, bias (beta) is added.
502
- use_scale: If True, multiply by scale (gamma). When the next layer is linear
503
- (also e.g. nnx.relu), this can be disabled since the scaling will be done
504
- by the next layer.
505
- bias_init: Initializer for bias, by default, zero.
506
- scale_init: Initializer for scale, by default, one.
507
- reduction_axes: Axes for computing normalization statistics. It is recommended
508
- to use the negative integer, since when the batch dimension is used,
509
- the reduction_axes may be wrong when using the positive integer.
510
- feature_axes: Feature axes for learned bias and scaling.
511
- axis_name: the axis name used to combine batch statistics from multiple
512
- devices. See ``jax.pmap`` for a description of axis names (default: None).
513
- This is only needed if the model is subdivided across devices, i.e. the
514
- array being normalized is sharded across devices within a pmap.
515
- axis_index_groups: groups of axis indices within that named axis
516
- representing subsets of devices to reduce over (default: None). For
517
- example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
518
- the examples on the first two and last two devices. See ``jax.lax.psum``
519
- for more details.
520
- use_fast_variance: If true, use a faster, but less numerically stable,
521
- calculation for the variance.
522
- """
523
-
524
- def __init__(
525
- self,
526
- in_size: Size,
527
- reduction_axes: Axes = -1,
528
- feature_axes: Axes = -1,
529
- *,
530
- epsilon: float = 1e-6,
531
- use_bias: bool = True,
532
- use_scale: bool = True,
533
- bias_init: Callable = init.ZeroInit(),
534
- scale_init: Callable = init.Constant(1.0),
535
- axis_name: Optional[str] = None,
536
- axis_index_groups: Any = None,
537
- use_fast_variance: bool = True,
538
- dtype: Optional[jax.typing.DTypeLike] = None,
539
- param_type: type = NormalizationParamState,
540
- ):
541
- super().__init__()
542
-
543
- self.in_size = in_size
544
- self.out_size = in_size
545
-
546
- # parameters about axis
547
- feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
548
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
549
- self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
550
- self.axis_name = axis_name
551
- self.axis_index_groups = axis_index_groups
552
-
553
- # variables
554
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
555
- for i, ax in enumerate(self.in_size)])
556
-
557
- weights = dict()
558
- if use_scale:
559
- weights['scale'] = init.param(scale_init, feature_shape)
560
- if use_bias:
561
- weights['bias'] = init.param(bias_init, feature_shape)
562
- if len(weights):
563
- self.weight = param_type(weights)
564
- else:
565
- self.weight = None
566
-
567
- # parameters
568
- self.epsilon = epsilon
569
- self.dtype = dtype or environ.dftype()
570
- self.use_bias = use_bias
571
- self.use_scale = use_scale
572
- self.bias_init = bias_init
573
- self.scale_init = scale_init
574
- self.use_fast_variance = use_fast_variance
575
-
576
- def update(self, x, *, mask: Optional[jax.Array] = None):
577
- """Applies layer normalization on the input.
578
-
579
- Args:
580
- x: the inputs
581
-
582
- Returns:
583
- Normalized inputs (the same shape as inputs).
584
- """
585
- mean, var = _compute_stats(
586
- x,
587
- self.reduction_axes,
588
- dtype=self.dtype,
589
- axis_name=self.axis_name,
590
- axis_index_groups=self.axis_index_groups,
591
- use_fast_variance=self.use_fast_variance,
592
- mask=mask,
593
- )
594
-
595
- return _normalize(
596
- x,
597
- mean=mean,
598
- var=var,
599
- weights=self.weight,
600
- reduction_axes=self.reduction_axes,
601
- feature_axes=self.feature_axes,
602
- dtype=self.dtype,
603
- epsilon=self.epsilon,
604
- )
605
-
606
-
607
- class RMSNorm(Module):
608
- """
609
- RMS Layer normalization (https://arxiv.org/abs/1910.07467).
610
-
611
- RMSNorm normalizes the activations of the layer for each given example in a
612
- batch independently, rather than across a batch like Batch Normalization.
613
- Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
614
- standard deviation of the activations, RMSNorm does not re-center at all
615
- and instead normalizes by the root mean square of the activations.
616
-
617
- Example usage::
618
-
619
- >>> import brainstate as brainstate
620
- >>> x = brainstate.random.normal(size=(5, 6))
621
- >>> layer = brainstate.nn.RMSNorm(num_features=6)
622
- >>> layer.states()
623
- >>> y = layer(x)
624
-
625
- Attributes:
626
- in_size: The input shape, without batch size.
627
- epsilon: A small float added to variance to avoid dividing by zero.
628
- dtype: the dtype of the result (default: infer from input and params).
629
- use_scale: If True, multiply by scale (gamma). When the next layer is linear
630
- (also e.g. nn.relu), this can be disabled since the scaling will be done
631
- by the next layer.
632
- scale_init: Initializer for scale, by default, one.
633
- reduction_axes: Axes for computing normalization statistics. It is recommended
634
- to use the negative integer, since when the batch dimension is used,
635
- the reduction_axes may be wrong when using the positive integer.
636
- feature_axes: Feature axes for learned bias and scaling.
637
- axis_name: the axis name used to combine batch statistics from multiple
638
- devices. See ``jax.pmap`` for a description of axis names (default: None).
639
- This is only needed if the model is subdivided across devices, i.e. the
640
- array being normalized is sharded across devices within a pmap.
641
- axis_index_groups: groups of axis indices within that named axis
642
- representing subsets of devices to reduce over (default: None). For
643
- example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
644
- the examples on the first two and last two devices. See ``jax.lax.psum``
645
- for more details.
646
- use_fast_variance: If true, use a faster, but less numerically stable,
647
- calculation for the variance.
648
- """
649
-
650
- def __init__(
651
- self,
652
- in_size: Size,
653
- *,
654
- epsilon: float = 1e-6,
655
- dtype: Optional[jax.typing.DTypeLike] = None,
656
- use_scale: bool = True,
657
- scale_init: Callable = init.Constant(1.0),
658
- reduction_axes: Axes = -1,
659
- feature_axes: Axes = -1,
660
- axis_name: Optional[str] = None,
661
- axis_index_groups: Any = None,
662
- use_fast_variance: bool = True,
663
- param_type: type = NormalizationParamState,
664
- ):
665
- super().__init__()
666
-
667
- self.in_size = in_size
668
- self.out_size = in_size
669
-
670
- # parameters about axis
671
- feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
672
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
673
- self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
674
- self.axis_name = axis_name
675
- self.axis_index_groups = axis_index_groups
676
-
677
- # variables
678
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
679
- for i, ax in enumerate(self.in_size)])
680
- if use_scale:
681
- self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
682
- else:
683
- self.scale = None
684
-
685
- # parameters
686
- self.epsilon = epsilon
687
- self.dtype = dtype or environ.dftype()
688
- self.use_scale = use_scale
689
- self.scale_init = scale_init
690
- self.use_fast_variance = use_fast_variance
691
-
692
- def update(self, x, *, mask: Optional[jax.Array] = None):
693
- """Applies layer normalization on the input.
694
-
695
- Args:
696
- x: the inputs
697
- mask: the mask
698
-
699
- Returns:
700
- Normalized inputs (the same shape as inputs).
701
- """
702
- mean, var = _compute_stats(
703
- x,
704
- self.reduction_axes,
705
- dtype=self.dtype,
706
- axis_name=self.axis_name,
707
- axis_index_groups=self.axis_index_groups,
708
- use_mean=False,
709
- use_fast_variance=self.use_fast_variance,
710
- mask=mask,
711
- )
712
-
713
- return _normalize(
714
- x,
715
- mean=mean,
716
- var=var,
717
- weights=self.scale,
718
- reduction_axes=self.reduction_axes,
719
- feature_axes=self.feature_axes,
720
- dtype=self.dtype,
721
- epsilon=self.epsilon,
722
- )
723
-
724
-
725
- class GroupNorm(Module):
726
- """
727
- Group normalization (arxiv.org/abs/1803.08494).
728
-
729
- This op is similar to batch normalization, but statistics are shared across
730
- equally-sized groups of channels and not shared across batch dimension.
731
- Thus, group normalization does not depend on the batch composition and does
732
- not require maintaining internal state for storing statistics.
733
- The user should either specify the total number of channel groups or the
734
- number of channels per group.
735
-
736
- .. note::
737
- LayerNorm is a special case of GroupNorm where ``num_groups=1``.
738
-
739
- Example usage::
740
-
741
- >>> import numpy as np
742
- >>> import brainstate as brainstate
743
- ...
744
- >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
745
- >>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
746
- >>> layer.states()
747
- >>> y = layer(x)
748
- >>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
749
- >>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
750
- >>> np.testing.assert_allclose(y, y2)
751
-
752
- Attributes:
753
- in_size: The input shape, without batch size.
754
- num_groups: the total number of channel groups. The default value of 32 is
755
- proposed by the original group normalization paper.
756
- group_size: the number of channels in a group.
757
- epsilon: A small float added to variance to avoid dividing by zero.
758
- dtype: the dtype of the result (default: infer from input and params).
759
- use_bias: If True, bias (beta) is added.
760
- use_scale: If True, multiply by scale (gamma). When the next layer is linear
761
- (also e.g. nn.relu), this can be disabled since the scaling will be done
762
- by the next layer.
763
- bias_init: Initializer for bias, by default, zero.
764
- scale_init: Initializer for scale, by default, one.
765
- reduction_axes: List of axes used for computing normalization statistics.
766
- This list must include the final dimension, which is assumed to be the
767
- feature axis. Furthermore, if the input used at call time has additional
768
- leading axes compared to the data used for initialisation, for example due
769
- to batching, then the reduction axes need to be defined explicitly.
770
- It is recommended to use the negative integer, since when the batch dimension is used,
771
- the reduction_axes may be wrong when using the positive integer.
772
- axis_name: the axis name used to combine batch statistics from multiple
773
- devices. See ``jax.pmap`` for a description of axis names (default: None).
774
- This is only needed if the model is subdivided across devices, i.e. the
775
- array being normalized is sharded across devices within a pmap or shard
776
- map. For SPMD jit, you do not need to manually synchronize. Just make sure
777
- that the axes are correctly annotated and XLA:SPMD will insert the
778
- necessary collectives.
779
- axis_index_groups: groups of axis indices within that named axis
780
- representing subsets of devices to reduce over (default: None). For
781
- example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
782
- examples on the first two and last two devices. See ``jax.lax.psum`` for
783
- more details.
784
- use_fast_variance: If true, use a faster, but less numerically stable,
785
- calculation for the variance.
786
- """
787
-
788
- def __init__(
789
- self,
790
- in_size: Size,
791
- feature_axis: Axes = -1,
792
- num_groups: Optional[int] = 32,
793
- group_size: Optional[int] = None,
794
- *,
795
- epsilon: float = 1e-6,
796
- dtype: Optional[jax.typing.DTypeLike] = None,
797
- use_bias: bool = True,
798
- use_scale: bool = True,
799
- bias_init: Callable = init.ZeroInit(),
800
- scale_init: Callable = init.Constant(1.),
801
- reduction_axes: Optional[Axes] = None,
802
- axis_name: Optional[str] = None,
803
- axis_index_groups: Any = None,
804
- use_fast_variance: bool = True,
805
- param_type: type = NormalizationParamState,
806
- ):
807
- super().__init__()
808
-
809
- self.in_size = in_size
810
- self.out_size = in_size
811
-
812
- # parameters about axis
813
- feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
814
- self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
815
- self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
816
- self.axis_name = axis_name
817
- self.axis_index_groups = axis_index_groups
818
-
819
- if (num_groups is None and group_size is None) or (
820
- num_groups is not None and group_size is not None
821
- ):
822
- raise ValueError(
823
- 'Either `num_groups` or `group_size` should be '
824
- 'specified. If `group_size` is to be specified, '
825
- 'pass `num_groups=None` as argument to override '
826
- 'the default `num_groups` value of 32.'
827
- )
828
-
829
- feature_shape = tuple([(ax if i in self.feature_axes else 1)
830
- for i, ax in enumerate(self.in_size)])
831
- assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
832
- num_features = feature_shape[0]
833
- if group_size is not None:
834
- if num_features % group_size != 0:
835
- raise ValueError(
836
- 'Number of features ({}) is not multiple of the '
837
- 'group size ({}).'.format(num_features, group_size)
838
- )
839
- self.num_groups = num_features // group_size
840
- self.group_size = group_size
841
- else:
842
- if not isinstance(num_groups, int) or num_groups <= 0 or (
843
- num_features % num_groups != 0
844
- ):
845
- raise ValueError(
846
- 'Number of groups ({}) does not divide the number'
847
- ' of channels ({}).'.format(num_groups, num_features)
848
- )
849
- self.num_groups = num_groups
850
- self.group_size = num_features // num_groups
851
-
852
- # variables
853
- weights = dict()
854
- if use_scale:
855
- weights['scale'] = init.param(scale_init, feature_shape)
856
- if use_bias:
857
- weights['bias'] = init.param(bias_init, feature_shape)
858
- if len(weights):
859
- self.weight = param_type(weights)
860
- else:
861
- self.weight = None
862
-
863
- # parameters
864
- self.epsilon = epsilon
865
- self.dtype = dtype
866
- self.use_bias = use_bias
867
- self.use_scale = use_scale
868
- self.bias_init = bias_init
869
- self.scale_init = scale_init
870
- self.use_fast_variance = use_fast_variance
871
-
872
- def update(self, x, *, mask: Optional[jax.Array] = None):
873
- """Applies group normalization to the input (arxiv.org/abs/1803.08494).
874
-
875
- Args:
876
- x: the input of shape ``...self.num_features`` where ``self.num_features``
877
- is a channels dimension and ``...`` represents an arbitrary number of
878
- extra dimensions that can be used to accumulate statistics over. If no
879
- reduction axes have been specified then all additional dimensions ``...``
880
- will be used to accumulate statistics apart from the leading dimension
881
- which is assumed to represent the batch.
882
- mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
883
- the positions for which the mean and variance should be computed.
884
-
885
- Returns:
886
- Normalized inputs (the same shape as inputs).
887
- """
888
- if self.reduction_axes is not None:
889
- reduction_axes = self.reduction_axes
890
- else:
891
- reduction_axes = list(range(1, x.ndim - 1)) + [-1]
892
- reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
893
-
894
- group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
895
- if mask is not None:
896
- mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
897
-
898
- mean, var = _compute_stats(
899
- x.reshape(group_shape),
900
- list(reduction_axes[:-1]) + [-1],
901
- dtype=self.dtype,
902
- axis_name=self.axis_name,
903
- axis_index_groups=self.axis_index_groups,
904
- use_fast_variance=self.use_fast_variance,
905
- mask=mask,
906
- )
907
- mean = jnp.repeat(mean, self.group_size, axis=1)
908
- var = jnp.repeat(var, self.group_size, axis=1)
909
- return _normalize(
910
- x,
911
- mean=mean,
912
- var=var,
913
- weights=self.weight,
914
- reduction_axes=reduction_axes[:-1],
915
- feature_axes=self.feature_axes,
916
- dtype=self.dtype,
917
- epsilon=self.epsilon,
918
- )
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Callable, Union, Sequence, Optional, Any
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ import brainunit as u
24
+ from brainstate import environ, init
25
+ from brainstate._state import ParamState, BatchState
26
+ from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
27
+ from ._module import Module
28
+
29
+ __all__ = [
30
+ 'BatchNorm0d',
31
+ 'BatchNorm1d',
32
+ 'BatchNorm2d',
33
+ 'BatchNorm3d',
34
+ 'LayerNorm',
35
+ 'RMSNorm',
36
+ 'GroupNorm',
37
+ ]
38
+
39
+
40
+ def weight_standardization(
41
+ w: ArrayLike,
42
+ eps: float = 1e-4,
43
+ gain: Optional[jax.Array] = None,
44
+ out_axis: int = -1,
45
+ ) -> Union[jax.Array, u.Quantity]:
46
+ """
47
+ Scaled Weight Standardization,
48
+ see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
49
+
50
+ Parameters
51
+ ----------
52
+ w : ArrayLike
53
+ The weight tensor.
54
+ eps : float
55
+ A small value to avoid division by zero.
56
+ gain : Array
57
+ The gain function, by default None.
58
+ out_axis : int
59
+ The output axis, by default -1.
60
+
61
+ Returns
62
+ -------
63
+ ArrayLike
64
+ The scaled weight tensor.
65
+ """
66
+ w = u.maybe_custom_array(w)
67
+ if out_axis < 0:
68
+ out_axis = w.ndim + out_axis
69
+ fan_in = 1 # get the fan-in of the weight tensor
70
+ axes = [] # get the axes of the weight tensor
71
+ for i in range(w.ndim):
72
+ if i != out_axis:
73
+ fan_in *= w.shape[i]
74
+ axes.append(i)
75
+ # normalize the weight
76
+ mean = u.math.mean(w, axis=axes, keepdims=True)
77
+ var = u.math.var(w, axis=axes, keepdims=True)
78
+
79
+ temp = u.math.maximum(var * fan_in, eps)
80
+ if isinstance(temp, u.Quantity):
81
+ unit = temp.unit
82
+ temp = temp.mantissa
83
+ if unit.is_unitless:
84
+ scale = jax.lax.rsqrt(temp)
85
+ else:
86
+ scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
87
+ else:
88
+ scale = jax.lax.rsqrt(temp)
89
+ if gain is not None:
90
+ scale = gain * scale
91
+ shift = mean * scale
92
+ return w * scale - shift
93
+
94
+
95
+
96
+ def canonicalize_dtype(
97
+ *args,
98
+ dtype: jax.typing.DTypeLike | None = None,
99
+ inexact: bool = True
100
+ ) -> jax.typing.DTypeLike:
101
+ """Canonicalize an optional dtype to the definitive dtype.
102
+
103
+ If the ``dtype`` is None this function will infer the dtype. If it is not
104
+ None it will be returned unmodified or an exceptions is raised if the dtype
105
+ is invalid.
106
+ from the input arguments using ``jnp.result_type``.
107
+
108
+ Args:
109
+ *args: JAX array compatible values. None values
110
+ are ignored.
111
+ dtype: Optional dtype override. If specified the arguments are cast to
112
+ the specified dtype instead and dtype inference is disabled.
113
+ inexact: When True, the output dtype must be a subdtype
114
+ of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
115
+ is useful when you want to apply operations that don't work directly on
116
+ integers like taking a mean for example.
117
+ Returns:
118
+ The dtype that *args should be cast to.
119
+ """
120
+ if dtype is None:
121
+ args_filtered = [jnp.asarray(x) for x in args if x is not None]
122
+ dtype = jnp.result_type(*args_filtered)
123
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
124
+ dtype = jnp.promote_types(jnp.float32, dtype)
125
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
126
+ raise ValueError(f'Dtype must be inexact: {dtype}')
127
+ return dtype
128
+
129
+
130
+ def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
131
+ axes = []
132
+ for axis in feature_axes:
133
+ if axis < 0:
134
+ axis += ndim
135
+ if axis < 0 or axis >= ndim:
136
+ raise ValueError(f'Invalid axis {axis} for {ndim}D input')
137
+ axes.append(axis)
138
+ return tuple(axes)
139
+
140
+
141
+ def _abs_sq(x):
142
+ """Computes the elementwise square of the absolute value |x|^2."""
143
+ if jnp.iscomplexobj(x):
144
+ return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
145
+ else:
146
+ return jax.lax.square(x)
147
+
148
+
149
+ class NormalizationParamState(ParamState):
150
+ # This is a dummy class to be used as a compatibility
151
+ # usage of `ETraceParam` for the layers in "brainetrace"
152
+ def execute(self, x):
153
+ param = self.value
154
+ if 'scale' in param:
155
+ x = x * param['scale']
156
+ if 'bias' in param:
157
+ x = x + param['bias']
158
+ return x
159
+
160
+
161
+ def _compute_stats(
162
+ x: ArrayLike,
163
+ axes: Sequence[int],
164
+ dtype: DTypeLike,
165
+ axis_name: Optional[str] = None,
166
+ axis_index_groups: Optional[Sequence[int]] = None,
167
+ use_mean: bool = True,
168
+ use_fast_variance: bool = True,
169
+ mask: Optional[jax.Array] = None,
170
+ ):
171
+ """
172
+ Computes mean and variance statistics.
173
+
174
+ This implementation takes care of a few important details:
175
+ - Computes in float32 precision for stability in half precision training.
176
+ - If ``use_fast_variance`` is ``True``, mean and variance are computed using
177
+ Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
178
+ - Clips negative variances to zero which can happen due to
179
+ roundoff errors. This avoids downstream NaNs.
180
+ - Supports averaging across a parallel axis and subgroups of a parallel axis
181
+ with a single `lax.pmean` call to avoid latency.
182
+
183
+ Arguments:
184
+ x: Input array.
185
+ axes: The axes in ``x`` to compute mean and variance statistics for.
186
+ dtype: tp.Optional dtype specifying the minimal precision. Statistics
187
+ are always at least float32 for stability (default: dtype of x).
188
+ axis_name: Optional name for the pmapped axis to compute mean over. Note,
189
+ this is only used for pmap and shard map. For SPMD jit, you do not need to
190
+ manually synchronize. Just make sure that the axes are correctly annotated
191
+ and XLA:SPMD will insert the necessary collectives.
192
+ axis_index_groups: Optional axis indices.
193
+ use_mean: If true, calculate the mean from the input and use it when
194
+ computing the variance. If false, set the mean to zero and compute
195
+ the variance without subtracting the mean.
196
+ use_fast_variance: If true, use a faster, but less numerically stable,
197
+ calculation for the variance.
198
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
199
+ the positions for which the mean and variance should be computed.
200
+
201
+ Returns:
202
+ A pair ``(mean, val)``.
203
+ """
204
+ if dtype is None:
205
+ dtype = jax.numpy.result_type(x)
206
+ # promote x to at least float32, this avoids half precision computation
207
+ # but preserves double or complex floating points
208
+ dtype = jax.numpy.promote_types(dtype, jnp.float32)
209
+ x = jnp.asarray(x, dtype)
210
+ axes = _canonicalize_axes(x.ndim, axes)
211
+
212
+ def maybe_distributed_mean(*xs, mask=None):
213
+ mus = tuple(x.mean(axes, where=mask) for x in xs)
214
+ if axis_name is None:
215
+ return mus if len(xs) > 1 else mus[0]
216
+ else:
217
+ # In the distributed case we stack multiple arrays to speed comms.
218
+ if len(xs) > 1:
219
+ reduced_mus = jax.lax.pmean(
220
+ jnp.stack(mus, axis=0),
221
+ axis_name,
222
+ axis_index_groups=axis_index_groups,
223
+ )
224
+ return tuple(reduced_mus[i] for i in range(len(xs)))
225
+ else:
226
+ return jax.lax.pmean(
227
+ mus[0],
228
+ axis_name,
229
+ axis_index_groups=axis_index_groups
230
+ )
231
+
232
+ if use_mean:
233
+ if use_fast_variance:
234
+ mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
235
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
236
+ # to floating point round-off errors.
237
+ var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
238
+ else:
239
+ mu = maybe_distributed_mean(x, mask=mask)
240
+ var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
241
+ else:
242
+ var = maybe_distributed_mean(_abs_sq(x), mask=mask)
243
+ mu = jnp.zeros_like(var)
244
+ return mu, var
245
+
246
+
247
+ def _normalize(
248
+ x: ArrayLike,
249
+ mean: Optional[ArrayLike],
250
+ var: Optional[ArrayLike],
251
+ weights: Optional[NormalizationParamState],
252
+ reduction_axes: Axes,
253
+ feature_axes: Axes,
254
+ dtype: DTypeLike,
255
+ epsilon: jax.typing.ArrayLike,
256
+ ):
257
+ """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
258
+
259
+ Arguments:
260
+ x: The input.
261
+ mean: Mean to use for normalization.
262
+ var: Variance to use for normalization.
263
+ weights: The scale and bias parameters.
264
+ reduction_axes: The axes in ``x`` to reduce.
265
+ feature_axes: The feature axes to apply the scale and bias.
266
+ dtype: The dtype of the result (default: infer from input and params).
267
+ epsilon: Normalization epsilon.
268
+
269
+ Returns:
270
+ The normalized input.
271
+ """
272
+ if mean is not None:
273
+ assert var is not None, 'mean and val must be both None or not None.'
274
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
275
+ feature_axes = _canonicalize_axes(x.ndim, feature_axes)
276
+ stats_shape = list(x.shape)
277
+ for axis in reduction_axes:
278
+ stats_shape[axis] = 1
279
+ mean = mean.reshape(stats_shape)
280
+ var = var.reshape(stats_shape)
281
+ feature_shape = [1] * x.ndim
282
+ for ax in feature_axes:
283
+ feature_shape[ax] = x.shape[ax]
284
+ y = x - mean
285
+ mul = jax.lax.rsqrt(var + epsilon)
286
+ y = y * mul
287
+ if weights is not None:
288
+ y = weights.execute(y)
289
+ dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
290
+ else:
291
+ assert var is None, 'mean and val must be both None or not None.'
292
+ assert weights is None, 'scale and bias are not supported without mean and val'
293
+ y = x
294
+ return jnp.asarray(y, dtype)
295
+
296
+
297
+ class _BatchNorm(Module):
298
+ __module__ = 'brainstate.nn'
299
+ num_spatial_dims: int
300
+
301
+ def __init__(
302
+ self,
303
+ in_size: Size,
304
+ feature_axis: Axes = -1,
305
+ *,
306
+ track_running_stats: bool = True,
307
+ epsilon: float = 1e-5,
308
+ momentum: float = 0.99,
309
+ affine: bool = True,
310
+ bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
311
+ scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
312
+ axis_name: Optional[Union[str, Sequence[str]]] = None,
313
+ axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
314
+ use_fast_variance: bool = True,
315
+ name: Optional[str] = None,
316
+ dtype: Any = None,
317
+ param_type: type = NormalizationParamState,
318
+ mean_type: type = BatchState,
319
+ ):
320
+ super().__init__(name=name)
321
+
322
+ # parameters
323
+ self.in_size = in_size
324
+ self.out_size = in_size
325
+ self.affine = affine
326
+ self.bias_initializer = bias_initializer
327
+ self.scale_initializer = scale_initializer
328
+ self.dtype = dtype or environ.dftype()
329
+ self.track_running_stats = track_running_stats
330
+ self.momentum = jnp.asarray(momentum, dtype=self.dtype)
331
+ self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
332
+ self.use_fast_variance = use_fast_variance
333
+
334
+ # parameters about axis
335
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
336
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
337
+ self.axis_name = axis_name
338
+ self.axis_index_groups = axis_index_groups
339
+
340
+ # variables
341
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
342
+ for i, ax in enumerate(self.in_size)])
343
+ if self.track_running_stats:
344
+ self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
345
+ self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
346
+ else:
347
+ self.running_mean = None
348
+ self.running_var = None
349
+
350
+ # parameters
351
+ if self.affine:
352
+ assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
353
+ bias = init.param(self.bias_initializer, feature_shape)
354
+ scale = init.param(self.scale_initializer, feature_shape)
355
+ self.weight = param_type(dict(bias=bias, scale=scale))
356
+ else:
357
+ self.weight = None
358
+
359
+ def update(self, x, mask: Optional[jax.Array] = None):
360
+ # input shape and batch mode or not
361
+ if x.ndim == self.num_spatial_dims + 2:
362
+ x_shape = x.shape[1:]
363
+ batch = True
364
+ elif x.ndim == self.num_spatial_dims + 1:
365
+ x_shape = x.shape
366
+ batch = False
367
+ else:
368
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
369
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
370
+ if self.in_size != x_shape:
371
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
372
+
373
+ # reduce the feature axis
374
+ if batch:
375
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
376
+ else:
377
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
378
+
379
+ # fitting phase
380
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
381
+
382
+ # compute the running mean and variance
383
+ if self.track_running_stats:
384
+ if fit_phase:
385
+ mean, var = _compute_stats(
386
+ x,
387
+ reduction_axes,
388
+ dtype=self.dtype,
389
+ axis_name=self.axis_name,
390
+ axis_index_groups=self.axis_index_groups,
391
+ use_fast_variance=self.use_fast_variance,
392
+ mask=mask,
393
+ )
394
+ self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
395
+ self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
396
+ else:
397
+ mean = self.running_mean.value
398
+ var = self.running_var.value
399
+ else:
400
+ mean, var = None, None
401
+
402
+ # normalize
403
+ return _normalize(
404
+ x,
405
+ mean=mean,
406
+ var=var,
407
+ weights=self.weight,
408
+ reduction_axes=reduction_axes,
409
+ feature_axes=self.feature_axes,
410
+ dtype=self.dtype,
411
+ epsilon=self.epsilon
412
+ )
413
+
414
+
415
+ class BatchNorm0d(_BatchNorm):
416
+ r"""0-D batch normalization [1]_.
417
+
418
+ The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
419
+
420
+ %s
421
+ """
422
+ __module__ = 'brainstate.nn'
423
+ num_spatial_dims: int = 0
424
+
425
+
426
+ class BatchNorm1d(_BatchNorm):
427
+ r"""1-D batch normalization [1]_.
428
+
429
+ The data should be of `(b, l, c)`, where `b` is the batch dimension,
430
+ `l` is the layer dimension, and `c` is the channel dimension.
431
+
432
+ %s
433
+ """
434
+ __module__ = 'brainstate.nn'
435
+ num_spatial_dims: int = 1
436
+
437
+
438
+ class BatchNorm2d(_BatchNorm):
439
+ r"""2-D batch normalization [1]_.
440
+
441
+ The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
442
+ `h` is the height dimension, `w` is the width dimension, and `c` is the
443
+ channel dimension.
444
+
445
+ %s
446
+ """
447
+ __module__ = 'brainstate.nn'
448
+ num_spatial_dims: int = 2
449
+
450
+
451
+ class BatchNorm3d(_BatchNorm):
452
+ r"""3-D batch normalization [1]_.
453
+
454
+ The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
455
+ `h` is the height dimension, `w` is the width dimension, `d` is the depth
456
+ dimension, and `c` is the channel dimension.
457
+
458
+ %s
459
+ """
460
+ __module__ = 'brainstate.nn'
461
+ num_spatial_dims: int = 3
462
+
463
+
464
+ _bn_doc = r'''
465
+
466
+ This layer aims to reduce the internal covariant shift of data. It
467
+ normalizes a batch of data by fixing the mean and variance of inputs
468
+ on each feature (channel). Most commonly, the first axis of the data
469
+ is the batch, and the last is the channel. However, users can specify
470
+ the axes to be normalized.
471
+
472
+ .. math::
473
+ y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
474
+
475
+ .. note::
476
+ This :attr:`momentum` argument is different from one used in optimizer
477
+ classes and the conventional notion of momentum. Mathematically, the
478
+ update rule for running statistics here is
479
+ :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
480
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
481
+ new observed value.
482
+
483
+ Parameters
484
+ ----------
485
+ in_size: sequence of int
486
+ The input shape, without batch size.
487
+ feature_axis: int, tuple, list
488
+ The feature or non-batch axis of the input.
489
+ track_running_stats: bool
490
+ A boolean value that when set to ``True``, this module tracks the running mean and variance,
491
+ and when set to ``False``, this module does not track such statistics, and initializes
492
+ statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
493
+ this module always uses batch statistics. in both training and eval modes. Default: ``True``.
494
+ momentum: float
495
+ The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
496
+ epsilon: float
497
+ A value added to the denominator for numerical stability. Default: 1e-5
498
+ affine: bool
499
+ A boolean value that when set to ``True``, this module has
500
+ learnable affine parameters. Default: ``True``
501
+ bias_initializer: ArrayLike, Callable
502
+ An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
503
+ Default: ``init.Constant(0.)``
504
+ scale_initializer: ArrayLike, Callable
505
+ An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
506
+ Default: ``init.Constant(1.)``
507
+ axis_name: optional, str, sequence of str
508
+ If not ``None``, it should be a string (or sequence of
509
+ strings) representing the axis name(s) over which this module is being
510
+ run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
511
+ argument means that batch statistics are calculated across all replicas
512
+ on the named axes.
513
+ axis_index_groups: optional, sequence
514
+ Specifies how devices are grouped. Valid
515
+ only within ``jax.pmap`` collectives.
516
+ Groups of axis indices within that named axis
517
+ representing subsets of devices to reduce over (default: None). For
518
+ example, `[[0, 1], [2, 3]]` would independently batch-normalize over
519
+ the examples on the first two and last two devices. See `jax.lax.psum`
520
+ for more details.
521
+ use_fast_variance: If true, use a faster, but less numerically stable,
522
+ calculation for the variance.
523
+
524
+
525
+ References
526
+ ----------
527
+ .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
528
+ by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
529
+
530
+ '''
531
+
532
+ BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
533
+ BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
534
+ BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
535
+
536
+
537
+ class LayerNorm(Module):
538
+ """
539
+ Layer normalization (https://arxiv.org/abs/1607.06450).
540
+
541
+ LayerNorm normalizes the activations of the layer for each given example in a
542
+ batch independently, rather than across a batch like Batch Normalization.
543
+ i.e. applies a transformation that maintains the mean activation within
544
+ each example close to 0 and the activation standard deviation close to 1.
545
+
546
+ Example usage::
547
+
548
+ >>> import brainstate as brainstate
549
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
550
+ >>> layer = brainstate.nn.LayerNorm(x.shape)
551
+ >>> layer.states()
552
+ >>> y = layer(x)
553
+
554
+ Attributes:
555
+ in_size: The input shape, without batch size.
556
+ epsilon: A small float added to variance to avoid dividing by zero.
557
+ dtype: the dtype of the result (default: infer from input and params).
558
+ use_bias: If True, bias (beta) is added.
559
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
560
+ (also e.g. nnx.relu), this can be disabled since the scaling will be done
561
+ by the next layer.
562
+ bias_init: Initializer for bias, by default, zero.
563
+ scale_init: Initializer for scale, by default, one.
564
+ reduction_axes: Axes for computing normalization statistics. It is recommended
565
+ to use the negative integer, since when the batch dimension is used,
566
+ the reduction_axes may be wrong when using the positive integer.
567
+ feature_axes: Feature axes for learned bias and scaling.
568
+ axis_name: the axis name used to combine batch statistics from multiple
569
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
570
+ This is only needed if the model is subdivided across devices, i.e. the
571
+ array being normalized is sharded across devices within a pmap.
572
+ axis_index_groups: groups of axis indices within that named axis
573
+ representing subsets of devices to reduce over (default: None). For
574
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
575
+ the examples on the first two and last two devices. See ``jax.lax.psum``
576
+ for more details.
577
+ use_fast_variance: If true, use a faster, but less numerically stable,
578
+ calculation for the variance.
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ in_size: Size,
584
+ reduction_axes: Axes = -1,
585
+ feature_axes: Axes = -1,
586
+ *,
587
+ epsilon: float = 1e-6,
588
+ use_bias: bool = True,
589
+ use_scale: bool = True,
590
+ bias_init: Callable = init.ZeroInit(),
591
+ scale_init: Callable = init.Constant(1.0),
592
+ axis_name: Optional[str] = None,
593
+ axis_index_groups: Any = None,
594
+ use_fast_variance: bool = True,
595
+ dtype: Optional[jax.typing.DTypeLike] = None,
596
+ param_type: type = NormalizationParamState,
597
+ ):
598
+ super().__init__()
599
+
600
+ self.in_size = in_size
601
+ self.out_size = in_size
602
+
603
+ # parameters about axis
604
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
605
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
606
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
607
+ self.axis_name = axis_name
608
+ self.axis_index_groups = axis_index_groups
609
+
610
+ # variables
611
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
612
+ for i, ax in enumerate(self.in_size)])
613
+
614
+ weights = dict()
615
+ if use_scale:
616
+ weights['scale'] = init.param(scale_init, feature_shape)
617
+ if use_bias:
618
+ weights['bias'] = init.param(bias_init, feature_shape)
619
+ if len(weights):
620
+ self.weight = param_type(weights)
621
+ else:
622
+ self.weight = None
623
+
624
+ # parameters
625
+ self.epsilon = epsilon
626
+ self.dtype = dtype or environ.dftype()
627
+ self.use_bias = use_bias
628
+ self.use_scale = use_scale
629
+ self.bias_init = bias_init
630
+ self.scale_init = scale_init
631
+ self.use_fast_variance = use_fast_variance
632
+
633
+ def update(self, x, *, mask: Optional[jax.Array] = None):
634
+ """Applies layer normalization on the input.
635
+
636
+ Args:
637
+ x: the inputs
638
+
639
+ Returns:
640
+ Normalized inputs (the same shape as inputs).
641
+ """
642
+ mean, var = _compute_stats(
643
+ x,
644
+ self.reduction_axes,
645
+ dtype=self.dtype,
646
+ axis_name=self.axis_name,
647
+ axis_index_groups=self.axis_index_groups,
648
+ use_fast_variance=self.use_fast_variance,
649
+ mask=mask,
650
+ )
651
+
652
+ return _normalize(
653
+ x,
654
+ mean=mean,
655
+ var=var,
656
+ weights=self.weight,
657
+ reduction_axes=self.reduction_axes,
658
+ feature_axes=self.feature_axes,
659
+ dtype=self.dtype,
660
+ epsilon=self.epsilon,
661
+ )
662
+
663
+
664
+ class RMSNorm(Module):
665
+ """
666
+ RMS Layer normalization (https://arxiv.org/abs/1910.07467).
667
+
668
+ RMSNorm normalizes the activations of the layer for each given example in a
669
+ batch independently, rather than across a batch like Batch Normalization.
670
+ Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
671
+ standard deviation of the activations, RMSNorm does not re-center at all
672
+ and instead normalizes by the root mean square of the activations.
673
+
674
+ Example usage::
675
+
676
+ >>> import brainstate as brainstate
677
+ >>> x = brainstate.random.normal(size=(5, 6))
678
+ >>> layer = brainstate.nn.RMSNorm(num_features=6)
679
+ >>> layer.states()
680
+ >>> y = layer(x)
681
+
682
+ Attributes:
683
+ in_size: The input shape, without batch size.
684
+ epsilon: A small float added to variance to avoid dividing by zero.
685
+ dtype: the dtype of the result (default: infer from input and params).
686
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
687
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
688
+ by the next layer.
689
+ scale_init: Initializer for scale, by default, one.
690
+ reduction_axes: Axes for computing normalization statistics. It is recommended
691
+ to use the negative integer, since when the batch dimension is used,
692
+ the reduction_axes may be wrong when using the positive integer.
693
+ feature_axes: Feature axes for learned bias and scaling.
694
+ axis_name: the axis name used to combine batch statistics from multiple
695
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
696
+ This is only needed if the model is subdivided across devices, i.e. the
697
+ array being normalized is sharded across devices within a pmap.
698
+ axis_index_groups: groups of axis indices within that named axis
699
+ representing subsets of devices to reduce over (default: None). For
700
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
701
+ the examples on the first two and last two devices. See ``jax.lax.psum``
702
+ for more details.
703
+ use_fast_variance: If true, use a faster, but less numerically stable,
704
+ calculation for the variance.
705
+ """
706
+
707
+ def __init__(
708
+ self,
709
+ in_size: Size,
710
+ *,
711
+ epsilon: float = 1e-6,
712
+ dtype: Optional[jax.typing.DTypeLike] = None,
713
+ use_scale: bool = True,
714
+ scale_init: Callable = init.Constant(1.0),
715
+ reduction_axes: Axes = -1,
716
+ feature_axes: Axes = -1,
717
+ axis_name: Optional[str] = None,
718
+ axis_index_groups: Any = None,
719
+ use_fast_variance: bool = True,
720
+ param_type: type = NormalizationParamState,
721
+ ):
722
+ super().__init__()
723
+
724
+ self.in_size = in_size
725
+ self.out_size = in_size
726
+
727
+ # parameters about axis
728
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
729
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
730
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
731
+ self.axis_name = axis_name
732
+ self.axis_index_groups = axis_index_groups
733
+
734
+ # variables
735
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
736
+ for i, ax in enumerate(self.in_size)])
737
+ if use_scale:
738
+ self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
739
+ else:
740
+ self.scale = None
741
+
742
+ # parameters
743
+ self.epsilon = epsilon
744
+ self.dtype = dtype or environ.dftype()
745
+ self.use_scale = use_scale
746
+ self.scale_init = scale_init
747
+ self.use_fast_variance = use_fast_variance
748
+
749
+ def update(self, x, *, mask: Optional[jax.Array] = None):
750
+ """Applies layer normalization on the input.
751
+
752
+ Args:
753
+ x: the inputs
754
+ mask: the mask
755
+
756
+ Returns:
757
+ Normalized inputs (the same shape as inputs).
758
+ """
759
+ mean, var = _compute_stats(
760
+ x,
761
+ self.reduction_axes,
762
+ dtype=self.dtype,
763
+ axis_name=self.axis_name,
764
+ axis_index_groups=self.axis_index_groups,
765
+ use_mean=False,
766
+ use_fast_variance=self.use_fast_variance,
767
+ mask=mask,
768
+ )
769
+
770
+ return _normalize(
771
+ x,
772
+ mean=mean,
773
+ var=var,
774
+ weights=self.scale,
775
+ reduction_axes=self.reduction_axes,
776
+ feature_axes=self.feature_axes,
777
+ dtype=self.dtype,
778
+ epsilon=self.epsilon,
779
+ )
780
+
781
+
782
+ class GroupNorm(Module):
783
+ """
784
+ Group normalization (arxiv.org/abs/1803.08494).
785
+
786
+ This op is similar to batch normalization, but statistics are shared across
787
+ equally-sized groups of channels and not shared across batch dimension.
788
+ Thus, group normalization does not depend on the batch composition and does
789
+ not require maintaining internal state for storing statistics.
790
+ The user should either specify the total number of channel groups or the
791
+ number of channels per group.
792
+
793
+ .. note::
794
+ LayerNorm is a special case of GroupNorm where ``num_groups=1``.
795
+
796
+ Example usage::
797
+
798
+ >>> import numpy as np
799
+ >>> import brainstate as brainstate
800
+ ...
801
+ >>> x = brainstate.random.normal(size=(3, 4, 5, 6))
802
+ >>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3)
803
+ >>> layer.states()
804
+ >>> y = layer(x)
805
+ >>> y = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x)
806
+ >>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
807
+ >>> np.testing.assert_allclose(y, y2)
808
+
809
+ Attributes:
810
+ in_size: The input shape, without batch size.
811
+ num_groups: the total number of channel groups. The default value of 32 is
812
+ proposed by the original group normalization paper.
813
+ group_size: the number of channels in a group.
814
+ epsilon: A small float added to variance to avoid dividing by zero.
815
+ dtype: the dtype of the result (default: infer from input and params).
816
+ use_bias: If True, bias (beta) is added.
817
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
818
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
819
+ by the next layer.
820
+ bias_init: Initializer for bias, by default, zero.
821
+ scale_init: Initializer for scale, by default, one.
822
+ reduction_axes: List of axes used for computing normalization statistics.
823
+ This list must include the final dimension, which is assumed to be the
824
+ feature axis. Furthermore, if the input used at call time has additional
825
+ leading axes compared to the data used for initialisation, for example due
826
+ to batching, then the reduction axes need to be defined explicitly.
827
+ It is recommended to use the negative integer, since when the batch dimension is used,
828
+ the reduction_axes may be wrong when using the positive integer.
829
+ axis_name: the axis name used to combine batch statistics from multiple
830
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
831
+ This is only needed if the model is subdivided across devices, i.e. the
832
+ array being normalized is sharded across devices within a pmap or shard
833
+ map. For SPMD jit, you do not need to manually synchronize. Just make sure
834
+ that the axes are correctly annotated and XLA:SPMD will insert the
835
+ necessary collectives.
836
+ axis_index_groups: groups of axis indices within that named axis
837
+ representing subsets of devices to reduce over (default: None). For
838
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
839
+ examples on the first two and last two devices. See ``jax.lax.psum`` for
840
+ more details.
841
+ use_fast_variance: If true, use a faster, but less numerically stable,
842
+ calculation for the variance.
843
+ """
844
+
845
+ def __init__(
846
+ self,
847
+ in_size: Size,
848
+ feature_axis: Axes = -1,
849
+ num_groups: Optional[int] = 32,
850
+ group_size: Optional[int] = None,
851
+ *,
852
+ epsilon: float = 1e-6,
853
+ dtype: Optional[jax.typing.DTypeLike] = None,
854
+ use_bias: bool = True,
855
+ use_scale: bool = True,
856
+ bias_init: Callable = init.ZeroInit(),
857
+ scale_init: Callable = init.Constant(1.),
858
+ reduction_axes: Optional[Axes] = None,
859
+ axis_name: Optional[str] = None,
860
+ axis_index_groups: Any = None,
861
+ use_fast_variance: bool = True,
862
+ param_type: type = NormalizationParamState,
863
+ ):
864
+ super().__init__()
865
+
866
+ self.in_size = in_size
867
+ self.out_size = in_size
868
+
869
+ # parameters about axis
870
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
871
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
872
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
873
+ self.axis_name = axis_name
874
+ self.axis_index_groups = axis_index_groups
875
+
876
+ if (num_groups is None and group_size is None) or (
877
+ num_groups is not None and group_size is not None
878
+ ):
879
+ raise ValueError(
880
+ 'Either `num_groups` or `group_size` should be '
881
+ 'specified. If `group_size` is to be specified, '
882
+ 'pass `num_groups=None` as argument to override '
883
+ 'the default `num_groups` value of 32.'
884
+ )
885
+
886
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
887
+ for i, ax in enumerate(self.in_size)])
888
+ assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
889
+ num_features = feature_shape[0]
890
+ if group_size is not None:
891
+ if num_features % group_size != 0:
892
+ raise ValueError(
893
+ 'Number of features ({}) is not multiple of the '
894
+ 'group size ({}).'.format(num_features, group_size)
895
+ )
896
+ self.num_groups = num_features // group_size
897
+ self.group_size = group_size
898
+ else:
899
+ if not isinstance(num_groups, int) or num_groups <= 0 or (
900
+ num_features % num_groups != 0
901
+ ):
902
+ raise ValueError(
903
+ 'Number of groups ({}) does not divide the number'
904
+ ' of channels ({}).'.format(num_groups, num_features)
905
+ )
906
+ self.num_groups = num_groups
907
+ self.group_size = num_features // num_groups
908
+
909
+ # variables
910
+ weights = dict()
911
+ if use_scale:
912
+ weights['scale'] = init.param(scale_init, feature_shape)
913
+ if use_bias:
914
+ weights['bias'] = init.param(bias_init, feature_shape)
915
+ if len(weights):
916
+ self.weight = param_type(weights)
917
+ else:
918
+ self.weight = None
919
+
920
+ # parameters
921
+ self.epsilon = epsilon
922
+ self.dtype = dtype
923
+ self.use_bias = use_bias
924
+ self.use_scale = use_scale
925
+ self.bias_init = bias_init
926
+ self.scale_init = scale_init
927
+ self.use_fast_variance = use_fast_variance
928
+
929
+ def update(self, x, *, mask: Optional[jax.Array] = None):
930
+ """Applies group normalization to the input (arxiv.org/abs/1803.08494).
931
+
932
+ Args:
933
+ x: the input of shape ``...self.num_features`` where ``self.num_features``
934
+ is a channels dimension and ``...`` represents an arbitrary number of
935
+ extra dimensions that can be used to accumulate statistics over. If no
936
+ reduction axes have been specified then all additional dimensions ``...``
937
+ will be used to accumulate statistics apart from the leading dimension
938
+ which is assumed to represent the batch.
939
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
940
+ the positions for which the mean and variance should be computed.
941
+
942
+ Returns:
943
+ Normalized inputs (the same shape as inputs).
944
+ """
945
+ if self.reduction_axes is not None:
946
+ reduction_axes = self.reduction_axes
947
+ else:
948
+ reduction_axes = list(range(1, x.ndim - 1)) + [-1]
949
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
950
+
951
+ group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
952
+ if mask is not None:
953
+ mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
954
+
955
+ mean, var = _compute_stats(
956
+ x.reshape(group_shape),
957
+ list(reduction_axes[:-1]) + [-1],
958
+ dtype=self.dtype,
959
+ axis_name=self.axis_name,
960
+ axis_index_groups=self.axis_index_groups,
961
+ use_fast_variance=self.use_fast_variance,
962
+ mask=mask,
963
+ )
964
+ mean = jnp.repeat(mean, self.group_size, axis=1)
965
+ var = jnp.repeat(var, self.group_size, axis=1)
966
+ return _normalize(
967
+ x,
968
+ mean=mean,
969
+ var=var,
970
+ weights=self.weight,
971
+ reduction_axes=reduction_axes[:-1],
972
+ feature_axes=self.feature_axes,
973
+ dtype=self.dtype,
974
+ epsilon=self.epsilon,
975
+ )