brainstate 0.1.0.post20241122__py2.py3-none-any.whl → 0.1.0.post20241129__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,8 +17,7 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import numbers
21
- from typing import Callable, Union, Sequence, Optional, Any
20
+ from typing import Callable, Union, Sequence, Optional, Any, Dict
22
21
 
23
22
  import jax
24
23
  import jax.numpy as jnp
@@ -29,10 +28,50 @@ from brainstate.nn._module import Module
29
28
  from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
30
29
 
31
30
  __all__ = [
32
- 'BatchNorm0d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
31
+ 'BatchNorm0d',
32
+ 'BatchNorm1d',
33
+ 'BatchNorm2d',
34
+ 'BatchNorm3d',
35
+ 'LayerNorm',
36
+ 'RMSNorm',
37
+ 'GroupNorm',
33
38
  ]
34
39
 
35
40
 
41
+ def canonicalize_dtype(
42
+ *args,
43
+ dtype: jax.typing.DTypeLike | None = None,
44
+ inexact: bool = True
45
+ ) -> jax.typing.DTypeLike:
46
+ """Canonicalize an optional dtype to the definitive dtype.
47
+
48
+ If the ``dtype`` is None this function will infer the dtype. If it is not
49
+ None it will be returned unmodified or an exceptions is raised if the dtype
50
+ is invalid.
51
+ from the input arguments using ``jnp.result_type``.
52
+
53
+ Args:
54
+ *args: JAX array compatible values. None values
55
+ are ignored.
56
+ dtype: Optional dtype override. If specified the arguments are cast to
57
+ the specified dtype instead and dtype inference is disabled.
58
+ inexact: When True, the output dtype must be a subdtype
59
+ of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
60
+ is useful when you want to apply operations that don't work directly on
61
+ integers like taking a mean for example.
62
+ Returns:
63
+ The dtype that *args should be cast to.
64
+ """
65
+ if dtype is None:
66
+ args_filtered = [jnp.asarray(x) for x in args if x is not None]
67
+ dtype = jnp.result_type(*args_filtered)
68
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
69
+ dtype = jnp.promote_types(jnp.float32, dtype)
70
+ if inexact and not jnp.issubdtype(dtype, jnp.inexact):
71
+ raise ValueError(f'Dtype must be inexact: {dtype}')
72
+ return dtype
73
+
74
+
36
75
  def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
37
76
  axes = []
38
77
  for axis in feature_axes:
@@ -59,28 +98,38 @@ def _compute_stats(
59
98
  axis_name: Optional[str] = None,
60
99
  axis_index_groups: Optional[Sequence[int]] = None,
61
100
  use_mean: bool = True,
101
+ use_fast_variance: bool = True,
102
+ mask: Optional[jax.Array] = None,
62
103
  ):
63
- """Computes mean and variance statistics.
104
+ """
105
+ Computes mean and variance statistics.
64
106
 
65
107
  This implementation takes care of a few important details:
66
108
  - Computes in float32 precision for stability in half precision training.
67
- - mean and variance are computable in a single XLA fusion,
68
- by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
109
+ - If ``use_fast_variance`` is ``True``, mean and variance are computed using
110
+ Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion.
69
111
  - Clips negative variances to zero which can happen due to
70
112
  roundoff errors. This avoids downstream NaNs.
71
113
  - Supports averaging across a parallel axis and subgroups of a parallel axis
72
114
  with a single `lax.pmean` call to avoid latency.
73
115
 
74
116
  Arguments:
75
- x: Input array.
76
- axes: The axes in ``x`` to compute mean and variance statistics for.
77
- dtype: tp.Optional dtype specifying the minimal precision. Statistics
78
- are always at least float32 for stability (default: dtype of x).
79
- axis_name: tp.Optional name for the pmapped axis to compute mean over.
80
- axis_index_groups: tp.Optional axis indices.
81
- use_mean: If true, calculate the mean from the input and use it when
82
- computing the variance. If false, set the mean to zero and compute
83
- the variance without subtracting the mean.
117
+ x: Input array.
118
+ axes: The axes in ``x`` to compute mean and variance statistics for.
119
+ dtype: tp.Optional dtype specifying the minimal precision. Statistics
120
+ are always at least float32 for stability (default: dtype of x).
121
+ axis_name: Optional name for the pmapped axis to compute mean over. Note,
122
+ this is only used for pmap and shard map. For SPMD jit, you do not need to
123
+ manually synchronize. Just make sure that the axes are correctly annotated
124
+ and XLA:SPMD will insert the necessary collectives.
125
+ axis_index_groups: Optional axis indices.
126
+ use_mean: If true, calculate the mean from the input and use it when
127
+ computing the variance. If false, set the mean to zero and compute
128
+ the variance without subtracting the mean.
129
+ use_fast_variance: If true, use a faster, but less numerically stable,
130
+ calculation for the variance.
131
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
132
+ the positions for which the mean and variance should be computed.
84
133
 
85
134
  Returns:
86
135
  A pair ``(mean, val)``.
@@ -89,32 +138,38 @@ def _compute_stats(
89
138
  dtype = jax.numpy.result_type(x)
90
139
  # promote x to at least float32, this avoids half precision computation
91
140
  # but preserves double or complex floating points
92
- dtype = jax.numpy.promote_types(dtype, environ.dftype())
141
+ dtype = jax.numpy.promote_types(dtype, jnp.float32)
93
142
  x = jnp.asarray(x, dtype)
143
+ axes = _canonicalize_axes(x.ndim, axes)
144
+
145
+ def maybe_distributed_mean(*xs, mask=None):
146
+ mus = tuple(x.mean(axes, where=mask) for x in xs)
147
+ if axis_name is None:
148
+ return mus if len(xs) > 1 else mus[0]
149
+ else:
150
+ # In the distributed case we stack multiple arrays to speed comms.
151
+ if len(xs) > 1:
152
+ reduced_mus = jax.lax.pmean(
153
+ jnp.stack(mus, axis=0), axis_name,
154
+ axis_index_groups=axis_index_groups,
155
+ )
156
+ return tuple(reduced_mus[i] for i in range(len(xs)))
157
+ else:
158
+ return jax.lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups)
94
159
 
95
- # Compute mean and mean of squared values.
96
- mean2 = jnp.mean(_abs_sq(x), axes)
97
160
  if use_mean:
98
- mean = jnp.mean(x, axes)
161
+ if use_fast_variance:
162
+ mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask)
163
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
164
+ # to floating point round-off errors.
165
+ var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
166
+ else:
167
+ mu = maybe_distributed_mean(x, mask=mask)
168
+ var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask)
99
169
  else:
100
- mean = jnp.zeros(mean2.shape, dtype=dtype)
101
-
102
- # If axis_name is provided, we need to average the mean and mean2 across
103
- if axis_name is not None:
104
- concatenated_mean = jnp.concatenate([mean, mean2])
105
- mean, mean2 = jnp.split(
106
- jax.lax.pmean(
107
- concatenated_mean,
108
- axis_name=axis_name,
109
- axis_index_groups=axis_index_groups,
110
- ),
111
- 2,
112
- )
113
-
114
- # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
115
- # to floating point round-off errors.
116
- var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
117
- return mean, var
170
+ var = maybe_distributed_mean(_abs_sq(x), mask=mask)
171
+ mu = jnp.zeros_like(var)
172
+ return mu, var
118
173
 
119
174
 
120
175
  def _normalize(
@@ -122,9 +177,10 @@ def _normalize(
122
177
  mean: Optional[ArrayLike],
123
178
  var: Optional[ArrayLike],
124
179
  weights: Optional[ParamState],
125
- reduction_axes: Sequence[int],
180
+ reduction_axes: Axes,
181
+ feature_axes: Axes,
126
182
  dtype: DTypeLike,
127
- epsilon: Union[numbers.Number, jax.Array],
183
+ epsilon: jax.typing.ArrayLike,
128
184
  ):
129
185
  """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
130
186
 
@@ -134,6 +190,7 @@ def _normalize(
134
190
  var: Variance to use for normalization.
135
191
  weights: The scale and bias parameters.
136
192
  reduction_axes: The axes in ``x`` to reduce.
193
+ feature_axes: The feature axes to apply the scale and bias.
137
194
  dtype: The dtype of the result (default: infer from input and params).
138
195
  epsilon: Normalization epsilon.
139
196
 
@@ -142,16 +199,23 @@ def _normalize(
142
199
  """
143
200
  if mean is not None:
144
201
  assert var is not None, 'mean and val must be both None or not None.'
202
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
203
+ feature_axes = _canonicalize_axes(x.ndim, feature_axes)
145
204
  stats_shape = list(x.shape)
146
205
  for axis in reduction_axes:
147
206
  stats_shape[axis] = 1
148
207
  mean = mean.reshape(stats_shape)
149
208
  var = var.reshape(stats_shape)
209
+ feature_shape = [1] * x.ndim
210
+ for ax in feature_axes:
211
+ feature_shape[ax] = x.shape[ax]
150
212
  y = x - mean
151
- mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
213
+ mul = jax.lax.rsqrt(var + epsilon)
152
214
  y = y * mul
215
+ args = []
153
216
  if weights is not None:
154
- y = _scale_operation(y, weights.value)
217
+ y, args = _scale_operation(y, weights.value)
218
+ dtype = canonicalize_dtype(x, *args, dtype=dtype)
155
219
  else:
156
220
  assert var is None, 'mean and val must be both None or not None.'
157
221
  assert weights is None, 'scale and bias are not supported without mean and val'
@@ -159,12 +223,15 @@ def _normalize(
159
223
  return jnp.asarray(y, dtype)
160
224
 
161
225
 
162
- def _scale_operation(x, param):
226
+ def _scale_operation(x: jax.Array, param: Dict):
227
+ args = []
163
228
  if 'scale' in param:
164
229
  x = x * param['scale']
230
+ args.append(param['scale'])
165
231
  if 'bias' in param:
166
232
  x = x + param['bias']
167
- return x
233
+ args.append(param['bias'])
234
+ return x, args
168
235
 
169
236
 
170
237
  class _BatchNorm(Module):
@@ -175,6 +242,7 @@ class _BatchNorm(Module):
175
242
  self,
176
243
  in_size: Size,
177
244
  feature_axis: Axes = -1,
245
+ *,
178
246
  track_running_stats: bool = True,
179
247
  epsilon: float = 1e-5,
180
248
  momentum: float = 0.99,
@@ -183,14 +251,15 @@ class _BatchNorm(Module):
183
251
  scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
184
252
  axis_name: Optional[Union[str, Sequence[str]]] = None,
185
253
  axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
254
+ use_fast_variance: bool = True,
186
255
  name: Optional[str] = None,
187
256
  dtype: Any = None,
188
257
  ):
189
258
  super().__init__(name=name)
190
259
 
191
260
  # parameters
192
- self.in_size = tuple(in_size)
193
- self.out_size = tuple(in_size)
261
+ self.in_size = in_size
262
+ self.out_size = in_size
194
263
  self.affine = affine
195
264
  self.bias_initializer = bias_initializer
196
265
  self.scale_initializer = scale_initializer
@@ -198,15 +267,17 @@ class _BatchNorm(Module):
198
267
  self.track_running_stats = track_running_stats
199
268
  self.momentum = jnp.asarray(momentum, dtype=self.dtype)
200
269
  self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
270
+ self.use_fast_variance = use_fast_variance
201
271
 
202
272
  # parameters about axis
203
273
  feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
204
- self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
274
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
205
275
  self.axis_name = axis_name
206
276
  self.axis_index_groups = axis_index_groups
207
277
 
208
278
  # variables
209
- feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
279
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
280
+ for i, ax in enumerate(self.in_size)])
210
281
  if self.track_running_stats:
211
282
  self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
212
283
  self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
@@ -223,7 +294,7 @@ class _BatchNorm(Module):
223
294
  else:
224
295
  self.weight = None
225
296
 
226
- def update(self, x):
297
+ def update(self, x, mask: Optional[jax.Array] = None):
227
298
  # input shape and batch mode or not
228
299
  if x.ndim == self.num_spatial_dims + 2:
229
300
  x_shape = x.shape[1:]
@@ -239,9 +310,9 @@ class _BatchNorm(Module):
239
310
 
240
311
  # reduce the feature axis
241
312
  if batch:
242
- reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
313
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axes)
243
314
  else:
244
- reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
315
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axes)
245
316
 
246
317
  # fitting phase
247
318
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
@@ -255,6 +326,8 @@ class _BatchNorm(Module):
255
326
  dtype=self.dtype,
256
327
  axis_name=self.axis_name,
257
328
  axis_index_groups=self.axis_index_groups,
329
+ use_fast_variance=self.use_fast_variance,
330
+ mask=mask,
258
331
  )
259
332
  self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
260
333
  self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
@@ -265,14 +338,22 @@ class _BatchNorm(Module):
265
338
  mean, var = None, None
266
339
 
267
340
  # normalize
268
- return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
341
+ return _normalize(
342
+ x,
343
+ mean=mean,
344
+ var=var,
345
+ weights=self.weight,
346
+ reduction_axes=reduction_axes,
347
+ feature_axes=self.feature_axes,
348
+ dtype=self.dtype,
349
+ epsilon=self.epsilon
350
+ )
269
351
 
270
352
 
271
353
  class BatchNorm0d(_BatchNorm):
272
- r"""1-D batch normalization [1]_.
354
+ r"""0-D batch normalization [1]_.
273
355
 
274
- The data should be of `(b, l, c)`, where `b` is the batch dimension,
275
- `l` is the layer dimension, and `c` is the channel dimension.
356
+ The data should be of `(b, c)`, where `b` is the batch dimension, and `c` is the channel dimension.
276
357
 
277
358
  %s
278
359
  """
@@ -375,7 +456,10 @@ _bn_doc = r'''
375
456
  example, `[[0, 1], [2, 3]]` would independently batch-normalize over
376
457
  the examples on the first two and last two devices. See `jax.lax.psum`
377
458
  for more details.
378
-
459
+ use_fast_variance: If true, use a faster, but less numerically stable,
460
+ calculation for the variance.
461
+
462
+
379
463
  References
380
464
  ----------
381
465
  .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
@@ -386,3 +470,441 @@ _bn_doc = r'''
386
470
  BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
387
471
  BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
388
472
  BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
473
+
474
+
475
+ class LayerNorm(Module):
476
+ """
477
+ Layer normalization (https://arxiv.org/abs/1607.06450).
478
+
479
+ LayerNorm normalizes the activations of the layer for each given example in a
480
+ batch independently, rather than across a batch like Batch Normalization.
481
+ i.e. applies a transformation that maintains the mean activation within
482
+ each example close to 0 and the activation standard deviation close to 1.
483
+
484
+ Example usage::
485
+
486
+ >>> import brainstate as bst
487
+ >>> x = bst.random.normal(size=(3, 4, 5, 6))
488
+ >>> layer = bst.nn.LayerNorm(x.shape)
489
+ >>> layer.states()
490
+ >>> y = layer(x)
491
+
492
+ Attributes:
493
+ in_size: The input shape, without batch size.
494
+ epsilon: A small float added to variance to avoid dividing by zero.
495
+ dtype: the dtype of the result (default: infer from input and params).
496
+ use_bias: If True, bias (beta) is added.
497
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
498
+ (also e.g. nnx.relu), this can be disabled since the scaling will be done
499
+ by the next layer.
500
+ bias_init: Initializer for bias, by default, zero.
501
+ scale_init: Initializer for scale, by default, one.
502
+ reduction_axes: Axes for computing normalization statistics. It is recommended
503
+ to use the negative integer, since when the batch dimension is used,
504
+ the reduction_axes may be wrong when using the positive integer.
505
+ feature_axes: Feature axes for learned bias and scaling.
506
+ axis_name: the axis name used to combine batch statistics from multiple
507
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
508
+ This is only needed if the model is subdivided across devices, i.e. the
509
+ array being normalized is sharded across devices within a pmap.
510
+ axis_index_groups: groups of axis indices within that named axis
511
+ representing subsets of devices to reduce over (default: None). For
512
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
513
+ the examples on the first two and last two devices. See ``jax.lax.psum``
514
+ for more details.
515
+ use_fast_variance: If true, use a faster, but less numerically stable,
516
+ calculation for the variance.
517
+ """
518
+
519
+ def __init__(
520
+ self,
521
+ in_size: Size,
522
+ reduction_axes: Axes = -1,
523
+ feature_axes: Axes = -1,
524
+ *,
525
+ epsilon: float = 1e-6,
526
+ use_bias: bool = True,
527
+ use_scale: bool = True,
528
+ bias_init: Callable = init.ZeroInit(),
529
+ scale_init: Callable = init.Constant(1.0),
530
+ axis_name: Optional[str] = None,
531
+ axis_index_groups: Any = None,
532
+ use_fast_variance: bool = True,
533
+ dtype: Optional[jax.typing.DTypeLike] = None,
534
+ ):
535
+ super().__init__()
536
+
537
+ self.in_size = in_size
538
+ self.out_size = in_size
539
+
540
+ # parameters about axis
541
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
542
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
543
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
544
+ self.axis_name = axis_name
545
+ self.axis_index_groups = axis_index_groups
546
+
547
+ # variables
548
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
549
+ for i, ax in enumerate(self.in_size)])
550
+
551
+ weights = dict()
552
+ if use_scale:
553
+ weights['scale'] = init.param(scale_init, feature_shape)
554
+ if use_bias:
555
+ weights['bias'] = init.param(bias_init, feature_shape)
556
+ if len(weights):
557
+ self.weight = ParamState(weights)
558
+ else:
559
+ self.weight = None
560
+
561
+ # parameters
562
+ self.epsilon = epsilon
563
+ self.dtype = dtype or environ.dftype()
564
+ self.use_bias = use_bias
565
+ self.use_scale = use_scale
566
+ self.bias_init = bias_init
567
+ self.scale_init = scale_init
568
+ self.use_fast_variance = use_fast_variance
569
+
570
+ def update(self, x, *, mask: Optional[jax.Array] = None):
571
+ """Applies layer normalization on the input.
572
+
573
+ Args:
574
+ x: the inputs
575
+
576
+ Returns:
577
+ Normalized inputs (the same shape as inputs).
578
+ """
579
+ mean, var = _compute_stats(
580
+ x,
581
+ self.reduction_axes,
582
+ dtype=self.dtype,
583
+ axis_name=self.axis_name,
584
+ axis_index_groups=self.axis_index_groups,
585
+ use_fast_variance=self.use_fast_variance,
586
+ mask=mask,
587
+ )
588
+
589
+ return _normalize(
590
+ x,
591
+ mean=mean,
592
+ var=var,
593
+ weights=self.weight,
594
+ reduction_axes=self.reduction_axes,
595
+ feature_axes=self.feature_axes,
596
+ dtype=self.dtype,
597
+ epsilon=self.epsilon,
598
+ )
599
+
600
+
601
+ class RMSNorm(Module):
602
+ """
603
+ RMS Layer normalization (https://arxiv.org/abs/1910.07467).
604
+
605
+ RMSNorm normalizes the activations of the layer for each given example in a
606
+ batch independently, rather than across a batch like Batch Normalization.
607
+ Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the
608
+ standard deviation of the activations, RMSNorm does not re-center at all
609
+ and instead normalizes by the root mean square of the activations.
610
+
611
+ Example usage::
612
+
613
+ >>> import brainstate as bst
614
+ >>> x = bst.random.normal(size=(5, 6))
615
+ >>> layer = bst.nn.RMSNorm(num_features=6)
616
+ >>> layer.states()
617
+ >>> y = layer(x)
618
+
619
+ Attributes:
620
+ in_size: The input shape, without batch size.
621
+ epsilon: A small float added to variance to avoid dividing by zero.
622
+ dtype: the dtype of the result (default: infer from input and params).
623
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
624
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
625
+ by the next layer.
626
+ scale_init: Initializer for scale, by default, one.
627
+ reduction_axes: Axes for computing normalization statistics. It is recommended
628
+ to use the negative integer, since when the batch dimension is used,
629
+ the reduction_axes may be wrong when using the positive integer.
630
+ feature_axes: Feature axes for learned bias and scaling.
631
+ axis_name: the axis name used to combine batch statistics from multiple
632
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
633
+ This is only needed if the model is subdivided across devices, i.e. the
634
+ array being normalized is sharded across devices within a pmap.
635
+ axis_index_groups: groups of axis indices within that named axis
636
+ representing subsets of devices to reduce over (default: None). For
637
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over
638
+ the examples on the first two and last two devices. See ``jax.lax.psum``
639
+ for more details.
640
+ use_fast_variance: If true, use a faster, but less numerically stable,
641
+ calculation for the variance.
642
+ """
643
+
644
+ def __init__(
645
+ self,
646
+ in_size: Size,
647
+ *,
648
+ epsilon: float = 1e-6,
649
+ dtype: Optional[jax.typing.DTypeLike] = None,
650
+ use_scale: bool = True,
651
+ scale_init: Callable = init.Constant(1.0),
652
+ reduction_axes: Axes = -1,
653
+ feature_axes: Axes = -1,
654
+ axis_name: Optional[str] = None,
655
+ axis_index_groups: Any = None,
656
+ use_fast_variance: bool = True,
657
+ ):
658
+ super().__init__()
659
+
660
+ self.in_size = in_size
661
+ self.out_size = in_size
662
+
663
+ # parameters about axis
664
+ feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
665
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
666
+ self.reduction_axes = (reduction_axes, ) if isinstance(reduction_axes, int) else reduction_axes
667
+ self.axis_name = axis_name
668
+ self.axis_index_groups = axis_index_groups
669
+
670
+ # variables
671
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
672
+ for i, ax in enumerate(self.in_size)])
673
+ if use_scale:
674
+ self.scale = ParamState({'scale': init.param(scale_init, feature_shape)})
675
+ else:
676
+ self.scale = None
677
+
678
+ # parameters
679
+ self.epsilon = epsilon
680
+ self.dtype = dtype or environ.dftype()
681
+ self.use_scale = use_scale
682
+ self.scale_init = scale_init
683
+ self.use_fast_variance = use_fast_variance
684
+
685
+ def update(self, x, *, mask: Optional[jax.Array] = None):
686
+ """Applies layer normalization on the input.
687
+
688
+ Args:
689
+ x: the inputs
690
+ mask: the mask
691
+
692
+ Returns:
693
+ Normalized inputs (the same shape as inputs).
694
+ """
695
+ mean, var = _compute_stats(
696
+ x,
697
+ self.reduction_axes,
698
+ dtype=self.dtype,
699
+ axis_name=self.axis_name,
700
+ axis_index_groups=self.axis_index_groups,
701
+ use_mean=False,
702
+ use_fast_variance=self.use_fast_variance,
703
+ mask=mask,
704
+ )
705
+
706
+ return _normalize(
707
+ x,
708
+ mean=mean,
709
+ var=var,
710
+ weights=self.scale,
711
+ reduction_axes=self.reduction_axes,
712
+ feature_axes=self.feature_axes,
713
+ dtype=self.dtype,
714
+ epsilon=self.epsilon,
715
+ )
716
+
717
+
718
+ class GroupNorm(Module):
719
+ """
720
+ Group normalization (arxiv.org/abs/1803.08494).
721
+
722
+ This op is similar to batch normalization, but statistics are shared across
723
+ equally-sized groups of channels and not shared across batch dimension.
724
+ Thus, group normalization does not depend on the batch composition and does
725
+ not require maintaining internal state for storing statistics.
726
+ The user should either specify the total number of channel groups or the
727
+ number of channels per group.
728
+
729
+ .. note::
730
+ LayerNorm is a special case of GroupNorm where ``num_groups=1``.
731
+
732
+ Example usage::
733
+
734
+ >>> import numpy as np
735
+ >>> import brainstate as bst
736
+ ...
737
+ >>> x = bst.random.normal(size=(3, 4, 5, 6))
738
+ >>> layer = bst.nn.GroupNorm(x.shape, num_groups=3)
739
+ >>> layer.states()
740
+ >>> y = layer(x)
741
+ >>> y = bst.nn.GroupNorm(x.shape, num_groups=1)(x)
742
+ >>> y2 = bst.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x)
743
+ >>> np.testing.assert_allclose(y, y2)
744
+
745
+ Attributes:
746
+ in_size: The input shape, without batch size.
747
+ num_groups: the total number of channel groups. The default value of 32 is
748
+ proposed by the original group normalization paper.
749
+ group_size: the number of channels in a group.
750
+ epsilon: A small float added to variance to avoid dividing by zero.
751
+ dtype: the dtype of the result (default: infer from input and params).
752
+ use_bias: If True, bias (beta) is added.
753
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
754
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
755
+ by the next layer.
756
+ bias_init: Initializer for bias, by default, zero.
757
+ scale_init: Initializer for scale, by default, one.
758
+ reduction_axes: List of axes used for computing normalization statistics.
759
+ This list must include the final dimension, which is assumed to be the
760
+ feature axis. Furthermore, if the input used at call time has additional
761
+ leading axes compared to the data used for initialisation, for example due
762
+ to batching, then the reduction axes need to be defined explicitly.
763
+ It is recommended to use the negative integer, since when the batch dimension is used,
764
+ the reduction_axes may be wrong when using the positive integer.
765
+ axis_name: the axis name used to combine batch statistics from multiple
766
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
767
+ This is only needed if the model is subdivided across devices, i.e. the
768
+ array being normalized is sharded across devices within a pmap or shard
769
+ map. For SPMD jit, you do not need to manually synchronize. Just make sure
770
+ that the axes are correctly annotated and XLA:SPMD will insert the
771
+ necessary collectives.
772
+ axis_index_groups: groups of axis indices within that named axis
773
+ representing subsets of devices to reduce over (default: None). For
774
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
775
+ examples on the first two and last two devices. See ``jax.lax.psum`` for
776
+ more details.
777
+ use_fast_variance: If true, use a faster, but less numerically stable,
778
+ calculation for the variance.
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ in_size: Size,
784
+ feature_axis: Axes = -1,
785
+ num_groups: Optional[int] = 32,
786
+ group_size: Optional[int] = None,
787
+ *,
788
+ epsilon: float = 1e-6,
789
+ dtype: Optional[jax.typing.DTypeLike] = None,
790
+ use_bias: bool = True,
791
+ use_scale: bool = True,
792
+ bias_init: Callable = init.ZeroInit(),
793
+ scale_init: Callable = init.Constant(1.),
794
+ reduction_axes: Optional[Axes] = None,
795
+ axis_name: Optional[str] = None,
796
+ axis_index_groups: Any = None,
797
+ use_fast_variance: bool = True,
798
+ ):
799
+ super().__init__()
800
+
801
+ self.in_size = in_size
802
+ self.out_size = in_size
803
+
804
+ # parameters about axis
805
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
806
+ self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axis)
807
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
808
+ self.axis_name = axis_name
809
+ self.axis_index_groups = axis_index_groups
810
+
811
+ if (num_groups is None and group_size is None) or (
812
+ num_groups is not None and group_size is not None
813
+ ):
814
+ raise ValueError(
815
+ 'Either `num_groups` or `group_size` should be '
816
+ 'specified. If `group_size` is to be specified, '
817
+ 'pass `num_groups=None` as argument to override '
818
+ 'the default `num_groups` value of 32.'
819
+ )
820
+
821
+ feature_shape = tuple([(ax if i in self.feature_axes else 1)
822
+ for i, ax in enumerate(self.in_size)])
823
+ assert len(feature_shape) == 1, 'GroupNorm only supports 1D feature axis.'
824
+ num_features = feature_shape[0]
825
+ if group_size is not None:
826
+ if num_features % group_size != 0:
827
+ raise ValueError(
828
+ 'Number of features ({}) is not multiple of the '
829
+ 'group size ({}).'.format(num_features, group_size)
830
+ )
831
+ self.num_groups = num_features // group_size
832
+ self.group_size = group_size
833
+ else:
834
+ if not isinstance(num_groups, int) or num_groups <= 0 or (
835
+ num_features % num_groups != 0
836
+ ):
837
+ raise ValueError(
838
+ 'Number of groups ({}) does not divide the number'
839
+ ' of channels ({}).'.format(num_groups, num_features)
840
+ )
841
+ self.num_groups = num_groups
842
+ self.group_size = num_features // num_groups
843
+
844
+ # variables
845
+ weights = dict()
846
+ if use_scale:
847
+ weights['scale'] = init.param(scale_init, feature_shape)
848
+ if use_bias:
849
+ weights['bias'] = init.param(bias_init, feature_shape)
850
+ if len(weights):
851
+ self.weight = ParamState(weights)
852
+ else:
853
+ self.weight = None
854
+
855
+ # parameters
856
+ self.epsilon = epsilon
857
+ self.dtype = dtype
858
+ self.use_bias = use_bias
859
+ self.use_scale = use_scale
860
+ self.bias_init = bias_init
861
+ self.scale_init = scale_init
862
+ self.use_fast_variance = use_fast_variance
863
+
864
+ def update(self, x, *, mask: Optional[jax.Array] = None):
865
+ """Applies group normalization to the input (arxiv.org/abs/1803.08494).
866
+
867
+ Args:
868
+ x: the input of shape ``...self.num_features`` where ``self.num_features``
869
+ is a channels dimension and ``...`` represents an arbitrary number of
870
+ extra dimensions that can be used to accumulate statistics over. If no
871
+ reduction axes have been specified then all additional dimensions ``...``
872
+ will be used to accumulate statistics apart from the leading dimension
873
+ which is assumed to represent the batch.
874
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
875
+ the positions for which the mean and variance should be computed.
876
+
877
+ Returns:
878
+ Normalized inputs (the same shape as inputs).
879
+ """
880
+ if self.reduction_axes is not None:
881
+ reduction_axes = self.reduction_axes
882
+ else:
883
+ reduction_axes = list(range(1, x.ndim - 1)) + [-1]
884
+ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
885
+
886
+ group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
887
+ if mask is not None:
888
+ mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size))
889
+
890
+ mean, var = _compute_stats(
891
+ x.reshape(group_shape),
892
+ list(reduction_axes[:-1]) + [-1],
893
+ dtype=self.dtype,
894
+ axis_name=self.axis_name,
895
+ axis_index_groups=self.axis_index_groups,
896
+ use_fast_variance=self.use_fast_variance,
897
+ mask=mask,
898
+ )
899
+ mean = jnp.repeat(mean, self.group_size, axis=1)
900
+ var = jnp.repeat(var, self.group_size, axis=1)
901
+ return _normalize(
902
+ x,
903
+ mean=mean,
904
+ var=var,
905
+ weights=self.weight,
906
+ reduction_axes=reduction_axes[:-1],
907
+ feature_axes=self.feature_axes,
908
+ dtype=self.dtype,
909
+ epsilon=self.epsilon,
910
+ )