brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,73 +1,699 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from absl.testing import absltest
17
- from absl.testing import parameterized
18
-
19
- import brainstate
20
-
21
-
22
- class Test_Normalization(parameterized.TestCase):
23
- @parameterized.product(
24
- fit=[True, False],
25
- )
26
- def test_BatchNorm1d(self, fit):
27
- net = brainstate.nn.BatchNorm1d((3, 10))
28
- brainstate.environ.set(fit=fit)
29
- input = brainstate.random.randn(1, 3, 10)
30
- output = net(input)
31
-
32
- @parameterized.product(
33
- fit=[True, False]
34
- )
35
- def test_BatchNorm2d(self, fit):
36
- net = brainstate.nn.BatchNorm2d([3, 4, 10])
37
- brainstate.environ.set(fit=fit)
38
- input = brainstate.random.randn(1, 3, 4, 10)
39
- output = net(input)
40
-
41
- @parameterized.product(
42
- fit=[True, False]
43
- )
44
- def test_BatchNorm3d(self, fit):
45
- net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
46
- brainstate.environ.set(fit=fit)
47
- input = brainstate.random.randn(1, 3, 4, 5, 10)
48
- output = net(input)
49
-
50
- # @parameterized.product(
51
- # normalized_shape=(10, [5, 10])
52
- # )
53
- # def test_LayerNorm(self, normalized_shape):
54
- # net = brainstate.nn.LayerNorm(normalized_shape, )
55
- # input = brainstate.random.randn(20, 5, 10)
56
- # output = net(input)
57
- #
58
- # @parameterized.product(
59
- # num_groups=[1, 2, 3, 6]
60
- # )
61
- # def test_GroupNorm(self, num_groups):
62
- # input = brainstate.random.randn(20, 10, 10, 6)
63
- # net = brainstate.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
64
- # output = net(input)
65
- #
66
- # def test_InstanceNorm(self):
67
- # input = brainstate.random.randn(20, 10, 10, 6)
68
- # net = brainstate.nn.InstanceNorm(num_channels=6, )
69
- # output = net(input)
70
-
71
-
72
- if __name__ == '__main__':
73
- absltest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from absl.testing import absltest
17
+ from absl.testing import parameterized
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+
21
+ import brainstate
22
+
23
+
24
+ class TestBatchNorm0d(parameterized.TestCase):
25
+ """Test BatchNorm0d with various configurations."""
26
+
27
+ @parameterized.product(
28
+ fit=[True, False],
29
+ feature_axis=[-1, 0],
30
+ track_running_stats=[True, False],
31
+ )
32
+ def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats):
33
+ """Test BatchNorm0d with batched input."""
34
+ batch_size = 8
35
+ channels = 10
36
+
37
+ # Channel last: (batch, channels)
38
+ if feature_axis == -1:
39
+ in_size = (channels,)
40
+ input_shape = (batch_size, channels)
41
+ # Channel first: (batch, channels) - same for 0D
42
+ else:
43
+ in_size = (channels,)
44
+ input_shape = (batch_size, channels)
45
+
46
+ # affine can only be True when track_running_stats is True
47
+ affine = track_running_stats
48
+
49
+ net = brainstate.nn.BatchNorm0d(
50
+ in_size,
51
+ feature_axis=feature_axis,
52
+ track_running_stats=track_running_stats,
53
+ affine=affine
54
+ )
55
+ brainstate.environ.set(fit=fit)
56
+
57
+ x = brainstate.random.randn(*input_shape)
58
+ output = net(x)
59
+
60
+ # Check output shape matches input
61
+ self.assertEqual(output.shape, input_shape)
62
+
63
+ # Check that output has approximately zero mean and unit variance when fitting
64
+ if fit and track_running_stats:
65
+ # Stats should be computed along batch dimension
66
+ mean = jnp.mean(output, axis=0)
67
+ var = jnp.var(output, axis=0)
68
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
69
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
70
+
71
+ def test_batchnorm0d_without_batch(self):
72
+ """Test BatchNorm0d without batching."""
73
+ channels = 10
74
+ in_size = (channels,)
75
+
76
+ net = brainstate.nn.BatchNorm0d(in_size, track_running_stats=True)
77
+ brainstate.environ.set(fit=False) # Use running stats
78
+
79
+ # Run with batch first to populate running stats
80
+ brainstate.environ.set(fit=True)
81
+ x_batch = brainstate.random.randn(16, channels)
82
+ _ = net(x_batch)
83
+
84
+ # Now test without batch
85
+ brainstate.environ.set(fit=False)
86
+ x_single = brainstate.random.randn(channels)
87
+ output = net(x_single)
88
+
89
+ self.assertEqual(output.shape, (channels,))
90
+
91
+ def test_batchnorm0d_affine(self):
92
+ """Test BatchNorm0d with and without affine parameters."""
93
+ channels = 10
94
+ in_size = (channels,)
95
+
96
+ # With affine
97
+ net_affine = brainstate.nn.BatchNorm0d(in_size, affine=True)
98
+ self.assertIsNotNone(net_affine.weight)
99
+
100
+ # Without affine (track_running_stats must be False)
101
+ net_no_affine = brainstate.nn.BatchNorm0d(
102
+ in_size, affine=False, track_running_stats=False
103
+ )
104
+ self.assertIsNone(net_no_affine.weight)
105
+
106
+
107
+ class TestBatchNorm1d(parameterized.TestCase):
108
+ """Test BatchNorm1d with various configurations."""
109
+
110
+ @parameterized.product(
111
+ fit=[True, False],
112
+ feature_axis=[-1, 0],
113
+ track_running_stats=[True, False],
114
+ )
115
+ def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats):
116
+ """Test BatchNorm1d with batched input."""
117
+ batch_size = 8
118
+ length = 20
119
+ channels = 10
120
+
121
+ # Channel last: (batch, length, channels)
122
+ if feature_axis == -1:
123
+ in_size = (length, channels)
124
+ input_shape = (batch_size, length, channels)
125
+ feature_axis_param = -1
126
+ # Channel first: (batch, channels, length)
127
+ else:
128
+ in_size = (channels, length)
129
+ input_shape = (batch_size, channels, length)
130
+ feature_axis_param = 0
131
+
132
+ # affine can only be True when track_running_stats is True
133
+ affine = track_running_stats
134
+
135
+ net = brainstate.nn.BatchNorm1d(
136
+ in_size,
137
+ feature_axis=feature_axis_param,
138
+ track_running_stats=track_running_stats,
139
+ affine=affine
140
+ )
141
+ brainstate.environ.set(fit=fit)
142
+
143
+ x = brainstate.random.randn(*input_shape)
144
+ output = net(x)
145
+
146
+ # Check output shape matches input
147
+ self.assertEqual(output.shape, input_shape)
148
+
149
+ def test_batchnorm1d_without_batch(self):
150
+ """Test BatchNorm1d without batching."""
151
+ length = 20
152
+ channels = 10
153
+ in_size = (length, channels)
154
+
155
+ net = brainstate.nn.BatchNorm1d(in_size, track_running_stats=True)
156
+
157
+ # Populate running stats first
158
+ brainstate.environ.set(fit=True)
159
+ x_batch = brainstate.random.randn(8, length, channels)
160
+ _ = net(x_batch)
161
+
162
+ # Test without batch
163
+ brainstate.environ.set(fit=False)
164
+ x_single = brainstate.random.randn(length, channels)
165
+ output = net(x_single)
166
+
167
+ self.assertEqual(output.shape, (length, channels))
168
+
169
+ @parameterized.product(
170
+ feature_axis=[-1, 0],
171
+ )
172
+ def test_batchnorm1d_channel_consistency(self, feature_axis):
173
+ """Test that normalization is consistent across different channel configurations."""
174
+ batch_size = 16
175
+ length = 20
176
+ channels = 10
177
+
178
+ if feature_axis == -1:
179
+ in_size = (length, channels)
180
+ input_shape = (batch_size, length, channels)
181
+ else:
182
+ in_size = (channels, length)
183
+ input_shape = (batch_size, channels, length)
184
+
185
+ net = brainstate.nn.BatchNorm1d(in_size, feature_axis=feature_axis)
186
+ brainstate.environ.set(fit=True)
187
+
188
+ x = brainstate.random.randn(*input_shape)
189
+ output = net(x)
190
+
191
+ # Output should have same shape as input
192
+ self.assertEqual(output.shape, input_shape)
193
+
194
+
195
+ class TestBatchNorm2d(parameterized.TestCase):
196
+ """Test BatchNorm2d with various configurations."""
197
+
198
+ @parameterized.product(
199
+ fit=[True, False],
200
+ feature_axis=[-1, 0],
201
+ track_running_stats=[True, False],
202
+ )
203
+ def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats):
204
+ """Test BatchNorm2d with batched input (images)."""
205
+ batch_size = 4
206
+ height, width = 28, 28
207
+ channels = 3
208
+
209
+ # Channel last: (batch, height, width, channels)
210
+ if feature_axis == -1:
211
+ in_size = (height, width, channels)
212
+ input_shape = (batch_size, height, width, channels)
213
+ feature_axis_param = -1
214
+ # Channel first: (batch, channels, height, width)
215
+ else:
216
+ in_size = (channels, height, width)
217
+ input_shape = (batch_size, channels, height, width)
218
+ feature_axis_param = 0
219
+
220
+ # affine can only be True when track_running_stats is True
221
+ affine = track_running_stats
222
+
223
+ net = brainstate.nn.BatchNorm2d(
224
+ in_size,
225
+ feature_axis=feature_axis_param,
226
+ track_running_stats=track_running_stats,
227
+ affine=affine
228
+ )
229
+ brainstate.environ.set(fit=fit)
230
+
231
+ x = brainstate.random.randn(*input_shape)
232
+ output = net(x)
233
+
234
+ # Check output shape matches input
235
+ self.assertEqual(output.shape, input_shape)
236
+
237
+ # Check normalization properties during training
238
+ if fit and track_running_stats:
239
+ # For channel last: normalize over (batch, height, width)
240
+ # For channel first: normalize over (batch, height, width)
241
+ if feature_axis == -1:
242
+ axes = (0, 1, 2)
243
+ else:
244
+ axes = (0, 2, 3)
245
+
246
+ mean = jnp.mean(output, axis=axes)
247
+ var = jnp.var(output, axis=axes)
248
+
249
+ # Each channel should be approximately normalized
250
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
251
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
252
+
253
+ def test_batchnorm2d_without_batch(self):
254
+ """Test BatchNorm2d without batching."""
255
+ height, width = 28, 28
256
+ channels = 3
257
+ in_size = (height, width, channels)
258
+
259
+ net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
260
+
261
+ # Populate running stats
262
+ brainstate.environ.set(fit=True)
263
+ x_batch = brainstate.random.randn(8, height, width, channels)
264
+ _ = net(x_batch)
265
+
266
+ # Test without batch
267
+ brainstate.environ.set(fit=False)
268
+ x_single = brainstate.random.randn(height, width, channels)
269
+ output = net(x_single)
270
+
271
+ self.assertEqual(output.shape, (height, width, channels))
272
+
273
+
274
+ class TestBatchNorm3d(parameterized.TestCase):
275
+ """Test BatchNorm3d with various configurations."""
276
+
277
+ @parameterized.product(
278
+ fit=[True, False],
279
+ feature_axis=[-1, 0],
280
+ track_running_stats=[True, False],
281
+ )
282
+ def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats):
283
+ """Test BatchNorm3d with batched input (volumes)."""
284
+ batch_size = 2
285
+ depth, height, width = 8, 16, 16
286
+ channels = 2
287
+
288
+ # Channel last: (batch, depth, height, width, channels)
289
+ if feature_axis == -1:
290
+ in_size = (depth, height, width, channels)
291
+ input_shape = (batch_size, depth, height, width, channels)
292
+ feature_axis_param = -1
293
+ # Channel first: (batch, channels, depth, height, width)
294
+ else:
295
+ in_size = (channels, depth, height, width)
296
+ input_shape = (batch_size, channels, depth, height, width)
297
+ feature_axis_param = 0
298
+
299
+ # affine can only be True when track_running_stats is True
300
+ affine = track_running_stats
301
+
302
+ net = brainstate.nn.BatchNorm3d(
303
+ in_size,
304
+ feature_axis=feature_axis_param,
305
+ track_running_stats=track_running_stats,
306
+ affine=affine
307
+ )
308
+ brainstate.environ.set(fit=fit)
309
+
310
+ x = brainstate.random.randn(*input_shape)
311
+ output = net(x)
312
+
313
+ # Check output shape matches input
314
+ self.assertEqual(output.shape, input_shape)
315
+
316
+ def test_batchnorm3d_without_batch(self):
317
+ """Test BatchNorm3d without batching."""
318
+ depth, height, width = 8, 16, 16
319
+ channels = 2
320
+ in_size = (depth, height, width, channels)
321
+
322
+ net = brainstate.nn.BatchNorm3d(in_size, track_running_stats=True)
323
+
324
+ # Populate running stats
325
+ brainstate.environ.set(fit=True)
326
+ x_batch = brainstate.random.randn(4, depth, height, width, channels)
327
+ _ = net(x_batch)
328
+
329
+ # Test without batch
330
+ brainstate.environ.set(fit=False)
331
+ x_single = brainstate.random.randn(depth, height, width, channels)
332
+ output = net(x_single)
333
+
334
+ self.assertEqual(output.shape, (depth, height, width, channels))
335
+
336
+
337
+ class TestLayerNorm(parameterized.TestCase):
338
+ """Test LayerNorm with various configurations."""
339
+
340
+ @parameterized.product(
341
+ reduction_axes=[(-1,), (-2, -1), (-3, -2, -1)],
342
+ use_bias=[True, False],
343
+ use_scale=[True, False],
344
+ )
345
+ def test_layernorm_basic(self, reduction_axes, use_bias, use_scale):
346
+ """Test LayerNorm with different reduction axes."""
347
+ in_size = (10, 20, 30)
348
+
349
+ net = brainstate.nn.LayerNorm(
350
+ in_size,
351
+ reduction_axes=reduction_axes,
352
+ use_bias=use_bias,
353
+ use_scale=use_scale,
354
+ )
355
+
356
+ # With batch
357
+ x = brainstate.random.randn(8, 10, 20, 30)
358
+ output = net(x)
359
+ self.assertEqual(output.shape, x.shape)
360
+
361
+ # Check normalization properties
362
+ mean = jnp.mean(output, axis=tuple(i + 1 for i in range(len(in_size))
363
+ if i - len(in_size) in reduction_axes))
364
+ var = jnp.var(output, axis=tuple(i + 1 for i in range(len(in_size))
365
+ if i - len(in_size) in reduction_axes))
366
+
367
+ def test_layernorm_2d_features(self):
368
+ """Test LayerNorm on 2D features (like in transformers)."""
369
+ seq_length = 50
370
+ hidden_dim = 128
371
+ batch_size = 16
372
+
373
+ in_size = (seq_length, hidden_dim)
374
+ net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1)
375
+
376
+ x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
377
+ output = net(x)
378
+
379
+ self.assertEqual(output.shape, x.shape)
380
+
381
+ # Each position should be normalized independently
382
+ mean = jnp.mean(output, axis=-1)
383
+ var = jnp.var(output, axis=-1)
384
+
385
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
386
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
387
+
388
+ def test_layernorm_without_batch(self):
389
+ """Test LayerNorm without batch dimension."""
390
+ in_size = (10, 20)
391
+ net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1)
392
+
393
+ x = brainstate.random.randn(10, 20)
394
+ output = net(x)
395
+
396
+ self.assertEqual(output.shape, (10, 20))
397
+
398
+ @parameterized.product(
399
+ in_size=[(10,), (10, 20), (10, 20, 30)],
400
+ )
401
+ def test_layernorm_various_dims(self, in_size):
402
+ """Test LayerNorm with various input dimensions."""
403
+ net = brainstate.nn.LayerNorm(in_size)
404
+
405
+ # With batch
406
+ x_with_batch = brainstate.random.randn(8, *in_size)
407
+ output_with_batch = net(x_with_batch)
408
+ self.assertEqual(output_with_batch.shape, x_with_batch.shape)
409
+
410
+
411
+ class TestRMSNorm(parameterized.TestCase):
412
+ """Test RMSNorm with various configurations."""
413
+
414
+ @parameterized.product(
415
+ use_scale=[True, False],
416
+ reduction_axes=[(-1,), (-2, -1)],
417
+ )
418
+ def test_rmsnorm_basic(self, use_scale, reduction_axes):
419
+ """Test RMSNorm with different configurations."""
420
+ in_size = (10, 20)
421
+
422
+ net = brainstate.nn.RMSNorm(
423
+ in_size,
424
+ use_scale=use_scale,
425
+ reduction_axes=reduction_axes,
426
+ )
427
+
428
+ x = brainstate.random.randn(8, 10, 20)
429
+ output = net(x)
430
+
431
+ self.assertEqual(output.shape, x.shape)
432
+
433
+ def test_rmsnorm_transformer_like(self):
434
+ """Test RMSNorm in transformer-like setting."""
435
+ seq_length = 50
436
+ hidden_dim = 128
437
+ batch_size = 16
438
+
439
+ in_size = (seq_length, hidden_dim)
440
+ net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1, feature_axes=-1)
441
+
442
+ x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
443
+ output = net(x)
444
+
445
+ self.assertEqual(output.shape, x.shape)
446
+
447
+ # RMSNorm should have approximately unit RMS (not zero mean)
448
+ rms = jnp.sqrt(jnp.mean(jnp.square(output), axis=-1))
449
+ np.testing.assert_allclose(rms, 1.0, atol=1e-1)
450
+
451
+ def test_rmsnorm_without_batch(self):
452
+ """Test RMSNorm without batch dimension."""
453
+ in_size = (10, 20)
454
+ net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1)
455
+
456
+ x = brainstate.random.randn(10, 20)
457
+ output = net(x)
458
+
459
+ self.assertEqual(output.shape, (10, 20))
460
+
461
+
462
+ class TestGroupNorm(parameterized.TestCase):
463
+ """Test GroupNorm with various configurations."""
464
+
465
+ @parameterized.product(
466
+ num_groups=[1, 2, 4, 8],
467
+ use_bias=[True, False],
468
+ use_scale=[True, False],
469
+ )
470
+ def test_groupnorm_basic(self, num_groups, use_bias, use_scale):
471
+ """Test GroupNorm with different number of groups."""
472
+ channels = 16
473
+ # GroupNorm requires 1D feature axis (just the channel dimension)
474
+ in_size = (channels,)
475
+
476
+ # Check if channels is divisible by num_groups
477
+ if channels % num_groups != 0:
478
+ return
479
+
480
+ net = brainstate.nn.GroupNorm(
481
+ in_size,
482
+ feature_axis=0,
483
+ num_groups=num_groups,
484
+ use_bias=use_bias,
485
+ use_scale=use_scale,
486
+ )
487
+
488
+ # Input needs at least 2D: (height, width, channels) or (batch, channels)
489
+ # Using (batch, channels) format
490
+ x = brainstate.random.randn(4, channels)
491
+ output = net(x)
492
+
493
+ self.assertEqual(output.shape, x.shape)
494
+
495
+ def test_groupnorm_channel_first(self):
496
+ """Test GroupNorm with channel-first format for images."""
497
+ channels = 16
498
+ # GroupNorm requires 1D feature (just channels)
499
+ in_size = (channels,)
500
+
501
+ net = brainstate.nn.GroupNorm(
502
+ in_size,
503
+ feature_axis=0,
504
+ num_groups=4,
505
+ )
506
+
507
+ # Test with image-like data: (batch, height, width, channels)
508
+ x = brainstate.random.randn(4, 32, 32, channels)
509
+ output = net(x)
510
+
511
+ self.assertEqual(output.shape, x.shape)
512
+
513
+ def test_groupnorm_channel_last(self):
514
+ """Test GroupNorm with channel-last format for images."""
515
+ channels = 16
516
+ # GroupNorm requires 1D feature (just channels)
517
+ in_size = (channels,)
518
+
519
+ net = brainstate.nn.GroupNorm(
520
+ in_size,
521
+ feature_axis=0, # feature_axis refers to position in in_size
522
+ num_groups=4,
523
+ )
524
+
525
+ # Test with image-like data: (batch, height, width, channels)
526
+ x = brainstate.random.randn(4, 32, 32, channels)
527
+ output = net(x)
528
+
529
+ self.assertEqual(output.shape, x.shape)
530
+
531
+ def test_groupnorm_equals_layernorm(self):
532
+ """Test that GroupNorm with num_groups=1 equals LayerNorm."""
533
+ channels = 16
534
+ # GroupNorm requires 1D feature
535
+ in_size = (channels,)
536
+
537
+ # GroupNorm with 1 group
538
+ group_norm = brainstate.nn.GroupNorm(
539
+ in_size,
540
+ feature_axis=0,
541
+ num_groups=1,
542
+ )
543
+
544
+ # LayerNorm with same setup
545
+ layer_norm = brainstate.nn.LayerNorm(
546
+ in_size,
547
+ reduction_axes=-1,
548
+ feature_axes=-1,
549
+ )
550
+
551
+ # Use 2D input: (batch, channels)
552
+ x = brainstate.random.randn(8, channels)
553
+
554
+ output_gn = group_norm(x)
555
+ output_ln = layer_norm(x)
556
+
557
+ # Shapes should match
558
+ self.assertEqual(output_gn.shape, output_ln.shape)
559
+
560
+ def test_groupnorm_group_size(self):
561
+ """Test GroupNorm with group_size instead of num_groups."""
562
+ channels = 16
563
+ group_size = 4
564
+ # GroupNorm requires 1D feature
565
+ in_size = (channels,)
566
+
567
+ net = brainstate.nn.GroupNorm(
568
+ in_size,
569
+ feature_axis=0,
570
+ num_groups=None,
571
+ group_size=group_size,
572
+ )
573
+
574
+ # Use 2D input: (batch, channels)
575
+ x = brainstate.random.randn(4, channels)
576
+ output = net(x)
577
+
578
+ self.assertEqual(output.shape, x.shape)
579
+ self.assertEqual(net.num_groups, channels // group_size)
580
+
581
+ def test_groupnorm_invalid_groups(self):
582
+ """Test that invalid num_groups raises error."""
583
+ channels = 15 # Not divisible by many numbers
584
+ # GroupNorm requires 1D feature
585
+ in_size = (channels,)
586
+
587
+ # Should raise error if num_groups doesn't divide channels
588
+ with self.assertRaises(ValueError):
589
+ net = brainstate.nn.GroupNorm(
590
+ in_size,
591
+ feature_axis=0,
592
+ num_groups=4, # 15 is not divisible by 4
593
+ )
594
+
595
+
596
+ class TestNormalizationUtilities(parameterized.TestCase):
597
+ """Test utility functions for normalization."""
598
+
599
+ def test_weight_standardization(self):
600
+ """Test weight_standardization function."""
601
+ w = brainstate.random.randn(3, 4, 5, 6)
602
+
603
+ w_std = brainstate.nn.weight_standardization(w, eps=1e-4)
604
+
605
+ self.assertEqual(w_std.shape, w.shape)
606
+
607
+ # Check that standardization works
608
+ # Mean should be close to 0 along non-output axes
609
+ mean = jnp.mean(w_std, axis=(0, 1, 2))
610
+ np.testing.assert_allclose(mean, 0.0, atol=1e-4)
611
+
612
+ def test_weight_standardization_with_gain(self):
613
+ """Test weight_standardization with gain parameter."""
614
+ w = brainstate.random.randn(3, 4, 5, 6)
615
+ gain = jnp.ones((6,))
616
+
617
+ w_std = brainstate.nn.weight_standardization(w, gain=gain)
618
+
619
+ self.assertEqual(w_std.shape, w.shape)
620
+
621
+
622
+ class TestNormalizationEdgeCases(parameterized.TestCase):
623
+ """Test edge cases and error conditions."""
624
+
625
+ def test_batchnorm_shape_mismatch(self):
626
+ """Test that BatchNorm raises error on shape mismatch."""
627
+ net = brainstate.nn.BatchNorm2d((28, 28, 3))
628
+
629
+ # Wrong shape should raise error
630
+ with self.assertRaises(ValueError):
631
+ x = brainstate.random.randn(4, 32, 32, 3) # Wrong height/width
632
+ _ = net(x)
633
+
634
+ def test_batchnorm_without_track_and_affine(self):
635
+ """Test that affine=True requires track_running_stats=True."""
636
+ # This should raise an assertion error
637
+ with self.assertRaises(AssertionError):
638
+ net = brainstate.nn.BatchNorm2d(
639
+ (28, 28, 3),
640
+ track_running_stats=False,
641
+ affine=True # Requires track_running_stats=True
642
+ )
643
+
644
+ def test_groupnorm_both_params(self):
645
+ """Test that GroupNorm raises error when both num_groups and group_size are specified."""
646
+ with self.assertRaises(ValueError):
647
+ net = brainstate.nn.GroupNorm(
648
+ (32, 32, 16),
649
+ num_groups=4,
650
+ group_size=4, # Can't specify both
651
+ )
652
+
653
+ def test_groupnorm_neither_param(self):
654
+ """Test that GroupNorm raises error when neither num_groups nor group_size are specified."""
655
+ with self.assertRaises(ValueError):
656
+ net = brainstate.nn.GroupNorm(
657
+ (32, 32, 16),
658
+ num_groups=None,
659
+ group_size=None, # Must specify one
660
+ )
661
+
662
+
663
+ class TestNormalizationConsistency(parameterized.TestCase):
664
+ """Test consistency across different batch sizes and modes."""
665
+
666
+ def test_batchnorm2d_consistency_across_batches(self):
667
+ """Test that BatchNorm2d behaves consistently across different batch sizes."""
668
+ in_size = (28, 28, 3)
669
+ net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
670
+
671
+ # Train on larger batch
672
+ brainstate.environ.set(fit=True)
673
+ x_large = brainstate.random.randn(32, 28, 28, 3)
674
+ _ = net(x_large)
675
+
676
+ # Test on smaller batch
677
+ brainstate.environ.set(fit=False)
678
+ x_small = brainstate.random.randn(4, 28, 28, 3)
679
+ output = net(x_small)
680
+
681
+ self.assertEqual(output.shape, x_small.shape)
682
+
683
+ def test_layernorm_consistency(self):
684
+ """Test that LayerNorm produces consistent results."""
685
+ in_size = (10, 20)
686
+ net = brainstate.nn.LayerNorm(in_size)
687
+
688
+ x = brainstate.random.randn(8, 10, 20)
689
+
690
+ # Run twice
691
+ output1 = net(x)
692
+ output2 = net(x)
693
+
694
+ # Should be deterministic
695
+ np.testing.assert_allclose(output1, output2)
696
+
697
+
698
+ if __name__ == '__main__':
699
+ absltest.main()