brainstate 0.1.0.post20241122__py2.py3-none-any.whl → 0.1.0.post20241129__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd.py +112 -114
- brainstate/augment/_autograd_test.py +97 -0
- brainstate/event/__init__.py +6 -6
- brainstate/event/_csr_mv_benchmark.py +14 -0
- brainstate/event/{_linear_test.py → _linear_mv_test.py} +1 -1
- brainstate/event/_xla_custom_op.py +5 -8
- brainstate/nn/_elementwise/_dropout.py +4 -4
- brainstate/nn/_interaction/_linear.py +2 -2
- brainstate/nn/_interaction/_normalizations.py +577 -55
- brainstate/optim/_optax_optimizer.py +1 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/METADATA +2 -2
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/RECORD +23 -22
- /brainstate/event/{_csr.py → _csr_mv.py} +0 -0
- /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
- /brainstate/event/{_fixed_probability.py → _fixedprob_mv.py} +0 -0
- /brainstate/event/{_fixed_probability_benchmark.py → _fixedprob_mv_benchmark.py} +0 -0
- /brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +0 -0
- /brainstate/event/{_linear.py → _linear_mv.py} +0 -0
- /brainstate/event/{_linear_benckmark.py → _linear_mv_benckmark.py} +0 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,7 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import
|
21
|
-
from typing import Callable, Union, Sequence, Optional, Any
|
20
|
+
from typing import Callable, Union, Sequence, Optional, Any, Dict
|
22
21
|
|
23
22
|
import jax
|
24
23
|
import jax.numpy as jnp
|
@@ -29,10 +28,50 @@ from brainstate.nn._module import Module
|
|
29
28
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
30
29
|
|
31
30
|
__all__ = [
|
32
|
-
'BatchNorm0d',
|
31
|
+
'BatchNorm0d',
|
32
|
+
'BatchNorm1d',
|
33
|
+
'BatchNorm2d',
|
34
|
+
'BatchNorm3d',
|
35
|
+
'LayerNorm',
|
36
|
+
'RMSNorm',
|
37
|
+
'GroupNorm',
|
33
38
|
]
|
34
39
|
|
35
40
|
|
41
|
+
def canonicalize_dtype(
|
42
|
+
*args,
|
43
|
+
dtype: jax.typing.DTypeLike | None = None,
|
44
|
+
inexact: bool = True
|
45
|
+
) -> jax.typing.DTypeLike:
|
46
|
+
"""Canonicalize an optional dtype to the definitive dtype.
|
47
|
+
|
48
|
+
If the ``dtype`` is None this function will infer the dtype. If it is not
|
49
|
+
None it will be returned unmodified or an exceptions is raised if the dtype
|
50
|
+
is invalid.
|
51
|
+
from the input arguments using ``jnp.result_type``.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
*args: JAX array compatible values. None values
|
55
|
+
are ignored.
|
56
|
+
dtype: Optional dtype override. If specified the arguments are cast to
|
57
|
+
the specified dtype instead and dtype inference is disabled.
|
58
|
+
inexact: When True, the output dtype must be a subdtype
|
59
|
+
of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
|
60
|
+
is useful when you want to apply operations that don't work directly on
|
61
|
+
integers like taking a mean for example.
|
62
|
+
Returns:
|
63
|
+
The dtype that *args should be cast to.
|
64
|
+
"""
|
65
|
+
if dtype is None:
|
66
|
+
args_filtered = [jnp.asarray(x) for x in args if x is not None]
|
67
|
+
dtype = jnp.result_type(*args_filtered)
|
68
|
+
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
|
69
|
+
dtype = jnp.promote_types(jnp.float32, dtype)
|
70
|
+
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
|
71
|
+
raise ValueError(f'Dtype must be inexact: {dtype}')
|
72
|
+
return dtype
|
73
|
+
|
74
|
+
|
36
75
|
def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
|
37
76
|
axes = []
|
38
77
|
for axis in feature_axes:
|
@@ -59,28 +98,38 @@ def _compute_stats(
|
|
59
98
|
axis_name: Optional[str] = None,
|
60
99
|
axis_index_groups: Optional[Sequence[int]] = None,
|
61
100
|
use_mean: bool = True,
|
101
|
+
use_fast_variance: bool = True,
|
102
|
+
mask: Optional[jax.Array] = None,
|
62
103
|
):
|
63
|
-
"""
|
104
|
+
"""
|
105
|
+
Computes mean and variance statistics.
|
64
106
|
|
65
107
|
This implementation takes care of a few important details:
|
66
108
|
- Computes in float32 precision for stability in half precision training.
|
67
|
-
- mean and variance are
|
68
|
-
|
109
|
+
- If ``use_fast_variance`` is ``True``, mean and variance are computed using
|
110
|
+
Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
|
69
111
|
- Clips negative variances to zero which can happen due to
|
70
112
|
roundoff errors. This avoids downstream NaNs.
|
71
113
|
- Supports averaging across a parallel axis and subgroups of a parallel axis
|
72
114
|
with a single `lax.pmean` call to avoid latency.
|
73
115
|
|
74
116
|
Arguments:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
117
|
+
x: Input array.
|
118
|
+
axes: The axes in ``x`` to compute mean and variance statistics for.
|
119
|
+
dtype: tp.Optional dtype specifying the minimal precision. Statistics
|
120
|
+
are always at least float32 for stability (default: dtype of x).
|
121
|
+
axis_name: Optional name for the pmapped axis to compute mean over. Note,
|
122
|
+
this is only used for pmap and shard map. For SPMD jit, you do not need to
|
123
|
+
manually synchronize. Just make sure that the axes are correctly annotated
|
124
|
+
and XLA:SPMD will insert the necessary collectives.
|
125
|
+
axis_index_groups: Optional axis indices.
|
126
|
+
use_mean: If true, calculate the mean from the input and use it when
|
127
|
+
computing the variance. If false, set the mean to zero and compute
|
128
|
+
the variance without subtracting the mean.
|
129
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
130
|
+
calculation for the variance.
|
131
|
+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
|
132
|
+
the positions for which the mean and variance should be computed.
|
84
133
|
|
85
134
|
Returns:
|
86
135
|
A pair ``(mean, val)``.
|
@@ -89,32 +138,38 @@ def _compute_stats(
|
|
89
138
|
dtype = jax.numpy.result_type(x)
|
90
139
|
# promote x to at least float32, this avoids half precision computation
|
91
140
|
# but preserves double or complex floating points
|
92
|
-
dtype = jax.numpy.promote_types(dtype,
|
141
|
+
dtype = jax.numpy.promote_types(dtype, jnp.float32)
|
93
142
|
x = jnp.asarray(x, dtype)
|
143
|
+
axes = _canonicalize_axes(x.ndim, axes)
|
144
|
+
|
145
|
+
def maybe_distributed_mean(*xs, mask=None):
|
146
|
+
mus = tuple(x.mean(axes, where=mask) for x in xs)
|
147
|
+
if axis_name is None:
|
148
|
+
return mus if len(xs) > 1 else mus[0]
|
149
|
+
else:
|
150
|
+
# In the distributed case we stack multiple arrays to speed comms.
|
151
|
+
if len(xs) > 1:
|
152
|
+
reduced_mus = jax.lax.pmean(
|
153
|
+
jnp.stack(mus, axis=0), axis_name,
|
154
|
+
axis_index_groups=axis_index_groups,
|
155
|
+
)
|
156
|
+
return tuple(reduced_mus[i] for i in range(len(xs)))
|
157
|
+
else:
|
158
|
+
return jax.lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups)
|
94
159
|
|
95
|
-
# Compute mean and mean of squared values.
|
96
|
-
mean2 = jnp.mean(_abs_sq(x), axes)
|
97
160
|
if use_mean:
|
98
|
-
|
161
|
+
if use_fast_variance:
|
162
|
+
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
|
163
|
+
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
164
|
+
# to floating point round-off errors.
|
165
|
+
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
|
166
|
+
else:
|
167
|
+
mu = maybe_distributed_mean(x, mask=mask)
|
168
|
+
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
|
99
169
|
else:
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
if axis_name is not None:
|
104
|
-
concatenated_mean = jnp.concatenate([mean, mean2])
|
105
|
-
mean, mean2 = jnp.split(
|
106
|
-
jax.lax.pmean(
|
107
|
-
concatenated_mean,
|
108
|
-
axis_name=axis_name,
|
109
|
-
axis_index_groups=axis_index_groups,
|
110
|
-
),
|
111
|
-
2,
|
112
|
-
)
|
113
|
-
|
114
|
-
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
115
|
-
# to floating point round-off errors.
|
116
|
-
var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
|
117
|
-
return mean, var
|
170
|
+
var = maybe_distributed_mean(_abs_sq(x), mask=mask)
|
171
|
+
mu = jnp.zeros_like(var)
|
172
|
+
return mu, var
|
118
173
|
|
119
174
|
|
120
175
|
def _normalize(
|
@@ -122,9 +177,10 @@ def _normalize(
|
|
122
177
|
mean: Optional[ArrayLike],
|
123
178
|
var: Optional[ArrayLike],
|
124
179
|
weights: Optional[ParamState],
|
125
|
-
reduction_axes:
|
180
|
+
reduction_axes: Axes,
|
181
|
+
feature_axes: Axes,
|
126
182
|
dtype: DTypeLike,
|
127
|
-
epsilon:
|
183
|
+
epsilon: jax.typing.ArrayLike,
|
128
184
|
):
|
129
185
|
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
|
130
186
|
|
@@ -134,6 +190,7 @@ def _normalize(
|
|
134
190
|
var: Variance to use for normalization.
|
135
191
|
weights: The scale and bias parameters.
|
136
192
|
reduction_axes: The axes in ``x`` to reduce.
|
193
|
+
feature_axes: The feature axes to apply the scale and bias.
|
137
194
|
dtype: The dtype of the result (default: infer from input and params).
|
138
195
|
epsilon: Normalization epsilon.
|
139
196
|
|
@@ -142,16 +199,23 @@ def _normalize(
|
|
142
199
|
"""
|
143
200
|
if mean is not None:
|
144
201
|
assert var is not None, 'mean and val must be both None or not None.'
|
202
|
+
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
|
203
|
+
feature_axes = _canonicalize_axes(x.ndim, feature_axes)
|
145
204
|
stats_shape = list(x.shape)
|
146
205
|
for axis in reduction_axes:
|
147
206
|
stats_shape[axis] = 1
|
148
207
|
mean = mean.reshape(stats_shape)
|
149
208
|
var = var.reshape(stats_shape)
|
209
|
+
feature_shape = [1] * x.ndim
|
210
|
+
for ax in feature_axes:
|
211
|
+
feature_shape[ax] = x.shape[ax]
|
150
212
|
y = x - mean
|
151
|
-
mul = jax.lax.rsqrt(var +
|
213
|
+
mul = jax.lax.rsqrt(var + epsilon)
|
152
214
|
y = y * mul
|
215
|
+
args = []
|
153
216
|
if weights is not None:
|
154
|
-
y = _scale_operation(y, weights.value)
|
217
|
+
y, args = _scale_operation(y, weights.value)
|
218
|
+
dtype = canonicalize_dtype(x, *args, dtype=dtype)
|
155
219
|
else:
|
156
220
|
assert var is None, 'mean and val must be both None or not None.'
|
157
221
|
assert weights is None, 'scale and bias are not supported without mean and val'
|
@@ -159,12 +223,15 @@ def _normalize(
|
|
159
223
|
return jnp.asarray(y, dtype)
|
160
224
|
|
161
225
|
|
162
|
-
def _scale_operation(x, param):
|
226
|
+
def _scale_operation(x: jax.Array, param: Dict):
|
227
|
+
args = []
|
163
228
|
if 'scale' in param:
|
164
229
|
x = x * param['scale']
|
230
|
+
args.append(param['scale'])
|
165
231
|
if 'bias' in param:
|
166
232
|
x = x + param['bias']
|
167
|
-
|
233
|
+
args.append(param['bias'])
|
234
|
+
return x, args
|
168
235
|
|
169
236
|
|
170
237
|
class _BatchNorm(Module):
|
@@ -175,6 +242,7 @@ class _BatchNorm(Module):
|
|
175
242
|
self,
|
176
243
|
in_size: Size,
|
177
244
|
feature_axis: Axes = -1,
|
245
|
+
*,
|
178
246
|
track_running_stats: bool = True,
|
179
247
|
epsilon: float = 1e-5,
|
180
248
|
momentum: float = 0.99,
|
@@ -183,14 +251,15 @@ class _BatchNorm(Module):
|
|
183
251
|
scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
|
184
252
|
axis_name: Optional[Union[str, Sequence[str]]] = None,
|
185
253
|
axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
|
254
|
+
use_fast_variance: bool = True,
|
186
255
|
name: Optional[str] = None,
|
187
256
|
dtype: Any = None,
|
188
257
|
):
|
189
258
|
super().__init__(name=name)
|
190
259
|
|
191
260
|
# parameters
|
192
|
-
self.in_size =
|
193
|
-
self.out_size =
|
261
|
+
self.in_size = in_size
|
262
|
+
self.out_size = in_size
|
194
263
|
self.affine = affine
|
195
264
|
self.bias_initializer = bias_initializer
|
196
265
|
self.scale_initializer = scale_initializer
|
@@ -198,15 +267,17 @@ class _BatchNorm(Module):
|
|
198
267
|
self.track_running_stats = track_running_stats
|
199
268
|
self.momentum = jnp.asarray(momentum, dtype=self.dtype)
|
200
269
|
self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
|
270
|
+
self.use_fast_variance = use_fast_variance
|
201
271
|
|
202
272
|
# parameters about axis
|
203
273
|
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
204
|
-
self.
|
274
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
|
205
275
|
self.axis_name = axis_name
|
206
276
|
self.axis_index_groups = axis_index_groups
|
207
277
|
|
208
278
|
# variables
|
209
|
-
feature_shape = tuple([ax if i in self.
|
279
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
280
|
+
for i, ax in enumerate(self.in_size)])
|
210
281
|
if self.track_running_stats:
|
211
282
|
self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
|
212
283
|
self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
|
@@ -223,7 +294,7 @@ class _BatchNorm(Module):
|
|
223
294
|
else:
|
224
295
|
self.weight = None
|
225
296
|
|
226
|
-
def update(self, x):
|
297
|
+
def update(self, x, mask: Optional[jax.Array] = None):
|
227
298
|
# input shape and batch mode or not
|
228
299
|
if x.ndim == self.num_spatial_dims + 2:
|
229
300
|
x_shape = x.shape[1:]
|
@@ -239,9 +310,9 @@ class _BatchNorm(Module):
|
|
239
310
|
|
240
311
|
# reduce the feature axis
|
241
312
|
if batch:
|
242
|
-
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.
|
313
|
+
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
|
243
314
|
else:
|
244
|
-
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.
|
315
|
+
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
|
245
316
|
|
246
317
|
# fitting phase
|
247
318
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
@@ -255,6 +326,8 @@ class _BatchNorm(Module):
|
|
255
326
|
dtype=self.dtype,
|
256
327
|
axis_name=self.axis_name,
|
257
328
|
axis_index_groups=self.axis_index_groups,
|
329
|
+
use_fast_variance=self.use_fast_variance,
|
330
|
+
mask=mask,
|
258
331
|
)
|
259
332
|
self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
|
260
333
|
self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
|
@@ -265,14 +338,22 @@ class _BatchNorm(Module):
|
|
265
338
|
mean, var = None, None
|
266
339
|
|
267
340
|
# normalize
|
268
|
-
return _normalize(
|
341
|
+
return _normalize(
|
342
|
+
x,
|
343
|
+
mean=mean,
|
344
|
+
var=var,
|
345
|
+
weights=self.weight,
|
346
|
+
reduction_axes=reduction_axes,
|
347
|
+
feature_axes=self.feature_axes,
|
348
|
+
dtype=self.dtype,
|
349
|
+
epsilon=self.epsilon
|
350
|
+
)
|
269
351
|
|
270
352
|
|
271
353
|
class BatchNorm0d(_BatchNorm):
|
272
|
-
r"""
|
354
|
+
r"""0-D batch normalization [1]_.
|
273
355
|
|
274
|
-
The data should be of `(b,
|
275
|
-
`l` is the layer dimension, and `c` is the channel dimension.
|
356
|
+
The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
|
276
357
|
|
277
358
|
%s
|
278
359
|
"""
|
@@ -375,7 +456,10 @@ _bn_doc = r'''
|
|
375
456
|
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
376
457
|
the examples on the first two and last two devices. See `jax.lax.psum`
|
377
458
|
for more details.
|
378
|
-
|
459
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
460
|
+
calculation for the variance.
|
461
|
+
|
462
|
+
|
379
463
|
References
|
380
464
|
----------
|
381
465
|
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
|
@@ -386,3 +470,441 @@ _bn_doc = r'''
|
|
386
470
|
BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
|
387
471
|
BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
|
388
472
|
BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
|
473
|
+
|
474
|
+
|
475
|
+
class LayerNorm(Module):
|
476
|
+
"""
|
477
|
+
Layer normalization (https://arxiv.org/abs/1607.06450).
|
478
|
+
|
479
|
+
LayerNorm normalizes the activations of the layer for each given example in a
|
480
|
+
batch independently, rather than across a batch like Batch Normalization.
|
481
|
+
i.e. applies a transformation that maintains the mean activation within
|
482
|
+
each example close to 0 and the activation standard deviation close to 1.
|
483
|
+
|
484
|
+
Example usage::
|
485
|
+
|
486
|
+
>>> import brainstate as bst
|
487
|
+
>>> x = bst.random.normal(size=(3, 4, 5, 6))
|
488
|
+
>>> layer = bst.nn.LayerNorm(x.shape)
|
489
|
+
>>> layer.states()
|
490
|
+
>>> y = layer(x)
|
491
|
+
|
492
|
+
Attributes:
|
493
|
+
in_size: The input shape, without batch size.
|
494
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
495
|
+
dtype: the dtype of the result (default: infer from input and params).
|
496
|
+
use_bias: If True, bias (beta) is added.
|
497
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
498
|
+
(also e.g. nnx.relu), this can be disabled since the scaling will be done
|
499
|
+
by the next layer.
|
500
|
+
bias_init: Initializer for bias, by default, zero.
|
501
|
+
scale_init: Initializer for scale, by default, one.
|
502
|
+
reduction_axes: Axes for computing normalization statistics. It is recommended
|
503
|
+
to use the negative integer, since when the batch dimension is used,
|
504
|
+
the reduction_axes may be wrong when using the positive integer.
|
505
|
+
feature_axes: Feature axes for learned bias and scaling.
|
506
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
507
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
508
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
509
|
+
array being normalized is sharded across devices within a pmap.
|
510
|
+
axis_index_groups: groups of axis indices within that named axis
|
511
|
+
representing subsets of devices to reduce over (default: None). For
|
512
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
|
513
|
+
the examples on the first two and last two devices. See ``jax.lax.psum``
|
514
|
+
for more details.
|
515
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
516
|
+
calculation for the variance.
|
517
|
+
"""
|
518
|
+
|
519
|
+
def __init__(
|
520
|
+
self,
|
521
|
+
in_size: Size,
|
522
|
+
reduction_axes: Axes = -1,
|
523
|
+
feature_axes: Axes = -1,
|
524
|
+
*,
|
525
|
+
epsilon: float = 1e-6,
|
526
|
+
use_bias: bool = True,
|
527
|
+
use_scale: bool = True,
|
528
|
+
bias_init: Callable = init.ZeroInit(),
|
529
|
+
scale_init: Callable = init.Constant(1.0),
|
530
|
+
axis_name: Optional[str] = None,
|
531
|
+
axis_index_groups: Any = None,
|
532
|
+
use_fast_variance: bool = True,
|
533
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
534
|
+
):
|
535
|
+
super().__init__()
|
536
|
+
|
537
|
+
self.in_size = in_size
|
538
|
+
self.out_size = in_size
|
539
|
+
|
540
|
+
# parameters about axis
|
541
|
+
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
542
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
543
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
544
|
+
self.axis_name = axis_name
|
545
|
+
self.axis_index_groups = axis_index_groups
|
546
|
+
|
547
|
+
# variables
|
548
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
549
|
+
for i, ax in enumerate(self.in_size)])
|
550
|
+
|
551
|
+
weights = dict()
|
552
|
+
if use_scale:
|
553
|
+
weights['scale'] = init.param(scale_init, feature_shape)
|
554
|
+
if use_bias:
|
555
|
+
weights['bias'] = init.param(bias_init, feature_shape)
|
556
|
+
if len(weights):
|
557
|
+
self.weight = ParamState(weights)
|
558
|
+
else:
|
559
|
+
self.weight = None
|
560
|
+
|
561
|
+
# parameters
|
562
|
+
self.epsilon = epsilon
|
563
|
+
self.dtype = dtype or environ.dftype()
|
564
|
+
self.use_bias = use_bias
|
565
|
+
self.use_scale = use_scale
|
566
|
+
self.bias_init = bias_init
|
567
|
+
self.scale_init = scale_init
|
568
|
+
self.use_fast_variance = use_fast_variance
|
569
|
+
|
570
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
571
|
+
"""Applies layer normalization on the input.
|
572
|
+
|
573
|
+
Args:
|
574
|
+
x: the inputs
|
575
|
+
|
576
|
+
Returns:
|
577
|
+
Normalized inputs (the same shape as inputs).
|
578
|
+
"""
|
579
|
+
mean, var = _compute_stats(
|
580
|
+
x,
|
581
|
+
self.reduction_axes,
|
582
|
+
dtype=self.dtype,
|
583
|
+
axis_name=self.axis_name,
|
584
|
+
axis_index_groups=self.axis_index_groups,
|
585
|
+
use_fast_variance=self.use_fast_variance,
|
586
|
+
mask=mask,
|
587
|
+
)
|
588
|
+
|
589
|
+
return _normalize(
|
590
|
+
x,
|
591
|
+
mean=mean,
|
592
|
+
var=var,
|
593
|
+
weights=self.weight,
|
594
|
+
reduction_axes=self.reduction_axes,
|
595
|
+
feature_axes=self.feature_axes,
|
596
|
+
dtype=self.dtype,
|
597
|
+
epsilon=self.epsilon,
|
598
|
+
)
|
599
|
+
|
600
|
+
|
601
|
+
class RMSNorm(Module):
|
602
|
+
"""
|
603
|
+
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
|
604
|
+
|
605
|
+
RMSNorm normalizes the activations of the layer for each given example in a
|
606
|
+
batch independently, rather than across a batch like Batch Normalization.
|
607
|
+
Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
|
608
|
+
standard deviation of the activations, RMSNorm does not re-center at all
|
609
|
+
and instead normalizes by the root mean square of the activations.
|
610
|
+
|
611
|
+
Example usage::
|
612
|
+
|
613
|
+
>>> import brainstate as bst
|
614
|
+
>>> x = bst.random.normal(size=(5, 6))
|
615
|
+
>>> layer = bst.nn.RMSNorm(num_features=6)
|
616
|
+
>>> layer.states()
|
617
|
+
>>> y = layer(x)
|
618
|
+
|
619
|
+
Attributes:
|
620
|
+
in_size: The input shape, without batch size.
|
621
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
622
|
+
dtype: the dtype of the result (default: infer from input and params).
|
623
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
624
|
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
625
|
+
by the next layer.
|
626
|
+
scale_init: Initializer for scale, by default, one.
|
627
|
+
reduction_axes: Axes for computing normalization statistics. It is recommended
|
628
|
+
to use the negative integer, since when the batch dimension is used,
|
629
|
+
the reduction_axes may be wrong when using the positive integer.
|
630
|
+
feature_axes: Feature axes for learned bias and scaling.
|
631
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
632
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
633
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
634
|
+
array being normalized is sharded across devices within a pmap.
|
635
|
+
axis_index_groups: groups of axis indices within that named axis
|
636
|
+
representing subsets of devices to reduce over (default: None). For
|
637
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
|
638
|
+
the examples on the first two and last two devices. See ``jax.lax.psum``
|
639
|
+
for more details.
|
640
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
641
|
+
calculation for the variance.
|
642
|
+
"""
|
643
|
+
|
644
|
+
def __init__(
|
645
|
+
self,
|
646
|
+
in_size: Size,
|
647
|
+
*,
|
648
|
+
epsilon: float = 1e-6,
|
649
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
650
|
+
use_scale: bool = True,
|
651
|
+
scale_init: Callable = init.Constant(1.0),
|
652
|
+
reduction_axes: Axes = -1,
|
653
|
+
feature_axes: Axes = -1,
|
654
|
+
axis_name: Optional[str] = None,
|
655
|
+
axis_index_groups: Any = None,
|
656
|
+
use_fast_variance: bool = True,
|
657
|
+
):
|
658
|
+
super().__init__()
|
659
|
+
|
660
|
+
self.in_size = in_size
|
661
|
+
self.out_size = in_size
|
662
|
+
|
663
|
+
# parameters about axis
|
664
|
+
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
665
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
666
|
+
self.reduction_axes = (reduction_axes, ) if isinstance(reduction_axes, int) else reduction_axes
|
667
|
+
self.axis_name = axis_name
|
668
|
+
self.axis_index_groups = axis_index_groups
|
669
|
+
|
670
|
+
# variables
|
671
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
672
|
+
for i, ax in enumerate(self.in_size)])
|
673
|
+
if use_scale:
|
674
|
+
self.scale = ParamState({'scale': init.param(scale_init, feature_shape)})
|
675
|
+
else:
|
676
|
+
self.scale = None
|
677
|
+
|
678
|
+
# parameters
|
679
|
+
self.epsilon = epsilon
|
680
|
+
self.dtype = dtype or environ.dftype()
|
681
|
+
self.use_scale = use_scale
|
682
|
+
self.scale_init = scale_init
|
683
|
+
self.use_fast_variance = use_fast_variance
|
684
|
+
|
685
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
686
|
+
"""Applies layer normalization on the input.
|
687
|
+
|
688
|
+
Args:
|
689
|
+
x: the inputs
|
690
|
+
mask: the mask
|
691
|
+
|
692
|
+
Returns:
|
693
|
+
Normalized inputs (the same shape as inputs).
|
694
|
+
"""
|
695
|
+
mean, var = _compute_stats(
|
696
|
+
x,
|
697
|
+
self.reduction_axes,
|
698
|
+
dtype=self.dtype,
|
699
|
+
axis_name=self.axis_name,
|
700
|
+
axis_index_groups=self.axis_index_groups,
|
701
|
+
use_mean=False,
|
702
|
+
use_fast_variance=self.use_fast_variance,
|
703
|
+
mask=mask,
|
704
|
+
)
|
705
|
+
|
706
|
+
return _normalize(
|
707
|
+
x,
|
708
|
+
mean=mean,
|
709
|
+
var=var,
|
710
|
+
weights=self.scale,
|
711
|
+
reduction_axes=self.reduction_axes,
|
712
|
+
feature_axes=self.feature_axes,
|
713
|
+
dtype=self.dtype,
|
714
|
+
epsilon=self.epsilon,
|
715
|
+
)
|
716
|
+
|
717
|
+
|
718
|
+
class GroupNorm(Module):
|
719
|
+
"""
|
720
|
+
Group normalization (arxiv.org/abs/1803.08494).
|
721
|
+
|
722
|
+
This op is similar to batch normalization, but statistics are shared across
|
723
|
+
equally-sized groups of channels and not shared across batch dimension.
|
724
|
+
Thus, group normalization does not depend on the batch composition and does
|
725
|
+
not require maintaining internal state for storing statistics.
|
726
|
+
The user should either specify the total number of channel groups or the
|
727
|
+
number of channels per group.
|
728
|
+
|
729
|
+
.. note::
|
730
|
+
LayerNorm is a special case of GroupNorm where ``num_groups=1``.
|
731
|
+
|
732
|
+
Example usage::
|
733
|
+
|
734
|
+
>>> import numpy as np
|
735
|
+
>>> import brainstate as bst
|
736
|
+
...
|
737
|
+
>>> x = bst.random.normal(size=(3, 4, 5, 6))
|
738
|
+
>>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
|
739
|
+
>>> layer.states()
|
740
|
+
>>> y = layer(x)
|
741
|
+
>>> y = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
|
742
|
+
>>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
743
|
+
>>> np.testing.assert_allclose(y, y2)
|
744
|
+
|
745
|
+
Attributes:
|
746
|
+
in_size: The input shape, without batch size.
|
747
|
+
num_groups: the total number of channel groups. The default value of 32 is
|
748
|
+
proposed by the original group normalization paper.
|
749
|
+
group_size: the number of channels in a group.
|
750
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
751
|
+
dtype: the dtype of the result (default: infer from input and params).
|
752
|
+
use_bias: If True, bias (beta) is added.
|
753
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
754
|
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
755
|
+
by the next layer.
|
756
|
+
bias_init: Initializer for bias, by default, zero.
|
757
|
+
scale_init: Initializer for scale, by default, one.
|
758
|
+
reduction_axes: List of axes used for computing normalization statistics.
|
759
|
+
This list must include the final dimension, which is assumed to be the
|
760
|
+
feature axis. Furthermore, if the input used at call time has additional
|
761
|
+
leading axes compared to the data used for initialisation, for example due
|
762
|
+
to batching, then the reduction axes need to be defined explicitly.
|
763
|
+
It is recommended to use the negative integer, since when the batch dimension is used,
|
764
|
+
the reduction_axes may be wrong when using the positive integer.
|
765
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
766
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
767
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
768
|
+
array being normalized is sharded across devices within a pmap or shard
|
769
|
+
map. For SPMD jit, you do not need to manually synchronize. Just make sure
|
770
|
+
that the axes are correctly annotated and XLA:SPMD will insert the
|
771
|
+
necessary collectives.
|
772
|
+
axis_index_groups: groups of axis indices within that named axis
|
773
|
+
representing subsets of devices to reduce over (default: None). For
|
774
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
|
775
|
+
examples on the first two and last two devices. See ``jax.lax.psum`` for
|
776
|
+
more details.
|
777
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
778
|
+
calculation for the variance.
|
779
|
+
"""
|
780
|
+
|
781
|
+
def __init__(
|
782
|
+
self,
|
783
|
+
in_size: Size,
|
784
|
+
feature_axis: Axes = -1,
|
785
|
+
num_groups: Optional[int] = 32,
|
786
|
+
group_size: Optional[int] = None,
|
787
|
+
*,
|
788
|
+
epsilon: float = 1e-6,
|
789
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
790
|
+
use_bias: bool = True,
|
791
|
+
use_scale: bool = True,
|
792
|
+
bias_init: Callable = init.ZeroInit(),
|
793
|
+
scale_init: Callable = init.Constant(1.),
|
794
|
+
reduction_axes: Optional[Axes] = None,
|
795
|
+
axis_name: Optional[str] = None,
|
796
|
+
axis_index_groups: Any = None,
|
797
|
+
use_fast_variance: bool = True,
|
798
|
+
):
|
799
|
+
super().__init__()
|
800
|
+
|
801
|
+
self.in_size = in_size
|
802
|
+
self.out_size = in_size
|
803
|
+
|
804
|
+
# parameters about axis
|
805
|
+
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
806
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
|
807
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
808
|
+
self.axis_name = axis_name
|
809
|
+
self.axis_index_groups = axis_index_groups
|
810
|
+
|
811
|
+
if (num_groups is None and group_size is None) or (
|
812
|
+
num_groups is not None and group_size is not None
|
813
|
+
):
|
814
|
+
raise ValueError(
|
815
|
+
'Either `num_groups` or `group_size` should be '
|
816
|
+
'specified. If `group_size` is to be specified, '
|
817
|
+
'pass `num_groups=None` as argument to override '
|
818
|
+
'the default `num_groups` value of 32.'
|
819
|
+
)
|
820
|
+
|
821
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
822
|
+
for i, ax in enumerate(self.in_size)])
|
823
|
+
assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
|
824
|
+
num_features = feature_shape[0]
|
825
|
+
if group_size is not None:
|
826
|
+
if num_features % group_size != 0:
|
827
|
+
raise ValueError(
|
828
|
+
'Number of features ({}) is not multiple of the '
|
829
|
+
'group size ({}).'.format(num_features, group_size)
|
830
|
+
)
|
831
|
+
self.num_groups = num_features // group_size
|
832
|
+
self.group_size = group_size
|
833
|
+
else:
|
834
|
+
if not isinstance(num_groups, int) or num_groups <= 0 or (
|
835
|
+
num_features % num_groups != 0
|
836
|
+
):
|
837
|
+
raise ValueError(
|
838
|
+
'Number of groups ({}) does not divide the number'
|
839
|
+
' of channels ({}).'.format(num_groups, num_features)
|
840
|
+
)
|
841
|
+
self.num_groups = num_groups
|
842
|
+
self.group_size = num_features // num_groups
|
843
|
+
|
844
|
+
# variables
|
845
|
+
weights = dict()
|
846
|
+
if use_scale:
|
847
|
+
weights['scale'] = init.param(scale_init, feature_shape)
|
848
|
+
if use_bias:
|
849
|
+
weights['bias'] = init.param(bias_init, feature_shape)
|
850
|
+
if len(weights):
|
851
|
+
self.weight = ParamState(weights)
|
852
|
+
else:
|
853
|
+
self.weight = None
|
854
|
+
|
855
|
+
# parameters
|
856
|
+
self.epsilon = epsilon
|
857
|
+
self.dtype = dtype
|
858
|
+
self.use_bias = use_bias
|
859
|
+
self.use_scale = use_scale
|
860
|
+
self.bias_init = bias_init
|
861
|
+
self.scale_init = scale_init
|
862
|
+
self.use_fast_variance = use_fast_variance
|
863
|
+
|
864
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
865
|
+
"""Applies group normalization to the input (arxiv.org/abs/1803.08494).
|
866
|
+
|
867
|
+
Args:
|
868
|
+
x: the input of shape ``...self.num_features`` where ``self.num_features``
|
869
|
+
is a channels dimension and ``...`` represents an arbitrary number of
|
870
|
+
extra dimensions that can be used to accumulate statistics over. If no
|
871
|
+
reduction axes have been specified then all additional dimensions ``...``
|
872
|
+
will be used to accumulate statistics apart from the leading dimension
|
873
|
+
which is assumed to represent the batch.
|
874
|
+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
|
875
|
+
the positions for which the mean and variance should be computed.
|
876
|
+
|
877
|
+
Returns:
|
878
|
+
Normalized inputs (the same shape as inputs).
|
879
|
+
"""
|
880
|
+
if self.reduction_axes is not None:
|
881
|
+
reduction_axes = self.reduction_axes
|
882
|
+
else:
|
883
|
+
reduction_axes = list(range(1, x.ndim - 1)) + [-1]
|
884
|
+
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
|
885
|
+
|
886
|
+
group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
|
887
|
+
if mask is not None:
|
888
|
+
mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
|
889
|
+
|
890
|
+
mean, var = _compute_stats(
|
891
|
+
x.reshape(group_shape),
|
892
|
+
list(reduction_axes[:-1]) + [-1],
|
893
|
+
dtype=self.dtype,
|
894
|
+
axis_name=self.axis_name,
|
895
|
+
axis_index_groups=self.axis_index_groups,
|
896
|
+
use_fast_variance=self.use_fast_variance,
|
897
|
+
mask=mask,
|
898
|
+
)
|
899
|
+
mean = jnp.repeat(mean, self.group_size, axis=1)
|
900
|
+
var = jnp.repeat(var, self.group_size, axis=1)
|
901
|
+
return _normalize(
|
902
|
+
x,
|
903
|
+
mean=mean,
|
904
|
+
var=var,
|
905
|
+
weights=self.weight,
|
906
|
+
reduction_axes=reduction_axes[:-1],
|
907
|
+
feature_axes=self.feature_axes,
|
908
|
+
dtype=self.dtype,
|
909
|
+
epsilon=self.epsilon,
|
910
|
+
)
|