brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +3 -0
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module_test.py +34 -37
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,19 @@ import unittest
|
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
|
21
|
-
import brainstate
|
21
|
+
import brainstate
|
22
22
|
|
23
23
|
|
24
24
|
class TestDropout(unittest.TestCase):
|
25
25
|
|
26
26
|
def test_dropout(self):
|
27
27
|
# Create a Dropout layer with a dropout rate of 0.5
|
28
|
-
dropout_layer =
|
28
|
+
dropout_layer = brainstate.nn.Dropout(0.5)
|
29
29
|
|
30
30
|
# Input data
|
31
31
|
input_data = np.arange(20)
|
32
32
|
|
33
|
-
with
|
33
|
+
with brainstate.environ.context(fit=True):
|
34
34
|
# Apply dropout
|
35
35
|
output_data = dropout_layer(input_data)
|
36
36
|
|
@@ -47,10 +47,10 @@ class TestDropout(unittest.TestCase):
|
|
47
47
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
48
48
|
|
49
49
|
def test_DropoutFixed(self):
|
50
|
-
dropout_layer =
|
50
|
+
dropout_layer = brainstate.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
|
51
51
|
dropout_layer.init_state(batch_size=2)
|
52
52
|
input_data = np.random.randn(2, 2, 3)
|
53
|
-
with
|
53
|
+
with brainstate.environ.context(fit=True):
|
54
54
|
output_data = dropout_layer.update(input_data)
|
55
55
|
self.assertEqual(input_data.shape, output_data.shape)
|
56
56
|
self.assertTrue(np.any(output_data == 0))
|
@@ -72,9 +72,9 @@ class TestDropout(unittest.TestCase):
|
|
72
72
|
# np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
73
73
|
|
74
74
|
def test_Dropout2d(self):
|
75
|
-
dropout_layer =
|
75
|
+
dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
|
76
76
|
input_data = np.random.randn(2, 3, 4, 5)
|
77
|
-
with
|
77
|
+
with brainstate.environ.context(fit=True):
|
78
78
|
output_data = dropout_layer(input_data)
|
79
79
|
self.assertEqual(input_data.shape, output_data.shape)
|
80
80
|
self.assertTrue(np.any(output_data == 0))
|
@@ -84,9 +84,9 @@ class TestDropout(unittest.TestCase):
|
|
84
84
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
85
85
|
|
86
86
|
def test_Dropout3d(self):
|
87
|
-
dropout_layer =
|
87
|
+
dropout_layer = brainstate.nn.Dropout3d(prob=0.5)
|
88
88
|
input_data = np.random.randn(2, 3, 4, 5, 6)
|
89
|
-
with
|
89
|
+
with brainstate.environ.context(fit=True):
|
90
90
|
output_data = dropout_layer(input_data)
|
91
91
|
self.assertEqual(input_data.shape, output_data.shape)
|
92
92
|
self.assertTrue(np.any(output_data == 0))
|
@@ -13,157 +13,155 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
from absl.testing import absltest
|
19
17
|
from absl.testing import parameterized
|
20
18
|
|
21
|
-
import brainstate
|
19
|
+
import brainstate
|
22
20
|
|
23
21
|
|
24
22
|
class Test_Activation(parameterized.TestCase):
|
25
23
|
|
26
24
|
def test_Threshold(self):
|
27
|
-
threshold_layer =
|
28
|
-
input =
|
25
|
+
threshold_layer = brainstate.nn.Threshold(5, 20)
|
26
|
+
input = brainstate.random.randn(2)
|
29
27
|
output = threshold_layer(input)
|
30
28
|
|
31
29
|
def test_ReLU(self):
|
32
|
-
ReLU_layer =
|
33
|
-
input =
|
30
|
+
ReLU_layer = brainstate.nn.ReLU()
|
31
|
+
input = brainstate.random.randn(2)
|
34
32
|
output = ReLU_layer(input)
|
35
33
|
|
36
34
|
def test_RReLU(self):
|
37
|
-
RReLU_layer =
|
38
|
-
input =
|
35
|
+
RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
|
36
|
+
input = brainstate.random.randn(2)
|
39
37
|
output = RReLU_layer(input)
|
40
38
|
|
41
39
|
def test_Hardtanh(self):
|
42
|
-
Hardtanh_layer =
|
43
|
-
input =
|
40
|
+
Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
|
41
|
+
input = brainstate.random.randn(2)
|
44
42
|
output = Hardtanh_layer(input)
|
45
43
|
|
46
44
|
def test_ReLU6(self):
|
47
|
-
ReLU6_layer =
|
48
|
-
input =
|
45
|
+
ReLU6_layer = brainstate.nn.ReLU6()
|
46
|
+
input = brainstate.random.randn(2)
|
49
47
|
output = ReLU6_layer(input)
|
50
48
|
|
51
49
|
def test_Sigmoid(self):
|
52
|
-
Sigmoid_layer =
|
53
|
-
input =
|
50
|
+
Sigmoid_layer = brainstate.nn.Sigmoid()
|
51
|
+
input = brainstate.random.randn(2)
|
54
52
|
output = Sigmoid_layer(input)
|
55
53
|
|
56
54
|
def test_Hardsigmoid(self):
|
57
|
-
Hardsigmoid_layer =
|
58
|
-
input =
|
55
|
+
Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
|
56
|
+
input = brainstate.random.randn(2)
|
59
57
|
output = Hardsigmoid_layer(input)
|
60
58
|
|
61
59
|
def test_Tanh(self):
|
62
|
-
Tanh_layer =
|
63
|
-
input =
|
60
|
+
Tanh_layer = brainstate.nn.Tanh()
|
61
|
+
input = brainstate.random.randn(2)
|
64
62
|
output = Tanh_layer(input)
|
65
63
|
|
66
64
|
def test_SiLU(self):
|
67
|
-
SiLU_layer =
|
68
|
-
input =
|
65
|
+
SiLU_layer = brainstate.nn.SiLU()
|
66
|
+
input = brainstate.random.randn(2)
|
69
67
|
output = SiLU_layer(input)
|
70
68
|
|
71
69
|
def test_Mish(self):
|
72
|
-
Mish_layer =
|
73
|
-
input =
|
70
|
+
Mish_layer = brainstate.nn.Mish()
|
71
|
+
input = brainstate.random.randn(2)
|
74
72
|
output = Mish_layer(input)
|
75
73
|
|
76
74
|
def test_Hardswish(self):
|
77
|
-
Hardswish_layer =
|
78
|
-
input =
|
75
|
+
Hardswish_layer = brainstate.nn.Hardswish()
|
76
|
+
input = brainstate.random.randn(2)
|
79
77
|
output = Hardswish_layer(input)
|
80
78
|
|
81
79
|
def test_ELU(self):
|
82
|
-
ELU_layer =
|
83
|
-
input =
|
80
|
+
ELU_layer = brainstate.nn.ELU(alpha=0.5, )
|
81
|
+
input = brainstate.random.randn(2)
|
84
82
|
output = ELU_layer(input)
|
85
83
|
|
86
84
|
def test_CELU(self):
|
87
|
-
CELU_layer =
|
88
|
-
input =
|
85
|
+
CELU_layer = brainstate.nn.CELU(alpha=0.5, )
|
86
|
+
input = brainstate.random.randn(2)
|
89
87
|
output = CELU_layer(input)
|
90
88
|
|
91
89
|
def test_SELU(self):
|
92
|
-
SELU_layer =
|
93
|
-
input =
|
90
|
+
SELU_layer = brainstate.nn.SELU()
|
91
|
+
input = brainstate.random.randn(2)
|
94
92
|
output = SELU_layer(input)
|
95
93
|
|
96
94
|
def test_GLU(self):
|
97
|
-
GLU_layer =
|
98
|
-
input =
|
95
|
+
GLU_layer = brainstate.nn.GLU()
|
96
|
+
input = brainstate.random.randn(4, 2)
|
99
97
|
output = GLU_layer(input)
|
100
98
|
|
101
99
|
@parameterized.product(
|
102
100
|
approximate=['tanh', 'none']
|
103
101
|
)
|
104
102
|
def test_GELU(self, approximate):
|
105
|
-
GELU_layer =
|
106
|
-
input =
|
103
|
+
GELU_layer = brainstate.nn.GELU()
|
104
|
+
input = brainstate.random.randn(2)
|
107
105
|
output = GELU_layer(input)
|
108
106
|
|
109
107
|
def test_Hardshrink(self):
|
110
|
-
Hardshrink_layer =
|
111
|
-
input =
|
108
|
+
Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
|
109
|
+
input = brainstate.random.randn(2)
|
112
110
|
output = Hardshrink_layer(input)
|
113
111
|
|
114
112
|
def test_LeakyReLU(self):
|
115
|
-
LeakyReLU_layer =
|
116
|
-
input =
|
113
|
+
LeakyReLU_layer = brainstate.nn.LeakyReLU()
|
114
|
+
input = brainstate.random.randn(2)
|
117
115
|
output = LeakyReLU_layer(input)
|
118
116
|
|
119
117
|
def test_LogSigmoid(self):
|
120
|
-
LogSigmoid_layer =
|
121
|
-
input =
|
118
|
+
LogSigmoid_layer = brainstate.nn.LogSigmoid()
|
119
|
+
input = brainstate.random.randn(2)
|
122
120
|
output = LogSigmoid_layer(input)
|
123
121
|
|
124
122
|
def test_Softplus(self):
|
125
|
-
Softplus_layer =
|
126
|
-
input =
|
123
|
+
Softplus_layer = brainstate.nn.Softplus()
|
124
|
+
input = brainstate.random.randn(2)
|
127
125
|
output = Softplus_layer(input)
|
128
126
|
|
129
127
|
def test_Softshrink(self):
|
130
|
-
Softshrink_layer =
|
131
|
-
input =
|
128
|
+
Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
|
129
|
+
input = brainstate.random.randn(2)
|
132
130
|
output = Softshrink_layer(input)
|
133
131
|
|
134
132
|
def test_PReLU(self):
|
135
|
-
PReLU_layer =
|
136
|
-
input =
|
133
|
+
PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
|
134
|
+
input = brainstate.random.randn(2)
|
137
135
|
output = PReLU_layer(input)
|
138
136
|
|
139
137
|
def test_Softsign(self):
|
140
|
-
Softsign_layer =
|
141
|
-
input =
|
138
|
+
Softsign_layer = brainstate.nn.Softsign()
|
139
|
+
input = brainstate.random.randn(2)
|
142
140
|
output = Softsign_layer(input)
|
143
141
|
|
144
142
|
def test_Tanhshrink(self):
|
145
|
-
Tanhshrink_layer =
|
146
|
-
input =
|
143
|
+
Tanhshrink_layer = brainstate.nn.Tanhshrink()
|
144
|
+
input = brainstate.random.randn(2)
|
147
145
|
output = Tanhshrink_layer(input)
|
148
146
|
|
149
147
|
def test_Softmin(self):
|
150
|
-
Softmin_layer =
|
151
|
-
input =
|
148
|
+
Softmin_layer = brainstate.nn.Softmin(dim=2)
|
149
|
+
input = brainstate.random.randn(2, 3, 4)
|
152
150
|
output = Softmin_layer(input)
|
153
151
|
|
154
152
|
def test_Softmax(self):
|
155
|
-
Softmax_layer =
|
156
|
-
input =
|
153
|
+
Softmax_layer = brainstate.nn.Softmax(dim=2)
|
154
|
+
input = brainstate.random.randn(2, 3, 4)
|
157
155
|
output = Softmax_layer(input)
|
158
156
|
|
159
157
|
def test_Softmax2d(self):
|
160
|
-
Softmax2d_layer =
|
161
|
-
input =
|
158
|
+
Softmax2d_layer = brainstate.nn.Softmax2d()
|
159
|
+
input = brainstate.random.randn(2, 3, 12, 13)
|
162
160
|
output = Softmax2d_layer(input)
|
163
161
|
|
164
162
|
def test_LogSoftmax(self):
|
165
|
-
LogSoftmax_layer =
|
166
|
-
input =
|
163
|
+
LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
|
164
|
+
input = brainstate.random.randn(2, 3, 4)
|
167
165
|
output = LogSoftmax_layer(input)
|
168
166
|
|
169
167
|
|
brainstate/nn/_exp_euler_test.py
CHANGED
@@ -13,13 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
20
19
|
import brainunit as u
|
21
20
|
|
22
|
-
import brainstate
|
21
|
+
import brainstate
|
23
22
|
|
24
23
|
|
25
24
|
class TestExpEuler(unittest.TestCase):
|
@@ -27,10 +26,10 @@ class TestExpEuler(unittest.TestCase):
|
|
27
26
|
def fun(x, tau):
|
28
27
|
return -x / tau
|
29
28
|
|
30
|
-
with
|
29
|
+
with brainstate.environ.context(dt=0.1):
|
31
30
|
with self.assertRaises(AssertionError):
|
32
|
-
r =
|
31
|
+
r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
33
32
|
|
34
|
-
with
|
35
|
-
r =
|
33
|
+
with brainstate.environ.context(dt=1. * u.ms):
|
34
|
+
r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
|
36
35
|
print(r)
|
@@ -1,13 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
3
|
import jax.numpy as jnp
|
6
4
|
import pytest
|
7
5
|
from absl.testing import absltest
|
8
6
|
from absl.testing import parameterized
|
9
7
|
|
10
|
-
import brainstate
|
8
|
+
import brainstate
|
11
9
|
|
12
10
|
|
13
11
|
class TestConv(parameterized.TestCase):
|
@@ -19,8 +17,8 @@ class TestConv(parameterized.TestCase):
|
|
19
17
|
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
|
20
18
|
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
|
21
19
|
|
22
|
-
net =
|
23
|
-
|
20
|
+
net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
|
21
|
+
stride=(2, 1), padding='VALID', groups=4)
|
24
22
|
out = net(img)
|
25
23
|
print("out shape: ", out.shape)
|
26
24
|
self.assertEqual(out.shape, (2, 99, 196, 32))
|
@@ -30,7 +28,7 @@ class TestConv(parameterized.TestCase):
|
|
30
28
|
# plt.show()
|
31
29
|
|
32
30
|
def test_conv1D(self):
|
33
|
-
model =
|
31
|
+
model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
|
34
32
|
input = jnp.ones((2, 5, 3))
|
35
33
|
out = model(input)
|
36
34
|
print("out shape: ", out.shape)
|
@@ -41,7 +39,7 @@ class TestConv(parameterized.TestCase):
|
|
41
39
|
# plt.show()
|
42
40
|
|
43
41
|
def test_conv2D(self):
|
44
|
-
model =
|
42
|
+
model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
|
45
43
|
input = jnp.ones((2, 5, 5, 3))
|
46
44
|
|
47
45
|
out = model(input)
|
@@ -49,7 +47,7 @@ class TestConv(parameterized.TestCase):
|
|
49
47
|
self.assertEqual(out.shape, (2, 5, 5, 32))
|
50
48
|
|
51
49
|
def test_conv3D(self):
|
52
|
-
model =
|
50
|
+
model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
|
53
51
|
input = jnp.ones((2, 5, 5, 5, 3))
|
54
52
|
out = model(input)
|
55
53
|
print("out shape: ", out.shape)
|
@@ -62,13 +60,13 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
62
60
|
|
63
61
|
x = jnp.ones((1, 8, 3))
|
64
62
|
for use_bias in [True, False]:
|
65
|
-
conv_transpose_module =
|
63
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
66
64
|
in_channels=3,
|
67
65
|
out_channels=4,
|
68
66
|
kernel_size=(3,),
|
69
67
|
padding='VALID',
|
70
|
-
w_initializer=
|
71
|
-
b_initializer=
|
68
|
+
w_initializer=brainstate.init.Constant(1.),
|
69
|
+
b_initializer=brainstate.init.Constant(1.) if use_bias else None,
|
72
70
|
)
|
73
71
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
74
72
|
y = conv_transpose_module(x)
|
@@ -91,14 +89,14 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
91
89
|
|
92
90
|
x = jnp.ones((1, 8, 3))
|
93
91
|
m = jnp.tril(jnp.ones((3, 3, 4)))
|
94
|
-
conv_transpose_module =
|
92
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
95
93
|
in_channels=3,
|
96
94
|
out_channels=4,
|
97
95
|
kernel_size=(3,),
|
98
96
|
padding='VALID',
|
99
97
|
mask=m,
|
100
|
-
w_initializer=
|
101
|
-
b_initializer=
|
98
|
+
w_initializer=brainstate.init.Constant(),
|
99
|
+
b_initializer=brainstate.init.Constant(),
|
102
100
|
)
|
103
101
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
104
102
|
y = conv_transpose_module(x)
|
@@ -119,14 +117,14 @@ class TestConvTranspose1d(parameterized.TestCase):
|
|
119
117
|
|
120
118
|
data = jnp.ones([1, 3, 1])
|
121
119
|
for use_bias in [True, False]:
|
122
|
-
net =
|
120
|
+
net = brainstate.nn.ConvTranspose1d(
|
123
121
|
in_channels=1,
|
124
122
|
out_channels=1,
|
125
123
|
kernel_size=3,
|
126
124
|
stride=1,
|
127
125
|
padding="SAME",
|
128
|
-
w_initializer=
|
129
|
-
b_initializer=
|
126
|
+
w_initializer=brainstate.init.Constant(),
|
127
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
130
128
|
)
|
131
129
|
out = net(data)
|
132
130
|
self.assertEqual(out.shape, (1, 3, 1))
|
@@ -143,13 +141,13 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
143
141
|
|
144
142
|
x = jnp.ones((1, 8, 8, 3))
|
145
143
|
for use_bias in [True, False]:
|
146
|
-
conv_transpose_module =
|
144
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
147
145
|
in_channels=3,
|
148
146
|
out_channels=4,
|
149
147
|
kernel_size=(3, 3),
|
150
148
|
padding='VALID',
|
151
|
-
w_initializer=
|
152
|
-
b_initializer=
|
149
|
+
w_initializer=brainstate.init.Constant(),
|
150
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
153
151
|
)
|
154
152
|
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
155
153
|
y = conv_transpose_module(x)
|
@@ -159,13 +157,13 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
159
157
|
|
160
158
|
x = jnp.ones((1, 8, 8, 3))
|
161
159
|
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
162
|
-
conv_transpose_module =
|
160
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
163
161
|
in_channels=3,
|
164
162
|
out_channels=4,
|
165
163
|
kernel_size=(3, 3),
|
166
164
|
padding='VALID',
|
167
165
|
mask=m,
|
168
|
-
w_initializer=
|
166
|
+
w_initializer=brainstate.init.Constant(),
|
169
167
|
)
|
170
168
|
y = conv_transpose_module(x)
|
171
169
|
print(y.shape)
|
@@ -174,14 +172,14 @@ class TestConvTranspose2d(parameterized.TestCase):
|
|
174
172
|
|
175
173
|
x = jnp.ones((1, 8, 8, 3))
|
176
174
|
for use_bias in [True, False]:
|
177
|
-
conv_transpose_module =
|
175
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
178
176
|
in_channels=3,
|
179
177
|
out_channels=4,
|
180
178
|
kernel_size=(3, 3),
|
181
179
|
stride=1,
|
182
180
|
padding='SAME',
|
183
|
-
w_initializer=
|
184
|
-
b_initializer=
|
181
|
+
w_initializer=brainstate.init.Constant(),
|
182
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
185
183
|
)
|
186
184
|
y = conv_transpose_module(x)
|
187
185
|
print(y.shape)
|
@@ -193,13 +191,13 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
193
191
|
|
194
192
|
x = jnp.ones((1, 8, 8, 8, 3))
|
195
193
|
for use_bias in [True, False]:
|
196
|
-
conv_transpose_module =
|
194
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
197
195
|
in_channels=3,
|
198
196
|
out_channels=4,
|
199
197
|
kernel_size=(3, 3, 3),
|
200
198
|
padding='VALID',
|
201
|
-
w_initializer=
|
202
|
-
b_initializer=
|
199
|
+
w_initializer=brainstate.init.Constant(),
|
200
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
203
201
|
)
|
204
202
|
y = conv_transpose_module(x)
|
205
203
|
print(y.shape)
|
@@ -208,13 +206,13 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
208
206
|
|
209
207
|
x = jnp.ones((1, 8, 8, 8, 3))
|
210
208
|
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
211
|
-
conv_transpose_module =
|
209
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
212
210
|
in_channels=3,
|
213
211
|
out_channels=4,
|
214
212
|
kernel_size=(3, 3, 3),
|
215
213
|
padding='VALID',
|
216
214
|
mask=m,
|
217
|
-
w_initializer=
|
215
|
+
w_initializer=brainstate.init.Constant(),
|
218
216
|
)
|
219
217
|
y = conv_transpose_module(x)
|
220
218
|
print(y.shape)
|
@@ -223,14 +221,14 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
223
221
|
|
224
222
|
x = jnp.ones((1, 8, 8, 8, 3))
|
225
223
|
for use_bias in [True, False]:
|
226
|
-
conv_transpose_module =
|
224
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
227
225
|
in_channels=3,
|
228
226
|
out_channels=4,
|
229
227
|
kernel_size=(3, 3, 3),
|
230
228
|
stride=1,
|
231
229
|
padding='SAME',
|
232
|
-
w_initializer=
|
233
|
-
b_initializer=
|
230
|
+
w_initializer=brainstate.init.Constant(),
|
231
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
234
232
|
)
|
235
233
|
y = conv_transpose_module(x)
|
236
234
|
print(y.shape)
|
@@ -14,14 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
17
|
import unittest
|
20
18
|
|
21
19
|
import brainunit as u
|
22
20
|
from absl.testing import parameterized
|
23
21
|
|
24
|
-
import brainstate
|
22
|
+
import brainstate
|
25
23
|
|
26
24
|
|
27
25
|
class TestDense(parameterized.TestCase):
|
@@ -32,19 +30,19 @@ class TestDense(parameterized.TestCase):
|
|
32
30
|
num_out=[20, ]
|
33
31
|
)
|
34
32
|
def test_Dense1(self, size, num_out):
|
35
|
-
f =
|
36
|
-
x =
|
33
|
+
f = brainstate.nn.Linear(10, num_out)
|
34
|
+
x = brainstate.random.random(size)
|
37
35
|
y = f(x)
|
38
36
|
self.assertTrue(y.shape == size[:-1] + (num_out,))
|
39
37
|
|
40
38
|
|
41
39
|
class TestSparseMatrix(unittest.TestCase):
|
42
40
|
def test_csr(self):
|
43
|
-
data =
|
41
|
+
data = brainstate.random.rand(10, 20)
|
44
42
|
data = data * (data > 0.9)
|
45
|
-
f =
|
43
|
+
f = brainstate.nn.SparseLinear(u.sparse.CSR.fromdense(data))
|
46
44
|
|
47
|
-
x =
|
45
|
+
x = brainstate.random.rand(10)
|
48
46
|
y = f(x)
|
49
47
|
self.assertTrue(
|
50
48
|
u.math.allclose(
|
@@ -53,7 +51,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
53
51
|
)
|
54
52
|
)
|
55
53
|
|
56
|
-
x =
|
54
|
+
x = brainstate.random.rand(5, 10)
|
57
55
|
y = f(x)
|
58
56
|
self.assertTrue(
|
59
57
|
u.math.allclose(
|
@@ -63,11 +61,11 @@ class TestSparseMatrix(unittest.TestCase):
|
|
63
61
|
)
|
64
62
|
|
65
63
|
def test_csc(self):
|
66
|
-
data =
|
64
|
+
data = brainstate.random.rand(10, 20)
|
67
65
|
data = data * (data > 0.9)
|
68
|
-
f =
|
66
|
+
f = brainstate.nn.SparseLinear(u.sparse.CSC.fromdense(data))
|
69
67
|
|
70
|
-
x =
|
68
|
+
x = brainstate.random.rand(10)
|
71
69
|
y = f(x)
|
72
70
|
self.assertTrue(
|
73
71
|
u.math.allclose(
|
@@ -76,7 +74,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
76
74
|
)
|
77
75
|
)
|
78
76
|
|
79
|
-
x =
|
77
|
+
x = brainstate.random.rand(5, 10)
|
80
78
|
y = f(x)
|
81
79
|
self.assertTrue(
|
82
80
|
u.math.allclose(
|
@@ -86,11 +84,11 @@ class TestSparseMatrix(unittest.TestCase):
|
|
86
84
|
)
|
87
85
|
|
88
86
|
def test_coo(self):
|
89
|
-
data =
|
87
|
+
data = brainstate.random.rand(10, 20)
|
90
88
|
data = data * (data > 0.9)
|
91
|
-
f =
|
89
|
+
f = brainstate.nn.SparseLinear(u.sparse.COO.fromdense(data))
|
92
90
|
|
93
|
-
x =
|
91
|
+
x = brainstate.random.rand(10)
|
94
92
|
y = f(x)
|
95
93
|
self.assertTrue(
|
96
94
|
u.math.allclose(
|
@@ -99,7 +97,7 @@ class TestSparseMatrix(unittest.TestCase):
|
|
99
97
|
)
|
100
98
|
)
|
101
99
|
|
102
|
-
x =
|
100
|
+
x = brainstate.random.rand(5, 10)
|
103
101
|
y = f(x)
|
104
102
|
self.assertTrue(
|
105
103
|
u.math.allclose(
|
@@ -13,12 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
from absl.testing import absltest
|
19
17
|
from absl.testing import parameterized
|
20
18
|
|
21
|
-
import brainstate
|
19
|
+
import brainstate
|
22
20
|
|
23
21
|
|
24
22
|
class Test_Normalization(parameterized.TestCase):
|
@@ -26,27 +24,27 @@ class Test_Normalization(parameterized.TestCase):
|
|
26
24
|
fit=[True, False],
|
27
25
|
)
|
28
26
|
def test_BatchNorm1d(self, fit):
|
29
|
-
net =
|
30
|
-
|
31
|
-
input =
|
27
|
+
net = brainstate.nn.BatchNorm1d((3, 10))
|
28
|
+
brainstate.environ.set(fit=fit)
|
29
|
+
input = brainstate.random.randn(1, 3, 10)
|
32
30
|
output = net(input)
|
33
31
|
|
34
32
|
@parameterized.product(
|
35
33
|
fit=[True, False]
|
36
34
|
)
|
37
35
|
def test_BatchNorm2d(self, fit):
|
38
|
-
net =
|
39
|
-
|
40
|
-
input =
|
36
|
+
net = brainstate.nn.BatchNorm2d([3, 4, 10])
|
37
|
+
brainstate.environ.set(fit=fit)
|
38
|
+
input = brainstate.random.randn(1, 3, 4, 10)
|
41
39
|
output = net(input)
|
42
40
|
|
43
41
|
@parameterized.product(
|
44
42
|
fit=[True, False]
|
45
43
|
)
|
46
44
|
def test_BatchNorm3d(self, fit):
|
47
|
-
net =
|
48
|
-
|
49
|
-
input =
|
45
|
+
net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
|
46
|
+
brainstate.environ.set(fit=fit)
|
47
|
+
input = brainstate.random.randn(1, 3, 4, 5, 10)
|
50
48
|
output = net(input)
|
51
49
|
|
52
50
|
# @parameterized.product(
|