brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_poolings_test.py
CHANGED
@@ -1,6 +1,22 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
1
16
|
# -*- coding: utf-8 -*-
|
2
17
|
|
3
18
|
import jax
|
19
|
+
import jax.numpy as jnp
|
4
20
|
import numpy as np
|
5
21
|
from absl.testing import absltest
|
6
22
|
from absl.testing import parameterized
|
@@ -48,15 +64,294 @@ class TestFlatten(parameterized.TestCase):
|
|
48
64
|
|
49
65
|
|
50
66
|
class TestUnflatten(parameterized.TestCase):
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
67
|
+
"""Comprehensive tests for Unflatten layer.
|
68
|
+
|
69
|
+
Note: Due to a bug in u.math.unflatten with negative axis handling,
|
70
|
+
we only test with positive axis values.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def test_unflatten_basic_2d(self):
|
74
|
+
"""Test basic Unflatten functionality for 2D tensors."""
|
75
|
+
arr = brainstate.random.rand(6, 12)
|
76
|
+
|
77
|
+
# Unflatten last dimension (use positive axis due to bug)
|
78
|
+
unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
|
79
|
+
out = unflatten(arr)
|
80
|
+
self.assertEqual(out.shape, (6, 3, 4))
|
81
|
+
|
82
|
+
# Unflatten first dimension
|
83
|
+
unflatten = nn.Unflatten(axis=0, sizes=(2, 3))
|
84
|
+
out = unflatten(arr)
|
85
|
+
self.assertEqual(out.shape, (2, 3, 12))
|
86
|
+
|
87
|
+
def test_unflatten_basic_3d(self):
|
88
|
+
"""Test basic Unflatten functionality for 3D tensors."""
|
89
|
+
arr = brainstate.random.rand(4, 6, 24)
|
90
|
+
|
91
|
+
# Unflatten last dimension using positive index
|
92
|
+
unflatten = nn.Unflatten(axis=2, sizes=(2, 3, 4))
|
93
|
+
out = unflatten(arr)
|
94
|
+
self.assertEqual(out.shape, (4, 6, 2, 3, 4))
|
95
|
+
|
96
|
+
# Unflatten middle dimension
|
97
|
+
unflatten = nn.Unflatten(axis=1, sizes=(2, 3))
|
98
|
+
out = unflatten(arr)
|
99
|
+
self.assertEqual(out.shape, (4, 2, 3, 24))
|
100
|
+
|
101
|
+
def test_unflatten_with_in_size(self):
|
102
|
+
"""Test Unflatten with in_size parameter."""
|
103
|
+
# Test with in_size specified
|
104
|
+
unflatten = nn.Unflatten(axis=1, sizes=(2, 3), in_size=(4, 6))
|
105
|
+
|
106
|
+
# Check that out_size is computed correctly
|
107
|
+
self.assertIsNotNone(unflatten.out_size)
|
108
|
+
self.assertEqual(unflatten.out_size, (4, 2, 3))
|
109
|
+
|
110
|
+
# Apply to actual tensor
|
111
|
+
arr = brainstate.random.rand(4, 6)
|
112
|
+
out = unflatten(arr)
|
113
|
+
self.assertEqual(out.shape, (4, 2, 3))
|
114
|
+
|
115
|
+
def test_unflatten_preserve_batch_dims(self):
|
116
|
+
"""Test that Unflatten preserves batch dimensions."""
|
117
|
+
# Multiple batch dimensions
|
118
|
+
arr = brainstate.random.rand(2, 3, 4, 20)
|
119
|
+
|
120
|
+
# Unflatten last dimension (use positive axis)
|
121
|
+
unflatten = nn.Unflatten(axis=3, sizes=(4, 5))
|
122
|
+
out = unflatten(arr)
|
123
|
+
self.assertEqual(out.shape, (2, 3, 4, 4, 5))
|
124
|
+
|
125
|
+
def test_unflatten_single_element_split(self):
|
126
|
+
"""Test Unflatten with sizes containing 1."""
|
127
|
+
arr = brainstate.random.rand(3, 12)
|
128
|
+
|
129
|
+
# Include dimension of size 1
|
130
|
+
unflatten = nn.Unflatten(axis=1, sizes=(1, 3, 4))
|
131
|
+
out = unflatten(arr)
|
132
|
+
self.assertEqual(out.shape, (3, 1, 3, 4))
|
133
|
+
|
134
|
+
# Multiple ones
|
135
|
+
unflatten = nn.Unflatten(axis=1, sizes=(1, 1, 12))
|
136
|
+
out = unflatten(arr)
|
137
|
+
self.assertEqual(out.shape, (3, 1, 1, 12))
|
138
|
+
|
139
|
+
def test_unflatten_large_split(self):
|
140
|
+
"""Test Unflatten with large number of dimensions."""
|
141
|
+
arr = brainstate.random.rand(2, 120)
|
142
|
+
|
143
|
+
# Split into many dimensions
|
144
|
+
unflatten = nn.Unflatten(axis=1, sizes=(2, 3, 4, 5))
|
145
|
+
out = unflatten(arr)
|
146
|
+
self.assertEqual(out.shape, (2, 2, 3, 4, 5))
|
147
|
+
|
148
|
+
# Verify total elements preserved
|
149
|
+
self.assertEqual(arr.size, out.size)
|
150
|
+
self.assertEqual(2 * 3 * 4 * 5, 120)
|
151
|
+
|
152
|
+
def test_unflatten_flatten_inverse(self):
|
153
|
+
"""Test that Unflatten is inverse of Flatten."""
|
154
|
+
original = brainstate.random.rand(2, 3, 4, 5)
|
155
|
+
|
156
|
+
# Flatten dimensions 1 and 2
|
157
|
+
flatten = nn.Flatten(start_axis=1, end_axis=2)
|
158
|
+
flattened = flatten(original)
|
159
|
+
self.assertEqual(flattened.shape, (2, 12, 5))
|
160
|
+
|
161
|
+
# Unflatten back
|
162
|
+
unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
|
163
|
+
restored = unflatten(flattened)
|
164
|
+
self.assertEqual(restored.shape, original.shape)
|
165
|
+
|
166
|
+
# Values should be identical
|
167
|
+
self.assertTrue(jnp.allclose(original, restored))
|
168
|
+
|
169
|
+
def test_unflatten_sequential_operations(self):
|
170
|
+
"""Test Unflatten in sequential operations."""
|
171
|
+
arr = brainstate.random.rand(4, 24)
|
172
|
+
|
173
|
+
# Apply multiple unflatten operations
|
174
|
+
unflatten1 = nn.Unflatten(axis=1, sizes=(6, 4))
|
175
|
+
intermediate = unflatten1(arr)
|
176
|
+
self.assertEqual(intermediate.shape, (4, 6, 4))
|
177
|
+
|
178
|
+
unflatten2 = nn.Unflatten(axis=1, sizes=(2, 3))
|
179
|
+
final = unflatten2(intermediate)
|
180
|
+
self.assertEqual(final.shape, (4, 2, 3, 4))
|
181
|
+
|
182
|
+
def test_unflatten_error_cases(self):
|
183
|
+
"""Test error handling in Unflatten."""
|
184
|
+
# Test invalid sizes type
|
185
|
+
with self.assertRaises(TypeError):
|
186
|
+
nn.Unflatten(axis=0, sizes=12) # sizes must be tuple or list
|
187
|
+
|
188
|
+
with self.assertRaises(TypeError):
|
189
|
+
nn.Unflatten(axis=0, sizes="invalid")
|
190
|
+
|
191
|
+
# Test invalid element in sizes
|
192
|
+
with self.assertRaises(TypeError):
|
193
|
+
nn.Unflatten(axis=0, sizes=(2, "invalid"))
|
194
|
+
|
195
|
+
with self.assertRaises(TypeError):
|
196
|
+
nn.Unflatten(axis=0, sizes=(2.5, 3)) # must be integers
|
57
197
|
|
58
|
-
|
59
|
-
|
198
|
+
@parameterized.named_parameters(
|
199
|
+
('axis_0_2d', 0, (10, 20), (2, 5)),
|
200
|
+
('axis_1_2d', 1, (10, 20), (4, 5)),
|
201
|
+
('axis_0_3d', 0, (6, 8, 10), (2, 3)),
|
202
|
+
('axis_1_3d', 1, (6, 8, 10), (2, 4)),
|
203
|
+
('axis_2_3d', 2, (6, 8, 10), (2, 5)),
|
204
|
+
)
|
205
|
+
def test_unflatten_parameterized(self, axis, input_shape, unflatten_sizes):
|
206
|
+
"""Parameterized test for various axis and shape combinations."""
|
207
|
+
arr = brainstate.random.rand(*input_shape)
|
208
|
+
unflatten = nn.Unflatten(axis=axis, sizes=unflatten_sizes)
|
209
|
+
out = unflatten(arr)
|
210
|
+
|
211
|
+
# Check that product of unflatten_sizes matches original dimension
|
212
|
+
original_dim_size = input_shape[axis]
|
213
|
+
self.assertEqual(np.prod(unflatten_sizes), original_dim_size)
|
214
|
+
|
215
|
+
# Check output shape
|
216
|
+
expected_shape = list(input_shape)
|
217
|
+
expected_shape[axis:axis+1] = unflatten_sizes
|
218
|
+
self.assertEqual(out.shape, tuple(expected_shape))
|
219
|
+
|
220
|
+
# Check total size preserved
|
221
|
+
self.assertEqual(arr.size, out.size)
|
222
|
+
|
223
|
+
def test_unflatten_values_preserved(self):
|
224
|
+
"""Test that values are correctly preserved during unflatten."""
|
225
|
+
# Create a tensor with known pattern
|
226
|
+
arr = jnp.arange(24).reshape(2, 12)
|
227
|
+
|
228
|
+
unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
|
229
|
+
out = unflatten(arr)
|
230
|
+
|
231
|
+
# Check shape
|
232
|
+
self.assertEqual(out.shape, (2, 3, 4))
|
233
|
+
|
234
|
+
# Check that values are correctly rearranged
|
235
|
+
# First batch
|
236
|
+
self.assertTrue(jnp.allclose(out[0, 0, :], jnp.arange(0, 4)))
|
237
|
+
self.assertTrue(jnp.allclose(out[0, 1, :], jnp.arange(4, 8)))
|
238
|
+
self.assertTrue(jnp.allclose(out[0, 2, :], jnp.arange(8, 12)))
|
239
|
+
|
240
|
+
# Second batch
|
241
|
+
self.assertTrue(jnp.allclose(out[1, 0, :], jnp.arange(12, 16)))
|
242
|
+
self.assertTrue(jnp.allclose(out[1, 1, :], jnp.arange(16, 20)))
|
243
|
+
self.assertTrue(jnp.allclose(out[1, 2, :], jnp.arange(20, 24)))
|
244
|
+
|
245
|
+
def test_unflatten_with_complex_shapes(self):
|
246
|
+
"""Test Unflatten with complex multi-dimensional shapes."""
|
247
|
+
# 5D tensor
|
248
|
+
arr = brainstate.random.rand(2, 3, 4, 5, 60)
|
249
|
+
|
250
|
+
# Unflatten last dimension (use positive axis)
|
251
|
+
unflatten = nn.Unflatten(axis=4, sizes=(3, 4, 5))
|
252
|
+
out = unflatten(arr)
|
253
|
+
self.assertEqual(out.shape, (2, 3, 4, 5, 3, 4, 5))
|
254
|
+
|
255
|
+
# Unflatten middle dimension
|
256
|
+
arr = brainstate.random.rand(2, 3, 12, 5, 6)
|
257
|
+
unflatten = nn.Unflatten(axis=2, sizes=(3, 4))
|
258
|
+
out = unflatten(arr)
|
259
|
+
self.assertEqual(out.shape, (2, 3, 3, 4, 5, 6))
|
260
|
+
|
261
|
+
def test_unflatten_edge_cases(self):
|
262
|
+
"""Test edge cases for Unflatten."""
|
263
|
+
# Single element tensor
|
264
|
+
arr = brainstate.random.rand(1)
|
265
|
+
unflatten = nn.Unflatten(axis=0, sizes=(1,))
|
266
|
+
out = unflatten(arr)
|
267
|
+
self.assertEqual(out.shape, (1,))
|
268
|
+
|
269
|
+
# Unflatten to same dimension (essentially no-op)
|
270
|
+
arr = brainstate.random.rand(3, 5)
|
271
|
+
unflatten = nn.Unflatten(axis=1, sizes=(5,))
|
272
|
+
out = unflatten(arr)
|
273
|
+
self.assertEqual(out.shape, (3, 5))
|
274
|
+
|
275
|
+
# Very large unflatten
|
276
|
+
arr = brainstate.random.rand(2, 1024)
|
277
|
+
unflatten = nn.Unflatten(axis=1, sizes=(4, 4, 4, 4, 4))
|
278
|
+
out = unflatten(arr)
|
279
|
+
self.assertEqual(out.shape, (2, 4, 4, 4, 4, 4))
|
280
|
+
self.assertEqual(4**5, 1024)
|
281
|
+
|
282
|
+
def test_unflatten_jit_compatibility(self):
|
283
|
+
"""Test that Unflatten works with JAX JIT compilation."""
|
284
|
+
arr = brainstate.random.rand(4, 12)
|
285
|
+
unflatten = nn.Unflatten(axis=1, sizes=(3, 4))
|
286
|
+
|
287
|
+
# JIT compile the unflatten operation
|
288
|
+
jitted_unflatten = jax.jit(unflatten.update)
|
289
|
+
|
290
|
+
# Compare results
|
291
|
+
out_normal = unflatten(arr)
|
292
|
+
out_jitted = jitted_unflatten(arr)
|
293
|
+
|
294
|
+
self.assertEqual(out_normal.shape, (4, 3, 4))
|
295
|
+
self.assertEqual(out_jitted.shape, (4, 3, 4))
|
296
|
+
self.assertTrue(jnp.allclose(out_normal, out_jitted))
|
297
|
+
|
298
|
+
|
299
|
+
class TestMaxPool1d(parameterized.TestCase):
|
300
|
+
"""Comprehensive tests for MaxPool1d."""
|
301
|
+
|
302
|
+
def test_maxpool1d_basic(self):
|
303
|
+
"""Test basic MaxPool1d functionality."""
|
304
|
+
# Test with different input shapes
|
305
|
+
arr = brainstate.random.rand(16, 32, 8) # (batch, length, channels)
|
306
|
+
|
307
|
+
# Test with kernel_size=2, stride=2
|
308
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=-1)
|
309
|
+
out = pool(arr)
|
310
|
+
self.assertEqual(out.shape, (16, 16, 8))
|
311
|
+
|
312
|
+
# Test with kernel_size=3, stride=1
|
313
|
+
pool = nn.MaxPool1d(3, 1, channel_axis=-1)
|
314
|
+
out = pool(arr)
|
315
|
+
self.assertEqual(out.shape, (16, 30, 8))
|
316
|
+
|
317
|
+
def test_maxpool1d_padding(self):
|
318
|
+
"""Test MaxPool1d with padding."""
|
319
|
+
arr = brainstate.random.rand(4, 10, 3)
|
320
|
+
|
321
|
+
# Test with padding
|
322
|
+
pool = nn.MaxPool1d(3, 2, padding=1, channel_axis=-1)
|
323
|
+
out = pool(arr)
|
324
|
+
self.assertEqual(out.shape, (4, 5, 3))
|
325
|
+
|
326
|
+
# Test with tuple padding (same value for both sides in 1D)
|
327
|
+
pool = nn.MaxPool1d(3, 2, padding=(1,), channel_axis=-1)
|
328
|
+
out = pool(arr)
|
329
|
+
self.assertEqual(out.shape, (4, 5, 3))
|
330
|
+
|
331
|
+
def test_maxpool1d_return_indices(self):
|
332
|
+
"""Test MaxPool1d with return_indices=True."""
|
333
|
+
arr = brainstate.random.rand(2, 10, 3)
|
334
|
+
|
335
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
|
336
|
+
out, indices = pool(arr)
|
337
|
+
self.assertEqual(out.shape, (2, 5, 3))
|
338
|
+
self.assertEqual(indices.shape, (2, 5, 3))
|
339
|
+
|
340
|
+
def test_maxpool1d_no_channel_axis(self):
|
341
|
+
"""Test MaxPool1d without channel axis."""
|
342
|
+
arr = brainstate.random.rand(16, 32)
|
343
|
+
|
344
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=None)
|
345
|
+
out = pool(arr)
|
346
|
+
self.assertEqual(out.shape, (16, 16))
|
347
|
+
|
348
|
+
|
349
|
+
class TestMaxPool2d(parameterized.TestCase):
|
350
|
+
"""Comprehensive tests for MaxPool2d."""
|
351
|
+
|
352
|
+
def test_maxpool2d_basic(self):
|
353
|
+
"""Test basic MaxPool2d functionality."""
|
354
|
+
arr = brainstate.random.rand(16, 32, 32, 8) # (batch, height, width, channels)
|
60
355
|
|
61
356
|
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
62
357
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -64,6 +359,10 @@ class TestPool(parameterized.TestCase):
|
|
64
359
|
out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
|
65
360
|
self.assertTrue(out.shape == (16, 32, 16, 4))
|
66
361
|
|
362
|
+
def test_maxpool2d_padding(self):
|
363
|
+
"""Test MaxPool2d with padding."""
|
364
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
365
|
+
|
67
366
|
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
68
367
|
self.assertTrue(out.shape == (16, 32, 17, 5))
|
69
368
|
|
@@ -76,7 +375,100 @@ class TestPool(parameterized.TestCase):
|
|
76
375
|
out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
77
376
|
self.assertTrue(out.shape == (16, 17, 32, 5))
|
78
377
|
|
79
|
-
def
|
378
|
+
def test_maxpool2d_return_indices(self):
|
379
|
+
"""Test MaxPool2d with return_indices=True."""
|
380
|
+
arr = brainstate.random.rand(2, 8, 8, 3)
|
381
|
+
|
382
|
+
pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
|
383
|
+
out, indices = pool(arr)
|
384
|
+
self.assertEqual(out.shape, (2, 4, 4, 3))
|
385
|
+
self.assertEqual(indices.shape, (2, 4, 4, 3))
|
386
|
+
|
387
|
+
def test_maxpool2d_different_strides(self):
|
388
|
+
"""Test MaxPool2d with different stride values."""
|
389
|
+
arr = brainstate.random.rand(2, 16, 16, 4)
|
390
|
+
|
391
|
+
# Different strides for height and width
|
392
|
+
pool = nn.MaxPool2d(3, stride=(2, 1), channel_axis=-1)
|
393
|
+
out = pool(arr)
|
394
|
+
self.assertEqual(out.shape, (2, 7, 14, 4))
|
395
|
+
|
396
|
+
|
397
|
+
class TestMaxPool3d(parameterized.TestCase):
|
398
|
+
"""Comprehensive tests for MaxPool3d."""
|
399
|
+
|
400
|
+
def test_maxpool3d_basic(self):
|
401
|
+
"""Test basic MaxPool3d functionality."""
|
402
|
+
arr = brainstate.random.rand(2, 16, 16, 16, 4) # (batch, depth, height, width, channels)
|
403
|
+
|
404
|
+
pool = nn.MaxPool3d(2, 2, channel_axis=-1)
|
405
|
+
out = pool(arr)
|
406
|
+
self.assertEqual(out.shape, (2, 8, 8, 8, 4))
|
407
|
+
|
408
|
+
pool = nn.MaxPool3d(3, 1, channel_axis=-1)
|
409
|
+
out = pool(arr)
|
410
|
+
self.assertEqual(out.shape, (2, 14, 14, 14, 4))
|
411
|
+
|
412
|
+
def test_maxpool3d_padding(self):
|
413
|
+
"""Test MaxPool3d with padding."""
|
414
|
+
arr = brainstate.random.rand(1, 8, 8, 8, 2)
|
415
|
+
|
416
|
+
pool = nn.MaxPool3d(3, 2, padding=1, channel_axis=-1)
|
417
|
+
out = pool(arr)
|
418
|
+
self.assertEqual(out.shape, (1, 4, 4, 4, 2))
|
419
|
+
|
420
|
+
def test_maxpool3d_return_indices(self):
|
421
|
+
"""Test MaxPool3d with return_indices=True."""
|
422
|
+
arr = brainstate.random.rand(1, 4, 4, 4, 2)
|
423
|
+
|
424
|
+
pool = nn.MaxPool3d(2, 2, channel_axis=-1, return_indices=True)
|
425
|
+
out, indices = pool(arr)
|
426
|
+
self.assertEqual(out.shape, (1, 2, 2, 2, 2))
|
427
|
+
self.assertEqual(indices.shape, (1, 2, 2, 2, 2))
|
428
|
+
|
429
|
+
|
430
|
+
class TestAvgPool1d(parameterized.TestCase):
|
431
|
+
"""Comprehensive tests for AvgPool1d."""
|
432
|
+
|
433
|
+
def test_avgpool1d_basic(self):
|
434
|
+
"""Test basic AvgPool1d functionality."""
|
435
|
+
arr = brainstate.random.rand(4, 16, 8) # (batch, length, channels)
|
436
|
+
|
437
|
+
pool = nn.AvgPool1d(2, 2, channel_axis=-1)
|
438
|
+
out = pool(arr)
|
439
|
+
self.assertEqual(out.shape, (4, 8, 8))
|
440
|
+
|
441
|
+
# Test averaging values
|
442
|
+
arr = jnp.ones((1, 4, 2))
|
443
|
+
pool = nn.AvgPool1d(2, 2, channel_axis=-1)
|
444
|
+
out = pool(arr)
|
445
|
+
self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 2))))
|
446
|
+
|
447
|
+
def test_avgpool1d_padding(self):
|
448
|
+
"""Test AvgPool1d with padding."""
|
449
|
+
arr = brainstate.random.rand(2, 10, 3)
|
450
|
+
|
451
|
+
pool = nn.AvgPool1d(3, 2, padding=1, channel_axis=-1)
|
452
|
+
out = pool(arr)
|
453
|
+
self.assertEqual(out.shape, (2, 5, 3))
|
454
|
+
|
455
|
+
def test_avgpool1d_divisor_override(self):
|
456
|
+
"""Test AvgPool1d divisor behavior."""
|
457
|
+
arr = jnp.ones((1, 4, 1))
|
458
|
+
|
459
|
+
# Standard average pooling
|
460
|
+
pool = nn.AvgPool1d(2, 2, channel_axis=-1)
|
461
|
+
out = pool(arr)
|
462
|
+
|
463
|
+
# All values should still be 1.0 for constant input
|
464
|
+
self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 1))))
|
465
|
+
|
466
|
+
|
467
|
+
class TestAvgPool2d(parameterized.TestCase):
|
468
|
+
"""Comprehensive tests for AvgPool2d."""
|
469
|
+
|
470
|
+
def test_avgpool2d_basic(self):
|
471
|
+
"""Test basic AvgPool2d functionality."""
|
80
472
|
arr = brainstate.random.rand(16, 32, 32, 8)
|
81
473
|
|
82
474
|
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
@@ -85,6 +477,10 @@ class TestPool(parameterized.TestCase):
|
|
85
477
|
out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
|
86
478
|
self.assertTrue(out.shape == (16, 32, 16, 4))
|
87
479
|
|
480
|
+
def test_avgpool2d_padding(self):
|
481
|
+
"""Test AvgPool2d with padding."""
|
482
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
483
|
+
|
88
484
|
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
89
485
|
self.assertTrue(out.shape == (16, 32, 17, 5))
|
90
486
|
|
@@ -97,121 +493,461 @@ class TestPool(parameterized.TestCase):
|
|
97
493
|
out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
98
494
|
self.assertTrue(out.shape == (16, 17, 32, 5))
|
99
495
|
|
496
|
+
def test_avgpool2d_values(self):
|
497
|
+
"""Test AvgPool2d computes correct average values."""
|
498
|
+
arr = jnp.ones((1, 4, 4, 1))
|
499
|
+
pool = nn.AvgPool2d(2, 2, channel_axis=-1)
|
500
|
+
out = pool(arr)
|
501
|
+
self.assertTrue(jnp.allclose(out, jnp.ones((1, 2, 2, 1))))
|
502
|
+
|
503
|
+
|
504
|
+
class TestAvgPool3d(parameterized.TestCase):
|
505
|
+
"""Comprehensive tests for AvgPool3d."""
|
506
|
+
|
507
|
+
def test_avgpool3d_basic(self):
|
508
|
+
"""Test basic AvgPool3d functionality."""
|
509
|
+
arr = brainstate.random.rand(2, 8, 8, 8, 4)
|
510
|
+
|
511
|
+
pool = nn.AvgPool3d(2, 2, channel_axis=-1)
|
512
|
+
out = pool(arr)
|
513
|
+
self.assertEqual(out.shape, (2, 4, 4, 4, 4))
|
514
|
+
|
515
|
+
def test_avgpool3d_padding(self):
|
516
|
+
"""Test AvgPool3d with padding."""
|
517
|
+
arr = brainstate.random.rand(1, 6, 6, 6, 2)
|
518
|
+
|
519
|
+
pool = nn.AvgPool3d(3, 2, padding=1, channel_axis=-1)
|
520
|
+
out = pool(arr)
|
521
|
+
self.assertEqual(out.shape, (1, 3, 3, 3, 2))
|
522
|
+
|
523
|
+
|
524
|
+
class TestMaxUnpool1d(parameterized.TestCase):
|
525
|
+
"""Comprehensive tests for MaxUnpool1d."""
|
526
|
+
|
527
|
+
def test_maxunpool1d_basic(self):
|
528
|
+
"""Test basic MaxUnpool1d functionality."""
|
529
|
+
# Create input
|
530
|
+
arr = brainstate.random.rand(2, 8, 3)
|
531
|
+
|
532
|
+
# Pool with indices
|
533
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
|
534
|
+
pooled, indices = pool(arr)
|
535
|
+
|
536
|
+
# Unpool
|
537
|
+
unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
|
538
|
+
unpooled = unpool(pooled, indices)
|
539
|
+
|
540
|
+
# Shape should match original (or be close depending on padding)
|
541
|
+
self.assertEqual(unpooled.shape, (2, 8, 3))
|
542
|
+
|
543
|
+
def test_maxunpool1d_with_output_size(self):
|
544
|
+
"""Test MaxUnpool1d with explicit output_size."""
|
545
|
+
arr = brainstate.random.rand(1, 10, 2)
|
546
|
+
|
547
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=-1, return_indices=True)
|
548
|
+
pooled, indices = pool(arr)
|
549
|
+
|
550
|
+
unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
|
551
|
+
unpooled = unpool(pooled, indices, output_size=(1, 10, 2))
|
552
|
+
|
553
|
+
self.assertEqual(unpooled.shape, (1, 10, 2))
|
554
|
+
|
555
|
+
|
556
|
+
class TestMaxUnpool2d(parameterized.TestCase):
|
557
|
+
"""Comprehensive tests for MaxUnpool2d."""
|
558
|
+
|
559
|
+
def test_maxunpool2d_basic(self):
|
560
|
+
"""Test basic MaxUnpool2d functionality."""
|
561
|
+
arr = brainstate.random.rand(2, 8, 8, 3)
|
562
|
+
|
563
|
+
# Pool with indices
|
564
|
+
pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
|
565
|
+
pooled, indices = pool(arr)
|
566
|
+
|
567
|
+
# Unpool
|
568
|
+
unpool = nn.MaxUnpool2d(2, 2, channel_axis=-1)
|
569
|
+
unpooled = unpool(pooled, indices)
|
570
|
+
|
571
|
+
self.assertEqual(unpooled.shape, (2, 8, 8, 3))
|
572
|
+
|
573
|
+
def test_maxunpool2d_values(self):
|
574
|
+
"""Test MaxUnpool2d places values correctly."""
|
575
|
+
# Create simple input where we can track values
|
576
|
+
arr = jnp.array([[1., 2., 3., 4.],
|
577
|
+
[5., 6., 7., 8.]]) # (2, 4)
|
578
|
+
arr = arr.reshape(1, 2, 2, 2) # (1, 2, 2, 2)
|
579
|
+
|
580
|
+
# Pool to get max value and its index
|
581
|
+
pool = nn.MaxPool2d(2, 2, channel_axis=-1, return_indices=True)
|
582
|
+
pooled, indices = pool(arr)
|
583
|
+
|
584
|
+
# Unpool
|
585
|
+
unpool = nn.MaxUnpool2d(2, 2, channel_axis=-1)
|
586
|
+
unpooled = unpool(pooled, indices)
|
587
|
+
|
588
|
+
# Check that max value (8.0) is preserved
|
589
|
+
self.assertTrue(jnp.max(unpooled) == 8.0)
|
590
|
+
# Check shape
|
591
|
+
self.assertEqual(unpooled.shape, (1, 2, 2, 2))
|
592
|
+
|
593
|
+
|
594
|
+
class TestMaxUnpool3d(parameterized.TestCase):
|
595
|
+
"""Comprehensive tests for MaxUnpool3d."""
|
596
|
+
|
597
|
+
def test_maxunpool3d_basic(self):
|
598
|
+
"""Test basic MaxUnpool3d functionality."""
|
599
|
+
arr = brainstate.random.rand(1, 4, 4, 4, 2)
|
600
|
+
|
601
|
+
# Pool with indices
|
602
|
+
pool = nn.MaxPool3d(2, 2, channel_axis=-1, return_indices=True)
|
603
|
+
pooled, indices = pool(arr)
|
604
|
+
|
605
|
+
# Unpool
|
606
|
+
unpool = nn.MaxUnpool3d(2, 2, channel_axis=-1)
|
607
|
+
unpooled = unpool(pooled, indices)
|
608
|
+
|
609
|
+
self.assertEqual(unpooled.shape, (1, 4, 4, 4, 2))
|
610
|
+
|
611
|
+
|
612
|
+
class TestLPPool1d(parameterized.TestCase):
|
613
|
+
"""Comprehensive tests for LPPool1d."""
|
614
|
+
|
615
|
+
def test_lppool1d_basic(self):
|
616
|
+
"""Test basic LPPool1d functionality."""
|
617
|
+
arr = brainstate.random.rand(2, 16, 4)
|
618
|
+
|
619
|
+
# Test L2 pooling (norm_type=2)
|
620
|
+
pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
|
621
|
+
out = pool(arr)
|
622
|
+
self.assertEqual(out.shape, (2, 8, 4))
|
623
|
+
|
624
|
+
def test_lppool1d_different_norms(self):
|
625
|
+
"""Test LPPool1d with different norm types."""
|
626
|
+
arr = brainstate.random.rand(1, 8, 2)
|
627
|
+
|
628
|
+
# Test with p=1 (should be similar to average)
|
629
|
+
pool1 = nn.LPPool1d(norm_type=1, kernel_size=2, stride=2, channel_axis=-1)
|
630
|
+
out1 = pool1(arr)
|
631
|
+
|
632
|
+
# Test with p=2 (L2 norm)
|
633
|
+
pool2 = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
|
634
|
+
out2 = pool2(arr)
|
635
|
+
|
636
|
+
# Test with large p (should approach max pooling)
|
637
|
+
pool_inf = nn.LPPool1d(norm_type=10, kernel_size=2, stride=2, channel_axis=-1)
|
638
|
+
out_inf = pool_inf(arr)
|
639
|
+
|
640
|
+
self.assertEqual(out1.shape, (1, 4, 2))
|
641
|
+
self.assertEqual(out2.shape, (1, 4, 2))
|
642
|
+
self.assertEqual(out_inf.shape, (1, 4, 2))
|
643
|
+
|
644
|
+
def test_lppool1d_value_check(self):
|
645
|
+
"""Test LPPool1d computes correct values."""
|
646
|
+
# Simple test case
|
647
|
+
arr = jnp.array([[[2., 2.], [2., 2.]]]) # (1, 2, 2)
|
648
|
+
|
649
|
+
pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
|
650
|
+
out = pool(arr)
|
651
|
+
|
652
|
+
# For constant values, Lp norm should equal the value
|
653
|
+
self.assertTrue(jnp.allclose(out, 2.0, atol=1e-5))
|
654
|
+
|
655
|
+
|
656
|
+
class TestLPPool2d(parameterized.TestCase):
|
657
|
+
"""Comprehensive tests for LPPool2d."""
|
658
|
+
|
659
|
+
def test_lppool2d_basic(self):
|
660
|
+
"""Test basic LPPool2d functionality."""
|
661
|
+
arr = brainstate.random.rand(2, 8, 8, 4)
|
662
|
+
|
663
|
+
pool = nn.LPPool2d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
|
664
|
+
out = pool(arr)
|
665
|
+
self.assertEqual(out.shape, (2, 4, 4, 4))
|
666
|
+
|
667
|
+
def test_lppool2d_padding(self):
|
668
|
+
"""Test LPPool2d with padding."""
|
669
|
+
arr = brainstate.random.rand(1, 7, 7, 2)
|
670
|
+
|
671
|
+
pool = nn.LPPool2d(norm_type=2, kernel_size=3, stride=2, padding=1, channel_axis=-1)
|
672
|
+
out = pool(arr)
|
673
|
+
self.assertEqual(out.shape, (1, 4, 4, 2))
|
674
|
+
|
675
|
+
def test_lppool2d_different_kernel_sizes(self):
|
676
|
+
"""Test LPPool2d with non-square kernels."""
|
677
|
+
arr = brainstate.random.rand(1, 8, 6, 2)
|
678
|
+
|
679
|
+
pool = nn.LPPool2d(norm_type=2, kernel_size=(3, 2), stride=(2, 1), channel_axis=-1)
|
680
|
+
out = pool(arr)
|
681
|
+
self.assertEqual(out.shape, (1, 3, 5, 2))
|
682
|
+
|
683
|
+
|
684
|
+
class TestLPPool3d(parameterized.TestCase):
|
685
|
+
"""Comprehensive tests for LPPool3d."""
|
686
|
+
|
687
|
+
def test_lppool3d_basic(self):
|
688
|
+
"""Test basic LPPool3d functionality."""
|
689
|
+
arr = brainstate.random.rand(1, 8, 8, 8, 2)
|
690
|
+
|
691
|
+
pool = nn.LPPool3d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
|
692
|
+
out = pool(arr)
|
693
|
+
self.assertEqual(out.shape, (1, 4, 4, 4, 2))
|
694
|
+
|
695
|
+
def test_lppool3d_different_norms(self):
|
696
|
+
"""Test LPPool3d with different norm types."""
|
697
|
+
arr = brainstate.random.rand(1, 4, 4, 4, 1)
|
698
|
+
|
699
|
+
# Different p values should give different results
|
700
|
+
pool1 = nn.LPPool3d(norm_type=1, kernel_size=2, stride=2, channel_axis=-1)
|
701
|
+
pool2 = nn.LPPool3d(norm_type=2, kernel_size=2, stride=2, channel_axis=-1)
|
702
|
+
pool3 = nn.LPPool3d(norm_type=3, kernel_size=2, stride=2, channel_axis=-1)
|
703
|
+
|
704
|
+
out1 = pool1(arr)
|
705
|
+
out2 = pool2(arr)
|
706
|
+
out3 = pool3(arr)
|
707
|
+
|
708
|
+
# All should have same shape
|
709
|
+
self.assertEqual(out1.shape, (1, 2, 2, 2, 1))
|
710
|
+
self.assertEqual(out2.shape, (1, 2, 2, 2, 1))
|
711
|
+
self.assertEqual(out3.shape, (1, 2, 2, 2, 1))
|
712
|
+
|
713
|
+
# Values should be different (unless input is uniform)
|
714
|
+
self.assertFalse(jnp.allclose(out1, out2))
|
715
|
+
self.assertFalse(jnp.allclose(out2, out3))
|
716
|
+
|
717
|
+
|
718
|
+
class TestAdaptivePool(parameterized.TestCase):
|
719
|
+
"""Tests for adaptive pooling layers."""
|
720
|
+
|
100
721
|
@parameterized.named_parameters(
|
101
722
|
dict(testcase_name=f'target_size={target_size}',
|
102
723
|
target_size=target_size)
|
103
724
|
for target_size in [10, 9, 8, 7, 6]
|
104
725
|
)
|
105
726
|
def test_adaptive_pool1d(self, target_size):
|
727
|
+
"""Test internal adaptive pooling function."""
|
106
728
|
from brainstate.nn._poolings import _adaptive_pool1d
|
107
729
|
|
108
730
|
arr = brainstate.random.rand(100)
|
109
731
|
op = jax.numpy.mean
|
110
732
|
|
111
733
|
out = _adaptive_pool1d(arr, target_size, op)
|
112
|
-
print(out.shape)
|
113
734
|
self.assertTrue(out.shape == (target_size,))
|
114
735
|
|
115
|
-
|
116
|
-
|
117
|
-
|
736
|
+
def test_adaptive_avg_pool1d(self):
|
737
|
+
"""Test AdaptiveAvgPool1d."""
|
738
|
+
input = brainstate.random.randn(2, 32, 4)
|
118
739
|
|
119
|
-
|
120
|
-
|
740
|
+
# Test with different target sizes
|
741
|
+
pool = nn.AdaptiveAvgPool1d(5, channel_axis=-1)
|
742
|
+
output = pool(input)
|
743
|
+
self.assertEqual(output.shape, (2, 5, 4))
|
121
744
|
|
122
|
-
|
123
|
-
|
745
|
+
# Test with single element input
|
746
|
+
pool = nn.AdaptiveAvgPool1d(1, channel_axis=-1)
|
747
|
+
output = pool(input)
|
748
|
+
self.assertEqual(output.shape, (2, 1, 4))
|
124
749
|
|
125
|
-
|
126
|
-
|
750
|
+
def test_adaptive_avg_pool2d(self):
|
751
|
+
"""Test AdaptiveAvgPool2d."""
|
752
|
+
input = brainstate.random.randn(2, 8, 9, 3)
|
127
753
|
|
128
|
-
|
129
|
-
|
754
|
+
# Square output
|
755
|
+
output = nn.AdaptiveAvgPool2d(5, channel_axis=-1)(input)
|
756
|
+
self.assertEqual(output.shape, (2, 5, 5, 3))
|
130
757
|
|
131
|
-
|
132
|
-
|
758
|
+
# Non-square output
|
759
|
+
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=-1)(input)
|
760
|
+
self.assertEqual(output.shape, (2, 5, 7, 3))
|
133
761
|
|
134
|
-
|
135
|
-
|
762
|
+
# Test with single integer (square output)
|
763
|
+
output = nn.AdaptiveAvgPool2d(4, channel_axis=-1)(input)
|
764
|
+
self.assertEqual(output.shape, (2, 4, 4, 3))
|
136
765
|
|
137
|
-
def
|
138
|
-
|
139
|
-
input = brainstate.random.randn(
|
766
|
+
def test_adaptive_avg_pool3d(self):
|
767
|
+
"""Test AdaptiveAvgPool3d."""
|
768
|
+
input = brainstate.random.randn(1, 8, 6, 4, 2)
|
140
769
|
|
141
|
-
|
142
|
-
|
770
|
+
pool = nn.AdaptiveAvgPool3d((4, 3, 2), channel_axis=-1)
|
771
|
+
output = pool(input)
|
772
|
+
self.assertEqual(output.shape, (1, 4, 3, 2, 2))
|
143
773
|
|
144
|
-
|
145
|
-
|
774
|
+
# Cube output
|
775
|
+
pool = nn.AdaptiveAvgPool3d(3, channel_axis=-1)
|
776
|
+
output = pool(input)
|
777
|
+
self.assertEqual(output.shape, (1, 3, 3, 3, 2))
|
778
|
+
|
779
|
+
def test_adaptive_max_pool1d(self):
|
780
|
+
"""Test AdaptiveMaxPool1d."""
|
781
|
+
input = brainstate.random.randn(2, 32, 4)
|
782
|
+
|
783
|
+
pool = nn.AdaptiveMaxPool1d(8, channel_axis=-1)
|
784
|
+
output = pool(input)
|
785
|
+
self.assertEqual(output.shape, (2, 8, 4))
|
786
|
+
|
787
|
+
def test_adaptive_max_pool2d(self):
|
788
|
+
"""Test AdaptiveMaxPool2d."""
|
789
|
+
input = brainstate.random.randn(2, 10, 8, 3)
|
790
|
+
|
791
|
+
pool = nn.AdaptiveMaxPool2d((5, 4), channel_axis=-1)
|
792
|
+
output = pool(input)
|
793
|
+
self.assertEqual(output.shape, (2, 5, 4, 3))
|
794
|
+
|
795
|
+
def test_adaptive_max_pool3d(self):
|
796
|
+
"""Test AdaptiveMaxPool3d."""
|
797
|
+
input = brainstate.random.randn(1, 8, 8, 8, 2)
|
798
|
+
|
799
|
+
pool = nn.AdaptiveMaxPool3d((4, 4, 4), channel_axis=-1)
|
800
|
+
output = pool(input)
|
801
|
+
self.assertEqual(output.shape, (1, 4, 4, 4, 2))
|
802
|
+
|
803
|
+
|
804
|
+
class TestPoolingEdgeCases(parameterized.TestCase):
|
805
|
+
"""Test edge cases and error conditions."""
|
806
|
+
|
807
|
+
def test_pool_with_stride_none(self):
|
808
|
+
"""Test pooling with stride=None (defaults to kernel_size)."""
|
809
|
+
arr = brainstate.random.rand(1, 8, 2)
|
810
|
+
|
811
|
+
pool = nn.MaxPool1d(kernel_size=3, stride=None, channel_axis=-1)
|
812
|
+
out = pool(arr)
|
813
|
+
# stride defaults to kernel_size=3
|
814
|
+
self.assertEqual(out.shape, (1, 2, 2))
|
146
815
|
|
147
|
-
|
148
|
-
|
816
|
+
def test_pool_with_large_kernel(self):
|
817
|
+
"""Test pooling with kernel larger than input."""
|
818
|
+
arr = brainstate.random.rand(1, 4, 2)
|
149
819
|
|
150
|
-
|
151
|
-
|
152
|
-
|
820
|
+
# Kernel size larger than spatial dimension
|
821
|
+
pool = nn.MaxPool1d(kernel_size=5, stride=1, channel_axis=-1)
|
822
|
+
out = pool(arr)
|
823
|
+
# Should handle gracefully (may produce empty output or handle with padding)
|
824
|
+
self.assertTrue(out.shape[1] >= 0)
|
153
825
|
|
154
|
-
def
|
155
|
-
|
156
|
-
|
157
|
-
output = net(input)
|
158
|
-
self.assertTrue(output.shape == (10, 6, 5, 3))
|
826
|
+
def test_pool_single_element(self):
|
827
|
+
"""Test pooling on single-element tensors."""
|
828
|
+
arr = brainstate.random.rand(1, 1, 1)
|
159
829
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
830
|
+
pool = nn.AvgPool1d(1, 1, channel_axis=-1)
|
831
|
+
out = pool(arr)
|
832
|
+
self.assertEqual(out.shape, (1, 1, 1))
|
833
|
+
self.assertTrue(jnp.allclose(out, arr))
|
165
834
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
def test_AdaptiveMaxPool1d_v1(self, axis):
|
170
|
-
input = brainstate.random.randn(32, 16)
|
171
|
-
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
172
|
-
output = net(input)
|
835
|
+
def test_adaptive_pool_smaller_output(self):
|
836
|
+
"""Test adaptive pooling with output smaller than input."""
|
837
|
+
arr = brainstate.random.rand(1, 16, 2)
|
173
838
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
input = brainstate.random.randn(2, 32, 16)
|
179
|
-
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
180
|
-
output = net(input)
|
839
|
+
# Adaptive pooling to smaller size
|
840
|
+
pool = nn.AdaptiveAvgPool1d(4, channel_axis=-1)
|
841
|
+
out = pool(arr)
|
842
|
+
self.assertEqual(out.shape, (1, 4, 2))
|
181
843
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
input = brainstate.random.randn(32, 16, 12)
|
187
|
-
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
188
|
-
output = net(input)
|
844
|
+
def test_unpool_without_indices(self):
|
845
|
+
"""Test unpooling behavior with placeholder indices."""
|
846
|
+
pooled = brainstate.random.rand(1, 4, 2)
|
847
|
+
indices = jnp.zeros_like(pooled, dtype=jnp.int32)
|
189
848
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
input = brainstate.random.randn(2, 32, 16, 12)
|
195
|
-
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
196
|
-
output = net(input)
|
849
|
+
unpool = nn.MaxUnpool1d(2, 2, channel_axis=-1)
|
850
|
+
# Should not raise error even with zero indices
|
851
|
+
unpooled = unpool(pooled, indices)
|
852
|
+
self.assertEqual(unpooled.shape, (1, 8, 2))
|
197
853
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
854
|
+
def test_lppool_extreme_norm(self):
|
855
|
+
"""Test LPPool with extreme norm values."""
|
856
|
+
arr = brainstate.random.rand(1, 8, 2) + 0.1 # Avoid zeros
|
857
|
+
|
858
|
+
# Very large p (approaches max pooling)
|
859
|
+
pool_large = nn.LPPool1d(norm_type=20, kernel_size=2, stride=2, channel_axis=-1)
|
860
|
+
out_large = pool_large(arr)
|
861
|
+
|
862
|
+
# Compare with actual max pooling
|
863
|
+
pool_max = nn.MaxPool1d(2, 2, channel_axis=-1)
|
864
|
+
out_max = pool_max(arr)
|
865
|
+
|
866
|
+
# Should approach max pooling for large p (but not exactly equal)
|
867
|
+
# Just check shapes match
|
868
|
+
self.assertEqual(out_large.shape, out_max.shape)
|
869
|
+
|
870
|
+
def test_pool_with_channels_first(self):
|
871
|
+
"""Test pooling with channels in different positions."""
|
872
|
+
arr = brainstate.random.rand(3, 16, 8) # (dim0, dim1, dim2)
|
873
|
+
|
874
|
+
# Channel axis at position 0 - treats dim 0 as channels, pools last dimension
|
875
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=0)
|
876
|
+
out = pool(arr)
|
877
|
+
# Pools the last dimension, keeping first two
|
878
|
+
self.assertEqual(out.shape, (3, 16, 4))
|
879
|
+
|
880
|
+
# Channel axis at position -1 (last) - pools middle dimension
|
881
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=-1)
|
882
|
+
out = pool(arr)
|
883
|
+
# Pools the middle dimension, keeping first and last
|
884
|
+
self.assertEqual(out.shape, (3, 8, 8))
|
885
|
+
|
886
|
+
# No channel axis - pools last dimension, treating earlier dims as batch
|
887
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=None)
|
888
|
+
out = pool(arr)
|
889
|
+
# Pools the last dimension
|
890
|
+
self.assertEqual(out.shape, (3, 16, 4))
|
891
|
+
|
892
|
+
|
893
|
+
class TestPoolingMathematicalProperties(parameterized.TestCase):
|
894
|
+
"""Test mathematical properties of pooling operations."""
|
895
|
+
|
896
|
+
def test_maxpool_idempotence(self):
|
897
|
+
"""Test that max pooling with kernel_size=1 is identity."""
|
898
|
+
arr = brainstate.random.rand(2, 8, 3)
|
899
|
+
|
900
|
+
pool = nn.MaxPool1d(1, 1, channel_axis=-1)
|
901
|
+
out = pool(arr)
|
902
|
+
|
903
|
+
self.assertTrue(jnp.allclose(out, arr))
|
904
|
+
|
905
|
+
def test_avgpool_constant_input(self):
|
906
|
+
"""Test average pooling on constant input."""
|
907
|
+
arr = jnp.ones((1, 8, 2)) * 5.0
|
908
|
+
|
909
|
+
pool = nn.AvgPool1d(2, 2, channel_axis=-1)
|
910
|
+
out = pool(arr)
|
911
|
+
|
912
|
+
# Average of constant should be the constant
|
913
|
+
self.assertTrue(jnp.allclose(out, 5.0))
|
914
|
+
|
915
|
+
def test_lppool_norm_properties(self):
|
916
|
+
"""Test Lp pooling norm properties."""
|
917
|
+
arr = brainstate.random.rand(1, 4, 1) + 0.1
|
918
|
+
|
919
|
+
# L1 norm (p=1) should give average of absolute values
|
920
|
+
pool_l1 = nn.LPPool1d(norm_type=1, kernel_size=4, stride=4, channel_axis=-1)
|
921
|
+
out_l1 = pool_l1(arr)
|
922
|
+
|
923
|
+
# Manual calculation
|
924
|
+
manual_l1 = jnp.mean(jnp.abs(arr[:, :4, :]))
|
925
|
+
|
926
|
+
self.assertTrue(jnp.allclose(out_l1[0, 0, 0], manual_l1, rtol=1e-5))
|
927
|
+
|
928
|
+
def test_maxpool_monotonicity(self):
|
929
|
+
"""Test that max pooling preserves monotonicity."""
|
930
|
+
arr1 = brainstate.random.rand(1, 8, 2)
|
931
|
+
arr2 = arr1 + 1.0 # Strictly greater
|
932
|
+
|
933
|
+
pool = nn.MaxPool1d(2, 2, channel_axis=-1)
|
934
|
+
out1 = pool(arr1)
|
935
|
+
out2 = pool(arr2)
|
936
|
+
|
937
|
+
# out2 should be strictly greater than out1
|
938
|
+
self.assertTrue(jnp.all(out2 > out1))
|
939
|
+
|
940
|
+
def test_adaptive_pool_preserves_values(self):
|
941
|
+
"""Test that adaptive pooling with same size preserves values."""
|
942
|
+
arr = brainstate.random.rand(1, 8, 2)
|
943
|
+
|
944
|
+
# Adaptive pool to same size
|
945
|
+
pool = nn.AdaptiveAvgPool1d(8, channel_axis=-1)
|
946
|
+
out = pool(arr)
|
947
|
+
|
948
|
+
# Should be approximately equal (might have small numerical differences)
|
949
|
+
self.assertTrue(jnp.allclose(out, arr, rtol=1e-5))
|
214
950
|
|
215
951
|
|
216
952
|
if __name__ == '__main__':
|
217
|
-
absltest.main()
|
953
|
+
absltest.main()
|