brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
brainstate/nn/_poolings_test.py
CHANGED
@@ -1,217 +1,217 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
|
-
import jax
|
4
|
-
import numpy as np
|
5
|
-
from absl.testing import absltest
|
6
|
-
from absl.testing import parameterized
|
7
|
-
|
8
|
-
import brainstate
|
9
|
-
import brainstate.nn as nn
|
10
|
-
|
11
|
-
|
12
|
-
class TestFlatten(parameterized.TestCase):
|
13
|
-
def test_flatten1(self):
|
14
|
-
for size in [
|
15
|
-
(16, 32, 32, 8),
|
16
|
-
(32, 8),
|
17
|
-
(10, 20, 30),
|
18
|
-
]:
|
19
|
-
arr = brainstate.random.rand(*size)
|
20
|
-
f = nn.Flatten(start_axis=0)
|
21
|
-
out = f(arr)
|
22
|
-
self.assertTrue(out.shape == (np.prod(size),))
|
23
|
-
|
24
|
-
def test_flatten2(self):
|
25
|
-
for size in [
|
26
|
-
(16, 32, 32, 8),
|
27
|
-
(32, 8),
|
28
|
-
(10, 20, 30),
|
29
|
-
]:
|
30
|
-
arr = brainstate.random.rand(*size)
|
31
|
-
f = nn.Flatten(start_axis=1)
|
32
|
-
out = f(arr)
|
33
|
-
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
34
|
-
|
35
|
-
def test_flatten3(self):
|
36
|
-
size = (16, 32, 32, 8)
|
37
|
-
arr = brainstate.random.rand(*size)
|
38
|
-
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
39
|
-
out = f(arr)
|
40
|
-
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
41
|
-
|
42
|
-
def test_flatten4(self):
|
43
|
-
size = (16, 32, 32, 8)
|
44
|
-
arr = brainstate.random.rand(*size)
|
45
|
-
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
46
|
-
out = f(arr)
|
47
|
-
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
48
|
-
|
49
|
-
|
50
|
-
class TestUnflatten(parameterized.TestCase):
|
51
|
-
pass
|
52
|
-
|
53
|
-
|
54
|
-
class TestPool(parameterized.TestCase):
|
55
|
-
def __init__(self, *args, **kwargs):
|
56
|
-
super().__init__(*args, **kwargs)
|
57
|
-
|
58
|
-
def test_MaxPool2d_v1(self):
|
59
|
-
arr = brainstate.random.rand(16, 32, 32, 8)
|
60
|
-
|
61
|
-
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
62
|
-
self.assertTrue(out.shape == (16, 16, 16, 8))
|
63
|
-
|
64
|
-
out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
|
65
|
-
self.assertTrue(out.shape == (16, 32, 16, 4))
|
66
|
-
|
67
|
-
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
68
|
-
self.assertTrue(out.shape == (16, 32, 17, 5))
|
69
|
-
|
70
|
-
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
|
71
|
-
self.assertTrue(out.shape == (16, 32, 18, 5))
|
72
|
-
|
73
|
-
out = nn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
|
74
|
-
self.assertTrue(out.shape == (16, 17, 17, 8))
|
75
|
-
|
76
|
-
out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
77
|
-
self.assertTrue(out.shape == (16, 17, 32, 5))
|
78
|
-
|
79
|
-
def test_AvgPool2d_v1(self):
|
80
|
-
arr = brainstate.random.rand(16, 32, 32, 8)
|
81
|
-
|
82
|
-
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
83
|
-
self.assertTrue(out.shape == (16, 16, 16, 8))
|
84
|
-
|
85
|
-
out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
|
86
|
-
self.assertTrue(out.shape == (16, 32, 16, 4))
|
87
|
-
|
88
|
-
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
89
|
-
self.assertTrue(out.shape == (16, 32, 17, 5))
|
90
|
-
|
91
|
-
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
|
92
|
-
self.assertTrue(out.shape == (16, 32, 18, 5))
|
93
|
-
|
94
|
-
out = nn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
|
95
|
-
self.assertTrue(out.shape == (16, 17, 17, 8))
|
96
|
-
|
97
|
-
out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
98
|
-
self.assertTrue(out.shape == (16, 17, 32, 5))
|
99
|
-
|
100
|
-
@parameterized.named_parameters(
|
101
|
-
dict(testcase_name=f'target_size={target_size}',
|
102
|
-
target_size=target_size)
|
103
|
-
for target_size in [10, 9, 8, 7, 6]
|
104
|
-
)
|
105
|
-
def test_adaptive_pool1d(self, target_size):
|
106
|
-
from brainstate.nn._poolings import _adaptive_pool1d
|
107
|
-
|
108
|
-
arr = brainstate.random.rand(100)
|
109
|
-
op = jax.numpy.mean
|
110
|
-
|
111
|
-
out = _adaptive_pool1d(arr, target_size, op)
|
112
|
-
print(out.shape)
|
113
|
-
self.assertTrue(out.shape == (target_size,))
|
114
|
-
|
115
|
-
out = _adaptive_pool1d(arr, target_size, op)
|
116
|
-
print(out.shape)
|
117
|
-
self.assertTrue(out.shape == (target_size,))
|
118
|
-
|
119
|
-
def test_AdaptiveAvgPool2d_v1(self):
|
120
|
-
input = brainstate.random.randn(64, 8, 9)
|
121
|
-
|
122
|
-
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
123
|
-
self.assertTrue(output.shape == (64, 5, 7))
|
124
|
-
|
125
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
|
126
|
-
self.assertTrue(output.shape == (64, 2, 3))
|
127
|
-
|
128
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
|
129
|
-
self.assertTrue(output.shape == (2, 3, 9))
|
130
|
-
|
131
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
|
132
|
-
self.assertTrue(output.shape == (2, 8, 3))
|
133
|
-
|
134
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
|
135
|
-
self.assertTrue(output.shape == (64, 2, 3))
|
136
|
-
|
137
|
-
def test_AdaptiveAvgPool2d_v2(self):
|
138
|
-
brainstate.random.seed()
|
139
|
-
input = brainstate.random.randn(128, 64, 32, 16)
|
140
|
-
|
141
|
-
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
142
|
-
self.assertTrue(output.shape == (128, 64, 5, 7))
|
143
|
-
|
144
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
|
145
|
-
self.assertTrue(output.shape == (128, 64, 2, 3))
|
146
|
-
|
147
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
|
148
|
-
self.assertTrue(output.shape == (128, 2, 3, 16))
|
149
|
-
|
150
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
|
151
|
-
self.assertTrue(output.shape == (128, 64, 2, 3))
|
152
|
-
print()
|
153
|
-
|
154
|
-
def test_AdaptiveAvgPool3d_v1(self):
|
155
|
-
input = brainstate.random.randn(10, 128, 64, 32)
|
156
|
-
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
|
157
|
-
output = net(input)
|
158
|
-
self.assertTrue(output.shape == (10, 6, 5, 3))
|
159
|
-
|
160
|
-
def test_AdaptiveAvgPool3d_v2(self):
|
161
|
-
input = brainstate.random.randn(10, 20, 128, 64, 32)
|
162
|
-
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
|
163
|
-
output = net(input)
|
164
|
-
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
165
|
-
|
166
|
-
@parameterized.product(
|
167
|
-
axis=(-1, 0, 1)
|
168
|
-
)
|
169
|
-
def test_AdaptiveMaxPool1d_v1(self, axis):
|
170
|
-
input = brainstate.random.randn(32, 16)
|
171
|
-
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
172
|
-
output = net(input)
|
173
|
-
|
174
|
-
@parameterized.product(
|
175
|
-
axis=(-1, 0, 1, 2)
|
176
|
-
)
|
177
|
-
def test_AdaptiveMaxPool1d_v2(self, axis):
|
178
|
-
input = brainstate.random.randn(2, 32, 16)
|
179
|
-
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
180
|
-
output = net(input)
|
181
|
-
|
182
|
-
@parameterized.product(
|
183
|
-
axis=(-1, 0, 1, 2)
|
184
|
-
)
|
185
|
-
def test_AdaptiveMaxPool2d_v1(self, axis):
|
186
|
-
input = brainstate.random.randn(32, 16, 12)
|
187
|
-
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
188
|
-
output = net(input)
|
189
|
-
|
190
|
-
@parameterized.product(
|
191
|
-
axis=(-1, 0, 1, 2, 3)
|
192
|
-
)
|
193
|
-
def test_AdaptiveMaxPool2d_v2(self, axis):
|
194
|
-
input = brainstate.random.randn(2, 32, 16, 12)
|
195
|
-
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
196
|
-
output = net(input)
|
197
|
-
|
198
|
-
@parameterized.product(
|
199
|
-
axis=(-1, 0, 1, 2, 3)
|
200
|
-
)
|
201
|
-
def test_AdaptiveMaxPool3d_v1(self, axis):
|
202
|
-
input = brainstate.random.randn(2, 128, 64, 32)
|
203
|
-
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
204
|
-
output = net(input)
|
205
|
-
print()
|
206
|
-
|
207
|
-
@parameterized.product(
|
208
|
-
axis=(-1, 0, 1, 2, 3, 4)
|
209
|
-
)
|
210
|
-
def test_AdaptiveMaxPool3d_v1(self, axis):
|
211
|
-
input = brainstate.random.randn(2, 128, 64, 32, 16)
|
212
|
-
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
213
|
-
output = net(input)
|
214
|
-
|
215
|
-
|
216
|
-
if __name__ == '__main__':
|
217
|
-
absltest.main()
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import numpy as np
|
5
|
+
from absl.testing import absltest
|
6
|
+
from absl.testing import parameterized
|
7
|
+
|
8
|
+
import brainstate
|
9
|
+
import brainstate.nn as nn
|
10
|
+
|
11
|
+
|
12
|
+
class TestFlatten(parameterized.TestCase):
|
13
|
+
def test_flatten1(self):
|
14
|
+
for size in [
|
15
|
+
(16, 32, 32, 8),
|
16
|
+
(32, 8),
|
17
|
+
(10, 20, 30),
|
18
|
+
]:
|
19
|
+
arr = brainstate.random.rand(*size)
|
20
|
+
f = nn.Flatten(start_axis=0)
|
21
|
+
out = f(arr)
|
22
|
+
self.assertTrue(out.shape == (np.prod(size),))
|
23
|
+
|
24
|
+
def test_flatten2(self):
|
25
|
+
for size in [
|
26
|
+
(16, 32, 32, 8),
|
27
|
+
(32, 8),
|
28
|
+
(10, 20, 30),
|
29
|
+
]:
|
30
|
+
arr = brainstate.random.rand(*size)
|
31
|
+
f = nn.Flatten(start_axis=1)
|
32
|
+
out = f(arr)
|
33
|
+
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
34
|
+
|
35
|
+
def test_flatten3(self):
|
36
|
+
size = (16, 32, 32, 8)
|
37
|
+
arr = brainstate.random.rand(*size)
|
38
|
+
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
39
|
+
out = f(arr)
|
40
|
+
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
41
|
+
|
42
|
+
def test_flatten4(self):
|
43
|
+
size = (16, 32, 32, 8)
|
44
|
+
arr = brainstate.random.rand(*size)
|
45
|
+
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
46
|
+
out = f(arr)
|
47
|
+
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
48
|
+
|
49
|
+
|
50
|
+
class TestUnflatten(parameterized.TestCase):
|
51
|
+
pass
|
52
|
+
|
53
|
+
|
54
|
+
class TestPool(parameterized.TestCase):
|
55
|
+
def __init__(self, *args, **kwargs):
|
56
|
+
super().__init__(*args, **kwargs)
|
57
|
+
|
58
|
+
def test_MaxPool2d_v1(self):
|
59
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
60
|
+
|
61
|
+
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
62
|
+
self.assertTrue(out.shape == (16, 16, 16, 8))
|
63
|
+
|
64
|
+
out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
|
65
|
+
self.assertTrue(out.shape == (16, 32, 16, 4))
|
66
|
+
|
67
|
+
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
68
|
+
self.assertTrue(out.shape == (16, 32, 17, 5))
|
69
|
+
|
70
|
+
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
|
71
|
+
self.assertTrue(out.shape == (16, 32, 18, 5))
|
72
|
+
|
73
|
+
out = nn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
|
74
|
+
self.assertTrue(out.shape == (16, 17, 17, 8))
|
75
|
+
|
76
|
+
out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
77
|
+
self.assertTrue(out.shape == (16, 17, 32, 5))
|
78
|
+
|
79
|
+
def test_AvgPool2d_v1(self):
|
80
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
81
|
+
|
82
|
+
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
83
|
+
self.assertTrue(out.shape == (16, 16, 16, 8))
|
84
|
+
|
85
|
+
out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
|
86
|
+
self.assertTrue(out.shape == (16, 32, 16, 4))
|
87
|
+
|
88
|
+
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
89
|
+
self.assertTrue(out.shape == (16, 32, 17, 5))
|
90
|
+
|
91
|
+
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
|
92
|
+
self.assertTrue(out.shape == (16, 32, 18, 5))
|
93
|
+
|
94
|
+
out = nn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
|
95
|
+
self.assertTrue(out.shape == (16, 17, 17, 8))
|
96
|
+
|
97
|
+
out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
98
|
+
self.assertTrue(out.shape == (16, 17, 32, 5))
|
99
|
+
|
100
|
+
@parameterized.named_parameters(
|
101
|
+
dict(testcase_name=f'target_size={target_size}',
|
102
|
+
target_size=target_size)
|
103
|
+
for target_size in [10, 9, 8, 7, 6]
|
104
|
+
)
|
105
|
+
def test_adaptive_pool1d(self, target_size):
|
106
|
+
from brainstate.nn._poolings import _adaptive_pool1d
|
107
|
+
|
108
|
+
arr = brainstate.random.rand(100)
|
109
|
+
op = jax.numpy.mean
|
110
|
+
|
111
|
+
out = _adaptive_pool1d(arr, target_size, op)
|
112
|
+
print(out.shape)
|
113
|
+
self.assertTrue(out.shape == (target_size,))
|
114
|
+
|
115
|
+
out = _adaptive_pool1d(arr, target_size, op)
|
116
|
+
print(out.shape)
|
117
|
+
self.assertTrue(out.shape == (target_size,))
|
118
|
+
|
119
|
+
def test_AdaptiveAvgPool2d_v1(self):
|
120
|
+
input = brainstate.random.randn(64, 8, 9)
|
121
|
+
|
122
|
+
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
123
|
+
self.assertTrue(output.shape == (64, 5, 7))
|
124
|
+
|
125
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
|
126
|
+
self.assertTrue(output.shape == (64, 2, 3))
|
127
|
+
|
128
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
|
129
|
+
self.assertTrue(output.shape == (2, 3, 9))
|
130
|
+
|
131
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
|
132
|
+
self.assertTrue(output.shape == (2, 8, 3))
|
133
|
+
|
134
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
|
135
|
+
self.assertTrue(output.shape == (64, 2, 3))
|
136
|
+
|
137
|
+
def test_AdaptiveAvgPool2d_v2(self):
|
138
|
+
brainstate.random.seed()
|
139
|
+
input = brainstate.random.randn(128, 64, 32, 16)
|
140
|
+
|
141
|
+
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
142
|
+
self.assertTrue(output.shape == (128, 64, 5, 7))
|
143
|
+
|
144
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
|
145
|
+
self.assertTrue(output.shape == (128, 64, 2, 3))
|
146
|
+
|
147
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
|
148
|
+
self.assertTrue(output.shape == (128, 2, 3, 16))
|
149
|
+
|
150
|
+
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
|
151
|
+
self.assertTrue(output.shape == (128, 64, 2, 3))
|
152
|
+
print()
|
153
|
+
|
154
|
+
def test_AdaptiveAvgPool3d_v1(self):
|
155
|
+
input = brainstate.random.randn(10, 128, 64, 32)
|
156
|
+
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
|
157
|
+
output = net(input)
|
158
|
+
self.assertTrue(output.shape == (10, 6, 5, 3))
|
159
|
+
|
160
|
+
def test_AdaptiveAvgPool3d_v2(self):
|
161
|
+
input = brainstate.random.randn(10, 20, 128, 64, 32)
|
162
|
+
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
|
163
|
+
output = net(input)
|
164
|
+
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
165
|
+
|
166
|
+
@parameterized.product(
|
167
|
+
axis=(-1, 0, 1)
|
168
|
+
)
|
169
|
+
def test_AdaptiveMaxPool1d_v1(self, axis):
|
170
|
+
input = brainstate.random.randn(32, 16)
|
171
|
+
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
172
|
+
output = net(input)
|
173
|
+
|
174
|
+
@parameterized.product(
|
175
|
+
axis=(-1, 0, 1, 2)
|
176
|
+
)
|
177
|
+
def test_AdaptiveMaxPool1d_v2(self, axis):
|
178
|
+
input = brainstate.random.randn(2, 32, 16)
|
179
|
+
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
180
|
+
output = net(input)
|
181
|
+
|
182
|
+
@parameterized.product(
|
183
|
+
axis=(-1, 0, 1, 2)
|
184
|
+
)
|
185
|
+
def test_AdaptiveMaxPool2d_v1(self, axis):
|
186
|
+
input = brainstate.random.randn(32, 16, 12)
|
187
|
+
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
188
|
+
output = net(input)
|
189
|
+
|
190
|
+
@parameterized.product(
|
191
|
+
axis=(-1, 0, 1, 2, 3)
|
192
|
+
)
|
193
|
+
def test_AdaptiveMaxPool2d_v2(self, axis):
|
194
|
+
input = brainstate.random.randn(2, 32, 16, 12)
|
195
|
+
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
196
|
+
output = net(input)
|
197
|
+
|
198
|
+
@parameterized.product(
|
199
|
+
axis=(-1, 0, 1, 2, 3)
|
200
|
+
)
|
201
|
+
def test_AdaptiveMaxPool3d_v1(self, axis):
|
202
|
+
input = brainstate.random.randn(2, 128, 64, 32)
|
203
|
+
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
204
|
+
output = net(input)
|
205
|
+
print()
|
206
|
+
|
207
|
+
@parameterized.product(
|
208
|
+
axis=(-1, 0, 1, 2, 3, 4)
|
209
|
+
)
|
210
|
+
def test_AdaptiveMaxPool3d_v1(self, axis):
|
211
|
+
input = brainstate.random.randn(2, 128, 64, 32, 16)
|
212
|
+
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
213
|
+
output = net(input)
|
214
|
+
|
215
|
+
|
216
|
+
if __name__ == '__main__':
|
217
|
+
absltest.main()
|