brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_conv_test.py
CHANGED
@@ -1,238 +1,238 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import pytest
|
5
|
-
from absl.testing import absltest
|
6
|
-
from absl.testing import parameterized
|
7
|
-
|
8
|
-
import brainstate
|
9
|
-
|
10
|
-
|
11
|
-
class TestConv(parameterized.TestCase):
|
12
|
-
def test_Conv2D_img(self):
|
13
|
-
img = jnp.zeros((2, 200, 198, 4))
|
14
|
-
for k in range(4):
|
15
|
-
x = 30 + 60 * k
|
16
|
-
y = 20 + 60 * k
|
17
|
-
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
|
18
|
-
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
|
19
|
-
|
20
|
-
net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
|
21
|
-
stride=(2, 1), padding='VALID', groups=4)
|
22
|
-
out = net(img)
|
23
|
-
print("out shape: ", out.shape)
|
24
|
-
self.assertEqual(out.shape, (2, 99, 196, 32))
|
25
|
-
# print("First output channel:")
|
26
|
-
# plt.figure(figsize=(10, 10))
|
27
|
-
# plt.imshow(np.array(img)[0, :, :, 0])
|
28
|
-
# plt.show()
|
29
|
-
|
30
|
-
def test_conv1D(self):
|
31
|
-
model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
|
32
|
-
input = jnp.ones((2, 5, 3))
|
33
|
-
out = model(input)
|
34
|
-
print("out shape: ", out.shape)
|
35
|
-
self.assertEqual(out.shape, (2, 5, 32))
|
36
|
-
# print("First output channel:")
|
37
|
-
# plt.figure(figsize=(10, 10))
|
38
|
-
# plt.imshow(np.array(out)[0, :, :])
|
39
|
-
# plt.show()
|
40
|
-
|
41
|
-
def test_conv2D(self):
|
42
|
-
model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
|
43
|
-
input = jnp.ones((2, 5, 5, 3))
|
44
|
-
|
45
|
-
out = model(input)
|
46
|
-
print("out shape: ", out.shape)
|
47
|
-
self.assertEqual(out.shape, (2, 5, 5, 32))
|
48
|
-
|
49
|
-
def test_conv3D(self):
|
50
|
-
model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
|
51
|
-
input = jnp.ones((2, 5, 5, 5, 3))
|
52
|
-
out = model(input)
|
53
|
-
print("out shape: ", out.shape)
|
54
|
-
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
|
55
|
-
|
56
|
-
|
57
|
-
@pytest.mark.skip(reason="not implemented yet")
|
58
|
-
class TestConvTranspose1d(parameterized.TestCase):
|
59
|
-
def test_conv_transpose(self):
|
60
|
-
|
61
|
-
x = jnp.ones((1, 8, 3))
|
62
|
-
for use_bias in [True, False]:
|
63
|
-
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
64
|
-
in_channels=3,
|
65
|
-
out_channels=4,
|
66
|
-
kernel_size=(3,),
|
67
|
-
padding='VALID',
|
68
|
-
w_initializer=brainstate.init.Constant(1.),
|
69
|
-
b_initializer=brainstate.init.Constant(1.) if use_bias else None,
|
70
|
-
)
|
71
|
-
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
72
|
-
y = conv_transpose_module(x)
|
73
|
-
print(y.shape)
|
74
|
-
correct_ans = jnp.array([[[4., 4., 4., 4.],
|
75
|
-
[7., 7., 7., 7.],
|
76
|
-
[10., 10., 10., 10.],
|
77
|
-
[10., 10., 10., 10.],
|
78
|
-
[10., 10., 10., 10.],
|
79
|
-
[10., 10., 10., 10.],
|
80
|
-
[10., 10., 10., 10.],
|
81
|
-
[10., 10., 10., 10.],
|
82
|
-
[7., 7., 7., 7.],
|
83
|
-
[4., 4., 4., 4.]]])
|
84
|
-
if not use_bias:
|
85
|
-
correct_ans -= 1.
|
86
|
-
self.assertTrue(jnp.allclose(y, correct_ans))
|
87
|
-
|
88
|
-
def test_single_input_masked_conv_transpose(self):
|
89
|
-
|
90
|
-
x = jnp.ones((1, 8, 3))
|
91
|
-
m = jnp.tril(jnp.ones((3, 3, 4)))
|
92
|
-
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
93
|
-
in_channels=3,
|
94
|
-
out_channels=4,
|
95
|
-
kernel_size=(3,),
|
96
|
-
padding='VALID',
|
97
|
-
mask=m,
|
98
|
-
w_initializer=brainstate.init.Constant(),
|
99
|
-
b_initializer=brainstate.init.Constant(),
|
100
|
-
)
|
101
|
-
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
102
|
-
y = conv_transpose_module(x)
|
103
|
-
print(y.shape)
|
104
|
-
correct_ans = jnp.array([[[4., 3., 2., 1.],
|
105
|
-
[7., 5., 3., 1.],
|
106
|
-
[10., 7., 4., 1.],
|
107
|
-
[10., 7., 4., 1.],
|
108
|
-
[10., 7., 4., 1.],
|
109
|
-
[10., 7., 4., 1.],
|
110
|
-
[10., 7., 4., 1.],
|
111
|
-
[10., 7., 4., 1.],
|
112
|
-
[7., 5., 3., 1.],
|
113
|
-
[4., 3., 2., 1.]]])
|
114
|
-
self.assertTrue(jnp.allclose(y, correct_ans))
|
115
|
-
|
116
|
-
def test_computation_padding_same(self):
|
117
|
-
|
118
|
-
data = jnp.ones([1, 3, 1])
|
119
|
-
for use_bias in [True, False]:
|
120
|
-
net = brainstate.nn.ConvTranspose1d(
|
121
|
-
in_channels=1,
|
122
|
-
out_channels=1,
|
123
|
-
kernel_size=3,
|
124
|
-
stride=1,
|
125
|
-
padding="SAME",
|
126
|
-
w_initializer=brainstate.init.Constant(),
|
127
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
128
|
-
)
|
129
|
-
out = net(data)
|
130
|
-
self.assertEqual(out.shape, (1, 3, 1))
|
131
|
-
out = jnp.squeeze(out, axis=(0, 2))
|
132
|
-
expected_out = jnp.asarray([2, 3, 2])
|
133
|
-
if use_bias:
|
134
|
-
expected_out += 1
|
135
|
-
self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
|
136
|
-
|
137
|
-
|
138
|
-
@pytest.mark.skip(reason="not implemented yet")
|
139
|
-
class TestConvTranspose2d(parameterized.TestCase):
|
140
|
-
def test_conv_transpose(self):
|
141
|
-
|
142
|
-
x = jnp.ones((1, 8, 8, 3))
|
143
|
-
for use_bias in [True, False]:
|
144
|
-
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
145
|
-
in_channels=3,
|
146
|
-
out_channels=4,
|
147
|
-
kernel_size=(3, 3),
|
148
|
-
padding='VALID',
|
149
|
-
w_initializer=brainstate.init.Constant(),
|
150
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
151
|
-
)
|
152
|
-
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
153
|
-
y = conv_transpose_module(x)
|
154
|
-
print(y.shape)
|
155
|
-
|
156
|
-
def test_single_input_masked_conv_transpose(self):
|
157
|
-
|
158
|
-
x = jnp.ones((1, 8, 8, 3))
|
159
|
-
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
160
|
-
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
161
|
-
in_channels=3,
|
162
|
-
out_channels=4,
|
163
|
-
kernel_size=(3, 3),
|
164
|
-
padding='VALID',
|
165
|
-
mask=m,
|
166
|
-
w_initializer=brainstate.init.Constant(),
|
167
|
-
)
|
168
|
-
y = conv_transpose_module(x)
|
169
|
-
print(y.shape)
|
170
|
-
|
171
|
-
def test_computation_padding_same(self):
|
172
|
-
|
173
|
-
x = jnp.ones((1, 8, 8, 3))
|
174
|
-
for use_bias in [True, False]:
|
175
|
-
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
176
|
-
in_channels=3,
|
177
|
-
out_channels=4,
|
178
|
-
kernel_size=(3, 3),
|
179
|
-
stride=1,
|
180
|
-
padding='SAME',
|
181
|
-
w_initializer=brainstate.init.Constant(),
|
182
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
183
|
-
)
|
184
|
-
y = conv_transpose_module(x)
|
185
|
-
print(y.shape)
|
186
|
-
|
187
|
-
|
188
|
-
@pytest.mark.skip(reason="not implemented yet")
|
189
|
-
class TestConvTranspose3d(parameterized.TestCase):
|
190
|
-
def test_conv_transpose(self):
|
191
|
-
|
192
|
-
x = jnp.ones((1, 8, 8, 8, 3))
|
193
|
-
for use_bias in [True, False]:
|
194
|
-
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
195
|
-
in_channels=3,
|
196
|
-
out_channels=4,
|
197
|
-
kernel_size=(3, 3, 3),
|
198
|
-
padding='VALID',
|
199
|
-
w_initializer=brainstate.init.Constant(),
|
200
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
201
|
-
)
|
202
|
-
y = conv_transpose_module(x)
|
203
|
-
print(y.shape)
|
204
|
-
|
205
|
-
def test_single_input_masked_conv_transpose(self):
|
206
|
-
|
207
|
-
x = jnp.ones((1, 8, 8, 8, 3))
|
208
|
-
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
209
|
-
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
210
|
-
in_channels=3,
|
211
|
-
out_channels=4,
|
212
|
-
kernel_size=(3, 3, 3),
|
213
|
-
padding='VALID',
|
214
|
-
mask=m,
|
215
|
-
w_initializer=brainstate.init.Constant(),
|
216
|
-
)
|
217
|
-
y = conv_transpose_module(x)
|
218
|
-
print(y.shape)
|
219
|
-
|
220
|
-
def test_computation_padding_same(self):
|
221
|
-
|
222
|
-
x = jnp.ones((1, 8, 8, 8, 3))
|
223
|
-
for use_bias in [True, False]:
|
224
|
-
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
225
|
-
in_channels=3,
|
226
|
-
out_channels=4,
|
227
|
-
kernel_size=(3, 3, 3),
|
228
|
-
stride=1,
|
229
|
-
padding='SAME',
|
230
|
-
w_initializer=brainstate.init.Constant(),
|
231
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
232
|
-
)
|
233
|
-
y = conv_transpose_module(x)
|
234
|
-
print(y.shape)
|
235
|
-
|
236
|
-
|
237
|
-
if __name__ == '__main__':
|
238
|
-
absltest.main()
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import pytest
|
5
|
+
from absl.testing import absltest
|
6
|
+
from absl.testing import parameterized
|
7
|
+
|
8
|
+
import brainstate
|
9
|
+
|
10
|
+
|
11
|
+
class TestConv(parameterized.TestCase):
|
12
|
+
def test_Conv2D_img(self):
|
13
|
+
img = jnp.zeros((2, 200, 198, 4))
|
14
|
+
for k in range(4):
|
15
|
+
x = 30 + 60 * k
|
16
|
+
y = 20 + 60 * k
|
17
|
+
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
|
18
|
+
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
|
19
|
+
|
20
|
+
net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
|
21
|
+
stride=(2, 1), padding='VALID', groups=4)
|
22
|
+
out = net(img)
|
23
|
+
print("out shape: ", out.shape)
|
24
|
+
self.assertEqual(out.shape, (2, 99, 196, 32))
|
25
|
+
# print("First output channel:")
|
26
|
+
# plt.figure(figsize=(10, 10))
|
27
|
+
# plt.imshow(np.array(img)[0, :, :, 0])
|
28
|
+
# plt.show()
|
29
|
+
|
30
|
+
def test_conv1D(self):
|
31
|
+
model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
|
32
|
+
input = jnp.ones((2, 5, 3))
|
33
|
+
out = model(input)
|
34
|
+
print("out shape: ", out.shape)
|
35
|
+
self.assertEqual(out.shape, (2, 5, 32))
|
36
|
+
# print("First output channel:")
|
37
|
+
# plt.figure(figsize=(10, 10))
|
38
|
+
# plt.imshow(np.array(out)[0, :, :])
|
39
|
+
# plt.show()
|
40
|
+
|
41
|
+
def test_conv2D(self):
|
42
|
+
model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
|
43
|
+
input = jnp.ones((2, 5, 5, 3))
|
44
|
+
|
45
|
+
out = model(input)
|
46
|
+
print("out shape: ", out.shape)
|
47
|
+
self.assertEqual(out.shape, (2, 5, 5, 32))
|
48
|
+
|
49
|
+
def test_conv3D(self):
|
50
|
+
model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
|
51
|
+
input = jnp.ones((2, 5, 5, 5, 3))
|
52
|
+
out = model(input)
|
53
|
+
print("out shape: ", out.shape)
|
54
|
+
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
|
55
|
+
|
56
|
+
|
57
|
+
@pytest.mark.skip(reason="not implemented yet")
|
58
|
+
class TestConvTranspose1d(parameterized.TestCase):
|
59
|
+
def test_conv_transpose(self):
|
60
|
+
|
61
|
+
x = jnp.ones((1, 8, 3))
|
62
|
+
for use_bias in [True, False]:
|
63
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
64
|
+
in_channels=3,
|
65
|
+
out_channels=4,
|
66
|
+
kernel_size=(3,),
|
67
|
+
padding='VALID',
|
68
|
+
w_initializer=brainstate.init.Constant(1.),
|
69
|
+
b_initializer=brainstate.init.Constant(1.) if use_bias else None,
|
70
|
+
)
|
71
|
+
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
72
|
+
y = conv_transpose_module(x)
|
73
|
+
print(y.shape)
|
74
|
+
correct_ans = jnp.array([[[4., 4., 4., 4.],
|
75
|
+
[7., 7., 7., 7.],
|
76
|
+
[10., 10., 10., 10.],
|
77
|
+
[10., 10., 10., 10.],
|
78
|
+
[10., 10., 10., 10.],
|
79
|
+
[10., 10., 10., 10.],
|
80
|
+
[10., 10., 10., 10.],
|
81
|
+
[10., 10., 10., 10.],
|
82
|
+
[7., 7., 7., 7.],
|
83
|
+
[4., 4., 4., 4.]]])
|
84
|
+
if not use_bias:
|
85
|
+
correct_ans -= 1.
|
86
|
+
self.assertTrue(jnp.allclose(y, correct_ans))
|
87
|
+
|
88
|
+
def test_single_input_masked_conv_transpose(self):
|
89
|
+
|
90
|
+
x = jnp.ones((1, 8, 3))
|
91
|
+
m = jnp.tril(jnp.ones((3, 3, 4)))
|
92
|
+
conv_transpose_module = brainstate.nn.ConvTranspose1d(
|
93
|
+
in_channels=3,
|
94
|
+
out_channels=4,
|
95
|
+
kernel_size=(3,),
|
96
|
+
padding='VALID',
|
97
|
+
mask=m,
|
98
|
+
w_initializer=brainstate.init.Constant(),
|
99
|
+
b_initializer=brainstate.init.Constant(),
|
100
|
+
)
|
101
|
+
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
102
|
+
y = conv_transpose_module(x)
|
103
|
+
print(y.shape)
|
104
|
+
correct_ans = jnp.array([[[4., 3., 2., 1.],
|
105
|
+
[7., 5., 3., 1.],
|
106
|
+
[10., 7., 4., 1.],
|
107
|
+
[10., 7., 4., 1.],
|
108
|
+
[10., 7., 4., 1.],
|
109
|
+
[10., 7., 4., 1.],
|
110
|
+
[10., 7., 4., 1.],
|
111
|
+
[10., 7., 4., 1.],
|
112
|
+
[7., 5., 3., 1.],
|
113
|
+
[4., 3., 2., 1.]]])
|
114
|
+
self.assertTrue(jnp.allclose(y, correct_ans))
|
115
|
+
|
116
|
+
def test_computation_padding_same(self):
|
117
|
+
|
118
|
+
data = jnp.ones([1, 3, 1])
|
119
|
+
for use_bias in [True, False]:
|
120
|
+
net = brainstate.nn.ConvTranspose1d(
|
121
|
+
in_channels=1,
|
122
|
+
out_channels=1,
|
123
|
+
kernel_size=3,
|
124
|
+
stride=1,
|
125
|
+
padding="SAME",
|
126
|
+
w_initializer=brainstate.init.Constant(),
|
127
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
128
|
+
)
|
129
|
+
out = net(data)
|
130
|
+
self.assertEqual(out.shape, (1, 3, 1))
|
131
|
+
out = jnp.squeeze(out, axis=(0, 2))
|
132
|
+
expected_out = jnp.asarray([2, 3, 2])
|
133
|
+
if use_bias:
|
134
|
+
expected_out += 1
|
135
|
+
self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
|
136
|
+
|
137
|
+
|
138
|
+
@pytest.mark.skip(reason="not implemented yet")
|
139
|
+
class TestConvTranspose2d(parameterized.TestCase):
|
140
|
+
def test_conv_transpose(self):
|
141
|
+
|
142
|
+
x = jnp.ones((1, 8, 8, 3))
|
143
|
+
for use_bias in [True, False]:
|
144
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
145
|
+
in_channels=3,
|
146
|
+
out_channels=4,
|
147
|
+
kernel_size=(3, 3),
|
148
|
+
padding='VALID',
|
149
|
+
w_initializer=brainstate.init.Constant(),
|
150
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
151
|
+
)
|
152
|
+
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
153
|
+
y = conv_transpose_module(x)
|
154
|
+
print(y.shape)
|
155
|
+
|
156
|
+
def test_single_input_masked_conv_transpose(self):
|
157
|
+
|
158
|
+
x = jnp.ones((1, 8, 8, 3))
|
159
|
+
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
160
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
161
|
+
in_channels=3,
|
162
|
+
out_channels=4,
|
163
|
+
kernel_size=(3, 3),
|
164
|
+
padding='VALID',
|
165
|
+
mask=m,
|
166
|
+
w_initializer=brainstate.init.Constant(),
|
167
|
+
)
|
168
|
+
y = conv_transpose_module(x)
|
169
|
+
print(y.shape)
|
170
|
+
|
171
|
+
def test_computation_padding_same(self):
|
172
|
+
|
173
|
+
x = jnp.ones((1, 8, 8, 3))
|
174
|
+
for use_bias in [True, False]:
|
175
|
+
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
176
|
+
in_channels=3,
|
177
|
+
out_channels=4,
|
178
|
+
kernel_size=(3, 3),
|
179
|
+
stride=1,
|
180
|
+
padding='SAME',
|
181
|
+
w_initializer=brainstate.init.Constant(),
|
182
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
183
|
+
)
|
184
|
+
y = conv_transpose_module(x)
|
185
|
+
print(y.shape)
|
186
|
+
|
187
|
+
|
188
|
+
@pytest.mark.skip(reason="not implemented yet")
|
189
|
+
class TestConvTranspose3d(parameterized.TestCase):
|
190
|
+
def test_conv_transpose(self):
|
191
|
+
|
192
|
+
x = jnp.ones((1, 8, 8, 8, 3))
|
193
|
+
for use_bias in [True, False]:
|
194
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
195
|
+
in_channels=3,
|
196
|
+
out_channels=4,
|
197
|
+
kernel_size=(3, 3, 3),
|
198
|
+
padding='VALID',
|
199
|
+
w_initializer=brainstate.init.Constant(),
|
200
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
201
|
+
)
|
202
|
+
y = conv_transpose_module(x)
|
203
|
+
print(y.shape)
|
204
|
+
|
205
|
+
def test_single_input_masked_conv_transpose(self):
|
206
|
+
|
207
|
+
x = jnp.ones((1, 8, 8, 8, 3))
|
208
|
+
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
209
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
210
|
+
in_channels=3,
|
211
|
+
out_channels=4,
|
212
|
+
kernel_size=(3, 3, 3),
|
213
|
+
padding='VALID',
|
214
|
+
mask=m,
|
215
|
+
w_initializer=brainstate.init.Constant(),
|
216
|
+
)
|
217
|
+
y = conv_transpose_module(x)
|
218
|
+
print(y.shape)
|
219
|
+
|
220
|
+
def test_computation_padding_same(self):
|
221
|
+
|
222
|
+
x = jnp.ones((1, 8, 8, 8, 3))
|
223
|
+
for use_bias in [True, False]:
|
224
|
+
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
225
|
+
in_channels=3,
|
226
|
+
out_channels=4,
|
227
|
+
kernel_size=(3, 3, 3),
|
228
|
+
stride=1,
|
229
|
+
padding='SAME',
|
230
|
+
w_initializer=brainstate.init.Constant(),
|
231
|
+
b_initializer=brainstate.init.Constant() if use_bias else None,
|
232
|
+
)
|
233
|
+
y = conv_transpose_module(x)
|
234
|
+
print(y.shape)
|
235
|
+
|
236
|
+
|
237
|
+
if __name__ == '__main__':
|
238
|
+
absltest.main()
|