brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,238 +1,849 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
1
16
  # -*- coding: utf-8 -*-
2
17
 
18
+ import unittest
19
+
20
+ import jax
3
21
  import jax.numpy as jnp
4
- import pytest
5
- from absl.testing import absltest
6
- from absl.testing import parameterized
7
22
 
8
23
  import brainstate
9
24
 
10
25
 
11
- class TestConv(parameterized.TestCase):
12
- def test_Conv2D_img(self):
13
- img = jnp.zeros((2, 200, 198, 4))
14
- for k in range(4):
15
- x = 30 + 60 * k
16
- y = 20 + 60 * k
17
- img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
18
- img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
19
-
20
- net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
21
- stride=(2, 1), padding='VALID', groups=4)
22
- out = net(img)
23
- print("out shape: ", out.shape)
24
- self.assertEqual(out.shape, (2, 99, 196, 32))
25
- # print("First output channel:")
26
- # plt.figure(figsize=(10, 10))
27
- # plt.imshow(np.array(img)[0, :, :, 0])
28
- # plt.show()
29
-
30
- def test_conv1D(self):
31
- model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
32
- input = jnp.ones((2, 5, 3))
33
- out = model(input)
34
- print("out shape: ", out.shape)
35
- self.assertEqual(out.shape, (2, 5, 32))
36
- # print("First output channel:")
37
- # plt.figure(figsize=(10, 10))
38
- # plt.imshow(np.array(out)[0, :, :])
39
- # plt.show()
40
-
41
- def test_conv2D(self):
42
- model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
43
- input = jnp.ones((2, 5, 5, 3))
44
-
45
- out = model(input)
46
- print("out shape: ", out.shape)
47
- self.assertEqual(out.shape, (2, 5, 5, 32))
48
-
49
- def test_conv3D(self):
50
- model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
51
- input = jnp.ones((2, 5, 5, 5, 3))
52
- out = model(input)
53
- print("out shape: ", out.shape)
54
- self.assertEqual(out.shape, (2, 5, 5, 5, 32))
55
-
56
-
57
- @pytest.mark.skip(reason="not implemented yet")
58
- class TestConvTranspose1d(parameterized.TestCase):
59
- def test_conv_transpose(self):
60
-
61
- x = jnp.ones((1, 8, 3))
62
- for use_bias in [True, False]:
63
- conv_transpose_module = brainstate.nn.ConvTranspose1d(
64
- in_channels=3,
65
- out_channels=4,
66
- kernel_size=(3,),
67
- padding='VALID',
68
- w_initializer=brainstate.init.Constant(1.),
69
- b_initializer=brainstate.init.Constant(1.) if use_bias else None,
70
- )
71
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
72
- y = conv_transpose_module(x)
73
- print(y.shape)
74
- correct_ans = jnp.array([[[4., 4., 4., 4.],
75
- [7., 7., 7., 7.],
76
- [10., 10., 10., 10.],
77
- [10., 10., 10., 10.],
78
- [10., 10., 10., 10.],
79
- [10., 10., 10., 10.],
80
- [10., 10., 10., 10.],
81
- [10., 10., 10., 10.],
82
- [7., 7., 7., 7.],
83
- [4., 4., 4., 4.]]])
84
- if not use_bias:
85
- correct_ans -= 1.
86
- self.assertTrue(jnp.allclose(y, correct_ans))
87
-
88
- def test_single_input_masked_conv_transpose(self):
89
-
90
- x = jnp.ones((1, 8, 3))
91
- m = jnp.tril(jnp.ones((3, 3, 4)))
92
- conv_transpose_module = brainstate.nn.ConvTranspose1d(
93
- in_channels=3,
94
- out_channels=4,
95
- kernel_size=(3,),
96
- padding='VALID',
97
- mask=m,
98
- w_initializer=brainstate.init.Constant(),
99
- b_initializer=brainstate.init.Constant(),
100
- )
101
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
102
- y = conv_transpose_module(x)
103
- print(y.shape)
104
- correct_ans = jnp.array([[[4., 3., 2., 1.],
105
- [7., 5., 3., 1.],
106
- [10., 7., 4., 1.],
107
- [10., 7., 4., 1.],
108
- [10., 7., 4., 1.],
109
- [10., 7., 4., 1.],
110
- [10., 7., 4., 1.],
111
- [10., 7., 4., 1.],
112
- [7., 5., 3., 1.],
113
- [4., 3., 2., 1.]]])
114
- self.assertTrue(jnp.allclose(y, correct_ans))
115
-
116
- def test_computation_padding_same(self):
117
-
118
- data = jnp.ones([1, 3, 1])
119
- for use_bias in [True, False]:
120
- net = brainstate.nn.ConvTranspose1d(
121
- in_channels=1,
122
- out_channels=1,
26
+ class TestConv1d(unittest.TestCase):
27
+ """Test cases for 1D convolution."""
28
+
29
+ def test_basic_channels_last(self):
30
+ """Test basic Conv1d with channels-last format."""
31
+ conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
32
+ x = jnp.ones((4, 100, 16))
33
+ y = conv(x)
34
+
35
+ self.assertEqual(y.shape, (4, 100, 32))
36
+ self.assertEqual(conv.in_channels, 16)
37
+ self.assertEqual(conv.out_channels, 32)
38
+ self.assertFalse(conv.channel_first)
39
+
40
+ def test_basic_channels_first(self):
41
+ """Test basic Conv1d with channels-first format."""
42
+ conv = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=5, channel_first=True)
43
+ x = jnp.ones((4, 16, 100))
44
+ y = conv(x)
45
+
46
+ self.assertEqual(y.shape, (4, 32, 100))
47
+ self.assertEqual(conv.in_channels, 16)
48
+ self.assertEqual(conv.out_channels, 32)
49
+ self.assertTrue(conv.channel_first)
50
+
51
+ def test_without_batch(self):
52
+ """Test Conv1d without batch dimension."""
53
+ conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3)
54
+ x = jnp.ones((50, 8))
55
+ y = conv(x)
56
+
57
+ self.assertEqual(y.shape, (50, 16))
58
+
59
+ def test_stride(self):
60
+ """Test Conv1d with stride."""
61
+ conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, stride=2, padding='VALID')
62
+ x = jnp.ones((2, 100, 8))
63
+ y = conv(x)
64
+
65
+ # VALID padding: output = (100 - 3 + 1) / 2 = 49
66
+ self.assertEqual(y.shape, (2, 49, 16))
67
+
68
+ def test_dilation(self):
69
+ """Test Conv1d with dilated convolution."""
70
+ conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, rhs_dilation=2)
71
+ x = jnp.ones((2, 100, 8))
72
+ y = conv(x)
73
+
74
+ self.assertEqual(y.shape, (2, 100, 16))
75
+
76
+ def test_groups(self):
77
+ """Test Conv1d with grouped convolution."""
78
+ conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3, groups=4)
79
+ x = jnp.ones((2, 100, 16))
80
+ y = conv(x)
81
+
82
+ self.assertEqual(y.shape, (2, 100, 32))
83
+ self.assertEqual(conv.groups, 4)
84
+
85
+ def test_with_bias(self):
86
+ """Test Conv1d with bias."""
87
+ conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3,
88
+ b_init=brainstate.init.Constant(0.0))
89
+ x = jnp.ones((2, 50, 8))
90
+ y = conv(x)
91
+
92
+ self.assertEqual(y.shape, (2, 50, 16))
93
+ self.assertIn('bias', conv.weight.value)
94
+
95
+
96
+ class TestConv2d(unittest.TestCase):
97
+ """Test cases for 2D convolution."""
98
+
99
+ def test_basic_channels_last(self):
100
+ """Test basic Conv2d with channels-last format."""
101
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
102
+ x = jnp.ones((8, 32, 32, 3))
103
+ y = conv(x)
104
+
105
+ self.assertEqual(y.shape, (8, 32, 32, 64))
106
+ self.assertEqual(conv.in_channels, 3)
107
+ self.assertEqual(conv.out_channels, 64)
108
+ self.assertFalse(conv.channel_first)
109
+
110
+ def test_basic_channels_first(self):
111
+ """Test basic Conv2d with channels-first format."""
112
+ conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
113
+ x = jnp.ones((8, 3, 32, 32))
114
+ y = conv(x)
115
+
116
+ self.assertEqual(y.shape, (8, 64, 32, 32))
117
+ self.assertEqual(conv.in_channels, 3)
118
+ self.assertEqual(conv.out_channels, 64)
119
+ self.assertTrue(conv.channel_first)
120
+
121
+ def test_rectangular_kernel(self):
122
+ """Test Conv2d with rectangular kernel."""
123
+ conv = brainstate.nn.Conv2d(in_size=(64, 64, 16), out_channels=32, kernel_size=(3, 5))
124
+ x = jnp.ones((4, 64, 64, 16))
125
+ y = conv(x)
126
+
127
+ self.assertEqual(y.shape, (4, 64, 64, 32))
128
+ self.assertEqual(conv.kernel_size, (3, 5))
129
+
130
+ def test_stride_2d(self):
131
+ """Test Conv2d with different strides."""
132
+ conv = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, stride=(2, 2), padding='VALID')
133
+ x = jnp.ones((4, 64, 64, 3))
134
+ y = conv(x)
135
+
136
+ # VALID padding: output = (64 - 3 + 1) / 2 = 31
137
+ self.assertEqual(y.shape, (4, 31, 31, 32))
138
+
139
+ def test_depthwise_convolution(self):
140
+ """Test depthwise convolution (groups = in_channels)."""
141
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=16, kernel_size=3, groups=16)
142
+ x = jnp.ones((4, 32, 32, 16))
143
+ y = conv(x)
144
+
145
+ self.assertEqual(y.shape, (4, 32, 32, 16))
146
+ self.assertEqual(conv.groups, 16)
147
+
148
+ def test_padding_same_vs_valid(self):
149
+ """Test different padding modes."""
150
+ conv_same = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='SAME')
151
+ conv_valid = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
152
+
153
+ x = jnp.ones((2, 32, 32, 3))
154
+ y_same = conv_same(x)
155
+ y_valid = conv_valid(x)
156
+
157
+ self.assertEqual(y_same.shape, (2, 32, 32, 16)) # SAME preserves size
158
+ self.assertEqual(y_valid.shape, (2, 28, 28, 16)) # VALID reduces size
159
+
160
+
161
+ class TestConv3d(unittest.TestCase):
162
+ """Test cases for 3D convolution."""
163
+
164
+ def test_basic_channels_last(self):
165
+ """Test basic Conv3d with channels-last format."""
166
+ conv = brainstate.nn.Conv3d(in_size=(16, 16, 16, 1), out_channels=32, kernel_size=3)
167
+ x = jnp.ones((2, 16, 16, 16, 1))
168
+ y = conv(x)
169
+
170
+ self.assertEqual(y.shape, (2, 16, 16, 16, 32))
171
+ self.assertEqual(conv.in_channels, 1)
172
+ self.assertEqual(conv.out_channels, 32)
173
+
174
+ def test_basic_channels_first(self):
175
+ """Test basic Conv3d with channels-first format."""
176
+ conv = brainstate.nn.Conv3d(in_size=(1, 16, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
177
+ x = jnp.ones((2, 1, 16, 16, 16))
178
+ y = conv(x)
179
+
180
+ self.assertEqual(y.shape, (2, 32, 16, 16, 16))
181
+ self.assertEqual(conv.in_channels, 1)
182
+ self.assertEqual(conv.out_channels, 32)
183
+
184
+ def test_video_data(self):
185
+ """Test Conv3d for video data."""
186
+ conv = brainstate.nn.Conv3d(in_size=(8, 32, 32, 3), out_channels=64, kernel_size=(3, 3, 3))
187
+ x = jnp.ones((4, 8, 32, 32, 3)) # batch, frames, height, width, channels
188
+ y = conv(x)
189
+
190
+ self.assertEqual(y.shape, (4, 8, 32, 32, 64))
191
+
192
+
193
+ class TestScaledWSConv1d(unittest.TestCase):
194
+ """Test cases for 1D convolution with weight standardization."""
195
+
196
+ def test_basic(self):
197
+ """Test basic ScaledWSConv1d."""
198
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
199
+ x = jnp.ones((4, 100, 16))
200
+ y = conv(x)
201
+
202
+ self.assertEqual(y.shape, (4, 100, 32))
203
+ self.assertIsNotNone(conv.eps)
204
+
205
+ def test_with_gain(self):
206
+ """Test ScaledWSConv1d with gain parameter."""
207
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=True)
208
+ x = jnp.ones((4, 100, 16))
209
+ y = conv(x)
210
+
211
+ self.assertEqual(y.shape, (4, 100, 32))
212
+ self.assertIn('gain', conv.weight.value)
213
+
214
+ def test_without_gain(self):
215
+ """Test ScaledWSConv1d without gain parameter."""
216
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=False)
217
+ x = jnp.ones((4, 100, 16))
218
+ y = conv(x)
219
+
220
+ self.assertEqual(y.shape, (4, 100, 32))
221
+ self.assertNotIn('gain', conv.weight.value)
222
+
223
+ def test_custom_eps(self):
224
+ """Test ScaledWSConv1d with custom epsilon."""
225
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, eps=1e-5)
226
+ self.assertEqual(conv.eps, 1e-5)
227
+
228
+
229
+ class TestScaledWSConv2d(unittest.TestCase):
230
+ """Test cases for 2D convolution with weight standardization."""
231
+
232
+ def test_basic_channels_last(self):
233
+ """Test basic ScaledWSConv2d with channels-last format."""
234
+ conv = brainstate.nn.ScaledWSConv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3)
235
+ x = jnp.ones((8, 64, 64, 3))
236
+ y = conv(x)
237
+
238
+ self.assertEqual(y.shape, (8, 64, 64, 32))
239
+
240
+ def test_basic_channels_first(self):
241
+ """Test basic ScaledWSConv2d with channels-first format."""
242
+ conv = brainstate.nn.ScaledWSConv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
243
+ x = jnp.ones((8, 3, 64, 64))
244
+ y = conv(x)
245
+
246
+ self.assertEqual(y.shape, (8, 32, 64, 64))
247
+
248
+ def test_with_group_norm_style(self):
249
+ """Test ScaledWSConv2d for use with group normalization."""
250
+ conv = brainstate.nn.ScaledWSConv2d(
251
+ in_size=(32, 32, 16),
252
+ out_channels=32,
253
+ kernel_size=3,
254
+ ws_gain=True,
255
+ groups=1
256
+ )
257
+ x = jnp.ones((4, 32, 32, 16))
258
+ y = conv(x)
259
+
260
+ self.assertEqual(y.shape, (4, 32, 32, 32))
261
+
262
+
263
+ class TestScaledWSConv3d(unittest.TestCase):
264
+ """Test cases for 3D convolution with weight standardization."""
265
+
266
+ def test_basic(self):
267
+ """Test basic ScaledWSConv3d."""
268
+ conv = brainstate.nn.ScaledWSConv3d(in_size=(8, 16, 16, 3), out_channels=32, kernel_size=3)
269
+ x = jnp.ones((2, 8, 16, 16, 3))
270
+ y = conv(x)
271
+
272
+ self.assertEqual(y.shape, (2, 8, 16, 16, 32))
273
+
274
+ def test_channels_first(self):
275
+ """Test ScaledWSConv3d with channels-first format."""
276
+ conv = brainstate.nn.ScaledWSConv3d(in_size=(3, 8, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
277
+ x = jnp.ones((2, 3, 8, 16, 16))
278
+ y = conv(x)
279
+
280
+ self.assertEqual(y.shape, (2, 32, 8, 16, 16))
281
+
282
+
283
+ class TestErrorHandling(unittest.TestCase):
284
+ """Test error handling and edge cases."""
285
+
286
+ def test_invalid_input_shape(self):
287
+ """Test that invalid input shapes raise appropriate errors."""
288
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
289
+ x_wrong = jnp.ones((8, 32, 32, 16)) # Wrong number of channels
290
+
291
+ with self.assertRaises(ValueError):
292
+ conv(x_wrong)
293
+
294
+ def test_invalid_groups(self):
295
+ """Test that invalid group configurations raise errors."""
296
+ with self.assertRaises(AssertionError):
297
+ # out_channels not divisible by groups
298
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=30, kernel_size=3, groups=4)
299
+
300
+ def test_dimension_mismatch(self):
301
+ """Test dimension mismatch detection."""
302
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
303
+ x_1d = jnp.ones((8, 32, 3)) # 1D instead of 2D
304
+
305
+ with self.assertRaises(ValueError):
306
+ conv(x_1d)
307
+
308
+
309
+ class TestOutputShapes(unittest.TestCase):
310
+ """Test output shape calculations."""
311
+
312
+ def test_same_padding_preserves_size(self):
313
+ """Test that SAME padding preserves spatial dimensions when stride=1."""
314
+ for kernel_size in [3, 5, 7]:
315
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=kernel_size, padding='SAME')
316
+ x = jnp.ones((4, 32, 32, 3))
317
+ y = conv(x)
318
+ self.assertEqual(y.shape, (4, 32, 32, 16), f"Failed for kernel_size={kernel_size}")
319
+
320
+ def test_valid_padding_reduces_size(self):
321
+ """Test that VALID padding reduces spatial dimensions."""
322
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
323
+ x = jnp.ones((4, 32, 32, 3))
324
+ y = conv(x)
325
+ # 32 - 5 + 1 = 28
326
+ self.assertEqual(y.shape, (4, 28, 28, 16))
327
+
328
+ def test_output_size_attribute(self):
329
+ """Test that out_size attribute is correctly computed."""
330
+ conv_cl = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, channel_first=False)
331
+ conv_cf = brainstate.nn.Conv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
332
+
333
+ self.assertEqual(conv_cl.out_size, (64, 64, 32))
334
+ self.assertEqual(conv_cf.out_size, (32, 64, 64))
335
+
336
+
337
+ class TestChannelFormatConsistency(unittest.TestCase):
338
+ """Test consistency between channels-first and channels-last formats."""
339
+
340
+ def test_conv1d_output_channels(self):
341
+ """Test that output channels are in correct position for both formats."""
342
+ conv_cl = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3)
343
+ conv_cf = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=3, channel_first=True)
344
+
345
+ x_cl = jnp.ones((4, 100, 16))
346
+ x_cf = jnp.ones((4, 16, 100))
347
+
348
+ y_cl = conv_cl(x_cl)
349
+ y_cf = conv_cf(x_cf)
350
+
351
+ # Channels-last: channels in last dimension
352
+ self.assertEqual(y_cl.shape[-1], 32)
353
+ # Channels-first: channels in first dimension (after batch)
354
+ self.assertEqual(y_cf.shape[1], 32)
355
+
356
+ def test_conv2d_output_channels(self):
357
+ """Test 2D output channel positions."""
358
+ conv_cl = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
359
+ conv_cf = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
360
+
361
+ x_cl = jnp.ones((4, 32, 32, 3))
362
+ x_cf = jnp.ones((4, 3, 32, 32))
363
+
364
+ y_cl = conv_cl(x_cl)
365
+ y_cf = conv_cf(x_cf)
366
+
367
+ self.assertEqual(y_cl.shape[-1], 64)
368
+ self.assertEqual(y_cf.shape[1], 64)
369
+
370
+
371
+ class TestReproducibility(unittest.TestCase):
372
+ """Test reproducibility with fixed seeds."""
373
+
374
+ def test_deterministic_output(self):
375
+ """Test that same seed produces same output."""
376
+ key = jax.random.PRNGKey(42)
377
+
378
+ conv1 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
379
+ conv2 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
380
+
381
+ # Use same random key for input
382
+ x = jax.random.normal(key, (4, 32, 32, 3))
383
+
384
+ # Note: outputs will differ due to different weight initialization
385
+ # This test just ensures no crashes with random inputs
386
+ y1 = conv1(x)
387
+ y2 = conv2(x)
388
+
389
+ self.assertEqual(y1.shape, y2.shape)
390
+
391
+
392
+ class TestRepr(unittest.TestCase):
393
+ """Test string representations."""
394
+
395
+ def test_conv_repr_channels_last(self):
396
+ """Test __repr__ for channels-last format."""
397
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
398
+ repr_str = repr(conv)
399
+
400
+ self.assertIn('Conv2d', repr_str)
401
+ self.assertIn('channel_first=False', repr_str)
402
+ self.assertIn('in_channels=3', repr_str)
403
+ self.assertIn('out_channels=64', repr_str)
404
+
405
+ def test_conv_repr_channels_first(self):
406
+ """Test __repr__ for channels-first format."""
407
+ conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
408
+ repr_str = repr(conv)
409
+
410
+ self.assertIn('Conv2d', repr_str)
411
+ self.assertIn('channel_first=True', repr_str)
412
+
413
+
414
+ class TestConvTranspose1d(unittest.TestCase):
415
+ """Test cases for ConvTranspose1d layer."""
416
+
417
+ def setUp(self):
418
+ """Set up test fixtures."""
419
+ self.in_size = (28, 16)
420
+ self.out_channels = 8
421
+ self.kernel_size = 4
422
+
423
+ def test_basic_channels_last(self):
424
+ """Test basic ConvTranspose1d with channels-last format."""
425
+ conv_t = brainstate.nn.ConvTranspose1d(
426
+ in_size=self.in_size,
427
+ out_channels=self.out_channels,
428
+ kernel_size=self.kernel_size,
429
+ stride=1
430
+ )
431
+ x = jnp.ones((2, 28, 16))
432
+ y = conv_t(x)
433
+
434
+ self.assertEqual(len(y.shape), 3)
435
+ self.assertEqual(y.shape[0], 2) # batch size
436
+ self.assertEqual(y.shape[-1], self.out_channels)
437
+ self.assertEqual(conv_t.in_channels, 16)
438
+ self.assertEqual(conv_t.out_channels, 8)
439
+ self.assertFalse(conv_t.channel_first)
440
+
441
+ def test_basic_channels_first(self):
442
+ """Test basic ConvTranspose1d with channels-first format."""
443
+ conv_t = brainstate.nn.ConvTranspose1d(
444
+ in_size=(16, 28), # (C, L) for channels-first
445
+ out_channels=self.out_channels,
446
+ kernel_size=self.kernel_size,
447
+ stride=1,
448
+ channel_first=True
449
+ )
450
+ x = jnp.ones((2, 16, 28))
451
+ y = conv_t(x)
452
+
453
+ self.assertEqual(len(y.shape), 3)
454
+ self.assertEqual(y.shape[0], 2) # batch size
455
+ self.assertEqual(y.shape[1], self.out_channels) # channels first
456
+ self.assertEqual(conv_t.in_channels, 16)
457
+ self.assertTrue(conv_t.channel_first)
458
+
459
+ def test_stride_upsampling(self):
460
+ """Test transposed convolution with stride for upsampling."""
461
+ conv_t = brainstate.nn.ConvTranspose1d(
462
+ in_size=(28, 16),
463
+ out_channels=8,
464
+ kernel_size=4,
465
+ stride=2,
466
+ padding='SAME'
467
+ )
468
+ x = jnp.ones((2, 28, 16))
469
+ y = conv_t(x)
470
+
471
+ # With stride=2, output should be approximately 2x larger
472
+ self.assertGreater(y.shape[1], x.shape[1])
473
+
474
+ def test_with_bias(self):
475
+ """Test ConvTranspose1d with bias."""
476
+ conv_t = brainstate.nn.ConvTranspose1d(
477
+ in_size=(50, 8),
478
+ out_channels=16,
479
+ kernel_size=3,
480
+ b_init=brainstate.init.Constant(0.0)
481
+ )
482
+ x = jnp.ones((4, 50, 8))
483
+ y = conv_t(x)
484
+
485
+ self.assertTrue('bias' in conv_t.weight.value)
486
+ self.assertEqual(y.shape[-1], 16)
487
+
488
+ def test_without_batch(self):
489
+ """Test ConvTranspose1d without batch dimension."""
490
+ conv_t = brainstate.nn.ConvTranspose1d(
491
+ in_size=(28, 16),
492
+ out_channels=8,
493
+ kernel_size=4
494
+ )
495
+ x = jnp.ones((28, 16))
496
+ y = conv_t(x)
497
+
498
+ self.assertEqual(len(y.shape), 2)
499
+ self.assertEqual(y.shape[-1], 8)
500
+
501
+ def test_groups(self):
502
+ """Test grouped transposed convolution."""
503
+ conv_t = brainstate.nn.ConvTranspose1d(
504
+ in_size=(28, 16),
505
+ out_channels=16,
506
+ kernel_size=3,
507
+ groups=4
508
+ )
509
+ x = jnp.ones((2, 28, 16))
510
+ y = conv_t(x)
511
+
512
+ self.assertEqual(y.shape[-1], 16)
513
+
514
+
515
+ class TestConvTranspose2d(unittest.TestCase):
516
+ """Test cases for ConvTranspose2d layer."""
517
+
518
+ def setUp(self):
519
+ """Set up test fixtures."""
520
+ self.in_size = (16, 16, 32)
521
+ self.out_channels = 16
522
+ self.kernel_size = 4
523
+
524
+ def test_basic_channels_last(self):
525
+ """Test basic ConvTranspose2d with channels-last format."""
526
+ conv_t = brainstate.nn.ConvTranspose2d(
527
+ in_size=self.in_size,
528
+ out_channels=self.out_channels,
529
+ kernel_size=self.kernel_size
530
+ )
531
+ x = jnp.ones((4, 16, 16, 32))
532
+ y = conv_t(x)
533
+
534
+ self.assertEqual(len(y.shape), 4)
535
+ self.assertEqual(y.shape[0], 4) # batch size
536
+ self.assertEqual(y.shape[-1], self.out_channels)
537
+ self.assertEqual(conv_t.in_channels, 32)
538
+ self.assertFalse(conv_t.channel_first)
539
+
540
+ def test_basic_channels_first(self):
541
+ """Test basic ConvTranspose2d with channels-first format."""
542
+ conv_t = brainstate.nn.ConvTranspose2d(
543
+ in_size=(32, 16, 16), # (C, H, W) for channels-first
544
+ out_channels=self.out_channels,
545
+ kernel_size=self.kernel_size,
546
+ channel_first=True
547
+ )
548
+ x = jnp.ones((4, 32, 16, 16))
549
+ y = conv_t(x)
550
+
551
+ self.assertEqual(len(y.shape), 4)
552
+ self.assertEqual(y.shape[1], self.out_channels) # channels first
553
+ self.assertTrue(conv_t.channel_first)
554
+
555
+ def test_stride_upsampling(self):
556
+ """Test 2x upsampling with stride=2."""
557
+ conv_t = brainstate.nn.ConvTranspose2d(
558
+ in_size=(16, 16, 32),
559
+ out_channels=16,
560
+ kernel_size=4,
561
+ stride=2,
562
+ padding='SAME'
563
+ )
564
+ x = jnp.ones((4, 16, 16, 32))
565
+ y = conv_t(x)
566
+
567
+ # With stride=2, output should be approximately 2x larger in each spatial dimension
568
+ self.assertGreater(y.shape[1], x.shape[1])
569
+ self.assertGreater(y.shape[2], x.shape[2])
570
+
571
+ def test_rectangular_kernel(self):
572
+ """Test ConvTranspose2d with rectangular kernel."""
573
+ conv_t = brainstate.nn.ConvTranspose2d(
574
+ in_size=(16, 16, 32),
575
+ out_channels=16,
576
+ kernel_size=(3, 5),
577
+ stride=1
578
+ )
579
+ x = jnp.ones((2, 16, 16, 32))
580
+ y = conv_t(x)
581
+
582
+ self.assertEqual(conv_t.kernel_size, (3, 5))
583
+ self.assertEqual(y.shape[-1], 16)
584
+
585
+ def test_padding_valid(self):
586
+ """Test ConvTranspose2d with VALID padding."""
587
+ conv_t = brainstate.nn.ConvTranspose2d(
588
+ in_size=(16, 16, 32),
589
+ out_channels=16,
590
+ kernel_size=4,
591
+ stride=2,
592
+ padding='VALID'
593
+ )
594
+ x = jnp.ones((2, 16, 16, 32))
595
+ y = conv_t(x)
596
+
597
+ # VALID padding means no padding, output computed by formula:
598
+ # out = (in - 1) * stride + kernel
599
+ # out = (16 - 1) * 2 + 4 = 34 (but JAX may compute it slightly differently)
600
+ # At minimum, it should upsample
601
+ self.assertGreater(y.shape[1], 16)
602
+
603
+ def test_groups(self):
604
+ """Test grouped transposed convolution."""
605
+ conv_t = brainstate.nn.ConvTranspose2d(
606
+ in_size=(16, 16, 32),
607
+ out_channels=32,
608
+ kernel_size=3,
609
+ groups=4
610
+ )
611
+ x = jnp.ones((2, 16, 16, 32))
612
+ y = conv_t(x)
613
+
614
+ self.assertEqual(y.shape[-1], 32)
615
+
616
+
617
+ class TestConvTranspose3d(unittest.TestCase):
618
+ """Test cases for ConvTranspose3d layer."""
619
+
620
+ def setUp(self):
621
+ """Set up test fixtures."""
622
+ self.in_size = (8, 8, 8, 16)
623
+ self.out_channels = 8
624
+ self.kernel_size = 4
625
+
626
+ def test_basic_channels_last(self):
627
+ """Test basic ConvTranspose3d with channels-last format."""
628
+ conv_t = brainstate.nn.ConvTranspose3d(
629
+ in_size=self.in_size,
630
+ out_channels=self.out_channels,
631
+ kernel_size=self.kernel_size
632
+ )
633
+ x = jnp.ones((2, 8, 8, 8, 16))
634
+ y = conv_t(x)
635
+
636
+ self.assertEqual(len(y.shape), 5)
637
+ self.assertEqual(y.shape[0], 2) # batch size
638
+ self.assertEqual(y.shape[-1], self.out_channels)
639
+ self.assertEqual(conv_t.in_channels, 16)
640
+
641
+ def test_basic_channels_first(self):
642
+ """Test basic ConvTranspose3d with channels-first format."""
643
+ conv_t = brainstate.nn.ConvTranspose3d(
644
+ in_size=(16, 8, 8, 8), # (C, H, W, D) for channels-first
645
+ out_channels=self.out_channels,
646
+ kernel_size=self.kernel_size,
647
+ channel_first=True
648
+ )
649
+ x = jnp.ones((2, 16, 8, 8, 8))
650
+ y = conv_t(x)
651
+
652
+ self.assertEqual(len(y.shape), 5)
653
+ self.assertEqual(y.shape[1], self.out_channels) # channels first
654
+ self.assertTrue(conv_t.channel_first)
655
+
656
+ def test_stride_upsampling(self):
657
+ """Test 3D upsampling with stride=2."""
658
+ conv_t = brainstate.nn.ConvTranspose3d(
659
+ in_size=(8, 8, 8, 16),
660
+ out_channels=8,
661
+ kernel_size=4,
662
+ stride=2,
663
+ padding='SAME'
664
+ )
665
+ x = jnp.ones((2, 8, 8, 8, 16))
666
+ y = conv_t(x)
667
+
668
+ # With stride=2, output should be approximately 2x larger
669
+ self.assertGreater(y.shape[1], x.shape[1])
670
+ self.assertGreater(y.shape[2], x.shape[2])
671
+ self.assertGreater(y.shape[3], x.shape[3])
672
+
673
+
674
+ class TestErrorHandlingConvTranspose(unittest.TestCase):
675
+ """Test error handling for transposed convolutions."""
676
+
677
+ def test_invalid_groups(self):
678
+ """Test that invalid groups raises assertion error."""
679
+ with self.assertRaises(AssertionError):
680
+ brainstate.nn.ConvTranspose2d(
681
+ in_size=(16, 16, 32),
682
+ out_channels=15, # Not divisible by groups
123
683
  kernel_size=3,
124
- stride=1,
125
- padding="SAME",
126
- w_initializer=brainstate.init.Constant(),
127
- b_initializer=brainstate.init.Constant() if use_bias else None,
684
+ groups=4
128
685
  )
129
- out = net(data)
130
- self.assertEqual(out.shape, (1, 3, 1))
131
- out = jnp.squeeze(out, axis=(0, 2))
132
- expected_out = jnp.asarray([2, 3, 2])
133
- if use_bias:
134
- expected_out += 1
135
- self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
136
-
137
-
138
- @pytest.mark.skip(reason="not implemented yet")
139
- class TestConvTranspose2d(parameterized.TestCase):
140
- def test_conv_transpose(self):
141
-
142
- x = jnp.ones((1, 8, 8, 3))
143
- for use_bias in [True, False]:
144
- conv_transpose_module = brainstate.nn.ConvTranspose2d(
145
- in_channels=3,
146
- out_channels=4,
147
- kernel_size=(3, 3),
148
- padding='VALID',
149
- w_initializer=brainstate.init.Constant(),
150
- b_initializer=brainstate.init.Constant() if use_bias else None,
151
- )
152
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
153
- y = conv_transpose_module(x)
154
- print(y.shape)
155
-
156
- def test_single_input_masked_conv_transpose(self):
157
-
158
- x = jnp.ones((1, 8, 8, 3))
159
- m = jnp.tril(jnp.ones((3, 3, 3, 4)))
160
- conv_transpose_module = brainstate.nn.ConvTranspose2d(
161
- in_channels=3,
162
- out_channels=4,
163
- kernel_size=(3, 3),
164
- padding='VALID',
165
- mask=m,
166
- w_initializer=brainstate.init.Constant(),
167
- )
168
- y = conv_transpose_module(x)
169
- print(y.shape)
170
-
171
- def test_computation_padding_same(self):
172
-
173
- x = jnp.ones((1, 8, 8, 3))
174
- for use_bias in [True, False]:
175
- conv_transpose_module = brainstate.nn.ConvTranspose2d(
176
- in_channels=3,
177
- out_channels=4,
178
- kernel_size=(3, 3),
179
- stride=1,
180
- padding='SAME',
181
- w_initializer=brainstate.init.Constant(),
182
- b_initializer=brainstate.init.Constant() if use_bias else None,
183
- )
184
- y = conv_transpose_module(x)
185
- print(y.shape)
186
-
187
-
188
- @pytest.mark.skip(reason="not implemented yet")
189
- class TestConvTranspose3d(parameterized.TestCase):
190
- def test_conv_transpose(self):
191
-
192
- x = jnp.ones((1, 8, 8, 8, 3))
193
- for use_bias in [True, False]:
194
- conv_transpose_module = brainstate.nn.ConvTranspose3d(
195
- in_channels=3,
196
- out_channels=4,
197
- kernel_size=(3, 3, 3),
198
- padding='VALID',
199
- w_initializer=brainstate.init.Constant(),
200
- b_initializer=brainstate.init.Constant() if use_bias else None,
201
- )
202
- y = conv_transpose_module(x)
203
- print(y.shape)
204
-
205
- def test_single_input_masked_conv_transpose(self):
206
-
207
- x = jnp.ones((1, 8, 8, 8, 3))
208
- m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
209
- conv_transpose_module = brainstate.nn.ConvTranspose3d(
210
- in_channels=3,
211
- out_channels=4,
212
- kernel_size=(3, 3, 3),
213
- padding='VALID',
214
- mask=m,
215
- w_initializer=brainstate.init.Constant(),
216
- )
217
- y = conv_transpose_module(x)
218
- print(y.shape)
219
-
220
- def test_computation_padding_same(self):
221
-
222
- x = jnp.ones((1, 8, 8, 8, 3))
223
- for use_bias in [True, False]:
224
- conv_transpose_module = brainstate.nn.ConvTranspose3d(
225
- in_channels=3,
226
- out_channels=4,
227
- kernel_size=(3, 3, 3),
228
- stride=1,
229
- padding='SAME',
230
- w_initializer=brainstate.init.Constant(),
231
- b_initializer=brainstate.init.Constant() if use_bias else None,
232
- )
233
- y = conv_transpose_module(x)
234
- print(y.shape)
235
686
 
687
+ def test_dimension_mismatch(self):
688
+ """Test that wrong input dimensions raise error."""
689
+ conv_t = brainstate.nn.ConvTranspose2d(
690
+ in_size=(16, 16, 32),
691
+ out_channels=16,
692
+ kernel_size=3
693
+ )
694
+ x = jnp.ones((2, 16, 16, 16)) # Wrong number of channels
695
+
696
+ with self.assertRaises(ValueError):
697
+ conv_t(x)
236
698
 
237
- if __name__ == '__main__':
238
- absltest.main()
699
+ def test_invalid_input_shape(self):
700
+ """Test that invalid input shape raises error."""
701
+ conv_t = brainstate.nn.ConvTranspose1d(
702
+ in_size=(28, 16),
703
+ out_channels=8,
704
+ kernel_size=3
705
+ )
706
+ x = jnp.ones((2, 2, 28, 16)) # Too many dimensions
707
+
708
+ with self.assertRaises(ValueError):
709
+ conv_t(x)
710
+
711
+
712
+ class TestOutputShapesConvTranspose(unittest.TestCase):
713
+ """Test output shape computation for transposed convolutions."""
714
+
715
+ def test_out_size_attribute_1d(self):
716
+ """Test that out_size attribute is correctly computed for 1D."""
717
+ conv_t = brainstate.nn.ConvTranspose1d(
718
+ in_size=(28, 16),
719
+ out_channels=8,
720
+ kernel_size=4,
721
+ stride=2
722
+ )
723
+
724
+ self.assertIsNotNone(conv_t.out_size)
725
+ self.assertEqual(len(conv_t.out_size), 2)
726
+
727
+ def test_out_size_attribute_2d(self):
728
+ """Test that out_size attribute is correctly computed for 2D."""
729
+ conv_t = brainstate.nn.ConvTranspose2d(
730
+ in_size=(16, 16, 32),
731
+ out_channels=16,
732
+ kernel_size=4,
733
+ stride=2
734
+ )
735
+
736
+ self.assertIsNotNone(conv_t.out_size)
737
+ self.assertEqual(len(conv_t.out_size), 3)
738
+
739
+ def test_upsampling_factor(self):
740
+ """Test that stride=2 approximately doubles spatial dimensions."""
741
+ conv_t = brainstate.nn.ConvTranspose2d(
742
+ in_size=(16, 16, 32),
743
+ out_channels=16,
744
+ kernel_size=4,
745
+ stride=2,
746
+ padding='SAME'
747
+ )
748
+ x = jnp.ones((2, 16, 16, 32))
749
+ y = conv_t(x)
750
+
751
+ # For SAME padding and stride=2, output should be approximately 2x input
752
+ self.assertGreaterEqual(y.shape[1], 28)
753
+ self.assertGreaterEqual(y.shape[2], 28)
754
+
755
+
756
+ class TestChannelFormatConsistencyConvTranspose(unittest.TestCase):
757
+ """Test consistency between different channel formats."""
758
+
759
+ def test_conv_transpose_1d_output_channels(self):
760
+ """Test that output channels are in correct position for both formats."""
761
+ # Channels-last
762
+ conv_t_last = brainstate.nn.ConvTranspose1d(
763
+ in_size=(28, 16),
764
+ out_channels=8,
765
+ kernel_size=3
766
+ )
767
+ x_last = jnp.ones((2, 28, 16))
768
+ y_last = conv_t_last(x_last)
769
+ self.assertEqual(y_last.shape[-1], 8)
770
+
771
+ # Channels-first
772
+ conv_t_first = brainstate.nn.ConvTranspose1d(
773
+ in_size=(16, 28),
774
+ out_channels=8,
775
+ kernel_size=3,
776
+ channel_first=True
777
+ )
778
+ x_first = jnp.ones((2, 16, 28))
779
+ y_first = conv_t_first(x_first)
780
+ self.assertEqual(y_first.shape[1], 8)
781
+
782
+ def test_conv_transpose_2d_output_channels(self):
783
+ """Test that output channels are in correct position for both formats."""
784
+ # Channels-last
785
+ conv_t_last = brainstate.nn.ConvTranspose2d(
786
+ in_size=(16, 16, 32),
787
+ out_channels=16,
788
+ kernel_size=3
789
+ )
790
+ x_last = jnp.ones((2, 16, 16, 32))
791
+ y_last = conv_t_last(x_last)
792
+ self.assertEqual(y_last.shape[-1], 16)
793
+
794
+ # Channels-first
795
+ conv_t_first = brainstate.nn.ConvTranspose2d(
796
+ in_size=(32, 16, 16),
797
+ out_channels=16,
798
+ kernel_size=3,
799
+ channel_first=True
800
+ )
801
+ x_first = jnp.ones((2, 32, 16, 16))
802
+ y_first = conv_t_first(x_first)
803
+ self.assertEqual(y_first.shape[1], 16)
804
+
805
+
806
+ class TestReproducibilityConvTranspose(unittest.TestCase):
807
+ """Test deterministic behavior of transposed convolutions."""
808
+
809
+ def test_deterministic_output(self):
810
+ """Test that same input produces same output."""
811
+ conv_t = brainstate.nn.ConvTranspose2d(
812
+ in_size=(16, 16, 32),
813
+ out_channels=16,
814
+ kernel_size=3
815
+ )
816
+ x = jnp.ones((2, 16, 16, 32))
817
+
818
+ y1 = conv_t(x)
819
+ y2 = conv_t(x)
820
+
821
+ self.assertTrue(jnp.allclose(y1, y2))
822
+
823
+
824
+ class TestKernelShapeConvTranspose(unittest.TestCase):
825
+ """Test kernel shape computation for transposed convolutions."""
826
+
827
+ def test_kernel_shape_1d(self):
828
+ """Test that kernel shape is correct for transposed conv 1D."""
829
+ conv_t = brainstate.nn.ConvTranspose1d(
830
+ in_size=(28, 16),
831
+ out_channels=8,
832
+ kernel_size=4,
833
+ groups=2
834
+ )
835
+ # For transpose conv: (kernel_size, out_channels, in_channels // groups)
836
+ expected_shape = (4, 8, 16 // 2)
837
+ self.assertEqual(conv_t.kernel_shape, expected_shape)
838
+
839
+ def test_kernel_shape_2d(self):
840
+ """Test that kernel shape is correct for transposed conv 2D."""
841
+ conv_t = brainstate.nn.ConvTranspose2d(
842
+ in_size=(16, 16, 32),
843
+ out_channels=16,
844
+ kernel_size=4,
845
+ groups=4
846
+ )
847
+ # For transpose conv: (kernel_h, kernel_w, out_channels, in_channels // groups)
848
+ expected_shape = (4, 4, 16, 32 // 4)
849
+ self.assertEqual(conv_t.kernel_shape, expected_shape)