brainstate 0.1.0.post20241125__py2.py3-none-any.whl → 0.1.0.post20241209__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,22 +17,61 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import numbers
21
- from typing import Callable, Union, Sequence, Optional, Any
20
+ from typing import Callable, Union, Sequence, Optional, Any, Dict
22
21
 
23
22
  import jax
24
23
  import jax.numpy as jnp
25
24
 
26
25
  from brainstate import environ, init
27
- from brainstate._state import LongTermState, ParamState
26
+ from brainstate._state import ParamState, BatchState
28
27
  from brainstate.nn._module import Module
29
28
  from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
30
29
 
31
30
  __all__ = [
32
- 'BatchNorm0d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
31
+ 'BatchNorm0d',
32
+ 'BatchNorm1d',
33
+ 'BatchNorm2d',
34
+ 'BatchNorm3d',
35
+ 'LayerNorm',
36
+ 'RMSNorm',
37
+ 'GroupNorm',
33
38
  ]
34
39
 
35
40
 
41
+ def canonicalize_dtype(
42
+ *args,
43
+ dtype: jax.typing.DTypeLike | None = None,
44
+ inexact: bool = True
45
+ ) -> jax.typing.DTypeLike:
46
+ """Canonicalize an optional dtype to the definitive dtype.
47
+
48
+ If the ``dtype`` is None this function will infer the dtype. If it is not
49
+ None it will be returned unmodified or an exceptions is raised if the dtype
50
+ is invalid.
51
+ from the input arguments using ``jnp.result_type``.
52
+
53
+ Args:
54
+ *args: JAX array compatible values. None values
55
+ are ignored.
56
+ dtype: Optional dtype override. If specified the arguments are cast to
57
+ the specified dtype instead and dtype inference is disabled.
58
+ inexact: When True, the output dtype must be a subdtype
59
+ of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
60
+ is useful when you want to apply operations that don't work directly on
61
+ integers like taking a mean for example.
62
+ Returns:
63
+ The dtype that *args should be cast to.
64
+ """
65
+ if dtype is None:
66
+ args_filtered = [jnp.asarray(x) for x in args if x is not None]
67
+ dtype = jnp.result_type(*args_filtered)
68
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
69
+ dtype = jnp.promote_types(jnp.float32, dtype)
70
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
71
+ raise ValueError(f'Dtype must be inexact: {dtype}')
72
+ return dtype
73
+
74
+
36
75
  def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
37
76
  axes = []
38
77
  for axis in feature_axes:
@@ -52,6 +91,18 @@ def _abs_sq(x):
52
91
  return jax.lax.square(x)
53
92
 
54
93
 
94
+ class NormalizationParamState(ParamState):
95
+ # This is a dummy class to be used as a compatibility
96
+ # usage of `ETraceParam` for the layers in "brainetrace"
97
+ def execute(self, x):
98
+ param = self.value
99
+ if 'scale' in param:
100
+ x = x * param['scale']
101
+ if 'bias' in param:
102
+ x = x + param['bias']
103
+ return x
104
+
105
+
55
106
  def _compute_stats(
56
107
  x: ArrayLike,
57
108
  axes: Sequence[int],
@@ -59,28 +110,38 @@ def _compute_stats(
59
110
  axis_name: Optional[str] = None,
60
111
  axis_index_groups: Optional[Sequence[int]] = None,
61
112
  use_mean: bool = True,
113
+ use_fast_variance: bool = True,
114
+ mask: Optional[jax.Array] = None,
62
115
  ):
63
- """Computes mean and variance statistics.
116
+ """
117
+ Computes mean and variance statistics.
64
118
 
65
119
  This implementation takes care of a few important details:
66
120
  - Computes in float32 precision for stability in half precision training.
67
- - mean and variance are computable in a single XLA fusion,
68
- by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
121
+ - If ``use_fast_variance`` is ``True``, mean and variance are computed using
122
+ Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
69
123
  - Clips negative variances to zero which can happen due to
70
124
  roundoff errors. This avoids downstream NaNs.
71
125
  - Supports averaging across a parallel axis and subgroups of a parallel axis
72
126
  with a single `lax.pmean` call to avoid latency.
73
127
 
74
128
  Arguments:
75
- x: Input array.
76
- axes: The axes in ``x`` to compute mean and variance statistics for.
77
- dtype: tp.Optional dtype specifying the minimal precision. Statistics
78
- are always at least float32 for stability (default: dtype of x).
79
- axis_name: tp.Optional name for the pmapped axis to compute mean over.
80
- axis_index_groups: tp.Optional axis indices.
81
- use_mean: If true, calculate the mean from the input and use it when
82
- computing the variance. If false, set the mean to zero and compute
83
- the variance without subtracting the mean.
129
+ x: Input array.
130
+ axes: The axes in ``x`` to compute mean and variance statistics for.
131
+ dtype: tp.Optional dtype specifying the minimal precision. Statistics
132
+ are always at least float32 for stability (default: dtype of x).
133
+ axis_name: Optional name for the pmapped axis to compute mean over. Note,
134
+ this is only used for pmap and shard map. For SPMD jit, you do not need to
135
+ manually synchronize. Just make sure that the axes are correctly annotated
136
+ and XLA:SPMD will insert the necessary collectives.
137
+ axis_index_groups: Optional axis indices.
138
+ use_mean: If true, calculate the mean from the input and use it when
139
+ computing the variance. If false, set the mean to zero and compute
140
+ the variance without subtracting the mean.
141
+ use_fast_variance: If true, use a faster, but less numerically stable,
142
+ calculation for the variance.
143
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
144
+ the positions for which the mean and variance should be computed.
84
145
 
85
146
  Returns:
86
147
  A pair ``(mean, val)``.
@@ -89,42 +150,54 @@ def _compute_stats(
89
150
  dtype = jax.numpy.result_type(x)
90
151
  # promote x to at least float32, this avoids half precision computation
91
152
  # but preserves double or complex floating points
92
- dtype = jax.numpy.promote_types(dtype, environ.dftype())
153
+ dtype = jax.numpy.promote_types(dtype, jnp.float32)
93
154
  x = jnp.asarray(x, dtype)
155
+ axes = _canonicalize_axes(x.ndim, axes)
156
+
157
+ def maybe_distributed_mean(*xs, mask=None):
158
+ mus = tuple(x.mean(axes, where=mask) for x in xs)
159
+ if axis_name is None:
160
+ return mus if len(xs) > 1 else mus[0]
161
+ else:
162
+ # In the distributed case we stack multiple arrays to speed comms.
163
+ if len(xs) > 1:
164
+ reduced_mus = jax.lax.pmean(
165
+ jnp.stack(mus, axis=0),
166
+ axis_name,
167
+ axis_index_groups=axis_index_groups,
168
+ )
169
+ return tuple(reduced_mus[i] for i in range(len(xs)))
170
+ else:
171
+ return jax.lax.pmean(
172
+ mus[0],
173
+ axis_name,
174
+ axis_index_groups=axis_index_groups
175
+ )
94
176
 
95
- # Compute mean and mean of squared values.
96
- mean2 = jnp.mean(_abs_sq(x), axes)
97
177
  if use_mean:
98
- mean = jnp.mean(x, axes)
178
+ if use_fast_variance:
179
+ mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
180
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
181
+ # to floating point round-off errors.
182
+ var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
183
+ else:
184
+ mu = maybe_distributed_mean(x, mask=mask)
185
+ var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
99
186
  else:
100
- mean = jnp.zeros(mean2.shape, dtype=dtype)
101
-
102
- # If axis_name is provided, we need to average the mean and mean2 across
103
- if axis_name is not None:
104
- concatenated_mean = jnp.concatenate([mean, mean2])
105
- mean, mean2 = jnp.split(
106
- jax.lax.pmean(
107
- concatenated_mean,
108
- axis_name=axis_name,
109
- axis_index_groups=axis_index_groups,
110
- ),
111
- 2,
112
- )
113
-
114
- # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
115
- # to floating point round-off errors.
116
- var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
117
- return mean, var
187
+ var = maybe_distributed_mean(_abs_sq(x), mask=mask)
188
+ mu = jnp.zeros_like(var)
189
+ return mu, var
118
190
 
119
191
 
120
192
  def _normalize(
121
193
  x: ArrayLike,
122
194
  mean: Optional[ArrayLike],
123
195
  var: Optional[ArrayLike],
124
- weights: Optional[ParamState],
125
- reduction_axes: Sequence[int],
196
+ weights: Optional[NormalizationParamState],
197
+ reduction_axes: Axes,
198
+ feature_axes: Axes,
126
199
  dtype: DTypeLike,
127
- epsilon: Union[numbers.Number, jax.Array],
200
+ epsilon: jax.typing.ArrayLike,
128
201
  ):
129
202
  """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
130
203
 
@@ -134,6 +207,7 @@ def _normalize(
134
207
  var: Variance to use for normalization.
135
208
  weights: The scale and bias parameters.
136
209
  reduction_axes: The axes in ``x`` to reduce.
210
+ feature_axes: The feature axes to apply the scale and bias.
137
211
  dtype: The dtype of the result (default: infer from input and params).
138
212
  epsilon: Normalization epsilon.
139
213
 
@@ -142,16 +216,22 @@ def _normalize(
142
216
  """
143
217
  if mean is not None:
144
218
  assert var is not None, 'mean and val must be both None or not None.'
219
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
220
+ feature_axes = _canonicalize_axes(x.ndim, feature_axes)
145
221
  stats_shape = list(x.shape)
146
222
  for axis in reduction_axes:
147
223
  stats_shape[axis] = 1
148
224
  mean = mean.reshape(stats_shape)
149
225
  var = var.reshape(stats_shape)
226
+ feature_shape = [1] * x.ndim
227
+ for ax in feature_axes:
228
+ feature_shape[ax] = x.shape[ax]
150
229
  y = x - mean
151
- mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
230
+ mul = jax.lax.rsqrt(var + epsilon)
152
231
  y = y * mul
153
232
  if weights is not None:
154
- y = _scale_operation(y, weights.value)
233
+ y = weights.execute(y)
234
+ dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
155
235
  else:
156
236
  assert var is None, 'mean and val must be both None or not None.'
157
237
  assert weights is None, 'scale and bias are not supported without mean and val'
@@ -159,14 +239,6 @@ def _normalize(
159
239
  return jnp.asarray(y, dtype)
160
240
 
161
241
 
162
- def _scale_operation(x, param):
163
- if 'scale' in param:
164
- x = x * param['scale']
165
- if 'bias' in param:
166
- x = x + param['bias']
167
- return x
168
-
169
-
170
242
  class _BatchNorm(Module):
171
243
  __module__ = 'brainstate.nn'
172
244
  num_spatial_dims: int
@@ -175,6 +247,7 @@ class _BatchNorm(Module):
175
247
  self,
176
248
  in_size: Size,
177
249
  feature_axis: Axes = -1,
250
+ *,
178
251
  track_running_stats: bool = True,
179
252
  epsilon: float = 1e-5,
180
253
  momentum: float = 0.99,
@@ -183,14 +256,17 @@ class _BatchNorm(Module):
183
256
  scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
184
257
  axis_name: Optional[Union[str, Sequence[str]]] = None,
185
258
  axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
259
+ use_fast_variance: bool = True,
186
260
  name: Optional[str] = None,
187
261
  dtype: Any = None,
262
+ param_type: type = NormalizationParamState,
263
+ mean_type: type = BatchState,
188
264
  ):
189
265
  super().__init__(name=name)
190
266
 
191
267
  # parameters
192
- self.in_size = tuple(in_size)
193
- self.out_size = tuple(in_size)
268
+ self.in_size = in_size
269
+ self.out_size = in_size
194
270
  self.affine = affine
195
271
  self.bias_initializer = bias_initializer
196
272
  self.scale_initializer = scale_initializer
@@ -198,18 +274,20 @@ class _BatchNorm(Module):
198
274
  self.track_running_stats = track_running_stats
199
275
  self.momentum = jnp.asarray(momentum, dtype=self.dtype)
200
276
  self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
277
+ self.use_fast_variance = use_fast_variance
201
278
 
202
279
  # parameters about axis
203
280
  feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
204
- self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
281
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
205
282
  self.axis_name = axis_name
206
283
  self.axis_index_groups = axis_index_groups
207
284
 
208
285
  # variables
209
- feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
286
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
287
+ for i, ax in enumerate(self.in_size)])
210
288
  if self.track_running_stats:
211
- self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
212
- self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
289
+ self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
290
+ self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
213
291
  else:
214
292
  self.running_mean = None
215
293
  self.running_var = None
@@ -219,11 +297,11 @@ class _BatchNorm(Module):
219
297
  assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
220
298
  bias = init.param(self.bias_initializer, feature_shape)
221
299
  scale = init.param(self.scale_initializer, feature_shape)
222
- self.weight = ParamState(dict(bias=bias, scale=scale))
300
+ self.weight = param_type(dict(bias=bias, scale=scale))
223
301
  else:
224
302
  self.weight = None
225
303
 
226
- def update(self, x):
304
+ def update(self, x, mask: Optional[jax.Array] = None):
227
305
  # input shape and batch mode or not
228
306
  if x.ndim == self.num_spatial_dims + 2:
229
307
  x_shape = x.shape[1:]
@@ -239,9 +317,9 @@ class _BatchNorm(Module):
239
317
 
240
318
  # reduce the feature axis
241
319
  if batch:
242
- reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
320
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
243
321
  else:
244
- reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
322
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
245
323
 
246
324
  # fitting phase
247
325
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
@@ -255,6 +333,8 @@ class _BatchNorm(Module):
255
333
  dtype=self.dtype,
256
334
  axis_name=self.axis_name,
257
335
  axis_index_groups=self.axis_index_groups,
336
+ use_fast_variance=self.use_fast_variance,
337
+ mask=mask,
258
338
  )
259
339
  self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
260
340
  self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
@@ -265,14 +345,22 @@ class _BatchNorm(Module):
265
345
  mean, var = None, None
266
346
 
267
347
  # normalize
268
- return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
348
+ return _normalize(
349
+ x,
350
+ mean=mean,
351
+ var=var,
352
+ weights=self.weight,
353
+ reduction_axes=reduction_axes,
354
+ feature_axes=self.feature_axes,
355
+ dtype=self.dtype,
356
+ epsilon=self.epsilon
357
+ )
269
358
 
270
359
 
271
360
  class BatchNorm0d(_BatchNorm):
272
- r"""1-D batch normalization [1]_.
361
+ r"""0-D batch normalization [1]_.
273
362
 
274
- The data should be of `(b, l, c)`, where `b` is the batch dimension,
275
- `l` is the layer dimension, and `c` is the channel dimension.
363
+ The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
276
364
 
277
365
  %s
278
366
  """
@@ -375,7 +463,10 @@ _bn_doc = r'''
375
463
  example, `[[0, 1], [2, 3]]` would independently batch-normalize over
376
464
  the examples on the first two and last two devices. See `jax.lax.psum`
377
465
  for more details.
378
-
466
+ use_fast_variance: If true, use a faster, but less numerically stable,
467
+ calculation for the variance.
468
+
469
+
379
470
  References
380
471
  ----------
381
472
  .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
@@ -386,3 +477,444 @@ _bn_doc = r'''
386
477
  BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
387
478
  BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
388
479
  BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
480
+
481
+
482
+ class LayerNorm(Module):
483
+ """
484
+ Layer normalization (https://arxiv.org/abs/1607.06450).
485
+
486
+ LayerNorm normalizes the activations of the layer for each given example in a
487
+ batch independently, rather than across a batch like Batch Normalization.
488
+ i.e. applies a transformation that maintains the mean activation within
489
+ each example close to 0 and the activation standard deviation close to 1.
490
+
491
+ Example usage::
492
+
493
+ >>> import brainstate as bst
494
+ >>> x = bst.random.normal(size=(3, 4, 5, 6))
495
+ >>> layer = bst.nn.LayerNorm(x.shape)
496
+ >>> layer.states()
497
+ >>> y = layer(x)
498
+
499
+ Attributes:
500
+ in_size: The input shape, without batch size.
501
+ epsilon: A small float added to variance to avoid dividing by zero.
502
+ dtype: the dtype of the result (default: infer from input and params).
503
+ use_bias: If True, bias (beta) is added.
504
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
505
+ (also e.g. nnx.relu), this can be disabled since the scaling will be done
506
+ by the next layer.
507
+ bias_init: Initializer for bias, by default, zero.
508
+ scale_init: Initializer for scale, by default, one.
509
+ reduction_axes: Axes for computing normalization statistics. It is recommended
510
+ to use the negative integer, since when the batch dimension is used,
511
+ the reduction_axes may be wrong when using the positive integer.
512
+ feature_axes: Feature axes for learned bias and scaling.
513
+ axis_name: the axis name used to combine batch statistics from multiple
514
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
515
+ This is only needed if the model is subdivided across devices, i.e. the
516
+ array being normalized is sharded across devices within a pmap.
517
+ axis_index_groups: groups of axis indices within that named axis
518
+ representing subsets of devices to reduce over (default: None). For
519
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
520
+ the examples on the first two and last two devices. See ``jax.lax.psum``
521
+ for more details.
522
+ use_fast_variance: If true, use a faster, but less numerically stable,
523
+ calculation for the variance.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ in_size: Size,
529
+ reduction_axes: Axes = -1,
530
+ feature_axes: Axes = -1,
531
+ *,
532
+ epsilon: float = 1e-6,
533
+ use_bias: bool = True,
534
+ use_scale: bool = True,
535
+ bias_init: Callable = init.ZeroInit(),
536
+ scale_init: Callable = init.Constant(1.0),
537
+ axis_name: Optional[str] = None,
538
+ axis_index_groups: Any = None,
539
+ use_fast_variance: bool = True,
540
+ dtype: Optional[jax.typing.DTypeLike] = None,
541
+ param_type: type = NormalizationParamState,
542
+ ):
543
+ super().__init__()
544
+
545
+ self.in_size = in_size
546
+ self.out_size = in_size
547
+
548
+ # parameters about axis
549
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
550
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
551
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
552
+ self.axis_name = axis_name
553
+ self.axis_index_groups = axis_index_groups
554
+
555
+ # variables
556
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
557
+ for i, ax in enumerate(self.in_size)])
558
+
559
+ weights = dict()
560
+ if use_scale:
561
+ weights['scale'] = init.param(scale_init, feature_shape)
562
+ if use_bias:
563
+ weights['bias'] = init.param(bias_init, feature_shape)
564
+ if len(weights):
565
+ self.weight = param_type(weights)
566
+ else:
567
+ self.weight = None
568
+
569
+ # parameters
570
+ self.epsilon = epsilon
571
+ self.dtype = dtype or environ.dftype()
572
+ self.use_bias = use_bias
573
+ self.use_scale = use_scale
574
+ self.bias_init = bias_init
575
+ self.scale_init = scale_init
576
+ self.use_fast_variance = use_fast_variance
577
+
578
+ def update(self, x, *, mask: Optional[jax.Array] = None):
579
+ """Applies layer normalization on the input.
580
+
581
+ Args:
582
+ x: the inputs
583
+
584
+ Returns:
585
+ Normalized inputs (the same shape as inputs).
586
+ """
587
+ mean, var = _compute_stats(
588
+ x,
589
+ self.reduction_axes,
590
+ dtype=self.dtype,
591
+ axis_name=self.axis_name,
592
+ axis_index_groups=self.axis_index_groups,
593
+ use_fast_variance=self.use_fast_variance,
594
+ mask=mask,
595
+ )
596
+
597
+ return _normalize(
598
+ x,
599
+ mean=mean,
600
+ var=var,
601
+ weights=self.weight,
602
+ reduction_axes=self.reduction_axes,
603
+ feature_axes=self.feature_axes,
604
+ dtype=self.dtype,
605
+ epsilon=self.epsilon,
606
+ )
607
+
608
+
609
+ class RMSNorm(Module):
610
+ """
611
+ RMS Layer normalization (https://arxiv.org/abs/1910.07467).
612
+
613
+ RMSNorm normalizes the activations of the layer for each given example in a
614
+ batch independently, rather than across a batch like Batch Normalization.
615
+ Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
616
+ standard deviation of the activations, RMSNorm does not re-center at all
617
+ and instead normalizes by the root mean square of the activations.
618
+
619
+ Example usage::
620
+
621
+ >>> import brainstate as bst
622
+ >>> x = bst.random.normal(size=(5, 6))
623
+ >>> layer = bst.nn.RMSNorm(num_features=6)
624
+ >>> layer.states()
625
+ >>> y = layer(x)
626
+
627
+ Attributes:
628
+ in_size: The input shape, without batch size.
629
+ epsilon: A small float added to variance to avoid dividing by zero.
630
+ dtype: the dtype of the result (default: infer from input and params).
631
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
632
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
633
+ by the next layer.
634
+ scale_init: Initializer for scale, by default, one.
635
+ reduction_axes: Axes for computing normalization statistics. It is recommended
636
+ to use the negative integer, since when the batch dimension is used,
637
+ the reduction_axes may be wrong when using the positive integer.
638
+ feature_axes: Feature axes for learned bias and scaling.
639
+ axis_name: the axis name used to combine batch statistics from multiple
640
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
641
+ This is only needed if the model is subdivided across devices, i.e. the
642
+ array being normalized is sharded across devices within a pmap.
643
+ axis_index_groups: groups of axis indices within that named axis
644
+ representing subsets of devices to reduce over (default: None). For
645
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
646
+ the examples on the first two and last two devices. See ``jax.lax.psum``
647
+ for more details.
648
+ use_fast_variance: If true, use a faster, but less numerically stable,
649
+ calculation for the variance.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ in_size: Size,
655
+ *,
656
+ epsilon: float = 1e-6,
657
+ dtype: Optional[jax.typing.DTypeLike] = None,
658
+ use_scale: bool = True,
659
+ scale_init: Callable = init.Constant(1.0),
660
+ reduction_axes: Axes = -1,
661
+ feature_axes: Axes = -1,
662
+ axis_name: Optional[str] = None,
663
+ axis_index_groups: Any = None,
664
+ use_fast_variance: bool = True,
665
+ param_type: type = NormalizationParamState,
666
+ ):
667
+ super().__init__()
668
+
669
+ self.in_size = in_size
670
+ self.out_size = in_size
671
+
672
+ # parameters about axis
673
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
674
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
675
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
676
+ self.axis_name = axis_name
677
+ self.axis_index_groups = axis_index_groups
678
+
679
+ # variables
680
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
681
+ for i, ax in enumerate(self.in_size)])
682
+ if use_scale:
683
+ self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
684
+ else:
685
+ self.scale = None
686
+
687
+ # parameters
688
+ self.epsilon = epsilon
689
+ self.dtype = dtype or environ.dftype()
690
+ self.use_scale = use_scale
691
+ self.scale_init = scale_init
692
+ self.use_fast_variance = use_fast_variance
693
+
694
+ def update(self, x, *, mask: Optional[jax.Array] = None):
695
+ """Applies layer normalization on the input.
696
+
697
+ Args:
698
+ x: the inputs
699
+ mask: the mask
700
+
701
+ Returns:
702
+ Normalized inputs (the same shape as inputs).
703
+ """
704
+ mean, var = _compute_stats(
705
+ x,
706
+ self.reduction_axes,
707
+ dtype=self.dtype,
708
+ axis_name=self.axis_name,
709
+ axis_index_groups=self.axis_index_groups,
710
+ use_mean=False,
711
+ use_fast_variance=self.use_fast_variance,
712
+ mask=mask,
713
+ )
714
+
715
+ return _normalize(
716
+ x,
717
+ mean=mean,
718
+ var=var,
719
+ weights=self.scale,
720
+ reduction_axes=self.reduction_axes,
721
+ feature_axes=self.feature_axes,
722
+ dtype=self.dtype,
723
+ epsilon=self.epsilon,
724
+ )
725
+
726
+
727
+ class GroupNorm(Module):
728
+ """
729
+ Group normalization (arxiv.org/abs/1803.08494).
730
+
731
+ This op is similar to batch normalization, but statistics are shared across
732
+ equally-sized groups of channels and not shared across batch dimension.
733
+ Thus, group normalization does not depend on the batch composition and does
734
+ not require maintaining internal state for storing statistics.
735
+ The user should either specify the total number of channel groups or the
736
+ number of channels per group.
737
+
738
+ .. note::
739
+ LayerNorm is a special case of GroupNorm where ``num_groups=1``.
740
+
741
+ Example usage::
742
+
743
+ >>> import numpy as np
744
+ >>> import brainstate as bst
745
+ ...
746
+ >>> x = bst.random.normal(size=(3, 4, 5, 6))
747
+ >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
748
+ >>> layer.states()
749
+ >>> y = layer(x)
750
+ >>> y = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
751
+ >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
752
+ >>> np.testing.assert_allclose(y, y2)
753
+
754
+ Attributes:
755
+ in_size: The input shape, without batch size.
756
+ num_groups: the total number of channel groups. The default value of 32 is
757
+ proposed by the original group normalization paper.
758
+ group_size: the number of channels in a group.
759
+ epsilon: A small float added to variance to avoid dividing by zero.
760
+ dtype: the dtype of the result (default: infer from input and params).
761
+ use_bias: If True, bias (beta) is added.
762
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
763
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
764
+ by the next layer.
765
+ bias_init: Initializer for bias, by default, zero.
766
+ scale_init: Initializer for scale, by default, one.
767
+ reduction_axes: List of axes used for computing normalization statistics.
768
+ This list must include the final dimension, which is assumed to be the
769
+ feature axis. Furthermore, if the input used at call time has additional
770
+ leading axes compared to the data used for initialisation, for example due
771
+ to batching, then the reduction axes need to be defined explicitly.
772
+ It is recommended to use the negative integer, since when the batch dimension is used,
773
+ the reduction_axes may be wrong when using the positive integer.
774
+ axis_name: the axis name used to combine batch statistics from multiple
775
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
776
+ This is only needed if the model is subdivided across devices, i.e. the
777
+ array being normalized is sharded across devices within a pmap or shard
778
+ map. For SPMD jit, you do not need to manually synchronize. Just make sure
779
+ that the axes are correctly annotated and XLA:SPMD will insert the
780
+ necessary collectives.
781
+ axis_index_groups: groups of axis indices within that named axis
782
+ representing subsets of devices to reduce over (default: None). For
783
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
784
+ examples on the first two and last two devices. See ``jax.lax.psum`` for
785
+ more details.
786
+ use_fast_variance: If true, use a faster, but less numerically stable,
787
+ calculation for the variance.
788
+ """
789
+
790
+ def __init__(
791
+ self,
792
+ in_size: Size,
793
+ feature_axis: Axes = -1,
794
+ num_groups: Optional[int] = 32,
795
+ group_size: Optional[int] = None,
796
+ *,
797
+ epsilon: float = 1e-6,
798
+ dtype: Optional[jax.typing.DTypeLike] = None,
799
+ use_bias: bool = True,
800
+ use_scale: bool = True,
801
+ bias_init: Callable = init.ZeroInit(),
802
+ scale_init: Callable = init.Constant(1.),
803
+ reduction_axes: Optional[Axes] = None,
804
+ axis_name: Optional[str] = None,
805
+ axis_index_groups: Any = None,
806
+ use_fast_variance: bool = True,
807
+ param_type: type = NormalizationParamState,
808
+ ):
809
+ super().__init__()
810
+
811
+ self.in_size = in_size
812
+ self.out_size = in_size
813
+
814
+ # parameters about axis
815
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
816
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
817
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
818
+ self.axis_name = axis_name
819
+ self.axis_index_groups = axis_index_groups
820
+
821
+ if (num_groups is None and group_size is None) or (
822
+ num_groups is not None and group_size is not None
823
+ ):
824
+ raise ValueError(
825
+ 'Either `num_groups` or `group_size` should be '
826
+ 'specified. If `group_size` is to be specified, '
827
+ 'pass `num_groups=None` as argument to override '
828
+ 'the default `num_groups` value of 32.'
829
+ )
830
+
831
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
832
+ for i, ax in enumerate(self.in_size)])
833
+ assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
834
+ num_features = feature_shape[0]
835
+ if group_size is not None:
836
+ if num_features % group_size != 0:
837
+ raise ValueError(
838
+ 'Number of features ({}) is not multiple of the '
839
+ 'group size ({}).'.format(num_features, group_size)
840
+ )
841
+ self.num_groups = num_features // group_size
842
+ self.group_size = group_size
843
+ else:
844
+ if not isinstance(num_groups, int) or num_groups <= 0 or (
845
+ num_features % num_groups != 0
846
+ ):
847
+ raise ValueError(
848
+ 'Number of groups ({}) does not divide the number'
849
+ ' of channels ({}).'.format(num_groups, num_features)
850
+ )
851
+ self.num_groups = num_groups
852
+ self.group_size = num_features // num_groups
853
+
854
+ # variables
855
+ weights = dict()
856
+ if use_scale:
857
+ weights['scale'] = init.param(scale_init, feature_shape)
858
+ if use_bias:
859
+ weights['bias'] = init.param(bias_init, feature_shape)
860
+ if len(weights):
861
+ self.weight = param_type(weights)
862
+ else:
863
+ self.weight = None
864
+
865
+ # parameters
866
+ self.epsilon = epsilon
867
+ self.dtype = dtype
868
+ self.use_bias = use_bias
869
+ self.use_scale = use_scale
870
+ self.bias_init = bias_init
871
+ self.scale_init = scale_init
872
+ self.use_fast_variance = use_fast_variance
873
+
874
+ def update(self, x, *, mask: Optional[jax.Array] = None):
875
+ """Applies group normalization to the input (arxiv.org/abs/1803.08494).
876
+
877
+ Args:
878
+ x: the input of shape ``...self.num_features`` where ``self.num_features``
879
+ is a channels dimension and ``...`` represents an arbitrary number of
880
+ extra dimensions that can be used to accumulate statistics over. If no
881
+ reduction axes have been specified then all additional dimensions ``...``
882
+ will be used to accumulate statistics apart from the leading dimension
883
+ which is assumed to represent the batch.
884
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
885
+ the positions for which the mean and variance should be computed.
886
+
887
+ Returns:
888
+ Normalized inputs (the same shape as inputs).
889
+ """
890
+ if self.reduction_axes is not None:
891
+ reduction_axes = self.reduction_axes
892
+ else:
893
+ reduction_axes = list(range(1, x.ndim - 1)) + [-1]
894
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
895
+
896
+ group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
897
+ if mask is not None:
898
+ mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
899
+
900
+ mean, var = _compute_stats(
901
+ x.reshape(group_shape),
902
+ list(reduction_axes[:-1]) + [-1],
903
+ dtype=self.dtype,
904
+ axis_name=self.axis_name,
905
+ axis_index_groups=self.axis_index_groups,
906
+ use_fast_variance=self.use_fast_variance,
907
+ mask=mask,
908
+ )
909
+ mean = jnp.repeat(mean, self.group_size, axis=1)
910
+ var = jnp.repeat(var, self.group_size, axis=1)
911
+ return _normalize(
912
+ x,
913
+ mean=mean,
914
+ var=var,
915
+ weights=self.weight,
916
+ reduction_axes=reduction_axes[:-1],
917
+ feature_axes=self.feature_axes,
918
+ dtype=self.dtype,
919
+ epsilon=self.epsilon,
920
+ )