brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,699 +1,699 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from absl.testing import absltest
|
17
|
-
from absl.testing import parameterized
|
18
|
-
import jax.numpy as jnp
|
19
|
-
import numpy as np
|
20
|
-
|
21
|
-
import brainstate
|
22
|
-
|
23
|
-
|
24
|
-
class TestBatchNorm0d(parameterized.TestCase):
|
25
|
-
"""Test BatchNorm0d with various configurations."""
|
26
|
-
|
27
|
-
@parameterized.product(
|
28
|
-
fit=[True, False],
|
29
|
-
feature_axis=[-1, 0],
|
30
|
-
track_running_stats=[True, False],
|
31
|
-
)
|
32
|
-
def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats):
|
33
|
-
"""Test BatchNorm0d with batched input."""
|
34
|
-
batch_size = 8
|
35
|
-
channels = 10
|
36
|
-
|
37
|
-
# Channel last: (batch, channels)
|
38
|
-
if feature_axis == -1:
|
39
|
-
in_size = (channels,)
|
40
|
-
input_shape = (batch_size, channels)
|
41
|
-
# Channel first: (batch, channels) - same for 0D
|
42
|
-
else:
|
43
|
-
in_size = (channels,)
|
44
|
-
input_shape = (batch_size, channels)
|
45
|
-
|
46
|
-
# affine can only be True when track_running_stats is True
|
47
|
-
affine = track_running_stats
|
48
|
-
|
49
|
-
net = brainstate.nn.BatchNorm0d(
|
50
|
-
in_size,
|
51
|
-
feature_axis=feature_axis,
|
52
|
-
track_running_stats=track_running_stats,
|
53
|
-
affine=affine
|
54
|
-
)
|
55
|
-
brainstate.environ.set(fit=fit)
|
56
|
-
|
57
|
-
x = brainstate.random.randn(*input_shape)
|
58
|
-
output = net(x)
|
59
|
-
|
60
|
-
# Check output shape matches input
|
61
|
-
self.assertEqual(output.shape, input_shape)
|
62
|
-
|
63
|
-
# Check that output has approximately zero mean and unit variance when fitting
|
64
|
-
if fit and track_running_stats:
|
65
|
-
# Stats should be computed along batch dimension
|
66
|
-
mean = jnp.mean(output, axis=0)
|
67
|
-
var = jnp.var(output, axis=0)
|
68
|
-
np.testing.assert_allclose(mean, 0.0, atol=1e-5)
|
69
|
-
np.testing.assert_allclose(var, 1.0, atol=1e-1)
|
70
|
-
|
71
|
-
def test_batchnorm0d_without_batch(self):
|
72
|
-
"""Test BatchNorm0d without batching."""
|
73
|
-
channels = 10
|
74
|
-
in_size = (channels,)
|
75
|
-
|
76
|
-
net = brainstate.nn.BatchNorm0d(in_size, track_running_stats=True)
|
77
|
-
brainstate.environ.set(fit=False) # Use running stats
|
78
|
-
|
79
|
-
# Run with batch first to populate running stats
|
80
|
-
brainstate.environ.set(fit=True)
|
81
|
-
x_batch = brainstate.random.randn(16, channels)
|
82
|
-
_ = net(x_batch)
|
83
|
-
|
84
|
-
# Now test without batch
|
85
|
-
brainstate.environ.set(fit=False)
|
86
|
-
x_single = brainstate.random.randn(channels)
|
87
|
-
output = net(x_single)
|
88
|
-
|
89
|
-
self.assertEqual(output.shape, (channels,))
|
90
|
-
|
91
|
-
def test_batchnorm0d_affine(self):
|
92
|
-
"""Test BatchNorm0d with and without affine parameters."""
|
93
|
-
channels = 10
|
94
|
-
in_size = (channels,)
|
95
|
-
|
96
|
-
# With affine
|
97
|
-
net_affine = brainstate.nn.BatchNorm0d(in_size, affine=True)
|
98
|
-
self.assertIsNotNone(net_affine.weight)
|
99
|
-
|
100
|
-
# Without affine (track_running_stats must be False)
|
101
|
-
net_no_affine = brainstate.nn.BatchNorm0d(
|
102
|
-
in_size, affine=False, track_running_stats=False
|
103
|
-
)
|
104
|
-
self.assertIsNone(net_no_affine.weight)
|
105
|
-
|
106
|
-
|
107
|
-
class TestBatchNorm1d(parameterized.TestCase):
|
108
|
-
"""Test BatchNorm1d with various configurations."""
|
109
|
-
|
110
|
-
@parameterized.product(
|
111
|
-
fit=[True, False],
|
112
|
-
feature_axis=[-1, 0],
|
113
|
-
track_running_stats=[True, False],
|
114
|
-
)
|
115
|
-
def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats):
|
116
|
-
"""Test BatchNorm1d with batched input."""
|
117
|
-
batch_size = 8
|
118
|
-
length = 20
|
119
|
-
channels = 10
|
120
|
-
|
121
|
-
# Channel last: (batch, length, channels)
|
122
|
-
if feature_axis == -1:
|
123
|
-
in_size = (length, channels)
|
124
|
-
input_shape = (batch_size, length, channels)
|
125
|
-
feature_axis_param = -1
|
126
|
-
# Channel first: (batch, channels, length)
|
127
|
-
else:
|
128
|
-
in_size = (channels, length)
|
129
|
-
input_shape = (batch_size, channels, length)
|
130
|
-
feature_axis_param = 0
|
131
|
-
|
132
|
-
# affine can only be True when track_running_stats is True
|
133
|
-
affine = track_running_stats
|
134
|
-
|
135
|
-
net = brainstate.nn.BatchNorm1d(
|
136
|
-
in_size,
|
137
|
-
feature_axis=feature_axis_param,
|
138
|
-
track_running_stats=track_running_stats,
|
139
|
-
affine=affine
|
140
|
-
)
|
141
|
-
brainstate.environ.set(fit=fit)
|
142
|
-
|
143
|
-
x = brainstate.random.randn(*input_shape)
|
144
|
-
output = net(x)
|
145
|
-
|
146
|
-
# Check output shape matches input
|
147
|
-
self.assertEqual(output.shape, input_shape)
|
148
|
-
|
149
|
-
def test_batchnorm1d_without_batch(self):
|
150
|
-
"""Test BatchNorm1d without batching."""
|
151
|
-
length = 20
|
152
|
-
channels = 10
|
153
|
-
in_size = (length, channels)
|
154
|
-
|
155
|
-
net = brainstate.nn.BatchNorm1d(in_size, track_running_stats=True)
|
156
|
-
|
157
|
-
# Populate running stats first
|
158
|
-
brainstate.environ.set(fit=True)
|
159
|
-
x_batch = brainstate.random.randn(8, length, channels)
|
160
|
-
_ = net(x_batch)
|
161
|
-
|
162
|
-
# Test without batch
|
163
|
-
brainstate.environ.set(fit=False)
|
164
|
-
x_single = brainstate.random.randn(length, channels)
|
165
|
-
output = net(x_single)
|
166
|
-
|
167
|
-
self.assertEqual(output.shape, (length, channels))
|
168
|
-
|
169
|
-
@parameterized.product(
|
170
|
-
feature_axis=[-1, 0],
|
171
|
-
)
|
172
|
-
def test_batchnorm1d_channel_consistency(self, feature_axis):
|
173
|
-
"""Test that normalization is consistent across different channel configurations."""
|
174
|
-
batch_size = 16
|
175
|
-
length = 20
|
176
|
-
channels = 10
|
177
|
-
|
178
|
-
if feature_axis == -1:
|
179
|
-
in_size = (length, channels)
|
180
|
-
input_shape = (batch_size, length, channels)
|
181
|
-
else:
|
182
|
-
in_size = (channels, length)
|
183
|
-
input_shape = (batch_size, channels, length)
|
184
|
-
|
185
|
-
net = brainstate.nn.BatchNorm1d(in_size, feature_axis=feature_axis)
|
186
|
-
brainstate.environ.set(fit=True)
|
187
|
-
|
188
|
-
x = brainstate.random.randn(*input_shape)
|
189
|
-
output = net(x)
|
190
|
-
|
191
|
-
# Output should have same shape as input
|
192
|
-
self.assertEqual(output.shape, input_shape)
|
193
|
-
|
194
|
-
|
195
|
-
class TestBatchNorm2d(parameterized.TestCase):
|
196
|
-
"""Test BatchNorm2d with various configurations."""
|
197
|
-
|
198
|
-
@parameterized.product(
|
199
|
-
fit=[True, False],
|
200
|
-
feature_axis=[-1, 0],
|
201
|
-
track_running_stats=[True, False],
|
202
|
-
)
|
203
|
-
def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats):
|
204
|
-
"""Test BatchNorm2d with batched input (images)."""
|
205
|
-
batch_size = 4
|
206
|
-
height, width = 28, 28
|
207
|
-
channels = 3
|
208
|
-
|
209
|
-
# Channel last: (batch, height, width, channels)
|
210
|
-
if feature_axis == -1:
|
211
|
-
in_size = (height, width, channels)
|
212
|
-
input_shape = (batch_size, height, width, channels)
|
213
|
-
feature_axis_param = -1
|
214
|
-
# Channel first: (batch, channels, height, width)
|
215
|
-
else:
|
216
|
-
in_size = (channels, height, width)
|
217
|
-
input_shape = (batch_size, channels, height, width)
|
218
|
-
feature_axis_param = 0
|
219
|
-
|
220
|
-
# affine can only be True when track_running_stats is True
|
221
|
-
affine = track_running_stats
|
222
|
-
|
223
|
-
net = brainstate.nn.BatchNorm2d(
|
224
|
-
in_size,
|
225
|
-
feature_axis=feature_axis_param,
|
226
|
-
track_running_stats=track_running_stats,
|
227
|
-
affine=affine
|
228
|
-
)
|
229
|
-
brainstate.environ.set(fit=fit)
|
230
|
-
|
231
|
-
x = brainstate.random.randn(*input_shape)
|
232
|
-
output = net(x)
|
233
|
-
|
234
|
-
# Check output shape matches input
|
235
|
-
self.assertEqual(output.shape, input_shape)
|
236
|
-
|
237
|
-
# Check normalization properties during training
|
238
|
-
if fit and track_running_stats:
|
239
|
-
# For channel last: normalize over (batch, height, width)
|
240
|
-
# For channel first: normalize over (batch, height, width)
|
241
|
-
if feature_axis == -1:
|
242
|
-
axes = (0, 1, 2)
|
243
|
-
else:
|
244
|
-
axes = (0, 2, 3)
|
245
|
-
|
246
|
-
mean = jnp.mean(output, axis=axes)
|
247
|
-
var = jnp.var(output, axis=axes)
|
248
|
-
|
249
|
-
# Each channel should be approximately normalized
|
250
|
-
np.testing.assert_allclose(mean, 0.0, atol=1e-5)
|
251
|
-
np.testing.assert_allclose(var, 1.0, atol=1e-1)
|
252
|
-
|
253
|
-
def test_batchnorm2d_without_batch(self):
|
254
|
-
"""Test BatchNorm2d without batching."""
|
255
|
-
height, width = 28, 28
|
256
|
-
channels = 3
|
257
|
-
in_size = (height, width, channels)
|
258
|
-
|
259
|
-
net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
|
260
|
-
|
261
|
-
# Populate running stats
|
262
|
-
brainstate.environ.set(fit=True)
|
263
|
-
x_batch = brainstate.random.randn(8, height, width, channels)
|
264
|
-
_ = net(x_batch)
|
265
|
-
|
266
|
-
# Test without batch
|
267
|
-
brainstate.environ.set(fit=False)
|
268
|
-
x_single = brainstate.random.randn(height, width, channels)
|
269
|
-
output = net(x_single)
|
270
|
-
|
271
|
-
self.assertEqual(output.shape, (height, width, channels))
|
272
|
-
|
273
|
-
|
274
|
-
class TestBatchNorm3d(parameterized.TestCase):
|
275
|
-
"""Test BatchNorm3d with various configurations."""
|
276
|
-
|
277
|
-
@parameterized.product(
|
278
|
-
fit=[True, False],
|
279
|
-
feature_axis=[-1, 0],
|
280
|
-
track_running_stats=[True, False],
|
281
|
-
)
|
282
|
-
def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats):
|
283
|
-
"""Test BatchNorm3d with batched input (volumes)."""
|
284
|
-
batch_size = 2
|
285
|
-
depth, height, width = 8, 16, 16
|
286
|
-
channels = 2
|
287
|
-
|
288
|
-
# Channel last: (batch, depth, height, width, channels)
|
289
|
-
if feature_axis == -1:
|
290
|
-
in_size = (depth, height, width, channels)
|
291
|
-
input_shape = (batch_size, depth, height, width, channels)
|
292
|
-
feature_axis_param = -1
|
293
|
-
# Channel first: (batch, channels, depth, height, width)
|
294
|
-
else:
|
295
|
-
in_size = (channels, depth, height, width)
|
296
|
-
input_shape = (batch_size, channels, depth, height, width)
|
297
|
-
feature_axis_param = 0
|
298
|
-
|
299
|
-
# affine can only be True when track_running_stats is True
|
300
|
-
affine = track_running_stats
|
301
|
-
|
302
|
-
net = brainstate.nn.BatchNorm3d(
|
303
|
-
in_size,
|
304
|
-
feature_axis=feature_axis_param,
|
305
|
-
track_running_stats=track_running_stats,
|
306
|
-
affine=affine
|
307
|
-
)
|
308
|
-
brainstate.environ.set(fit=fit)
|
309
|
-
|
310
|
-
x = brainstate.random.randn(*input_shape)
|
311
|
-
output = net(x)
|
312
|
-
|
313
|
-
# Check output shape matches input
|
314
|
-
self.assertEqual(output.shape, input_shape)
|
315
|
-
|
316
|
-
def test_batchnorm3d_without_batch(self):
|
317
|
-
"""Test BatchNorm3d without batching."""
|
318
|
-
depth, height, width = 8, 16, 16
|
319
|
-
channels = 2
|
320
|
-
in_size = (depth, height, width, channels)
|
321
|
-
|
322
|
-
net = brainstate.nn.BatchNorm3d(in_size, track_running_stats=True)
|
323
|
-
|
324
|
-
# Populate running stats
|
325
|
-
brainstate.environ.set(fit=True)
|
326
|
-
x_batch = brainstate.random.randn(4, depth, height, width, channels)
|
327
|
-
_ = net(x_batch)
|
328
|
-
|
329
|
-
# Test without batch
|
330
|
-
brainstate.environ.set(fit=False)
|
331
|
-
x_single = brainstate.random.randn(depth, height, width, channels)
|
332
|
-
output = net(x_single)
|
333
|
-
|
334
|
-
self.assertEqual(output.shape, (depth, height, width, channels))
|
335
|
-
|
336
|
-
|
337
|
-
class TestLayerNorm(parameterized.TestCase):
|
338
|
-
"""Test LayerNorm with various configurations."""
|
339
|
-
|
340
|
-
@parameterized.product(
|
341
|
-
reduction_axes=[(-1,), (-2, -1), (-3, -2, -1)],
|
342
|
-
use_bias=[True, False],
|
343
|
-
use_scale=[True, False],
|
344
|
-
)
|
345
|
-
def test_layernorm_basic(self, reduction_axes, use_bias, use_scale):
|
346
|
-
"""Test LayerNorm with different reduction axes."""
|
347
|
-
in_size = (10, 20, 30)
|
348
|
-
|
349
|
-
net = brainstate.nn.LayerNorm(
|
350
|
-
in_size,
|
351
|
-
reduction_axes=reduction_axes,
|
352
|
-
use_bias=use_bias,
|
353
|
-
use_scale=use_scale,
|
354
|
-
)
|
355
|
-
|
356
|
-
# With batch
|
357
|
-
x = brainstate.random.randn(8, 10, 20, 30)
|
358
|
-
output = net(x)
|
359
|
-
self.assertEqual(output.shape, x.shape)
|
360
|
-
|
361
|
-
# Check normalization properties
|
362
|
-
mean = jnp.mean(output, axis=tuple(i + 1 for i in range(len(in_size))
|
363
|
-
if i - len(in_size) in reduction_axes))
|
364
|
-
var = jnp.var(output, axis=tuple(i + 1 for i in range(len(in_size))
|
365
|
-
if i - len(in_size) in reduction_axes))
|
366
|
-
|
367
|
-
def test_layernorm_2d_features(self):
|
368
|
-
"""Test LayerNorm on 2D features (like in transformers)."""
|
369
|
-
seq_length = 50
|
370
|
-
hidden_dim = 128
|
371
|
-
batch_size = 16
|
372
|
-
|
373
|
-
in_size = (seq_length, hidden_dim)
|
374
|
-
net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1)
|
375
|
-
|
376
|
-
x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
|
377
|
-
output = net(x)
|
378
|
-
|
379
|
-
self.assertEqual(output.shape, x.shape)
|
380
|
-
|
381
|
-
# Each position should be normalized independently
|
382
|
-
mean = jnp.mean(output, axis=-1)
|
383
|
-
var = jnp.var(output, axis=-1)
|
384
|
-
|
385
|
-
np.testing.assert_allclose(mean, 0.0, atol=1e-5)
|
386
|
-
np.testing.assert_allclose(var, 1.0, atol=1e-1)
|
387
|
-
|
388
|
-
def test_layernorm_without_batch(self):
|
389
|
-
"""Test LayerNorm without batch dimension."""
|
390
|
-
in_size = (10, 20)
|
391
|
-
net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1)
|
392
|
-
|
393
|
-
x = brainstate.random.randn(10, 20)
|
394
|
-
output = net(x)
|
395
|
-
|
396
|
-
self.assertEqual(output.shape, (10, 20))
|
397
|
-
|
398
|
-
@parameterized.product(
|
399
|
-
in_size=[(10,), (10, 20), (10, 20, 30)],
|
400
|
-
)
|
401
|
-
def test_layernorm_various_dims(self, in_size):
|
402
|
-
"""Test LayerNorm with various input dimensions."""
|
403
|
-
net = brainstate.nn.LayerNorm(in_size)
|
404
|
-
|
405
|
-
# With batch
|
406
|
-
x_with_batch = brainstate.random.randn(8, *in_size)
|
407
|
-
output_with_batch = net(x_with_batch)
|
408
|
-
self.assertEqual(output_with_batch.shape, x_with_batch.shape)
|
409
|
-
|
410
|
-
|
411
|
-
class TestRMSNorm(parameterized.TestCase):
|
412
|
-
"""Test RMSNorm with various configurations."""
|
413
|
-
|
414
|
-
@parameterized.product(
|
415
|
-
use_scale=[True, False],
|
416
|
-
reduction_axes=[(-1,), (-2, -1)],
|
417
|
-
)
|
418
|
-
def test_rmsnorm_basic(self, use_scale, reduction_axes):
|
419
|
-
"""Test RMSNorm with different configurations."""
|
420
|
-
in_size = (10, 20)
|
421
|
-
|
422
|
-
net = brainstate.nn.RMSNorm(
|
423
|
-
in_size,
|
424
|
-
use_scale=use_scale,
|
425
|
-
reduction_axes=reduction_axes,
|
426
|
-
)
|
427
|
-
|
428
|
-
x = brainstate.random.randn(8, 10, 20)
|
429
|
-
output = net(x)
|
430
|
-
|
431
|
-
self.assertEqual(output.shape, x.shape)
|
432
|
-
|
433
|
-
def test_rmsnorm_transformer_like(self):
|
434
|
-
"""Test RMSNorm in transformer-like setting."""
|
435
|
-
seq_length = 50
|
436
|
-
hidden_dim = 128
|
437
|
-
batch_size = 16
|
438
|
-
|
439
|
-
in_size = (seq_length, hidden_dim)
|
440
|
-
net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1, feature_axes=-1)
|
441
|
-
|
442
|
-
x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
|
443
|
-
output = net(x)
|
444
|
-
|
445
|
-
self.assertEqual(output.shape, x.shape)
|
446
|
-
|
447
|
-
# RMSNorm should have approximately unit RMS (not zero mean)
|
448
|
-
rms = jnp.sqrt(jnp.mean(jnp.square(output), axis=-1))
|
449
|
-
np.testing.assert_allclose(rms, 1.0, atol=1e-1)
|
450
|
-
|
451
|
-
def test_rmsnorm_without_batch(self):
|
452
|
-
"""Test RMSNorm without batch dimension."""
|
453
|
-
in_size = (10, 20)
|
454
|
-
net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1)
|
455
|
-
|
456
|
-
x = brainstate.random.randn(10, 20)
|
457
|
-
output = net(x)
|
458
|
-
|
459
|
-
self.assertEqual(output.shape, (10, 20))
|
460
|
-
|
461
|
-
|
462
|
-
class TestGroupNorm(parameterized.TestCase):
|
463
|
-
"""Test GroupNorm with various configurations."""
|
464
|
-
|
465
|
-
@parameterized.product(
|
466
|
-
num_groups=[1, 2, 4, 8],
|
467
|
-
use_bias=[True, False],
|
468
|
-
use_scale=[True, False],
|
469
|
-
)
|
470
|
-
def test_groupnorm_basic(self, num_groups, use_bias, use_scale):
|
471
|
-
"""Test GroupNorm with different number of groups."""
|
472
|
-
channels = 16
|
473
|
-
# GroupNorm requires 1D feature axis (just the channel dimension)
|
474
|
-
in_size = (channels,)
|
475
|
-
|
476
|
-
# Check if channels is divisible by num_groups
|
477
|
-
if channels % num_groups != 0:
|
478
|
-
return
|
479
|
-
|
480
|
-
net = brainstate.nn.GroupNorm(
|
481
|
-
in_size,
|
482
|
-
feature_axis=0,
|
483
|
-
num_groups=num_groups,
|
484
|
-
use_bias=use_bias,
|
485
|
-
use_scale=use_scale,
|
486
|
-
)
|
487
|
-
|
488
|
-
# Input needs at least 2D: (height, width, channels) or (batch, channels)
|
489
|
-
# Using (batch, channels) format
|
490
|
-
x = brainstate.random.randn(4, channels)
|
491
|
-
output = net(x)
|
492
|
-
|
493
|
-
self.assertEqual(output.shape, x.shape)
|
494
|
-
|
495
|
-
def test_groupnorm_channel_first(self):
|
496
|
-
"""Test GroupNorm with channel-first format for images."""
|
497
|
-
channels = 16
|
498
|
-
# GroupNorm requires 1D feature (just channels)
|
499
|
-
in_size = (channels,)
|
500
|
-
|
501
|
-
net = brainstate.nn.GroupNorm(
|
502
|
-
in_size,
|
503
|
-
feature_axis=0,
|
504
|
-
num_groups=4,
|
505
|
-
)
|
506
|
-
|
507
|
-
# Test with image-like data: (batch, height, width, channels)
|
508
|
-
x = brainstate.random.randn(4, 32, 32, channels)
|
509
|
-
output = net(x)
|
510
|
-
|
511
|
-
self.assertEqual(output.shape, x.shape)
|
512
|
-
|
513
|
-
def test_groupnorm_channel_last(self):
|
514
|
-
"""Test GroupNorm with channel-last format for images."""
|
515
|
-
channels = 16
|
516
|
-
# GroupNorm requires 1D feature (just channels)
|
517
|
-
in_size = (channels,)
|
518
|
-
|
519
|
-
net = brainstate.nn.GroupNorm(
|
520
|
-
in_size,
|
521
|
-
feature_axis=0, # feature_axis refers to position in in_size
|
522
|
-
num_groups=4,
|
523
|
-
)
|
524
|
-
|
525
|
-
# Test with image-like data: (batch, height, width, channels)
|
526
|
-
x = brainstate.random.randn(4, 32, 32, channels)
|
527
|
-
output = net(x)
|
528
|
-
|
529
|
-
self.assertEqual(output.shape, x.shape)
|
530
|
-
|
531
|
-
def test_groupnorm_equals_layernorm(self):
|
532
|
-
"""Test that GroupNorm with num_groups=1 equals LayerNorm."""
|
533
|
-
channels = 16
|
534
|
-
# GroupNorm requires 1D feature
|
535
|
-
in_size = (channels,)
|
536
|
-
|
537
|
-
# GroupNorm with 1 group
|
538
|
-
group_norm = brainstate.nn.GroupNorm(
|
539
|
-
in_size,
|
540
|
-
feature_axis=0,
|
541
|
-
num_groups=1,
|
542
|
-
)
|
543
|
-
|
544
|
-
# LayerNorm with same setup
|
545
|
-
layer_norm = brainstate.nn.LayerNorm(
|
546
|
-
in_size,
|
547
|
-
reduction_axes=-1,
|
548
|
-
feature_axes=-1,
|
549
|
-
)
|
550
|
-
|
551
|
-
# Use 2D input: (batch, channels)
|
552
|
-
x = brainstate.random.randn(8, channels)
|
553
|
-
|
554
|
-
output_gn = group_norm(x)
|
555
|
-
output_ln = layer_norm(x)
|
556
|
-
|
557
|
-
# Shapes should match
|
558
|
-
self.assertEqual(output_gn.shape, output_ln.shape)
|
559
|
-
|
560
|
-
def test_groupnorm_group_size(self):
|
561
|
-
"""Test GroupNorm with group_size instead of num_groups."""
|
562
|
-
channels = 16
|
563
|
-
group_size = 4
|
564
|
-
# GroupNorm requires 1D feature
|
565
|
-
in_size = (channels,)
|
566
|
-
|
567
|
-
net = brainstate.nn.GroupNorm(
|
568
|
-
in_size,
|
569
|
-
feature_axis=0,
|
570
|
-
num_groups=None,
|
571
|
-
group_size=group_size,
|
572
|
-
)
|
573
|
-
|
574
|
-
# Use 2D input: (batch, channels)
|
575
|
-
x = brainstate.random.randn(4, channels)
|
576
|
-
output = net(x)
|
577
|
-
|
578
|
-
self.assertEqual(output.shape, x.shape)
|
579
|
-
self.assertEqual(net.num_groups, channels // group_size)
|
580
|
-
|
581
|
-
def test_groupnorm_invalid_groups(self):
|
582
|
-
"""Test that invalid num_groups raises error."""
|
583
|
-
channels = 15 # Not divisible by many numbers
|
584
|
-
# GroupNorm requires 1D feature
|
585
|
-
in_size = (channels,)
|
586
|
-
|
587
|
-
# Should raise error if num_groups doesn't divide channels
|
588
|
-
with self.assertRaises(ValueError):
|
589
|
-
net = brainstate.nn.GroupNorm(
|
590
|
-
in_size,
|
591
|
-
feature_axis=0,
|
592
|
-
num_groups=4, # 15 is not divisible by 4
|
593
|
-
)
|
594
|
-
|
595
|
-
|
596
|
-
class TestNormalizationUtilities(parameterized.TestCase):
|
597
|
-
"""Test utility functions for normalization."""
|
598
|
-
|
599
|
-
def test_weight_standardization(self):
|
600
|
-
"""Test weight_standardization function."""
|
601
|
-
w = brainstate.random.randn(3, 4, 5, 6)
|
602
|
-
|
603
|
-
w_std = brainstate.nn.weight_standardization(w, eps=1e-4)
|
604
|
-
|
605
|
-
self.assertEqual(w_std.shape, w.shape)
|
606
|
-
|
607
|
-
# Check that standardization works
|
608
|
-
# Mean should be close to 0 along non-output axes
|
609
|
-
mean = jnp.mean(w_std, axis=(0, 1, 2))
|
610
|
-
np.testing.assert_allclose(mean, 0.0, atol=1e-4)
|
611
|
-
|
612
|
-
def test_weight_standardization_with_gain(self):
|
613
|
-
"""Test weight_standardization with gain parameter."""
|
614
|
-
w = brainstate.random.randn(3, 4, 5, 6)
|
615
|
-
gain = jnp.ones((6,))
|
616
|
-
|
617
|
-
w_std = brainstate.nn.weight_standardization(w, gain=gain)
|
618
|
-
|
619
|
-
self.assertEqual(w_std.shape, w.shape)
|
620
|
-
|
621
|
-
|
622
|
-
class TestNormalizationEdgeCases(parameterized.TestCase):
|
623
|
-
"""Test edge cases and error conditions."""
|
624
|
-
|
625
|
-
def test_batchnorm_shape_mismatch(self):
|
626
|
-
"""Test that BatchNorm raises error on shape mismatch."""
|
627
|
-
net = brainstate.nn.BatchNorm2d((28, 28, 3))
|
628
|
-
|
629
|
-
# Wrong shape should raise error
|
630
|
-
with self.assertRaises(ValueError):
|
631
|
-
x = brainstate.random.randn(4, 32, 32, 3) # Wrong height/width
|
632
|
-
_ = net(x)
|
633
|
-
|
634
|
-
def test_batchnorm_without_track_and_affine(self):
|
635
|
-
"""Test that affine=True requires track_running_stats=True."""
|
636
|
-
# This should raise an assertion error
|
637
|
-
with self.assertRaises(AssertionError):
|
638
|
-
net = brainstate.nn.BatchNorm2d(
|
639
|
-
(28, 28, 3),
|
640
|
-
track_running_stats=False,
|
641
|
-
affine=True # Requires track_running_stats=True
|
642
|
-
)
|
643
|
-
|
644
|
-
def test_groupnorm_both_params(self):
|
645
|
-
"""Test that GroupNorm raises error when both num_groups and group_size are specified."""
|
646
|
-
with self.assertRaises(ValueError):
|
647
|
-
net = brainstate.nn.GroupNorm(
|
648
|
-
(32, 32, 16),
|
649
|
-
num_groups=4,
|
650
|
-
group_size=4, # Can't specify both
|
651
|
-
)
|
652
|
-
|
653
|
-
def test_groupnorm_neither_param(self):
|
654
|
-
"""Test that GroupNorm raises error when neither num_groups nor group_size are specified."""
|
655
|
-
with self.assertRaises(ValueError):
|
656
|
-
net = brainstate.nn.GroupNorm(
|
657
|
-
(32, 32, 16),
|
658
|
-
num_groups=None,
|
659
|
-
group_size=None, # Must specify one
|
660
|
-
)
|
661
|
-
|
662
|
-
|
663
|
-
class TestNormalizationConsistency(parameterized.TestCase):
|
664
|
-
"""Test consistency across different batch sizes and modes."""
|
665
|
-
|
666
|
-
def test_batchnorm2d_consistency_across_batches(self):
|
667
|
-
"""Test that BatchNorm2d behaves consistently across different batch sizes."""
|
668
|
-
in_size = (28, 28, 3)
|
669
|
-
net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
|
670
|
-
|
671
|
-
# Train on larger batch
|
672
|
-
brainstate.environ.set(fit=True)
|
673
|
-
x_large = brainstate.random.randn(32, 28, 28, 3)
|
674
|
-
_ = net(x_large)
|
675
|
-
|
676
|
-
# Test on smaller batch
|
677
|
-
brainstate.environ.set(fit=False)
|
678
|
-
x_small = brainstate.random.randn(4, 28, 28, 3)
|
679
|
-
output = net(x_small)
|
680
|
-
|
681
|
-
self.assertEqual(output.shape, x_small.shape)
|
682
|
-
|
683
|
-
def test_layernorm_consistency(self):
|
684
|
-
"""Test that LayerNorm produces consistent results."""
|
685
|
-
in_size = (10, 20)
|
686
|
-
net = brainstate.nn.LayerNorm(in_size)
|
687
|
-
|
688
|
-
x = brainstate.random.randn(8, 10, 20)
|
689
|
-
|
690
|
-
# Run twice
|
691
|
-
output1 = net(x)
|
692
|
-
output2 = net(x)
|
693
|
-
|
694
|
-
# Should be deterministic
|
695
|
-
np.testing.assert_allclose(output1, output2)
|
696
|
-
|
697
|
-
|
698
|
-
if __name__ == '__main__':
|
699
|
-
absltest.main()
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from absl.testing import absltest
|
17
|
+
from absl.testing import parameterized
|
18
|
+
import jax.numpy as jnp
|
19
|
+
import numpy as np
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
|
24
|
+
class TestBatchNorm0d(parameterized.TestCase):
|
25
|
+
"""Test BatchNorm0d with various configurations."""
|
26
|
+
|
27
|
+
@parameterized.product(
|
28
|
+
fit=[True, False],
|
29
|
+
feature_axis=[-1, 0],
|
30
|
+
track_running_stats=[True, False],
|
31
|
+
)
|
32
|
+
def test_batchnorm0d_with_batch(self, fit, feature_axis, track_running_stats):
|
33
|
+
"""Test BatchNorm0d with batched input."""
|
34
|
+
batch_size = 8
|
35
|
+
channels = 10
|
36
|
+
|
37
|
+
# Channel last: (batch, channels)
|
38
|
+
if feature_axis == -1:
|
39
|
+
in_size = (channels,)
|
40
|
+
input_shape = (batch_size, channels)
|
41
|
+
# Channel first: (batch, channels) - same for 0D
|
42
|
+
else:
|
43
|
+
in_size = (channels,)
|
44
|
+
input_shape = (batch_size, channels)
|
45
|
+
|
46
|
+
# affine can only be True when track_running_stats is True
|
47
|
+
affine = track_running_stats
|
48
|
+
|
49
|
+
net = brainstate.nn.BatchNorm0d(
|
50
|
+
in_size,
|
51
|
+
feature_axis=feature_axis,
|
52
|
+
track_running_stats=track_running_stats,
|
53
|
+
affine=affine
|
54
|
+
)
|
55
|
+
brainstate.environ.set(fit=fit)
|
56
|
+
|
57
|
+
x = brainstate.random.randn(*input_shape)
|
58
|
+
output = net(x)
|
59
|
+
|
60
|
+
# Check output shape matches input
|
61
|
+
self.assertEqual(output.shape, input_shape)
|
62
|
+
|
63
|
+
# Check that output has approximately zero mean and unit variance when fitting
|
64
|
+
if fit and track_running_stats:
|
65
|
+
# Stats should be computed along batch dimension
|
66
|
+
mean = jnp.mean(output, axis=0)
|
67
|
+
var = jnp.var(output, axis=0)
|
68
|
+
np.testing.assert_allclose(mean, 0.0, atol=1e-5)
|
69
|
+
np.testing.assert_allclose(var, 1.0, atol=1e-1)
|
70
|
+
|
71
|
+
def test_batchnorm0d_without_batch(self):
|
72
|
+
"""Test BatchNorm0d without batching."""
|
73
|
+
channels = 10
|
74
|
+
in_size = (channels,)
|
75
|
+
|
76
|
+
net = brainstate.nn.BatchNorm0d(in_size, track_running_stats=True)
|
77
|
+
brainstate.environ.set(fit=False) # Use running stats
|
78
|
+
|
79
|
+
# Run with batch first to populate running stats
|
80
|
+
brainstate.environ.set(fit=True)
|
81
|
+
x_batch = brainstate.random.randn(16, channels)
|
82
|
+
_ = net(x_batch)
|
83
|
+
|
84
|
+
# Now test without batch
|
85
|
+
brainstate.environ.set(fit=False)
|
86
|
+
x_single = brainstate.random.randn(channels)
|
87
|
+
output = net(x_single)
|
88
|
+
|
89
|
+
self.assertEqual(output.shape, (channels,))
|
90
|
+
|
91
|
+
def test_batchnorm0d_affine(self):
|
92
|
+
"""Test BatchNorm0d with and without affine parameters."""
|
93
|
+
channels = 10
|
94
|
+
in_size = (channels,)
|
95
|
+
|
96
|
+
# With affine
|
97
|
+
net_affine = brainstate.nn.BatchNorm0d(in_size, affine=True)
|
98
|
+
self.assertIsNotNone(net_affine.weight)
|
99
|
+
|
100
|
+
# Without affine (track_running_stats must be False)
|
101
|
+
net_no_affine = brainstate.nn.BatchNorm0d(
|
102
|
+
in_size, affine=False, track_running_stats=False
|
103
|
+
)
|
104
|
+
self.assertIsNone(net_no_affine.weight)
|
105
|
+
|
106
|
+
|
107
|
+
class TestBatchNorm1d(parameterized.TestCase):
|
108
|
+
"""Test BatchNorm1d with various configurations."""
|
109
|
+
|
110
|
+
@parameterized.product(
|
111
|
+
fit=[True, False],
|
112
|
+
feature_axis=[-1, 0],
|
113
|
+
track_running_stats=[True, False],
|
114
|
+
)
|
115
|
+
def test_batchnorm1d_with_batch(self, fit, feature_axis, track_running_stats):
|
116
|
+
"""Test BatchNorm1d with batched input."""
|
117
|
+
batch_size = 8
|
118
|
+
length = 20
|
119
|
+
channels = 10
|
120
|
+
|
121
|
+
# Channel last: (batch, length, channels)
|
122
|
+
if feature_axis == -1:
|
123
|
+
in_size = (length, channels)
|
124
|
+
input_shape = (batch_size, length, channels)
|
125
|
+
feature_axis_param = -1
|
126
|
+
# Channel first: (batch, channels, length)
|
127
|
+
else:
|
128
|
+
in_size = (channels, length)
|
129
|
+
input_shape = (batch_size, channels, length)
|
130
|
+
feature_axis_param = 0
|
131
|
+
|
132
|
+
# affine can only be True when track_running_stats is True
|
133
|
+
affine = track_running_stats
|
134
|
+
|
135
|
+
net = brainstate.nn.BatchNorm1d(
|
136
|
+
in_size,
|
137
|
+
feature_axis=feature_axis_param,
|
138
|
+
track_running_stats=track_running_stats,
|
139
|
+
affine=affine
|
140
|
+
)
|
141
|
+
brainstate.environ.set(fit=fit)
|
142
|
+
|
143
|
+
x = brainstate.random.randn(*input_shape)
|
144
|
+
output = net(x)
|
145
|
+
|
146
|
+
# Check output shape matches input
|
147
|
+
self.assertEqual(output.shape, input_shape)
|
148
|
+
|
149
|
+
def test_batchnorm1d_without_batch(self):
|
150
|
+
"""Test BatchNorm1d without batching."""
|
151
|
+
length = 20
|
152
|
+
channels = 10
|
153
|
+
in_size = (length, channels)
|
154
|
+
|
155
|
+
net = brainstate.nn.BatchNorm1d(in_size, track_running_stats=True)
|
156
|
+
|
157
|
+
# Populate running stats first
|
158
|
+
brainstate.environ.set(fit=True)
|
159
|
+
x_batch = brainstate.random.randn(8, length, channels)
|
160
|
+
_ = net(x_batch)
|
161
|
+
|
162
|
+
# Test without batch
|
163
|
+
brainstate.environ.set(fit=False)
|
164
|
+
x_single = brainstate.random.randn(length, channels)
|
165
|
+
output = net(x_single)
|
166
|
+
|
167
|
+
self.assertEqual(output.shape, (length, channels))
|
168
|
+
|
169
|
+
@parameterized.product(
|
170
|
+
feature_axis=[-1, 0],
|
171
|
+
)
|
172
|
+
def test_batchnorm1d_channel_consistency(self, feature_axis):
|
173
|
+
"""Test that normalization is consistent across different channel configurations."""
|
174
|
+
batch_size = 16
|
175
|
+
length = 20
|
176
|
+
channels = 10
|
177
|
+
|
178
|
+
if feature_axis == -1:
|
179
|
+
in_size = (length, channels)
|
180
|
+
input_shape = (batch_size, length, channels)
|
181
|
+
else:
|
182
|
+
in_size = (channels, length)
|
183
|
+
input_shape = (batch_size, channels, length)
|
184
|
+
|
185
|
+
net = brainstate.nn.BatchNorm1d(in_size, feature_axis=feature_axis)
|
186
|
+
brainstate.environ.set(fit=True)
|
187
|
+
|
188
|
+
x = brainstate.random.randn(*input_shape)
|
189
|
+
output = net(x)
|
190
|
+
|
191
|
+
# Output should have same shape as input
|
192
|
+
self.assertEqual(output.shape, input_shape)
|
193
|
+
|
194
|
+
|
195
|
+
class TestBatchNorm2d(parameterized.TestCase):
|
196
|
+
"""Test BatchNorm2d with various configurations."""
|
197
|
+
|
198
|
+
@parameterized.product(
|
199
|
+
fit=[True, False],
|
200
|
+
feature_axis=[-1, 0],
|
201
|
+
track_running_stats=[True, False],
|
202
|
+
)
|
203
|
+
def test_batchnorm2d_with_batch(self, fit, feature_axis, track_running_stats):
|
204
|
+
"""Test BatchNorm2d with batched input (images)."""
|
205
|
+
batch_size = 4
|
206
|
+
height, width = 28, 28
|
207
|
+
channels = 3
|
208
|
+
|
209
|
+
# Channel last: (batch, height, width, channels)
|
210
|
+
if feature_axis == -1:
|
211
|
+
in_size = (height, width, channels)
|
212
|
+
input_shape = (batch_size, height, width, channels)
|
213
|
+
feature_axis_param = -1
|
214
|
+
# Channel first: (batch, channels, height, width)
|
215
|
+
else:
|
216
|
+
in_size = (channels, height, width)
|
217
|
+
input_shape = (batch_size, channels, height, width)
|
218
|
+
feature_axis_param = 0
|
219
|
+
|
220
|
+
# affine can only be True when track_running_stats is True
|
221
|
+
affine = track_running_stats
|
222
|
+
|
223
|
+
net = brainstate.nn.BatchNorm2d(
|
224
|
+
in_size,
|
225
|
+
feature_axis=feature_axis_param,
|
226
|
+
track_running_stats=track_running_stats,
|
227
|
+
affine=affine
|
228
|
+
)
|
229
|
+
brainstate.environ.set(fit=fit)
|
230
|
+
|
231
|
+
x = brainstate.random.randn(*input_shape)
|
232
|
+
output = net(x)
|
233
|
+
|
234
|
+
# Check output shape matches input
|
235
|
+
self.assertEqual(output.shape, input_shape)
|
236
|
+
|
237
|
+
# Check normalization properties during training
|
238
|
+
if fit and track_running_stats:
|
239
|
+
# For channel last: normalize over (batch, height, width)
|
240
|
+
# For channel first: normalize over (batch, height, width)
|
241
|
+
if feature_axis == -1:
|
242
|
+
axes = (0, 1, 2)
|
243
|
+
else:
|
244
|
+
axes = (0, 2, 3)
|
245
|
+
|
246
|
+
mean = jnp.mean(output, axis=axes)
|
247
|
+
var = jnp.var(output, axis=axes)
|
248
|
+
|
249
|
+
# Each channel should be approximately normalized
|
250
|
+
np.testing.assert_allclose(mean, 0.0, atol=1e-5)
|
251
|
+
np.testing.assert_allclose(var, 1.0, atol=1e-1)
|
252
|
+
|
253
|
+
def test_batchnorm2d_without_batch(self):
|
254
|
+
"""Test BatchNorm2d without batching."""
|
255
|
+
height, width = 28, 28
|
256
|
+
channels = 3
|
257
|
+
in_size = (height, width, channels)
|
258
|
+
|
259
|
+
net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
|
260
|
+
|
261
|
+
# Populate running stats
|
262
|
+
brainstate.environ.set(fit=True)
|
263
|
+
x_batch = brainstate.random.randn(8, height, width, channels)
|
264
|
+
_ = net(x_batch)
|
265
|
+
|
266
|
+
# Test without batch
|
267
|
+
brainstate.environ.set(fit=False)
|
268
|
+
x_single = brainstate.random.randn(height, width, channels)
|
269
|
+
output = net(x_single)
|
270
|
+
|
271
|
+
self.assertEqual(output.shape, (height, width, channels))
|
272
|
+
|
273
|
+
|
274
|
+
class TestBatchNorm3d(parameterized.TestCase):
|
275
|
+
"""Test BatchNorm3d with various configurations."""
|
276
|
+
|
277
|
+
@parameterized.product(
|
278
|
+
fit=[True, False],
|
279
|
+
feature_axis=[-1, 0],
|
280
|
+
track_running_stats=[True, False],
|
281
|
+
)
|
282
|
+
def test_batchnorm3d_with_batch(self, fit, feature_axis, track_running_stats):
|
283
|
+
"""Test BatchNorm3d with batched input (volumes)."""
|
284
|
+
batch_size = 2
|
285
|
+
depth, height, width = 8, 16, 16
|
286
|
+
channels = 2
|
287
|
+
|
288
|
+
# Channel last: (batch, depth, height, width, channels)
|
289
|
+
if feature_axis == -1:
|
290
|
+
in_size = (depth, height, width, channels)
|
291
|
+
input_shape = (batch_size, depth, height, width, channels)
|
292
|
+
feature_axis_param = -1
|
293
|
+
# Channel first: (batch, channels, depth, height, width)
|
294
|
+
else:
|
295
|
+
in_size = (channels, depth, height, width)
|
296
|
+
input_shape = (batch_size, channels, depth, height, width)
|
297
|
+
feature_axis_param = 0
|
298
|
+
|
299
|
+
# affine can only be True when track_running_stats is True
|
300
|
+
affine = track_running_stats
|
301
|
+
|
302
|
+
net = brainstate.nn.BatchNorm3d(
|
303
|
+
in_size,
|
304
|
+
feature_axis=feature_axis_param,
|
305
|
+
track_running_stats=track_running_stats,
|
306
|
+
affine=affine
|
307
|
+
)
|
308
|
+
brainstate.environ.set(fit=fit)
|
309
|
+
|
310
|
+
x = brainstate.random.randn(*input_shape)
|
311
|
+
output = net(x)
|
312
|
+
|
313
|
+
# Check output shape matches input
|
314
|
+
self.assertEqual(output.shape, input_shape)
|
315
|
+
|
316
|
+
def test_batchnorm3d_without_batch(self):
|
317
|
+
"""Test BatchNorm3d without batching."""
|
318
|
+
depth, height, width = 8, 16, 16
|
319
|
+
channels = 2
|
320
|
+
in_size = (depth, height, width, channels)
|
321
|
+
|
322
|
+
net = brainstate.nn.BatchNorm3d(in_size, track_running_stats=True)
|
323
|
+
|
324
|
+
# Populate running stats
|
325
|
+
brainstate.environ.set(fit=True)
|
326
|
+
x_batch = brainstate.random.randn(4, depth, height, width, channels)
|
327
|
+
_ = net(x_batch)
|
328
|
+
|
329
|
+
# Test without batch
|
330
|
+
brainstate.environ.set(fit=False)
|
331
|
+
x_single = brainstate.random.randn(depth, height, width, channels)
|
332
|
+
output = net(x_single)
|
333
|
+
|
334
|
+
self.assertEqual(output.shape, (depth, height, width, channels))
|
335
|
+
|
336
|
+
|
337
|
+
class TestLayerNorm(parameterized.TestCase):
|
338
|
+
"""Test LayerNorm with various configurations."""
|
339
|
+
|
340
|
+
@parameterized.product(
|
341
|
+
reduction_axes=[(-1,), (-2, -1), (-3, -2, -1)],
|
342
|
+
use_bias=[True, False],
|
343
|
+
use_scale=[True, False],
|
344
|
+
)
|
345
|
+
def test_layernorm_basic(self, reduction_axes, use_bias, use_scale):
|
346
|
+
"""Test LayerNorm with different reduction axes."""
|
347
|
+
in_size = (10, 20, 30)
|
348
|
+
|
349
|
+
net = brainstate.nn.LayerNorm(
|
350
|
+
in_size,
|
351
|
+
reduction_axes=reduction_axes,
|
352
|
+
use_bias=use_bias,
|
353
|
+
use_scale=use_scale,
|
354
|
+
)
|
355
|
+
|
356
|
+
# With batch
|
357
|
+
x = brainstate.random.randn(8, 10, 20, 30)
|
358
|
+
output = net(x)
|
359
|
+
self.assertEqual(output.shape, x.shape)
|
360
|
+
|
361
|
+
# Check normalization properties
|
362
|
+
mean = jnp.mean(output, axis=tuple(i + 1 for i in range(len(in_size))
|
363
|
+
if i - len(in_size) in reduction_axes))
|
364
|
+
var = jnp.var(output, axis=tuple(i + 1 for i in range(len(in_size))
|
365
|
+
if i - len(in_size) in reduction_axes))
|
366
|
+
|
367
|
+
def test_layernorm_2d_features(self):
|
368
|
+
"""Test LayerNorm on 2D features (like in transformers)."""
|
369
|
+
seq_length = 50
|
370
|
+
hidden_dim = 128
|
371
|
+
batch_size = 16
|
372
|
+
|
373
|
+
in_size = (seq_length, hidden_dim)
|
374
|
+
net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1)
|
375
|
+
|
376
|
+
x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
|
377
|
+
output = net(x)
|
378
|
+
|
379
|
+
self.assertEqual(output.shape, x.shape)
|
380
|
+
|
381
|
+
# Each position should be normalized independently
|
382
|
+
mean = jnp.mean(output, axis=-1)
|
383
|
+
var = jnp.var(output, axis=-1)
|
384
|
+
|
385
|
+
np.testing.assert_allclose(mean, 0.0, atol=1e-5)
|
386
|
+
np.testing.assert_allclose(var, 1.0, atol=1e-1)
|
387
|
+
|
388
|
+
def test_layernorm_without_batch(self):
|
389
|
+
"""Test LayerNorm without batch dimension."""
|
390
|
+
in_size = (10, 20)
|
391
|
+
net = brainstate.nn.LayerNorm(in_size, reduction_axes=-1)
|
392
|
+
|
393
|
+
x = brainstate.random.randn(10, 20)
|
394
|
+
output = net(x)
|
395
|
+
|
396
|
+
self.assertEqual(output.shape, (10, 20))
|
397
|
+
|
398
|
+
@parameterized.product(
|
399
|
+
in_size=[(10,), (10, 20), (10, 20, 30)],
|
400
|
+
)
|
401
|
+
def test_layernorm_various_dims(self, in_size):
|
402
|
+
"""Test LayerNorm with various input dimensions."""
|
403
|
+
net = brainstate.nn.LayerNorm(in_size)
|
404
|
+
|
405
|
+
# With batch
|
406
|
+
x_with_batch = brainstate.random.randn(8, *in_size)
|
407
|
+
output_with_batch = net(x_with_batch)
|
408
|
+
self.assertEqual(output_with_batch.shape, x_with_batch.shape)
|
409
|
+
|
410
|
+
|
411
|
+
class TestRMSNorm(parameterized.TestCase):
|
412
|
+
"""Test RMSNorm with various configurations."""
|
413
|
+
|
414
|
+
@parameterized.product(
|
415
|
+
use_scale=[True, False],
|
416
|
+
reduction_axes=[(-1,), (-2, -1)],
|
417
|
+
)
|
418
|
+
def test_rmsnorm_basic(self, use_scale, reduction_axes):
|
419
|
+
"""Test RMSNorm with different configurations."""
|
420
|
+
in_size = (10, 20)
|
421
|
+
|
422
|
+
net = brainstate.nn.RMSNorm(
|
423
|
+
in_size,
|
424
|
+
use_scale=use_scale,
|
425
|
+
reduction_axes=reduction_axes,
|
426
|
+
)
|
427
|
+
|
428
|
+
x = brainstate.random.randn(8, 10, 20)
|
429
|
+
output = net(x)
|
430
|
+
|
431
|
+
self.assertEqual(output.shape, x.shape)
|
432
|
+
|
433
|
+
def test_rmsnorm_transformer_like(self):
|
434
|
+
"""Test RMSNorm in transformer-like setting."""
|
435
|
+
seq_length = 50
|
436
|
+
hidden_dim = 128
|
437
|
+
batch_size = 16
|
438
|
+
|
439
|
+
in_size = (seq_length, hidden_dim)
|
440
|
+
net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1, feature_axes=-1)
|
441
|
+
|
442
|
+
x = brainstate.random.randn(batch_size, seq_length, hidden_dim)
|
443
|
+
output = net(x)
|
444
|
+
|
445
|
+
self.assertEqual(output.shape, x.shape)
|
446
|
+
|
447
|
+
# RMSNorm should have approximately unit RMS (not zero mean)
|
448
|
+
rms = jnp.sqrt(jnp.mean(jnp.square(output), axis=-1))
|
449
|
+
np.testing.assert_allclose(rms, 1.0, atol=1e-1)
|
450
|
+
|
451
|
+
def test_rmsnorm_without_batch(self):
|
452
|
+
"""Test RMSNorm without batch dimension."""
|
453
|
+
in_size = (10, 20)
|
454
|
+
net = brainstate.nn.RMSNorm(in_size, reduction_axes=-1)
|
455
|
+
|
456
|
+
x = brainstate.random.randn(10, 20)
|
457
|
+
output = net(x)
|
458
|
+
|
459
|
+
self.assertEqual(output.shape, (10, 20))
|
460
|
+
|
461
|
+
|
462
|
+
class TestGroupNorm(parameterized.TestCase):
|
463
|
+
"""Test GroupNorm with various configurations."""
|
464
|
+
|
465
|
+
@parameterized.product(
|
466
|
+
num_groups=[1, 2, 4, 8],
|
467
|
+
use_bias=[True, False],
|
468
|
+
use_scale=[True, False],
|
469
|
+
)
|
470
|
+
def test_groupnorm_basic(self, num_groups, use_bias, use_scale):
|
471
|
+
"""Test GroupNorm with different number of groups."""
|
472
|
+
channels = 16
|
473
|
+
# GroupNorm requires 1D feature axis (just the channel dimension)
|
474
|
+
in_size = (channels,)
|
475
|
+
|
476
|
+
# Check if channels is divisible by num_groups
|
477
|
+
if channels % num_groups != 0:
|
478
|
+
return
|
479
|
+
|
480
|
+
net = brainstate.nn.GroupNorm(
|
481
|
+
in_size,
|
482
|
+
feature_axis=0,
|
483
|
+
num_groups=num_groups,
|
484
|
+
use_bias=use_bias,
|
485
|
+
use_scale=use_scale,
|
486
|
+
)
|
487
|
+
|
488
|
+
# Input needs at least 2D: (height, width, channels) or (batch, channels)
|
489
|
+
# Using (batch, channels) format
|
490
|
+
x = brainstate.random.randn(4, channels)
|
491
|
+
output = net(x)
|
492
|
+
|
493
|
+
self.assertEqual(output.shape, x.shape)
|
494
|
+
|
495
|
+
def test_groupnorm_channel_first(self):
|
496
|
+
"""Test GroupNorm with channel-first format for images."""
|
497
|
+
channels = 16
|
498
|
+
# GroupNorm requires 1D feature (just channels)
|
499
|
+
in_size = (channels,)
|
500
|
+
|
501
|
+
net = brainstate.nn.GroupNorm(
|
502
|
+
in_size,
|
503
|
+
feature_axis=0,
|
504
|
+
num_groups=4,
|
505
|
+
)
|
506
|
+
|
507
|
+
# Test with image-like data: (batch, height, width, channels)
|
508
|
+
x = brainstate.random.randn(4, 32, 32, channels)
|
509
|
+
output = net(x)
|
510
|
+
|
511
|
+
self.assertEqual(output.shape, x.shape)
|
512
|
+
|
513
|
+
def test_groupnorm_channel_last(self):
|
514
|
+
"""Test GroupNorm with channel-last format for images."""
|
515
|
+
channels = 16
|
516
|
+
# GroupNorm requires 1D feature (just channels)
|
517
|
+
in_size = (channels,)
|
518
|
+
|
519
|
+
net = brainstate.nn.GroupNorm(
|
520
|
+
in_size,
|
521
|
+
feature_axis=0, # feature_axis refers to position in in_size
|
522
|
+
num_groups=4,
|
523
|
+
)
|
524
|
+
|
525
|
+
# Test with image-like data: (batch, height, width, channels)
|
526
|
+
x = brainstate.random.randn(4, 32, 32, channels)
|
527
|
+
output = net(x)
|
528
|
+
|
529
|
+
self.assertEqual(output.shape, x.shape)
|
530
|
+
|
531
|
+
def test_groupnorm_equals_layernorm(self):
|
532
|
+
"""Test that GroupNorm with num_groups=1 equals LayerNorm."""
|
533
|
+
channels = 16
|
534
|
+
# GroupNorm requires 1D feature
|
535
|
+
in_size = (channels,)
|
536
|
+
|
537
|
+
# GroupNorm with 1 group
|
538
|
+
group_norm = brainstate.nn.GroupNorm(
|
539
|
+
in_size,
|
540
|
+
feature_axis=0,
|
541
|
+
num_groups=1,
|
542
|
+
)
|
543
|
+
|
544
|
+
# LayerNorm with same setup
|
545
|
+
layer_norm = brainstate.nn.LayerNorm(
|
546
|
+
in_size,
|
547
|
+
reduction_axes=-1,
|
548
|
+
feature_axes=-1,
|
549
|
+
)
|
550
|
+
|
551
|
+
# Use 2D input: (batch, channels)
|
552
|
+
x = brainstate.random.randn(8, channels)
|
553
|
+
|
554
|
+
output_gn = group_norm(x)
|
555
|
+
output_ln = layer_norm(x)
|
556
|
+
|
557
|
+
# Shapes should match
|
558
|
+
self.assertEqual(output_gn.shape, output_ln.shape)
|
559
|
+
|
560
|
+
def test_groupnorm_group_size(self):
|
561
|
+
"""Test GroupNorm with group_size instead of num_groups."""
|
562
|
+
channels = 16
|
563
|
+
group_size = 4
|
564
|
+
# GroupNorm requires 1D feature
|
565
|
+
in_size = (channels,)
|
566
|
+
|
567
|
+
net = brainstate.nn.GroupNorm(
|
568
|
+
in_size,
|
569
|
+
feature_axis=0,
|
570
|
+
num_groups=None,
|
571
|
+
group_size=group_size,
|
572
|
+
)
|
573
|
+
|
574
|
+
# Use 2D input: (batch, channels)
|
575
|
+
x = brainstate.random.randn(4, channels)
|
576
|
+
output = net(x)
|
577
|
+
|
578
|
+
self.assertEqual(output.shape, x.shape)
|
579
|
+
self.assertEqual(net.num_groups, channels // group_size)
|
580
|
+
|
581
|
+
def test_groupnorm_invalid_groups(self):
|
582
|
+
"""Test that invalid num_groups raises error."""
|
583
|
+
channels = 15 # Not divisible by many numbers
|
584
|
+
# GroupNorm requires 1D feature
|
585
|
+
in_size = (channels,)
|
586
|
+
|
587
|
+
# Should raise error if num_groups doesn't divide channels
|
588
|
+
with self.assertRaises(ValueError):
|
589
|
+
net = brainstate.nn.GroupNorm(
|
590
|
+
in_size,
|
591
|
+
feature_axis=0,
|
592
|
+
num_groups=4, # 15 is not divisible by 4
|
593
|
+
)
|
594
|
+
|
595
|
+
|
596
|
+
class TestNormalizationUtilities(parameterized.TestCase):
|
597
|
+
"""Test utility functions for normalization."""
|
598
|
+
|
599
|
+
def test_weight_standardization(self):
|
600
|
+
"""Test weight_standardization function."""
|
601
|
+
w = brainstate.random.randn(3, 4, 5, 6)
|
602
|
+
|
603
|
+
w_std = brainstate.nn.weight_standardization(w, eps=1e-4)
|
604
|
+
|
605
|
+
self.assertEqual(w_std.shape, w.shape)
|
606
|
+
|
607
|
+
# Check that standardization works
|
608
|
+
# Mean should be close to 0 along non-output axes
|
609
|
+
mean = jnp.mean(w_std, axis=(0, 1, 2))
|
610
|
+
np.testing.assert_allclose(mean, 0.0, atol=1e-4)
|
611
|
+
|
612
|
+
def test_weight_standardization_with_gain(self):
|
613
|
+
"""Test weight_standardization with gain parameter."""
|
614
|
+
w = brainstate.random.randn(3, 4, 5, 6)
|
615
|
+
gain = jnp.ones((6,))
|
616
|
+
|
617
|
+
w_std = brainstate.nn.weight_standardization(w, gain=gain)
|
618
|
+
|
619
|
+
self.assertEqual(w_std.shape, w.shape)
|
620
|
+
|
621
|
+
|
622
|
+
class TestNormalizationEdgeCases(parameterized.TestCase):
|
623
|
+
"""Test edge cases and error conditions."""
|
624
|
+
|
625
|
+
def test_batchnorm_shape_mismatch(self):
|
626
|
+
"""Test that BatchNorm raises error on shape mismatch."""
|
627
|
+
net = brainstate.nn.BatchNorm2d((28, 28, 3))
|
628
|
+
|
629
|
+
# Wrong shape should raise error
|
630
|
+
with self.assertRaises(ValueError):
|
631
|
+
x = brainstate.random.randn(4, 32, 32, 3) # Wrong height/width
|
632
|
+
_ = net(x)
|
633
|
+
|
634
|
+
def test_batchnorm_without_track_and_affine(self):
|
635
|
+
"""Test that affine=True requires track_running_stats=True."""
|
636
|
+
# This should raise an assertion error
|
637
|
+
with self.assertRaises(AssertionError):
|
638
|
+
net = brainstate.nn.BatchNorm2d(
|
639
|
+
(28, 28, 3),
|
640
|
+
track_running_stats=False,
|
641
|
+
affine=True # Requires track_running_stats=True
|
642
|
+
)
|
643
|
+
|
644
|
+
def test_groupnorm_both_params(self):
|
645
|
+
"""Test that GroupNorm raises error when both num_groups and group_size are specified."""
|
646
|
+
with self.assertRaises(ValueError):
|
647
|
+
net = brainstate.nn.GroupNorm(
|
648
|
+
(32, 32, 16),
|
649
|
+
num_groups=4,
|
650
|
+
group_size=4, # Can't specify both
|
651
|
+
)
|
652
|
+
|
653
|
+
def test_groupnorm_neither_param(self):
|
654
|
+
"""Test that GroupNorm raises error when neither num_groups nor group_size are specified."""
|
655
|
+
with self.assertRaises(ValueError):
|
656
|
+
net = brainstate.nn.GroupNorm(
|
657
|
+
(32, 32, 16),
|
658
|
+
num_groups=None,
|
659
|
+
group_size=None, # Must specify one
|
660
|
+
)
|
661
|
+
|
662
|
+
|
663
|
+
class TestNormalizationConsistency(parameterized.TestCase):
|
664
|
+
"""Test consistency across different batch sizes and modes."""
|
665
|
+
|
666
|
+
def test_batchnorm2d_consistency_across_batches(self):
|
667
|
+
"""Test that BatchNorm2d behaves consistently across different batch sizes."""
|
668
|
+
in_size = (28, 28, 3)
|
669
|
+
net = brainstate.nn.BatchNorm2d(in_size, track_running_stats=True)
|
670
|
+
|
671
|
+
# Train on larger batch
|
672
|
+
brainstate.environ.set(fit=True)
|
673
|
+
x_large = brainstate.random.randn(32, 28, 28, 3)
|
674
|
+
_ = net(x_large)
|
675
|
+
|
676
|
+
# Test on smaller batch
|
677
|
+
brainstate.environ.set(fit=False)
|
678
|
+
x_small = brainstate.random.randn(4, 28, 28, 3)
|
679
|
+
output = net(x_small)
|
680
|
+
|
681
|
+
self.assertEqual(output.shape, x_small.shape)
|
682
|
+
|
683
|
+
def test_layernorm_consistency(self):
|
684
|
+
"""Test that LayerNorm produces consistent results."""
|
685
|
+
in_size = (10, 20)
|
686
|
+
net = brainstate.nn.LayerNorm(in_size)
|
687
|
+
|
688
|
+
x = brainstate.random.randn(8, 10, 20)
|
689
|
+
|
690
|
+
# Run twice
|
691
|
+
output1 = net(x)
|
692
|
+
output2 = net(x)
|
693
|
+
|
694
|
+
# Should be deterministic
|
695
|
+
np.testing.assert_allclose(output1, output2)
|
696
|
+
|
697
|
+
|
698
|
+
if __name__ == '__main__':
|
699
|
+
absltest.main()
|