brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -3
- brainstate/_state.py +178 -178
- brainstate/_utils.py +0 -1
- brainstate/augment/_autograd.py +0 -2
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape.py +0 -2
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping.py +2 -3
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/augment/_random.py +0 -2
- brainstate/compile/_ad_checkpoint.py +0 -2
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions.py +0 -2
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if.py +0 -2
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_jit.py +9 -8
- brainstate/compile/_loop_collect_return.py +0 -2
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection.py +0 -2
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +30 -17
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/compile/_progress_bar.py +0 -1
- brainstate/compile/_unvmap.py +0 -1
- brainstate/compile/_util.py +0 -2
- brainstate/environ.py +0 -2
- brainstate/functional/_activations.py +0 -2
- brainstate/functional/_activations_test.py +61 -61
- brainstate/functional/_normalization.py +0 -2
- brainstate/functional/_others.py +0 -2
- brainstate/functional/_spikes.py +0 -1
- brainstate/graph/_graph_node.py +1 -3
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation.py +4 -2
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_base.py +0 -2
- brainstate/init/_generic.py +0 -1
- brainstate/init/_random_inits.py +0 -1
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits.py +0 -2
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +0 -2
- brainstate/nn/_collective_ops.py +0 -3
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_common.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_inputs.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout.py +0 -1
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_projection_base.py +0 -1
- brainstate/nn/_dynamics/_state_delay.py +0 -2
- brainstate/nn/_dynamics/_synouts.py +0 -2
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout.py +0 -2
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise.py +0 -2
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv.py +0 -1
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv.py +0 -2
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler.py +0 -2
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv.py +0 -2
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_embedding.py +0 -1
- brainstate/nn/_interaction/_linear.py +0 -2
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations.py +0 -2
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings.py +0 -2
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module.py +0 -1
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/metrics.py +0 -2
- brainstate/optim/_base.py +0 -2
- brainstate/optim/_lr_scheduler.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer.py +0 -2
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/optim/_sgd_optimizer.py +0 -1
- brainstate/random/_rand_funs.py +0 -1
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed.py +0 -1
- brainstate/random/_rand_seed_test.py +10 -12
- brainstate/random/_rand_state.py +0 -1
- brainstate/surrogate.py +0 -1
- brainstate/typing.py +0 -2
- brainstate/util/_caller.py +4 -6
- brainstate/util/_others.py +0 -2
- brainstate/util/_pretty_pytree.py +201 -150
- brainstate/util/_pretty_repr.py +0 -2
- brainstate/util/_pretty_table.py +57 -3
- brainstate/util/_scaling.py +0 -2
- brainstate/util/_struct.py +0 -2
- brainstate/util/filter.py +0 -2
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
- brainstate-0.1.2.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
3
|
import jax.numpy as jnp
|
6
4
|
import pytest
|
7
5
|
from absl.testing import absltest
|
8
6
|
from absl.testing import parameterized
|
9
7
|
|
10
|
-
import brainstate
|
8
|
+
import brainstate
|
11
9
|
|
12
10
|
|
13
11
|
class TestConv(parameterized.TestCase):
|
@@ -19,8 +17,8 @@ class TestConv(parameterized.TestCase):
|
|
19
17
|
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
|
20
18
|
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
|
21
19
|
|
22
|
-
net =
|
23
|
-
|
20
|
+
net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
|
21
|
+
stride=(2, 1), padding='VALID', groups=4)
|
24
22
|
out = net(img)
|
25
23
|
print("out shape: ", out.shape)
|
26
24
|
self.assertEqual(out.shape, (2, 99, 196, 32))
|
@@ -30,7 +28,7 @@ class TestConv(parameterized.TestCase):
|
|
30
28
|
# plt.show()
|
31
29
|
|
32
30
|
def test_conv1D(self):
|
33
|
-
model =
|
31
|
+
model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
|
34
32
|
input = jnp.ones((2, 5, 3))
|
35
33
|
out = model(input)
|
36
34
|
print("out shape: ", out.shape)
|
@@ -41,7 +39,7 @@ class TestConv(parameterized.TestCase):
|
|
41
39
|
# plt.show()
|
42
40
|
|
43
41
|
def test_conv2D(self):
|
44
|
-
model =
|
42
|
+
model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
|
45
43
|
input = jnp.ones((2, 5, 5, 3))
|
46
44
|
|
47
45
|
out = model(input)
|
@@ -49,7 +47,7 @@ class TestConv(parameterized.TestCase):
|
|
49
47
|
self.assertEqual(out.shape, (2, 5, 5, 32))
|
50
48
|
|
51
49
|
def test_conv3D(self):
|
52
|
-
model =
|
50
|
+
model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
|
53
51
|
input = jnp.ones((2, 5, 5, 5, 3))
|
54
52
|
out = model(input)
|
55
53
|
print("out shape: ", out.shape)
|
@@ -62,13 +60,13 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
62
60
|
|
63
61
|
x = jnp.ones((1, 8, 3))
|
64
62
|
for use_bias in [True, False]:
|
65
|
-
conv_transpose_module =
|
63
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
66
64
|
in_channels=3,
|
67
65
|
out_channels=4,
|
68
66
|
kernel_size=(3,),
|
69
67
|
padding='VALID',
|
70
|
-
w_initializer=
|
71
|
-
b_initializer=
|
68
|
+
w_initializer=brainstate.init.Constant(1.),
|
69
|
+
b_initializer=brainstate.init.Constant(1.) if use_bias else None,
|
72
70
|
)
|
73
71
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
74
72
|
y = conv_transpose_module(x)
|
@@ -91,14 +89,14 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
91
89
|
|
92
90
|
x = jnp.ones((1, 8, 3))
|
93
91
|
m = jnp.tril(jnp.ones((3, 3, 4)))
|
94
|
-
conv_transpose_module =
|
92
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
95
93
|
in_channels=3,
|
96
94
|
out_channels=4,
|
97
95
|
kernel_size=(3,),
|
98
96
|
padding='VALID',
|
99
97
|
mask=m,
|
100
|
-
w_initializer=
|
101
|
-
b_initializer=
|
98
|
+
w_initializer=brainstate.init.Constant(),
|
99
|
+
b_initializer=brainstate.init.Constant(),
|
102
100
|
)
|
103
101
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
104
102
|
y = conv_transpose_module(x)
|
@@ -119,14 +117,14 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
119
117
|
|
120
118
|
data = jnp.ones([1, 3, 1])
|
121
119
|
for use_bias in [True, False]:
|
122
|
-
net =
|
120
|
+
net = brainstate.nn.ConvTranspose1d(
|
123
121
|
in_channels=1,
|
124
122
|
out_channels=1,
|
125
123
|
kernel_size=3,
|
126
124
|
stride=1,
|
127
125
|
padding="SAME",
|
128
|
-
w_initializer=
|
129
|
-
b_initializer=
|
126
|
+
w_initializer=brainstate.init.Constant(),
|
127
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
130
128
|
)
|
131
129
|
out = net(data)
|
132
130
|
self.assertEqual(out.shape, (1, 3, 1))
|
@@ -143,13 +141,13 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
143
141
|
|
144
142
|
x = jnp.ones((1, 8, 8, 3))
|
145
143
|
for use_bias in [True, False]:
|
146
|
-
conv_transpose_module =
|
144
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
147
145
|
in_channels=3,
|
148
146
|
out_channels=4,
|
149
147
|
kernel_size=(3, 3),
|
150
148
|
padding='VALID',
|
151
|
-
w_initializer=
|
152
|
-
b_initializer=
|
149
|
+
w_initializer=brainstate.init.Constant(),
|
150
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
153
151
|
)
|
154
152
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
155
153
|
y = conv_transpose_module(x)
|
@@ -159,13 +157,13 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
159
157
|
|
160
158
|
x = jnp.ones((1, 8, 8, 3))
|
161
159
|
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
162
|
-
conv_transpose_module =
|
160
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
163
161
|
in_channels=3,
|
164
162
|
out_channels=4,
|
165
163
|
kernel_size=(3, 3),
|
166
164
|
padding='VALID',
|
167
165
|
mask=m,
|
168
|
-
w_initializer=
|
166
|
+
w_initializer=brainstate.init.Constant(),
|
169
167
|
)
|
170
168
|
y = conv_transpose_module(x)
|
171
169
|
print(y.shape)
|
@@ -174,14 +172,14 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
174
172
|
|
175
173
|
x = jnp.ones((1, 8, 8, 3))
|
176
174
|
for use_bias in [True, False]:
|
177
|
-
conv_transpose_module =
|
175
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
178
176
|
in_channels=3,
|
179
177
|
out_channels=4,
|
180
178
|
kernel_size=(3, 3),
|
181
179
|
stride=1,
|
182
180
|
padding='SAME',
|
183
|
-
w_initializer=
|
184
|
-
b_initializer=
|
181
|
+
w_initializer=brainstate.init.Constant(),
|
182
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
185
183
|
)
|
186
184
|
y = conv_transpose_module(x)
|
187
185
|
print(y.shape)
|
@@ -193,13 +191,13 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
193
191
|
|
194
192
|
x = jnp.ones((1, 8, 8, 8, 3))
|
195
193
|
for use_bias in [True, False]:
|
196
|
-
conv_transpose_module =
|
194
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
197
195
|
in_channels=3,
|
198
196
|
out_channels=4,
|
199
197
|
kernel_size=(3, 3, 3),
|
200
198
|
padding='VALID',
|
201
|
-
w_initializer=
|
202
|
-
b_initializer=
|
199
|
+
w_initializer=brainstate.init.Constant(),
|
200
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
203
201
|
)
|
204
202
|
y = conv_transpose_module(x)
|
205
203
|
print(y.shape)
|
@@ -208,13 +206,13 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
208
206
|
|
209
207
|
x = jnp.ones((1, 8, 8, 8, 3))
|
210
208
|
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
211
|
-
conv_transpose_module =
|
209
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
212
210
|
in_channels=3,
|
213
211
|
out_channels=4,
|
214
212
|
kernel_size=(3, 3, 3),
|
215
213
|
padding='VALID',
|
216
214
|
mask=m,
|
217
|
-
w_initializer=
|
215
|
+
w_initializer=brainstate.init.Constant(),
|
218
216
|
)
|
219
217
|
y = conv_transpose_module(x)
|
220
218
|
print(y.shape)
|
@@ -223,14 +221,14 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
223
221
|
|
224
222
|
x = jnp.ones((1, 8, 8, 8, 3))
|
225
223
|
for use_bias in [True, False]:
|
226
|
-
conv_transpose_module =
|
224
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
227
225
|
in_channels=3,
|
228
226
|
out_channels=4,
|
229
227
|
kernel_size=(3, 3, 3),
|
230
228
|
stride=1,
|
231
229
|
padding='SAME',
|
232
|
-
w_initializer=
|
233
|
-
b_initializer=
|
230
|
+
w_initializer=brainstate.init.Constant(),
|
231
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
234
232
|
)
|
235
233
|
y = conv_transpose_module(x)
|
236
234
|
print(y.shape)
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from __future__ import annotations
|
16
15
|
|
17
16
|
from typing import Optional, Callable, Union
|
18
17
|
|
@@ -14,14 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
17
|
import unittest
|
20
18
|
|
21
19
|
import brainunit as u
|
22
20
|
from absl.testing import parameterized
|
23
21
|
|
24
|
-
import brainstate
|
22
|
+
import brainstate
|
25
23
|
|
26
24
|
|
27
25
|
class TestDense(parameterized.TestCase):
|
@@ -32,19 +30,19 @@ class TestDense(parameterized.TestCase):
|
|
32
30
|
num_out=[20, ]
|
33
31
|
)
|
34
32
|
def test_Dense1(self, size, num_out):
|
35
|
-
f =
|
36
|
-
x =
|
33
|
+
f = brainstate.nn.Linear(10, num_out)
|
34
|
+
x = brainstate.random.random(size)
|
37
35
|
y = f(x)
|
38
36
|
self.assertTrue(y.shape == size[:-1] + (num_out,))
|
39
37
|
|
40
38
|
|
41
39
|
class TestSparseMatrix(unittest.TestCase):
|
42
40
|
def test_csr(self):
|
43
|
-
data =
|
41
|
+
data = brainstate.random.rand(10, 20)
|
44
42
|
data = data * (data > 0.9)
|
45
|
-
f =
|
43
|
+
f = brainstate.nn.SparseLinear(u.sparse.CSR.fromdense(data))
|
46
44
|
|
47
|
-
x =
|
45
|
+
x = brainstate.random.rand(10)
|
48
46
|
y = f(x)
|
49
47
|
self.assertTrue(
|
50
48
|
u.math.allclose(
|
@@ -53,7 +51,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
53
51
|
)
|
54
52
|
)
|
55
53
|
|
56
|
-
x =
|
54
|
+
x = brainstate.random.rand(5, 10)
|
57
55
|
y = f(x)
|
58
56
|
self.assertTrue(
|
59
57
|
u.math.allclose(
|
@@ -63,11 +61,11 @@ class TestSparseMatrix(unittest.TestCase):
|
|
63
61
|
)
|
64
62
|
|
65
63
|
def test_csc(self):
|
66
|
-
data =
|
64
|
+
data = brainstate.random.rand(10, 20)
|
67
65
|
data = data * (data > 0.9)
|
68
|
-
f =
|
66
|
+
f = brainstate.nn.SparseLinear(u.sparse.CSC.fromdense(data))
|
69
67
|
|
70
|
-
x =
|
68
|
+
x = brainstate.random.rand(10)
|
71
69
|
y = f(x)
|
72
70
|
self.assertTrue(
|
73
71
|
u.math.allclose(
|
@@ -76,7 +74,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
76
74
|
)
|
77
75
|
)
|
78
76
|
|
79
|
-
x =
|
77
|
+
x = brainstate.random.rand(5, 10)
|
80
78
|
y = f(x)
|
81
79
|
self.assertTrue(
|
82
80
|
u.math.allclose(
|
@@ -86,11 +84,11 @@ class TestSparseMatrix(unittest.TestCase):
|
|
86
84
|
)
|
87
85
|
|
88
86
|
def test_coo(self):
|
89
|
-
data =
|
87
|
+
data = brainstate.random.rand(10, 20)
|
90
88
|
data = data * (data > 0.9)
|
91
|
-
f =
|
89
|
+
f = brainstate.nn.SparseLinear(u.sparse.COO.fromdense(data))
|
92
90
|
|
93
|
-
x =
|
91
|
+
x = brainstate.random.rand(10)
|
94
92
|
y = f(x)
|
95
93
|
self.assertTrue(
|
96
94
|
u.math.allclose(
|
@@ -99,7 +97,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
99
97
|
)
|
100
98
|
)
|
101
99
|
|
102
|
-
x =
|
100
|
+
x = brainstate.random.rand(5, 10)
|
103
101
|
y = f(x)
|
104
102
|
self.assertTrue(
|
105
103
|
u.math.allclose(
|
@@ -13,12 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
from absl.testing import absltest
|
19
17
|
from absl.testing import parameterized
|
20
18
|
|
21
|
-
import brainstate
|
19
|
+
import brainstate
|
22
20
|
|
23
21
|
|
24
22
|
class Test_Normalization(parameterized.TestCase):
|
@@ -26,27 +24,27 @@ class Test_Normalization(parameterized.TestCase):
|
|
26
24
|
fit=[True, False],
|
27
25
|
)
|
28
26
|
def test_BatchNorm1d(self, fit):
|
29
|
-
net =
|
30
|
-
|
31
|
-
input =
|
27
|
+
net = brainstate.nn.BatchNorm1d((3, 10))
|
28
|
+
brainstate.environ.set(fit=fit)
|
29
|
+
input = brainstate.random.randn(1, 3, 10)
|
32
30
|
output = net(input)
|
33
31
|
|
34
32
|
@parameterized.product(
|
35
33
|
fit=[True, False]
|
36
34
|
)
|
37
35
|
def test_BatchNorm2d(self, fit):
|
38
|
-
net =
|
39
|
-
|
40
|
-
input =
|
36
|
+
net = brainstate.nn.BatchNorm2d([3, 4, 10])
|
37
|
+
brainstate.environ.set(fit=fit)
|
38
|
+
input = brainstate.random.randn(1, 3, 4, 10)
|
41
39
|
output = net(input)
|
42
40
|
|
43
41
|
@parameterized.product(
|
44
42
|
fit=[True, False]
|
45
43
|
)
|
46
44
|
def test_BatchNorm3d(self, fit):
|
47
|
-
net =
|
48
|
-
|
49
|
-
input =
|
45
|
+
net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
|
46
|
+
brainstate.environ.set(fit=fit)
|
47
|
+
input = brainstate.random.randn(1, 3, 4, 5, 10)
|
50
48
|
output = net(input)
|
51
49
|
|
52
50
|
# @parameterized.product(
|
@@ -1,13 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
3
|
import jax
|
6
4
|
import numpy as np
|
7
5
|
from absl.testing import absltest
|
8
6
|
from absl.testing import parameterized
|
9
7
|
|
10
|
-
import brainstate
|
8
|
+
import brainstate
|
11
9
|
import brainstate.nn as nn
|
12
10
|
|
13
11
|
|
@@ -18,7 +16,7 @@ class TestFlatten(parameterized.TestCase):
|
|
18
16
|
(32, 8),
|
19
17
|
(10, 20, 30),
|
20
18
|
]:
|
21
|
-
arr =
|
19
|
+
arr = brainstate.random.rand(*size)
|
22
20
|
f = nn.Flatten(start_axis=0)
|
23
21
|
out = f(arr)
|
24
22
|
self.assertTrue(out.shape == (np.prod(size),))
|
@@ -29,21 +27,21 @@ class TestFlatten(parameterized.TestCase):
|
|
29
27
|
(32, 8),
|
30
28
|
(10, 20, 30),
|
31
29
|
]:
|
32
|
-
arr =
|
30
|
+
arr = brainstate.random.rand(*size)
|
33
31
|
f = nn.Flatten(start_axis=1)
|
34
32
|
out = f(arr)
|
35
33
|
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
36
34
|
|
37
35
|
def test_flatten3(self):
|
38
36
|
size = (16, 32, 32, 8)
|
39
|
-
arr =
|
37
|
+
arr = brainstate.random.rand(*size)
|
40
38
|
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
41
39
|
out = f(arr)
|
42
40
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
43
41
|
|
44
42
|
def test_flatten4(self):
|
45
43
|
size = (16, 32, 32, 8)
|
46
|
-
arr =
|
44
|
+
arr = brainstate.random.rand(*size)
|
47
45
|
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
48
46
|
out = f(arr)
|
49
47
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
@@ -58,7 +56,7 @@ class TestPool(parameterized.TestCase):
|
|
58
56
|
super().__init__(*args, **kwargs)
|
59
57
|
|
60
58
|
def test_MaxPool2d_v1(self):
|
61
|
-
arr =
|
59
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
62
60
|
|
63
61
|
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
64
62
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -79,7 +77,7 @@ class TestPool(parameterized.TestCase):
|
|
79
77
|
self.assertTrue(out.shape == (16, 17, 32, 5))
|
80
78
|
|
81
79
|
def test_AvgPool2d_v1(self):
|
82
|
-
arr =
|
80
|
+
arr = brainstate.random.rand(16, 32, 32, 8)
|
83
81
|
|
84
82
|
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
85
83
|
self.assertTrue(out.shape == (16, 16, 16, 8))
|
@@ -107,7 +105,7 @@ class TestPool(parameterized.TestCase):
|
|
107
105
|
def test_adaptive_pool1d(self, target_size):
|
108
106
|
from brainstate.nn._interaction._poolings import _adaptive_pool1d
|
109
107
|
|
110
|
-
arr =
|
108
|
+
arr = brainstate.random.rand(100)
|
111
109
|
op = jax.numpy.mean
|
112
110
|
|
113
111
|
out = _adaptive_pool1d(arr, target_size, op)
|
@@ -119,7 +117,7 @@ class TestPool(parameterized.TestCase):
|
|
119
117
|
self.assertTrue(out.shape == (target_size,))
|
120
118
|
|
121
119
|
def test_AdaptiveAvgPool2d_v1(self):
|
122
|
-
input =
|
120
|
+
input = brainstate.random.randn(64, 8, 9)
|
123
121
|
|
124
122
|
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
125
123
|
self.assertTrue(output.shape == (64, 5, 7))
|
@@ -137,8 +135,8 @@ class TestPool(parameterized.TestCase):
|
|
137
135
|
self.assertTrue(output.shape == (64, 2, 3))
|
138
136
|
|
139
137
|
def test_AdaptiveAvgPool2d_v2(self):
|
140
|
-
|
141
|
-
input =
|
138
|
+
brainstate.random.seed()
|
139
|
+
input = brainstate.random.randn(128, 64, 32, 16)
|
142
140
|
|
143
141
|
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
144
142
|
self.assertTrue(output.shape == (128, 64, 5, 7))
|
@@ -154,13 +152,13 @@ class TestPool(parameterized.TestCase):
|
|
154
152
|
print()
|
155
153
|
|
156
154
|
def test_AdaptiveAvgPool3d_v1(self):
|
157
|
-
input =
|
155
|
+
input = brainstate.random.randn(10, 128, 64, 32)
|
158
156
|
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
|
159
157
|
output = net(input)
|
160
158
|
self.assertTrue(output.shape == (10, 6, 5, 3))
|
161
159
|
|
162
160
|
def test_AdaptiveAvgPool3d_v2(self):
|
163
|
-
input =
|
161
|
+
input = brainstate.random.randn(10, 20, 128, 64, 32)
|
164
162
|
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
|
165
163
|
output = net(input)
|
166
164
|
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
@@ -169,7 +167,7 @@ class TestPool(parameterized.TestCase):
|
|
169
167
|
axis=(-1, 0, 1)
|
170
168
|
)
|
171
169
|
def test_AdaptiveMaxPool1d_v1(self, axis):
|
172
|
-
input =
|
170
|
+
input = brainstate.random.randn(32, 16)
|
173
171
|
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
174
172
|
output = net(input)
|
175
173
|
|
@@ -177,7 +175,7 @@ class TestPool(parameterized.TestCase):
|
|
177
175
|
axis=(-1, 0, 1, 2)
|
178
176
|
)
|
179
177
|
def test_AdaptiveMaxPool1d_v2(self, axis):
|
180
|
-
input =
|
178
|
+
input = brainstate.random.randn(2, 32, 16)
|
181
179
|
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
182
180
|
output = net(input)
|
183
181
|
|
@@ -185,7 +183,7 @@ class TestPool(parameterized.TestCase):
|
|
185
183
|
axis=(-1, 0, 1, 2)
|
186
184
|
)
|
187
185
|
def test_AdaptiveMaxPool2d_v1(self, axis):
|
188
|
-
input =
|
186
|
+
input = brainstate.random.randn(32, 16, 12)
|
189
187
|
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
190
188
|
output = net(input)
|
191
189
|
|
@@ -193,7 +191,7 @@ class TestPool(parameterized.TestCase):
|
|
193
191
|
axis=(-1, 0, 1, 2, 3)
|
194
192
|
)
|
195
193
|
def test_AdaptiveMaxPool2d_v2(self, axis):
|
196
|
-
input =
|
194
|
+
input = brainstate.random.randn(2, 32, 16, 12)
|
197
195
|
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
198
196
|
output = net(input)
|
199
197
|
|
@@ -201,7 +199,7 @@ class TestPool(parameterized.TestCase):
|
|
201
199
|
axis=(-1, 0, 1, 2, 3)
|
202
200
|
)
|
203
201
|
def test_AdaptiveMaxPool3d_v1(self, axis):
|
204
|
-
input =
|
202
|
+
input = brainstate.random.randn(2, 128, 64, 32)
|
205
203
|
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
206
204
|
output = net(input)
|
207
205
|
print()
|
@@ -210,7 +208,7 @@ class TestPool(parameterized.TestCase):
|
|
210
208
|
axis=(-1, 0, 1, 2, 3, 4)
|
211
209
|
)
|
212
210
|
def test_AdaptiveMaxPool3d_v1(self, axis):
|
213
|
-
input =
|
211
|
+
input = brainstate.random.randn(2, 128, 64, 32, 16)
|
214
212
|
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
215
213
|
output = net(input)
|
216
214
|
|
brainstate/nn/_module.py
CHANGED
@@ -25,7 +25,6 @@ The basic classes include:
|
|
25
25
|
- ``Sequential``: The class for a sequential of modules, which update the modules sequentially.
|
26
26
|
|
27
27
|
"""
|
28
|
-
from __future__ import annotations
|
29
28
|
|
30
29
|
import warnings
|
31
30
|
from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
|