brainstate 0.1.0.post20241125__py2.py3-none-any.whl → 0.1.0.post20241209__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +15 -1
- brainstate/augment/_mapping_test.py +84 -0
- brainstate/compile/_loop_collect_return.py +5 -1
- brainstate/compile/_make_jaxpr.py +30 -25
- brainstate/compile/_progress_bar.py +30 -12
- brainstate/functional/_activations.py +4 -12
- brainstate/graph/_graph_operation.py +4 -1
- brainstate/nn/_collective_ops.py +18 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -1
- brainstate/nn/_elementwise/_dropout.py +31 -22
- brainstate/nn/_interaction/_normalizations.py +598 -66
- brainstate/util/_tracers.py +0 -7
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/RECORD +17 -19
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/top_level.txt +0 -1
- benchmark/COBA_2005.py +0 -125
- benchmark/CUBA_2005.py +0 -149
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241125.dist-info → brainstate-0.1.0.post20241209.dist-info}/WHEEL +0 -0
@@ -17,22 +17,61 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import
|
21
|
-
from typing import Callable, Union, Sequence, Optional, Any
|
20
|
+
from typing import Callable, Union, Sequence, Optional, Any, Dict
|
22
21
|
|
23
22
|
import jax
|
24
23
|
import jax.numpy as jnp
|
25
24
|
|
26
25
|
from brainstate import environ, init
|
27
|
-
from brainstate._state import
|
26
|
+
from brainstate._state import ParamState, BatchState
|
28
27
|
from brainstate.nn._module import Module
|
29
28
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
30
29
|
|
31
30
|
__all__ = [
|
32
|
-
'BatchNorm0d',
|
31
|
+
'BatchNorm0d',
|
32
|
+
'BatchNorm1d',
|
33
|
+
'BatchNorm2d',
|
34
|
+
'BatchNorm3d',
|
35
|
+
'LayerNorm',
|
36
|
+
'RMSNorm',
|
37
|
+
'GroupNorm',
|
33
38
|
]
|
34
39
|
|
35
40
|
|
41
|
+
def canonicalize_dtype(
|
42
|
+
*args,
|
43
|
+
dtype: jax.typing.DTypeLike | None = None,
|
44
|
+
inexact: bool = True
|
45
|
+
) -> jax.typing.DTypeLike:
|
46
|
+
"""Canonicalize an optional dtype to the definitive dtype.
|
47
|
+
|
48
|
+
If the ``dtype`` is None this function will infer the dtype. If it is not
|
49
|
+
None it will be returned unmodified or an exceptions is raised if the dtype
|
50
|
+
is invalid.
|
51
|
+
from the input arguments using ``jnp.result_type``.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
*args: JAX array compatible values. None values
|
55
|
+
are ignored.
|
56
|
+
dtype: Optional dtype override. If specified the arguments are cast to
|
57
|
+
the specified dtype instead and dtype inference is disabled.
|
58
|
+
inexact: When True, the output dtype must be a subdtype
|
59
|
+
of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
|
60
|
+
is useful when you want to apply operations that don't work directly on
|
61
|
+
integers like taking a mean for example.
|
62
|
+
Returns:
|
63
|
+
The dtype that *args should be cast to.
|
64
|
+
"""
|
65
|
+
if dtype is None:
|
66
|
+
args_filtered = [jnp.asarray(x) for x in args if x is not None]
|
67
|
+
dtype = jnp.result_type(*args_filtered)
|
68
|
+
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
|
69
|
+
dtype = jnp.promote_types(jnp.float32, dtype)
|
70
|
+
if inexact and not jnp.issubdtype(dtype, jnp.inexact):
|
71
|
+
raise ValueError(f'Dtype must be inexact: {dtype}')
|
72
|
+
return dtype
|
73
|
+
|
74
|
+
|
36
75
|
def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
|
37
76
|
axes = []
|
38
77
|
for axis in feature_axes:
|
@@ -52,6 +91,18 @@ def _abs_sq(x):
|
|
52
91
|
return jax.lax.square(x)
|
53
92
|
|
54
93
|
|
94
|
+
class NormalizationParamState(ParamState):
|
95
|
+
# This is a dummy class to be used as a compatibility
|
96
|
+
# usage of `ETraceParam` for the layers in "brainetrace"
|
97
|
+
def execute(self, x):
|
98
|
+
param = self.value
|
99
|
+
if 'scale' in param:
|
100
|
+
x = x * param['scale']
|
101
|
+
if 'bias' in param:
|
102
|
+
x = x + param['bias']
|
103
|
+
return x
|
104
|
+
|
105
|
+
|
55
106
|
def _compute_stats(
|
56
107
|
x: ArrayLike,
|
57
108
|
axes: Sequence[int],
|
@@ -59,28 +110,38 @@ def _compute_stats(
|
|
59
110
|
axis_name: Optional[str] = None,
|
60
111
|
axis_index_groups: Optional[Sequence[int]] = None,
|
61
112
|
use_mean: bool = True,
|
113
|
+
use_fast_variance: bool = True,
|
114
|
+
mask: Optional[jax.Array] = None,
|
62
115
|
):
|
63
|
-
"""
|
116
|
+
"""
|
117
|
+
Computes mean and variance statistics.
|
64
118
|
|
65
119
|
This implementation takes care of a few important details:
|
66
120
|
- Computes in float32 precision for stability in half precision training.
|
67
|
-
- mean and variance are
|
68
|
-
|
121
|
+
- If ``use_fast_variance`` is ``True``, mean and variance are computed using
|
122
|
+
Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
|
69
123
|
- Clips negative variances to zero which can happen due to
|
70
124
|
roundoff errors. This avoids downstream NaNs.
|
71
125
|
- Supports averaging across a parallel axis and subgroups of a parallel axis
|
72
126
|
with a single `lax.pmean` call to avoid latency.
|
73
127
|
|
74
128
|
Arguments:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
129
|
+
x: Input array.
|
130
|
+
axes: The axes in ``x`` to compute mean and variance statistics for.
|
131
|
+
dtype: tp.Optional dtype specifying the minimal precision. Statistics
|
132
|
+
are always at least float32 for stability (default: dtype of x).
|
133
|
+
axis_name: Optional name for the pmapped axis to compute mean over. Note,
|
134
|
+
this is only used for pmap and shard map. For SPMD jit, you do not need to
|
135
|
+
manually synchronize. Just make sure that the axes are correctly annotated
|
136
|
+
and XLA:SPMD will insert the necessary collectives.
|
137
|
+
axis_index_groups: Optional axis indices.
|
138
|
+
use_mean: If true, calculate the mean from the input and use it when
|
139
|
+
computing the variance. If false, set the mean to zero and compute
|
140
|
+
the variance without subtracting the mean.
|
141
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
142
|
+
calculation for the variance.
|
143
|
+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
|
144
|
+
the positions for which the mean and variance should be computed.
|
84
145
|
|
85
146
|
Returns:
|
86
147
|
A pair ``(mean, val)``.
|
@@ -89,42 +150,54 @@ def _compute_stats(
|
|
89
150
|
dtype = jax.numpy.result_type(x)
|
90
151
|
# promote x to at least float32, this avoids half precision computation
|
91
152
|
# but preserves double or complex floating points
|
92
|
-
dtype = jax.numpy.promote_types(dtype,
|
153
|
+
dtype = jax.numpy.promote_types(dtype, jnp.float32)
|
93
154
|
x = jnp.asarray(x, dtype)
|
155
|
+
axes = _canonicalize_axes(x.ndim, axes)
|
156
|
+
|
157
|
+
def maybe_distributed_mean(*xs, mask=None):
|
158
|
+
mus = tuple(x.mean(axes, where=mask) for x in xs)
|
159
|
+
if axis_name is None:
|
160
|
+
return mus if len(xs) > 1 else mus[0]
|
161
|
+
else:
|
162
|
+
# In the distributed case we stack multiple arrays to speed comms.
|
163
|
+
if len(xs) > 1:
|
164
|
+
reduced_mus = jax.lax.pmean(
|
165
|
+
jnp.stack(mus, axis=0),
|
166
|
+
axis_name,
|
167
|
+
axis_index_groups=axis_index_groups,
|
168
|
+
)
|
169
|
+
return tuple(reduced_mus[i] for i in range(len(xs)))
|
170
|
+
else:
|
171
|
+
return jax.lax.pmean(
|
172
|
+
mus[0],
|
173
|
+
axis_name,
|
174
|
+
axis_index_groups=axis_index_groups
|
175
|
+
)
|
94
176
|
|
95
|
-
# Compute mean and mean of squared values.
|
96
|
-
mean2 = jnp.mean(_abs_sq(x), axes)
|
97
177
|
if use_mean:
|
98
|
-
|
178
|
+
if use_fast_variance:
|
179
|
+
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
|
180
|
+
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
181
|
+
# to floating point round-off errors.
|
182
|
+
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
|
183
|
+
else:
|
184
|
+
mu = maybe_distributed_mean(x, mask=mask)
|
185
|
+
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
|
99
186
|
else:
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
if axis_name is not None:
|
104
|
-
concatenated_mean = jnp.concatenate([mean, mean2])
|
105
|
-
mean, mean2 = jnp.split(
|
106
|
-
jax.lax.pmean(
|
107
|
-
concatenated_mean,
|
108
|
-
axis_name=axis_name,
|
109
|
-
axis_index_groups=axis_index_groups,
|
110
|
-
),
|
111
|
-
2,
|
112
|
-
)
|
113
|
-
|
114
|
-
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
115
|
-
# to floating point round-off errors.
|
116
|
-
var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
|
117
|
-
return mean, var
|
187
|
+
var = maybe_distributed_mean(_abs_sq(x), mask=mask)
|
188
|
+
mu = jnp.zeros_like(var)
|
189
|
+
return mu, var
|
118
190
|
|
119
191
|
|
120
192
|
def _normalize(
|
121
193
|
x: ArrayLike,
|
122
194
|
mean: Optional[ArrayLike],
|
123
195
|
var: Optional[ArrayLike],
|
124
|
-
weights: Optional[
|
125
|
-
reduction_axes:
|
196
|
+
weights: Optional[NormalizationParamState],
|
197
|
+
reduction_axes: Axes,
|
198
|
+
feature_axes: Axes,
|
126
199
|
dtype: DTypeLike,
|
127
|
-
epsilon:
|
200
|
+
epsilon: jax.typing.ArrayLike,
|
128
201
|
):
|
129
202
|
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
|
130
203
|
|
@@ -134,6 +207,7 @@ def _normalize(
|
|
134
207
|
var: Variance to use for normalization.
|
135
208
|
weights: The scale and bias parameters.
|
136
209
|
reduction_axes: The axes in ``x`` to reduce.
|
210
|
+
feature_axes: The feature axes to apply the scale and bias.
|
137
211
|
dtype: The dtype of the result (default: infer from input and params).
|
138
212
|
epsilon: Normalization epsilon.
|
139
213
|
|
@@ -142,16 +216,22 @@ def _normalize(
|
|
142
216
|
"""
|
143
217
|
if mean is not None:
|
144
218
|
assert var is not None, 'mean and val must be both None or not None.'
|
219
|
+
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
|
220
|
+
feature_axes = _canonicalize_axes(x.ndim, feature_axes)
|
145
221
|
stats_shape = list(x.shape)
|
146
222
|
for axis in reduction_axes:
|
147
223
|
stats_shape[axis] = 1
|
148
224
|
mean = mean.reshape(stats_shape)
|
149
225
|
var = var.reshape(stats_shape)
|
226
|
+
feature_shape = [1] * x.ndim
|
227
|
+
for ax in feature_axes:
|
228
|
+
feature_shape[ax] = x.shape[ax]
|
150
229
|
y = x - mean
|
151
|
-
mul = jax.lax.rsqrt(var +
|
230
|
+
mul = jax.lax.rsqrt(var + epsilon)
|
152
231
|
y = y * mul
|
153
232
|
if weights is not None:
|
154
|
-
y =
|
233
|
+
y = weights.execute(y)
|
234
|
+
dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
|
155
235
|
else:
|
156
236
|
assert var is None, 'mean and val must be both None or not None.'
|
157
237
|
assert weights is None, 'scale and bias are not supported without mean and val'
|
@@ -159,14 +239,6 @@ def _normalize(
|
|
159
239
|
return jnp.asarray(y, dtype)
|
160
240
|
|
161
241
|
|
162
|
-
def _scale_operation(x, param):
|
163
|
-
if 'scale' in param:
|
164
|
-
x = x * param['scale']
|
165
|
-
if 'bias' in param:
|
166
|
-
x = x + param['bias']
|
167
|
-
return x
|
168
|
-
|
169
|
-
|
170
242
|
class _BatchNorm(Module):
|
171
243
|
__module__ = 'brainstate.nn'
|
172
244
|
num_spatial_dims: int
|
@@ -175,6 +247,7 @@ class _BatchNorm(Module):
|
|
175
247
|
self,
|
176
248
|
in_size: Size,
|
177
249
|
feature_axis: Axes = -1,
|
250
|
+
*,
|
178
251
|
track_running_stats: bool = True,
|
179
252
|
epsilon: float = 1e-5,
|
180
253
|
momentum: float = 0.99,
|
@@ -183,14 +256,17 @@ class _BatchNorm(Module):
|
|
183
256
|
scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
|
184
257
|
axis_name: Optional[Union[str, Sequence[str]]] = None,
|
185
258
|
axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
|
259
|
+
use_fast_variance: bool = True,
|
186
260
|
name: Optional[str] = None,
|
187
261
|
dtype: Any = None,
|
262
|
+
param_type: type = NormalizationParamState,
|
263
|
+
mean_type: type = BatchState,
|
188
264
|
):
|
189
265
|
super().__init__(name=name)
|
190
266
|
|
191
267
|
# parameters
|
192
|
-
self.in_size =
|
193
|
-
self.out_size =
|
268
|
+
self.in_size = in_size
|
269
|
+
self.out_size = in_size
|
194
270
|
self.affine = affine
|
195
271
|
self.bias_initializer = bias_initializer
|
196
272
|
self.scale_initializer = scale_initializer
|
@@ -198,18 +274,20 @@ class _BatchNorm(Module):
|
|
198
274
|
self.track_running_stats = track_running_stats
|
199
275
|
self.momentum = jnp.asarray(momentum, dtype=self.dtype)
|
200
276
|
self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
|
277
|
+
self.use_fast_variance = use_fast_variance
|
201
278
|
|
202
279
|
# parameters about axis
|
203
280
|
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
204
|
-
self.
|
281
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
|
205
282
|
self.axis_name = axis_name
|
206
283
|
self.axis_index_groups = axis_index_groups
|
207
284
|
|
208
285
|
# variables
|
209
|
-
feature_shape = tuple([ax if i in self.
|
286
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
287
|
+
for i, ax in enumerate(self.in_size)])
|
210
288
|
if self.track_running_stats:
|
211
|
-
self.running_mean =
|
212
|
-
self.running_var =
|
289
|
+
self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
|
290
|
+
self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
|
213
291
|
else:
|
214
292
|
self.running_mean = None
|
215
293
|
self.running_var = None
|
@@ -219,11 +297,11 @@ class _BatchNorm(Module):
|
|
219
297
|
assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
|
220
298
|
bias = init.param(self.bias_initializer, feature_shape)
|
221
299
|
scale = init.param(self.scale_initializer, feature_shape)
|
222
|
-
self.weight =
|
300
|
+
self.weight = param_type(dict(bias=bias, scale=scale))
|
223
301
|
else:
|
224
302
|
self.weight = None
|
225
303
|
|
226
|
-
def update(self, x):
|
304
|
+
def update(self, x, mask: Optional[jax.Array] = None):
|
227
305
|
# input shape and batch mode or not
|
228
306
|
if x.ndim == self.num_spatial_dims + 2:
|
229
307
|
x_shape = x.shape[1:]
|
@@ -239,9 +317,9 @@ class _BatchNorm(Module):
|
|
239
317
|
|
240
318
|
# reduce the feature axis
|
241
319
|
if batch:
|
242
|
-
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.
|
320
|
+
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
|
243
321
|
else:
|
244
|
-
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.
|
322
|
+
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
|
245
323
|
|
246
324
|
# fitting phase
|
247
325
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
@@ -255,6 +333,8 @@ class _BatchNorm(Module):
|
|
255
333
|
dtype=self.dtype,
|
256
334
|
axis_name=self.axis_name,
|
257
335
|
axis_index_groups=self.axis_index_groups,
|
336
|
+
use_fast_variance=self.use_fast_variance,
|
337
|
+
mask=mask,
|
258
338
|
)
|
259
339
|
self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
|
260
340
|
self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
|
@@ -265,14 +345,22 @@ class _BatchNorm(Module):
|
|
265
345
|
mean, var = None, None
|
266
346
|
|
267
347
|
# normalize
|
268
|
-
return _normalize(
|
348
|
+
return _normalize(
|
349
|
+
x,
|
350
|
+
mean=mean,
|
351
|
+
var=var,
|
352
|
+
weights=self.weight,
|
353
|
+
reduction_axes=reduction_axes,
|
354
|
+
feature_axes=self.feature_axes,
|
355
|
+
dtype=self.dtype,
|
356
|
+
epsilon=self.epsilon
|
357
|
+
)
|
269
358
|
|
270
359
|
|
271
360
|
class BatchNorm0d(_BatchNorm):
|
272
|
-
r"""
|
361
|
+
r"""0-D batch normalization [1]_.
|
273
362
|
|
274
|
-
The data should be of `(b,
|
275
|
-
`l` is the layer dimension, and `c` is the channel dimension.
|
363
|
+
The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
|
276
364
|
|
277
365
|
%s
|
278
366
|
"""
|
@@ -375,7 +463,10 @@ _bn_doc = r'''
|
|
375
463
|
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
376
464
|
the examples on the first two and last two devices. See `jax.lax.psum`
|
377
465
|
for more details.
|
378
|
-
|
466
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
467
|
+
calculation for the variance.
|
468
|
+
|
469
|
+
|
379
470
|
References
|
380
471
|
----------
|
381
472
|
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
|
@@ -386,3 +477,444 @@ _bn_doc = r'''
|
|
386
477
|
BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
|
387
478
|
BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
|
388
479
|
BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
|
480
|
+
|
481
|
+
|
482
|
+
class LayerNorm(Module):
|
483
|
+
"""
|
484
|
+
Layer normalization (https://arxiv.org/abs/1607.06450).
|
485
|
+
|
486
|
+
LayerNorm normalizes the activations of the layer for each given example in a
|
487
|
+
batch independently, rather than across a batch like Batch Normalization.
|
488
|
+
i.e. applies a transformation that maintains the mean activation within
|
489
|
+
each example close to 0 and the activation standard deviation close to 1.
|
490
|
+
|
491
|
+
Example usage::
|
492
|
+
|
493
|
+
>>> import brainstate as bst
|
494
|
+
>>> x = bst.random.normal(size=(3, 4, 5, 6))
|
495
|
+
>>> layer = bst.nn.LayerNorm(x.shape)
|
496
|
+
>>> layer.states()
|
497
|
+
>>> y = layer(x)
|
498
|
+
|
499
|
+
Attributes:
|
500
|
+
in_size: The input shape, without batch size.
|
501
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
502
|
+
dtype: the dtype of the result (default: infer from input and params).
|
503
|
+
use_bias: If True, bias (beta) is added.
|
504
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
505
|
+
(also e.g. nnx.relu), this can be disabled since the scaling will be done
|
506
|
+
by the next layer.
|
507
|
+
bias_init: Initializer for bias, by default, zero.
|
508
|
+
scale_init: Initializer for scale, by default, one.
|
509
|
+
reduction_axes: Axes for computing normalization statistics. It is recommended
|
510
|
+
to use the negative integer, since when the batch dimension is used,
|
511
|
+
the reduction_axes may be wrong when using the positive integer.
|
512
|
+
feature_axes: Feature axes for learned bias and scaling.
|
513
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
514
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
515
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
516
|
+
array being normalized is sharded across devices within a pmap.
|
517
|
+
axis_index_groups: groups of axis indices within that named axis
|
518
|
+
representing subsets of devices to reduce over (default: None). For
|
519
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
|
520
|
+
the examples on the first two and last two devices. See ``jax.lax.psum``
|
521
|
+
for more details.
|
522
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
523
|
+
calculation for the variance.
|
524
|
+
"""
|
525
|
+
|
526
|
+
def __init__(
|
527
|
+
self,
|
528
|
+
in_size: Size,
|
529
|
+
reduction_axes: Axes = -1,
|
530
|
+
feature_axes: Axes = -1,
|
531
|
+
*,
|
532
|
+
epsilon: float = 1e-6,
|
533
|
+
use_bias: bool = True,
|
534
|
+
use_scale: bool = True,
|
535
|
+
bias_init: Callable = init.ZeroInit(),
|
536
|
+
scale_init: Callable = init.Constant(1.0),
|
537
|
+
axis_name: Optional[str] = None,
|
538
|
+
axis_index_groups: Any = None,
|
539
|
+
use_fast_variance: bool = True,
|
540
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
541
|
+
param_type: type = NormalizationParamState,
|
542
|
+
):
|
543
|
+
super().__init__()
|
544
|
+
|
545
|
+
self.in_size = in_size
|
546
|
+
self.out_size = in_size
|
547
|
+
|
548
|
+
# parameters about axis
|
549
|
+
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
550
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
551
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
552
|
+
self.axis_name = axis_name
|
553
|
+
self.axis_index_groups = axis_index_groups
|
554
|
+
|
555
|
+
# variables
|
556
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
557
|
+
for i, ax in enumerate(self.in_size)])
|
558
|
+
|
559
|
+
weights = dict()
|
560
|
+
if use_scale:
|
561
|
+
weights['scale'] = init.param(scale_init, feature_shape)
|
562
|
+
if use_bias:
|
563
|
+
weights['bias'] = init.param(bias_init, feature_shape)
|
564
|
+
if len(weights):
|
565
|
+
self.weight = param_type(weights)
|
566
|
+
else:
|
567
|
+
self.weight = None
|
568
|
+
|
569
|
+
# parameters
|
570
|
+
self.epsilon = epsilon
|
571
|
+
self.dtype = dtype or environ.dftype()
|
572
|
+
self.use_bias = use_bias
|
573
|
+
self.use_scale = use_scale
|
574
|
+
self.bias_init = bias_init
|
575
|
+
self.scale_init = scale_init
|
576
|
+
self.use_fast_variance = use_fast_variance
|
577
|
+
|
578
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
579
|
+
"""Applies layer normalization on the input.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
x: the inputs
|
583
|
+
|
584
|
+
Returns:
|
585
|
+
Normalized inputs (the same shape as inputs).
|
586
|
+
"""
|
587
|
+
mean, var = _compute_stats(
|
588
|
+
x,
|
589
|
+
self.reduction_axes,
|
590
|
+
dtype=self.dtype,
|
591
|
+
axis_name=self.axis_name,
|
592
|
+
axis_index_groups=self.axis_index_groups,
|
593
|
+
use_fast_variance=self.use_fast_variance,
|
594
|
+
mask=mask,
|
595
|
+
)
|
596
|
+
|
597
|
+
return _normalize(
|
598
|
+
x,
|
599
|
+
mean=mean,
|
600
|
+
var=var,
|
601
|
+
weights=self.weight,
|
602
|
+
reduction_axes=self.reduction_axes,
|
603
|
+
feature_axes=self.feature_axes,
|
604
|
+
dtype=self.dtype,
|
605
|
+
epsilon=self.epsilon,
|
606
|
+
)
|
607
|
+
|
608
|
+
|
609
|
+
class RMSNorm(Module):
|
610
|
+
"""
|
611
|
+
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
|
612
|
+
|
613
|
+
RMSNorm normalizes the activations of the layer for each given example in a
|
614
|
+
batch independently, rather than across a batch like Batch Normalization.
|
615
|
+
Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
|
616
|
+
standard deviation of the activations, RMSNorm does not re-center at all
|
617
|
+
and instead normalizes by the root mean square of the activations.
|
618
|
+
|
619
|
+
Example usage::
|
620
|
+
|
621
|
+
>>> import brainstate as bst
|
622
|
+
>>> x = bst.random.normal(size=(5, 6))
|
623
|
+
>>> layer = bst.nn.RMSNorm(num_features=6)
|
624
|
+
>>> layer.states()
|
625
|
+
>>> y = layer(x)
|
626
|
+
|
627
|
+
Attributes:
|
628
|
+
in_size: The input shape, without batch size.
|
629
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
630
|
+
dtype: the dtype of the result (default: infer from input and params).
|
631
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
632
|
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
633
|
+
by the next layer.
|
634
|
+
scale_init: Initializer for scale, by default, one.
|
635
|
+
reduction_axes: Axes for computing normalization statistics. It is recommended
|
636
|
+
to use the negative integer, since when the batch dimension is used,
|
637
|
+
the reduction_axes may be wrong when using the positive integer.
|
638
|
+
feature_axes: Feature axes for learned bias and scaling.
|
639
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
640
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
641
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
642
|
+
array being normalized is sharded across devices within a pmap.
|
643
|
+
axis_index_groups: groups of axis indices within that named axis
|
644
|
+
representing subsets of devices to reduce over (default: None). For
|
645
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
|
646
|
+
the examples on the first two and last two devices. See ``jax.lax.psum``
|
647
|
+
for more details.
|
648
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
649
|
+
calculation for the variance.
|
650
|
+
"""
|
651
|
+
|
652
|
+
def __init__(
|
653
|
+
self,
|
654
|
+
in_size: Size,
|
655
|
+
*,
|
656
|
+
epsilon: float = 1e-6,
|
657
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
658
|
+
use_scale: bool = True,
|
659
|
+
scale_init: Callable = init.Constant(1.0),
|
660
|
+
reduction_axes: Axes = -1,
|
661
|
+
feature_axes: Axes = -1,
|
662
|
+
axis_name: Optional[str] = None,
|
663
|
+
axis_index_groups: Any = None,
|
664
|
+
use_fast_variance: bool = True,
|
665
|
+
param_type: type = NormalizationParamState,
|
666
|
+
):
|
667
|
+
super().__init__()
|
668
|
+
|
669
|
+
self.in_size = in_size
|
670
|
+
self.out_size = in_size
|
671
|
+
|
672
|
+
# parameters about axis
|
673
|
+
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
674
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
675
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
676
|
+
self.axis_name = axis_name
|
677
|
+
self.axis_index_groups = axis_index_groups
|
678
|
+
|
679
|
+
# variables
|
680
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
681
|
+
for i, ax in enumerate(self.in_size)])
|
682
|
+
if use_scale:
|
683
|
+
self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
|
684
|
+
else:
|
685
|
+
self.scale = None
|
686
|
+
|
687
|
+
# parameters
|
688
|
+
self.epsilon = epsilon
|
689
|
+
self.dtype = dtype or environ.dftype()
|
690
|
+
self.use_scale = use_scale
|
691
|
+
self.scale_init = scale_init
|
692
|
+
self.use_fast_variance = use_fast_variance
|
693
|
+
|
694
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
695
|
+
"""Applies layer normalization on the input.
|
696
|
+
|
697
|
+
Args:
|
698
|
+
x: the inputs
|
699
|
+
mask: the mask
|
700
|
+
|
701
|
+
Returns:
|
702
|
+
Normalized inputs (the same shape as inputs).
|
703
|
+
"""
|
704
|
+
mean, var = _compute_stats(
|
705
|
+
x,
|
706
|
+
self.reduction_axes,
|
707
|
+
dtype=self.dtype,
|
708
|
+
axis_name=self.axis_name,
|
709
|
+
axis_index_groups=self.axis_index_groups,
|
710
|
+
use_mean=False,
|
711
|
+
use_fast_variance=self.use_fast_variance,
|
712
|
+
mask=mask,
|
713
|
+
)
|
714
|
+
|
715
|
+
return _normalize(
|
716
|
+
x,
|
717
|
+
mean=mean,
|
718
|
+
var=var,
|
719
|
+
weights=self.scale,
|
720
|
+
reduction_axes=self.reduction_axes,
|
721
|
+
feature_axes=self.feature_axes,
|
722
|
+
dtype=self.dtype,
|
723
|
+
epsilon=self.epsilon,
|
724
|
+
)
|
725
|
+
|
726
|
+
|
727
|
+
class GroupNorm(Module):
|
728
|
+
"""
|
729
|
+
Group normalization (arxiv.org/abs/1803.08494).
|
730
|
+
|
731
|
+
This op is similar to batch normalization, but statistics are shared across
|
732
|
+
equally-sized groups of channels and not shared across batch dimension.
|
733
|
+
Thus, group normalization does not depend on the batch composition and does
|
734
|
+
not require maintaining internal state for storing statistics.
|
735
|
+
The user should either specify the total number of channel groups or the
|
736
|
+
number of channels per group.
|
737
|
+
|
738
|
+
.. note::
|
739
|
+
LayerNorm is a special case of GroupNorm where ``num_groups=1``.
|
740
|
+
|
741
|
+
Example usage::
|
742
|
+
|
743
|
+
>>> import numpy as np
|
744
|
+
>>> import brainstate as bst
|
745
|
+
...
|
746
|
+
>>> x = bst.random.normal(size=(3, 4, 5, 6))
|
747
|
+
>>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
|
748
|
+
>>> layer.states()
|
749
|
+
>>> y = layer(x)
|
750
|
+
>>> y = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
|
751
|
+
>>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
|
752
|
+
>>> np.testing.assert_allclose(y, y2)
|
753
|
+
|
754
|
+
Attributes:
|
755
|
+
in_size: The input shape, without batch size.
|
756
|
+
num_groups: the total number of channel groups. The default value of 32 is
|
757
|
+
proposed by the original group normalization paper.
|
758
|
+
group_size: the number of channels in a group.
|
759
|
+
epsilon: A small float added to variance to avoid dividing by zero.
|
760
|
+
dtype: the dtype of the result (default: infer from input and params).
|
761
|
+
use_bias: If True, bias (beta) is added.
|
762
|
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
763
|
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
764
|
+
by the next layer.
|
765
|
+
bias_init: Initializer for bias, by default, zero.
|
766
|
+
scale_init: Initializer for scale, by default, one.
|
767
|
+
reduction_axes: List of axes used for computing normalization statistics.
|
768
|
+
This list must include the final dimension, which is assumed to be the
|
769
|
+
feature axis. Furthermore, if the input used at call time has additional
|
770
|
+
leading axes compared to the data used for initialisation, for example due
|
771
|
+
to batching, then the reduction axes need to be defined explicitly.
|
772
|
+
It is recommended to use the negative integer, since when the batch dimension is used,
|
773
|
+
the reduction_axes may be wrong when using the positive integer.
|
774
|
+
axis_name: the axis name used to combine batch statistics from multiple
|
775
|
+
devices. See ``jax.pmap`` for a description of axis names (default: None).
|
776
|
+
This is only needed if the model is subdivided across devices, i.e. the
|
777
|
+
array being normalized is sharded across devices within a pmap or shard
|
778
|
+
map. For SPMD jit, you do not need to manually synchronize. Just make sure
|
779
|
+
that the axes are correctly annotated and XLA:SPMD will insert the
|
780
|
+
necessary collectives.
|
781
|
+
axis_index_groups: groups of axis indices within that named axis
|
782
|
+
representing subsets of devices to reduce over (default: None). For
|
783
|
+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
|
784
|
+
examples on the first two and last two devices. See ``jax.lax.psum`` for
|
785
|
+
more details.
|
786
|
+
use_fast_variance: If true, use a faster, but less numerically stable,
|
787
|
+
calculation for the variance.
|
788
|
+
"""
|
789
|
+
|
790
|
+
def __init__(
|
791
|
+
self,
|
792
|
+
in_size: Size,
|
793
|
+
feature_axis: Axes = -1,
|
794
|
+
num_groups: Optional[int] = 32,
|
795
|
+
group_size: Optional[int] = None,
|
796
|
+
*,
|
797
|
+
epsilon: float = 1e-6,
|
798
|
+
dtype: Optional[jax.typing.DTypeLike] = None,
|
799
|
+
use_bias: bool = True,
|
800
|
+
use_scale: bool = True,
|
801
|
+
bias_init: Callable = init.ZeroInit(),
|
802
|
+
scale_init: Callable = init.Constant(1.),
|
803
|
+
reduction_axes: Optional[Axes] = None,
|
804
|
+
axis_name: Optional[str] = None,
|
805
|
+
axis_index_groups: Any = None,
|
806
|
+
use_fast_variance: bool = True,
|
807
|
+
param_type: type = NormalizationParamState,
|
808
|
+
):
|
809
|
+
super().__init__()
|
810
|
+
|
811
|
+
self.in_size = in_size
|
812
|
+
self.out_size = in_size
|
813
|
+
|
814
|
+
# parameters about axis
|
815
|
+
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
816
|
+
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
|
817
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
818
|
+
self.axis_name = axis_name
|
819
|
+
self.axis_index_groups = axis_index_groups
|
820
|
+
|
821
|
+
if (num_groups is None and group_size is None) or (
|
822
|
+
num_groups is not None and group_size is not None
|
823
|
+
):
|
824
|
+
raise ValueError(
|
825
|
+
'Either `num_groups` or `group_size` should be '
|
826
|
+
'specified. If `group_size` is to be specified, '
|
827
|
+
'pass `num_groups=None` as argument to override '
|
828
|
+
'the default `num_groups` value of 32.'
|
829
|
+
)
|
830
|
+
|
831
|
+
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
832
|
+
for i, ax in enumerate(self.in_size)])
|
833
|
+
assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
|
834
|
+
num_features = feature_shape[0]
|
835
|
+
if group_size is not None:
|
836
|
+
if num_features % group_size != 0:
|
837
|
+
raise ValueError(
|
838
|
+
'Number of features ({}) is not multiple of the '
|
839
|
+
'group size ({}).'.format(num_features, group_size)
|
840
|
+
)
|
841
|
+
self.num_groups = num_features // group_size
|
842
|
+
self.group_size = group_size
|
843
|
+
else:
|
844
|
+
if not isinstance(num_groups, int) or num_groups <= 0 or (
|
845
|
+
num_features % num_groups != 0
|
846
|
+
):
|
847
|
+
raise ValueError(
|
848
|
+
'Number of groups ({}) does not divide the number'
|
849
|
+
' of channels ({}).'.format(num_groups, num_features)
|
850
|
+
)
|
851
|
+
self.num_groups = num_groups
|
852
|
+
self.group_size = num_features // num_groups
|
853
|
+
|
854
|
+
# variables
|
855
|
+
weights = dict()
|
856
|
+
if use_scale:
|
857
|
+
weights['scale'] = init.param(scale_init, feature_shape)
|
858
|
+
if use_bias:
|
859
|
+
weights['bias'] = init.param(bias_init, feature_shape)
|
860
|
+
if len(weights):
|
861
|
+
self.weight = param_type(weights)
|
862
|
+
else:
|
863
|
+
self.weight = None
|
864
|
+
|
865
|
+
# parameters
|
866
|
+
self.epsilon = epsilon
|
867
|
+
self.dtype = dtype
|
868
|
+
self.use_bias = use_bias
|
869
|
+
self.use_scale = use_scale
|
870
|
+
self.bias_init = bias_init
|
871
|
+
self.scale_init = scale_init
|
872
|
+
self.use_fast_variance = use_fast_variance
|
873
|
+
|
874
|
+
def update(self, x, *, mask: Optional[jax.Array] = None):
|
875
|
+
"""Applies group normalization to the input (arxiv.org/abs/1803.08494).
|
876
|
+
|
877
|
+
Args:
|
878
|
+
x: the input of shape ``...self.num_features`` where ``self.num_features``
|
879
|
+
is a channels dimension and ``...`` represents an arbitrary number of
|
880
|
+
extra dimensions that can be used to accumulate statistics over. If no
|
881
|
+
reduction axes have been specified then all additional dimensions ``...``
|
882
|
+
will be used to accumulate statistics apart from the leading dimension
|
883
|
+
which is assumed to represent the batch.
|
884
|
+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
|
885
|
+
the positions for which the mean and variance should be computed.
|
886
|
+
|
887
|
+
Returns:
|
888
|
+
Normalized inputs (the same shape as inputs).
|
889
|
+
"""
|
890
|
+
if self.reduction_axes is not None:
|
891
|
+
reduction_axes = self.reduction_axes
|
892
|
+
else:
|
893
|
+
reduction_axes = list(range(1, x.ndim - 1)) + [-1]
|
894
|
+
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
|
895
|
+
|
896
|
+
group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
|
897
|
+
if mask is not None:
|
898
|
+
mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
|
899
|
+
|
900
|
+
mean, var = _compute_stats(
|
901
|
+
x.reshape(group_shape),
|
902
|
+
list(reduction_axes[:-1]) + [-1],
|
903
|
+
dtype=self.dtype,
|
904
|
+
axis_name=self.axis_name,
|
905
|
+
axis_index_groups=self.axis_index_groups,
|
906
|
+
use_fast_variance=self.use_fast_variance,
|
907
|
+
mask=mask,
|
908
|
+
)
|
909
|
+
mean = jnp.repeat(mean, self.group_size, axis=1)
|
910
|
+
var = jnp.repeat(var, self.group_size, axis=1)
|
911
|
+
return _normalize(
|
912
|
+
x,
|
913
|
+
mean=mean,
|
914
|
+
var=var,
|
915
|
+
weights=self.weight,
|
916
|
+
reduction_axes=reduction_axes[:-1],
|
917
|
+
feature_axes=self.feature_axes,
|
918
|
+
dtype=self.dtype,
|
919
|
+
epsilon=self.epsilon,
|
920
|
+
)
|