brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,953 +1,953 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- import numpy as np
21
- from absl.testing import absltest
22
- from absl.testing import parameterized
23
-
24
- import brainstate
25
- import brainstate.nn as nn
26
-
27
-
28
- class TestFlatten(parameterized.TestCase):
29
- def test_flatten1(self):
30
- for size in [
31
- (16, 32, 32, 8),
32
- (32, 8),
33
- (10, 20, 30),
34
- ]:
35
- arr = brainstate.random.rand(*size)
36
- f = nn.Flatten(start_axis=0)
37
- out = f(arr)
38
- self.assertTrue(out.shape == (np.prod(size),))
39
-
40
- def test_flatten2(self):
41
- for size in [
42
- (16, 32, 32, 8),
43
- (32, 8),
44
- (10, 20, 30),
45
- ]:
46
- arr = brainstate.random.rand(*size)
47
- f = nn.Flatten(start_axis=1)
48
- out = f(arr)
49
- self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
50
-
51
- def test_flatten3(self):
52
- size = (16, 32, 32, 8)
53
- arr = brainstate.random.rand(*size)
54
- f = nn.Flatten(start_axis=0, in_size=(32, 8))
55
- out = f(arr)
56
- self.assertTrue(out.shape == (16, 32, 32 * 8))
57
-
58
- def test_flatten4(self):
59
- size = (16, 32, 32, 8)
60
- arr = brainstate.random.rand(*size)
61
- f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
62
- out = f(arr)
63
- self.assertTrue(out.shape == (16, 32, 32 * 8))
64
-
65
-
66
- class TestUnflatten(parameterized.TestCase):
67
- """Comprehensive tests for Unflatten layer.
68
-
69
- Note: Due to a bug in u.math.unflatten with negative axis handling,
70
- we only test with positive axis values.
71
- """
72
-
73
- def test_unflatten_basic_2d(self):
74
- """Test basic Unflatten functionality for 2D tensors."""
75
- arr = brainstate.random.rand(6, 12)
76
-
77
- # Unflatten last dimension (use positive axis due to bug)
78
- unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
79
- out = unflatten(arr)
80
- self.assertEqual(out.shape, (6, 3, 4))
81
-
82
- # Unflatten first dimension
83
- unflatten = nn.Unflatten(axis=0, sizes=(2, 3))
84
- out = unflatten(arr)
85
- self.assertEqual(out.shape, (2, 3, 12))
86
-
87
- def test_unflatten_basic_3d(self):
88
- """Test basic Unflatten functionality for 3D tensors."""
89
- arr = brainstate.random.rand(4, 6, 24)
90
-
91
- # Unflatten last dimension using positive index
92
- unflatten = nn.Unflatten(axis=2, sizes=(2, 3, 4))
93
- out = unflatten(arr)
94
- self.assertEqual(out.shape, (4, 6, 2, 3, 4))
95
-
96
- # Unflatten middle dimension
97
- unflatten = nn.Unflatten(axis=1, sizes=(2, 3))
98
- out = unflatten(arr)
99
- self.assertEqual(out.shape, (4, 2, 3, 24))
100
-
101
- def test_unflatten_with_in_size(self):
102
- """Test Unflatten with in_size parameter."""
103
- # Test with in_size specified
104
- unflatten = nn.Unflatten(axis=1, sizes=(2, 3), in_size=(4, 6))
105
-
106
- # Check that out_size is computed correctly
107
- self.assertIsNotNone(unflatten.out_size)
108
- self.assertEqual(unflatten.out_size, (4, 2, 3))
109
-
110
- # Apply to actual tensor
111
- arr = brainstate.random.rand(4, 6)
112
- out = unflatten(arr)
113
- self.assertEqual(out.shape, (4, 2, 3))
114
-
115
- def test_unflatten_preserve_batch_dims(self):
116
- """Test that Unflatten preserves batch dimensions."""
117
- # Multiple batch dimensions
118
- arr = brainstate.random.rand(2, 3, 4, 20)
119
-
120
- # Unflatten last dimension (use positive axis)
121
- unflatten = nn.Unflatten(axis=3, sizes=(4, 5))
122
- out = unflatten(arr)
123
- self.assertEqual(out.shape, (2, 3, 4, 4, 5))
124
-
125
- def test_unflatten_single_element_split(self):
126
- """Test Unflatten with sizes containing 1."""
127
- arr = brainstate.random.rand(3, 12)
128
-
129
- # Include dimension of size 1
130
- unflatten = nn.Unflatten(axis=1, sizes=(1, 3, 4))
131
- out = unflatten(arr)
132
- self.assertEqual(out.shape, (3, 1, 3, 4))
133
-
134
- # Multiple ones
135
- unflatten = nn.Unflatten(axis=1, sizes=(1, 1, 12))
136
- out = unflatten(arr)
137
- self.assertEqual(out.shape, (3, 1, 1, 12))
138
-
139
- def test_unflatten_large_split(self):
140
- """Test Unflatten with large number of dimensions."""
141
- arr = brainstate.random.rand(2, 120)
142
-
143
- # Split into many dimensions
144
- unflatten = nn.Unflatten(axis=1, sizes=(2, 3, 4, 5))
145
- out = unflatten(arr)
146
- self.assertEqual(out.shape, (2, 2, 3, 4, 5))
147
-
148
- # Verify total elements preserved
149
- self.assertEqual(arr.size, out.size)
150
- self.assertEqual(2 * 3 * 4 * 5, 120)
151
-
152
- def test_unflatten_flatten_inverse(self):
153
- """Test that Unflatten is inverse of Flatten."""
154
- original = brainstate.random.rand(2, 3, 4, 5)
155
-
156
- # Flatten dimensions 1 and 2
157
- flatten = nn.Flatten(start_axis=1, end_axis=2)
158
- flattened = flatten(original)
159
- self.assertEqual(flattened.shape, (2, 12, 5))
160
-
161
- # Unflatten back
162
- unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
163
- restored = unflatten(flattened)
164
- self.assertEqual(restored.shape, original.shape)
165
-
166
- # Values should be identical
167
- self.assertTrue(jnp.allclose(original, restored))
168
-
169
- def test_unflatten_sequential_operations(self):
170
- """Test Unflatten in sequential operations."""
171
- arr = brainstate.random.rand(4, 24)
172
-
173
- # Apply multiple unflatten operations
174
- unflatten1 = nn.Unflatten(axis=1, sizes=(6, 4))
175
- intermediate = unflatten1(arr)
176
- self.assertEqual(intermediate.shape, (4, 6, 4))
177
-
178
- unflatten2 = nn.Unflatten(axis=1, sizes=(2, 3))
179
- final = unflatten2(intermediate)
180
- self.assertEqual(final.shape, (4, 2, 3, 4))
181
-
182
- def test_unflatten_error_cases(self):
183
- """Test error handling in Unflatten."""
184
- # Test invalid sizes type
185
- with self.assertRaises(TypeError):
186
- nn.Unflatten(axis=0, sizes=12) # sizes must be tuple or list
187
-
188
- with self.assertRaises(TypeError):
189
- nn.Unflatten(axis=0, sizes="invalid")
190
-
191
- # Test invalid element in sizes
192
- with self.assertRaises(TypeError):
193
- nn.Unflatten(axis=0, sizes=(2, "invalid"))
194
-
195
- with self.assertRaises(TypeError):
196
- nn.Unflatten(axis=0, sizes=(2.5, 3)) # must be integers
197
-
198
- @parameterized.named_parameters(
199
- ('axis_0_2d', 0, (10, 20), (2, 5)),
200
- ('axis_1_2d', 1, (10, 20), (4, 5)),
201
- ('axis_0_3d', 0, (6, 8, 10), (2, 3)),
202
- ('axis_1_3d', 1, (6, 8, 10), (2, 4)),
203
- ('axis_2_3d', 2, (6, 8, 10), (2, 5)),
204
- )
205
- def test_unflatten_parameterized(self, axis, input_shape, unflatten_sizes):
206
- """Parameterized test for various axis and shape combinations."""
207
- arr = brainstate.random.rand(*input_shape)
208
- unflatten = nn.Unflatten(axis=axis, sizes=unflatten_sizes)
209
- out = unflatten(arr)
210
-
211
- # Check that product of unflatten_sizes matches original dimension
212
- original_dim_size = input_shape[axis]
213
- self.assertEqual(np.prod(unflatten_sizes), original_dim_size)
214
-
215
- # Check output shape
216
- expected_shape = list(input_shape)
217
- expected_shape[axis:axis+1] = unflatten_sizes
218
- self.assertEqual(out.shape, tuple(expected_shape))
219
-
220
- # Check total size preserved
221
- self.assertEqual(arr.size, out.size)
222
-
223
- def test_unflatten_values_preserved(self):
224
- """Test that values are correctly preserved during unflatten."""
225
- # Create a tensor with known pattern
226
- arr = jnp.arange(24).reshape(2, 12)
227
-
228
- unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
229
- out = unflatten(arr)
230
-
231
- # Check shape
232
- self.assertEqual(out.shape, (2, 3, 4))
233
-
234
- # Check that values are correctly rearranged
235
- # First batch
236
- self.assertTrue(jnp.allclose(out[0, 0, :], jnp.arange(0, 4)))
237
- self.assertTrue(jnp.allclose(out[0, 1, :], jnp.arange(4, 8)))
238
- self.assertTrue(jnp.allclose(out[0, 2, :], jnp.arange(8, 12)))
239
-
240
- # Second batch
241
- self.assertTrue(jnp.allclose(out[1, 0, :], jnp.arange(12, 16)))
242
- self.assertTrue(jnp.allclose(out[1, 1, :], jnp.arange(16, 20)))
243
- self.assertTrue(jnp.allclose(out[1, 2, :], jnp.arange(20, 24)))
244
-
245
- def test_unflatten_with_complex_shapes(self):
246
- """Test Unflatten with complex multi-dimensional shapes."""
247
- # 5D tensor
248
- arr = brainstate.random.rand(2, 3, 4, 5, 60)
249
-
250
- # Unflatten last dimension (use positive axis)
251
- unflatten = nn.Unflatten(axis=4, sizes=(3, 4, 5))
252
- out = unflatten(arr)
253
- self.assertEqual(out.shape, (2, 3, 4, 5, 3, 4, 5))
254
-
255
- # Unflatten middle dimension
256
- arr = brainstate.random.rand(2, 3, 12, 5, 6)
257
- unflatten = nn.Unflatten(axis=2, sizes=(3, 4))
258
- out = unflatten(arr)
259
- self.assertEqual(out.shape, (2, 3, 3, 4, 5, 6))
260
-
261
- def test_unflatten_edge_cases(self):
262
- """Test edge cases for Unflatten."""
263
- # Single element tensor
264
- arr = brainstate.random.rand(1)
265
- unflatten = nn.Unflatten(axis=0, sizes=(1,))
266
- out = unflatten(arr)
267
- self.assertEqual(out.shape, (1,))
268
-
269
- # Unflatten to same dimension (essentially no-op)
270
- arr = brainstate.random.rand(3, 5)
271
- unflatten = nn.Unflatten(axis=1, sizes=(5,))
272
- out = unflatten(arr)
273
- self.assertEqual(out.shape, (3, 5))
274
-
275
- # Very large unflatten
276
- arr = brainstate.random.rand(2, 1024)
277
- unflatten = nn.Unflatten(axis=1, sizes=(4, 4, 4, 4, 4))
278
- out = unflatten(arr)
279
- self.assertEqual(out.shape, (2, 4, 4, 4, 4, 4))
280
- self.assertEqual(4**5, 1024)
281
-
282
- def test_unflatten_jit_compatibility(self):
283
- """Test that Unflatten works with JAX JIT compilation."""
284
- arr = brainstate.random.rand(4, 12)
285
- unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
286
-
287
- # JIT compile the unflatten operation
288
- jitted_unflatten = jax.jit(unflatten.update)
289
-
290
- # Compare results
291
- out_normal = unflatten(arr)
292
- out_jitted = jitted_unflatten(arr)
293
-
294
- self.assertEqual(out_normal.shape, (4, 3, 4))
295
- self.assertEqual(out_jitted.shape, (4, 3, 4))
296
- self.assertTrue(jnp.allclose(out_normal, out_jitted))
297
-
298
-
299
- class TestMaxPool1d(parameterized.TestCase):
300
- """Comprehensive tests for MaxPool1d."""
301
-
302
- def test_maxpool1d_basic(self):
303
- """Test basic MaxPool1d functionality."""
304
- # Test with different input shapes
305
- arr = brainstate.random.rand(16, 32, 8) # (batch, length, channels)
306
-
307
- # Test with kernel_size=2, stride=2
308
- pool = nn.MaxPool1d(2, 2, channel_axis=-1)
309
- out = pool(arr)
310
- self.assertEqual(out.shape, (16, 16, 8))
311
-
312
- # Test with kernel_size=3, stride=1
313
- pool = nn.MaxPool1d(3, 1, channel_axis=-1)
314
- out = pool(arr)
315
- self.assertEqual(out.shape, (16, 30, 8))
316
-
317
- def test_maxpool1d_padding(self):
318
- """Test MaxPool1d with padding."""
319
- arr = brainstate.random.rand(4, 10, 3)
320
-
321
- # Test with padding
322
- pool = nn.MaxPool1d(3, 2, padding=1, channel_axis=-1)
323
- out = pool(arr)
324
- self.assertEqual(out.shape, (4, 5, 3))
325
-
326
- # Test with tuple padding (same value for both sides in 1D)
327
- pool = nn.MaxPool1d(3, 2, padding=(1,), channel_axis=-1)
328
- out = pool(arr)
329
- self.assertEqual(out.shape, (4, 5, 3))
330
-
331
- def test_maxpool1d_return_indices(self):
332
- """Test MaxPool1d with return_indices=True."""
333
- arr = brainstate.random.rand(2, 10, 3)
334
-
335
- pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
336
- out, indices = pool(arr)
337
- self.assertEqual(out.shape, (2, 5, 3))
338
- self.assertEqual(indices.shape, (2, 5, 3))
339
-
340
- def test_maxpool1d_no_channel_axis(self):
341
- """Test MaxPool1d without channel axis."""
342
- arr = brainstate.random.rand(16, 32)
343
-
344
- pool = nn.MaxPool1d(2, 2, channel_axis=None)
345
- out = pool(arr)
346
- self.assertEqual(out.shape, (16, 16))
347
-
348
-
349
- class TestMaxPool2d(parameterized.TestCase):
350
- """Comprehensive tests for MaxPool2d."""
351
-
352
- def test_maxpool2d_basic(self):
353
- """Test basic MaxPool2d functionality."""
354
- arr = brainstate.random.rand(16, 32, 32, 8) # (batch, height, width, channels)
355
-
356
- out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
357
- self.assertTrue(out.shape == (16, 16, 16, 8))
358
-
359
- out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
360
- self.assertTrue(out.shape == (16, 32, 16, 4))
361
-
362
- def test_maxpool2d_padding(self):
363
- """Test MaxPool2d with padding."""
364
- arr = brainstate.random.rand(16, 32, 32, 8)
365
-
366
- out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
367
- self.assertTrue(out.shape == (16, 32, 17, 5))
368
-
369
- out = nn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
370
- self.assertTrue(out.shape == (16, 32, 18, 5))
371
-
372
- out = nn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
373
- self.assertTrue(out.shape == (16, 17, 17, 8))
374
-
375
- out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
376
- self.assertTrue(out.shape == (16, 17, 32, 5))
377
-
378
- def test_maxpool2d_return_indices(self):
379
- """Test MaxPool2d with return_indices=True."""
380
- arr = brainstate.random.rand(2, 8, 8, 3)
381
-
382
- pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
383
- out, indices = pool(arr)
384
- self.assertEqual(out.shape, (2, 4, 4, 3))
385
- self.assertEqual(indices.shape, (2, 4, 4, 3))
386
-
387
- def test_maxpool2d_different_strides(self):
388
- """Test MaxPool2d with different stride values."""
389
- arr = brainstate.random.rand(2, 16, 16, 4)
390
-
391
- # Different strides for height and width
392
- pool = nn.MaxPool2d(3, stride=(2, 1), channel_axis=-1)
393
- out = pool(arr)
394
- self.assertEqual(out.shape, (2, 7, 14, 4))
395
-
396
-
397
- class TestMaxPool3d(parameterized.TestCase):
398
- """Comprehensive tests for MaxPool3d."""
399
-
400
- def test_maxpool3d_basic(self):
401
- """Test basic MaxPool3d functionality."""
402
- arr = brainstate.random.rand(2, 16, 16, 16, 4) # (batch, depth, height, width, channels)
403
-
404
- pool = nn.MaxPool3d(2, 2, channel_axis=-1)
405
- out = pool(arr)
406
- self.assertEqual(out.shape, (2, 8, 8, 8, 4))
407
-
408
- pool = nn.MaxPool3d(3, 1, channel_axis=-1)
409
- out = pool(arr)
410
- self.assertEqual(out.shape, (2, 14, 14, 14, 4))
411
-
412
- def test_maxpool3d_padding(self):
413
- """Test MaxPool3d with padding."""
414
- arr = brainstate.random.rand(1, 8, 8, 8, 2)
415
-
416
- pool = nn.MaxPool3d(3, 2, padding=1, channel_axis=-1)
417
- out = pool(arr)
418
- self.assertEqual(out.shape, (1, 4, 4, 4, 2))
419
-
420
- def test_maxpool3d_return_indices(self):
421
- """Test MaxPool3d with return_indices=True."""
422
- arr = brainstate.random.rand(1, 4, 4, 4, 2)
423
-
424
- pool = nn.MaxPool3d(2, 2, channel_axis=-1, return_indices=True)
425
- out, indices = pool(arr)
426
- self.assertEqual(out.shape, (1, 2, 2, 2, 2))
427
- self.assertEqual(indices.shape, (1, 2, 2, 2, 2))
428
-
429
-
430
- class TestAvgPool1d(parameterized.TestCase):
431
- """Comprehensive tests for AvgPool1d."""
432
-
433
- def test_avgpool1d_basic(self):
434
- """Test basic AvgPool1d functionality."""
435
- arr = brainstate.random.rand(4, 16, 8) # (batch, length, channels)
436
-
437
- pool = nn.AvgPool1d(2, 2, channel_axis=-1)
438
- out = pool(arr)
439
- self.assertEqual(out.shape, (4, 8, 8))
440
-
441
- # Test averaging values
442
- arr = jnp.ones((1, 4, 2))
443
- pool = nn.AvgPool1d(2, 2, channel_axis=-1)
444
- out = pool(arr)
445
- self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 2))))
446
-
447
- def test_avgpool1d_padding(self):
448
- """Test AvgPool1d with padding."""
449
- arr = brainstate.random.rand(2, 10, 3)
450
-
451
- pool = nn.AvgPool1d(3, 2, padding=1, channel_axis=-1)
452
- out = pool(arr)
453
- self.assertEqual(out.shape, (2, 5, 3))
454
-
455
- def test_avgpool1d_divisor_override(self):
456
- """Test AvgPool1d divisor behavior."""
457
- arr = jnp.ones((1, 4, 1))
458
-
459
- # Standard average pooling
460
- pool = nn.AvgPool1d(2, 2, channel_axis=-1)
461
- out = pool(arr)
462
-
463
- # All values should still be 1.0 for constant input
464
- self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 1))))
465
-
466
-
467
- class TestAvgPool2d(parameterized.TestCase):
468
- """Comprehensive tests for AvgPool2d."""
469
-
470
- def test_avgpool2d_basic(self):
471
- """Test basic AvgPool2d functionality."""
472
- arr = brainstate.random.rand(16, 32, 32, 8)
473
-
474
- out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
475
- self.assertTrue(out.shape == (16, 16, 16, 8))
476
-
477
- out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
478
- self.assertTrue(out.shape == (16, 32, 16, 4))
479
-
480
- def test_avgpool2d_padding(self):
481
- """Test AvgPool2d with padding."""
482
- arr = brainstate.random.rand(16, 32, 32, 8)
483
-
484
- out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
485
- self.assertTrue(out.shape == (16, 32, 17, 5))
486
-
487
- out = nn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
488
- self.assertTrue(out.shape == (16, 32, 18, 5))
489
-
490
- out = nn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
491
- self.assertTrue(out.shape == (16, 17, 17, 8))
492
-
493
- out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
494
- self.assertTrue(out.shape == (16, 17, 32, 5))
495
-
496
- def test_avgpool2d_values(self):
497
- """Test AvgPool2d computes correct average values."""
498
- arr = jnp.ones((1, 4, 4, 1))
499
- pool = nn.AvgPool2d(2, 2, channel_axis=-1)
500
- out = pool(arr)
501
- self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 2, 1))))
502
-
503
-
504
- class TestAvgPool3d(parameterized.TestCase):
505
- """Comprehensive tests for AvgPool3d."""
506
-
507
- def test_avgpool3d_basic(self):
508
- """Test basic AvgPool3d functionality."""
509
- arr = brainstate.random.rand(2, 8, 8, 8, 4)
510
-
511
- pool = nn.AvgPool3d(2, 2, channel_axis=-1)
512
- out = pool(arr)
513
- self.assertEqual(out.shape, (2, 4, 4, 4, 4))
514
-
515
- def test_avgpool3d_padding(self):
516
- """Test AvgPool3d with padding."""
517
- arr = brainstate.random.rand(1, 6, 6, 6, 2)
518
-
519
- pool = nn.AvgPool3d(3, 2, padding=1, channel_axis=-1)
520
- out = pool(arr)
521
- self.assertEqual(out.shape, (1, 3, 3, 3, 2))
522
-
523
-
524
- class TestMaxUnpool1d(parameterized.TestCase):
525
- """Comprehensive tests for MaxUnpool1d."""
526
-
527
- def test_maxunpool1d_basic(self):
528
- """Test basic MaxUnpool1d functionality."""
529
- # Create input
530
- arr = brainstate.random.rand(2, 8, 3)
531
-
532
- # Pool with indices
533
- pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
534
- pooled, indices = pool(arr)
535
-
536
- # Unpool
537
- unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
538
- unpooled = unpool(pooled, indices)
539
-
540
- # Shape should match original (or be close depending on padding)
541
- self.assertEqual(unpooled.shape, (2, 8, 3))
542
-
543
- def test_maxunpool1d_with_output_size(self):
544
- """Test MaxUnpool1d with explicit output_size."""
545
- arr = brainstate.random.rand(1, 10, 2)
546
-
547
- pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
548
- pooled, indices = pool(arr)
549
-
550
- unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
551
- unpooled = unpool(pooled, indices, output_size=(1, 10, 2))
552
-
553
- self.assertEqual(unpooled.shape, (1, 10, 2))
554
-
555
-
556
- class TestMaxUnpool2d(parameterized.TestCase):
557
- """Comprehensive tests for MaxUnpool2d."""
558
-
559
- def test_maxunpool2d_basic(self):
560
- """Test basic MaxUnpool2d functionality."""
561
- arr = brainstate.random.rand(2, 8, 8, 3)
562
-
563
- # Pool with indices
564
- pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
565
- pooled, indices = pool(arr)
566
-
567
- # Unpool
568
- unpool = nn.MaxUnpool2d(2, 2, channel_axis=-1)
569
- unpooled = unpool(pooled, indices)
570
-
571
- self.assertEqual(unpooled.shape, (2, 8, 8, 3))
572
-
573
- def test_maxunpool2d_values(self):
574
- """Test MaxUnpool2d places values correctly."""
575
- # Create simple input where we can track values
576
- arr = jnp.array([[1., 2., 3., 4.],
577
- [5., 6., 7., 8.]]) # (2, 4)
578
- arr = arr.reshape(1, 2, 2, 2) # (1, 2, 2, 2)
579
-
580
- # Pool to get max value and its index
581
- pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
582
- pooled, indices = pool(arr)
583
-
584
- # Unpool
585
- unpool = nn.MaxUnpool2d(2, 2, channel_axis=-1)
586
- unpooled = unpool(pooled, indices)
587
-
588
- # Check that max value (8.0) is preserved
589
- self.assertTrue(jnp.max(unpooled) == 8.0)
590
- # Check shape
591
- self.assertEqual(unpooled.shape, (1, 2, 2, 2))
592
-
593
-
594
- class TestMaxUnpool3d(parameterized.TestCase):
595
- """Comprehensive tests for MaxUnpool3d."""
596
-
597
- def test_maxunpool3d_basic(self):
598
- """Test basic MaxUnpool3d functionality."""
599
- arr = brainstate.random.rand(1, 4, 4, 4, 2)
600
-
601
- # Pool with indices
602
- pool = nn.MaxPool3d(2, 2, channel_axis=-1, return_indices=True)
603
- pooled, indices = pool(arr)
604
-
605
- # Unpool
606
- unpool = nn.MaxUnpool3d(2, 2, channel_axis=-1)
607
- unpooled = unpool(pooled, indices)
608
-
609
- self.assertEqual(unpooled.shape, (1, 4, 4, 4, 2))
610
-
611
-
612
- class TestLPPool1d(parameterized.TestCase):
613
- """Comprehensive tests for LPPool1d."""
614
-
615
- def test_lppool1d_basic(self):
616
- """Test basic LPPool1d functionality."""
617
- arr = brainstate.random.rand(2, 16, 4)
618
-
619
- # Test L2 pooling (norm_type=2)
620
- pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
621
- out = pool(arr)
622
- self.assertEqual(out.shape, (2, 8, 4))
623
-
624
- def test_lppool1d_different_norms(self):
625
- """Test LPPool1d with different norm types."""
626
- arr = brainstate.random.rand(1, 8, 2)
627
-
628
- # Test with p=1 (should be similar to average)
629
- pool1 = nn.LPPool1d(norm_type=1, kernel_size=2, stride=2, channel_axis=-1)
630
- out1 = pool1(arr)
631
-
632
- # Test with p=2 (L2 norm)
633
- pool2 = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
634
- out2 = pool2(arr)
635
-
636
- # Test with large p (should approach max pooling)
637
- pool_inf = nn.LPPool1d(norm_type=10, kernel_size=2, stride=2, channel_axis=-1)
638
- out_inf = pool_inf(arr)
639
-
640
- self.assertEqual(out1.shape, (1, 4, 2))
641
- self.assertEqual(out2.shape, (1, 4, 2))
642
- self.assertEqual(out_inf.shape, (1, 4, 2))
643
-
644
- def test_lppool1d_value_check(self):
645
- """Test LPPool1d computes correct values."""
646
- # Simple test case
647
- arr = jnp.array([[[2., 2.], [2., 2.]]]) # (1, 2, 2)
648
-
649
- pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
650
- out = pool(arr)
651
-
652
- # For constant values, Lp norm should equal the value
653
- self.assertTrue(jnp.allclose(out, 2.0, atol=1e-5))
654
-
655
-
656
- class TestLPPool2d(parameterized.TestCase):
657
- """Comprehensive tests for LPPool2d."""
658
-
659
- def test_lppool2d_basic(self):
660
- """Test basic LPPool2d functionality."""
661
- arr = brainstate.random.rand(2, 8, 8, 4)
662
-
663
- pool = nn.LPPool2d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
664
- out = pool(arr)
665
- self.assertEqual(out.shape, (2, 4, 4, 4))
666
-
667
- def test_lppool2d_padding(self):
668
- """Test LPPool2d with padding."""
669
- arr = brainstate.random.rand(1, 7, 7, 2)
670
-
671
- pool = nn.LPPool2d(norm_type=2, kernel_size=3, stride=2, padding=1, channel_axis=-1)
672
- out = pool(arr)
673
- self.assertEqual(out.shape, (1, 4, 4, 2))
674
-
675
- def test_lppool2d_different_kernel_sizes(self):
676
- """Test LPPool2d with non-square kernels."""
677
- arr = brainstate.random.rand(1, 8, 6, 2)
678
-
679
- pool = nn.LPPool2d(norm_type=2, kernel_size=(3, 2), stride=(2, 1), channel_axis=-1)
680
- out = pool(arr)
681
- self.assertEqual(out.shape, (1, 3, 5, 2))
682
-
683
-
684
- class TestLPPool3d(parameterized.TestCase):
685
- """Comprehensive tests for LPPool3d."""
686
-
687
- def test_lppool3d_basic(self):
688
- """Test basic LPPool3d functionality."""
689
- arr = brainstate.random.rand(1, 8, 8, 8, 2)
690
-
691
- pool = nn.LPPool3d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
692
- out = pool(arr)
693
- self.assertEqual(out.shape, (1, 4, 4, 4, 2))
694
-
695
- def test_lppool3d_different_norms(self):
696
- """Test LPPool3d with different norm types."""
697
- arr = brainstate.random.rand(1, 4, 4, 4, 1)
698
-
699
- # Different p values should give different results
700
- pool1 = nn.LPPool3d(norm_type=1, kernel_size=2, stride=2, channel_axis=-1)
701
- pool2 = nn.LPPool3d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
702
- pool3 = nn.LPPool3d(norm_type=3, kernel_size=2, stride=2, channel_axis=-1)
703
-
704
- out1 = pool1(arr)
705
- out2 = pool2(arr)
706
- out3 = pool3(arr)
707
-
708
- # All should have same shape
709
- self.assertEqual(out1.shape, (1, 2, 2, 2, 1))
710
- self.assertEqual(out2.shape, (1, 2, 2, 2, 1))
711
- self.assertEqual(out3.shape, (1, 2, 2, 2, 1))
712
-
713
- # Values should be different (unless input is uniform)
714
- self.assertFalse(jnp.allclose(out1, out2))
715
- self.assertFalse(jnp.allclose(out2, out3))
716
-
717
-
718
- class TestAdaptivePool(parameterized.TestCase):
719
- """Tests for adaptive pooling layers."""
720
-
721
- @parameterized.named_parameters(
722
- dict(testcase_name=f'target_size={target_size}',
723
- target_size=target_size)
724
- for target_size in [10, 9, 8, 7, 6]
725
- )
726
- def test_adaptive_pool1d(self, target_size):
727
- """Test internal adaptive pooling function."""
728
- from brainstate.nn._poolings import _adaptive_pool1d
729
-
730
- arr = brainstate.random.rand(100)
731
- op = jax.numpy.mean
732
-
733
- out = _adaptive_pool1d(arr, target_size, op)
734
- self.assertTrue(out.shape == (target_size,))
735
-
736
- def test_adaptive_avg_pool1d(self):
737
- """Test AdaptiveAvgPool1d."""
738
- input = brainstate.random.randn(2, 32, 4)
739
-
740
- # Test with different target sizes
741
- pool = nn.AdaptiveAvgPool1d(5, channel_axis=-1)
742
- output = pool(input)
743
- self.assertEqual(output.shape, (2, 5, 4))
744
-
745
- # Test with single element input
746
- pool = nn.AdaptiveAvgPool1d(1, channel_axis=-1)
747
- output = pool(input)
748
- self.assertEqual(output.shape, (2, 1, 4))
749
-
750
- def test_adaptive_avg_pool2d(self):
751
- """Test AdaptiveAvgPool2d."""
752
- input = brainstate.random.randn(2, 8, 9, 3)
753
-
754
- # Square output
755
- output = nn.AdaptiveAvgPool2d(5, channel_axis=-1)(input)
756
- self.assertEqual(output.shape, (2, 5, 5, 3))
757
-
758
- # Non-square output
759
- output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=-1)(input)
760
- self.assertEqual(output.shape, (2, 5, 7, 3))
761
-
762
- # Test with single integer (square output)
763
- output = nn.AdaptiveAvgPool2d(4, channel_axis=-1)(input)
764
- self.assertEqual(output.shape, (2, 4, 4, 3))
765
-
766
- def test_adaptive_avg_pool3d(self):
767
- """Test AdaptiveAvgPool3d."""
768
- input = brainstate.random.randn(1, 8, 6, 4, 2)
769
-
770
- pool = nn.AdaptiveAvgPool3d((4, 3, 2), channel_axis=-1)
771
- output = pool(input)
772
- self.assertEqual(output.shape, (1, 4, 3, 2, 2))
773
-
774
- # Cube output
775
- pool = nn.AdaptiveAvgPool3d(3, channel_axis=-1)
776
- output = pool(input)
777
- self.assertEqual(output.shape, (1, 3, 3, 3, 2))
778
-
779
- def test_adaptive_max_pool1d(self):
780
- """Test AdaptiveMaxPool1d."""
781
- input = brainstate.random.randn(2, 32, 4)
782
-
783
- pool = nn.AdaptiveMaxPool1d(8, channel_axis=-1)
784
- output = pool(input)
785
- self.assertEqual(output.shape, (2, 8, 4))
786
-
787
- def test_adaptive_max_pool2d(self):
788
- """Test AdaptiveMaxPool2d."""
789
- input = brainstate.random.randn(2, 10, 8, 3)
790
-
791
- pool = nn.AdaptiveMaxPool2d((5, 4), channel_axis=-1)
792
- output = pool(input)
793
- self.assertEqual(output.shape, (2, 5, 4, 3))
794
-
795
- def test_adaptive_max_pool3d(self):
796
- """Test AdaptiveMaxPool3d."""
797
- input = brainstate.random.randn(1, 8, 8, 8, 2)
798
-
799
- pool = nn.AdaptiveMaxPool3d((4, 4, 4), channel_axis=-1)
800
- output = pool(input)
801
- self.assertEqual(output.shape, (1, 4, 4, 4, 2))
802
-
803
-
804
- class TestPoolingEdgeCases(parameterized.TestCase):
805
- """Test edge cases and error conditions."""
806
-
807
- def test_pool_with_stride_none(self):
808
- """Test pooling with stride=None (defaults to kernel_size)."""
809
- arr = brainstate.random.rand(1, 8, 2)
810
-
811
- pool = nn.MaxPool1d(kernel_size=3, stride=None, channel_axis=-1)
812
- out = pool(arr)
813
- # stride defaults to kernel_size=3
814
- self.assertEqual(out.shape, (1, 2, 2))
815
-
816
- def test_pool_with_large_kernel(self):
817
- """Test pooling with kernel larger than input."""
818
- arr = brainstate.random.rand(1, 4, 2)
819
-
820
- # Kernel size larger than spatial dimension
821
- pool = nn.MaxPool1d(kernel_size=5, stride=1, channel_axis=-1)
822
- out = pool(arr)
823
- # Should handle gracefully (may produce empty output or handle with padding)
824
- self.assertTrue(out.shape[1] >= 0)
825
-
826
- def test_pool_single_element(self):
827
- """Test pooling on single-element tensors."""
828
- arr = brainstate.random.rand(1, 1, 1)
829
-
830
- pool = nn.AvgPool1d(1, 1, channel_axis=-1)
831
- out = pool(arr)
832
- self.assertEqual(out.shape, (1, 1, 1))
833
- self.assertTrue(jnp.allclose(out, arr))
834
-
835
- def test_adaptive_pool_smaller_output(self):
836
- """Test adaptive pooling with output smaller than input."""
837
- arr = brainstate.random.rand(1, 16, 2)
838
-
839
- # Adaptive pooling to smaller size
840
- pool = nn.AdaptiveAvgPool1d(4, channel_axis=-1)
841
- out = pool(arr)
842
- self.assertEqual(out.shape, (1, 4, 2))
843
-
844
- def test_unpool_without_indices(self):
845
- """Test unpooling behavior with placeholder indices."""
846
- pooled = brainstate.random.rand(1, 4, 2)
847
- indices = jnp.zeros_like(pooled, dtype=jnp.int32)
848
-
849
- unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
850
- # Should not raise error even with zero indices
851
- unpooled = unpool(pooled, indices)
852
- self.assertEqual(unpooled.shape, (1, 8, 2))
853
-
854
- def test_lppool_extreme_norm(self):
855
- """Test LPPool with extreme norm values."""
856
- arr = brainstate.random.rand(1, 8, 2) + 0.1 # Avoid zeros
857
-
858
- # Very large p (approaches max pooling)
859
- pool_large = nn.LPPool1d(norm_type=20, kernel_size=2, stride=2, channel_axis=-1)
860
- out_large = pool_large(arr)
861
-
862
- # Compare with actual max pooling
863
- pool_max = nn.MaxPool1d(2, 2, channel_axis=-1)
864
- out_max = pool_max(arr)
865
-
866
- # Should approach max pooling for large p (but not exactly equal)
867
- # Just check shapes match
868
- self.assertEqual(out_large.shape, out_max.shape)
869
-
870
- def test_pool_with_channels_first(self):
871
- """Test pooling with channels in different positions."""
872
- arr = brainstate.random.rand(3, 16, 8) # (dim0, dim1, dim2)
873
-
874
- # Channel axis at position 0 - treats dim 0 as channels, pools last dimension
875
- pool = nn.MaxPool1d(2, 2, channel_axis=0)
876
- out = pool(arr)
877
- # Pools the last dimension, keeping first two
878
- self.assertEqual(out.shape, (3, 16, 4))
879
-
880
- # Channel axis at position -1 (last) - pools middle dimension
881
- pool = nn.MaxPool1d(2, 2, channel_axis=-1)
882
- out = pool(arr)
883
- # Pools the middle dimension, keeping first and last
884
- self.assertEqual(out.shape, (3, 8, 8))
885
-
886
- # No channel axis - pools last dimension, treating earlier dims as batch
887
- pool = nn.MaxPool1d(2, 2, channel_axis=None)
888
- out = pool(arr)
889
- # Pools the last dimension
890
- self.assertEqual(out.shape, (3, 16, 4))
891
-
892
-
893
- class TestPoolingMathematicalProperties(parameterized.TestCase):
894
- """Test mathematical properties of pooling operations."""
895
-
896
- def test_maxpool_idempotence(self):
897
- """Test that max pooling with kernel_size=1 is identity."""
898
- arr = brainstate.random.rand(2, 8, 3)
899
-
900
- pool = nn.MaxPool1d(1, 1, channel_axis=-1)
901
- out = pool(arr)
902
-
903
- self.assertTrue(jnp.allclose(out, arr))
904
-
905
- def test_avgpool_constant_input(self):
906
- """Test average pooling on constant input."""
907
- arr = jnp.ones((1, 8, 2)) * 5.0
908
-
909
- pool = nn.AvgPool1d(2, 2, channel_axis=-1)
910
- out = pool(arr)
911
-
912
- # Average of constant should be the constant
913
- self.assertTrue(jnp.allclose(out, 5.0))
914
-
915
- def test_lppool_norm_properties(self):
916
- """Test Lp pooling norm properties."""
917
- arr = brainstate.random.rand(1, 4, 1) + 0.1
918
-
919
- # L1 norm (p=1) should give average of absolute values
920
- pool_l1 = nn.LPPool1d(norm_type=1, kernel_size=4, stride=4, channel_axis=-1)
921
- out_l1 = pool_l1(arr)
922
-
923
- # Manual calculation
924
- manual_l1 = jnp.mean(jnp.abs(arr[:, :4, :]))
925
-
926
- self.assertTrue(jnp.allclose(out_l1[0, 0, 0], manual_l1, rtol=1e-5))
927
-
928
- def test_maxpool_monotonicity(self):
929
- """Test that max pooling preserves monotonicity."""
930
- arr1 = brainstate.random.rand(1, 8, 2)
931
- arr2 = arr1 + 1.0 # Strictly greater
932
-
933
- pool = nn.MaxPool1d(2, 2, channel_axis=-1)
934
- out1 = pool(arr1)
935
- out2 = pool(arr2)
936
-
937
- # out2 should be strictly greater than out1
938
- self.assertTrue(jnp.all(out2 > out1))
939
-
940
- def test_adaptive_pool_preserves_values(self):
941
- """Test that adaptive pooling with same size preserves values."""
942
- arr = brainstate.random.rand(1, 8, 2)
943
-
944
- # Adaptive pool to same size
945
- pool = nn.AdaptiveAvgPool1d(8, channel_axis=-1)
946
- out = pool(arr)
947
-
948
- # Should be approximately equal (might have small numerical differences)
949
- self.assertTrue(jnp.allclose(out, arr, rtol=1e-5))
950
-
951
-
952
- if __name__ == '__main__':
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ from absl.testing import absltest
22
+ from absl.testing import parameterized
23
+
24
+ import brainstate
25
+ import brainstate.nn as nn
26
+
27
+
28
+ class TestFlatten(parameterized.TestCase):
29
+ def test_flatten1(self):
30
+ for size in [
31
+ (16, 32, 32, 8),
32
+ (32, 8),
33
+ (10, 20, 30),
34
+ ]:
35
+ arr = brainstate.random.rand(*size)
36
+ f = nn.Flatten(start_axis=0)
37
+ out = f(arr)
38
+ self.assertTrue(out.shape == (np.prod(size),))
39
+
40
+ def test_flatten2(self):
41
+ for size in [
42
+ (16, 32, 32, 8),
43
+ (32, 8),
44
+ (10, 20, 30),
45
+ ]:
46
+ arr = brainstate.random.rand(*size)
47
+ f = nn.Flatten(start_axis=1)
48
+ out = f(arr)
49
+ self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
50
+
51
+ def test_flatten3(self):
52
+ size = (16, 32, 32, 8)
53
+ arr = brainstate.random.rand(*size)
54
+ f = nn.Flatten(start_axis=0, in_size=(32, 8))
55
+ out = f(arr)
56
+ self.assertTrue(out.shape == (16, 32, 32 * 8))
57
+
58
+ def test_flatten4(self):
59
+ size = (16, 32, 32, 8)
60
+ arr = brainstate.random.rand(*size)
61
+ f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
62
+ out = f(arr)
63
+ self.assertTrue(out.shape == (16, 32, 32 * 8))
64
+
65
+
66
+ class TestUnflatten(parameterized.TestCase):
67
+ """Comprehensive tests for Unflatten layer.
68
+
69
+ Note: Due to a bug in u.math.unflatten with negative axis handling,
70
+ we only test with positive axis values.
71
+ """
72
+
73
+ def test_unflatten_basic_2d(self):
74
+ """Test basic Unflatten functionality for 2D tensors."""
75
+ arr = brainstate.random.rand(6, 12)
76
+
77
+ # Unflatten last dimension (use positive axis due to bug)
78
+ unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
79
+ out = unflatten(arr)
80
+ self.assertEqual(out.shape, (6, 3, 4))
81
+
82
+ # Unflatten first dimension
83
+ unflatten = nn.Unflatten(axis=0, sizes=(2, 3))
84
+ out = unflatten(arr)
85
+ self.assertEqual(out.shape, (2, 3, 12))
86
+
87
+ def test_unflatten_basic_3d(self):
88
+ """Test basic Unflatten functionality for 3D tensors."""
89
+ arr = brainstate.random.rand(4, 6, 24)
90
+
91
+ # Unflatten last dimension using positive index
92
+ unflatten = nn.Unflatten(axis=2, sizes=(2, 3, 4))
93
+ out = unflatten(arr)
94
+ self.assertEqual(out.shape, (4, 6, 2, 3, 4))
95
+
96
+ # Unflatten middle dimension
97
+ unflatten = nn.Unflatten(axis=1, sizes=(2, 3))
98
+ out = unflatten(arr)
99
+ self.assertEqual(out.shape, (4, 2, 3, 24))
100
+
101
+ def test_unflatten_with_in_size(self):
102
+ """Test Unflatten with in_size parameter."""
103
+ # Test with in_size specified
104
+ unflatten = nn.Unflatten(axis=1, sizes=(2, 3), in_size=(4, 6))
105
+
106
+ # Check that out_size is computed correctly
107
+ self.assertIsNotNone(unflatten.out_size)
108
+ self.assertEqual(unflatten.out_size, (4, 2, 3))
109
+
110
+ # Apply to actual tensor
111
+ arr = brainstate.random.rand(4, 6)
112
+ out = unflatten(arr)
113
+ self.assertEqual(out.shape, (4, 2, 3))
114
+
115
+ def test_unflatten_preserve_batch_dims(self):
116
+ """Test that Unflatten preserves batch dimensions."""
117
+ # Multiple batch dimensions
118
+ arr = brainstate.random.rand(2, 3, 4, 20)
119
+
120
+ # Unflatten last dimension (use positive axis)
121
+ unflatten = nn.Unflatten(axis=3, sizes=(4, 5))
122
+ out = unflatten(arr)
123
+ self.assertEqual(out.shape, (2, 3, 4, 4, 5))
124
+
125
+ def test_unflatten_single_element_split(self):
126
+ """Test Unflatten with sizes containing 1."""
127
+ arr = brainstate.random.rand(3, 12)
128
+
129
+ # Include dimension of size 1
130
+ unflatten = nn.Unflatten(axis=1, sizes=(1, 3, 4))
131
+ out = unflatten(arr)
132
+ self.assertEqual(out.shape, (3, 1, 3, 4))
133
+
134
+ # Multiple ones
135
+ unflatten = nn.Unflatten(axis=1, sizes=(1, 1, 12))
136
+ out = unflatten(arr)
137
+ self.assertEqual(out.shape, (3, 1, 1, 12))
138
+
139
+ def test_unflatten_large_split(self):
140
+ """Test Unflatten with large number of dimensions."""
141
+ arr = brainstate.random.rand(2, 120)
142
+
143
+ # Split into many dimensions
144
+ unflatten = nn.Unflatten(axis=1, sizes=(2, 3, 4, 5))
145
+ out = unflatten(arr)
146
+ self.assertEqual(out.shape, (2, 2, 3, 4, 5))
147
+
148
+ # Verify total elements preserved
149
+ self.assertEqual(arr.size, out.size)
150
+ self.assertEqual(2 * 3 * 4 * 5, 120)
151
+
152
+ def test_unflatten_flatten_inverse(self):
153
+ """Test that Unflatten is inverse of Flatten."""
154
+ original = brainstate.random.rand(2, 3, 4, 5)
155
+
156
+ # Flatten dimensions 1 and 2
157
+ flatten = nn.Flatten(start_axis=1, end_axis=2)
158
+ flattened = flatten(original)
159
+ self.assertEqual(flattened.shape, (2, 12, 5))
160
+
161
+ # Unflatten back
162
+ unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
163
+ restored = unflatten(flattened)
164
+ self.assertEqual(restored.shape, original.shape)
165
+
166
+ # Values should be identical
167
+ self.assertTrue(jnp.allclose(original, restored))
168
+
169
+ def test_unflatten_sequential_operations(self):
170
+ """Test Unflatten in sequential operations."""
171
+ arr = brainstate.random.rand(4, 24)
172
+
173
+ # Apply multiple unflatten operations
174
+ unflatten1 = nn.Unflatten(axis=1, sizes=(6, 4))
175
+ intermediate = unflatten1(arr)
176
+ self.assertEqual(intermediate.shape, (4, 6, 4))
177
+
178
+ unflatten2 = nn.Unflatten(axis=1, sizes=(2, 3))
179
+ final = unflatten2(intermediate)
180
+ self.assertEqual(final.shape, (4, 2, 3, 4))
181
+
182
+ def test_unflatten_error_cases(self):
183
+ """Test error handling in Unflatten."""
184
+ # Test invalid sizes type
185
+ with self.assertRaises(TypeError):
186
+ nn.Unflatten(axis=0, sizes=12) # sizes must be tuple or list
187
+
188
+ with self.assertRaises(TypeError):
189
+ nn.Unflatten(axis=0, sizes="invalid")
190
+
191
+ # Test invalid element in sizes
192
+ with self.assertRaises(TypeError):
193
+ nn.Unflatten(axis=0, sizes=(2, "invalid"))
194
+
195
+ with self.assertRaises(TypeError):
196
+ nn.Unflatten(axis=0, sizes=(2.5, 3)) # must be integers
197
+
198
+ @parameterized.named_parameters(
199
+ ('axis_0_2d', 0, (10, 20), (2, 5)),
200
+ ('axis_1_2d', 1, (10, 20), (4, 5)),
201
+ ('axis_0_3d', 0, (6, 8, 10), (2, 3)),
202
+ ('axis_1_3d', 1, (6, 8, 10), (2, 4)),
203
+ ('axis_2_3d', 2, (6, 8, 10), (2, 5)),
204
+ )
205
+ def test_unflatten_parameterized(self, axis, input_shape, unflatten_sizes):
206
+ """Parameterized test for various axis and shape combinations."""
207
+ arr = brainstate.random.rand(*input_shape)
208
+ unflatten = nn.Unflatten(axis=axis, sizes=unflatten_sizes)
209
+ out = unflatten(arr)
210
+
211
+ # Check that product of unflatten_sizes matches original dimension
212
+ original_dim_size = input_shape[axis]
213
+ self.assertEqual(np.prod(unflatten_sizes), original_dim_size)
214
+
215
+ # Check output shape
216
+ expected_shape = list(input_shape)
217
+ expected_shape[axis:axis+1] = unflatten_sizes
218
+ self.assertEqual(out.shape, tuple(expected_shape))
219
+
220
+ # Check total size preserved
221
+ self.assertEqual(arr.size, out.size)
222
+
223
+ def test_unflatten_values_preserved(self):
224
+ """Test that values are correctly preserved during unflatten."""
225
+ # Create a tensor with known pattern
226
+ arr = jnp.arange(24).reshape(2, 12)
227
+
228
+ unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
229
+ out = unflatten(arr)
230
+
231
+ # Check shape
232
+ self.assertEqual(out.shape, (2, 3, 4))
233
+
234
+ # Check that values are correctly rearranged
235
+ # First batch
236
+ self.assertTrue(jnp.allclose(out[0, 0, :], jnp.arange(0, 4)))
237
+ self.assertTrue(jnp.allclose(out[0, 1, :], jnp.arange(4, 8)))
238
+ self.assertTrue(jnp.allclose(out[0, 2, :], jnp.arange(8, 12)))
239
+
240
+ # Second batch
241
+ self.assertTrue(jnp.allclose(out[1, 0, :], jnp.arange(12, 16)))
242
+ self.assertTrue(jnp.allclose(out[1, 1, :], jnp.arange(16, 20)))
243
+ self.assertTrue(jnp.allclose(out[1, 2, :], jnp.arange(20, 24)))
244
+
245
+ def test_unflatten_with_complex_shapes(self):
246
+ """Test Unflatten with complex multi-dimensional shapes."""
247
+ # 5D tensor
248
+ arr = brainstate.random.rand(2, 3, 4, 5, 60)
249
+
250
+ # Unflatten last dimension (use positive axis)
251
+ unflatten = nn.Unflatten(axis=4, sizes=(3, 4, 5))
252
+ out = unflatten(arr)
253
+ self.assertEqual(out.shape, (2, 3, 4, 5, 3, 4, 5))
254
+
255
+ # Unflatten middle dimension
256
+ arr = brainstate.random.rand(2, 3, 12, 5, 6)
257
+ unflatten = nn.Unflatten(axis=2, sizes=(3, 4))
258
+ out = unflatten(arr)
259
+ self.assertEqual(out.shape, (2, 3, 3, 4, 5, 6))
260
+
261
+ def test_unflatten_edge_cases(self):
262
+ """Test edge cases for Unflatten."""
263
+ # Single element tensor
264
+ arr = brainstate.random.rand(1)
265
+ unflatten = nn.Unflatten(axis=0, sizes=(1,))
266
+ out = unflatten(arr)
267
+ self.assertEqual(out.shape, (1,))
268
+
269
+ # Unflatten to same dimension (essentially no-op)
270
+ arr = brainstate.random.rand(3, 5)
271
+ unflatten = nn.Unflatten(axis=1, sizes=(5,))
272
+ out = unflatten(arr)
273
+ self.assertEqual(out.shape, (3, 5))
274
+
275
+ # Very large unflatten
276
+ arr = brainstate.random.rand(2, 1024)
277
+ unflatten = nn.Unflatten(axis=1, sizes=(4, 4, 4, 4, 4))
278
+ out = unflatten(arr)
279
+ self.assertEqual(out.shape, (2, 4, 4, 4, 4, 4))
280
+ self.assertEqual(4**5, 1024)
281
+
282
+ def test_unflatten_jit_compatibility(self):
283
+ """Test that Unflatten works with JAX JIT compilation."""
284
+ arr = brainstate.random.rand(4, 12)
285
+ unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
286
+
287
+ # JIT compile the unflatten operation
288
+ jitted_unflatten = jax.jit(unflatten.update)
289
+
290
+ # Compare results
291
+ out_normal = unflatten(arr)
292
+ out_jitted = jitted_unflatten(arr)
293
+
294
+ self.assertEqual(out_normal.shape, (4, 3, 4))
295
+ self.assertEqual(out_jitted.shape, (4, 3, 4))
296
+ self.assertTrue(jnp.allclose(out_normal, out_jitted))
297
+
298
+
299
+ class TestMaxPool1d(parameterized.TestCase):
300
+ """Comprehensive tests for MaxPool1d."""
301
+
302
+ def test_maxpool1d_basic(self):
303
+ """Test basic MaxPool1d functionality."""
304
+ # Test with different input shapes
305
+ arr = brainstate.random.rand(16, 32, 8) # (batch, length, channels)
306
+
307
+ # Test with kernel_size=2, stride=2
308
+ pool = nn.MaxPool1d(2, 2, channel_axis=-1)
309
+ out = pool(arr)
310
+ self.assertEqual(out.shape, (16, 16, 8))
311
+
312
+ # Test with kernel_size=3, stride=1
313
+ pool = nn.MaxPool1d(3, 1, channel_axis=-1)
314
+ out = pool(arr)
315
+ self.assertEqual(out.shape, (16, 30, 8))
316
+
317
+ def test_maxpool1d_padding(self):
318
+ """Test MaxPool1d with padding."""
319
+ arr = brainstate.random.rand(4, 10, 3)
320
+
321
+ # Test with padding
322
+ pool = nn.MaxPool1d(3, 2, padding=1, channel_axis=-1)
323
+ out = pool(arr)
324
+ self.assertEqual(out.shape, (4, 5, 3))
325
+
326
+ # Test with tuple padding (same value for both sides in 1D)
327
+ pool = nn.MaxPool1d(3, 2, padding=(1,), channel_axis=-1)
328
+ out = pool(arr)
329
+ self.assertEqual(out.shape, (4, 5, 3))
330
+
331
+ def test_maxpool1d_return_indices(self):
332
+ """Test MaxPool1d with return_indices=True."""
333
+ arr = brainstate.random.rand(2, 10, 3)
334
+
335
+ pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
336
+ out, indices = pool(arr)
337
+ self.assertEqual(out.shape, (2, 5, 3))
338
+ self.assertEqual(indices.shape, (2, 5, 3))
339
+
340
+ def test_maxpool1d_no_channel_axis(self):
341
+ """Test MaxPool1d without channel axis."""
342
+ arr = brainstate.random.rand(16, 32)
343
+
344
+ pool = nn.MaxPool1d(2, 2, channel_axis=None)
345
+ out = pool(arr)
346
+ self.assertEqual(out.shape, (16, 16))
347
+
348
+
349
+ class TestMaxPool2d(parameterized.TestCase):
350
+ """Comprehensive tests for MaxPool2d."""
351
+
352
+ def test_maxpool2d_basic(self):
353
+ """Test basic MaxPool2d functionality."""
354
+ arr = brainstate.random.rand(16, 32, 32, 8) # (batch, height, width, channels)
355
+
356
+ out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
357
+ self.assertTrue(out.shape == (16, 16, 16, 8))
358
+
359
+ out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
360
+ self.assertTrue(out.shape == (16, 32, 16, 4))
361
+
362
+ def test_maxpool2d_padding(self):
363
+ """Test MaxPool2d with padding."""
364
+ arr = brainstate.random.rand(16, 32, 32, 8)
365
+
366
+ out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
367
+ self.assertTrue(out.shape == (16, 32, 17, 5))
368
+
369
+ out = nn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
370
+ self.assertTrue(out.shape == (16, 32, 18, 5))
371
+
372
+ out = nn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
373
+ self.assertTrue(out.shape == (16, 17, 17, 8))
374
+
375
+ out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
376
+ self.assertTrue(out.shape == (16, 17, 32, 5))
377
+
378
+ def test_maxpool2d_return_indices(self):
379
+ """Test MaxPool2d with return_indices=True."""
380
+ arr = brainstate.random.rand(2, 8, 8, 3)
381
+
382
+ pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
383
+ out, indices = pool(arr)
384
+ self.assertEqual(out.shape, (2, 4, 4, 3))
385
+ self.assertEqual(indices.shape, (2, 4, 4, 3))
386
+
387
+ def test_maxpool2d_different_strides(self):
388
+ """Test MaxPool2d with different stride values."""
389
+ arr = brainstate.random.rand(2, 16, 16, 4)
390
+
391
+ # Different strides for height and width
392
+ pool = nn.MaxPool2d(3, stride=(2, 1), channel_axis=-1)
393
+ out = pool(arr)
394
+ self.assertEqual(out.shape, (2, 7, 14, 4))
395
+
396
+
397
+ class TestMaxPool3d(parameterized.TestCase):
398
+ """Comprehensive tests for MaxPool3d."""
399
+
400
+ def test_maxpool3d_basic(self):
401
+ """Test basic MaxPool3d functionality."""
402
+ arr = brainstate.random.rand(2, 16, 16, 16, 4) # (batch, depth, height, width, channels)
403
+
404
+ pool = nn.MaxPool3d(2, 2, channel_axis=-1)
405
+ out = pool(arr)
406
+ self.assertEqual(out.shape, (2, 8, 8, 8, 4))
407
+
408
+ pool = nn.MaxPool3d(3, 1, channel_axis=-1)
409
+ out = pool(arr)
410
+ self.assertEqual(out.shape, (2, 14, 14, 14, 4))
411
+
412
+ def test_maxpool3d_padding(self):
413
+ """Test MaxPool3d with padding."""
414
+ arr = brainstate.random.rand(1, 8, 8, 8, 2)
415
+
416
+ pool = nn.MaxPool3d(3, 2, padding=1, channel_axis=-1)
417
+ out = pool(arr)
418
+ self.assertEqual(out.shape, (1, 4, 4, 4, 2))
419
+
420
+ def test_maxpool3d_return_indices(self):
421
+ """Test MaxPool3d with return_indices=True."""
422
+ arr = brainstate.random.rand(1, 4, 4, 4, 2)
423
+
424
+ pool = nn.MaxPool3d(2, 2, channel_axis=-1, return_indices=True)
425
+ out, indices = pool(arr)
426
+ self.assertEqual(out.shape, (1, 2, 2, 2, 2))
427
+ self.assertEqual(indices.shape, (1, 2, 2, 2, 2))
428
+
429
+
430
+ class TestAvgPool1d(parameterized.TestCase):
431
+ """Comprehensive tests for AvgPool1d."""
432
+
433
+ def test_avgpool1d_basic(self):
434
+ """Test basic AvgPool1d functionality."""
435
+ arr = brainstate.random.rand(4, 16, 8) # (batch, length, channels)
436
+
437
+ pool = nn.AvgPool1d(2, 2, channel_axis=-1)
438
+ out = pool(arr)
439
+ self.assertEqual(out.shape, (4, 8, 8))
440
+
441
+ # Test averaging values
442
+ arr = jnp.ones((1, 4, 2))
443
+ pool = nn.AvgPool1d(2, 2, channel_axis=-1)
444
+ out = pool(arr)
445
+ self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 2))))
446
+
447
+ def test_avgpool1d_padding(self):
448
+ """Test AvgPool1d with padding."""
449
+ arr = brainstate.random.rand(2, 10, 3)
450
+
451
+ pool = nn.AvgPool1d(3, 2, padding=1, channel_axis=-1)
452
+ out = pool(arr)
453
+ self.assertEqual(out.shape, (2, 5, 3))
454
+
455
+ def test_avgpool1d_divisor_override(self):
456
+ """Test AvgPool1d divisor behavior."""
457
+ arr = jnp.ones((1, 4, 1))
458
+
459
+ # Standard average pooling
460
+ pool = nn.AvgPool1d(2, 2, channel_axis=-1)
461
+ out = pool(arr)
462
+
463
+ # All values should still be 1.0 for constant input
464
+ self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 1))))
465
+
466
+
467
+ class TestAvgPool2d(parameterized.TestCase):
468
+ """Comprehensive tests for AvgPool2d."""
469
+
470
+ def test_avgpool2d_basic(self):
471
+ """Test basic AvgPool2d functionality."""
472
+ arr = brainstate.random.rand(16, 32, 32, 8)
473
+
474
+ out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
475
+ self.assertTrue(out.shape == (16, 16, 16, 8))
476
+
477
+ out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
478
+ self.assertTrue(out.shape == (16, 32, 16, 4))
479
+
480
+ def test_avgpool2d_padding(self):
481
+ """Test AvgPool2d with padding."""
482
+ arr = brainstate.random.rand(16, 32, 32, 8)
483
+
484
+ out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
485
+ self.assertTrue(out.shape == (16, 32, 17, 5))
486
+
487
+ out = nn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
488
+ self.assertTrue(out.shape == (16, 32, 18, 5))
489
+
490
+ out = nn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
491
+ self.assertTrue(out.shape == (16, 17, 17, 8))
492
+
493
+ out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
494
+ self.assertTrue(out.shape == (16, 17, 32, 5))
495
+
496
+ def test_avgpool2d_values(self):
497
+ """Test AvgPool2d computes correct average values."""
498
+ arr = jnp.ones((1, 4, 4, 1))
499
+ pool = nn.AvgPool2d(2, 2, channel_axis=-1)
500
+ out = pool(arr)
501
+ self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 2, 1))))
502
+
503
+
504
+ class TestAvgPool3d(parameterized.TestCase):
505
+ """Comprehensive tests for AvgPool3d."""
506
+
507
+ def test_avgpool3d_basic(self):
508
+ """Test basic AvgPool3d functionality."""
509
+ arr = brainstate.random.rand(2, 8, 8, 8, 4)
510
+
511
+ pool = nn.AvgPool3d(2, 2, channel_axis=-1)
512
+ out = pool(arr)
513
+ self.assertEqual(out.shape, (2, 4, 4, 4, 4))
514
+
515
+ def test_avgpool3d_padding(self):
516
+ """Test AvgPool3d with padding."""
517
+ arr = brainstate.random.rand(1, 6, 6, 6, 2)
518
+
519
+ pool = nn.AvgPool3d(3, 2, padding=1, channel_axis=-1)
520
+ out = pool(arr)
521
+ self.assertEqual(out.shape, (1, 3, 3, 3, 2))
522
+
523
+
524
+ class TestMaxUnpool1d(parameterized.TestCase):
525
+ """Comprehensive tests for MaxUnpool1d."""
526
+
527
+ def test_maxunpool1d_basic(self):
528
+ """Test basic MaxUnpool1d functionality."""
529
+ # Create input
530
+ arr = brainstate.random.rand(2, 8, 3)
531
+
532
+ # Pool with indices
533
+ pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
534
+ pooled, indices = pool(arr)
535
+
536
+ # Unpool
537
+ unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
538
+ unpooled = unpool(pooled, indices)
539
+
540
+ # Shape should match original (or be close depending on padding)
541
+ self.assertEqual(unpooled.shape, (2, 8, 3))
542
+
543
+ def test_maxunpool1d_with_output_size(self):
544
+ """Test MaxUnpool1d with explicit output_size."""
545
+ arr = brainstate.random.rand(1, 10, 2)
546
+
547
+ pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
548
+ pooled, indices = pool(arr)
549
+
550
+ unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
551
+ unpooled = unpool(pooled, indices, output_size=(1, 10, 2))
552
+
553
+ self.assertEqual(unpooled.shape, (1, 10, 2))
554
+
555
+
556
+ class TestMaxUnpool2d(parameterized.TestCase):
557
+ """Comprehensive tests for MaxUnpool2d."""
558
+
559
+ def test_maxunpool2d_basic(self):
560
+ """Test basic MaxUnpool2d functionality."""
561
+ arr = brainstate.random.rand(2, 8, 8, 3)
562
+
563
+ # Pool with indices
564
+ pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
565
+ pooled, indices = pool(arr)
566
+
567
+ # Unpool
568
+ unpool = nn.MaxUnpool2d(2, 2, channel_axis=-1)
569
+ unpooled = unpool(pooled, indices)
570
+
571
+ self.assertEqual(unpooled.shape, (2, 8, 8, 3))
572
+
573
+ def test_maxunpool2d_values(self):
574
+ """Test MaxUnpool2d places values correctly."""
575
+ # Create simple input where we can track values
576
+ arr = jnp.array([[1., 2., 3., 4.],
577
+ [5., 6., 7., 8.]]) # (2, 4)
578
+ arr = arr.reshape(1, 2, 2, 2) # (1, 2, 2, 2)
579
+
580
+ # Pool to get max value and its index
581
+ pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
582
+ pooled, indices = pool(arr)
583
+
584
+ # Unpool
585
+ unpool = nn.MaxUnpool2d(2, 2, channel_axis=-1)
586
+ unpooled = unpool(pooled, indices)
587
+
588
+ # Check that max value (8.0) is preserved
589
+ self.assertTrue(jnp.max(unpooled) == 8.0)
590
+ # Check shape
591
+ self.assertEqual(unpooled.shape, (1, 2, 2, 2))
592
+
593
+
594
+ class TestMaxUnpool3d(parameterized.TestCase):
595
+ """Comprehensive tests for MaxUnpool3d."""
596
+
597
+ def test_maxunpool3d_basic(self):
598
+ """Test basic MaxUnpool3d functionality."""
599
+ arr = brainstate.random.rand(1, 4, 4, 4, 2)
600
+
601
+ # Pool with indices
602
+ pool = nn.MaxPool3d(2, 2, channel_axis=-1, return_indices=True)
603
+ pooled, indices = pool(arr)
604
+
605
+ # Unpool
606
+ unpool = nn.MaxUnpool3d(2, 2, channel_axis=-1)
607
+ unpooled = unpool(pooled, indices)
608
+
609
+ self.assertEqual(unpooled.shape, (1, 4, 4, 4, 2))
610
+
611
+
612
+ class TestLPPool1d(parameterized.TestCase):
613
+ """Comprehensive tests for LPPool1d."""
614
+
615
+ def test_lppool1d_basic(self):
616
+ """Test basic LPPool1d functionality."""
617
+ arr = brainstate.random.rand(2, 16, 4)
618
+
619
+ # Test L2 pooling (norm_type=2)
620
+ pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
621
+ out = pool(arr)
622
+ self.assertEqual(out.shape, (2, 8, 4))
623
+
624
+ def test_lppool1d_different_norms(self):
625
+ """Test LPPool1d with different norm types."""
626
+ arr = brainstate.random.rand(1, 8, 2)
627
+
628
+ # Test with p=1 (should be similar to average)
629
+ pool1 = nn.LPPool1d(norm_type=1, kernel_size=2, stride=2, channel_axis=-1)
630
+ out1 = pool1(arr)
631
+
632
+ # Test with p=2 (L2 norm)
633
+ pool2 = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
634
+ out2 = pool2(arr)
635
+
636
+ # Test with large p (should approach max pooling)
637
+ pool_inf = nn.LPPool1d(norm_type=10, kernel_size=2, stride=2, channel_axis=-1)
638
+ out_inf = pool_inf(arr)
639
+
640
+ self.assertEqual(out1.shape, (1, 4, 2))
641
+ self.assertEqual(out2.shape, (1, 4, 2))
642
+ self.assertEqual(out_inf.shape, (1, 4, 2))
643
+
644
+ def test_lppool1d_value_check(self):
645
+ """Test LPPool1d computes correct values."""
646
+ # Simple test case
647
+ arr = jnp.array([[[2., 2.], [2., 2.]]]) # (1, 2, 2)
648
+
649
+ pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
650
+ out = pool(arr)
651
+
652
+ # For constant values, Lp norm should equal the value
653
+ self.assertTrue(jnp.allclose(out, 2.0, atol=1e-5))
654
+
655
+
656
+ class TestLPPool2d(parameterized.TestCase):
657
+ """Comprehensive tests for LPPool2d."""
658
+
659
+ def test_lppool2d_basic(self):
660
+ """Test basic LPPool2d functionality."""
661
+ arr = brainstate.random.rand(2, 8, 8, 4)
662
+
663
+ pool = nn.LPPool2d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
664
+ out = pool(arr)
665
+ self.assertEqual(out.shape, (2, 4, 4, 4))
666
+
667
+ def test_lppool2d_padding(self):
668
+ """Test LPPool2d with padding."""
669
+ arr = brainstate.random.rand(1, 7, 7, 2)
670
+
671
+ pool = nn.LPPool2d(norm_type=2, kernel_size=3, stride=2, padding=1, channel_axis=-1)
672
+ out = pool(arr)
673
+ self.assertEqual(out.shape, (1, 4, 4, 2))
674
+
675
+ def test_lppool2d_different_kernel_sizes(self):
676
+ """Test LPPool2d with non-square kernels."""
677
+ arr = brainstate.random.rand(1, 8, 6, 2)
678
+
679
+ pool = nn.LPPool2d(norm_type=2, kernel_size=(3, 2), stride=(2, 1), channel_axis=-1)
680
+ out = pool(arr)
681
+ self.assertEqual(out.shape, (1, 3, 5, 2))
682
+
683
+
684
+ class TestLPPool3d(parameterized.TestCase):
685
+ """Comprehensive tests for LPPool3d."""
686
+
687
+ def test_lppool3d_basic(self):
688
+ """Test basic LPPool3d functionality."""
689
+ arr = brainstate.random.rand(1, 8, 8, 8, 2)
690
+
691
+ pool = nn.LPPool3d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
692
+ out = pool(arr)
693
+ self.assertEqual(out.shape, (1, 4, 4, 4, 2))
694
+
695
+ def test_lppool3d_different_norms(self):
696
+ """Test LPPool3d with different norm types."""
697
+ arr = brainstate.random.rand(1, 4, 4, 4, 1)
698
+
699
+ # Different p values should give different results
700
+ pool1 = nn.LPPool3d(norm_type=1, kernel_size=2, stride=2, channel_axis=-1)
701
+ pool2 = nn.LPPool3d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
702
+ pool3 = nn.LPPool3d(norm_type=3, kernel_size=2, stride=2, channel_axis=-1)
703
+
704
+ out1 = pool1(arr)
705
+ out2 = pool2(arr)
706
+ out3 = pool3(arr)
707
+
708
+ # All should have same shape
709
+ self.assertEqual(out1.shape, (1, 2, 2, 2, 1))
710
+ self.assertEqual(out2.shape, (1, 2, 2, 2, 1))
711
+ self.assertEqual(out3.shape, (1, 2, 2, 2, 1))
712
+
713
+ # Values should be different (unless input is uniform)
714
+ self.assertFalse(jnp.allclose(out1, out2))
715
+ self.assertFalse(jnp.allclose(out2, out3))
716
+
717
+
718
+ class TestAdaptivePool(parameterized.TestCase):
719
+ """Tests for adaptive pooling layers."""
720
+
721
+ @parameterized.named_parameters(
722
+ dict(testcase_name=f'target_size={target_size}',
723
+ target_size=target_size)
724
+ for target_size in [10, 9, 8, 7, 6]
725
+ )
726
+ def test_adaptive_pool1d(self, target_size):
727
+ """Test internal adaptive pooling function."""
728
+ from brainstate.nn._poolings import _adaptive_pool1d
729
+
730
+ arr = brainstate.random.rand(100)
731
+ op = jax.numpy.mean
732
+
733
+ out = _adaptive_pool1d(arr, target_size, op)
734
+ self.assertTrue(out.shape == (target_size,))
735
+
736
+ def test_adaptive_avg_pool1d(self):
737
+ """Test AdaptiveAvgPool1d."""
738
+ input = brainstate.random.randn(2, 32, 4)
739
+
740
+ # Test with different target sizes
741
+ pool = nn.AdaptiveAvgPool1d(5, channel_axis=-1)
742
+ output = pool(input)
743
+ self.assertEqual(output.shape, (2, 5, 4))
744
+
745
+ # Test with single element input
746
+ pool = nn.AdaptiveAvgPool1d(1, channel_axis=-1)
747
+ output = pool(input)
748
+ self.assertEqual(output.shape, (2, 1, 4))
749
+
750
+ def test_adaptive_avg_pool2d(self):
751
+ """Test AdaptiveAvgPool2d."""
752
+ input = brainstate.random.randn(2, 8, 9, 3)
753
+
754
+ # Square output
755
+ output = nn.AdaptiveAvgPool2d(5, channel_axis=-1)(input)
756
+ self.assertEqual(output.shape, (2, 5, 5, 3))
757
+
758
+ # Non-square output
759
+ output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=-1)(input)
760
+ self.assertEqual(output.shape, (2, 5, 7, 3))
761
+
762
+ # Test with single integer (square output)
763
+ output = nn.AdaptiveAvgPool2d(4, channel_axis=-1)(input)
764
+ self.assertEqual(output.shape, (2, 4, 4, 3))
765
+
766
+ def test_adaptive_avg_pool3d(self):
767
+ """Test AdaptiveAvgPool3d."""
768
+ input = brainstate.random.randn(1, 8, 6, 4, 2)
769
+
770
+ pool = nn.AdaptiveAvgPool3d((4, 3, 2), channel_axis=-1)
771
+ output = pool(input)
772
+ self.assertEqual(output.shape, (1, 4, 3, 2, 2))
773
+
774
+ # Cube output
775
+ pool = nn.AdaptiveAvgPool3d(3, channel_axis=-1)
776
+ output = pool(input)
777
+ self.assertEqual(output.shape, (1, 3, 3, 3, 2))
778
+
779
+ def test_adaptive_max_pool1d(self):
780
+ """Test AdaptiveMaxPool1d."""
781
+ input = brainstate.random.randn(2, 32, 4)
782
+
783
+ pool = nn.AdaptiveMaxPool1d(8, channel_axis=-1)
784
+ output = pool(input)
785
+ self.assertEqual(output.shape, (2, 8, 4))
786
+
787
+ def test_adaptive_max_pool2d(self):
788
+ """Test AdaptiveMaxPool2d."""
789
+ input = brainstate.random.randn(2, 10, 8, 3)
790
+
791
+ pool = nn.AdaptiveMaxPool2d((5, 4), channel_axis=-1)
792
+ output = pool(input)
793
+ self.assertEqual(output.shape, (2, 5, 4, 3))
794
+
795
+ def test_adaptive_max_pool3d(self):
796
+ """Test AdaptiveMaxPool3d."""
797
+ input = brainstate.random.randn(1, 8, 8, 8, 2)
798
+
799
+ pool = nn.AdaptiveMaxPool3d((4, 4, 4), channel_axis=-1)
800
+ output = pool(input)
801
+ self.assertEqual(output.shape, (1, 4, 4, 4, 2))
802
+
803
+
804
+ class TestPoolingEdgeCases(parameterized.TestCase):
805
+ """Test edge cases and error conditions."""
806
+
807
+ def test_pool_with_stride_none(self):
808
+ """Test pooling with stride=None (defaults to kernel_size)."""
809
+ arr = brainstate.random.rand(1, 8, 2)
810
+
811
+ pool = nn.MaxPool1d(kernel_size=3, stride=None, channel_axis=-1)
812
+ out = pool(arr)
813
+ # stride defaults to kernel_size=3
814
+ self.assertEqual(out.shape, (1, 2, 2))
815
+
816
+ def test_pool_with_large_kernel(self):
817
+ """Test pooling with kernel larger than input."""
818
+ arr = brainstate.random.rand(1, 4, 2)
819
+
820
+ # Kernel size larger than spatial dimension
821
+ pool = nn.MaxPool1d(kernel_size=5, stride=1, channel_axis=-1)
822
+ out = pool(arr)
823
+ # Should handle gracefully (may produce empty output or handle with padding)
824
+ self.assertTrue(out.shape[1] >= 0)
825
+
826
+ def test_pool_single_element(self):
827
+ """Test pooling on single-element tensors."""
828
+ arr = brainstate.random.rand(1, 1, 1)
829
+
830
+ pool = nn.AvgPool1d(1, 1, channel_axis=-1)
831
+ out = pool(arr)
832
+ self.assertEqual(out.shape, (1, 1, 1))
833
+ self.assertTrue(jnp.allclose(out, arr))
834
+
835
+ def test_adaptive_pool_smaller_output(self):
836
+ """Test adaptive pooling with output smaller than input."""
837
+ arr = brainstate.random.rand(1, 16, 2)
838
+
839
+ # Adaptive pooling to smaller size
840
+ pool = nn.AdaptiveAvgPool1d(4, channel_axis=-1)
841
+ out = pool(arr)
842
+ self.assertEqual(out.shape, (1, 4, 2))
843
+
844
+ def test_unpool_without_indices(self):
845
+ """Test unpooling behavior with placeholder indices."""
846
+ pooled = brainstate.random.rand(1, 4, 2)
847
+ indices = jnp.zeros_like(pooled, dtype=jnp.int32)
848
+
849
+ unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
850
+ # Should not raise error even with zero indices
851
+ unpooled = unpool(pooled, indices)
852
+ self.assertEqual(unpooled.shape, (1, 8, 2))
853
+
854
+ def test_lppool_extreme_norm(self):
855
+ """Test LPPool with extreme norm values."""
856
+ arr = brainstate.random.rand(1, 8, 2) + 0.1 # Avoid zeros
857
+
858
+ # Very large p (approaches max pooling)
859
+ pool_large = nn.LPPool1d(norm_type=20, kernel_size=2, stride=2, channel_axis=-1)
860
+ out_large = pool_large(arr)
861
+
862
+ # Compare with actual max pooling
863
+ pool_max = nn.MaxPool1d(2, 2, channel_axis=-1)
864
+ out_max = pool_max(arr)
865
+
866
+ # Should approach max pooling for large p (but not exactly equal)
867
+ # Just check shapes match
868
+ self.assertEqual(out_large.shape, out_max.shape)
869
+
870
+ def test_pool_with_channels_first(self):
871
+ """Test pooling with channels in different positions."""
872
+ arr = brainstate.random.rand(3, 16, 8) # (dim0, dim1, dim2)
873
+
874
+ # Channel axis at position 0 - treats dim 0 as channels, pools last dimension
875
+ pool = nn.MaxPool1d(2, 2, channel_axis=0)
876
+ out = pool(arr)
877
+ # Pools the last dimension, keeping first two
878
+ self.assertEqual(out.shape, (3, 16, 4))
879
+
880
+ # Channel axis at position -1 (last) - pools middle dimension
881
+ pool = nn.MaxPool1d(2, 2, channel_axis=-1)
882
+ out = pool(arr)
883
+ # Pools the middle dimension, keeping first and last
884
+ self.assertEqual(out.shape, (3, 8, 8))
885
+
886
+ # No channel axis - pools last dimension, treating earlier dims as batch
887
+ pool = nn.MaxPool1d(2, 2, channel_axis=None)
888
+ out = pool(arr)
889
+ # Pools the last dimension
890
+ self.assertEqual(out.shape, (3, 16, 4))
891
+
892
+
893
+ class TestPoolingMathematicalProperties(parameterized.TestCase):
894
+ """Test mathematical properties of pooling operations."""
895
+
896
+ def test_maxpool_idempotence(self):
897
+ """Test that max pooling with kernel_size=1 is identity."""
898
+ arr = brainstate.random.rand(2, 8, 3)
899
+
900
+ pool = nn.MaxPool1d(1, 1, channel_axis=-1)
901
+ out = pool(arr)
902
+
903
+ self.assertTrue(jnp.allclose(out, arr))
904
+
905
+ def test_avgpool_constant_input(self):
906
+ """Test average pooling on constant input."""
907
+ arr = jnp.ones((1, 8, 2)) * 5.0
908
+
909
+ pool = nn.AvgPool1d(2, 2, channel_axis=-1)
910
+ out = pool(arr)
911
+
912
+ # Average of constant should be the constant
913
+ self.assertTrue(jnp.allclose(out, 5.0))
914
+
915
+ def test_lppool_norm_properties(self):
916
+ """Test Lp pooling norm properties."""
917
+ arr = brainstate.random.rand(1, 4, 1) + 0.1
918
+
919
+ # L1 norm (p=1) should give average of absolute values
920
+ pool_l1 = nn.LPPool1d(norm_type=1, kernel_size=4, stride=4, channel_axis=-1)
921
+ out_l1 = pool_l1(arr)
922
+
923
+ # Manual calculation
924
+ manual_l1 = jnp.mean(jnp.abs(arr[:, :4, :]))
925
+
926
+ self.assertTrue(jnp.allclose(out_l1[0, 0, 0], manual_l1, rtol=1e-5))
927
+
928
+ def test_maxpool_monotonicity(self):
929
+ """Test that max pooling preserves monotonicity."""
930
+ arr1 = brainstate.random.rand(1, 8, 2)
931
+ arr2 = arr1 + 1.0 # Strictly greater
932
+
933
+ pool = nn.MaxPool1d(2, 2, channel_axis=-1)
934
+ out1 = pool(arr1)
935
+ out2 = pool(arr2)
936
+
937
+ # out2 should be strictly greater than out1
938
+ self.assertTrue(jnp.all(out2 > out1))
939
+
940
+ def test_adaptive_pool_preserves_values(self):
941
+ """Test that adaptive pooling with same size preserves values."""
942
+ arr = brainstate.random.rand(1, 8, 2)
943
+
944
+ # Adaptive pool to same size
945
+ pool = nn.AdaptiveAvgPool1d(8, channel_axis=-1)
946
+ out = pool(arr)
947
+
948
+ # Should be approximately equal (might have small numerical differences)
949
+ self.assertTrue(jnp.allclose(out, arr, rtol=1e-5))
950
+
951
+
952
+ if __name__ == '__main__':
953
953
  absltest.main()