brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/nn/_misc.py ADDED
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ from enum import Enum
20
+ from functools import wraps
21
+ from typing import Sequence, Callable
22
+
23
+ import jax.numpy as jnp
24
+
25
+ from .. import environ, math
26
+ from .._state import State
27
+ from ..transform import vector_grad
28
+
29
+ __all__ = [
30
+ # 'exp_euler',
31
+ 'exp_euler_step',
32
+ ]
33
+
34
+ git_issue_addr = 'https://github.com/brainpy/brainscale/issues'
35
+
36
+
37
+ def state_traceback(states: Sequence[State]):
38
+ """
39
+ Traceback the states of the brain model.
40
+
41
+ Parameters
42
+ ----------
43
+ states : Sequence[bst.State]
44
+ The states of the brain model.
45
+
46
+ Returns
47
+ -------
48
+ str
49
+ The traceback information of the states.
50
+ """
51
+ state_info = []
52
+ for i, state in enumerate(states):
53
+ state_info.append(f'State {i}: {state}\n'
54
+ f'defined at \n'
55
+ f'{state.source_info.traceback}\n')
56
+ return '\n'.join(state_info)
57
+
58
+
59
+ class BaseEnum(Enum):
60
+ @classmethod
61
+ def get_by_name(cls, name: str):
62
+ for item in cls:
63
+ if item.name == name:
64
+ return item
65
+ raise ValueError(f'Cannot find the {cls.__name__} type {name}.')
66
+
67
+ @classmethod
68
+ def get(cls, type_: str | Enum):
69
+ if isinstance(type_, cls):
70
+ return type_
71
+ elif isinstance(type_, str):
72
+ return cls.get_by_name(type_)
73
+ else:
74
+ raise ValueError(f'Cannot find the {cls.__name__} type {type_}.')
75
+
76
+
77
+ def exp_euler(fun):
78
+ """
79
+ Exponential Euler method for solving ODEs.
80
+
81
+ Args:
82
+ fun: Callable. The function to be solved.
83
+
84
+ Returns:
85
+ The integral function.
86
+ """
87
+
88
+ @wraps(fun)
89
+ def integral(*args, **kwargs):
90
+ assert len(args) > 0, 'The input arguments should not be empty.'
91
+ if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
92
+ raise ValueError(
93
+ 'The input data type should be float32, float64, float16, or bfloat16 '
94
+ 'when using Exponential Euler method.'
95
+ f'But we got {args[0].dtype}.'
96
+ )
97
+ dt = environ.get('dt')
98
+ linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
99
+ phi = math.exprel(dt * linear)
100
+ return args[0] + dt * phi * derivative
101
+
102
+ return integral
103
+
104
+
105
+ def exp_euler_step(fun: Callable, *args, **kwargs):
106
+ """
107
+ Exponential Euler method for solving ODEs.
108
+
109
+ Examples
110
+ --------
111
+ >>> def fun(x, t):
112
+ ... return -x
113
+ >>> x = 1.0
114
+ >>> exp_euler_step(fun, x, None)
115
+
116
+ Args:
117
+ fun: Callable. The function to be solved.
118
+
119
+ Returns:
120
+ The integral function.
121
+ """
122
+ assert len(args) > 0, 'The input arguments should not be empty.'
123
+ if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
124
+ raise ValueError(
125
+ 'The input data type should be float32, float64, float16, or bfloat16 '
126
+ 'when using Exponential Euler method.'
127
+ f'But we got {args[0].dtype}.'
128
+ )
129
+ dt = environ.get('dt')
130
+ linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
131
+ phi = math.exprel(dt * linear)
132
+ return args[0] + dt * phi * derivative
@@ -0,0 +1,389 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import numbers
21
+ from typing import Callable, Union, Sequence, Optional, Any
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+ from ._base import DnnLayer
27
+ from .. import environ, init
28
+ from .._state import LongTermState, ParamState
29
+ from ..mixin import Mode
30
+ from ..typing import DTypeLike, ArrayLike, Size, Axes
31
+
32
+ __all__ = [
33
+ 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
34
+ ]
35
+
36
+
37
+ def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
38
+ axes = []
39
+ for axis in feature_axes:
40
+ if axis < 0:
41
+ axis += ndim
42
+ if axis < 0 or axis >= ndim:
43
+ raise ValueError(f'Invalid axis {axis} for {ndim}D input')
44
+ axes.append(axis)
45
+ return tuple(axes)
46
+
47
+
48
+ def _abs_sq(x):
49
+ """Computes the elementwise square of the absolute value |x|^2."""
50
+ if jnp.iscomplexobj(x):
51
+ return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
52
+ else:
53
+ return jax.lax.square(x)
54
+
55
+
56
+ def _compute_stats(
57
+ x: ArrayLike,
58
+ axes: Sequence[int],
59
+ dtype: DTypeLike,
60
+ axis_name: Optional[str] = None,
61
+ axis_index_groups: Optional[Sequence[int]] = None,
62
+ use_mean: bool = True,
63
+ ):
64
+ """Computes mean and variance statistics.
65
+
66
+ This implementation takes care of a few important details:
67
+ - Computes in float32 precision for stability in half precision training.
68
+ - mean and variance are computable in a single XLA fusion,
69
+ by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
70
+ - Clips negative variances to zero which can happen due to
71
+ roundoff errors. This avoids downstream NaNs.
72
+ - Supports averaging across a parallel axis and subgroups of a parallel axis
73
+ with a single `lax.pmean` call to avoid latency.
74
+
75
+ Arguments:
76
+ x: Input array.
77
+ axes: The axes in ``x`` to compute mean and variance statistics for.
78
+ dtype: tp.Optional dtype specifying the minimal precision. Statistics
79
+ are always at least float32 for stability (default: dtype of x).
80
+ axis_name: tp.Optional name for the pmapped axis to compute mean over.
81
+ axis_index_groups: tp.Optional axis indices.
82
+ use_mean: If true, calculate the mean from the input and use it when
83
+ computing the variance. If false, set the mean to zero and compute
84
+ the variance without subtracting the mean.
85
+
86
+ Returns:
87
+ A pair ``(mean, var)``.
88
+ """
89
+ if dtype is None:
90
+ dtype = jax.numpy.result_type(x)
91
+ # promote x to at least float32, this avoids half precision computation
92
+ # but preserves double or complex floating points
93
+ dtype = jax.numpy.promote_types(dtype, environ.dftype())
94
+ x = jnp.asarray(x, dtype)
95
+
96
+ # Compute mean and mean of squared values.
97
+ mean2 = jnp.mean(_abs_sq(x), axes)
98
+ if use_mean:
99
+ mean = jnp.mean(x, axes)
100
+ else:
101
+ mean = jnp.zeros(mean2.shape, dtype=dtype)
102
+
103
+ # If axis_name is provided, we need to average the mean and mean2 across
104
+ if axis_name is not None:
105
+ concatenated_mean = jnp.concatenate([mean, mean2])
106
+ mean, mean2 = jnp.split(
107
+ jax.lax.pmean(
108
+ concatenated_mean,
109
+ axis_name=axis_name,
110
+ axis_index_groups=axis_index_groups,
111
+ ),
112
+ 2,
113
+ )
114
+
115
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
116
+ # to floating point round-off errors.
117
+ var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
118
+ return mean, var
119
+
120
+
121
+ def _normalize(
122
+ x: ArrayLike,
123
+ mean: Optional[ArrayLike],
124
+ var: Optional[ArrayLike],
125
+ weights: Optional[ParamState],
126
+ reduction_axes: Sequence[int],
127
+ dtype: DTypeLike,
128
+ epsilon: Union[numbers.Number, jax.Array],
129
+ ):
130
+ """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
131
+
132
+ Arguments:
133
+ x: The input.
134
+ mean: Mean to use for normalization.
135
+ var: Variance to use for normalization.
136
+ weights: The scale and bias parameters.
137
+ reduction_axes: The axes in ``x`` to reduce.
138
+ dtype: The dtype of the result (default: infer from input and params).
139
+ epsilon: Normalization epsilon.
140
+
141
+ Returns:
142
+ The normalized input.
143
+ """
144
+ if mean is not None:
145
+ assert var is not None, 'mean and var must be both None or not None.'
146
+ stats_shape = list(x.shape)
147
+ for axis in reduction_axes:
148
+ stats_shape[axis] = 1
149
+ mean = mean.reshape(stats_shape)
150
+ var = var.reshape(stats_shape)
151
+ y = x - mean
152
+ mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
153
+ y = y * mul
154
+ if weights is not None:
155
+ y = _scale_operation(y, weights.value)
156
+ else:
157
+ assert var is None, 'mean and var must be both None or not None.'
158
+ assert weights is None, 'scale and bias are not supported without mean and var'
159
+ y = x
160
+ return jnp.asarray(y, dtype)
161
+
162
+
163
+ def _scale_operation(x, param):
164
+ if 'scale' in param:
165
+ x = x * param['scale']
166
+ if 'bias' in param:
167
+ x = x + param['bias']
168
+ return x
169
+
170
+
171
+ class _BatchNorm(DnnLayer):
172
+ __module__ = 'brainstate.nn'
173
+ num_spatial_dims: int
174
+
175
+ def __init__(
176
+ self,
177
+ in_size: Size,
178
+ feature_axis: Axes = -1,
179
+ track_running_stats: bool = True,
180
+ epsilon: float = 1e-5,
181
+ momentum: float = 0.99,
182
+ affine: bool = True,
183
+ bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
184
+ scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
185
+ axis_name: Optional[Union[str, Sequence[str]]] = None,
186
+ axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
187
+ mode: Optional[Mode] = None,
188
+ name: Optional[str] = None,
189
+ dtype: Any = None,
190
+ ):
191
+ super().__init__(name=name, mode=mode)
192
+
193
+ # parameters
194
+ self.in_size = tuple(in_size)
195
+ self.out_size = tuple(in_size)
196
+ self.affine = affine
197
+ self.bias_initializer = bias_initializer
198
+ self.scale_initializer = scale_initializer
199
+ self.dtype = dtype or environ.dftype()
200
+ self.track_running_stats = track_running_stats
201
+ self.momentum = jnp.asarray(momentum, dtype=self.dtype)
202
+ self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
203
+
204
+ # parameters about axis
205
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
206
+ self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
207
+ self.axis_name = axis_name
208
+ self.axis_index_groups = axis_index_groups
209
+
210
+ # variables
211
+ feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
212
+ if self.track_running_stats:
213
+ self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
214
+ self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
215
+ else:
216
+ self.running_mean = None
217
+ self.running_var = None
218
+
219
+ # parameters
220
+ if self.affine:
221
+ assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
222
+ bias = init.param(self.bias_initializer, feature_shape)
223
+ scale = init.param(self.scale_initializer, feature_shape)
224
+ self.weight = ParamState(dict(bias=bias, scale=scale))
225
+ else:
226
+ self.weight = None
227
+
228
+ def _check_input_dim(self, x):
229
+ if x.ndim == self.num_spatial_dims + 2:
230
+ x_shape = x.shape[1:]
231
+ elif x.ndim == self.num_spatial_dims + 1:
232
+ x_shape = x.shape
233
+ else:
234
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
235
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
236
+ if self.in_size != x_shape:
237
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
238
+
239
+ def update(self, x):
240
+ # input shape and batch mode or not
241
+ if x.ndim == self.num_spatial_dims + 2:
242
+ x_shape = x.shape[1:]
243
+ batch = True
244
+ elif x.ndim == self.num_spatial_dims + 1:
245
+ x_shape = x.shape
246
+ batch = False
247
+ else:
248
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
249
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
250
+ if self.in_size != x_shape:
251
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
252
+
253
+ # reduce the feature axis
254
+ if batch:
255
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
256
+ else:
257
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
258
+
259
+ # fitting phase
260
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
261
+
262
+ # compute the running mean and variance
263
+ if self.track_running_stats:
264
+ if fit_phase:
265
+ mean, var = _compute_stats(
266
+ x,
267
+ reduction_axes,
268
+ dtype=self.dtype,
269
+ axis_name=self.axis_name,
270
+ axis_index_groups=self.axis_index_groups,
271
+ )
272
+ self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
273
+ self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
274
+ else:
275
+ mean = self.running_mean.value
276
+ var = self.running_var.value
277
+ else:
278
+ mean, var = None, None
279
+
280
+ # normalize
281
+ return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
282
+
283
+
284
+ class BatchNorm1d(_BatchNorm):
285
+ r"""1-D batch normalization [1]_.
286
+
287
+ The data should be of `(b, l, c)`, where `b` is the batch dimension,
288
+ `l` is the layer dimension, and `c` is the channel dimension.
289
+
290
+ %s
291
+ """
292
+ __module__ = 'brainstate.nn'
293
+ num_spatial_dims: int = 1
294
+
295
+
296
+ class BatchNorm2d(_BatchNorm):
297
+ r"""2-D batch normalization [1]_.
298
+
299
+ The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
300
+ `h` is the height dimension, `w` is the width dimension, and `c` is the
301
+ channel dimension.
302
+
303
+ %s
304
+ """
305
+ __module__ = 'brainstate.nn'
306
+ num_spatial_dims: int = 2
307
+
308
+
309
+ class BatchNorm3d(_BatchNorm):
310
+ r"""3-D batch normalization [1]_.
311
+
312
+ The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
313
+ `h` is the height dimension, `w` is the width dimension, `d` is the depth
314
+ dimension, and `c` is the channel dimension.
315
+
316
+ %s
317
+ """
318
+ __module__ = 'brainstate.nn'
319
+ num_spatial_dims: int = 3
320
+
321
+
322
+ _bn_doc = r'''
323
+
324
+ This layer aims to reduce the internal covariant shift of data. It
325
+ normalizes a batch of data by fixing the mean and variance of inputs
326
+ on each feature (channel). Most commonly, the first axis of the data
327
+ is the batch, and the last is the channel. However, users can specify
328
+ the axes to be normalized.
329
+
330
+ .. math::
331
+ y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
332
+
333
+ .. note::
334
+ This :attr:`momentum` argument is different from one used in optimizer
335
+ classes and the conventional notion of momentum. Mathematically, the
336
+ update rule for running statistics here is
337
+ :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
338
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
339
+ new observed value.
340
+
341
+ Parameters
342
+ ----------
343
+ in_size: sequence of int
344
+ The input shape, without batch size.
345
+ feature_axis: int, tuple, list
346
+ The feature or non-batch axis of the input.
347
+ track_running_stats: bool
348
+ A boolean value that when set to ``True``, this module tracks the running mean and variance,
349
+ and when set to ``False``, this module does not track such statistics, and initializes
350
+ statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
351
+ this module always uses batch statistics. in both training and eval modes. Default: ``True``.
352
+ momentum: float
353
+ The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
354
+ epsilon: float
355
+ A value added to the denominator for numerical stability. Default: 1e-5
356
+ affine: bool
357
+ A boolean value that when set to ``True``, this module has
358
+ learnable affine parameters. Default: ``True``
359
+ bias_initializer: ArrayLike, Callable
360
+ An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
361
+ Default: ``init.Constant(0.)``
362
+ scale_initializer: ArrayLike, Callable
363
+ An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
364
+ Default: ``init.Constant(1.)``
365
+ axis_name: optional, str, sequence of str
366
+ If not ``None``, it should be a string (or sequence of
367
+ strings) representing the axis name(s) over which this module is being
368
+ run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
369
+ argument means that batch statistics are calculated across all replicas
370
+ on the named axes.
371
+ axis_index_groups: optional, sequence
372
+ Specifies how devices are grouped. Valid
373
+ only within ``jax.pmap`` collectives.
374
+ Groups of axis indices within that named axis
375
+ representing subsets of devices to reduce over (default: None). For
376
+ example, `[[0, 1], [2, 3]]` would independently batch-normalize over
377
+ the examples on the first two and last two devices. See `jax.lax.psum`
378
+ for more details.
379
+
380
+ References
381
+ ----------
382
+ .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
383
+ by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
384
+
385
+ '''
386
+
387
+ BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
388
+ BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
389
+ BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
@@ -0,0 +1,100 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ from functools import partial
20
+ from typing import Optional
21
+
22
+ import jax.numpy as jnp
23
+
24
+ from ._base import DnnLayer
25
+ from .. import random, math, environ, typing, init
26
+ from ..mixin import Mode
27
+
28
+ __all__ = [
29
+ 'DropoutFixed',
30
+ ]
31
+
32
+
33
+ class DropoutFixed(DnnLayer):
34
+ """
35
+ A dropout layer with the fixed dropout mask along the time axis once after initialized.
36
+
37
+ In training, to compensate for the fraction of input values dropped (`rate`),
38
+ all surviving values are multiplied by `1 / (1 - rate)`.
39
+
40
+ This layer is active only during training (``mode=brainstate.mixin.Training``). In other
41
+ circumstances it is a no-op.
42
+
43
+ .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
44
+ neural networks from overfitting." The journal of machine learning
45
+ research 15.1 (2014): 1929-1958.
46
+
47
+ .. admonition:: Tip
48
+ :class: tip
49
+
50
+ This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
51
+ Network Architectures <https://arxiv.org/abs/1903.06379>`_:
52
+
53
+ There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
54
+ training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
55
+ are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
56
+ iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
57
+ the output error and modify the network parameters only at the last time step. For dropout to be effective in
58
+ our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
59
+ data is not changed, such that the neural network is constituted by the same random subset of units during
60
+ each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
61
+ each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
62
+ iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
63
+ are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
64
+ time window within an iteration.
65
+
66
+ Args:
67
+ in_size: The size of the input tensor.
68
+ prob: Probability to keep element of the tensor.
69
+ mode: Mode. The computation mode of the object.
70
+ name: str. The name of the dynamic system.
71
+ """
72
+ __module__ = 'brainstate.nn'
73
+
74
+ def __init__(
75
+ self,
76
+ in_size: typing.Size,
77
+ prob: float = 0.5,
78
+ mode: Optional[Mode] = None,
79
+ name: Optional[str] = None
80
+ ) -> None:
81
+ super().__init__(mode=mode, name=name)
82
+ assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
83
+ self.prob = prob
84
+ self.in_size = in_size
85
+ self.out_size = in_size
86
+
87
+ def init_state(self, batch_size=None, **kwargs):
88
+ self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
89
+
90
+ def update(self, x):
91
+ dtype = math.get_dtype(x)
92
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
93
+ if fit_phase:
94
+ assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
95
+ f"Please call `init_state()` method first.")
96
+ return jnp.where(self.mask,
97
+ jnp.asarray(x / self.prob, dtype=dtype),
98
+ jnp.asarray(0., dtype=dtype))
99
+ else:
100
+ return x