brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,699 +1,699 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from absl.testing import absltest
17
- from absl.testing import parameterized
18
- import jax.numpy as jnp
19
- import numpy as np
20
-
21
- import brainstate
22
-
23
-
24
- class TestBatchNorm0d(parameterized.TestCase):
25
- """Test BatchNorm0d with various configurations."""
26
-
27
- @parameterized.product(
28
- fit=[True, False],
29
- feature_axis=[-1, 0],
30
- track_running_stats=[True, False],
31
- )
32
- def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats):
33
- """Test BatchNorm0d with batched input."""
34
- batch_size = 8
35
- channels = 10
36
-
37
- # Channel last: (batch, channels)
38
- if feature_axis == -1:
39
- in_size = (channels,)
40
- input_shape = (batch_size, channels)
41
- # Channel first: (batch, channels) - same for 0D
42
- else:
43
- in_size = (channels,)
44
- input_shape = (batch_size, channels)
45
-
46
- # affine can only be True when track_running_stats is True
47
- affine = track_running_stats
48
-
49
- net = brainstate.nn.BatchNorm0d(
50
- in_size,
51
- feature_axis=feature_axis,
52
- track_running_stats=track_running_stats,
53
- affine=affine
54
- )
55
- brainstate.environ.set(fit=fit)
56
-
57
- x = brainstate.random.randn(*input_shape)
58
- output = net(x)
59
-
60
- # Check output shape matches input
61
- self.assertEqual(output.shape, input_shape)
62
-
63
- # Check that output has approximately zero mean and unit variance when fitting
64
- if fit and track_running_stats:
65
- # Stats should be computed along batch dimension
66
- mean = jnp.mean(output, axis=0)
67
- var = jnp.var(output, axis=0)
68
- np.testing.assert_allclose(mean, 0.0, atol=1e-5)
69
- np.testing.assert_allclose(var, 1.0, atol=1e-1)
70
-
71
- def test_batchnorm0d_without_batch(self):
72
- """Test BatchNorm0d without batching."""
73
- channels = 10
74
- in_size = (channels,)
75
-
76
- net = brainstate.nn.BatchNorm0d(in_size, track_running_stats=True)
77
- brainstate.environ.set(fit=False) # Use running stats
78
-
79
- # Run with batch first to populate running stats
80
- brainstate.environ.set(fit=True)
81
- x_batch = brainstate.random.randn(16, channels)
82
- _ = net(x_batch)
83
-
84
- # Now test without batch
85
- brainstate.environ.set(fit=False)
86
- x_single = brainstate.random.randn(channels)
87
- output = net(x_single)
88
-
89
- self.assertEqual(output.shape, (channels,))
90
-
91
- def test_batchnorm0d_affine(self):
92
- """Test BatchNorm0d with and without affine parameters."""
93
- channels = 10
94
- in_size = (channels,)
95
-
96
- # With affine
97
- net_affine = brainstate.nn.BatchNorm0d(in_size, affine=True)
98
- self.assertIsNotNone(net_affine.weight)
99
-
100
- # Without affine (track_running_stats must be False)
101
- net_no_affine = brainstate.nn.BatchNorm0d(
102
- in_size, affine=False, track_running_stats=False
103
- )
104
- self.assertIsNone(net_no_affine.weight)
105
-
106
-
107
- class TestBatchNorm1d(parameterized.TestCase):
108
- """Test BatchNorm1d with various configurations."""
109
-
110
- @parameterized.product(
111
- fit=[True, False],
112
- feature_axis=[-1, 0],
113
- track_running_stats=[True, False],
114
- )
115
- def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats):
116
- """Test BatchNorm1d with batched input."""
117
- batch_size = 8
118
- length = 20
119
- channels = 10
120
-
121
- # Channel last: (batch, length, channels)
122
- if feature_axis == -1:
123
- in_size = (length, channels)
124
- input_shape = (batch_size, length, channels)
125
- feature_axis_param = -1
126
- # Channel first: (batch, channels, length)
127
- else:
128
- in_size = (channels, length)
129
- input_shape = (batch_size, channels, length)
130
- feature_axis_param = 0
131
-
132
- # affine can only be True when track_running_stats is True
133
- affine = track_running_stats
134
-
135
- net = brainstate.nn.BatchNorm1d(
136
- in_size,
137
- feature_axis=feature_axis_param,
138
- track_running_stats=track_running_stats,
139
- affine=affine
140
- )
141
- brainstate.environ.set(fit=fit)
142
-
143
- x = brainstate.random.randn(*input_shape)
144
- output = net(x)
145
-
146
- # Check output shape matches input
147
- self.assertEqual(output.shape, input_shape)
148
-
149
- def test_batchnorm1d_without_batch(self):
150
- """Test BatchNorm1d without batching."""
151
- length = 20
152
- channels = 10
153
- in_size = (length, channels)
154
-
155
- net = brainstate.nn.BatchNorm1d(in_size, track_running_stats=True)
156
-
157
- # Populate running stats first
158
- brainstate.environ.set(fit=True)
159
- x_batch = brainstate.random.randn(8, length, channels)
160
- _ = net(x_batch)
161
-
162
- # Test without batch
163
- brainstate.environ.set(fit=False)
164
- x_single = brainstate.random.randn(length, channels)
165
- output = net(x_single)
166
-
167
- self.assertEqual(output.shape, (length, channels))
168
-
169
- @parameterized.product(
170
- feature_axis=[-1, 0],
171
- )
172
- def test_batchnorm1d_channel_consistency(self, feature_axis):
173
- """Test that normalization is consistent across different channel configurations."""
174
- batch_size = 16
175
- length = 20
176
- channels = 10
177
-
178
- if feature_axis == -1:
179
- in_size = (length, channels)
180
- input_shape = (batch_size, length, channels)
181
- else:
182
- in_size = (channels, length)
183
- input_shape = (batch_size, channels, length)
184
-
185
- net = brainstate.nn.BatchNorm1d(in_size, feature_axis=feature_axis)
186
- brainstate.environ.set(fit=True)
187
-
188
- x = brainstate.random.randn(*input_shape)
189
- output = net(x)
190
-
191
- # Output should have same shape as input
192
- self.assertEqual(output.shape, input_shape)
193
-
194
-
195
- class TestBatchNorm2d(parameterized.TestCase):
196
- """Test BatchNorm2d with various configurations."""
197
-
198
- @parameterized.product(
199
- fit=[True, False],
200
- feature_axis=[-1, 0],
201
- track_running_stats=[True, False],
202
- )
203
- def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats):
204
- """Test BatchNorm2d with batched input (images)."""
205
- batch_size = 4
206
- height, width = 28, 28
207
- channels = 3
208
-
209
- # Channel last: (batch, height, width, channels)
210
- if feature_axis == -1:
211
- in_size = (height, width, channels)
212
- input_shape = (batch_size, height, width, channels)
213
- feature_axis_param = -1
214
- # Channel first: (batch, channels, height, width)
215
- else:
216
- in_size = (channels, height, width)
217
- input_shape = (batch_size, channels, height, width)
218
- feature_axis_param = 0
219
-
220
- # affine can only be True when track_running_stats is True
221
- affine = track_running_stats
222
-
223
- net = brainstate.nn.BatchNorm2d(
224
- in_size,
225
- feature_axis=feature_axis_param,
226
- track_running_stats=track_running_stats,
227
- affine=affine
228
- )
229
- brainstate.environ.set(fit=fit)
230
-
231
- x = brainstate.random.randn(*input_shape)
232
- output = net(x)
233
-
234
- # Check output shape matches input
235
- self.assertEqual(output.shape, input_shape)
236
-
237
- # Check normalization properties during training
238
- if fit and track_running_stats:
239
- # For channel last: normalize over (batch, height, width)
240
- # For channel first: normalize over (batch, height, width)
241
- if feature_axis == -1:
242
- axes = (0, 1, 2)
243
- else:
244
- axes = (0, 2, 3)
245
-
246
- mean = jnp.mean(output, axis=axes)
247
- var = jnp.var(output, axis=axes)
248
-
249
- # Each channel should be approximately normalized
250
- np.testing.assert_allclose(mean, 0.0, atol=1e-5)
251
- np.testing.assert_allclose(var, 1.0, atol=1e-1)
252
-
253
- def test_batchnorm2d_without_batch(self):
254
- """Test BatchNorm2d without batching."""
255
- height, width = 28, 28
256
- channels = 3
257
- in_size = (height, width, channels)
258
-
259
- net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
260
-
261
- # Populate running stats
262
- brainstate.environ.set(fit=True)
263
- x_batch = brainstate.random.randn(8, height, width, channels)
264
- _ = net(x_batch)
265
-
266
- # Test without batch
267
- brainstate.environ.set(fit=False)
268
- x_single = brainstate.random.randn(height, width, channels)
269
- output = net(x_single)
270
-
271
- self.assertEqual(output.shape, (height, width, channels))
272
-
273
-
274
- class TestBatchNorm3d(parameterized.TestCase):
275
- """Test BatchNorm3d with various configurations."""
276
-
277
- @parameterized.product(
278
- fit=[True, False],
279
- feature_axis=[-1, 0],
280
- track_running_stats=[True, False],
281
- )
282
- def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats):
283
- """Test BatchNorm3d with batched input (volumes)."""
284
- batch_size = 2
285
- depth, height, width = 8, 16, 16
286
- channels = 2
287
-
288
- # Channel last: (batch, depth, height, width, channels)
289
- if feature_axis == -1:
290
- in_size = (depth, height, width, channels)
291
- input_shape = (batch_size, depth, height, width, channels)
292
- feature_axis_param = -1
293
- # Channel first: (batch, channels, depth, height, width)
294
- else:
295
- in_size = (channels, depth, height, width)
296
- input_shape = (batch_size, channels, depth, height, width)
297
- feature_axis_param = 0
298
-
299
- # affine can only be True when track_running_stats is True
300
- affine = track_running_stats
301
-
302
- net = brainstate.nn.BatchNorm3d(
303
- in_size,
304
- feature_axis=feature_axis_param,
305
- track_running_stats=track_running_stats,
306
- affine=affine
307
- )
308
- brainstate.environ.set(fit=fit)
309
-
310
- x = brainstate.random.randn(*input_shape)
311
- output = net(x)
312
-
313
- # Check output shape matches input
314
- self.assertEqual(output.shape, input_shape)
315
-
316
- def test_batchnorm3d_without_batch(self):
317
- """Test BatchNorm3d without batching."""
318
- depth, height, width = 8, 16, 16
319
- channels = 2
320
- in_size = (depth, height, width, channels)
321
-
322
- net = brainstate.nn.BatchNorm3d(in_size, track_running_stats=True)
323
-
324
- # Populate running stats
325
- brainstate.environ.set(fit=True)
326
- x_batch = brainstate.random.randn(4, depth, height, width, channels)
327
- _ = net(x_batch)
328
-
329
- # Test without batch
330
- brainstate.environ.set(fit=False)
331
- x_single = brainstate.random.randn(depth, height, width, channels)
332
- output = net(x_single)
333
-
334
- self.assertEqual(output.shape, (depth, height, width, channels))
335
-
336
-
337
- class TestLayerNorm(parameterized.TestCase):
338
- """Test LayerNorm with various configurations."""
339
-
340
- @parameterized.product(
341
- reduction_axes=[(-1,), (-2, -1), (-3, -2, -1)],
342
- use_bias=[True, False],
343
- use_scale=[True, False],
344
- )
345
- def test_layernorm_basic(self, reduction_axes, use_bias, use_scale):
346
- """Test LayerNorm with different reduction axes."""
347
- in_size = (10, 20, 30)
348
-
349
- net = brainstate.nn.LayerNorm(
350
- in_size,
351
- reduction_axes=reduction_axes,
352
- use_bias=use_bias,
353
- use_scale=use_scale,
354
- )
355
-
356
- # With batch
357
- x = brainstate.random.randn(8, 10, 20, 30)
358
- output = net(x)
359
- self.assertEqual(output.shape, x.shape)
360
-
361
- # Check normalization properties
362
- mean = jnp.mean(output, axis=tuple(i + 1 for i in range(len(in_size))
363
- if i - len(in_size) in reduction_axes))
364
- var = jnp.var(output, axis=tuple(i + 1 for i in range(len(in_size))
365
- if i - len(in_size) in reduction_axes))
366
-
367
- def test_layernorm_2d_features(self):
368
- """Test LayerNorm on 2D features (like in transformers)."""
369
- seq_length = 50
370
- hidden_dim = 128
371
- batch_size = 16
372
-
373
- in_size = (seq_length, hidden_dim)
374
- net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1)
375
-
376
- x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
377
- output = net(x)
378
-
379
- self.assertEqual(output.shape, x.shape)
380
-
381
- # Each position should be normalized independently
382
- mean = jnp.mean(output, axis=-1)
383
- var = jnp.var(output, axis=-1)
384
-
385
- np.testing.assert_allclose(mean, 0.0, atol=1e-5)
386
- np.testing.assert_allclose(var, 1.0, atol=1e-1)
387
-
388
- def test_layernorm_without_batch(self):
389
- """Test LayerNorm without batch dimension."""
390
- in_size = (10, 20)
391
- net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1)
392
-
393
- x = brainstate.random.randn(10, 20)
394
- output = net(x)
395
-
396
- self.assertEqual(output.shape, (10, 20))
397
-
398
- @parameterized.product(
399
- in_size=[(10,), (10, 20), (10, 20, 30)],
400
- )
401
- def test_layernorm_various_dims(self, in_size):
402
- """Test LayerNorm with various input dimensions."""
403
- net = brainstate.nn.LayerNorm(in_size)
404
-
405
- # With batch
406
- x_with_batch = brainstate.random.randn(8, *in_size)
407
- output_with_batch = net(x_with_batch)
408
- self.assertEqual(output_with_batch.shape, x_with_batch.shape)
409
-
410
-
411
- class TestRMSNorm(parameterized.TestCase):
412
- """Test RMSNorm with various configurations."""
413
-
414
- @parameterized.product(
415
- use_scale=[True, False],
416
- reduction_axes=[(-1,), (-2, -1)],
417
- )
418
- def test_rmsnorm_basic(self, use_scale, reduction_axes):
419
- """Test RMSNorm with different configurations."""
420
- in_size = (10, 20)
421
-
422
- net = brainstate.nn.RMSNorm(
423
- in_size,
424
- use_scale=use_scale,
425
- reduction_axes=reduction_axes,
426
- )
427
-
428
- x = brainstate.random.randn(8, 10, 20)
429
- output = net(x)
430
-
431
- self.assertEqual(output.shape, x.shape)
432
-
433
- def test_rmsnorm_transformer_like(self):
434
- """Test RMSNorm in transformer-like setting."""
435
- seq_length = 50
436
- hidden_dim = 128
437
- batch_size = 16
438
-
439
- in_size = (seq_length, hidden_dim)
440
- net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1, feature_axes=-1)
441
-
442
- x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
443
- output = net(x)
444
-
445
- self.assertEqual(output.shape, x.shape)
446
-
447
- # RMSNorm should have approximately unit RMS (not zero mean)
448
- rms = jnp.sqrt(jnp.mean(jnp.square(output), axis=-1))
449
- np.testing.assert_allclose(rms, 1.0, atol=1e-1)
450
-
451
- def test_rmsnorm_without_batch(self):
452
- """Test RMSNorm without batch dimension."""
453
- in_size = (10, 20)
454
- net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1)
455
-
456
- x = brainstate.random.randn(10, 20)
457
- output = net(x)
458
-
459
- self.assertEqual(output.shape, (10, 20))
460
-
461
-
462
- class TestGroupNorm(parameterized.TestCase):
463
- """Test GroupNorm with various configurations."""
464
-
465
- @parameterized.product(
466
- num_groups=[1, 2, 4, 8],
467
- use_bias=[True, False],
468
- use_scale=[True, False],
469
- )
470
- def test_groupnorm_basic(self, num_groups, use_bias, use_scale):
471
- """Test GroupNorm with different number of groups."""
472
- channels = 16
473
- # GroupNorm requires 1D feature axis (just the channel dimension)
474
- in_size = (channels,)
475
-
476
- # Check if channels is divisible by num_groups
477
- if channels % num_groups != 0:
478
- return
479
-
480
- net = brainstate.nn.GroupNorm(
481
- in_size,
482
- feature_axis=0,
483
- num_groups=num_groups,
484
- use_bias=use_bias,
485
- use_scale=use_scale,
486
- )
487
-
488
- # Input needs at least 2D: (height, width, channels) or (batch, channels)
489
- # Using (batch, channels) format
490
- x = brainstate.random.randn(4, channels)
491
- output = net(x)
492
-
493
- self.assertEqual(output.shape, x.shape)
494
-
495
- def test_groupnorm_channel_first(self):
496
- """Test GroupNorm with channel-first format for images."""
497
- channels = 16
498
- # GroupNorm requires 1D feature (just channels)
499
- in_size = (channels,)
500
-
501
- net = brainstate.nn.GroupNorm(
502
- in_size,
503
- feature_axis=0,
504
- num_groups=4,
505
- )
506
-
507
- # Test with image-like data: (batch, height, width, channels)
508
- x = brainstate.random.randn(4, 32, 32, channels)
509
- output = net(x)
510
-
511
- self.assertEqual(output.shape, x.shape)
512
-
513
- def test_groupnorm_channel_last(self):
514
- """Test GroupNorm with channel-last format for images."""
515
- channels = 16
516
- # GroupNorm requires 1D feature (just channels)
517
- in_size = (channels,)
518
-
519
- net = brainstate.nn.GroupNorm(
520
- in_size,
521
- feature_axis=0, # feature_axis refers to position in in_size
522
- num_groups=4,
523
- )
524
-
525
- # Test with image-like data: (batch, height, width, channels)
526
- x = brainstate.random.randn(4, 32, 32, channels)
527
- output = net(x)
528
-
529
- self.assertEqual(output.shape, x.shape)
530
-
531
- def test_groupnorm_equals_layernorm(self):
532
- """Test that GroupNorm with num_groups=1 equals LayerNorm."""
533
- channels = 16
534
- # GroupNorm requires 1D feature
535
- in_size = (channels,)
536
-
537
- # GroupNorm with 1 group
538
- group_norm = brainstate.nn.GroupNorm(
539
- in_size,
540
- feature_axis=0,
541
- num_groups=1,
542
- )
543
-
544
- # LayerNorm with same setup
545
- layer_norm = brainstate.nn.LayerNorm(
546
- in_size,
547
- reduction_axes=-1,
548
- feature_axes=-1,
549
- )
550
-
551
- # Use 2D input: (batch, channels)
552
- x = brainstate.random.randn(8, channels)
553
-
554
- output_gn = group_norm(x)
555
- output_ln = layer_norm(x)
556
-
557
- # Shapes should match
558
- self.assertEqual(output_gn.shape, output_ln.shape)
559
-
560
- def test_groupnorm_group_size(self):
561
- """Test GroupNorm with group_size instead of num_groups."""
562
- channels = 16
563
- group_size = 4
564
- # GroupNorm requires 1D feature
565
- in_size = (channels,)
566
-
567
- net = brainstate.nn.GroupNorm(
568
- in_size,
569
- feature_axis=0,
570
- num_groups=None,
571
- group_size=group_size,
572
- )
573
-
574
- # Use 2D input: (batch, channels)
575
- x = brainstate.random.randn(4, channels)
576
- output = net(x)
577
-
578
- self.assertEqual(output.shape, x.shape)
579
- self.assertEqual(net.num_groups, channels // group_size)
580
-
581
- def test_groupnorm_invalid_groups(self):
582
- """Test that invalid num_groups raises error."""
583
- channels = 15 # Not divisible by many numbers
584
- # GroupNorm requires 1D feature
585
- in_size = (channels,)
586
-
587
- # Should raise error if num_groups doesn't divide channels
588
- with self.assertRaises(ValueError):
589
- net = brainstate.nn.GroupNorm(
590
- in_size,
591
- feature_axis=0,
592
- num_groups=4, # 15 is not divisible by 4
593
- )
594
-
595
-
596
- class TestNormalizationUtilities(parameterized.TestCase):
597
- """Test utility functions for normalization."""
598
-
599
- def test_weight_standardization(self):
600
- """Test weight_standardization function."""
601
- w = brainstate.random.randn(3, 4, 5, 6)
602
-
603
- w_std = brainstate.nn.weight_standardization(w, eps=1e-4)
604
-
605
- self.assertEqual(w_std.shape, w.shape)
606
-
607
- # Check that standardization works
608
- # Mean should be close to 0 along non-output axes
609
- mean = jnp.mean(w_std, axis=(0, 1, 2))
610
- np.testing.assert_allclose(mean, 0.0, atol=1e-4)
611
-
612
- def test_weight_standardization_with_gain(self):
613
- """Test weight_standardization with gain parameter."""
614
- w = brainstate.random.randn(3, 4, 5, 6)
615
- gain = jnp.ones((6,))
616
-
617
- w_std = brainstate.nn.weight_standardization(w, gain=gain)
618
-
619
- self.assertEqual(w_std.shape, w.shape)
620
-
621
-
622
- class TestNormalizationEdgeCases(parameterized.TestCase):
623
- """Test edge cases and error conditions."""
624
-
625
- def test_batchnorm_shape_mismatch(self):
626
- """Test that BatchNorm raises error on shape mismatch."""
627
- net = brainstate.nn.BatchNorm2d((28, 28, 3))
628
-
629
- # Wrong shape should raise error
630
- with self.assertRaises(ValueError):
631
- x = brainstate.random.randn(4, 32, 32, 3) # Wrong height/width
632
- _ = net(x)
633
-
634
- def test_batchnorm_without_track_and_affine(self):
635
- """Test that affine=True requires track_running_stats=True."""
636
- # This should raise an assertion error
637
- with self.assertRaises(AssertionError):
638
- net = brainstate.nn.BatchNorm2d(
639
- (28, 28, 3),
640
- track_running_stats=False,
641
- affine=True # Requires track_running_stats=True
642
- )
643
-
644
- def test_groupnorm_both_params(self):
645
- """Test that GroupNorm raises error when both num_groups and group_size are specified."""
646
- with self.assertRaises(ValueError):
647
- net = brainstate.nn.GroupNorm(
648
- (32, 32, 16),
649
- num_groups=4,
650
- group_size=4, # Can't specify both
651
- )
652
-
653
- def test_groupnorm_neither_param(self):
654
- """Test that GroupNorm raises error when neither num_groups nor group_size are specified."""
655
- with self.assertRaises(ValueError):
656
- net = brainstate.nn.GroupNorm(
657
- (32, 32, 16),
658
- num_groups=None,
659
- group_size=None, # Must specify one
660
- )
661
-
662
-
663
- class TestNormalizationConsistency(parameterized.TestCase):
664
- """Test consistency across different batch sizes and modes."""
665
-
666
- def test_batchnorm2d_consistency_across_batches(self):
667
- """Test that BatchNorm2d behaves consistently across different batch sizes."""
668
- in_size = (28, 28, 3)
669
- net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
670
-
671
- # Train on larger batch
672
- brainstate.environ.set(fit=True)
673
- x_large = brainstate.random.randn(32, 28, 28, 3)
674
- _ = net(x_large)
675
-
676
- # Test on smaller batch
677
- brainstate.environ.set(fit=False)
678
- x_small = brainstate.random.randn(4, 28, 28, 3)
679
- output = net(x_small)
680
-
681
- self.assertEqual(output.shape, x_small.shape)
682
-
683
- def test_layernorm_consistency(self):
684
- """Test that LayerNorm produces consistent results."""
685
- in_size = (10, 20)
686
- net = brainstate.nn.LayerNorm(in_size)
687
-
688
- x = brainstate.random.randn(8, 10, 20)
689
-
690
- # Run twice
691
- output1 = net(x)
692
- output2 = net(x)
693
-
694
- # Should be deterministic
695
- np.testing.assert_allclose(output1, output2)
696
-
697
-
698
- if __name__ == '__main__':
699
- absltest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from absl.testing import absltest
17
+ from absl.testing import parameterized
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+
21
+ import brainstate
22
+
23
+
24
+ class TestBatchNorm0d(parameterized.TestCase):
25
+ """Test BatchNorm0d with various configurations."""
26
+
27
+ @parameterized.product(
28
+ fit=[True, False],
29
+ feature_axis=[-1, 0],
30
+ track_running_stats=[True, False],
31
+ )
32
+ def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats):
33
+ """Test BatchNorm0d with batched input."""
34
+ batch_size = 8
35
+ channels = 10
36
+
37
+ # Channel last: (batch, channels)
38
+ if feature_axis == -1:
39
+ in_size = (channels,)
40
+ input_shape = (batch_size, channels)
41
+ # Channel first: (batch, channels) - same for 0D
42
+ else:
43
+ in_size = (channels,)
44
+ input_shape = (batch_size, channels)
45
+
46
+ # affine can only be True when track_running_stats is True
47
+ affine = track_running_stats
48
+
49
+ net = brainstate.nn.BatchNorm0d(
50
+ in_size,
51
+ feature_axis=feature_axis,
52
+ track_running_stats=track_running_stats,
53
+ affine=affine
54
+ )
55
+ brainstate.environ.set(fit=fit)
56
+
57
+ x = brainstate.random.randn(*input_shape)
58
+ output = net(x)
59
+
60
+ # Check output shape matches input
61
+ self.assertEqual(output.shape, input_shape)
62
+
63
+ # Check that output has approximately zero mean and unit variance when fitting
64
+ if fit and track_running_stats:
65
+ # Stats should be computed along batch dimension
66
+ mean = jnp.mean(output, axis=0)
67
+ var = jnp.var(output, axis=0)
68
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
69
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
70
+
71
+ def test_batchnorm0d_without_batch(self):
72
+ """Test BatchNorm0d without batching."""
73
+ channels = 10
74
+ in_size = (channels,)
75
+
76
+ net = brainstate.nn.BatchNorm0d(in_size, track_running_stats=True)
77
+ brainstate.environ.set(fit=False) # Use running stats
78
+
79
+ # Run with batch first to populate running stats
80
+ brainstate.environ.set(fit=True)
81
+ x_batch = brainstate.random.randn(16, channels)
82
+ _ = net(x_batch)
83
+
84
+ # Now test without batch
85
+ brainstate.environ.set(fit=False)
86
+ x_single = brainstate.random.randn(channels)
87
+ output = net(x_single)
88
+
89
+ self.assertEqual(output.shape, (channels,))
90
+
91
+ def test_batchnorm0d_affine(self):
92
+ """Test BatchNorm0d with and without affine parameters."""
93
+ channels = 10
94
+ in_size = (channels,)
95
+
96
+ # With affine
97
+ net_affine = brainstate.nn.BatchNorm0d(in_size, affine=True)
98
+ self.assertIsNotNone(net_affine.weight)
99
+
100
+ # Without affine (track_running_stats must be False)
101
+ net_no_affine = brainstate.nn.BatchNorm0d(
102
+ in_size, affine=False, track_running_stats=False
103
+ )
104
+ self.assertIsNone(net_no_affine.weight)
105
+
106
+
107
+ class TestBatchNorm1d(parameterized.TestCase):
108
+ """Test BatchNorm1d with various configurations."""
109
+
110
+ @parameterized.product(
111
+ fit=[True, False],
112
+ feature_axis=[-1, 0],
113
+ track_running_stats=[True, False],
114
+ )
115
+ def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats):
116
+ """Test BatchNorm1d with batched input."""
117
+ batch_size = 8
118
+ length = 20
119
+ channels = 10
120
+
121
+ # Channel last: (batch, length, channels)
122
+ if feature_axis == -1:
123
+ in_size = (length, channels)
124
+ input_shape = (batch_size, length, channels)
125
+ feature_axis_param = -1
126
+ # Channel first: (batch, channels, length)
127
+ else:
128
+ in_size = (channels, length)
129
+ input_shape = (batch_size, channels, length)
130
+ feature_axis_param = 0
131
+
132
+ # affine can only be True when track_running_stats is True
133
+ affine = track_running_stats
134
+
135
+ net = brainstate.nn.BatchNorm1d(
136
+ in_size,
137
+ feature_axis=feature_axis_param,
138
+ track_running_stats=track_running_stats,
139
+ affine=affine
140
+ )
141
+ brainstate.environ.set(fit=fit)
142
+
143
+ x = brainstate.random.randn(*input_shape)
144
+ output = net(x)
145
+
146
+ # Check output shape matches input
147
+ self.assertEqual(output.shape, input_shape)
148
+
149
+ def test_batchnorm1d_without_batch(self):
150
+ """Test BatchNorm1d without batching."""
151
+ length = 20
152
+ channels = 10
153
+ in_size = (length, channels)
154
+
155
+ net = brainstate.nn.BatchNorm1d(in_size, track_running_stats=True)
156
+
157
+ # Populate running stats first
158
+ brainstate.environ.set(fit=True)
159
+ x_batch = brainstate.random.randn(8, length, channels)
160
+ _ = net(x_batch)
161
+
162
+ # Test without batch
163
+ brainstate.environ.set(fit=False)
164
+ x_single = brainstate.random.randn(length, channels)
165
+ output = net(x_single)
166
+
167
+ self.assertEqual(output.shape, (length, channels))
168
+
169
+ @parameterized.product(
170
+ feature_axis=[-1, 0],
171
+ )
172
+ def test_batchnorm1d_channel_consistency(self, feature_axis):
173
+ """Test that normalization is consistent across different channel configurations."""
174
+ batch_size = 16
175
+ length = 20
176
+ channels = 10
177
+
178
+ if feature_axis == -1:
179
+ in_size = (length, channels)
180
+ input_shape = (batch_size, length, channels)
181
+ else:
182
+ in_size = (channels, length)
183
+ input_shape = (batch_size, channels, length)
184
+
185
+ net = brainstate.nn.BatchNorm1d(in_size, feature_axis=feature_axis)
186
+ brainstate.environ.set(fit=True)
187
+
188
+ x = brainstate.random.randn(*input_shape)
189
+ output = net(x)
190
+
191
+ # Output should have same shape as input
192
+ self.assertEqual(output.shape, input_shape)
193
+
194
+
195
+ class TestBatchNorm2d(parameterized.TestCase):
196
+ """Test BatchNorm2d with various configurations."""
197
+
198
+ @parameterized.product(
199
+ fit=[True, False],
200
+ feature_axis=[-1, 0],
201
+ track_running_stats=[True, False],
202
+ )
203
+ def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats):
204
+ """Test BatchNorm2d with batched input (images)."""
205
+ batch_size = 4
206
+ height, width = 28, 28
207
+ channels = 3
208
+
209
+ # Channel last: (batch, height, width, channels)
210
+ if feature_axis == -1:
211
+ in_size = (height, width, channels)
212
+ input_shape = (batch_size, height, width, channels)
213
+ feature_axis_param = -1
214
+ # Channel first: (batch, channels, height, width)
215
+ else:
216
+ in_size = (channels, height, width)
217
+ input_shape = (batch_size, channels, height, width)
218
+ feature_axis_param = 0
219
+
220
+ # affine can only be True when track_running_stats is True
221
+ affine = track_running_stats
222
+
223
+ net = brainstate.nn.BatchNorm2d(
224
+ in_size,
225
+ feature_axis=feature_axis_param,
226
+ track_running_stats=track_running_stats,
227
+ affine=affine
228
+ )
229
+ brainstate.environ.set(fit=fit)
230
+
231
+ x = brainstate.random.randn(*input_shape)
232
+ output = net(x)
233
+
234
+ # Check output shape matches input
235
+ self.assertEqual(output.shape, input_shape)
236
+
237
+ # Check normalization properties during training
238
+ if fit and track_running_stats:
239
+ # For channel last: normalize over (batch, height, width)
240
+ # For channel first: normalize over (batch, height, width)
241
+ if feature_axis == -1:
242
+ axes = (0, 1, 2)
243
+ else:
244
+ axes = (0, 2, 3)
245
+
246
+ mean = jnp.mean(output, axis=axes)
247
+ var = jnp.var(output, axis=axes)
248
+
249
+ # Each channel should be approximately normalized
250
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
251
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
252
+
253
+ def test_batchnorm2d_without_batch(self):
254
+ """Test BatchNorm2d without batching."""
255
+ height, width = 28, 28
256
+ channels = 3
257
+ in_size = (height, width, channels)
258
+
259
+ net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
260
+
261
+ # Populate running stats
262
+ brainstate.environ.set(fit=True)
263
+ x_batch = brainstate.random.randn(8, height, width, channels)
264
+ _ = net(x_batch)
265
+
266
+ # Test without batch
267
+ brainstate.environ.set(fit=False)
268
+ x_single = brainstate.random.randn(height, width, channels)
269
+ output = net(x_single)
270
+
271
+ self.assertEqual(output.shape, (height, width, channels))
272
+
273
+
274
+ class TestBatchNorm3d(parameterized.TestCase):
275
+ """Test BatchNorm3d with various configurations."""
276
+
277
+ @parameterized.product(
278
+ fit=[True, False],
279
+ feature_axis=[-1, 0],
280
+ track_running_stats=[True, False],
281
+ )
282
+ def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats):
283
+ """Test BatchNorm3d with batched input (volumes)."""
284
+ batch_size = 2
285
+ depth, height, width = 8, 16, 16
286
+ channels = 2
287
+
288
+ # Channel last: (batch, depth, height, width, channels)
289
+ if feature_axis == -1:
290
+ in_size = (depth, height, width, channels)
291
+ input_shape = (batch_size, depth, height, width, channels)
292
+ feature_axis_param = -1
293
+ # Channel first: (batch, channels, depth, height, width)
294
+ else:
295
+ in_size = (channels, depth, height, width)
296
+ input_shape = (batch_size, channels, depth, height, width)
297
+ feature_axis_param = 0
298
+
299
+ # affine can only be True when track_running_stats is True
300
+ affine = track_running_stats
301
+
302
+ net = brainstate.nn.BatchNorm3d(
303
+ in_size,
304
+ feature_axis=feature_axis_param,
305
+ track_running_stats=track_running_stats,
306
+ affine=affine
307
+ )
308
+ brainstate.environ.set(fit=fit)
309
+
310
+ x = brainstate.random.randn(*input_shape)
311
+ output = net(x)
312
+
313
+ # Check output shape matches input
314
+ self.assertEqual(output.shape, input_shape)
315
+
316
+ def test_batchnorm3d_without_batch(self):
317
+ """Test BatchNorm3d without batching."""
318
+ depth, height, width = 8, 16, 16
319
+ channels = 2
320
+ in_size = (depth, height, width, channels)
321
+
322
+ net = brainstate.nn.BatchNorm3d(in_size, track_running_stats=True)
323
+
324
+ # Populate running stats
325
+ brainstate.environ.set(fit=True)
326
+ x_batch = brainstate.random.randn(4, depth, height, width, channels)
327
+ _ = net(x_batch)
328
+
329
+ # Test without batch
330
+ brainstate.environ.set(fit=False)
331
+ x_single = brainstate.random.randn(depth, height, width, channels)
332
+ output = net(x_single)
333
+
334
+ self.assertEqual(output.shape, (depth, height, width, channels))
335
+
336
+
337
+ class TestLayerNorm(parameterized.TestCase):
338
+ """Test LayerNorm with various configurations."""
339
+
340
+ @parameterized.product(
341
+ reduction_axes=[(-1,), (-2, -1), (-3, -2, -1)],
342
+ use_bias=[True, False],
343
+ use_scale=[True, False],
344
+ )
345
+ def test_layernorm_basic(self, reduction_axes, use_bias, use_scale):
346
+ """Test LayerNorm with different reduction axes."""
347
+ in_size = (10, 20, 30)
348
+
349
+ net = brainstate.nn.LayerNorm(
350
+ in_size,
351
+ reduction_axes=reduction_axes,
352
+ use_bias=use_bias,
353
+ use_scale=use_scale,
354
+ )
355
+
356
+ # With batch
357
+ x = brainstate.random.randn(8, 10, 20, 30)
358
+ output = net(x)
359
+ self.assertEqual(output.shape, x.shape)
360
+
361
+ # Check normalization properties
362
+ mean = jnp.mean(output, axis=tuple(i + 1 for i in range(len(in_size))
363
+ if i - len(in_size) in reduction_axes))
364
+ var = jnp.var(output, axis=tuple(i + 1 for i in range(len(in_size))
365
+ if i - len(in_size) in reduction_axes))
366
+
367
+ def test_layernorm_2d_features(self):
368
+ """Test LayerNorm on 2D features (like in transformers)."""
369
+ seq_length = 50
370
+ hidden_dim = 128
371
+ batch_size = 16
372
+
373
+ in_size = (seq_length, hidden_dim)
374
+ net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1)
375
+
376
+ x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
377
+ output = net(x)
378
+
379
+ self.assertEqual(output.shape, x.shape)
380
+
381
+ # Each position should be normalized independently
382
+ mean = jnp.mean(output, axis=-1)
383
+ var = jnp.var(output, axis=-1)
384
+
385
+ np.testing.assert_allclose(mean, 0.0, atol=1e-5)
386
+ np.testing.assert_allclose(var, 1.0, atol=1e-1)
387
+
388
+ def test_layernorm_without_batch(self):
389
+ """Test LayerNorm without batch dimension."""
390
+ in_size = (10, 20)
391
+ net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1)
392
+
393
+ x = brainstate.random.randn(10, 20)
394
+ output = net(x)
395
+
396
+ self.assertEqual(output.shape, (10, 20))
397
+
398
+ @parameterized.product(
399
+ in_size=[(10,), (10, 20), (10, 20, 30)],
400
+ )
401
+ def test_layernorm_various_dims(self, in_size):
402
+ """Test LayerNorm with various input dimensions."""
403
+ net = brainstate.nn.LayerNorm(in_size)
404
+
405
+ # With batch
406
+ x_with_batch = brainstate.random.randn(8, *in_size)
407
+ output_with_batch = net(x_with_batch)
408
+ self.assertEqual(output_with_batch.shape, x_with_batch.shape)
409
+
410
+
411
+ class TestRMSNorm(parameterized.TestCase):
412
+ """Test RMSNorm with various configurations."""
413
+
414
+ @parameterized.product(
415
+ use_scale=[True, False],
416
+ reduction_axes=[(-1,), (-2, -1)],
417
+ )
418
+ def test_rmsnorm_basic(self, use_scale, reduction_axes):
419
+ """Test RMSNorm with different configurations."""
420
+ in_size = (10, 20)
421
+
422
+ net = brainstate.nn.RMSNorm(
423
+ in_size,
424
+ use_scale=use_scale,
425
+ reduction_axes=reduction_axes,
426
+ )
427
+
428
+ x = brainstate.random.randn(8, 10, 20)
429
+ output = net(x)
430
+
431
+ self.assertEqual(output.shape, x.shape)
432
+
433
+ def test_rmsnorm_transformer_like(self):
434
+ """Test RMSNorm in transformer-like setting."""
435
+ seq_length = 50
436
+ hidden_dim = 128
437
+ batch_size = 16
438
+
439
+ in_size = (seq_length, hidden_dim)
440
+ net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1, feature_axes=-1)
441
+
442
+ x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
443
+ output = net(x)
444
+
445
+ self.assertEqual(output.shape, x.shape)
446
+
447
+ # RMSNorm should have approximately unit RMS (not zero mean)
448
+ rms = jnp.sqrt(jnp.mean(jnp.square(output), axis=-1))
449
+ np.testing.assert_allclose(rms, 1.0, atol=1e-1)
450
+
451
+ def test_rmsnorm_without_batch(self):
452
+ """Test RMSNorm without batch dimension."""
453
+ in_size = (10, 20)
454
+ net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1)
455
+
456
+ x = brainstate.random.randn(10, 20)
457
+ output = net(x)
458
+
459
+ self.assertEqual(output.shape, (10, 20))
460
+
461
+
462
+ class TestGroupNorm(parameterized.TestCase):
463
+ """Test GroupNorm with various configurations."""
464
+
465
+ @parameterized.product(
466
+ num_groups=[1, 2, 4, 8],
467
+ use_bias=[True, False],
468
+ use_scale=[True, False],
469
+ )
470
+ def test_groupnorm_basic(self, num_groups, use_bias, use_scale):
471
+ """Test GroupNorm with different number of groups."""
472
+ channels = 16
473
+ # GroupNorm requires 1D feature axis (just the channel dimension)
474
+ in_size = (channels,)
475
+
476
+ # Check if channels is divisible by num_groups
477
+ if channels % num_groups != 0:
478
+ return
479
+
480
+ net = brainstate.nn.GroupNorm(
481
+ in_size,
482
+ feature_axis=0,
483
+ num_groups=num_groups,
484
+ use_bias=use_bias,
485
+ use_scale=use_scale,
486
+ )
487
+
488
+ # Input needs at least 2D: (height, width, channels) or (batch, channels)
489
+ # Using (batch, channels) format
490
+ x = brainstate.random.randn(4, channels)
491
+ output = net(x)
492
+
493
+ self.assertEqual(output.shape, x.shape)
494
+
495
+ def test_groupnorm_channel_first(self):
496
+ """Test GroupNorm with channel-first format for images."""
497
+ channels = 16
498
+ # GroupNorm requires 1D feature (just channels)
499
+ in_size = (channels,)
500
+
501
+ net = brainstate.nn.GroupNorm(
502
+ in_size,
503
+ feature_axis=0,
504
+ num_groups=4,
505
+ )
506
+
507
+ # Test with image-like data: (batch, height, width, channels)
508
+ x = brainstate.random.randn(4, 32, 32, channels)
509
+ output = net(x)
510
+
511
+ self.assertEqual(output.shape, x.shape)
512
+
513
+ def test_groupnorm_channel_last(self):
514
+ """Test GroupNorm with channel-last format for images."""
515
+ channels = 16
516
+ # GroupNorm requires 1D feature (just channels)
517
+ in_size = (channels,)
518
+
519
+ net = brainstate.nn.GroupNorm(
520
+ in_size,
521
+ feature_axis=0, # feature_axis refers to position in in_size
522
+ num_groups=4,
523
+ )
524
+
525
+ # Test with image-like data: (batch, height, width, channels)
526
+ x = brainstate.random.randn(4, 32, 32, channels)
527
+ output = net(x)
528
+
529
+ self.assertEqual(output.shape, x.shape)
530
+
531
+ def test_groupnorm_equals_layernorm(self):
532
+ """Test that GroupNorm with num_groups=1 equals LayerNorm."""
533
+ channels = 16
534
+ # GroupNorm requires 1D feature
535
+ in_size = (channels,)
536
+
537
+ # GroupNorm with 1 group
538
+ group_norm = brainstate.nn.GroupNorm(
539
+ in_size,
540
+ feature_axis=0,
541
+ num_groups=1,
542
+ )
543
+
544
+ # LayerNorm with same setup
545
+ layer_norm = brainstate.nn.LayerNorm(
546
+ in_size,
547
+ reduction_axes=-1,
548
+ feature_axes=-1,
549
+ )
550
+
551
+ # Use 2D input: (batch, channels)
552
+ x = brainstate.random.randn(8, channels)
553
+
554
+ output_gn = group_norm(x)
555
+ output_ln = layer_norm(x)
556
+
557
+ # Shapes should match
558
+ self.assertEqual(output_gn.shape, output_ln.shape)
559
+
560
+ def test_groupnorm_group_size(self):
561
+ """Test GroupNorm with group_size instead of num_groups."""
562
+ channels = 16
563
+ group_size = 4
564
+ # GroupNorm requires 1D feature
565
+ in_size = (channels,)
566
+
567
+ net = brainstate.nn.GroupNorm(
568
+ in_size,
569
+ feature_axis=0,
570
+ num_groups=None,
571
+ group_size=group_size,
572
+ )
573
+
574
+ # Use 2D input: (batch, channels)
575
+ x = brainstate.random.randn(4, channels)
576
+ output = net(x)
577
+
578
+ self.assertEqual(output.shape, x.shape)
579
+ self.assertEqual(net.num_groups, channels // group_size)
580
+
581
+ def test_groupnorm_invalid_groups(self):
582
+ """Test that invalid num_groups raises error."""
583
+ channels = 15 # Not divisible by many numbers
584
+ # GroupNorm requires 1D feature
585
+ in_size = (channels,)
586
+
587
+ # Should raise error if num_groups doesn't divide channels
588
+ with self.assertRaises(ValueError):
589
+ net = brainstate.nn.GroupNorm(
590
+ in_size,
591
+ feature_axis=0,
592
+ num_groups=4, # 15 is not divisible by 4
593
+ )
594
+
595
+
596
+ class TestNormalizationUtilities(parameterized.TestCase):
597
+ """Test utility functions for normalization."""
598
+
599
+ def test_weight_standardization(self):
600
+ """Test weight_standardization function."""
601
+ w = brainstate.random.randn(3, 4, 5, 6)
602
+
603
+ w_std = brainstate.nn.weight_standardization(w, eps=1e-4)
604
+
605
+ self.assertEqual(w_std.shape, w.shape)
606
+
607
+ # Check that standardization works
608
+ # Mean should be close to 0 along non-output axes
609
+ mean = jnp.mean(w_std, axis=(0, 1, 2))
610
+ np.testing.assert_allclose(mean, 0.0, atol=1e-4)
611
+
612
+ def test_weight_standardization_with_gain(self):
613
+ """Test weight_standardization with gain parameter."""
614
+ w = brainstate.random.randn(3, 4, 5, 6)
615
+ gain = jnp.ones((6,))
616
+
617
+ w_std = brainstate.nn.weight_standardization(w, gain=gain)
618
+
619
+ self.assertEqual(w_std.shape, w.shape)
620
+
621
+
622
+ class TestNormalizationEdgeCases(parameterized.TestCase):
623
+ """Test edge cases and error conditions."""
624
+
625
+ def test_batchnorm_shape_mismatch(self):
626
+ """Test that BatchNorm raises error on shape mismatch."""
627
+ net = brainstate.nn.BatchNorm2d((28, 28, 3))
628
+
629
+ # Wrong shape should raise error
630
+ with self.assertRaises(ValueError):
631
+ x = brainstate.random.randn(4, 32, 32, 3) # Wrong height/width
632
+ _ = net(x)
633
+
634
+ def test_batchnorm_without_track_and_affine(self):
635
+ """Test that affine=True requires track_running_stats=True."""
636
+ # This should raise an assertion error
637
+ with self.assertRaises(AssertionError):
638
+ net = brainstate.nn.BatchNorm2d(
639
+ (28, 28, 3),
640
+ track_running_stats=False,
641
+ affine=True # Requires track_running_stats=True
642
+ )
643
+
644
+ def test_groupnorm_both_params(self):
645
+ """Test that GroupNorm raises error when both num_groups and group_size are specified."""
646
+ with self.assertRaises(ValueError):
647
+ net = brainstate.nn.GroupNorm(
648
+ (32, 32, 16),
649
+ num_groups=4,
650
+ group_size=4, # Can't specify both
651
+ )
652
+
653
+ def test_groupnorm_neither_param(self):
654
+ """Test that GroupNorm raises error when neither num_groups nor group_size are specified."""
655
+ with self.assertRaises(ValueError):
656
+ net = brainstate.nn.GroupNorm(
657
+ (32, 32, 16),
658
+ num_groups=None,
659
+ group_size=None, # Must specify one
660
+ )
661
+
662
+
663
+ class TestNormalizationConsistency(parameterized.TestCase):
664
+ """Test consistency across different batch sizes and modes."""
665
+
666
+ def test_batchnorm2d_consistency_across_batches(self):
667
+ """Test that BatchNorm2d behaves consistently across different batch sizes."""
668
+ in_size = (28, 28, 3)
669
+ net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
670
+
671
+ # Train on larger batch
672
+ brainstate.environ.set(fit=True)
673
+ x_large = brainstate.random.randn(32, 28, 28, 3)
674
+ _ = net(x_large)
675
+
676
+ # Test on smaller batch
677
+ brainstate.environ.set(fit=False)
678
+ x_small = brainstate.random.randn(4, 28, 28, 3)
679
+ output = net(x_small)
680
+
681
+ self.assertEqual(output.shape, x_small.shape)
682
+
683
+ def test_layernorm_consistency(self):
684
+ """Test that LayerNorm produces consistent results."""
685
+ in_size = (10, 20)
686
+ net = brainstate.nn.LayerNorm(in_size)
687
+
688
+ x = brainstate.random.randn(8, 10, 20)
689
+
690
+ # Run twice
691
+ output1 = net(x)
692
+ output2 = net(x)
693
+
694
+ # Should be deterministic
695
+ np.testing.assert_allclose(output1, output2)
696
+
697
+
698
+ if __name__ == '__main__':
699
+ absltest.main()