brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_normalizations.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -17,16 +17,18 @@
|
|
17
17
|
|
18
18
|
from typing import Callable, Union, Sequence, Optional, Any
|
19
19
|
|
20
|
+
import brainunit as u
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
|
23
|
-
|
24
|
-
from brainstate import environ, init
|
24
|
+
from brainstate import environ
|
25
25
|
from brainstate._state import ParamState, BatchState
|
26
26
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
27
|
+
from . import init as init
|
27
28
|
from ._module import Module
|
28
29
|
|
29
30
|
__all__ = [
|
31
|
+
'weight_standardization',
|
30
32
|
'BatchNorm0d',
|
31
33
|
'BatchNorm1d',
|
32
34
|
'BatchNorm2d',
|
@@ -44,24 +46,47 @@ def weight_standardization(
|
|
44
46
|
out_axis: int = -1,
|
45
47
|
) -> Union[jax.Array, u.Quantity]:
|
46
48
|
"""
|
47
|
-
Scaled Weight Standardization
|
48
|
-
|
49
|
+
Scaled Weight Standardization.
|
50
|
+
|
51
|
+
Applies weight standardization to improve training stability, as described in
|
52
|
+
"Micro-Batch Training with Batch-Channel Normalization and Weight Standardization" [1]_.
|
49
53
|
|
50
54
|
Parameters
|
51
55
|
----------
|
52
56
|
w : ArrayLike
|
53
|
-
The weight tensor.
|
54
|
-
eps : float
|
55
|
-
A small value to avoid division by zero.
|
56
|
-
gain : Array
|
57
|
-
|
58
|
-
out_axis : int
|
59
|
-
The output axis
|
57
|
+
The weight tensor to be standardized.
|
58
|
+
eps : float, optional
|
59
|
+
A small value added to variance to avoid division by zero. Default is 1e-4.
|
60
|
+
gain : jax.Array, optional
|
61
|
+
Optional gain parameter to scale the standardized weights. Default is None.
|
62
|
+
out_axis : int, optional
|
63
|
+
The output axis of the weight tensor. Default is -1.
|
60
64
|
|
61
65
|
Returns
|
62
66
|
-------
|
63
|
-
|
64
|
-
The
|
67
|
+
jax.Array or u.Quantity
|
68
|
+
The standardized weight tensor with the same shape as input.
|
69
|
+
|
70
|
+
References
|
71
|
+
----------
|
72
|
+
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
73
|
+
Micro-Batch Training with Batch-Channel Normalization and Weight Standardization.
|
74
|
+
arXiv preprint arXiv:1903.10520.
|
75
|
+
|
76
|
+
Examples
|
77
|
+
--------
|
78
|
+
.. code-block:: python
|
79
|
+
|
80
|
+
>>> import brainstate as bst
|
81
|
+
>>> import jax.numpy as jnp
|
82
|
+
>>>
|
83
|
+
>>> # Standardize a weight matrix
|
84
|
+
>>> w = jnp.ones((3, 4))
|
85
|
+
>>> w_std = bst.nn.weight_standardization(w)
|
86
|
+
>>>
|
87
|
+
>>> # With custom gain
|
88
|
+
>>> gain = jnp.ones((4,))
|
89
|
+
>>> w_std = bst.nn.weight_standardization(w, gain=gain)
|
65
90
|
"""
|
66
91
|
w = u.maybe_custom_array(w)
|
67
92
|
if out_axis < 0:
|
@@ -92,30 +117,53 @@ def weight_standardization(
|
|
92
117
|
return w * scale - shift
|
93
118
|
|
94
119
|
|
95
|
-
|
96
120
|
def canonicalize_dtype(
|
97
121
|
*args,
|
98
122
|
dtype: jax.typing.DTypeLike | None = None,
|
99
123
|
inexact: bool = True
|
100
124
|
) -> jax.typing.DTypeLike:
|
101
|
-
"""
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
is
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
125
|
+
"""
|
126
|
+
Canonicalize an optional dtype to the definitive dtype.
|
127
|
+
|
128
|
+
If the ``dtype`` is None, this function will infer the dtype from the input
|
129
|
+
arguments using ``jnp.result_type``. If it is not None, it will be returned
|
130
|
+
unmodified or an exception is raised if the dtype is invalid.
|
131
|
+
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
*args : ArrayLike
|
135
|
+
JAX array compatible values. None values are ignored.
|
136
|
+
dtype : jax.typing.DTypeLike, optional
|
137
|
+
Optional dtype override. If specified, the arguments are cast to the
|
138
|
+
specified dtype and dtype inference is disabled. Default is None.
|
139
|
+
inexact : bool, optional
|
140
|
+
When True, the output dtype must be a subtype of ``jnp.inexact``.
|
141
|
+
Inexact dtypes are real or complex floating points. This is useful
|
142
|
+
when applying operations that don't work directly on integers like
|
143
|
+
taking a mean. Default is True.
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
jax.typing.DTypeLike
|
148
|
+
The dtype that ``*args`` should be cast to.
|
149
|
+
|
150
|
+
Raises
|
151
|
+
------
|
152
|
+
ValueError
|
153
|
+
If ``inexact=True`` and the resulting dtype is not an inexact type.
|
154
|
+
|
155
|
+
Examples
|
156
|
+
--------
|
157
|
+
.. code-block:: python
|
158
|
+
|
159
|
+
>>> import jax.numpy as jnp
|
160
|
+
>>>
|
161
|
+
>>> # Infer dtype from arguments
|
162
|
+
>>> x = jnp.array([1, 2, 3])
|
163
|
+
>>> dtype = canonicalize_dtype(x)
|
164
|
+
>>>
|
165
|
+
>>> # Specify explicit dtype
|
166
|
+
>>> dtype = canonicalize_dtype(x, dtype=jnp.float64)
|
119
167
|
"""
|
120
168
|
if dtype is None:
|
121
169
|
args_filtered = [jnp.asarray(x) for x in args if x is not None]
|
@@ -169,37 +217,47 @@ def _compute_stats(
|
|
169
217
|
mask: Optional[jax.Array] = None,
|
170
218
|
):
|
171
219
|
"""
|
172
|
-
|
220
|
+
Compute mean and variance statistics for normalization.
|
221
|
+
|
222
|
+
This implementation includes several optimizations:
|
173
223
|
|
174
|
-
This implementation takes care of a few important details:
|
175
224
|
- Computes in float32 precision for stability in half precision training.
|
176
|
-
- If ``use_fast_variance`` is
|
177
|
-
|
178
|
-
- Clips negative variances to zero
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
225
|
+
- If ``use_fast_variance`` is True, uses the formula Var = E[|x|^2] - |E[x]|^2
|
226
|
+
instead of Var = E[|x - E[x]|^2] in a single XLA fusion.
|
227
|
+
- Clips negative variances to zero to avoid downstream NaNs from roundoff errors.
|
228
|
+
- Supports averaging across parallel axes and subgroups with a single
|
229
|
+
``lax.pmean`` call to reduce latency.
|
230
|
+
|
231
|
+
Parameters
|
232
|
+
----------
|
233
|
+
x : ArrayLike
|
234
|
+
Input array.
|
235
|
+
axes : Sequence[int]
|
236
|
+
The axes in ``x`` to compute mean and variance statistics for.
|
237
|
+
dtype : DTypeLike
|
238
|
+
Optional dtype specifying the minimal precision. Statistics are always
|
239
|
+
at least float32 for stability. If None, uses the dtype of x.
|
240
|
+
axis_name : str, optional
|
241
|
+
Optional name for the pmapped axis to compute mean over. Only used for
|
242
|
+
pmap and shard map. For SPMD jit, axes should be correctly annotated
|
243
|
+
and XLA:SPMD will insert necessary collectives. Default is None.
|
244
|
+
axis_index_groups : Sequence[int], optional
|
245
|
+
Optional axis indices for grouped reductions. Default is None.
|
246
|
+
use_mean : bool, optional
|
247
|
+
If True, calculate the mean from the input and use it when computing
|
248
|
+
the variance. If False, set the mean to zero and compute the variance
|
249
|
+
without subtracting the mean. Default is True.
|
250
|
+
use_fast_variance : bool, optional
|
251
|
+
If True, use a faster but less numerically stable calculation for the
|
252
|
+
variance. Default is True.
|
253
|
+
mask : jax.Array, optional
|
254
|
+
Binary array of shape broadcastable to ``x``, indicating the positions
|
255
|
+
for which the mean and variance should be computed. Default is None.
|
256
|
+
|
257
|
+
Returns
|
258
|
+
-------
|
259
|
+
tuple of jax.Array
|
260
|
+
A pair ``(mean, var)`` containing the computed mean and variance.
|
203
261
|
"""
|
204
262
|
if dtype is None:
|
205
263
|
dtype = jax.numpy.result_type(x)
|
@@ -254,20 +312,32 @@ def _normalize(
|
|
254
312
|
dtype: DTypeLike,
|
255
313
|
epsilon: jax.typing.ArrayLike,
|
256
314
|
):
|
257
|
-
"""
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
315
|
+
"""
|
316
|
+
Normalize the input and optionally apply learned scale and bias.
|
317
|
+
|
318
|
+
Parameters
|
319
|
+
----------
|
320
|
+
x : ArrayLike
|
321
|
+
The input array.
|
322
|
+
mean : ArrayLike, optional
|
323
|
+
Mean to use for normalization. If None, normalization is skipped.
|
324
|
+
var : ArrayLike, optional
|
325
|
+
Variance to use for normalization. If None, normalization is skipped.
|
326
|
+
weights : NormalizationParamState, optional
|
327
|
+
The scale and bias parameters. If None, no affine transformation is applied.
|
328
|
+
reduction_axes : Axes
|
329
|
+
The axes in ``x`` to reduce.
|
330
|
+
feature_axes : Axes
|
331
|
+
The feature axes to apply the scale and bias.
|
332
|
+
dtype : DTypeLike
|
333
|
+
The dtype of the result. If None, inferred from input and parameters.
|
334
|
+
epsilon : jax.typing.ArrayLike
|
335
|
+
A small value added to variance to avoid division by zero.
|
336
|
+
|
337
|
+
Returns
|
338
|
+
-------
|
339
|
+
jax.Array
|
340
|
+
The normalized input array.
|
271
341
|
"""
|
272
342
|
if mean is not None:
|
273
343
|
assert var is not None, 'mean and val must be both None or not None.'
|
@@ -413,169 +483,379 @@ class _BatchNorm(Module):
|
|
413
483
|
|
414
484
|
|
415
485
|
class BatchNorm0d(_BatchNorm):
|
416
|
-
|
486
|
+
"""
|
487
|
+
0-D batch normalization.
|
417
488
|
|
418
|
-
|
489
|
+
Normalizes a batch of 0-D data (vectors) by fixing the mean and variance
|
490
|
+
of inputs on each feature (channel). This layer aims to reduce the internal
|
491
|
+
covariate shift of data.
|
419
492
|
|
420
|
-
|
493
|
+
The input data should have shape ``(b, c)``, where ``b`` is the batch dimension
|
494
|
+
and ``c`` is the channel dimension.
|
495
|
+
|
496
|
+
The normalization is performed as:
|
497
|
+
|
498
|
+
.. math::
|
499
|
+
y = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\operatorname{Var}[x] + \\epsilon}} \\cdot \\gamma + \\beta
|
500
|
+
|
501
|
+
where :math:`\\gamma` and :math:`\\beta` are learnable affine parameters (if ``affine=True``).
|
502
|
+
|
503
|
+
Parameters
|
504
|
+
----------
|
505
|
+
in_size : tuple of int
|
506
|
+
The input shape, without batch dimension.
|
507
|
+
feature_axis : int or tuple of int, optional
|
508
|
+
The feature or non-batch axis of the input. Default is -1.
|
509
|
+
track_running_stats : bool, optional
|
510
|
+
If True, tracks the running mean and variance. If False, uses batch
|
511
|
+
statistics in both training and eval modes. Default is True.
|
512
|
+
epsilon : float, optional
|
513
|
+
A value added to the denominator for numerical stability. Default is 1e-5.
|
514
|
+
momentum : float, optional
|
515
|
+
The momentum value used for the ``running_mean`` and ``running_var``
|
516
|
+
computation. The update rule is:
|
517
|
+
:math:`\\hat{x}_{\\text{new}} = \\text{momentum} \\times \\hat{x} + (1 - \\text{momentum}) \\times x_t`.
|
518
|
+
Default is 0.99.
|
519
|
+
affine : bool, optional
|
520
|
+
If True, this module has learnable affine parameters (scale and bias).
|
521
|
+
Default is True.
|
522
|
+
bias_initializer : ArrayLike or Callable, optional
|
523
|
+
Initializer for the bias (beta) parameter. Default is ``init.Constant(0.)``.
|
524
|
+
scale_initializer : ArrayLike or Callable, optional
|
525
|
+
Initializer for the scale (gamma) parameter. Default is ``init.Constant(1.)``.
|
526
|
+
axis_name : str or sequence of str, optional
|
527
|
+
The axis name(s) for parallel reduction using ``jax.pmap`` or ``jax.vmap``.
|
528
|
+
If specified, batch statistics are calculated across all replicas on the
|
529
|
+
named axes. Default is None.
|
530
|
+
axis_index_groups : sequence of sequence of int, optional
|
531
|
+
Groups of axis indices within the named axis representing subsets of
|
532
|
+
devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
|
533
|
+
independently batch-normalize over the first two and last two devices.
|
534
|
+
See ``jax.lax.psum`` for more details. Default is None.
|
535
|
+
use_fast_variance : bool, optional
|
536
|
+
If True, use a faster but less numerically stable calculation for
|
537
|
+
the variance. Default is True.
|
538
|
+
|
539
|
+
Notes
|
540
|
+
-----
|
541
|
+
The ``momentum`` parameter is different from the conventional notion of
|
542
|
+
momentum used in optimizers.
|
543
|
+
|
544
|
+
References
|
545
|
+
----------
|
546
|
+
.. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
|
547
|
+
Deep Network Training by Reducing Internal Covariate Shift.
|
548
|
+
In International Conference on Machine Learning (pp. 448-456).
|
549
|
+
|
550
|
+
Examples
|
551
|
+
--------
|
552
|
+
.. code-block:: python
|
553
|
+
|
554
|
+
>>> import brainstate as bst
|
555
|
+
>>> import jax.numpy as jnp
|
556
|
+
>>>
|
557
|
+
>>> # Create a BatchNorm0d layer
|
558
|
+
>>> layer = bst.nn.BatchNorm0d(in_size=(10,))
|
559
|
+
>>>
|
560
|
+
>>> # Apply normalization to a batch of data
|
561
|
+
>>> x = jnp.ones((32, 10)) # batch_size=32, features=10
|
562
|
+
>>> y = layer(x)
|
563
|
+
>>>
|
564
|
+
>>> # Check output shape
|
565
|
+
>>> print(y.shape)
|
566
|
+
(32, 10)
|
421
567
|
"""
|
422
568
|
__module__ = 'brainstate.nn'
|
423
569
|
num_spatial_dims: int = 0
|
424
570
|
|
425
571
|
|
426
572
|
class BatchNorm1d(_BatchNorm):
|
427
|
-
|
573
|
+
"""
|
574
|
+
1-D batch normalization.
|
575
|
+
|
576
|
+
Normalizes a batch of 1-D data by fixing the mean and variance of inputs
|
577
|
+
on each feature (channel). This layer aims to reduce the internal covariate
|
578
|
+
shift of data.
|
428
579
|
|
429
|
-
The data should
|
430
|
-
|
580
|
+
The input data should have shape ``(b, l, c)``, where ``b`` is the batch
|
581
|
+
dimension, ``l`` is the spatial/sequence dimension, and ``c`` is the channel
|
582
|
+
dimension.
|
431
583
|
|
432
|
-
|
584
|
+
Parameters
|
585
|
+
----------
|
586
|
+
in_size : tuple of int
|
587
|
+
The input shape, without batch dimension. For 1-D data, typically ``(l, c)``.
|
588
|
+
feature_axis : int or tuple of int, optional
|
589
|
+
The feature or non-batch axis of the input. Default is -1.
|
590
|
+
track_running_stats : bool, optional
|
591
|
+
If True, tracks the running mean and variance. If False, uses batch
|
592
|
+
statistics in both training and eval modes. Default is True.
|
593
|
+
epsilon : float, optional
|
594
|
+
A value added to the denominator for numerical stability. Default is 1e-5.
|
595
|
+
momentum : float, optional
|
596
|
+
The momentum value for running statistics computation. Default is 0.99.
|
597
|
+
affine : bool, optional
|
598
|
+
If True, has learnable affine parameters (scale and bias). Default is True.
|
599
|
+
bias_initializer : ArrayLike or Callable, optional
|
600
|
+
Initializer for the bias parameter. Default is ``init.Constant(0.)``.
|
601
|
+
scale_initializer : ArrayLike or Callable, optional
|
602
|
+
Initializer for the scale parameter. Default is ``init.Constant(1.)``.
|
603
|
+
axis_name : str or sequence of str, optional
|
604
|
+
Axis name(s) for parallel reduction. Default is None.
|
605
|
+
axis_index_groups : sequence of sequence of int, optional
|
606
|
+
Groups of axis indices for device-grouped reduction. Default is None.
|
607
|
+
use_fast_variance : bool, optional
|
608
|
+
If True, use faster but less stable variance calculation. Default is True.
|
609
|
+
|
610
|
+
References
|
611
|
+
----------
|
612
|
+
.. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
|
613
|
+
Deep Network Training by Reducing Internal Covariate Shift.
|
614
|
+
In International Conference on Machine Learning (pp. 448-456).
|
615
|
+
|
616
|
+
See Also
|
617
|
+
--------
|
618
|
+
BatchNorm0d : 0-D batch normalization
|
619
|
+
BatchNorm2d : 2-D batch normalization
|
620
|
+
BatchNorm3d : 3-D batch normalization
|
621
|
+
|
622
|
+
Examples
|
623
|
+
--------
|
624
|
+
.. code-block:: python
|
625
|
+
|
626
|
+
>>> import brainstate as bst
|
627
|
+
>>> import jax.numpy as jnp
|
628
|
+
>>>
|
629
|
+
>>> # Create a BatchNorm1d layer for sequence data
|
630
|
+
>>> layer = bst.nn.BatchNorm1d(in_size=(100, 64)) # length=100, channels=64
|
631
|
+
>>>
|
632
|
+
>>> # Apply normalization
|
633
|
+
>>> x = jnp.ones((8, 100, 64)) # batch_size=8
|
634
|
+
>>> y = layer(x)
|
635
|
+
>>> print(y.shape)
|
636
|
+
(8, 100, 64)
|
433
637
|
"""
|
434
638
|
__module__ = 'brainstate.nn'
|
435
639
|
num_spatial_dims: int = 1
|
436
640
|
|
437
641
|
|
438
642
|
class BatchNorm2d(_BatchNorm):
|
439
|
-
|
643
|
+
"""
|
644
|
+
2-D batch normalization.
|
645
|
+
|
646
|
+
Normalizes a batch of 2-D data (e.g., images) by fixing the mean and variance
|
647
|
+
of inputs on each feature (channel). This layer aims to reduce the internal
|
648
|
+
covariate shift of data.
|
440
649
|
|
441
|
-
The data should
|
442
|
-
|
443
|
-
channel dimension.
|
650
|
+
The input data should have shape ``(b, h, w, c)``, where ``b`` is the batch
|
651
|
+
dimension, ``h`` is the height dimension, ``w`` is the width dimension, and
|
652
|
+
``c`` is the channel dimension.
|
444
653
|
|
445
|
-
|
654
|
+
Parameters
|
655
|
+
----------
|
656
|
+
in_size : tuple of int
|
657
|
+
The input shape, without batch dimension. For 2-D data, typically ``(h, w, c)``.
|
658
|
+
feature_axis : int or tuple of int, optional
|
659
|
+
The feature or non-batch axis of the input. Default is -1.
|
660
|
+
track_running_stats : bool, optional
|
661
|
+
If True, tracks the running mean and variance. If False, uses batch
|
662
|
+
statistics in both training and eval modes. Default is True.
|
663
|
+
epsilon : float, optional
|
664
|
+
A value added to the denominator for numerical stability. Default is 1e-5.
|
665
|
+
momentum : float, optional
|
666
|
+
The momentum value for running statistics computation. Default is 0.99.
|
667
|
+
affine : bool, optional
|
668
|
+
If True, has learnable affine parameters (scale and bias). Default is True.
|
669
|
+
bias_initializer : ArrayLike or Callable, optional
|
670
|
+
Initializer for the bias parameter. Default is ``init.Constant(0.)``.
|
671
|
+
scale_initializer : ArrayLike or Callable, optional
|
672
|
+
Initializer for the scale parameter. Default is ``init.Constant(1.)``.
|
673
|
+
axis_name : str or sequence of str, optional
|
674
|
+
Axis name(s) for parallel reduction. Default is None.
|
675
|
+
axis_index_groups : sequence of sequence of int, optional
|
676
|
+
Groups of axis indices for device-grouped reduction. Default is None.
|
677
|
+
use_fast_variance : bool, optional
|
678
|
+
If True, use faster but less stable variance calculation. Default is True.
|
679
|
+
|
680
|
+
References
|
681
|
+
----------
|
682
|
+
.. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
|
683
|
+
Deep Network Training by Reducing Internal Covariate Shift.
|
684
|
+
In International Conference on Machine Learning (pp. 448-456).
|
685
|
+
|
686
|
+
See Also
|
687
|
+
--------
|
688
|
+
BatchNorm0d : 0-D batch normalization
|
689
|
+
BatchNorm1d : 1-D batch normalization
|
690
|
+
BatchNorm3d : 3-D batch normalization
|
691
|
+
|
692
|
+
Examples
|
693
|
+
--------
|
694
|
+
.. code-block:: python
|
695
|
+
|
696
|
+
>>> import brainstate as bst
|
697
|
+
>>> import jax.numpy as jnp
|
698
|
+
>>>
|
699
|
+
>>> # Create a BatchNorm2d layer for image data
|
700
|
+
>>> layer = bst.nn.BatchNorm2d(in_size=(28, 28, 3)) # 28x28 RGB images
|
701
|
+
>>>
|
702
|
+
>>> # Apply normalization
|
703
|
+
>>> x = jnp.ones((16, 28, 28, 3)) # batch_size=16
|
704
|
+
>>> y = layer(x)
|
705
|
+
>>> print(y.shape)
|
706
|
+
(16, 28, 28, 3)
|
446
707
|
"""
|
447
708
|
__module__ = 'brainstate.nn'
|
448
709
|
num_spatial_dims: int = 2
|
449
710
|
|
450
711
|
|
451
712
|
class BatchNorm3d(_BatchNorm):
|
452
|
-
|
713
|
+
"""
|
714
|
+
3-D batch normalization.
|
715
|
+
|
716
|
+
Normalizes a batch of 3-D data (e.g., video or volumetric data) by fixing
|
717
|
+
the mean and variance of inputs on each feature (channel). This layer aims
|
718
|
+
to reduce the internal covariate shift of data.
|
453
719
|
|
454
|
-
The data should
|
455
|
-
|
456
|
-
dimension, and
|
720
|
+
The input data should have shape ``(b, h, w, d, c)``, where ``b`` is the
|
721
|
+
batch dimension, ``h`` is the height dimension, ``w`` is the width dimension,
|
722
|
+
``d`` is the depth dimension, and ``c`` is the channel dimension.
|
457
723
|
|
458
|
-
|
724
|
+
Parameters
|
725
|
+
----------
|
726
|
+
in_size : tuple of int
|
727
|
+
The input shape, without batch dimension. For 3-D data, typically ``(h, w, d, c)``.
|
728
|
+
feature_axis : int or tuple of int, optional
|
729
|
+
The feature or non-batch axis of the input. Default is -1.
|
730
|
+
track_running_stats : bool, optional
|
731
|
+
If True, tracks the running mean and variance. If False, uses batch
|
732
|
+
statistics in both training and eval modes. Default is True.
|
733
|
+
epsilon : float, optional
|
734
|
+
A value added to the denominator for numerical stability. Default is 1e-5.
|
735
|
+
momentum : float, optional
|
736
|
+
The momentum value for running statistics computation. Default is 0.99.
|
737
|
+
affine : bool, optional
|
738
|
+
If True, has learnable affine parameters (scale and bias). Default is True.
|
739
|
+
bias_initializer : ArrayLike or Callable, optional
|
740
|
+
Initializer for the bias parameter. Default is ``init.Constant(0.)``.
|
741
|
+
scale_initializer : ArrayLike or Callable, optional
|
742
|
+
Initializer for the scale parameter. Default is ``init.Constant(1.)``.
|
743
|
+
axis_name : str or sequence of str, optional
|
744
|
+
Axis name(s) for parallel reduction. Default is None.
|
745
|
+
axis_index_groups : sequence of sequence of int, optional
|
746
|
+
Groups of axis indices for device-grouped reduction. Default is None.
|
747
|
+
use_fast_variance : bool, optional
|
748
|
+
If True, use faster but less stable variance calculation. Default is True.
|
749
|
+
|
750
|
+
References
|
751
|
+
----------
|
752
|
+
.. [1] Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating
|
753
|
+
Deep Network Training by Reducing Internal Covariate Shift.
|
754
|
+
In International Conference on Machine Learning (pp. 448-456).
|
755
|
+
|
756
|
+
See Also
|
757
|
+
--------
|
758
|
+
BatchNorm0d : 0-D batch normalization
|
759
|
+
BatchNorm1d : 1-D batch normalization
|
760
|
+
BatchNorm2d : 2-D batch normalization
|
761
|
+
|
762
|
+
Examples
|
763
|
+
--------
|
764
|
+
.. code-block:: python
|
765
|
+
|
766
|
+
>>> import brainstate as bst
|
767
|
+
>>> import jax.numpy as jnp
|
768
|
+
>>>
|
769
|
+
>>> # Create a BatchNorm3d layer for volumetric data
|
770
|
+
>>> layer = bst.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes
|
771
|
+
>>>
|
772
|
+
>>> # Apply normalization
|
773
|
+
>>> x = jnp.ones((4, 32, 32, 32, 1)) # batch_size=4
|
774
|
+
>>> y = layer(x)
|
775
|
+
>>> print(y.shape)
|
776
|
+
(4, 32, 32, 32, 1)
|
459
777
|
"""
|
460
778
|
__module__ = 'brainstate.nn'
|
461
779
|
num_spatial_dims: int = 3
|
462
780
|
|
463
781
|
|
464
|
-
_bn_doc = r'''
|
465
|
-
|
466
|
-
This layer aims to reduce the internal covariant shift of data. It
|
467
|
-
normalizes a batch of data by fixing the mean and variance of inputs
|
468
|
-
on each feature (channel). Most commonly, the first axis of the data
|
469
|
-
is the batch, and the last is the channel. However, users can specify
|
470
|
-
the axes to be normalized.
|
471
|
-
|
472
|
-
.. math::
|
473
|
-
y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
|
474
|
-
|
475
|
-
.. note::
|
476
|
-
This :attr:`momentum` argument is different from one used in optimizer
|
477
|
-
classes and the conventional notion of momentum. Mathematically, the
|
478
|
-
update rule for running statistics here is
|
479
|
-
:math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
|
480
|
-
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
481
|
-
new observed value.
|
482
|
-
|
483
|
-
Parameters
|
484
|
-
----------
|
485
|
-
in_size: sequence of int
|
486
|
-
The input shape, without batch size.
|
487
|
-
feature_axis: int, tuple, list
|
488
|
-
The feature or non-batch axis of the input.
|
489
|
-
track_running_stats: bool
|
490
|
-
A boolean value that when set to ``True``, this module tracks the running mean and variance,
|
491
|
-
and when set to ``False``, this module does not track such statistics, and initializes
|
492
|
-
statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
|
493
|
-
this module always uses batch statistics. in both training and eval modes. Default: ``True``.
|
494
|
-
momentum: float
|
495
|
-
The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
|
496
|
-
epsilon: float
|
497
|
-
A value added to the denominator for numerical stability. Default: 1e-5
|
498
|
-
affine: bool
|
499
|
-
A boolean value that when set to ``True``, this module has
|
500
|
-
learnable affine parameters. Default: ``True``
|
501
|
-
bias_initializer: ArrayLike, Callable
|
502
|
-
An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
|
503
|
-
Default: ``init.Constant(0.)``
|
504
|
-
scale_initializer: ArrayLike, Callable
|
505
|
-
An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
|
506
|
-
Default: ``init.Constant(1.)``
|
507
|
-
axis_name: optional, str, sequence of str
|
508
|
-
If not ``None``, it should be a string (or sequence of
|
509
|
-
strings) representing the axis name(s) over which this module is being
|
510
|
-
run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
|
511
|
-
argument means that batch statistics are calculated across all replicas
|
512
|
-
on the named axes.
|
513
|
-
axis_index_groups: optional, sequence
|
514
|
-
Specifies how devices are grouped. Valid
|
515
|
-
only within ``jax.pmap`` collectives.
|
516
|
-
Groups of axis indices within that named axis
|
517
|
-
representing subsets of devices to reduce over (default: None). For
|
518
|
-
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
519
|
-
the examples on the first two and last two devices. See `jax.lax.psum`
|
520
|
-
for more details.
|
521
|
-
use_fast_variance: If true, use a faster, but less numerically stable,
|
522
|
-
calculation for the variance.
|
523
|
-
|
524
|
-
|
525
|
-
References
|
526
|
-
----------
|
527
|
-
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
|
528
|
-
by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
|
529
|
-
|
530
|
-
'''
|
531
|
-
|
532
|
-
BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
|
533
|
-
BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
|
534
|
-
BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
|
535
|
-
|
536
|
-
|
537
782
|
class LayerNorm(Module):
|
538
783
|
"""
|
539
|
-
Layer normalization
|
784
|
+
Layer normalization layer [1]_.
|
540
785
|
|
541
|
-
LayerNorm normalizes the activations of the layer for each given example in
|
542
|
-
batch independently, rather than across a batch like Batch Normalization.
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
786
|
+
LayerNorm normalizes the activations of the layer for each given example in
|
787
|
+
a batch independently, rather than across a batch like Batch Normalization.
|
788
|
+
It applies a transformation that maintains the mean activation within each
|
789
|
+
example close to 0 and the activation standard deviation close to 1.
|
790
|
+
|
791
|
+
Parameters
|
792
|
+
----------
|
793
|
+
in_size : tuple of int
|
794
|
+
The input shape, without batch dimension.
|
795
|
+
reduction_axes : int or tuple of int, optional
|
796
|
+
Axes for computing normalization statistics. It is recommended to use
|
797
|
+
negative integers, as positive integers may cause issues when batch
|
798
|
+
dimensions are present. Default is -1.
|
799
|
+
feature_axes : int or tuple of int, optional
|
800
|
+
Feature axes for learned bias and scaling. Default is -1.
|
801
|
+
epsilon : float, optional
|
802
|
+
A small value added to variance to avoid division by zero. Default is 1e-6.
|
803
|
+
use_bias : bool, optional
|
804
|
+
If True, bias (beta) is added. Default is True.
|
805
|
+
use_scale : bool, optional
|
806
|
+
If True, multiply by scale (gamma). When the next layer is linear
|
807
|
+
(e.g., nn.relu), this can be disabled since scaling will be done by
|
808
|
+
the next layer. Default is True.
|
809
|
+
bias_init : Callable, optional
|
810
|
+
Initializer for bias parameter. Default is ``init.ZeroInit()``.
|
811
|
+
scale_init : Callable, optional
|
812
|
+
Initializer for scale parameter. Default is ``init.Constant(1.0)``.
|
813
|
+
axis_name : str, optional
|
814
|
+
The axis name used to combine batch statistics from multiple devices.
|
815
|
+
See ``jax.pmap`` for axis name description. Only needed if the model
|
816
|
+
is subdivided across devices. Default is None.
|
817
|
+
axis_index_groups : sequence, optional
|
818
|
+
Groups of axis indices within the named axis representing subsets of
|
819
|
+
devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
|
820
|
+
independently normalize over the first two and last two devices.
|
821
|
+
See ``jax.lax.psum`` for details. Default is None.
|
822
|
+
use_fast_variance : bool, optional
|
823
|
+
If True, use a faster but less numerically stable calculation for
|
824
|
+
the variance. Default is True.
|
825
|
+
dtype : jax.typing.DTypeLike, optional
|
826
|
+
The dtype of the result. If None, inferred from input and parameters.
|
827
|
+
Default is None.
|
828
|
+
|
829
|
+
References
|
830
|
+
----------
|
831
|
+
.. [1] Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer normalization.
|
832
|
+
arXiv preprint arXiv:1607.06450.
|
833
|
+
|
834
|
+
See Also
|
835
|
+
--------
|
836
|
+
RMSNorm : Root Mean Square Layer Normalization
|
837
|
+
GroupNorm : Group Normalization
|
838
|
+
BatchNorm1d : 1-D Batch Normalization
|
839
|
+
|
840
|
+
Examples
|
841
|
+
--------
|
842
|
+
.. code-block:: python
|
843
|
+
|
844
|
+
>>> import brainstate as bst
|
845
|
+
>>>
|
846
|
+
>>> # Create a LayerNorm layer
|
847
|
+
>>> x = bst.random.normal(size=(3, 4, 5, 6))
|
848
|
+
>>> layer = bst.nn.LayerNorm(x.shape)
|
849
|
+
>>>
|
850
|
+
>>> # Apply normalization
|
851
|
+
>>> y = layer(x)
|
852
|
+
>>> print(y.shape)
|
853
|
+
(3, 4, 5, 6)
|
854
|
+
>>>
|
855
|
+
>>> # Normalize only the last dimension
|
856
|
+
>>> layer = bst.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
|
857
|
+
>>> x = bst.random.normal((5, 10, 20))
|
858
|
+
>>> y = layer(x)
|
579
859
|
"""
|
580
860
|
|
581
861
|
def __init__(
|
@@ -631,13 +911,21 @@ class LayerNorm(Module):
|
|
631
911
|
self.use_fast_variance = use_fast_variance
|
632
912
|
|
633
913
|
def update(self, x, *, mask: Optional[jax.Array] = None):
|
634
|
-
"""
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
914
|
+
"""
|
915
|
+
Apply layer normalization on the input.
|
916
|
+
|
917
|
+
Parameters
|
918
|
+
----------
|
919
|
+
x : jax.Array
|
920
|
+
The input array.
|
921
|
+
mask : jax.Array, optional
|
922
|
+
Binary array of shape broadcastable to ``x``, indicating the
|
923
|
+
positions for which normalization should be computed. Default is None.
|
924
|
+
|
925
|
+
Returns
|
926
|
+
-------
|
927
|
+
jax.Array
|
928
|
+
Normalized inputs with the same shape as the input.
|
641
929
|
"""
|
642
930
|
mean, var = _compute_stats(
|
643
931
|
x,
|
@@ -663,45 +951,75 @@ class LayerNorm(Module):
|
|
663
951
|
|
664
952
|
class RMSNorm(Module):
|
665
953
|
"""
|
666
|
-
|
954
|
+
Root Mean Square Layer Normalization [1]_.
|
667
955
|
|
668
956
|
RMSNorm normalizes the activations of the layer for each given example in a
|
669
957
|
batch independently, rather than across a batch like Batch Normalization.
|
670
|
-
Unlike LayerNorm which re-centers the mean to
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
958
|
+
Unlike LayerNorm which re-centers the mean to 0 and normalizes by the standard
|
959
|
+
deviation, RMSNorm does not re-center at all and instead normalizes by the
|
960
|
+
root mean square of the activations.
|
961
|
+
|
962
|
+
Parameters
|
963
|
+
----------
|
964
|
+
in_size : tuple of int
|
965
|
+
The input shape, without batch dimension.
|
966
|
+
epsilon : float, optional
|
967
|
+
A small value added to variance to avoid division by zero. Default is 1e-6.
|
968
|
+
dtype : jax.typing.DTypeLike, optional
|
969
|
+
The dtype of the result. If None, inferred from input and parameters.
|
970
|
+
Default is None.
|
971
|
+
use_scale : bool, optional
|
972
|
+
If True, multiply by scale (gamma). When the next layer is linear
|
973
|
+
(e.g., nn.relu), this can be disabled since scaling will be done by
|
974
|
+
the next layer. Default is True.
|
975
|
+
scale_init : Callable, optional
|
976
|
+
Initializer for scale parameter. Default is ``init.Constant(1.0)``.
|
977
|
+
reduction_axes : int or tuple of int, optional
|
978
|
+
Axes for computing normalization statistics. It is recommended to use
|
979
|
+
negative integers. Default is -1.
|
980
|
+
feature_axes : int or tuple of int, optional
|
981
|
+
Feature axes for learned scaling. Default is -1.
|
982
|
+
axis_name : str, optional
|
983
|
+
The axis name used to combine batch statistics from multiple devices.
|
984
|
+
See ``jax.pmap`` for details. Default is None.
|
985
|
+
axis_index_groups : sequence, optional
|
986
|
+
Groups of axis indices within the named axis representing subsets of
|
987
|
+
devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
|
988
|
+
independently normalize over the first two and last two devices.
|
989
|
+
Default is None.
|
990
|
+
use_fast_variance : bool, optional
|
991
|
+
If True, use a faster but less numerically stable calculation for
|
992
|
+
the variance. Default is True.
|
993
|
+
|
994
|
+
References
|
995
|
+
----------
|
996
|
+
.. [1] Zhang, B., & Sennrich, R. (2019). Root Mean Square Layer Normalization.
|
997
|
+
Advances in Neural Information Processing Systems, 32.
|
998
|
+
|
999
|
+
See Also
|
1000
|
+
--------
|
1001
|
+
LayerNorm : Layer Normalization
|
1002
|
+
GroupNorm : Group Normalization
|
1003
|
+
|
1004
|
+
Examples
|
1005
|
+
--------
|
1006
|
+
.. code-block:: python
|
1007
|
+
|
1008
|
+
>>> import brainstate as bst
|
1009
|
+
>>>
|
1010
|
+
>>> # Create an RMSNorm layer
|
1011
|
+
>>> x = bst.random.normal(size=(5, 6))
|
1012
|
+
>>> layer = bst.nn.RMSNorm(in_size=(6,))
|
1013
|
+
>>>
|
1014
|
+
>>> # Apply normalization
|
1015
|
+
>>> y = layer(x)
|
1016
|
+
>>> print(y.shape)
|
1017
|
+
(5, 6)
|
1018
|
+
>>>
|
1019
|
+
>>> # Without scaling
|
1020
|
+
>>> layer = bst.nn.RMSNorm(in_size=(10,), use_scale=False)
|
1021
|
+
>>> x = bst.random.normal((3, 10))
|
1022
|
+
>>> y = layer(x)
|
705
1023
|
"""
|
706
1024
|
|
707
1025
|
def __init__(
|
@@ -747,14 +1065,21 @@ class RMSNorm(Module):
|
|
747
1065
|
self.use_fast_variance = use_fast_variance
|
748
1066
|
|
749
1067
|
def update(self, x, *, mask: Optional[jax.Array] = None):
|
750
|
-
"""
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
1068
|
+
"""
|
1069
|
+
Apply RMS normalization on the input.
|
1070
|
+
|
1071
|
+
Parameters
|
1072
|
+
----------
|
1073
|
+
x : jax.Array
|
1074
|
+
The input array.
|
1075
|
+
mask : jax.Array, optional
|
1076
|
+
Binary array of shape broadcastable to ``x``, indicating the
|
1077
|
+
positions for which normalization should be computed. Default is None.
|
1078
|
+
|
1079
|
+
Returns
|
1080
|
+
-------
|
1081
|
+
jax.Array
|
1082
|
+
Normalized inputs with the same shape as the input.
|
758
1083
|
"""
|
759
1084
|
mean, var = _compute_stats(
|
760
1085
|
x,
|
@@ -781,65 +1106,93 @@ class RMSNorm(Module):
|
|
781
1106
|
|
782
1107
|
class GroupNorm(Module):
|
783
1108
|
"""
|
784
|
-
Group
|
785
|
-
|
786
|
-
|
787
|
-
equally-sized groups of channels and not shared across
|
788
|
-
Thus, group normalization does not depend on the batch
|
789
|
-
not require maintaining internal state for storing statistics.
|
790
|
-
|
791
|
-
number of
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
devices
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
1109
|
+
Group Normalization layer [1]_.
|
1110
|
+
|
1111
|
+
Group normalization is similar to batch normalization, but statistics are
|
1112
|
+
shared across equally-sized groups of channels and not shared across the
|
1113
|
+
batch dimension. Thus, group normalization does not depend on the batch
|
1114
|
+
composition and does not require maintaining internal state for storing statistics.
|
1115
|
+
|
1116
|
+
The user should specify either the total number of channel groups (``num_groups``)
|
1117
|
+
or the number of channels per group (``group_size``).
|
1118
|
+
|
1119
|
+
Parameters
|
1120
|
+
----------
|
1121
|
+
in_size : tuple of int
|
1122
|
+
The input shape, without batch dimension.
|
1123
|
+
feature_axis : int or tuple of int, optional
|
1124
|
+
The feature axis of the input. Default is -1.
|
1125
|
+
num_groups : int, optional
|
1126
|
+
The total number of channel groups. The default value of 32 is proposed
|
1127
|
+
by the original group normalization paper. Either ``num_groups`` or
|
1128
|
+
``group_size`` must be specified, but not both. Default is 32.
|
1129
|
+
group_size : int, optional
|
1130
|
+
The number of channels in each group. Either ``num_groups`` or
|
1131
|
+
``group_size`` must be specified, but not both. Default is None.
|
1132
|
+
epsilon : float, optional
|
1133
|
+
A small value added to variance to avoid division by zero. Default is 1e-6.
|
1134
|
+
dtype : jax.typing.DTypeLike, optional
|
1135
|
+
The dtype of the result. If None, inferred from input and parameters.
|
1136
|
+
Default is None.
|
1137
|
+
use_bias : bool, optional
|
1138
|
+
If True, bias (beta) is added. Default is True.
|
1139
|
+
use_scale : bool, optional
|
1140
|
+
If True, multiply by scale (gamma). When the next layer is linear
|
1141
|
+
(e.g., nn.relu), this can be disabled. Default is True.
|
1142
|
+
bias_init : Callable, optional
|
1143
|
+
Initializer for bias parameter. Default is ``init.ZeroInit()``.
|
1144
|
+
scale_init : Callable, optional
|
1145
|
+
Initializer for scale parameter. Default is ``init.Constant(1.)``.
|
1146
|
+
reduction_axes : int or tuple of int, optional
|
1147
|
+
List of axes used for computing normalization statistics. Must include
|
1148
|
+
the final dimension (feature axis). It is recommended to use negative
|
1149
|
+
integers. Default is None.
|
1150
|
+
axis_name : str, optional
|
1151
|
+
The axis name used to combine batch statistics from multiple devices.
|
1152
|
+
See ``jax.pmap`` for details. Default is None.
|
1153
|
+
axis_index_groups : sequence, optional
|
1154
|
+
Groups of axis indices within the named axis representing subsets of
|
1155
|
+
devices to reduce over. For example, ``[[0, 1], [2, 3]]`` would
|
1156
|
+
independently normalize over the first two and last two devices.
|
1157
|
+
Default is None.
|
1158
|
+
use_fast_variance : bool, optional
|
1159
|
+
If True, use a faster but less numerically stable calculation for
|
1160
|
+
the variance. Default is True.
|
1161
|
+
|
1162
|
+
Notes
|
1163
|
+
-----
|
1164
|
+
LayerNorm is a special case of GroupNorm where ``num_groups=1``.
|
1165
|
+
|
1166
|
+
References
|
1167
|
+
----------
|
1168
|
+
.. [1] Wu, Y., & He, K. (2018). Group Normalization.
|
1169
|
+
In Proceedings of the European Conference on Computer Vision (ECCV)
|
1170
|
+
(pp. 3-19).
|
1171
|
+
|
1172
|
+
See Also
|
1173
|
+
--------
|
1174
|
+
LayerNorm : Layer Normalization
|
1175
|
+
BatchNorm2d : 2-D Batch Normalization
|
1176
|
+
|
1177
|
+
Examples
|
1178
|
+
--------
|
1179
|
+
.. code-block:: python
|
1180
|
+
|
1181
|
+
>>> import numpy as np
|
1182
|
+
>>> import brainstate as bst
|
1183
|
+
>>>
|
1184
|
+
>>> # Create a GroupNorm layer with 3 groups
|
1185
|
+
>>> x = bst.random.normal(size=(3, 4, 5, 6))
|
1186
|
+
>>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
|
1187
|
+
>>> y = layer(x)
|
1188
|
+
>>>
|
1189
|
+
>>> # GroupNorm with num_groups=1 is equivalent to LayerNorm
|
1190
|
+
>>> y1 = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
|
1191
|
+
>>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
1192
|
+
>>> np.testing.assert_allclose(y1, y2, rtol=1e-5)
|
1193
|
+
>>>
|
1194
|
+
>>> # Specify group_size instead of num_groups
|
1195
|
+
>>> layer = bst.nn.GroupNorm((12,), num_groups=None, group_size=4)
|
843
1196
|
"""
|
844
1197
|
|
845
1198
|
def __init__(
|
@@ -927,20 +1280,26 @@ class GroupNorm(Module):
|
|
927
1280
|
self.use_fast_variance = use_fast_variance
|
928
1281
|
|
929
1282
|
def update(self, x, *, mask: Optional[jax.Array] = None):
|
930
|
-
"""
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
1283
|
+
"""
|
1284
|
+
Apply group normalization to the input.
|
1285
|
+
|
1286
|
+
Parameters
|
1287
|
+
----------
|
1288
|
+
x : jax.Array
|
1289
|
+
The input of shape ``...C`` where ``C`` is the channels dimension
|
1290
|
+
and ``...`` represents an arbitrary number of extra dimensions. If no
|
1291
|
+
reduction axes have been specified, all additional dimensions will be
|
1292
|
+
used to accumulate statistics apart from the leading dimension which
|
1293
|
+
is assumed to represent the batch.
|
1294
|
+
mask : jax.Array, optional
|
1295
|
+
Binary array of shape broadcastable to ``x``, indicating the
|
1296
|
+
positions for which the mean and variance should be computed.
|
1297
|
+
Default is None.
|
1298
|
+
|
1299
|
+
Returns
|
1300
|
+
-------
|
1301
|
+
jax.Array
|
1302
|
+
Normalized inputs with the same shape as the input.
|
944
1303
|
"""
|
945
1304
|
if self.reduction_axes is not None:
|
946
1305
|
reduction_axes = self.reduction_axes
|