brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,58 +15,684 @@
15
15
 
16
16
  from absl.testing import absltest
17
17
  from absl.testing import parameterized
18
+ import jax.numpy as jnp
19
+ import numpy as np
18
20
 
19
21
  import brainstate
20
22
 
21
23
 
22
- class Test_Normalization(parameterized.TestCase):
24
+ class TestBatchNorm0d(parameterized.TestCase):
25
+ """Test BatchNorm0d with various configurations."""
26
+
27
+ @parameterized.product(
28
+ fit=[True, False],
29
+ feature_axis=[-1, 0],
30
+ track_running_stats=[True, False],
31
+ )
32
+ def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats):
33
+ """Test BatchNorm0d with batched input."""
34
+ batch_size = 8
35
+ channels = 10
36
+
37
+ # Channel last: (batch, channels)
38
+ if feature_axis == -1:
39
+ in_size = (channels,)
40
+ input_shape = (batch_size, channels)
41
+ # Channel first: (batch, channels) - same for 0D
42
+ else:
43
+ in_size = (channels,)
44
+ input_shape = (batch_size, channels)
45
+
46
+ # affine can only be True when track_running_stats is True
47
+ affine = track_running_stats
48
+
49
+ net = brainstate.nn.BatchNorm0d(
50
+ in_size,
51
+ feature_axis=feature_axis,
52
+ track_running_stats=track_running_stats,
53
+ affine=affine
54
+ )
55
+ brainstate.environ.set(fit=fit)
56
+
57
+ x = brainstate.random.randn(*input_shape)
58
+ output = net(x)
59
+
60
+ # Check output shape matches input
61
+ self.assertEqual(output.shape, input_shape)
62
+
63
+ # Check that output has approximately zero mean and unit variance when fitting
64
+ if fit and track_running_stats:
65
+ # Stats should be computed along batch dimension
66
+ mean = jnp.mean(output, axis=0)
67
+ var = jnp.var(output, axis=0)
68
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
69
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
70
+
71
+ def test_batchnorm0d_without_batch(self):
72
+ """Test BatchNorm0d without batching."""
73
+ channels = 10
74
+ in_size = (channels,)
75
+
76
+ net = brainstate.nn.BatchNorm0d(in_size, track_running_stats=True)
77
+ brainstate.environ.set(fit=False) # Use running stats
78
+
79
+ # Run with batch first to populate running stats
80
+ brainstate.environ.set(fit=True)
81
+ x_batch = brainstate.random.randn(16, channels)
82
+ _ = net(x_batch)
83
+
84
+ # Now test without batch
85
+ brainstate.environ.set(fit=False)
86
+ x_single = brainstate.random.randn(channels)
87
+ output = net(x_single)
88
+
89
+ self.assertEqual(output.shape, (channels,))
90
+
91
+ def test_batchnorm0d_affine(self):
92
+ """Test BatchNorm0d with and without affine parameters."""
93
+ channels = 10
94
+ in_size = (channels,)
95
+
96
+ # With affine
97
+ net_affine = brainstate.nn.BatchNorm0d(in_size, affine=True)
98
+ self.assertIsNotNone(net_affine.weight)
99
+
100
+ # Without affine (track_running_stats must be False)
101
+ net_no_affine = brainstate.nn.BatchNorm0d(
102
+ in_size, affine=False, track_running_stats=False
103
+ )
104
+ self.assertIsNone(net_no_affine.weight)
105
+
106
+
107
+ class TestBatchNorm1d(parameterized.TestCase):
108
+ """Test BatchNorm1d with various configurations."""
109
+
23
110
  @parameterized.product(
24
111
  fit=[True, False],
112
+ feature_axis=[-1, 0],
113
+ track_running_stats=[True, False],
25
114
  )
26
- def test_BatchNorm1d(self, fit):
27
- net = brainstate.nn.BatchNorm1d((3, 10))
115
+ def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats):
116
+ """Test BatchNorm1d with batched input."""
117
+ batch_size = 8
118
+ length = 20
119
+ channels = 10
120
+
121
+ # Channel last: (batch, length, channels)
122
+ if feature_axis == -1:
123
+ in_size = (length, channels)
124
+ input_shape = (batch_size, length, channels)
125
+ feature_axis_param = -1
126
+ # Channel first: (batch, channels, length)
127
+ else:
128
+ in_size = (channels, length)
129
+ input_shape = (batch_size, channels, length)
130
+ feature_axis_param = 0
131
+
132
+ # affine can only be True when track_running_stats is True
133
+ affine = track_running_stats
134
+
135
+ net = brainstate.nn.BatchNorm1d(
136
+ in_size,
137
+ feature_axis=feature_axis_param,
138
+ track_running_stats=track_running_stats,
139
+ affine=affine
140
+ )
28
141
  brainstate.environ.set(fit=fit)
29
- input = brainstate.random.randn(1, 3, 10)
30
- output = net(input)
142
+
143
+ x = brainstate.random.randn(*input_shape)
144
+ output = net(x)
145
+
146
+ # Check output shape matches input
147
+ self.assertEqual(output.shape, input_shape)
148
+
149
+ def test_batchnorm1d_without_batch(self):
150
+ """Test BatchNorm1d without batching."""
151
+ length = 20
152
+ channels = 10
153
+ in_size = (length, channels)
154
+
155
+ net = brainstate.nn.BatchNorm1d(in_size, track_running_stats=True)
156
+
157
+ # Populate running stats first
158
+ brainstate.environ.set(fit=True)
159
+ x_batch = brainstate.random.randn(8, length, channels)
160
+ _ = net(x_batch)
161
+
162
+ # Test without batch
163
+ brainstate.environ.set(fit=False)
164
+ x_single = brainstate.random.randn(length, channels)
165
+ output = net(x_single)
166
+
167
+ self.assertEqual(output.shape, (length, channels))
168
+
169
+ @parameterized.product(
170
+ feature_axis=[-1, 0],
171
+ )
172
+ def test_batchnorm1d_channel_consistency(self, feature_axis):
173
+ """Test that normalization is consistent across different channel configurations."""
174
+ batch_size = 16
175
+ length = 20
176
+ channels = 10
177
+
178
+ if feature_axis == -1:
179
+ in_size = (length, channels)
180
+ input_shape = (batch_size, length, channels)
181
+ else:
182
+ in_size = (channels, length)
183
+ input_shape = (batch_size, channels, length)
184
+
185
+ net = brainstate.nn.BatchNorm1d(in_size, feature_axis=feature_axis)
186
+ brainstate.environ.set(fit=True)
187
+
188
+ x = brainstate.random.randn(*input_shape)
189
+ output = net(x)
190
+
191
+ # Output should have same shape as input
192
+ self.assertEqual(output.shape, input_shape)
193
+
194
+
195
+ class TestBatchNorm2d(parameterized.TestCase):
196
+ """Test BatchNorm2d with various configurations."""
31
197
 
32
198
  @parameterized.product(
33
- fit=[True, False]
199
+ fit=[True, False],
200
+ feature_axis=[-1, 0],
201
+ track_running_stats=[True, False],
34
202
  )
35
- def test_BatchNorm2d(self, fit):
36
- net = brainstate.nn.BatchNorm2d([3, 4, 10])
203
+ def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats):
204
+ """Test BatchNorm2d with batched input (images)."""
205
+ batch_size = 4
206
+ height, width = 28, 28
207
+ channels = 3
208
+
209
+ # Channel last: (batch, height, width, channels)
210
+ if feature_axis == -1:
211
+ in_size = (height, width, channels)
212
+ input_shape = (batch_size, height, width, channels)
213
+ feature_axis_param = -1
214
+ # Channel first: (batch, channels, height, width)
215
+ else:
216
+ in_size = (channels, height, width)
217
+ input_shape = (batch_size, channels, height, width)
218
+ feature_axis_param = 0
219
+
220
+ # affine can only be True when track_running_stats is True
221
+ affine = track_running_stats
222
+
223
+ net = brainstate.nn.BatchNorm2d(
224
+ in_size,
225
+ feature_axis=feature_axis_param,
226
+ track_running_stats=track_running_stats,
227
+ affine=affine
228
+ )
37
229
  brainstate.environ.set(fit=fit)
38
- input = brainstate.random.randn(1, 3, 4, 10)
39
- output = net(input)
230
+
231
+ x = brainstate.random.randn(*input_shape)
232
+ output = net(x)
233
+
234
+ # Check output shape matches input
235
+ self.assertEqual(output.shape, input_shape)
236
+
237
+ # Check normalization properties during training
238
+ if fit and track_running_stats:
239
+ # For channel last: normalize over (batch, height, width)
240
+ # For channel first: normalize over (batch, height, width)
241
+ if feature_axis == -1:
242
+ axes = (0, 1, 2)
243
+ else:
244
+ axes = (0, 2, 3)
245
+
246
+ mean = jnp.mean(output, axis=axes)
247
+ var = jnp.var(output, axis=axes)
248
+
249
+ # Each channel should be approximately normalized
250
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
251
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
252
+
253
+ def test_batchnorm2d_without_batch(self):
254
+ """Test BatchNorm2d without batching."""
255
+ height, width = 28, 28
256
+ channels = 3
257
+ in_size = (height, width, channels)
258
+
259
+ net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
260
+
261
+ # Populate running stats
262
+ brainstate.environ.set(fit=True)
263
+ x_batch = brainstate.random.randn(8, height, width, channels)
264
+ _ = net(x_batch)
265
+
266
+ # Test without batch
267
+ brainstate.environ.set(fit=False)
268
+ x_single = brainstate.random.randn(height, width, channels)
269
+ output = net(x_single)
270
+
271
+ self.assertEqual(output.shape, (height, width, channels))
272
+
273
+
274
+ class TestBatchNorm3d(parameterized.TestCase):
275
+ """Test BatchNorm3d with various configurations."""
40
276
 
41
277
  @parameterized.product(
42
- fit=[True, False]
278
+ fit=[True, False],
279
+ feature_axis=[-1, 0],
280
+ track_running_stats=[True, False],
43
281
  )
44
- def test_BatchNorm3d(self, fit):
45
- net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
282
+ def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats):
283
+ """Test BatchNorm3d with batched input (volumes)."""
284
+ batch_size = 2
285
+ depth, height, width = 8, 16, 16
286
+ channels = 2
287
+
288
+ # Channel last: (batch, depth, height, width, channels)
289
+ if feature_axis == -1:
290
+ in_size = (depth, height, width, channels)
291
+ input_shape = (batch_size, depth, height, width, channels)
292
+ feature_axis_param = -1
293
+ # Channel first: (batch, channels, depth, height, width)
294
+ else:
295
+ in_size = (channels, depth, height, width)
296
+ input_shape = (batch_size, channels, depth, height, width)
297
+ feature_axis_param = 0
298
+
299
+ # affine can only be True when track_running_stats is True
300
+ affine = track_running_stats
301
+
302
+ net = brainstate.nn.BatchNorm3d(
303
+ in_size,
304
+ feature_axis=feature_axis_param,
305
+ track_running_stats=track_running_stats,
306
+ affine=affine
307
+ )
46
308
  brainstate.environ.set(fit=fit)
47
- input = brainstate.random.randn(1, 3, 4, 5, 10)
48
- output = net(input)
49
-
50
- # @parameterized.product(
51
- # normalized_shape=(10, [5, 10])
52
- # )
53
- # def test_LayerNorm(self, normalized_shape):
54
- # net = brainstate.nn.LayerNorm(normalized_shape, )
55
- # input = brainstate.random.randn(20, 5, 10)
56
- # output = net(input)
57
- #
58
- # @parameterized.product(
59
- # num_groups=[1, 2, 3, 6]
60
- # )
61
- # def test_GroupNorm(self, num_groups):
62
- # input = brainstate.random.randn(20, 10, 10, 6)
63
- # net = brainstate.nn.GroupNorm(num_groups=num_groups, num_channels=6, )
64
- # output = net(input)
65
- #
66
- # def test_InstanceNorm(self):
67
- # input = brainstate.random.randn(20, 10, 10, 6)
68
- # net = brainstate.nn.InstanceNorm(num_channels=6, )
69
- # output = net(input)
309
+
310
+ x = brainstate.random.randn(*input_shape)
311
+ output = net(x)
312
+
313
+ # Check output shape matches input
314
+ self.assertEqual(output.shape, input_shape)
315
+
316
+ def test_batchnorm3d_without_batch(self):
317
+ """Test BatchNorm3d without batching."""
318
+ depth, height, width = 8, 16, 16
319
+ channels = 2
320
+ in_size = (depth, height, width, channels)
321
+
322
+ net = brainstate.nn.BatchNorm3d(in_size, track_running_stats=True)
323
+
324
+ # Populate running stats
325
+ brainstate.environ.set(fit=True)
326
+ x_batch = brainstate.random.randn(4, depth, height, width, channels)
327
+ _ = net(x_batch)
328
+
329
+ # Test without batch
330
+ brainstate.environ.set(fit=False)
331
+ x_single = brainstate.random.randn(depth, height, width, channels)
332
+ output = net(x_single)
333
+
334
+ self.assertEqual(output.shape, (depth, height, width, channels))
335
+
336
+
337
+ class TestLayerNorm(parameterized.TestCase):
338
+ """Test LayerNorm with various configurations."""
339
+
340
+ @parameterized.product(
341
+ reduction_axes=[(-1,), (-2, -1), (-3, -2, -1)],
342
+ use_bias=[True, False],
343
+ use_scale=[True, False],
344
+ )
345
+ def test_layernorm_basic(self, reduction_axes, use_bias, use_scale):
346
+ """Test LayerNorm with different reduction axes."""
347
+ in_size = (10, 20, 30)
348
+
349
+ net = brainstate.nn.LayerNorm(
350
+ in_size,
351
+ reduction_axes=reduction_axes,
352
+ use_bias=use_bias,
353
+ use_scale=use_scale,
354
+ )
355
+
356
+ # With batch
357
+ x = brainstate.random.randn(8, 10, 20, 30)
358
+ output = net(x)
359
+ self.assertEqual(output.shape, x.shape)
360
+
361
+ # Check normalization properties
362
+ mean = jnp.mean(output, axis=tuple(i + 1 for i in range(len(in_size))
363
+ if i - len(in_size) in reduction_axes))
364
+ var = jnp.var(output, axis=tuple(i + 1 for i in range(len(in_size))
365
+ if i - len(in_size) in reduction_axes))
366
+
367
+ def test_layernorm_2d_features(self):
368
+ """Test LayerNorm on 2D features (like in transformers)."""
369
+ seq_length = 50
370
+ hidden_dim = 128
371
+ batch_size = 16
372
+
373
+ in_size = (seq_length, hidden_dim)
374
+ net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1)
375
+
376
+ x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
377
+ output = net(x)
378
+
379
+ self.assertEqual(output.shape, x.shape)
380
+
381
+ # Each position should be normalized independently
382
+ mean = jnp.mean(output, axis=-1)
383
+ var = jnp.var(output, axis=-1)
384
+
385
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
386
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
387
+
388
+ def test_layernorm_without_batch(self):
389
+ """Test LayerNorm without batch dimension."""
390
+ in_size = (10, 20)
391
+ net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1)
392
+
393
+ x = brainstate.random.randn(10, 20)
394
+ output = net(x)
395
+
396
+ self.assertEqual(output.shape, (10, 20))
397
+
398
+ @parameterized.product(
399
+ in_size=[(10,), (10, 20), (10, 20, 30)],
400
+ )
401
+ def test_layernorm_various_dims(self, in_size):
402
+ """Test LayerNorm with various input dimensions."""
403
+ net = brainstate.nn.LayerNorm(in_size)
404
+
405
+ # With batch
406
+ x_with_batch = brainstate.random.randn(8, *in_size)
407
+ output_with_batch = net(x_with_batch)
408
+ self.assertEqual(output_with_batch.shape, x_with_batch.shape)
409
+
410
+
411
+ class TestRMSNorm(parameterized.TestCase):
412
+ """Test RMSNorm with various configurations."""
413
+
414
+ @parameterized.product(
415
+ use_scale=[True, False],
416
+ reduction_axes=[(-1,), (-2, -1)],
417
+ )
418
+ def test_rmsnorm_basic(self, use_scale, reduction_axes):
419
+ """Test RMSNorm with different configurations."""
420
+ in_size = (10, 20)
421
+
422
+ net = brainstate.nn.RMSNorm(
423
+ in_size,
424
+ use_scale=use_scale,
425
+ reduction_axes=reduction_axes,
426
+ )
427
+
428
+ x = brainstate.random.randn(8, 10, 20)
429
+ output = net(x)
430
+
431
+ self.assertEqual(output.shape, x.shape)
432
+
433
+ def test_rmsnorm_transformer_like(self):
434
+ """Test RMSNorm in transformer-like setting."""
435
+ seq_length = 50
436
+ hidden_dim = 128
437
+ batch_size = 16
438
+
439
+ in_size = (seq_length, hidden_dim)
440
+ net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1, feature_axes=-1)
441
+
442
+ x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
443
+ output = net(x)
444
+
445
+ self.assertEqual(output.shape, x.shape)
446
+
447
+ # RMSNorm should have approximately unit RMS (not zero mean)
448
+ rms = jnp.sqrt(jnp.mean(jnp.square(output), axis=-1))
449
+ np.testing.assert_allclose(rms, 1.0, atol=1e-1)
450
+
451
+ def test_rmsnorm_without_batch(self):
452
+ """Test RMSNorm without batch dimension."""
453
+ in_size = (10, 20)
454
+ net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1)
455
+
456
+ x = brainstate.random.randn(10, 20)
457
+ output = net(x)
458
+
459
+ self.assertEqual(output.shape, (10, 20))
460
+
461
+
462
+ class TestGroupNorm(parameterized.TestCase):
463
+ """Test GroupNorm with various configurations."""
464
+
465
+ @parameterized.product(
466
+ num_groups=[1, 2, 4, 8],
467
+ use_bias=[True, False],
468
+ use_scale=[True, False],
469
+ )
470
+ def test_groupnorm_basic(self, num_groups, use_bias, use_scale):
471
+ """Test GroupNorm with different number of groups."""
472
+ channels = 16
473
+ # GroupNorm requires 1D feature axis (just the channel dimension)
474
+ in_size = (channels,)
475
+
476
+ # Check if channels is divisible by num_groups
477
+ if channels % num_groups != 0:
478
+ return
479
+
480
+ net = brainstate.nn.GroupNorm(
481
+ in_size,
482
+ feature_axis=0,
483
+ num_groups=num_groups,
484
+ use_bias=use_bias,
485
+ use_scale=use_scale,
486
+ )
487
+
488
+ # Input needs at least 2D: (height, width, channels) or (batch, channels)
489
+ # Using (batch, channels) format
490
+ x = brainstate.random.randn(4, channels)
491
+ output = net(x)
492
+
493
+ self.assertEqual(output.shape, x.shape)
494
+
495
+ def test_groupnorm_channel_first(self):
496
+ """Test GroupNorm with channel-first format for images."""
497
+ channels = 16
498
+ # GroupNorm requires 1D feature (just channels)
499
+ in_size = (channels,)
500
+
501
+ net = brainstate.nn.GroupNorm(
502
+ in_size,
503
+ feature_axis=0,
504
+ num_groups=4,
505
+ )
506
+
507
+ # Test with image-like data: (batch, height, width, channels)
508
+ x = brainstate.random.randn(4, 32, 32, channels)
509
+ output = net(x)
510
+
511
+ self.assertEqual(output.shape, x.shape)
512
+
513
+ def test_groupnorm_channel_last(self):
514
+ """Test GroupNorm with channel-last format for images."""
515
+ channels = 16
516
+ # GroupNorm requires 1D feature (just channels)
517
+ in_size = (channels,)
518
+
519
+ net = brainstate.nn.GroupNorm(
520
+ in_size,
521
+ feature_axis=0, # feature_axis refers to position in in_size
522
+ num_groups=4,
523
+ )
524
+
525
+ # Test with image-like data: (batch, height, width, channels)
526
+ x = brainstate.random.randn(4, 32, 32, channels)
527
+ output = net(x)
528
+
529
+ self.assertEqual(output.shape, x.shape)
530
+
531
+ def test_groupnorm_equals_layernorm(self):
532
+ """Test that GroupNorm with num_groups=1 equals LayerNorm."""
533
+ channels = 16
534
+ # GroupNorm requires 1D feature
535
+ in_size = (channels,)
536
+
537
+ # GroupNorm with 1 group
538
+ group_norm = brainstate.nn.GroupNorm(
539
+ in_size,
540
+ feature_axis=0,
541
+ num_groups=1,
542
+ )
543
+
544
+ # LayerNorm with same setup
545
+ layer_norm = brainstate.nn.LayerNorm(
546
+ in_size,
547
+ reduction_axes=-1,
548
+ feature_axes=-1,
549
+ )
550
+
551
+ # Use 2D input: (batch, channels)
552
+ x = brainstate.random.randn(8, channels)
553
+
554
+ output_gn = group_norm(x)
555
+ output_ln = layer_norm(x)
556
+
557
+ # Shapes should match
558
+ self.assertEqual(output_gn.shape, output_ln.shape)
559
+
560
+ def test_groupnorm_group_size(self):
561
+ """Test GroupNorm with group_size instead of num_groups."""
562
+ channels = 16
563
+ group_size = 4
564
+ # GroupNorm requires 1D feature
565
+ in_size = (channels,)
566
+
567
+ net = brainstate.nn.GroupNorm(
568
+ in_size,
569
+ feature_axis=0,
570
+ num_groups=None,
571
+ group_size=group_size,
572
+ )
573
+
574
+ # Use 2D input: (batch, channels)
575
+ x = brainstate.random.randn(4, channels)
576
+ output = net(x)
577
+
578
+ self.assertEqual(output.shape, x.shape)
579
+ self.assertEqual(net.num_groups, channels // group_size)
580
+
581
+ def test_groupnorm_invalid_groups(self):
582
+ """Test that invalid num_groups raises error."""
583
+ channels = 15 # Not divisible by many numbers
584
+ # GroupNorm requires 1D feature
585
+ in_size = (channels,)
586
+
587
+ # Should raise error if num_groups doesn't divide channels
588
+ with self.assertRaises(ValueError):
589
+ net = brainstate.nn.GroupNorm(
590
+ in_size,
591
+ feature_axis=0,
592
+ num_groups=4, # 15 is not divisible by 4
593
+ )
594
+
595
+
596
+ class TestNormalizationUtilities(parameterized.TestCase):
597
+ """Test utility functions for normalization."""
598
+
599
+ def test_weight_standardization(self):
600
+ """Test weight_standardization function."""
601
+ w = brainstate.random.randn(3, 4, 5, 6)
602
+
603
+ w_std = brainstate.nn.weight_standardization(w, eps=1e-4)
604
+
605
+ self.assertEqual(w_std.shape, w.shape)
606
+
607
+ # Check that standardization works
608
+ # Mean should be close to 0 along non-output axes
609
+ mean = jnp.mean(w_std, axis=(0, 1, 2))
610
+ np.testing.assert_allclose(mean, 0.0, atol=1e-4)
611
+
612
+ def test_weight_standardization_with_gain(self):
613
+ """Test weight_standardization with gain parameter."""
614
+ w = brainstate.random.randn(3, 4, 5, 6)
615
+ gain = jnp.ones((6,))
616
+
617
+ w_std = brainstate.nn.weight_standardization(w, gain=gain)
618
+
619
+ self.assertEqual(w_std.shape, w.shape)
620
+
621
+
622
+ class TestNormalizationEdgeCases(parameterized.TestCase):
623
+ """Test edge cases and error conditions."""
624
+
625
+ def test_batchnorm_shape_mismatch(self):
626
+ """Test that BatchNorm raises error on shape mismatch."""
627
+ net = brainstate.nn.BatchNorm2d((28, 28, 3))
628
+
629
+ # Wrong shape should raise error
630
+ with self.assertRaises(ValueError):
631
+ x = brainstate.random.randn(4, 32, 32, 3) # Wrong height/width
632
+ _ = net(x)
633
+
634
+ def test_batchnorm_without_track_and_affine(self):
635
+ """Test that affine=True requires track_running_stats=True."""
636
+ # This should raise an assertion error
637
+ with self.assertRaises(AssertionError):
638
+ net = brainstate.nn.BatchNorm2d(
639
+ (28, 28, 3),
640
+ track_running_stats=False,
641
+ affine=True # Requires track_running_stats=True
642
+ )
643
+
644
+ def test_groupnorm_both_params(self):
645
+ """Test that GroupNorm raises error when both num_groups and group_size are specified."""
646
+ with self.assertRaises(ValueError):
647
+ net = brainstate.nn.GroupNorm(
648
+ (32, 32, 16),
649
+ num_groups=4,
650
+ group_size=4, # Can't specify both
651
+ )
652
+
653
+ def test_groupnorm_neither_param(self):
654
+ """Test that GroupNorm raises error when neither num_groups nor group_size are specified."""
655
+ with self.assertRaises(ValueError):
656
+ net = brainstate.nn.GroupNorm(
657
+ (32, 32, 16),
658
+ num_groups=None,
659
+ group_size=None, # Must specify one
660
+ )
661
+
662
+
663
+ class TestNormalizationConsistency(parameterized.TestCase):
664
+ """Test consistency across different batch sizes and modes."""
665
+
666
+ def test_batchnorm2d_consistency_across_batches(self):
667
+ """Test that BatchNorm2d behaves consistently across different batch sizes."""
668
+ in_size = (28, 28, 3)
669
+ net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
670
+
671
+ # Train on larger batch
672
+ brainstate.environ.set(fit=True)
673
+ x_large = brainstate.random.randn(32, 28, 28, 3)
674
+ _ = net(x_large)
675
+
676
+ # Test on smaller batch
677
+ brainstate.environ.set(fit=False)
678
+ x_small = brainstate.random.randn(4, 28, 28, 3)
679
+ output = net(x_small)
680
+
681
+ self.assertEqual(output.shape, x_small.shape)
682
+
683
+ def test_layernorm_consistency(self):
684
+ """Test that LayerNorm produces consistent results."""
685
+ in_size = (10, 20)
686
+ net = brainstate.nn.LayerNorm(in_size)
687
+
688
+ x = brainstate.random.randn(8, 10, 20)
689
+
690
+ # Run twice
691
+ output1 = net(x)
692
+ output2 = net(x)
693
+
694
+ # Should be deterministic
695
+ np.testing.assert_allclose(output1, output2)
70
696
 
71
697
 
72
698
  if __name__ == '__main__':