brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_conv_test.py
CHANGED
@@ -1,238 +1,849 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
1
16
|
# -*- coding: utf-8 -*-
|
2
17
|
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax
|
3
21
|
import jax.numpy as jnp
|
4
|
-
import pytest
|
5
|
-
from absl.testing import absltest
|
6
|
-
from absl.testing import parameterized
|
7
22
|
|
8
23
|
import brainstate
|
9
24
|
|
10
25
|
|
11
|
-
class
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
self.assertEqual(
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
26
|
+
class TestConv1d(unittest.TestCase):
|
27
|
+
"""Test cases for 1D convolution."""
|
28
|
+
|
29
|
+
def test_basic_channels_last(self):
|
30
|
+
"""Test basic Conv1d with channels-last format."""
|
31
|
+
conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
|
32
|
+
x = jnp.ones((4, 100, 16))
|
33
|
+
y = conv(x)
|
34
|
+
|
35
|
+
self.assertEqual(y.shape, (4, 100, 32))
|
36
|
+
self.assertEqual(conv.in_channels, 16)
|
37
|
+
self.assertEqual(conv.out_channels, 32)
|
38
|
+
self.assertFalse(conv.channel_first)
|
39
|
+
|
40
|
+
def test_basic_channels_first(self):
|
41
|
+
"""Test basic Conv1d with channels-first format."""
|
42
|
+
conv = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=5, channel_first=True)
|
43
|
+
x = jnp.ones((4, 16, 100))
|
44
|
+
y = conv(x)
|
45
|
+
|
46
|
+
self.assertEqual(y.shape, (4, 32, 100))
|
47
|
+
self.assertEqual(conv.in_channels, 16)
|
48
|
+
self.assertEqual(conv.out_channels, 32)
|
49
|
+
self.assertTrue(conv.channel_first)
|
50
|
+
|
51
|
+
def test_without_batch(self):
|
52
|
+
"""Test Conv1d without batch dimension."""
|
53
|
+
conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3)
|
54
|
+
x = jnp.ones((50, 8))
|
55
|
+
y = conv(x)
|
56
|
+
|
57
|
+
self.assertEqual(y.shape, (50, 16))
|
58
|
+
|
59
|
+
def test_stride(self):
|
60
|
+
"""Test Conv1d with stride."""
|
61
|
+
conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, stride=2, padding='VALID')
|
62
|
+
x = jnp.ones((2, 100, 8))
|
63
|
+
y = conv(x)
|
64
|
+
|
65
|
+
# VALID padding: output = (100 - 3 + 1) / 2 = 49
|
66
|
+
self.assertEqual(y.shape, (2, 49, 16))
|
67
|
+
|
68
|
+
def test_dilation(self):
|
69
|
+
"""Test Conv1d with dilated convolution."""
|
70
|
+
conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, rhs_dilation=2)
|
71
|
+
x = jnp.ones((2, 100, 8))
|
72
|
+
y = conv(x)
|
73
|
+
|
74
|
+
self.assertEqual(y.shape, (2, 100, 16))
|
75
|
+
|
76
|
+
def test_groups(self):
|
77
|
+
"""Test Conv1d with grouped convolution."""
|
78
|
+
conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3, groups=4)
|
79
|
+
x = jnp.ones((2, 100, 16))
|
80
|
+
y = conv(x)
|
81
|
+
|
82
|
+
self.assertEqual(y.shape, (2, 100, 32))
|
83
|
+
self.assertEqual(conv.groups, 4)
|
84
|
+
|
85
|
+
def test_with_bias(self):
|
86
|
+
"""Test Conv1d with bias."""
|
87
|
+
conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3,
|
88
|
+
b_init=brainstate.init.Constant(0.0))
|
89
|
+
x = jnp.ones((2, 50, 8))
|
90
|
+
y = conv(x)
|
91
|
+
|
92
|
+
self.assertEqual(y.shape, (2, 50, 16))
|
93
|
+
self.assertIn('bias', conv.weight.value)
|
94
|
+
|
95
|
+
|
96
|
+
class TestConv2d(unittest.TestCase):
|
97
|
+
"""Test cases for 2D convolution."""
|
98
|
+
|
99
|
+
def test_basic_channels_last(self):
|
100
|
+
"""Test basic Conv2d with channels-last format."""
|
101
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
|
102
|
+
x = jnp.ones((8, 32, 32, 3))
|
103
|
+
y = conv(x)
|
104
|
+
|
105
|
+
self.assertEqual(y.shape, (8, 32, 32, 64))
|
106
|
+
self.assertEqual(conv.in_channels, 3)
|
107
|
+
self.assertEqual(conv.out_channels, 64)
|
108
|
+
self.assertFalse(conv.channel_first)
|
109
|
+
|
110
|
+
def test_basic_channels_first(self):
|
111
|
+
"""Test basic Conv2d with channels-first format."""
|
112
|
+
conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
|
113
|
+
x = jnp.ones((8, 3, 32, 32))
|
114
|
+
y = conv(x)
|
115
|
+
|
116
|
+
self.assertEqual(y.shape, (8, 64, 32, 32))
|
117
|
+
self.assertEqual(conv.in_channels, 3)
|
118
|
+
self.assertEqual(conv.out_channels, 64)
|
119
|
+
self.assertTrue(conv.channel_first)
|
120
|
+
|
121
|
+
def test_rectangular_kernel(self):
|
122
|
+
"""Test Conv2d with rectangular kernel."""
|
123
|
+
conv = brainstate.nn.Conv2d(in_size=(64, 64, 16), out_channels=32, kernel_size=(3, 5))
|
124
|
+
x = jnp.ones((4, 64, 64, 16))
|
125
|
+
y = conv(x)
|
126
|
+
|
127
|
+
self.assertEqual(y.shape, (4, 64, 64, 32))
|
128
|
+
self.assertEqual(conv.kernel_size, (3, 5))
|
129
|
+
|
130
|
+
def test_stride_2d(self):
|
131
|
+
"""Test Conv2d with different strides."""
|
132
|
+
conv = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, stride=(2, 2), padding='VALID')
|
133
|
+
x = jnp.ones((4, 64, 64, 3))
|
134
|
+
y = conv(x)
|
135
|
+
|
136
|
+
# VALID padding: output = (64 - 3 + 1) / 2 = 31
|
137
|
+
self.assertEqual(y.shape, (4, 31, 31, 32))
|
138
|
+
|
139
|
+
def test_depthwise_convolution(self):
|
140
|
+
"""Test depthwise convolution (groups = in_channels)."""
|
141
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=16, kernel_size=3, groups=16)
|
142
|
+
x = jnp.ones((4, 32, 32, 16))
|
143
|
+
y = conv(x)
|
144
|
+
|
145
|
+
self.assertEqual(y.shape, (4, 32, 32, 16))
|
146
|
+
self.assertEqual(conv.groups, 16)
|
147
|
+
|
148
|
+
def test_padding_same_vs_valid(self):
|
149
|
+
"""Test different padding modes."""
|
150
|
+
conv_same = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='SAME')
|
151
|
+
conv_valid = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
|
152
|
+
|
153
|
+
x = jnp.ones((2, 32, 32, 3))
|
154
|
+
y_same = conv_same(x)
|
155
|
+
y_valid = conv_valid(x)
|
156
|
+
|
157
|
+
self.assertEqual(y_same.shape, (2, 32, 32, 16)) # SAME preserves size
|
158
|
+
self.assertEqual(y_valid.shape, (2, 28, 28, 16)) # VALID reduces size
|
159
|
+
|
160
|
+
|
161
|
+
class TestConv3d(unittest.TestCase):
|
162
|
+
"""Test cases for 3D convolution."""
|
163
|
+
|
164
|
+
def test_basic_channels_last(self):
|
165
|
+
"""Test basic Conv3d with channels-last format."""
|
166
|
+
conv = brainstate.nn.Conv3d(in_size=(16, 16, 16, 1), out_channels=32, kernel_size=3)
|
167
|
+
x = jnp.ones((2, 16, 16, 16, 1))
|
168
|
+
y = conv(x)
|
169
|
+
|
170
|
+
self.assertEqual(y.shape, (2, 16, 16, 16, 32))
|
171
|
+
self.assertEqual(conv.in_channels, 1)
|
172
|
+
self.assertEqual(conv.out_channels, 32)
|
173
|
+
|
174
|
+
def test_basic_channels_first(self):
|
175
|
+
"""Test basic Conv3d with channels-first format."""
|
176
|
+
conv = brainstate.nn.Conv3d(in_size=(1, 16, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
|
177
|
+
x = jnp.ones((2, 1, 16, 16, 16))
|
178
|
+
y = conv(x)
|
179
|
+
|
180
|
+
self.assertEqual(y.shape, (2, 32, 16, 16, 16))
|
181
|
+
self.assertEqual(conv.in_channels, 1)
|
182
|
+
self.assertEqual(conv.out_channels, 32)
|
183
|
+
|
184
|
+
def test_video_data(self):
|
185
|
+
"""Test Conv3d for video data."""
|
186
|
+
conv = brainstate.nn.Conv3d(in_size=(8, 32, 32, 3), out_channels=64, kernel_size=(3, 3, 3))
|
187
|
+
x = jnp.ones((4, 8, 32, 32, 3)) # batch, frames, height, width, channels
|
188
|
+
y = conv(x)
|
189
|
+
|
190
|
+
self.assertEqual(y.shape, (4, 8, 32, 32, 64))
|
191
|
+
|
192
|
+
|
193
|
+
class TestScaledWSConv1d(unittest.TestCase):
|
194
|
+
"""Test cases for 1D convolution with weight standardization."""
|
195
|
+
|
196
|
+
def test_basic(self):
|
197
|
+
"""Test basic ScaledWSConv1d."""
|
198
|
+
conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
|
199
|
+
x = jnp.ones((4, 100, 16))
|
200
|
+
y = conv(x)
|
201
|
+
|
202
|
+
self.assertEqual(y.shape, (4, 100, 32))
|
203
|
+
self.assertIsNotNone(conv.eps)
|
204
|
+
|
205
|
+
def test_with_gain(self):
|
206
|
+
"""Test ScaledWSConv1d with gain parameter."""
|
207
|
+
conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=True)
|
208
|
+
x = jnp.ones((4, 100, 16))
|
209
|
+
y = conv(x)
|
210
|
+
|
211
|
+
self.assertEqual(y.shape, (4, 100, 32))
|
212
|
+
self.assertIn('gain', conv.weight.value)
|
213
|
+
|
214
|
+
def test_without_gain(self):
|
215
|
+
"""Test ScaledWSConv1d without gain parameter."""
|
216
|
+
conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=False)
|
217
|
+
x = jnp.ones((4, 100, 16))
|
218
|
+
y = conv(x)
|
219
|
+
|
220
|
+
self.assertEqual(y.shape, (4, 100, 32))
|
221
|
+
self.assertNotIn('gain', conv.weight.value)
|
222
|
+
|
223
|
+
def test_custom_eps(self):
|
224
|
+
"""Test ScaledWSConv1d with custom epsilon."""
|
225
|
+
conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, eps=1e-5)
|
226
|
+
self.assertEqual(conv.eps, 1e-5)
|
227
|
+
|
228
|
+
|
229
|
+
class TestScaledWSConv2d(unittest.TestCase):
|
230
|
+
"""Test cases for 2D convolution with weight standardization."""
|
231
|
+
|
232
|
+
def test_basic_channels_last(self):
|
233
|
+
"""Test basic ScaledWSConv2d with channels-last format."""
|
234
|
+
conv = brainstate.nn.ScaledWSConv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3)
|
235
|
+
x = jnp.ones((8, 64, 64, 3))
|
236
|
+
y = conv(x)
|
237
|
+
|
238
|
+
self.assertEqual(y.shape, (8, 64, 64, 32))
|
239
|
+
|
240
|
+
def test_basic_channels_first(self):
|
241
|
+
"""Test basic ScaledWSConv2d with channels-first format."""
|
242
|
+
conv = brainstate.nn.ScaledWSConv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
|
243
|
+
x = jnp.ones((8, 3, 64, 64))
|
244
|
+
y = conv(x)
|
245
|
+
|
246
|
+
self.assertEqual(y.shape, (8, 32, 64, 64))
|
247
|
+
|
248
|
+
def test_with_group_norm_style(self):
|
249
|
+
"""Test ScaledWSConv2d for use with group normalization."""
|
250
|
+
conv = brainstate.nn.ScaledWSConv2d(
|
251
|
+
in_size=(32, 32, 16),
|
252
|
+
out_channels=32,
|
253
|
+
kernel_size=3,
|
254
|
+
ws_gain=True,
|
255
|
+
groups=1
|
256
|
+
)
|
257
|
+
x = jnp.ones((4, 32, 32, 16))
|
258
|
+
y = conv(x)
|
259
|
+
|
260
|
+
self.assertEqual(y.shape, (4, 32, 32, 32))
|
261
|
+
|
262
|
+
|
263
|
+
class TestScaledWSConv3d(unittest.TestCase):
|
264
|
+
"""Test cases for 3D convolution with weight standardization."""
|
265
|
+
|
266
|
+
def test_basic(self):
|
267
|
+
"""Test basic ScaledWSConv3d."""
|
268
|
+
conv = brainstate.nn.ScaledWSConv3d(in_size=(8, 16, 16, 3), out_channels=32, kernel_size=3)
|
269
|
+
x = jnp.ones((2, 8, 16, 16, 3))
|
270
|
+
y = conv(x)
|
271
|
+
|
272
|
+
self.assertEqual(y.shape, (2, 8, 16, 16, 32))
|
273
|
+
|
274
|
+
def test_channels_first(self):
|
275
|
+
"""Test ScaledWSConv3d with channels-first format."""
|
276
|
+
conv = brainstate.nn.ScaledWSConv3d(in_size=(3, 8, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
|
277
|
+
x = jnp.ones((2, 3, 8, 16, 16))
|
278
|
+
y = conv(x)
|
279
|
+
|
280
|
+
self.assertEqual(y.shape, (2, 32, 8, 16, 16))
|
281
|
+
|
282
|
+
|
283
|
+
class TestErrorHandling(unittest.TestCase):
|
284
|
+
"""Test error handling and edge cases."""
|
285
|
+
|
286
|
+
def test_invalid_input_shape(self):
|
287
|
+
"""Test that invalid input shapes raise appropriate errors."""
|
288
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
|
289
|
+
x_wrong = jnp.ones((8, 32, 32, 16)) # Wrong number of channels
|
290
|
+
|
291
|
+
with self.assertRaises(ValueError):
|
292
|
+
conv(x_wrong)
|
293
|
+
|
294
|
+
def test_invalid_groups(self):
|
295
|
+
"""Test that invalid group configurations raise errors."""
|
296
|
+
with self.assertRaises(AssertionError):
|
297
|
+
# out_channels not divisible by groups
|
298
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=30, kernel_size=3, groups=4)
|
299
|
+
|
300
|
+
def test_dimension_mismatch(self):
|
301
|
+
"""Test dimension mismatch detection."""
|
302
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
|
303
|
+
x_1d = jnp.ones((8, 32, 3)) # 1D instead of 2D
|
304
|
+
|
305
|
+
with self.assertRaises(ValueError):
|
306
|
+
conv(x_1d)
|
307
|
+
|
308
|
+
|
309
|
+
class TestOutputShapes(unittest.TestCase):
|
310
|
+
"""Test output shape calculations."""
|
311
|
+
|
312
|
+
def test_same_padding_preserves_size(self):
|
313
|
+
"""Test that SAME padding preserves spatial dimensions when stride=1."""
|
314
|
+
for kernel_size in [3, 5, 7]:
|
315
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=kernel_size, padding='SAME')
|
316
|
+
x = jnp.ones((4, 32, 32, 3))
|
317
|
+
y = conv(x)
|
318
|
+
self.assertEqual(y.shape, (4, 32, 32, 16), f"Failed for kernel_size={kernel_size}")
|
319
|
+
|
320
|
+
def test_valid_padding_reduces_size(self):
|
321
|
+
"""Test that VALID padding reduces spatial dimensions."""
|
322
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
|
323
|
+
x = jnp.ones((4, 32, 32, 3))
|
324
|
+
y = conv(x)
|
325
|
+
# 32 - 5 + 1 = 28
|
326
|
+
self.assertEqual(y.shape, (4, 28, 28, 16))
|
327
|
+
|
328
|
+
def test_output_size_attribute(self):
|
329
|
+
"""Test that out_size attribute is correctly computed."""
|
330
|
+
conv_cl = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, channel_first=False)
|
331
|
+
conv_cf = brainstate.nn.Conv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
|
332
|
+
|
333
|
+
self.assertEqual(conv_cl.out_size, (64, 64, 32))
|
334
|
+
self.assertEqual(conv_cf.out_size, (32, 64, 64))
|
335
|
+
|
336
|
+
|
337
|
+
class TestChannelFormatConsistency(unittest.TestCase):
|
338
|
+
"""Test consistency between channels-first and channels-last formats."""
|
339
|
+
|
340
|
+
def test_conv1d_output_channels(self):
|
341
|
+
"""Test that output channels are in correct position for both formats."""
|
342
|
+
conv_cl = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3)
|
343
|
+
conv_cf = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=3, channel_first=True)
|
344
|
+
|
345
|
+
x_cl = jnp.ones((4, 100, 16))
|
346
|
+
x_cf = jnp.ones((4, 16, 100))
|
347
|
+
|
348
|
+
y_cl = conv_cl(x_cl)
|
349
|
+
y_cf = conv_cf(x_cf)
|
350
|
+
|
351
|
+
# Channels-last: channels in last dimension
|
352
|
+
self.assertEqual(y_cl.shape[-1], 32)
|
353
|
+
# Channels-first: channels in first dimension (after batch)
|
354
|
+
self.assertEqual(y_cf.shape[1], 32)
|
355
|
+
|
356
|
+
def test_conv2d_output_channels(self):
|
357
|
+
"""Test 2D output channel positions."""
|
358
|
+
conv_cl = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
|
359
|
+
conv_cf = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
|
360
|
+
|
361
|
+
x_cl = jnp.ones((4, 32, 32, 3))
|
362
|
+
x_cf = jnp.ones((4, 3, 32, 32))
|
363
|
+
|
364
|
+
y_cl = conv_cl(x_cl)
|
365
|
+
y_cf = conv_cf(x_cf)
|
366
|
+
|
367
|
+
self.assertEqual(y_cl.shape[-1], 64)
|
368
|
+
self.assertEqual(y_cf.shape[1], 64)
|
369
|
+
|
370
|
+
|
371
|
+
class TestReproducibility(unittest.TestCase):
|
372
|
+
"""Test reproducibility with fixed seeds."""
|
373
|
+
|
374
|
+
def test_deterministic_output(self):
|
375
|
+
"""Test that same seed produces same output."""
|
376
|
+
key = jax.random.PRNGKey(42)
|
377
|
+
|
378
|
+
conv1 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
|
379
|
+
conv2 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
|
380
|
+
|
381
|
+
# Use same random key for input
|
382
|
+
x = jax.random.normal(key, (4, 32, 32, 3))
|
383
|
+
|
384
|
+
# Note: outputs will differ due to different weight initialization
|
385
|
+
# This test just ensures no crashes with random inputs
|
386
|
+
y1 = conv1(x)
|
387
|
+
y2 = conv2(x)
|
388
|
+
|
389
|
+
self.assertEqual(y1.shape, y2.shape)
|
390
|
+
|
391
|
+
|
392
|
+
class TestRepr(unittest.TestCase):
|
393
|
+
"""Test string representations."""
|
394
|
+
|
395
|
+
def test_conv_repr_channels_last(self):
|
396
|
+
"""Test __repr__ for channels-last format."""
|
397
|
+
conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
|
398
|
+
repr_str = repr(conv)
|
399
|
+
|
400
|
+
self.assertIn('Conv2d', repr_str)
|
401
|
+
self.assertIn('channel_first=False', repr_str)
|
402
|
+
self.assertIn('in_channels=3', repr_str)
|
403
|
+
self.assertIn('out_channels=64', repr_str)
|
404
|
+
|
405
|
+
def test_conv_repr_channels_first(self):
|
406
|
+
"""Test __repr__ for channels-first format."""
|
407
|
+
conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
|
408
|
+
repr_str = repr(conv)
|
409
|
+
|
410
|
+
self.assertIn('Conv2d', repr_str)
|
411
|
+
self.assertIn('channel_first=True', repr_str)
|
412
|
+
|
413
|
+
|
414
|
+
class TestConvTranspose1d(unittest.TestCase):
|
415
|
+
"""Test cases for ConvTranspose1d layer."""
|
416
|
+
|
417
|
+
def setUp(self):
|
418
|
+
"""Set up test fixtures."""
|
419
|
+
self.in_size = (28, 16)
|
420
|
+
self.out_channels = 8
|
421
|
+
self.kernel_size = 4
|
422
|
+
|
423
|
+
def test_basic_channels_last(self):
|
424
|
+
"""Test basic ConvTranspose1d with channels-last format."""
|
425
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
426
|
+
in_size=self.in_size,
|
427
|
+
out_channels=self.out_channels,
|
428
|
+
kernel_size=self.kernel_size,
|
429
|
+
stride=1
|
430
|
+
)
|
431
|
+
x = jnp.ones((2, 28, 16))
|
432
|
+
y = conv_t(x)
|
433
|
+
|
434
|
+
self.assertEqual(len(y.shape), 3)
|
435
|
+
self.assertEqual(y.shape[0], 2) # batch size
|
436
|
+
self.assertEqual(y.shape[-1], self.out_channels)
|
437
|
+
self.assertEqual(conv_t.in_channels, 16)
|
438
|
+
self.assertEqual(conv_t.out_channels, 8)
|
439
|
+
self.assertFalse(conv_t.channel_first)
|
440
|
+
|
441
|
+
def test_basic_channels_first(self):
|
442
|
+
"""Test basic ConvTranspose1d with channels-first format."""
|
443
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
444
|
+
in_size=(16, 28), # (C, L) for channels-first
|
445
|
+
out_channels=self.out_channels,
|
446
|
+
kernel_size=self.kernel_size,
|
447
|
+
stride=1,
|
448
|
+
channel_first=True
|
449
|
+
)
|
450
|
+
x = jnp.ones((2, 16, 28))
|
451
|
+
y = conv_t(x)
|
452
|
+
|
453
|
+
self.assertEqual(len(y.shape), 3)
|
454
|
+
self.assertEqual(y.shape[0], 2) # batch size
|
455
|
+
self.assertEqual(y.shape[1], self.out_channels) # channels first
|
456
|
+
self.assertEqual(conv_t.in_channels, 16)
|
457
|
+
self.assertTrue(conv_t.channel_first)
|
458
|
+
|
459
|
+
def test_stride_upsampling(self):
|
460
|
+
"""Test transposed convolution with stride for upsampling."""
|
461
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
462
|
+
in_size=(28, 16),
|
463
|
+
out_channels=8,
|
464
|
+
kernel_size=4,
|
465
|
+
stride=2,
|
466
|
+
padding='SAME'
|
467
|
+
)
|
468
|
+
x = jnp.ones((2, 28, 16))
|
469
|
+
y = conv_t(x)
|
470
|
+
|
471
|
+
# With stride=2, output should be approximately 2x larger
|
472
|
+
self.assertGreater(y.shape[1], x.shape[1])
|
473
|
+
|
474
|
+
def test_with_bias(self):
|
475
|
+
"""Test ConvTranspose1d with bias."""
|
476
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
477
|
+
in_size=(50, 8),
|
478
|
+
out_channels=16,
|
479
|
+
kernel_size=3,
|
480
|
+
b_init=brainstate.init.Constant(0.0)
|
481
|
+
)
|
482
|
+
x = jnp.ones((4, 50, 8))
|
483
|
+
y = conv_t(x)
|
484
|
+
|
485
|
+
self.assertTrue('bias' in conv_t.weight.value)
|
486
|
+
self.assertEqual(y.shape[-1], 16)
|
487
|
+
|
488
|
+
def test_without_batch(self):
|
489
|
+
"""Test ConvTranspose1d without batch dimension."""
|
490
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
491
|
+
in_size=(28, 16),
|
492
|
+
out_channels=8,
|
493
|
+
kernel_size=4
|
494
|
+
)
|
495
|
+
x = jnp.ones((28, 16))
|
496
|
+
y = conv_t(x)
|
497
|
+
|
498
|
+
self.assertEqual(len(y.shape), 2)
|
499
|
+
self.assertEqual(y.shape[-1], 8)
|
500
|
+
|
501
|
+
def test_groups(self):
|
502
|
+
"""Test grouped transposed convolution."""
|
503
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
504
|
+
in_size=(28, 16),
|
505
|
+
out_channels=16,
|
506
|
+
kernel_size=3,
|
507
|
+
groups=4
|
508
|
+
)
|
509
|
+
x = jnp.ones((2, 28, 16))
|
510
|
+
y = conv_t(x)
|
511
|
+
|
512
|
+
self.assertEqual(y.shape[-1], 16)
|
513
|
+
|
514
|
+
|
515
|
+
class TestConvTranspose2d(unittest.TestCase):
|
516
|
+
"""Test cases for ConvTranspose2d layer."""
|
517
|
+
|
518
|
+
def setUp(self):
|
519
|
+
"""Set up test fixtures."""
|
520
|
+
self.in_size = (16, 16, 32)
|
521
|
+
self.out_channels = 16
|
522
|
+
self.kernel_size = 4
|
523
|
+
|
524
|
+
def test_basic_channels_last(self):
|
525
|
+
"""Test basic ConvTranspose2d with channels-last format."""
|
526
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
527
|
+
in_size=self.in_size,
|
528
|
+
out_channels=self.out_channels,
|
529
|
+
kernel_size=self.kernel_size
|
530
|
+
)
|
531
|
+
x = jnp.ones((4, 16, 16, 32))
|
532
|
+
y = conv_t(x)
|
533
|
+
|
534
|
+
self.assertEqual(len(y.shape), 4)
|
535
|
+
self.assertEqual(y.shape[0], 4) # batch size
|
536
|
+
self.assertEqual(y.shape[-1], self.out_channels)
|
537
|
+
self.assertEqual(conv_t.in_channels, 32)
|
538
|
+
self.assertFalse(conv_t.channel_first)
|
539
|
+
|
540
|
+
def test_basic_channels_first(self):
|
541
|
+
"""Test basic ConvTranspose2d with channels-first format."""
|
542
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
543
|
+
in_size=(32, 16, 16), # (C, H, W) for channels-first
|
544
|
+
out_channels=self.out_channels,
|
545
|
+
kernel_size=self.kernel_size,
|
546
|
+
channel_first=True
|
547
|
+
)
|
548
|
+
x = jnp.ones((4, 32, 16, 16))
|
549
|
+
y = conv_t(x)
|
550
|
+
|
551
|
+
self.assertEqual(len(y.shape), 4)
|
552
|
+
self.assertEqual(y.shape[1], self.out_channels) # channels first
|
553
|
+
self.assertTrue(conv_t.channel_first)
|
554
|
+
|
555
|
+
def test_stride_upsampling(self):
|
556
|
+
"""Test 2x upsampling with stride=2."""
|
557
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
558
|
+
in_size=(16, 16, 32),
|
559
|
+
out_channels=16,
|
560
|
+
kernel_size=4,
|
561
|
+
stride=2,
|
562
|
+
padding='SAME'
|
563
|
+
)
|
564
|
+
x = jnp.ones((4, 16, 16, 32))
|
565
|
+
y = conv_t(x)
|
566
|
+
|
567
|
+
# With stride=2, output should be approximately 2x larger in each spatial dimension
|
568
|
+
self.assertGreater(y.shape[1], x.shape[1])
|
569
|
+
self.assertGreater(y.shape[2], x.shape[2])
|
570
|
+
|
571
|
+
def test_rectangular_kernel(self):
|
572
|
+
"""Test ConvTranspose2d with rectangular kernel."""
|
573
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
574
|
+
in_size=(16, 16, 32),
|
575
|
+
out_channels=16,
|
576
|
+
kernel_size=(3, 5),
|
577
|
+
stride=1
|
578
|
+
)
|
579
|
+
x = jnp.ones((2, 16, 16, 32))
|
580
|
+
y = conv_t(x)
|
581
|
+
|
582
|
+
self.assertEqual(conv_t.kernel_size, (3, 5))
|
583
|
+
self.assertEqual(y.shape[-1], 16)
|
584
|
+
|
585
|
+
def test_padding_valid(self):
|
586
|
+
"""Test ConvTranspose2d with VALID padding."""
|
587
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
588
|
+
in_size=(16, 16, 32),
|
589
|
+
out_channels=16,
|
590
|
+
kernel_size=4,
|
591
|
+
stride=2,
|
592
|
+
padding='VALID'
|
593
|
+
)
|
594
|
+
x = jnp.ones((2, 16, 16, 32))
|
595
|
+
y = conv_t(x)
|
596
|
+
|
597
|
+
# VALID padding means no padding, output computed by formula:
|
598
|
+
# out = (in - 1) * stride + kernel
|
599
|
+
# out = (16 - 1) * 2 + 4 = 34 (but JAX may compute it slightly differently)
|
600
|
+
# At minimum, it should upsample
|
601
|
+
self.assertGreater(y.shape[1], 16)
|
602
|
+
|
603
|
+
def test_groups(self):
|
604
|
+
"""Test grouped transposed convolution."""
|
605
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
606
|
+
in_size=(16, 16, 32),
|
607
|
+
out_channels=32,
|
608
|
+
kernel_size=3,
|
609
|
+
groups=4
|
610
|
+
)
|
611
|
+
x = jnp.ones((2, 16, 16, 32))
|
612
|
+
y = conv_t(x)
|
613
|
+
|
614
|
+
self.assertEqual(y.shape[-1], 32)
|
615
|
+
|
616
|
+
|
617
|
+
class TestConvTranspose3d(unittest.TestCase):
|
618
|
+
"""Test cases for ConvTranspose3d layer."""
|
619
|
+
|
620
|
+
def setUp(self):
|
621
|
+
"""Set up test fixtures."""
|
622
|
+
self.in_size = (8, 8, 8, 16)
|
623
|
+
self.out_channels = 8
|
624
|
+
self.kernel_size = 4
|
625
|
+
|
626
|
+
def test_basic_channels_last(self):
|
627
|
+
"""Test basic ConvTranspose3d with channels-last format."""
|
628
|
+
conv_t = brainstate.nn.ConvTranspose3d(
|
629
|
+
in_size=self.in_size,
|
630
|
+
out_channels=self.out_channels,
|
631
|
+
kernel_size=self.kernel_size
|
632
|
+
)
|
633
|
+
x = jnp.ones((2, 8, 8, 8, 16))
|
634
|
+
y = conv_t(x)
|
635
|
+
|
636
|
+
self.assertEqual(len(y.shape), 5)
|
637
|
+
self.assertEqual(y.shape[0], 2) # batch size
|
638
|
+
self.assertEqual(y.shape[-1], self.out_channels)
|
639
|
+
self.assertEqual(conv_t.in_channels, 16)
|
640
|
+
|
641
|
+
def test_basic_channels_first(self):
|
642
|
+
"""Test basic ConvTranspose3d with channels-first format."""
|
643
|
+
conv_t = brainstate.nn.ConvTranspose3d(
|
644
|
+
in_size=(16, 8, 8, 8), # (C, H, W, D) for channels-first
|
645
|
+
out_channels=self.out_channels,
|
646
|
+
kernel_size=self.kernel_size,
|
647
|
+
channel_first=True
|
648
|
+
)
|
649
|
+
x = jnp.ones((2, 16, 8, 8, 8))
|
650
|
+
y = conv_t(x)
|
651
|
+
|
652
|
+
self.assertEqual(len(y.shape), 5)
|
653
|
+
self.assertEqual(y.shape[1], self.out_channels) # channels first
|
654
|
+
self.assertTrue(conv_t.channel_first)
|
655
|
+
|
656
|
+
def test_stride_upsampling(self):
|
657
|
+
"""Test 3D upsampling with stride=2."""
|
658
|
+
conv_t = brainstate.nn.ConvTranspose3d(
|
659
|
+
in_size=(8, 8, 8, 16),
|
660
|
+
out_channels=8,
|
661
|
+
kernel_size=4,
|
662
|
+
stride=2,
|
663
|
+
padding='SAME'
|
664
|
+
)
|
665
|
+
x = jnp.ones((2, 8, 8, 8, 16))
|
666
|
+
y = conv_t(x)
|
667
|
+
|
668
|
+
# With stride=2, output should be approximately 2x larger
|
669
|
+
self.assertGreater(y.shape[1], x.shape[1])
|
670
|
+
self.assertGreater(y.shape[2], x.shape[2])
|
671
|
+
self.assertGreater(y.shape[3], x.shape[3])
|
672
|
+
|
673
|
+
|
674
|
+
class TestErrorHandlingConvTranspose(unittest.TestCase):
|
675
|
+
"""Test error handling for transposed convolutions."""
|
676
|
+
|
677
|
+
def test_invalid_groups(self):
|
678
|
+
"""Test that invalid groups raises assertion error."""
|
679
|
+
with self.assertRaises(AssertionError):
|
680
|
+
brainstate.nn.ConvTranspose2d(
|
681
|
+
in_size=(16, 16, 32),
|
682
|
+
out_channels=15, # Not divisible by groups
|
123
683
|
kernel_size=3,
|
124
|
-
|
125
|
-
padding="SAME",
|
126
|
-
w_initializer=brainstate.init.Constant(),
|
127
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
684
|
+
groups=4
|
128
685
|
)
|
129
|
-
out = net(data)
|
130
|
-
self.assertEqual(out.shape, (1, 3, 1))
|
131
|
-
out = jnp.squeeze(out, axis=(0, 2))
|
132
|
-
expected_out = jnp.asarray([2, 3, 2])
|
133
|
-
if use_bias:
|
134
|
-
expected_out += 1
|
135
|
-
self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
|
136
|
-
|
137
|
-
|
138
|
-
@pytest.mark.skip(reason="not implemented yet")
|
139
|
-
class TestConvTranspose2d(parameterized.TestCase):
|
140
|
-
def test_conv_transpose(self):
|
141
|
-
|
142
|
-
x = jnp.ones((1, 8, 8, 3))
|
143
|
-
for use_bias in [True, False]:
|
144
|
-
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
145
|
-
in_channels=3,
|
146
|
-
out_channels=4,
|
147
|
-
kernel_size=(3, 3),
|
148
|
-
padding='VALID',
|
149
|
-
w_initializer=brainstate.init.Constant(),
|
150
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
151
|
-
)
|
152
|
-
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
153
|
-
y = conv_transpose_module(x)
|
154
|
-
print(y.shape)
|
155
|
-
|
156
|
-
def test_single_input_masked_conv_transpose(self):
|
157
|
-
|
158
|
-
x = jnp.ones((1, 8, 8, 3))
|
159
|
-
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
160
|
-
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
161
|
-
in_channels=3,
|
162
|
-
out_channels=4,
|
163
|
-
kernel_size=(3, 3),
|
164
|
-
padding='VALID',
|
165
|
-
mask=m,
|
166
|
-
w_initializer=brainstate.init.Constant(),
|
167
|
-
)
|
168
|
-
y = conv_transpose_module(x)
|
169
|
-
print(y.shape)
|
170
|
-
|
171
|
-
def test_computation_padding_same(self):
|
172
|
-
|
173
|
-
x = jnp.ones((1, 8, 8, 3))
|
174
|
-
for use_bias in [True, False]:
|
175
|
-
conv_transpose_module = brainstate.nn.ConvTranspose2d(
|
176
|
-
in_channels=3,
|
177
|
-
out_channels=4,
|
178
|
-
kernel_size=(3, 3),
|
179
|
-
stride=1,
|
180
|
-
padding='SAME',
|
181
|
-
w_initializer=brainstate.init.Constant(),
|
182
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
183
|
-
)
|
184
|
-
y = conv_transpose_module(x)
|
185
|
-
print(y.shape)
|
186
|
-
|
187
|
-
|
188
|
-
@pytest.mark.skip(reason="not implemented yet")
|
189
|
-
class TestConvTranspose3d(parameterized.TestCase):
|
190
|
-
def test_conv_transpose(self):
|
191
|
-
|
192
|
-
x = jnp.ones((1, 8, 8, 8, 3))
|
193
|
-
for use_bias in [True, False]:
|
194
|
-
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
195
|
-
in_channels=3,
|
196
|
-
out_channels=4,
|
197
|
-
kernel_size=(3, 3, 3),
|
198
|
-
padding='VALID',
|
199
|
-
w_initializer=brainstate.init.Constant(),
|
200
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
201
|
-
)
|
202
|
-
y = conv_transpose_module(x)
|
203
|
-
print(y.shape)
|
204
|
-
|
205
|
-
def test_single_input_masked_conv_transpose(self):
|
206
|
-
|
207
|
-
x = jnp.ones((1, 8, 8, 8, 3))
|
208
|
-
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
209
|
-
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
210
|
-
in_channels=3,
|
211
|
-
out_channels=4,
|
212
|
-
kernel_size=(3, 3, 3),
|
213
|
-
padding='VALID',
|
214
|
-
mask=m,
|
215
|
-
w_initializer=brainstate.init.Constant(),
|
216
|
-
)
|
217
|
-
y = conv_transpose_module(x)
|
218
|
-
print(y.shape)
|
219
|
-
|
220
|
-
def test_computation_padding_same(self):
|
221
|
-
|
222
|
-
x = jnp.ones((1, 8, 8, 8, 3))
|
223
|
-
for use_bias in [True, False]:
|
224
|
-
conv_transpose_module = brainstate.nn.ConvTranspose3d(
|
225
|
-
in_channels=3,
|
226
|
-
out_channels=4,
|
227
|
-
kernel_size=(3, 3, 3),
|
228
|
-
stride=1,
|
229
|
-
padding='SAME',
|
230
|
-
w_initializer=brainstate.init.Constant(),
|
231
|
-
b_initializer=brainstate.init.Constant() if use_bias else None,
|
232
|
-
)
|
233
|
-
y = conv_transpose_module(x)
|
234
|
-
print(y.shape)
|
235
686
|
|
687
|
+
def test_dimension_mismatch(self):
|
688
|
+
"""Test that wrong input dimensions raise error."""
|
689
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
690
|
+
in_size=(16, 16, 32),
|
691
|
+
out_channels=16,
|
692
|
+
kernel_size=3
|
693
|
+
)
|
694
|
+
x = jnp.ones((2, 16, 16, 16)) # Wrong number of channels
|
695
|
+
|
696
|
+
with self.assertRaises(ValueError):
|
697
|
+
conv_t(x)
|
236
698
|
|
237
|
-
|
238
|
-
|
699
|
+
def test_invalid_input_shape(self):
|
700
|
+
"""Test that invalid input shape raises error."""
|
701
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
702
|
+
in_size=(28, 16),
|
703
|
+
out_channels=8,
|
704
|
+
kernel_size=3
|
705
|
+
)
|
706
|
+
x = jnp.ones((2, 2, 28, 16)) # Too many dimensions
|
707
|
+
|
708
|
+
with self.assertRaises(ValueError):
|
709
|
+
conv_t(x)
|
710
|
+
|
711
|
+
|
712
|
+
class TestOutputShapesConvTranspose(unittest.TestCase):
|
713
|
+
"""Test output shape computation for transposed convolutions."""
|
714
|
+
|
715
|
+
def test_out_size_attribute_1d(self):
|
716
|
+
"""Test that out_size attribute is correctly computed for 1D."""
|
717
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
718
|
+
in_size=(28, 16),
|
719
|
+
out_channels=8,
|
720
|
+
kernel_size=4,
|
721
|
+
stride=2
|
722
|
+
)
|
723
|
+
|
724
|
+
self.assertIsNotNone(conv_t.out_size)
|
725
|
+
self.assertEqual(len(conv_t.out_size), 2)
|
726
|
+
|
727
|
+
def test_out_size_attribute_2d(self):
|
728
|
+
"""Test that out_size attribute is correctly computed for 2D."""
|
729
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
730
|
+
in_size=(16, 16, 32),
|
731
|
+
out_channels=16,
|
732
|
+
kernel_size=4,
|
733
|
+
stride=2
|
734
|
+
)
|
735
|
+
|
736
|
+
self.assertIsNotNone(conv_t.out_size)
|
737
|
+
self.assertEqual(len(conv_t.out_size), 3)
|
738
|
+
|
739
|
+
def test_upsampling_factor(self):
|
740
|
+
"""Test that stride=2 approximately doubles spatial dimensions."""
|
741
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
742
|
+
in_size=(16, 16, 32),
|
743
|
+
out_channels=16,
|
744
|
+
kernel_size=4,
|
745
|
+
stride=2,
|
746
|
+
padding='SAME'
|
747
|
+
)
|
748
|
+
x = jnp.ones((2, 16, 16, 32))
|
749
|
+
y = conv_t(x)
|
750
|
+
|
751
|
+
# For SAME padding and stride=2, output should be approximately 2x input
|
752
|
+
self.assertGreaterEqual(y.shape[1], 28)
|
753
|
+
self.assertGreaterEqual(y.shape[2], 28)
|
754
|
+
|
755
|
+
|
756
|
+
class TestChannelFormatConsistencyConvTranspose(unittest.TestCase):
|
757
|
+
"""Test consistency between different channel formats."""
|
758
|
+
|
759
|
+
def test_conv_transpose_1d_output_channels(self):
|
760
|
+
"""Test that output channels are in correct position for both formats."""
|
761
|
+
# Channels-last
|
762
|
+
conv_t_last = brainstate.nn.ConvTranspose1d(
|
763
|
+
in_size=(28, 16),
|
764
|
+
out_channels=8,
|
765
|
+
kernel_size=3
|
766
|
+
)
|
767
|
+
x_last = jnp.ones((2, 28, 16))
|
768
|
+
y_last = conv_t_last(x_last)
|
769
|
+
self.assertEqual(y_last.shape[-1], 8)
|
770
|
+
|
771
|
+
# Channels-first
|
772
|
+
conv_t_first = brainstate.nn.ConvTranspose1d(
|
773
|
+
in_size=(16, 28),
|
774
|
+
out_channels=8,
|
775
|
+
kernel_size=3,
|
776
|
+
channel_first=True
|
777
|
+
)
|
778
|
+
x_first = jnp.ones((2, 16, 28))
|
779
|
+
y_first = conv_t_first(x_first)
|
780
|
+
self.assertEqual(y_first.shape[1], 8)
|
781
|
+
|
782
|
+
def test_conv_transpose_2d_output_channels(self):
|
783
|
+
"""Test that output channels are in correct position for both formats."""
|
784
|
+
# Channels-last
|
785
|
+
conv_t_last = brainstate.nn.ConvTranspose2d(
|
786
|
+
in_size=(16, 16, 32),
|
787
|
+
out_channels=16,
|
788
|
+
kernel_size=3
|
789
|
+
)
|
790
|
+
x_last = jnp.ones((2, 16, 16, 32))
|
791
|
+
y_last = conv_t_last(x_last)
|
792
|
+
self.assertEqual(y_last.shape[-1], 16)
|
793
|
+
|
794
|
+
# Channels-first
|
795
|
+
conv_t_first = brainstate.nn.ConvTranspose2d(
|
796
|
+
in_size=(32, 16, 16),
|
797
|
+
out_channels=16,
|
798
|
+
kernel_size=3,
|
799
|
+
channel_first=True
|
800
|
+
)
|
801
|
+
x_first = jnp.ones((2, 32, 16, 16))
|
802
|
+
y_first = conv_t_first(x_first)
|
803
|
+
self.assertEqual(y_first.shape[1], 16)
|
804
|
+
|
805
|
+
|
806
|
+
class TestReproducibilityConvTranspose(unittest.TestCase):
|
807
|
+
"""Test deterministic behavior of transposed convolutions."""
|
808
|
+
|
809
|
+
def test_deterministic_output(self):
|
810
|
+
"""Test that same input produces same output."""
|
811
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
812
|
+
in_size=(16, 16, 32),
|
813
|
+
out_channels=16,
|
814
|
+
kernel_size=3
|
815
|
+
)
|
816
|
+
x = jnp.ones((2, 16, 16, 32))
|
817
|
+
|
818
|
+
y1 = conv_t(x)
|
819
|
+
y2 = conv_t(x)
|
820
|
+
|
821
|
+
self.assertTrue(jnp.allclose(y1, y2))
|
822
|
+
|
823
|
+
|
824
|
+
class TestKernelShapeConvTranspose(unittest.TestCase):
|
825
|
+
"""Test kernel shape computation for transposed convolutions."""
|
826
|
+
|
827
|
+
def test_kernel_shape_1d(self):
|
828
|
+
"""Test that kernel shape is correct for transposed conv 1D."""
|
829
|
+
conv_t = brainstate.nn.ConvTranspose1d(
|
830
|
+
in_size=(28, 16),
|
831
|
+
out_channels=8,
|
832
|
+
kernel_size=4,
|
833
|
+
groups=2
|
834
|
+
)
|
835
|
+
# For transpose conv: (kernel_size, out_channels, in_channels // groups)
|
836
|
+
expected_shape = (4, 8, 16 // 2)
|
837
|
+
self.assertEqual(conv_t.kernel_shape, expected_shape)
|
838
|
+
|
839
|
+
def test_kernel_shape_2d(self):
|
840
|
+
"""Test that kernel shape is correct for transposed conv 2D."""
|
841
|
+
conv_t = brainstate.nn.ConvTranspose2d(
|
842
|
+
in_size=(16, 16, 32),
|
843
|
+
out_channels=16,
|
844
|
+
kernel_size=4,
|
845
|
+
groups=4
|
846
|
+
)
|
847
|
+
# For transpose conv: (kernel_h, kernel_w, out_channels, in_channels // groups)
|
848
|
+
expected_shape = (4, 4, 16, 32 // 4)
|
849
|
+
self.assertEqual(conv_t.kernel_shape, expected_shape)
|