brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import jax.numpy as jnp
6
+ import pytest
7
+ from absl.testing import absltest
8
+ from absl.testing import parameterized
9
+
10
+ import brainstate as bst
11
+
12
+
13
+ class TestConv(parameterized.TestCase):
14
+ def test_Conv2D_img(self):
15
+ img = jnp.zeros((2, 200, 198, 4))
16
+ for k in range(4):
17
+ x = 30 + 60 * k
18
+ y = 20 + 60 * k
19
+ img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
20
+ img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
21
+
22
+ net = bst.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
23
+ stride=(2, 1), padding='VALID', groups=4)
24
+ out = net(img)
25
+ print("out shape: ", out.shape)
26
+ self.assertEqual(out.shape, (2, 99, 196, 32))
27
+ # print("First output channel:")
28
+ # plt.figure(figsize=(10, 10))
29
+ # plt.imshow(np.array(img)[0, :, :, 0])
30
+ # plt.show()
31
+
32
+ def test_conv1D(self):
33
+ model = bst.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
34
+ input = jnp.ones((2, 5, 3))
35
+ out = model(input)
36
+ print("out shape: ", out.shape)
37
+ self.assertEqual(out.shape, (2, 5, 32))
38
+ # print("First output channel:")
39
+ # plt.figure(figsize=(10, 10))
40
+ # plt.imshow(np.array(out)[0, :, :])
41
+ # plt.show()
42
+
43
+ def test_conv2D(self):
44
+ model = bst.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
45
+ input = jnp.ones((2, 5, 5, 3))
46
+
47
+ out = model(input)
48
+ print("out shape: ", out.shape)
49
+ self.assertEqual(out.shape, (2, 5, 5, 32))
50
+
51
+ def test_conv3D(self):
52
+ model = bst.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
53
+ input = jnp.ones((2, 5, 5, 5, 3))
54
+ out = model(input)
55
+ print("out shape: ", out.shape)
56
+ self.assertEqual(out.shape, (2, 5, 5, 5, 32))
57
+
58
+
59
+ @pytest.mark.skip(reason="not implemented yet")
60
+ class TestConvTranspose1d(parameterized.TestCase):
61
+ def test_conv_transpose(self):
62
+
63
+ x = jnp.ones((1, 8, 3))
64
+ for use_bias in [True, False]:
65
+ conv_transpose_module = bst.nn.ConvTranspose1d(
66
+ in_channels=3,
67
+ out_channels=4,
68
+ kernel_size=(3,),
69
+ padding='VALID',
70
+ w_initializer=bst.init.Constant(1.),
71
+ b_initializer=bst.init.Constant(1.) if use_bias else None,
72
+ )
73
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
74
+ y = conv_transpose_module(x)
75
+ print(y.shape)
76
+ correct_ans = jnp.array([[[4., 4., 4., 4.],
77
+ [7., 7., 7., 7.],
78
+ [10., 10., 10., 10.],
79
+ [10., 10., 10., 10.],
80
+ [10., 10., 10., 10.],
81
+ [10., 10., 10., 10.],
82
+ [10., 10., 10., 10.],
83
+ [10., 10., 10., 10.],
84
+ [7., 7., 7., 7.],
85
+ [4., 4., 4., 4.]]])
86
+ if not use_bias:
87
+ correct_ans -= 1.
88
+ self.assertTrue(jnp.allclose(y, correct_ans))
89
+
90
+ def test_single_input_masked_conv_transpose(self):
91
+
92
+ x = jnp.ones((1, 8, 3))
93
+ m = jnp.tril(jnp.ones((3, 3, 4)))
94
+ conv_transpose_module = bst.nn.ConvTranspose1d(
95
+ in_channels=3,
96
+ out_channels=4,
97
+ kernel_size=(3,),
98
+ padding='VALID',
99
+ mask=m,
100
+ w_initializer=bst.init.Constant(),
101
+ b_initializer=bst.init.Constant(),
102
+ )
103
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
104
+ y = conv_transpose_module(x)
105
+ print(y.shape)
106
+ correct_ans = jnp.array([[[4., 3., 2., 1.],
107
+ [7., 5., 3., 1.],
108
+ [10., 7., 4., 1.],
109
+ [10., 7., 4., 1.],
110
+ [10., 7., 4., 1.],
111
+ [10., 7., 4., 1.],
112
+ [10., 7., 4., 1.],
113
+ [10., 7., 4., 1.],
114
+ [7., 5., 3., 1.],
115
+ [4., 3., 2., 1.]]])
116
+ self.assertTrue(jnp.allclose(y, correct_ans))
117
+
118
+ def test_computation_padding_same(self):
119
+
120
+ data = jnp.ones([1, 3, 1])
121
+ for use_bias in [True, False]:
122
+ net = bst.nn.ConvTranspose1d(
123
+ in_channels=1,
124
+ out_channels=1,
125
+ kernel_size=3,
126
+ stride=1,
127
+ padding="SAME",
128
+ w_initializer=bst.init.Constant(),
129
+ b_initializer=bst.init.Constant() if use_bias else None,
130
+ )
131
+ out = net(data)
132
+ self.assertEqual(out.shape, (1, 3, 1))
133
+ out = jnp.squeeze(out, axis=(0, 2))
134
+ expected_out = jnp.asarray([2, 3, 2])
135
+ if use_bias:
136
+ expected_out += 1
137
+ self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
138
+
139
+
140
+ @pytest.mark.skip(reason="not implemented yet")
141
+ class TestConvTranspose2d(parameterized.TestCase):
142
+ def test_conv_transpose(self):
143
+
144
+ x = jnp.ones((1, 8, 8, 3))
145
+ for use_bias in [True, False]:
146
+ conv_transpose_module = bst.nn.ConvTranspose2d(
147
+ in_channels=3,
148
+ out_channels=4,
149
+ kernel_size=(3, 3),
150
+ padding='VALID',
151
+ w_initializer=bst.init.Constant(),
152
+ b_initializer=bst.init.Constant() if use_bias else None,
153
+ )
154
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
155
+ y = conv_transpose_module(x)
156
+ print(y.shape)
157
+
158
+ def test_single_input_masked_conv_transpose(self):
159
+
160
+ x = jnp.ones((1, 8, 8, 3))
161
+ m = jnp.tril(jnp.ones((3, 3, 3, 4)))
162
+ conv_transpose_module = bst.nn.ConvTranspose2d(
163
+ in_channels=3,
164
+ out_channels=4,
165
+ kernel_size=(3, 3),
166
+ padding='VALID',
167
+ mask=m,
168
+ w_initializer=bst.init.Constant(),
169
+ )
170
+ y = conv_transpose_module(x)
171
+ print(y.shape)
172
+
173
+ def test_computation_padding_same(self):
174
+
175
+ x = jnp.ones((1, 8, 8, 3))
176
+ for use_bias in [True, False]:
177
+ conv_transpose_module = bst.nn.ConvTranspose2d(
178
+ in_channels=3,
179
+ out_channels=4,
180
+ kernel_size=(3, 3),
181
+ stride=1,
182
+ padding='SAME',
183
+ w_initializer=bst.init.Constant(),
184
+ b_initializer=bst.init.Constant() if use_bias else None,
185
+ )
186
+ y = conv_transpose_module(x)
187
+ print(y.shape)
188
+
189
+
190
+ @pytest.mark.skip(reason="not implemented yet")
191
+ class TestConvTranspose3d(parameterized.TestCase):
192
+ def test_conv_transpose(self):
193
+
194
+ x = jnp.ones((1, 8, 8, 8, 3))
195
+ for use_bias in [True, False]:
196
+ conv_transpose_module = bst.nn.ConvTranspose3d(
197
+ in_channels=3,
198
+ out_channels=4,
199
+ kernel_size=(3, 3, 3),
200
+ padding='VALID',
201
+ w_initializer=bst.init.Constant(),
202
+ b_initializer=bst.init.Constant() if use_bias else None,
203
+ )
204
+ y = conv_transpose_module(x)
205
+ print(y.shape)
206
+
207
+ def test_single_input_masked_conv_transpose(self):
208
+
209
+ x = jnp.ones((1, 8, 8, 8, 3))
210
+ m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
211
+ conv_transpose_module = bst.nn.ConvTranspose3d(
212
+ in_channels=3,
213
+ out_channels=4,
214
+ kernel_size=(3, 3, 3),
215
+ padding='VALID',
216
+ mask=m,
217
+ w_initializer=bst.init.Constant(),
218
+ )
219
+ y = conv_transpose_module(x)
220
+ print(y.shape)
221
+
222
+ def test_computation_padding_same(self):
223
+
224
+ x = jnp.ones((1, 8, 8, 8, 3))
225
+ for use_bias in [True, False]:
226
+ conv_transpose_module = bst.nn.ConvTranspose3d(
227
+ in_channels=3,
228
+ out_channels=4,
229
+ kernel_size=(3, 3, 3),
230
+ stride=1,
231
+ padding='SAME',
232
+ w_initializer=bst.init.Constant(),
233
+ b_initializer=bst.init.Constant() if use_bias else None,
234
+ )
235
+ y = conv_transpose_module(x)
236
+ print(y.shape)
237
+
238
+
239
+ class TestDense(parameterized.TestCase):
240
+ @parameterized.product(
241
+ size=[(10,),
242
+ (20, 10),
243
+ (5, 8, 10)],
244
+ num_out=[20, ]
245
+ )
246
+ def test_Dense1(self, size, num_out):
247
+ f = bst.nn.Linear(10, num_out)
248
+ x = bst.random.random(size)
249
+ y = f(x)
250
+ self.assertTrue(y.shape == size[:-1] + (num_out,))
251
+
252
+
253
+ if __name__ == '__main__':
254
+ absltest.main()
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional, Callable, Union
18
+
19
+ from brainstate import init
20
+ from brainstate._state import ParamState
21
+ from brainstate.nn._module import Module
22
+ from brainstate.typing import ArrayLike
23
+
24
+ __all__ = [
25
+ 'Embedding',
26
+ ]
27
+
28
+
29
+ class Embedding(Module):
30
+ r"""
31
+ A simple lookup table that stores embeddings of a fixed size.
32
+
33
+ Args:
34
+ num_embeddings: Size of embedding dictionary. Must be non-negative.
35
+ embedding_size: Size of each embedding vector. Must be non-negative.
36
+ embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ num_embeddings: int,
42
+ embedding_size: int,
43
+ embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
44
+ name: Optional[str] = None,
45
+ ):
46
+ super().__init__(name=name)
47
+ if num_embeddings < 0:
48
+ raise ValueError("num_embeddings must not be negative.")
49
+ if embedding_size < 0:
50
+ raise ValueError("embedding_size must not be negative.")
51
+ self.num_embeddings = num_embeddings
52
+ self.embedding_size = embedding_size
53
+ self.out_size = (embedding_size,)
54
+
55
+ weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
56
+ self.weight = ParamState(weight)
57
+
58
+ def update(self, indices: ArrayLike):
59
+ return self.weight.value[indices]
@@ -0,0 +1,388 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import numbers
21
+ from typing import Callable, Union, Sequence, Optional, Any
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+
26
+ from brainstate import environ, init
27
+ from brainstate._state import LongTermState, ParamState
28
+ from brainstate.nn._module import Module
29
+ from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
30
+
31
+ __all__ = [
32
+ 'BatchNorm0d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
33
+ ]
34
+
35
+
36
+ def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
37
+ axes = []
38
+ for axis in feature_axes:
39
+ if axis < 0:
40
+ axis += ndim
41
+ if axis < 0 or axis >= ndim:
42
+ raise ValueError(f'Invalid axis {axis} for {ndim}D input')
43
+ axes.append(axis)
44
+ return tuple(axes)
45
+
46
+
47
+ def _abs_sq(x):
48
+ """Computes the elementwise square of the absolute value |x|^2."""
49
+ if jnp.iscomplexobj(x):
50
+ return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
51
+ else:
52
+ return jax.lax.square(x)
53
+
54
+
55
+ def _compute_stats(
56
+ x: ArrayLike,
57
+ axes: Sequence[int],
58
+ dtype: DTypeLike,
59
+ axis_name: Optional[str] = None,
60
+ axis_index_groups: Optional[Sequence[int]] = None,
61
+ use_mean: bool = True,
62
+ ):
63
+ """Computes mean and variance statistics.
64
+
65
+ This implementation takes care of a few important details:
66
+ - Computes in float32 precision for stability in half precision training.
67
+ - mean and variance are computable in a single XLA fusion,
68
+ by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
69
+ - Clips negative variances to zero which can happen due to
70
+ roundoff errors. This avoids downstream NaNs.
71
+ - Supports averaging across a parallel axis and subgroups of a parallel axis
72
+ with a single `lax.pmean` call to avoid latency.
73
+
74
+ Arguments:
75
+ x: Input array.
76
+ axes: The axes in ``x`` to compute mean and variance statistics for.
77
+ dtype: tp.Optional dtype specifying the minimal precision. Statistics
78
+ are always at least float32 for stability (default: dtype of x).
79
+ axis_name: tp.Optional name for the pmapped axis to compute mean over.
80
+ axis_index_groups: tp.Optional axis indices.
81
+ use_mean: If true, calculate the mean from the input and use it when
82
+ computing the variance. If false, set the mean to zero and compute
83
+ the variance without subtracting the mean.
84
+
85
+ Returns:
86
+ A pair ``(mean, val)``.
87
+ """
88
+ if dtype is None:
89
+ dtype = jax.numpy.result_type(x)
90
+ # promote x to at least float32, this avoids half precision computation
91
+ # but preserves double or complex floating points
92
+ dtype = jax.numpy.promote_types(dtype, environ.dftype())
93
+ x = jnp.asarray(x, dtype)
94
+
95
+ # Compute mean and mean of squared values.
96
+ mean2 = jnp.mean(_abs_sq(x), axes)
97
+ if use_mean:
98
+ mean = jnp.mean(x, axes)
99
+ else:
100
+ mean = jnp.zeros(mean2.shape, dtype=dtype)
101
+
102
+ # If axis_name is provided, we need to average the mean and mean2 across
103
+ if axis_name is not None:
104
+ concatenated_mean = jnp.concatenate([mean, mean2])
105
+ mean, mean2 = jnp.split(
106
+ jax.lax.pmean(
107
+ concatenated_mean,
108
+ axis_name=axis_name,
109
+ axis_index_groups=axis_index_groups,
110
+ ),
111
+ 2,
112
+ )
113
+
114
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
115
+ # to floating point round-off errors.
116
+ var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
117
+ return mean, var
118
+
119
+
120
+ def _normalize(
121
+ x: ArrayLike,
122
+ mean: Optional[ArrayLike],
123
+ var: Optional[ArrayLike],
124
+ weights: Optional[ParamState],
125
+ reduction_axes: Sequence[int],
126
+ dtype: DTypeLike,
127
+ epsilon: Union[numbers.Number, jax.Array],
128
+ ):
129
+ """Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
130
+
131
+ Arguments:
132
+ x: The input.
133
+ mean: Mean to use for normalization.
134
+ var: Variance to use for normalization.
135
+ weights: The scale and bias parameters.
136
+ reduction_axes: The axes in ``x`` to reduce.
137
+ dtype: The dtype of the result (default: infer from input and params).
138
+ epsilon: Normalization epsilon.
139
+
140
+ Returns:
141
+ The normalized input.
142
+ """
143
+ if mean is not None:
144
+ assert var is not None, 'mean and val must be both None or not None.'
145
+ stats_shape = list(x.shape)
146
+ for axis in reduction_axes:
147
+ stats_shape[axis] = 1
148
+ mean = mean.reshape(stats_shape)
149
+ var = var.reshape(stats_shape)
150
+ y = x - mean
151
+ mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
152
+ y = y * mul
153
+ if weights is not None:
154
+ y = _scale_operation(y, weights.value)
155
+ else:
156
+ assert var is None, 'mean and val must be both None or not None.'
157
+ assert weights is None, 'scale and bias are not supported without mean and val'
158
+ y = x
159
+ return jnp.asarray(y, dtype)
160
+
161
+
162
+ def _scale_operation(x, param):
163
+ if 'scale' in param:
164
+ x = x * param['scale']
165
+ if 'bias' in param:
166
+ x = x + param['bias']
167
+ return x
168
+
169
+
170
+ class _BatchNorm(Module):
171
+ __module__ = 'brainstate.nn'
172
+ num_spatial_dims: int
173
+
174
+ def __init__(
175
+ self,
176
+ in_size: Size,
177
+ feature_axis: Axes = -1,
178
+ track_running_stats: bool = True,
179
+ epsilon: float = 1e-5,
180
+ momentum: float = 0.99,
181
+ affine: bool = True,
182
+ bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
183
+ scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
184
+ axis_name: Optional[Union[str, Sequence[str]]] = None,
185
+ axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
186
+ name: Optional[str] = None,
187
+ dtype: Any = None,
188
+ ):
189
+ super().__init__(name=name)
190
+
191
+ # parameters
192
+ self.in_size = tuple(in_size)
193
+ self.out_size = tuple(in_size)
194
+ self.affine = affine
195
+ self.bias_initializer = bias_initializer
196
+ self.scale_initializer = scale_initializer
197
+ self.dtype = dtype or environ.dftype()
198
+ self.track_running_stats = track_running_stats
199
+ self.momentum = jnp.asarray(momentum, dtype=self.dtype)
200
+ self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
201
+
202
+ # parameters about axis
203
+ feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
204
+ self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
205
+ self.axis_name = axis_name
206
+ self.axis_index_groups = axis_index_groups
207
+
208
+ # variables
209
+ feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
210
+ if self.track_running_stats:
211
+ self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
212
+ self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
213
+ else:
214
+ self.running_mean = None
215
+ self.running_var = None
216
+
217
+ # parameters
218
+ if self.affine:
219
+ assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
220
+ bias = init.param(self.bias_initializer, feature_shape)
221
+ scale = init.param(self.scale_initializer, feature_shape)
222
+ self.weight = ParamState(dict(bias=bias, scale=scale))
223
+ else:
224
+ self.weight = None
225
+
226
+ def update(self, x):
227
+ # input shape and batch mode or not
228
+ if x.ndim == self.num_spatial_dims + 2:
229
+ x_shape = x.shape[1:]
230
+ batch = True
231
+ elif x.ndim == self.num_spatial_dims + 1:
232
+ x_shape = x.shape
233
+ batch = False
234
+ else:
235
+ raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
236
+ f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
237
+ if self.in_size != x_shape:
238
+ raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
239
+
240
+ # reduce the feature axis
241
+ if batch:
242
+ reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
243
+ else:
244
+ reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
245
+
246
+ # fitting phase
247
+ fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
248
+
249
+ # compute the running mean and variance
250
+ if self.track_running_stats:
251
+ if fit_phase:
252
+ mean, var = _compute_stats(
253
+ x,
254
+ reduction_axes,
255
+ dtype=self.dtype,
256
+ axis_name=self.axis_name,
257
+ axis_index_groups=self.axis_index_groups,
258
+ )
259
+ self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
260
+ self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
261
+ else:
262
+ mean = self.running_mean.value
263
+ var = self.running_var.value
264
+ else:
265
+ mean, var = None, None
266
+
267
+ # normalize
268
+ return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
269
+
270
+
271
+ class BatchNorm0d(_BatchNorm):
272
+ r"""1-D batch normalization [1]_.
273
+
274
+ The data should be of `(b, l, c)`, where `b` is the batch dimension,
275
+ `l` is the layer dimension, and `c` is the channel dimension.
276
+
277
+ %s
278
+ """
279
+ __module__ = 'brainstate.nn'
280
+ num_spatial_dims: int = 0
281
+
282
+
283
+ class BatchNorm1d(_BatchNorm):
284
+ r"""1-D batch normalization [1]_.
285
+
286
+ The data should be of `(b, l, c)`, where `b` is the batch dimension,
287
+ `l` is the layer dimension, and `c` is the channel dimension.
288
+
289
+ %s
290
+ """
291
+ __module__ = 'brainstate.nn'
292
+ num_spatial_dims: int = 1
293
+
294
+
295
+ class BatchNorm2d(_BatchNorm):
296
+ r"""2-D batch normalization [1]_.
297
+
298
+ The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
299
+ `h` is the height dimension, `w` is the width dimension, and `c` is the
300
+ channel dimension.
301
+
302
+ %s
303
+ """
304
+ __module__ = 'brainstate.nn'
305
+ num_spatial_dims: int = 2
306
+
307
+
308
+ class BatchNorm3d(_BatchNorm):
309
+ r"""3-D batch normalization [1]_.
310
+
311
+ The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
312
+ `h` is the height dimension, `w` is the width dimension, `d` is the depth
313
+ dimension, and `c` is the channel dimension.
314
+
315
+ %s
316
+ """
317
+ __module__ = 'brainstate.nn'
318
+ num_spatial_dims: int = 3
319
+
320
+
321
+ _bn_doc = r'''
322
+
323
+ This layer aims to reduce the internal covariant shift of data. It
324
+ normalizes a batch of data by fixing the mean and variance of inputs
325
+ on each feature (channel). Most commonly, the first axis of the data
326
+ is the batch, and the last is the channel. However, users can specify
327
+ the axes to be normalized.
328
+
329
+ .. math::
330
+ y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
331
+
332
+ .. note::
333
+ This :attr:`momentum` argument is different from one used in optimizer
334
+ classes and the conventional notion of momentum. Mathematically, the
335
+ update rule for running statistics here is
336
+ :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
337
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
338
+ new observed value.
339
+
340
+ Parameters
341
+ ----------
342
+ in_size: sequence of int
343
+ The input shape, without batch size.
344
+ feature_axis: int, tuple, list
345
+ The feature or non-batch axis of the input.
346
+ track_running_stats: bool
347
+ A boolean value that when set to ``True``, this module tracks the running mean and variance,
348
+ and when set to ``False``, this module does not track such statistics, and initializes
349
+ statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
350
+ this module always uses batch statistics. in both training and eval modes. Default: ``True``.
351
+ momentum: float
352
+ The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
353
+ epsilon: float
354
+ A value added to the denominator for numerical stability. Default: 1e-5
355
+ affine: bool
356
+ A boolean value that when set to ``True``, this module has
357
+ learnable affine parameters. Default: ``True``
358
+ bias_initializer: ArrayLike, Callable
359
+ An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
360
+ Default: ``init.Constant(0.)``
361
+ scale_initializer: ArrayLike, Callable
362
+ An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
363
+ Default: ``init.Constant(1.)``
364
+ axis_name: optional, str, sequence of str
365
+ If not ``None``, it should be a string (or sequence of
366
+ strings) representing the axis name(s) over which this module is being
367
+ run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
368
+ argument means that batch statistics are calculated across all replicas
369
+ on the named axes.
370
+ axis_index_groups: optional, sequence
371
+ Specifies how devices are grouped. Valid
372
+ only within ``jax.pmap`` collectives.
373
+ Groups of axis indices within that named axis
374
+ representing subsets of devices to reduce over (default: None). For
375
+ example, `[[0, 1], [2, 3]]` would independently batch-normalize over
376
+ the examples on the first two and last two devices. See `jax.lax.psum`
377
+ for more details.
378
+
379
+ References
380
+ ----------
381
+ .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
382
+ by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
383
+
384
+ '''
385
+
386
+ BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
387
+ BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
388
+ BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc