brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,66 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Optional, Callable, Union
17
-
18
- from ._base import DnnLayer
19
- from .. import init
20
- from brainstate._state import ParamState
21
- from brainstate.mixin import Mode, Training
22
- from brainstate.typing import ArrayLike
23
-
24
- __all__ = [
25
- 'Embedding',
26
- ]
27
-
28
-
29
- class Embedding(DnnLayer):
30
- r"""
31
- A simple lookup table that stores embeddings of a fixed size.
32
-
33
- Args:
34
- num_embeddings: Size of embedding dictionary. Must be non-negative.
35
- embedding_size: Size of each embedding vector. Must be non-negative.
36
- embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
37
-
38
- """
39
-
40
- def __init__(
41
- self,
42
- num_embeddings: int,
43
- embedding_size: int,
44
- embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
45
- name: Optional[str] = None,
46
- mode: Optional[Mode] = None,
47
- ):
48
- super().__init__(name=name, mode=mode)
49
- if num_embeddings < 0:
50
- raise ValueError("num_embeddings must not be negative.")
51
- if embedding_size < 0:
52
- raise ValueError("embedding_size must not be negative.")
53
- self.num_embeddings = num_embeddings
54
- self.embedding_size = embedding_size
55
- self.out_size = (embedding_size,)
56
-
57
- weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
58
- if self.mode.has(Training):
59
- self.weight = ParamState(weight)
60
- else:
61
- self.weight = weight
62
-
63
- def update(self, indices: ArrayLike):
64
- if self.mode.has(Training):
65
- return self.weight.value[indices]
66
- return self.weight[indices]
brainstate/nn/_misc.py DELETED
@@ -1,133 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from __future__ import annotations
18
-
19
- from enum import Enum
20
- from functools import wraps
21
- from typing import Sequence, Callable
22
-
23
- import brainunit as bu
24
- import jax.numpy as jnp
25
-
26
- from .. import environ
27
- from .._state import State
28
- from ..transform import vector_grad
29
-
30
- __all__ = [
31
- # 'exp_euler',
32
- 'exp_euler_step',
33
- ]
34
-
35
- git_issue_addr = 'https://github.com/brainpy/brainscale/issues'
36
-
37
-
38
- def state_traceback(states: Sequence[State]):
39
- """
40
- Traceback the states of the brain model.
41
-
42
- Parameters
43
- ----------
44
- states : Sequence[bst.State]
45
- The states of the brain model.
46
-
47
- Returns
48
- -------
49
- str
50
- The traceback information of the states.
51
- """
52
- state_info = []
53
- for i, state in enumerate(states):
54
- state_info.append(f'State {i}: {state}\n'
55
- f'defined at \n'
56
- f'{state.source_info.traceback}\n')
57
- return '\n'.join(state_info)
58
-
59
-
60
- class BaseEnum(Enum):
61
- @classmethod
62
- def get_by_name(cls, name: str):
63
- for item in cls:
64
- if item.name == name:
65
- return item
66
- raise ValueError(f'Cannot find the {cls.__name__} type {name}.')
67
-
68
- @classmethod
69
- def get(cls, type_: str | Enum):
70
- if isinstance(type_, cls):
71
- return type_
72
- elif isinstance(type_, str):
73
- return cls.get_by_name(type_)
74
- else:
75
- raise ValueError(f'Cannot find the {cls.__name__} type {type_}.')
76
-
77
-
78
- def exp_euler(fun):
79
- """
80
- Exponential Euler method for solving ODEs.
81
-
82
- Args:
83
- fun: Callable. The function to be solved.
84
-
85
- Returns:
86
- The integral function.
87
- """
88
-
89
- @wraps(fun)
90
- def integral(*args, **kwargs):
91
- assert len(args) > 0, 'The input arguments should not be empty.'
92
- if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
93
- raise ValueError(
94
- 'The input data type should be float32, float64, float16, or bfloat16 '
95
- 'when using Exponential Euler method.'
96
- f'But we got {args[0].dtype}.'
97
- )
98
- dt = environ.get('dt')
99
- linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
100
- phi = bu.math.exprel(dt * linear)
101
- return args[0] + dt * phi * derivative
102
-
103
- return integral
104
-
105
-
106
- def exp_euler_step(fun: Callable, *args, **kwargs):
107
- """
108
- Exponential Euler method for solving ODEs.
109
-
110
- Examples
111
- --------
112
- >>> def fun(x, t):
113
- ... return -x
114
- >>> x = 1.0
115
- >>> exp_euler_step(fun, x, None)
116
-
117
- Args:
118
- fun: Callable. The function to be solved.
119
-
120
- Returns:
121
- The integral function.
122
- """
123
- assert len(args) > 0, 'The input arguments should not be empty.'
124
- if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
125
- raise ValueError(
126
- 'The input data type should be float32, float64, float16, or bfloat16 '
127
- 'when using Exponential Euler method.'
128
- f'But we got {args[0].dtype}.'
129
- )
130
- dt = environ.get('dt')
131
- linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
132
- phi = bu.math.exprel(dt * linear)
133
- return args[0] + dt * phi * derivative
@@ -1,389 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from __future__ import annotations
19
-
20
- import numbers
21
- from typing import Callable, Union, Sequence, Optional, Any
22
-
23
- import jax
24
- import jax.numpy as jnp
25
-
26
- from ._base import DnnLayer
27
- from .. import environ, init
28
- from brainstate._state import LongTermState, ParamState
29
- from brainstate.mixin import Mode
30
- from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
31
-
32
- __all__ = [
33
- 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
34
- ]
35
-
36
-
37
- def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
38
- axes = []
39
- for axis in feature_axes:
40
- if axis < 0:
41
- axis += ndim
42
- if axis < 0 or axis >= ndim:
43
- raise ValueError(f'Invalid axis {axis} for {ndim}D input')
44
- axes.append(axis)
45
- return tuple(axes)
46
-
47
-
48
- def _abs_sq(x):
49
- """Computes the elementwise square of the absolute value |x|^2."""
50
- if jnp.iscomplexobj(x):
51
- return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
52
- else:
53
- return jax.lax.square(x)
54
-
55
-
56
- def _compute_stats(
57
- x: ArrayLike,
58
- axes: Sequence[int],
59
- dtype: DTypeLike,
60
- axis_name: Optional[str] = None,
61
- axis_index_groups: Optional[Sequence[int]] = None,
62
- use_mean: bool = True,
63
- ):
64
- """Computes mean and variance statistics.
65
-
66
- This implementation takes care of a few important details:
67
- - Computes in float32 precision for stability in half precision training.
68
- - mean and variance are computable in a single XLA fusion,
69
- by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
70
- - Clips negative variances to zero which can happen due to
71
- roundoff errors. This avoids downstream NaNs.
72
- - Supports averaging across a parallel axis and subgroups of a parallel axis
73
- with a single `lax.pmean` call to avoid latency.
74
-
75
- Arguments:
76
- x: Input array.
77
- axes: The axes in ``x`` to compute mean and variance statistics for.
78
- dtype: tp.Optional dtype specifying the minimal precision. Statistics
79
- are always at least float32 for stability (default: dtype of x).
80
- axis_name: tp.Optional name for the pmapped axis to compute mean over.
81
- axis_index_groups: tp.Optional axis indices.
82
- use_mean: If true, calculate the mean from the input and use it when
83
- computing the variance. If false, set the mean to zero and compute
84
- the variance without subtracting the mean.
85
-
86
- Returns:
87
- A pair ``(mean, var)``.
88
- """
89
- if dtype is None:
90
- dtype = jax.numpy.result_type(x)
91
- # promote x to at least float32, this avoids half precision computation
92
- # but preserves double or complex floating points
93
- dtype = jax.numpy.promote_types(dtype, environ.dftype())
94
- x = jnp.asarray(x, dtype)
95
-
96
- # Compute mean and mean of squared values.
97
- mean2 = jnp.mean(_abs_sq(x), axes)
98
- if use_mean:
99
- mean = jnp.mean(x, axes)
100
- else:
101
- mean = jnp.zeros(mean2.shape, dtype=dtype)
102
-
103
- # If axis_name is provided, we need to average the mean and mean2 across
104
- if axis_name is not None:
105
- concatenated_mean = jnp.concatenate([mean, mean2])
106
- mean, mean2 = jnp.split(
107
- jax.lax.pmean(
108
- concatenated_mean,
109
- axis_name=axis_name,
110
- axis_index_groups=axis_index_groups,
111
- ),
112
- 2,
113
- )
114
-
115
- # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
116
- # to floating point round-off errors.
117
- var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
118
- return mean, var
119
-
120
-
121
- def _normalize(
122
- x: ArrayLike,
123
- mean: Optional[ArrayLike],
124
- var: Optional[ArrayLike],
125
- weights: Optional[ParamState],
126
- reduction_axes: Sequence[int],
127
- dtype: DTypeLike,
128
- epsilon: Union[numbers.Number, jax.Array],
129
- ):
130
- """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
131
-
132
- Arguments:
133
- x: The input.
134
- mean: Mean to use for normalization.
135
- var: Variance to use for normalization.
136
- weights: The scale and bias parameters.
137
- reduction_axes: The axes in ``x`` to reduce.
138
- dtype: The dtype of the result (default: infer from input and params).
139
- epsilon: Normalization epsilon.
140
-
141
- Returns:
142
- The normalized input.
143
- """
144
- if mean is not None:
145
- assert var is not None, 'mean and var must be both None or not None.'
146
- stats_shape = list(x.shape)
147
- for axis in reduction_axes:
148
- stats_shape[axis] = 1
149
- mean = mean.reshape(stats_shape)
150
- var = var.reshape(stats_shape)
151
- y = x - mean
152
- mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
153
- y = y * mul
154
- if weights is not None:
155
- y = _scale_operation(y, weights.value)
156
- else:
157
- assert var is None, 'mean and var must be both None or not None.'
158
- assert weights is None, 'scale and bias are not supported without mean and var'
159
- y = x
160
- return jnp.asarray(y, dtype)
161
-
162
-
163
- def _scale_operation(x, param):
164
- if 'scale' in param:
165
- x = x * param['scale']
166
- if 'bias' in param:
167
- x = x + param['bias']
168
- return x
169
-
170
-
171
- class _BatchNorm(DnnLayer):
172
- __module__ = 'brainstate.nn'
173
- num_spatial_dims: int
174
-
175
- def __init__(
176
- self,
177
- in_size: Size,
178
- feature_axis: Axes = -1,
179
- track_running_stats: bool = True,
180
- epsilon: float = 1e-5,
181
- momentum: float = 0.99,
182
- affine: bool = True,
183
- bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
184
- scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
185
- axis_name: Optional[Union[str, Sequence[str]]] = None,
186
- axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
187
- mode: Optional[Mode] = None,
188
- name: Optional[str] = None,
189
- dtype: Any = None,
190
- ):
191
- super().__init__(name=name, mode=mode)
192
-
193
- # parameters
194
- self.in_size = tuple(in_size)
195
- self.out_size = tuple(in_size)
196
- self.affine = affine
197
- self.bias_initializer = bias_initializer
198
- self.scale_initializer = scale_initializer
199
- self.dtype = dtype or environ.dftype()
200
- self.track_running_stats = track_running_stats
201
- self.momentum = jnp.asarray(momentum, dtype=self.dtype)
202
- self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
203
-
204
- # parameters about axis
205
- feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
206
- self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
207
- self.axis_name = axis_name
208
- self.axis_index_groups = axis_index_groups
209
-
210
- # variables
211
- feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
212
- if self.track_running_stats:
213
- self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
214
- self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
215
- else:
216
- self.running_mean = None
217
- self.running_var = None
218
-
219
- # parameters
220
- if self.affine:
221
- assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
222
- bias = init.param(self.bias_initializer, feature_shape)
223
- scale = init.param(self.scale_initializer, feature_shape)
224
- self.weight = ParamState(dict(bias=bias, scale=scale))
225
- else:
226
- self.weight = None
227
-
228
- def _check_input_dim(self, x):
229
- if x.ndim == self.num_spatial_dims + 2:
230
- x_shape = x.shape[1:]
231
- elif x.ndim == self.num_spatial_dims + 1:
232
- x_shape = x.shape
233
- else:
234
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
235
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
236
- if self.in_size != x_shape:
237
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
238
-
239
- def update(self, x):
240
- # input shape and batch mode or not
241
- if x.ndim == self.num_spatial_dims + 2:
242
- x_shape = x.shape[1:]
243
- batch = True
244
- elif x.ndim == self.num_spatial_dims + 1:
245
- x_shape = x.shape
246
- batch = False
247
- else:
248
- raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
249
- f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
250
- if self.in_size != x_shape:
251
- raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
252
-
253
- # reduce the feature axis
254
- if batch:
255
- reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
256
- else:
257
- reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
258
-
259
- # fitting phase
260
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
261
-
262
- # compute the running mean and variance
263
- if self.track_running_stats:
264
- if fit_phase:
265
- mean, var = _compute_stats(
266
- x,
267
- reduction_axes,
268
- dtype=self.dtype,
269
- axis_name=self.axis_name,
270
- axis_index_groups=self.axis_index_groups,
271
- )
272
- self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
273
- self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
274
- else:
275
- mean = self.running_mean.value
276
- var = self.running_var.value
277
- else:
278
- mean, var = None, None
279
-
280
- # normalize
281
- return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
282
-
283
-
284
- class BatchNorm1d(_BatchNorm):
285
- r"""1-D batch normalization [1]_.
286
-
287
- The data should be of `(b, l, c)`, where `b` is the batch dimension,
288
- `l` is the layer dimension, and `c` is the channel dimension.
289
-
290
- %s
291
- """
292
- __module__ = 'brainstate.nn'
293
- num_spatial_dims: int = 1
294
-
295
-
296
- class BatchNorm2d(_BatchNorm):
297
- r"""2-D batch normalization [1]_.
298
-
299
- The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
300
- `h` is the height dimension, `w` is the width dimension, and `c` is the
301
- channel dimension.
302
-
303
- %s
304
- """
305
- __module__ = 'brainstate.nn'
306
- num_spatial_dims: int = 2
307
-
308
-
309
- class BatchNorm3d(_BatchNorm):
310
- r"""3-D batch normalization [1]_.
311
-
312
- The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
313
- `h` is the height dimension, `w` is the width dimension, `d` is the depth
314
- dimension, and `c` is the channel dimension.
315
-
316
- %s
317
- """
318
- __module__ = 'brainstate.nn'
319
- num_spatial_dims: int = 3
320
-
321
-
322
- _bn_doc = r'''
323
-
324
- This layer aims to reduce the internal covariant shift of data. It
325
- normalizes a batch of data by fixing the mean and variance of inputs
326
- on each feature (channel). Most commonly, the first axis of the data
327
- is the batch, and the last is the channel. However, users can specify
328
- the axes to be normalized.
329
-
330
- .. math::
331
- y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
332
-
333
- .. note::
334
- This :attr:`momentum` argument is different from one used in optimizer
335
- classes and the conventional notion of momentum. Mathematically, the
336
- update rule for running statistics here is
337
- :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
338
- where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
339
- new observed value.
340
-
341
- Parameters
342
- ----------
343
- in_size: sequence of int
344
- The input shape, without batch size.
345
- feature_axis: int, tuple, list
346
- The feature or non-batch axis of the input.
347
- track_running_stats: bool
348
- A boolean value that when set to ``True``, this module tracks the running mean and variance,
349
- and when set to ``False``, this module does not track such statistics, and initializes
350
- statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
351
- this module always uses batch statistics. in both training and eval modes. Default: ``True``.
352
- momentum: float
353
- The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
354
- epsilon: float
355
- A value added to the denominator for numerical stability. Default: 1e-5
356
- affine: bool
357
- A boolean value that when set to ``True``, this module has
358
- learnable affine parameters. Default: ``True``
359
- bias_initializer: ArrayLike, Callable
360
- An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
361
- Default: ``init.Constant(0.)``
362
- scale_initializer: ArrayLike, Callable
363
- An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
364
- Default: ``init.Constant(1.)``
365
- axis_name: optional, str, sequence of str
366
- If not ``None``, it should be a string (or sequence of
367
- strings) representing the axis name(s) over which this module is being
368
- run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
369
- argument means that batch statistics are calculated across all replicas
370
- on the named axes.
371
- axis_index_groups: optional, sequence
372
- Specifies how devices are grouped. Valid
373
- only within ``jax.pmap`` collectives.
374
- Groups of axis indices within that named axis
375
- representing subsets of devices to reduce over (default: None). For
376
- example, `[[0, 1], [2, 3]]` would independently batch-normalize over
377
- the examples on the first two and last two devices. See `jax.lax.psum`
378
- for more details.
379
-
380
- References
381
- ----------
382
- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
383
- by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
384
-
385
- '''
386
-
387
- BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
388
- BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
389
- BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
brainstate/nn/_others.py DELETED
@@ -1,101 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from __future__ import annotations
18
-
19
- from functools import partial
20
- from typing import Optional
21
-
22
- import brainunit as bu
23
- import jax.numpy as jnp
24
-
25
- from ._base import DnnLayer
26
- from brainstate.mixin import Mode
27
- from brainstate import random, environ, typing, init
28
-
29
- __all__ = [
30
- 'DropoutFixed',
31
- ]
32
-
33
-
34
- class DropoutFixed(DnnLayer):
35
- """
36
- A dropout layer with the fixed dropout mask along the time axis once after initialized.
37
-
38
- In training, to compensate for the fraction of input values dropped (`rate`),
39
- all surviving values are multiplied by `1 / (1 - rate)`.
40
-
41
- This layer is active only during training (``mode=brainstate.mixin.Training``). In other
42
- circumstances it is a no-op.
43
-
44
- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
45
- neural networks from overfitting." The journal of machine learning
46
- research 15.1 (2014): 1929-1958.
47
-
48
- .. admonition:: Tip
49
- :class: tip
50
-
51
- This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
52
- Network Architectures <https://arxiv.org/abs/1903.06379>`_:
53
-
54
- There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
55
- training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
56
- are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
57
- iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
58
- the output error and modify the network parameters only at the last time step. For dropout to be effective in
59
- our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
60
- data is not changed, such that the neural network is constituted by the same random subset of units during
61
- each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
62
- each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
63
- iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
64
- are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
65
- time window within an iteration.
66
-
67
- Args:
68
- in_size: The size of the input tensor.
69
- prob: Probability to keep element of the tensor.
70
- mode: Mode. The computation mode of the object.
71
- name: str. The name of the dynamic system.
72
- """
73
- __module__ = 'brainstate.nn'
74
-
75
- def __init__(
76
- self,
77
- in_size: typing.Size,
78
- prob: float = 0.5,
79
- mode: Optional[Mode] = None,
80
- name: Optional[str] = None
81
- ) -> None:
82
- super().__init__(mode=mode, name=name)
83
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
84
- self.prob = prob
85
- self.in_size = in_size
86
- self.out_size = in_size
87
-
88
- def init_state(self, batch_size=None, **kwargs):
89
- self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
90
-
91
- def update(self, x):
92
- dtype = bu.math.get_dtype(x)
93
- fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
94
- if fit_phase:
95
- assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
96
- f"Please call `init_state()` method first.")
97
- return jnp.where(self.mask,
98
- jnp.asarray(x / self.prob, dtype=dtype),
99
- jnp.asarray(0., dtype=dtype))
100
- else:
101
- return x