brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,849 +1,849 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import unittest
19
-
20
- import jax
21
- import jax.numpy as jnp
22
-
23
- import brainstate
24
-
25
-
26
- class TestConv1d(unittest.TestCase):
27
- """Test cases for 1D convolution."""
28
-
29
- def test_basic_channels_last(self):
30
- """Test basic Conv1d with channels-last format."""
31
- conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
32
- x = jnp.ones((4, 100, 16))
33
- y = conv(x)
34
-
35
- self.assertEqual(y.shape, (4, 100, 32))
36
- self.assertEqual(conv.in_channels, 16)
37
- self.assertEqual(conv.out_channels, 32)
38
- self.assertFalse(conv.channel_first)
39
-
40
- def test_basic_channels_first(self):
41
- """Test basic Conv1d with channels-first format."""
42
- conv = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=5, channel_first=True)
43
- x = jnp.ones((4, 16, 100))
44
- y = conv(x)
45
-
46
- self.assertEqual(y.shape, (4, 32, 100))
47
- self.assertEqual(conv.in_channels, 16)
48
- self.assertEqual(conv.out_channels, 32)
49
- self.assertTrue(conv.channel_first)
50
-
51
- def test_without_batch(self):
52
- """Test Conv1d without batch dimension."""
53
- conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3)
54
- x = jnp.ones((50, 8))
55
- y = conv(x)
56
-
57
- self.assertEqual(y.shape, (50, 16))
58
-
59
- def test_stride(self):
60
- """Test Conv1d with stride."""
61
- conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, stride=2, padding='VALID')
62
- x = jnp.ones((2, 100, 8))
63
- y = conv(x)
64
-
65
- # VALID padding: output = (100 - 3 + 1) / 2 = 49
66
- self.assertEqual(y.shape, (2, 49, 16))
67
-
68
- def test_dilation(self):
69
- """Test Conv1d with dilated convolution."""
70
- conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, rhs_dilation=2)
71
- x = jnp.ones((2, 100, 8))
72
- y = conv(x)
73
-
74
- self.assertEqual(y.shape, (2, 100, 16))
75
-
76
- def test_groups(self):
77
- """Test Conv1d with grouped convolution."""
78
- conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3, groups=4)
79
- x = jnp.ones((2, 100, 16))
80
- y = conv(x)
81
-
82
- self.assertEqual(y.shape, (2, 100, 32))
83
- self.assertEqual(conv.groups, 4)
84
-
85
- def test_with_bias(self):
86
- """Test Conv1d with bias."""
87
- conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3,
88
- b_init=brainstate.init.Constant(0.0))
89
- x = jnp.ones((2, 50, 8))
90
- y = conv(x)
91
-
92
- self.assertEqual(y.shape, (2, 50, 16))
93
- self.assertIn('bias', conv.weight.value)
94
-
95
-
96
- class TestConv2d(unittest.TestCase):
97
- """Test cases for 2D convolution."""
98
-
99
- def test_basic_channels_last(self):
100
- """Test basic Conv2d with channels-last format."""
101
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
102
- x = jnp.ones((8, 32, 32, 3))
103
- y = conv(x)
104
-
105
- self.assertEqual(y.shape, (8, 32, 32, 64))
106
- self.assertEqual(conv.in_channels, 3)
107
- self.assertEqual(conv.out_channels, 64)
108
- self.assertFalse(conv.channel_first)
109
-
110
- def test_basic_channels_first(self):
111
- """Test basic Conv2d with channels-first format."""
112
- conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
113
- x = jnp.ones((8, 3, 32, 32))
114
- y = conv(x)
115
-
116
- self.assertEqual(y.shape, (8, 64, 32, 32))
117
- self.assertEqual(conv.in_channels, 3)
118
- self.assertEqual(conv.out_channels, 64)
119
- self.assertTrue(conv.channel_first)
120
-
121
- def test_rectangular_kernel(self):
122
- """Test Conv2d with rectangular kernel."""
123
- conv = brainstate.nn.Conv2d(in_size=(64, 64, 16), out_channels=32, kernel_size=(3, 5))
124
- x = jnp.ones((4, 64, 64, 16))
125
- y = conv(x)
126
-
127
- self.assertEqual(y.shape, (4, 64, 64, 32))
128
- self.assertEqual(conv.kernel_size, (3, 5))
129
-
130
- def test_stride_2d(self):
131
- """Test Conv2d with different strides."""
132
- conv = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, stride=(2, 2), padding='VALID')
133
- x = jnp.ones((4, 64, 64, 3))
134
- y = conv(x)
135
-
136
- # VALID padding: output = (64 - 3 + 1) / 2 = 31
137
- self.assertEqual(y.shape, (4, 31, 31, 32))
138
-
139
- def test_depthwise_convolution(self):
140
- """Test depthwise convolution (groups = in_channels)."""
141
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=16, kernel_size=3, groups=16)
142
- x = jnp.ones((4, 32, 32, 16))
143
- y = conv(x)
144
-
145
- self.assertEqual(y.shape, (4, 32, 32, 16))
146
- self.assertEqual(conv.groups, 16)
147
-
148
- def test_padding_same_vs_valid(self):
149
- """Test different padding modes."""
150
- conv_same = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='SAME')
151
- conv_valid = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
152
-
153
- x = jnp.ones((2, 32, 32, 3))
154
- y_same = conv_same(x)
155
- y_valid = conv_valid(x)
156
-
157
- self.assertEqual(y_same.shape, (2, 32, 32, 16)) # SAME preserves size
158
- self.assertEqual(y_valid.shape, (2, 28, 28, 16)) # VALID reduces size
159
-
160
-
161
- class TestConv3d(unittest.TestCase):
162
- """Test cases for 3D convolution."""
163
-
164
- def test_basic_channels_last(self):
165
- """Test basic Conv3d with channels-last format."""
166
- conv = brainstate.nn.Conv3d(in_size=(16, 16, 16, 1), out_channels=32, kernel_size=3)
167
- x = jnp.ones((2, 16, 16, 16, 1))
168
- y = conv(x)
169
-
170
- self.assertEqual(y.shape, (2, 16, 16, 16, 32))
171
- self.assertEqual(conv.in_channels, 1)
172
- self.assertEqual(conv.out_channels, 32)
173
-
174
- def test_basic_channels_first(self):
175
- """Test basic Conv3d with channels-first format."""
176
- conv = brainstate.nn.Conv3d(in_size=(1, 16, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
177
- x = jnp.ones((2, 1, 16, 16, 16))
178
- y = conv(x)
179
-
180
- self.assertEqual(y.shape, (2, 32, 16, 16, 16))
181
- self.assertEqual(conv.in_channels, 1)
182
- self.assertEqual(conv.out_channels, 32)
183
-
184
- def test_video_data(self):
185
- """Test Conv3d for video data."""
186
- conv = brainstate.nn.Conv3d(in_size=(8, 32, 32, 3), out_channels=64, kernel_size=(3, 3, 3))
187
- x = jnp.ones((4, 8, 32, 32, 3)) # batch, frames, height, width, channels
188
- y = conv(x)
189
-
190
- self.assertEqual(y.shape, (4, 8, 32, 32, 64))
191
-
192
-
193
- class TestScaledWSConv1d(unittest.TestCase):
194
- """Test cases for 1D convolution with weight standardization."""
195
-
196
- def test_basic(self):
197
- """Test basic ScaledWSConv1d."""
198
- conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
199
- x = jnp.ones((4, 100, 16))
200
- y = conv(x)
201
-
202
- self.assertEqual(y.shape, (4, 100, 32))
203
- self.assertIsNotNone(conv.eps)
204
-
205
- def test_with_gain(self):
206
- """Test ScaledWSConv1d with gain parameter."""
207
- conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=True)
208
- x = jnp.ones((4, 100, 16))
209
- y = conv(x)
210
-
211
- self.assertEqual(y.shape, (4, 100, 32))
212
- self.assertIn('gain', conv.weight.value)
213
-
214
- def test_without_gain(self):
215
- """Test ScaledWSConv1d without gain parameter."""
216
- conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=False)
217
- x = jnp.ones((4, 100, 16))
218
- y = conv(x)
219
-
220
- self.assertEqual(y.shape, (4, 100, 32))
221
- self.assertNotIn('gain', conv.weight.value)
222
-
223
- def test_custom_eps(self):
224
- """Test ScaledWSConv1d with custom epsilon."""
225
- conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, eps=1e-5)
226
- self.assertEqual(conv.eps, 1e-5)
227
-
228
-
229
- class TestScaledWSConv2d(unittest.TestCase):
230
- """Test cases for 2D convolution with weight standardization."""
231
-
232
- def test_basic_channels_last(self):
233
- """Test basic ScaledWSConv2d with channels-last format."""
234
- conv = brainstate.nn.ScaledWSConv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3)
235
- x = jnp.ones((8, 64, 64, 3))
236
- y = conv(x)
237
-
238
- self.assertEqual(y.shape, (8, 64, 64, 32))
239
-
240
- def test_basic_channels_first(self):
241
- """Test basic ScaledWSConv2d with channels-first format."""
242
- conv = brainstate.nn.ScaledWSConv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
243
- x = jnp.ones((8, 3, 64, 64))
244
- y = conv(x)
245
-
246
- self.assertEqual(y.shape, (8, 32, 64, 64))
247
-
248
- def test_with_group_norm_style(self):
249
- """Test ScaledWSConv2d for use with group normalization."""
250
- conv = brainstate.nn.ScaledWSConv2d(
251
- in_size=(32, 32, 16),
252
- out_channels=32,
253
- kernel_size=3,
254
- ws_gain=True,
255
- groups=1
256
- )
257
- x = jnp.ones((4, 32, 32, 16))
258
- y = conv(x)
259
-
260
- self.assertEqual(y.shape, (4, 32, 32, 32))
261
-
262
-
263
- class TestScaledWSConv3d(unittest.TestCase):
264
- """Test cases for 3D convolution with weight standardization."""
265
-
266
- def test_basic(self):
267
- """Test basic ScaledWSConv3d."""
268
- conv = brainstate.nn.ScaledWSConv3d(in_size=(8, 16, 16, 3), out_channels=32, kernel_size=3)
269
- x = jnp.ones((2, 8, 16, 16, 3))
270
- y = conv(x)
271
-
272
- self.assertEqual(y.shape, (2, 8, 16, 16, 32))
273
-
274
- def test_channels_first(self):
275
- """Test ScaledWSConv3d with channels-first format."""
276
- conv = brainstate.nn.ScaledWSConv3d(in_size=(3, 8, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
277
- x = jnp.ones((2, 3, 8, 16, 16))
278
- y = conv(x)
279
-
280
- self.assertEqual(y.shape, (2, 32, 8, 16, 16))
281
-
282
-
283
- class TestErrorHandling(unittest.TestCase):
284
- """Test error handling and edge cases."""
285
-
286
- def test_invalid_input_shape(self):
287
- """Test that invalid input shapes raise appropriate errors."""
288
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
289
- x_wrong = jnp.ones((8, 32, 32, 16)) # Wrong number of channels
290
-
291
- with self.assertRaises(ValueError):
292
- conv(x_wrong)
293
-
294
- def test_invalid_groups(self):
295
- """Test that invalid group configurations raise errors."""
296
- with self.assertRaises(AssertionError):
297
- # out_channels not divisible by groups
298
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=30, kernel_size=3, groups=4)
299
-
300
- def test_dimension_mismatch(self):
301
- """Test dimension mismatch detection."""
302
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
303
- x_1d = jnp.ones((8, 32, 3)) # 1D instead of 2D
304
-
305
- with self.assertRaises(ValueError):
306
- conv(x_1d)
307
-
308
-
309
- class TestOutputShapes(unittest.TestCase):
310
- """Test output shape calculations."""
311
-
312
- def test_same_padding_preserves_size(self):
313
- """Test that SAME padding preserves spatial dimensions when stride=1."""
314
- for kernel_size in [3, 5, 7]:
315
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=kernel_size, padding='SAME')
316
- x = jnp.ones((4, 32, 32, 3))
317
- y = conv(x)
318
- self.assertEqual(y.shape, (4, 32, 32, 16), f"Failed for kernel_size={kernel_size}")
319
-
320
- def test_valid_padding_reduces_size(self):
321
- """Test that VALID padding reduces spatial dimensions."""
322
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
323
- x = jnp.ones((4, 32, 32, 3))
324
- y = conv(x)
325
- # 32 - 5 + 1 = 28
326
- self.assertEqual(y.shape, (4, 28, 28, 16))
327
-
328
- def test_output_size_attribute(self):
329
- """Test that out_size attribute is correctly computed."""
330
- conv_cl = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, channel_first=False)
331
- conv_cf = brainstate.nn.Conv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
332
-
333
- self.assertEqual(conv_cl.out_size, (64, 64, 32))
334
- self.assertEqual(conv_cf.out_size, (32, 64, 64))
335
-
336
-
337
- class TestChannelFormatConsistency(unittest.TestCase):
338
- """Test consistency between channels-first and channels-last formats."""
339
-
340
- def test_conv1d_output_channels(self):
341
- """Test that output channels are in correct position for both formats."""
342
- conv_cl = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3)
343
- conv_cf = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=3, channel_first=True)
344
-
345
- x_cl = jnp.ones((4, 100, 16))
346
- x_cf = jnp.ones((4, 16, 100))
347
-
348
- y_cl = conv_cl(x_cl)
349
- y_cf = conv_cf(x_cf)
350
-
351
- # Channels-last: channels in last dimension
352
- self.assertEqual(y_cl.shape[-1], 32)
353
- # Channels-first: channels in first dimension (after batch)
354
- self.assertEqual(y_cf.shape[1], 32)
355
-
356
- def test_conv2d_output_channels(self):
357
- """Test 2D output channel positions."""
358
- conv_cl = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
359
- conv_cf = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
360
-
361
- x_cl = jnp.ones((4, 32, 32, 3))
362
- x_cf = jnp.ones((4, 3, 32, 32))
363
-
364
- y_cl = conv_cl(x_cl)
365
- y_cf = conv_cf(x_cf)
366
-
367
- self.assertEqual(y_cl.shape[-1], 64)
368
- self.assertEqual(y_cf.shape[1], 64)
369
-
370
-
371
- class TestReproducibility(unittest.TestCase):
372
- """Test reproducibility with fixed seeds."""
373
-
374
- def test_deterministic_output(self):
375
- """Test that same seed produces same output."""
376
- key = jax.random.PRNGKey(42)
377
-
378
- conv1 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
379
- conv2 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
380
-
381
- # Use same random key for input
382
- x = jax.random.normal(key, (4, 32, 32, 3))
383
-
384
- # Note: outputs will differ due to different weight initialization
385
- # This test just ensures no crashes with random inputs
386
- y1 = conv1(x)
387
- y2 = conv2(x)
388
-
389
- self.assertEqual(y1.shape, y2.shape)
390
-
391
-
392
- class TestRepr(unittest.TestCase):
393
- """Test string representations."""
394
-
395
- def test_conv_repr_channels_last(self):
396
- """Test __repr__ for channels-last format."""
397
- conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
398
- repr_str = repr(conv)
399
-
400
- self.assertIn('Conv2d', repr_str)
401
- self.assertIn('channel_first=False', repr_str)
402
- self.assertIn('in_channels=3', repr_str)
403
- self.assertIn('out_channels=64', repr_str)
404
-
405
- def test_conv_repr_channels_first(self):
406
- """Test __repr__ for channels-first format."""
407
- conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
408
- repr_str = repr(conv)
409
-
410
- self.assertIn('Conv2d', repr_str)
411
- self.assertIn('channel_first=True', repr_str)
412
-
413
-
414
- class TestConvTranspose1d(unittest.TestCase):
415
- """Test cases for ConvTranspose1d layer."""
416
-
417
- def setUp(self):
418
- """Set up test fixtures."""
419
- self.in_size = (28, 16)
420
- self.out_channels = 8
421
- self.kernel_size = 4
422
-
423
- def test_basic_channels_last(self):
424
- """Test basic ConvTranspose1d with channels-last format."""
425
- conv_t = brainstate.nn.ConvTranspose1d(
426
- in_size=self.in_size,
427
- out_channels=self.out_channels,
428
- kernel_size=self.kernel_size,
429
- stride=1
430
- )
431
- x = jnp.ones((2, 28, 16))
432
- y = conv_t(x)
433
-
434
- self.assertEqual(len(y.shape), 3)
435
- self.assertEqual(y.shape[0], 2) # batch size
436
- self.assertEqual(y.shape[-1], self.out_channels)
437
- self.assertEqual(conv_t.in_channels, 16)
438
- self.assertEqual(conv_t.out_channels, 8)
439
- self.assertFalse(conv_t.channel_first)
440
-
441
- def test_basic_channels_first(self):
442
- """Test basic ConvTranspose1d with channels-first format."""
443
- conv_t = brainstate.nn.ConvTranspose1d(
444
- in_size=(16, 28), # (C, L) for channels-first
445
- out_channels=self.out_channels,
446
- kernel_size=self.kernel_size,
447
- stride=1,
448
- channel_first=True
449
- )
450
- x = jnp.ones((2, 16, 28))
451
- y = conv_t(x)
452
-
453
- self.assertEqual(len(y.shape), 3)
454
- self.assertEqual(y.shape[0], 2) # batch size
455
- self.assertEqual(y.shape[1], self.out_channels) # channels first
456
- self.assertEqual(conv_t.in_channels, 16)
457
- self.assertTrue(conv_t.channel_first)
458
-
459
- def test_stride_upsampling(self):
460
- """Test transposed convolution with stride for upsampling."""
461
- conv_t = brainstate.nn.ConvTranspose1d(
462
- in_size=(28, 16),
463
- out_channels=8,
464
- kernel_size=4,
465
- stride=2,
466
- padding='SAME'
467
- )
468
- x = jnp.ones((2, 28, 16))
469
- y = conv_t(x)
470
-
471
- # With stride=2, output should be approximately 2x larger
472
- self.assertGreater(y.shape[1], x.shape[1])
473
-
474
- def test_with_bias(self):
475
- """Test ConvTranspose1d with bias."""
476
- conv_t = brainstate.nn.ConvTranspose1d(
477
- in_size=(50, 8),
478
- out_channels=16,
479
- kernel_size=3,
480
- b_init=brainstate.init.Constant(0.0)
481
- )
482
- x = jnp.ones((4, 50, 8))
483
- y = conv_t(x)
484
-
485
- self.assertTrue('bias' in conv_t.weight.value)
486
- self.assertEqual(y.shape[-1], 16)
487
-
488
- def test_without_batch(self):
489
- """Test ConvTranspose1d without batch dimension."""
490
- conv_t = brainstate.nn.ConvTranspose1d(
491
- in_size=(28, 16),
492
- out_channels=8,
493
- kernel_size=4
494
- )
495
- x = jnp.ones((28, 16))
496
- y = conv_t(x)
497
-
498
- self.assertEqual(len(y.shape), 2)
499
- self.assertEqual(y.shape[-1], 8)
500
-
501
- def test_groups(self):
502
- """Test grouped transposed convolution."""
503
- conv_t = brainstate.nn.ConvTranspose1d(
504
- in_size=(28, 16),
505
- out_channels=16,
506
- kernel_size=3,
507
- groups=4
508
- )
509
- x = jnp.ones((2, 28, 16))
510
- y = conv_t(x)
511
-
512
- self.assertEqual(y.shape[-1], 16)
513
-
514
-
515
- class TestConvTranspose2d(unittest.TestCase):
516
- """Test cases for ConvTranspose2d layer."""
517
-
518
- def setUp(self):
519
- """Set up test fixtures."""
520
- self.in_size = (16, 16, 32)
521
- self.out_channels = 16
522
- self.kernel_size = 4
523
-
524
- def test_basic_channels_last(self):
525
- """Test basic ConvTranspose2d with channels-last format."""
526
- conv_t = brainstate.nn.ConvTranspose2d(
527
- in_size=self.in_size,
528
- out_channels=self.out_channels,
529
- kernel_size=self.kernel_size
530
- )
531
- x = jnp.ones((4, 16, 16, 32))
532
- y = conv_t(x)
533
-
534
- self.assertEqual(len(y.shape), 4)
535
- self.assertEqual(y.shape[0], 4) # batch size
536
- self.assertEqual(y.shape[-1], self.out_channels)
537
- self.assertEqual(conv_t.in_channels, 32)
538
- self.assertFalse(conv_t.channel_first)
539
-
540
- def test_basic_channels_first(self):
541
- """Test basic ConvTranspose2d with channels-first format."""
542
- conv_t = brainstate.nn.ConvTranspose2d(
543
- in_size=(32, 16, 16), # (C, H, W) for channels-first
544
- out_channels=self.out_channels,
545
- kernel_size=self.kernel_size,
546
- channel_first=True
547
- )
548
- x = jnp.ones((4, 32, 16, 16))
549
- y = conv_t(x)
550
-
551
- self.assertEqual(len(y.shape), 4)
552
- self.assertEqual(y.shape[1], self.out_channels) # channels first
553
- self.assertTrue(conv_t.channel_first)
554
-
555
- def test_stride_upsampling(self):
556
- """Test 2x upsampling with stride=2."""
557
- conv_t = brainstate.nn.ConvTranspose2d(
558
- in_size=(16, 16, 32),
559
- out_channels=16,
560
- kernel_size=4,
561
- stride=2,
562
- padding='SAME'
563
- )
564
- x = jnp.ones((4, 16, 16, 32))
565
- y = conv_t(x)
566
-
567
- # With stride=2, output should be approximately 2x larger in each spatial dimension
568
- self.assertGreater(y.shape[1], x.shape[1])
569
- self.assertGreater(y.shape[2], x.shape[2])
570
-
571
- def test_rectangular_kernel(self):
572
- """Test ConvTranspose2d with rectangular kernel."""
573
- conv_t = brainstate.nn.ConvTranspose2d(
574
- in_size=(16, 16, 32),
575
- out_channels=16,
576
- kernel_size=(3, 5),
577
- stride=1
578
- )
579
- x = jnp.ones((2, 16, 16, 32))
580
- y = conv_t(x)
581
-
582
- self.assertEqual(conv_t.kernel_size, (3, 5))
583
- self.assertEqual(y.shape[-1], 16)
584
-
585
- def test_padding_valid(self):
586
- """Test ConvTranspose2d with VALID padding."""
587
- conv_t = brainstate.nn.ConvTranspose2d(
588
- in_size=(16, 16, 32),
589
- out_channels=16,
590
- kernel_size=4,
591
- stride=2,
592
- padding='VALID'
593
- )
594
- x = jnp.ones((2, 16, 16, 32))
595
- y = conv_t(x)
596
-
597
- # VALID padding means no padding, output computed by formula:
598
- # out = (in - 1) * stride + kernel
599
- # out = (16 - 1) * 2 + 4 = 34 (but JAX may compute it slightly differently)
600
- # At minimum, it should upsample
601
- self.assertGreater(y.shape[1], 16)
602
-
603
- def test_groups(self):
604
- """Test grouped transposed convolution."""
605
- conv_t = brainstate.nn.ConvTranspose2d(
606
- in_size=(16, 16, 32),
607
- out_channels=32,
608
- kernel_size=3,
609
- groups=4
610
- )
611
- x = jnp.ones((2, 16, 16, 32))
612
- y = conv_t(x)
613
-
614
- self.assertEqual(y.shape[-1], 32)
615
-
616
-
617
- class TestConvTranspose3d(unittest.TestCase):
618
- """Test cases for ConvTranspose3d layer."""
619
-
620
- def setUp(self):
621
- """Set up test fixtures."""
622
- self.in_size = (8, 8, 8, 16)
623
- self.out_channels = 8
624
- self.kernel_size = 4
625
-
626
- def test_basic_channels_last(self):
627
- """Test basic ConvTranspose3d with channels-last format."""
628
- conv_t = brainstate.nn.ConvTranspose3d(
629
- in_size=self.in_size,
630
- out_channels=self.out_channels,
631
- kernel_size=self.kernel_size
632
- )
633
- x = jnp.ones((2, 8, 8, 8, 16))
634
- y = conv_t(x)
635
-
636
- self.assertEqual(len(y.shape), 5)
637
- self.assertEqual(y.shape[0], 2) # batch size
638
- self.assertEqual(y.shape[-1], self.out_channels)
639
- self.assertEqual(conv_t.in_channels, 16)
640
-
641
- def test_basic_channels_first(self):
642
- """Test basic ConvTranspose3d with channels-first format."""
643
- conv_t = brainstate.nn.ConvTranspose3d(
644
- in_size=(16, 8, 8, 8), # (C, H, W, D) for channels-first
645
- out_channels=self.out_channels,
646
- kernel_size=self.kernel_size,
647
- channel_first=True
648
- )
649
- x = jnp.ones((2, 16, 8, 8, 8))
650
- y = conv_t(x)
651
-
652
- self.assertEqual(len(y.shape), 5)
653
- self.assertEqual(y.shape[1], self.out_channels) # channels first
654
- self.assertTrue(conv_t.channel_first)
655
-
656
- def test_stride_upsampling(self):
657
- """Test 3D upsampling with stride=2."""
658
- conv_t = brainstate.nn.ConvTranspose3d(
659
- in_size=(8, 8, 8, 16),
660
- out_channels=8,
661
- kernel_size=4,
662
- stride=2,
663
- padding='SAME'
664
- )
665
- x = jnp.ones((2, 8, 8, 8, 16))
666
- y = conv_t(x)
667
-
668
- # With stride=2, output should be approximately 2x larger
669
- self.assertGreater(y.shape[1], x.shape[1])
670
- self.assertGreater(y.shape[2], x.shape[2])
671
- self.assertGreater(y.shape[3], x.shape[3])
672
-
673
-
674
- class TestErrorHandlingConvTranspose(unittest.TestCase):
675
- """Test error handling for transposed convolutions."""
676
-
677
- def test_invalid_groups(self):
678
- """Test that invalid groups raises assertion error."""
679
- with self.assertRaises(AssertionError):
680
- brainstate.nn.ConvTranspose2d(
681
- in_size=(16, 16, 32),
682
- out_channels=15, # Not divisible by groups
683
- kernel_size=3,
684
- groups=4
685
- )
686
-
687
- def test_dimension_mismatch(self):
688
- """Test that wrong input dimensions raise error."""
689
- conv_t = brainstate.nn.ConvTranspose2d(
690
- in_size=(16, 16, 32),
691
- out_channels=16,
692
- kernel_size=3
693
- )
694
- x = jnp.ones((2, 16, 16, 16)) # Wrong number of channels
695
-
696
- with self.assertRaises(ValueError):
697
- conv_t(x)
698
-
699
- def test_invalid_input_shape(self):
700
- """Test that invalid input shape raises error."""
701
- conv_t = brainstate.nn.ConvTranspose1d(
702
- in_size=(28, 16),
703
- out_channels=8,
704
- kernel_size=3
705
- )
706
- x = jnp.ones((2, 2, 28, 16)) # Too many dimensions
707
-
708
- with self.assertRaises(ValueError):
709
- conv_t(x)
710
-
711
-
712
- class TestOutputShapesConvTranspose(unittest.TestCase):
713
- """Test output shape computation for transposed convolutions."""
714
-
715
- def test_out_size_attribute_1d(self):
716
- """Test that out_size attribute is correctly computed for 1D."""
717
- conv_t = brainstate.nn.ConvTranspose1d(
718
- in_size=(28, 16),
719
- out_channels=8,
720
- kernel_size=4,
721
- stride=2
722
- )
723
-
724
- self.assertIsNotNone(conv_t.out_size)
725
- self.assertEqual(len(conv_t.out_size), 2)
726
-
727
- def test_out_size_attribute_2d(self):
728
- """Test that out_size attribute is correctly computed for 2D."""
729
- conv_t = brainstate.nn.ConvTranspose2d(
730
- in_size=(16, 16, 32),
731
- out_channels=16,
732
- kernel_size=4,
733
- stride=2
734
- )
735
-
736
- self.assertIsNotNone(conv_t.out_size)
737
- self.assertEqual(len(conv_t.out_size), 3)
738
-
739
- def test_upsampling_factor(self):
740
- """Test that stride=2 approximately doubles spatial dimensions."""
741
- conv_t = brainstate.nn.ConvTranspose2d(
742
- in_size=(16, 16, 32),
743
- out_channels=16,
744
- kernel_size=4,
745
- stride=2,
746
- padding='SAME'
747
- )
748
- x = jnp.ones((2, 16, 16, 32))
749
- y = conv_t(x)
750
-
751
- # For SAME padding and stride=2, output should be approximately 2x input
752
- self.assertGreaterEqual(y.shape[1], 28)
753
- self.assertGreaterEqual(y.shape[2], 28)
754
-
755
-
756
- class TestChannelFormatConsistencyConvTranspose(unittest.TestCase):
757
- """Test consistency between different channel formats."""
758
-
759
- def test_conv_transpose_1d_output_channels(self):
760
- """Test that output channels are in correct position for both formats."""
761
- # Channels-last
762
- conv_t_last = brainstate.nn.ConvTranspose1d(
763
- in_size=(28, 16),
764
- out_channels=8,
765
- kernel_size=3
766
- )
767
- x_last = jnp.ones((2, 28, 16))
768
- y_last = conv_t_last(x_last)
769
- self.assertEqual(y_last.shape[-1], 8)
770
-
771
- # Channels-first
772
- conv_t_first = brainstate.nn.ConvTranspose1d(
773
- in_size=(16, 28),
774
- out_channels=8,
775
- kernel_size=3,
776
- channel_first=True
777
- )
778
- x_first = jnp.ones((2, 16, 28))
779
- y_first = conv_t_first(x_first)
780
- self.assertEqual(y_first.shape[1], 8)
781
-
782
- def test_conv_transpose_2d_output_channels(self):
783
- """Test that output channels are in correct position for both formats."""
784
- # Channels-last
785
- conv_t_last = brainstate.nn.ConvTranspose2d(
786
- in_size=(16, 16, 32),
787
- out_channels=16,
788
- kernel_size=3
789
- )
790
- x_last = jnp.ones((2, 16, 16, 32))
791
- y_last = conv_t_last(x_last)
792
- self.assertEqual(y_last.shape[-1], 16)
793
-
794
- # Channels-first
795
- conv_t_first = brainstate.nn.ConvTranspose2d(
796
- in_size=(32, 16, 16),
797
- out_channels=16,
798
- kernel_size=3,
799
- channel_first=True
800
- )
801
- x_first = jnp.ones((2, 32, 16, 16))
802
- y_first = conv_t_first(x_first)
803
- self.assertEqual(y_first.shape[1], 16)
804
-
805
-
806
- class TestReproducibilityConvTranspose(unittest.TestCase):
807
- """Test deterministic behavior of transposed convolutions."""
808
-
809
- def test_deterministic_output(self):
810
- """Test that same input produces same output."""
811
- conv_t = brainstate.nn.ConvTranspose2d(
812
- in_size=(16, 16, 32),
813
- out_channels=16,
814
- kernel_size=3
815
- )
816
- x = jnp.ones((2, 16, 16, 32))
817
-
818
- y1 = conv_t(x)
819
- y2 = conv_t(x)
820
-
821
- self.assertTrue(jnp.allclose(y1, y2))
822
-
823
-
824
- class TestKernelShapeConvTranspose(unittest.TestCase):
825
- """Test kernel shape computation for transposed convolutions."""
826
-
827
- def test_kernel_shape_1d(self):
828
- """Test that kernel shape is correct for transposed conv 1D."""
829
- conv_t = brainstate.nn.ConvTranspose1d(
830
- in_size=(28, 16),
831
- out_channels=8,
832
- kernel_size=4,
833
- groups=2
834
- )
835
- # For transpose conv: (kernel_size, out_channels, in_channels // groups)
836
- expected_shape = (4, 8, 16 // 2)
837
- self.assertEqual(conv_t.kernel_shape, expected_shape)
838
-
839
- def test_kernel_shape_2d(self):
840
- """Test that kernel shape is correct for transposed conv 2D."""
841
- conv_t = brainstate.nn.ConvTranspose2d(
842
- in_size=(16, 16, 32),
843
- out_channels=16,
844
- kernel_size=4,
845
- groups=4
846
- )
847
- # For transpose conv: (kernel_h, kernel_w, out_channels, in_channels // groups)
848
- expected_shape = (4, 4, 16, 32 // 4)
849
- self.assertEqual(conv_t.kernel_shape, expected_shape)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import unittest
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ import brainstate
24
+
25
+
26
+ class TestConv1d(unittest.TestCase):
27
+ """Test cases for 1D convolution."""
28
+
29
+ def test_basic_channels_last(self):
30
+ """Test basic Conv1d with channels-last format."""
31
+ conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
32
+ x = jnp.ones((4, 100, 16))
33
+ y = conv(x)
34
+
35
+ self.assertEqual(y.shape, (4, 100, 32))
36
+ self.assertEqual(conv.in_channels, 16)
37
+ self.assertEqual(conv.out_channels, 32)
38
+ self.assertFalse(conv.channel_first)
39
+
40
+ def test_basic_channels_first(self):
41
+ """Test basic Conv1d with channels-first format."""
42
+ conv = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=5, channel_first=True)
43
+ x = jnp.ones((4, 16, 100))
44
+ y = conv(x)
45
+
46
+ self.assertEqual(y.shape, (4, 32, 100))
47
+ self.assertEqual(conv.in_channels, 16)
48
+ self.assertEqual(conv.out_channels, 32)
49
+ self.assertTrue(conv.channel_first)
50
+
51
+ def test_without_batch(self):
52
+ """Test Conv1d without batch dimension."""
53
+ conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3)
54
+ x = jnp.ones((50, 8))
55
+ y = conv(x)
56
+
57
+ self.assertEqual(y.shape, (50, 16))
58
+
59
+ def test_stride(self):
60
+ """Test Conv1d with stride."""
61
+ conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, stride=2, padding='VALID')
62
+ x = jnp.ones((2, 100, 8))
63
+ y = conv(x)
64
+
65
+ # VALID padding: output = (100 - 3 + 1) / 2 = 49
66
+ self.assertEqual(y.shape, (2, 49, 16))
67
+
68
+ def test_dilation(self):
69
+ """Test Conv1d with dilated convolution."""
70
+ conv = brainstate.nn.Conv1d(in_size=(100, 8), out_channels=16, kernel_size=3, rhs_dilation=2)
71
+ x = jnp.ones((2, 100, 8))
72
+ y = conv(x)
73
+
74
+ self.assertEqual(y.shape, (2, 100, 16))
75
+
76
+ def test_groups(self):
77
+ """Test Conv1d with grouped convolution."""
78
+ conv = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3, groups=4)
79
+ x = jnp.ones((2, 100, 16))
80
+ y = conv(x)
81
+
82
+ self.assertEqual(y.shape, (2, 100, 32))
83
+ self.assertEqual(conv.groups, 4)
84
+
85
+ def test_with_bias(self):
86
+ """Test Conv1d with bias."""
87
+ conv = brainstate.nn.Conv1d(in_size=(50, 8), out_channels=16, kernel_size=3,
88
+ b_init=brainstate.init.Constant(0.0))
89
+ x = jnp.ones((2, 50, 8))
90
+ y = conv(x)
91
+
92
+ self.assertEqual(y.shape, (2, 50, 16))
93
+ self.assertIn('bias', conv.weight.value)
94
+
95
+
96
+ class TestConv2d(unittest.TestCase):
97
+ """Test cases for 2D convolution."""
98
+
99
+ def test_basic_channels_last(self):
100
+ """Test basic Conv2d with channels-last format."""
101
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
102
+ x = jnp.ones((8, 32, 32, 3))
103
+ y = conv(x)
104
+
105
+ self.assertEqual(y.shape, (8, 32, 32, 64))
106
+ self.assertEqual(conv.in_channels, 3)
107
+ self.assertEqual(conv.out_channels, 64)
108
+ self.assertFalse(conv.channel_first)
109
+
110
+ def test_basic_channels_first(self):
111
+ """Test basic Conv2d with channels-first format."""
112
+ conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
113
+ x = jnp.ones((8, 3, 32, 32))
114
+ y = conv(x)
115
+
116
+ self.assertEqual(y.shape, (8, 64, 32, 32))
117
+ self.assertEqual(conv.in_channels, 3)
118
+ self.assertEqual(conv.out_channels, 64)
119
+ self.assertTrue(conv.channel_first)
120
+
121
+ def test_rectangular_kernel(self):
122
+ """Test Conv2d with rectangular kernel."""
123
+ conv = brainstate.nn.Conv2d(in_size=(64, 64, 16), out_channels=32, kernel_size=(3, 5))
124
+ x = jnp.ones((4, 64, 64, 16))
125
+ y = conv(x)
126
+
127
+ self.assertEqual(y.shape, (4, 64, 64, 32))
128
+ self.assertEqual(conv.kernel_size, (3, 5))
129
+
130
+ def test_stride_2d(self):
131
+ """Test Conv2d with different strides."""
132
+ conv = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, stride=(2, 2), padding='VALID')
133
+ x = jnp.ones((4, 64, 64, 3))
134
+ y = conv(x)
135
+
136
+ # VALID padding: output = (64 - 3 + 1) / 2 = 31
137
+ self.assertEqual(y.shape, (4, 31, 31, 32))
138
+
139
+ def test_depthwise_convolution(self):
140
+ """Test depthwise convolution (groups = in_channels)."""
141
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=16, kernel_size=3, groups=16)
142
+ x = jnp.ones((4, 32, 32, 16))
143
+ y = conv(x)
144
+
145
+ self.assertEqual(y.shape, (4, 32, 32, 16))
146
+ self.assertEqual(conv.groups, 16)
147
+
148
+ def test_padding_same_vs_valid(self):
149
+ """Test different padding modes."""
150
+ conv_same = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='SAME')
151
+ conv_valid = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
152
+
153
+ x = jnp.ones((2, 32, 32, 3))
154
+ y_same = conv_same(x)
155
+ y_valid = conv_valid(x)
156
+
157
+ self.assertEqual(y_same.shape, (2, 32, 32, 16)) # SAME preserves size
158
+ self.assertEqual(y_valid.shape, (2, 28, 28, 16)) # VALID reduces size
159
+
160
+
161
+ class TestConv3d(unittest.TestCase):
162
+ """Test cases for 3D convolution."""
163
+
164
+ def test_basic_channels_last(self):
165
+ """Test basic Conv3d with channels-last format."""
166
+ conv = brainstate.nn.Conv3d(in_size=(16, 16, 16, 1), out_channels=32, kernel_size=3)
167
+ x = jnp.ones((2, 16, 16, 16, 1))
168
+ y = conv(x)
169
+
170
+ self.assertEqual(y.shape, (2, 16, 16, 16, 32))
171
+ self.assertEqual(conv.in_channels, 1)
172
+ self.assertEqual(conv.out_channels, 32)
173
+
174
+ def test_basic_channels_first(self):
175
+ """Test basic Conv3d with channels-first format."""
176
+ conv = brainstate.nn.Conv3d(in_size=(1, 16, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
177
+ x = jnp.ones((2, 1, 16, 16, 16))
178
+ y = conv(x)
179
+
180
+ self.assertEqual(y.shape, (2, 32, 16, 16, 16))
181
+ self.assertEqual(conv.in_channels, 1)
182
+ self.assertEqual(conv.out_channels, 32)
183
+
184
+ def test_video_data(self):
185
+ """Test Conv3d for video data."""
186
+ conv = brainstate.nn.Conv3d(in_size=(8, 32, 32, 3), out_channels=64, kernel_size=(3, 3, 3))
187
+ x = jnp.ones((4, 8, 32, 32, 3)) # batch, frames, height, width, channels
188
+ y = conv(x)
189
+
190
+ self.assertEqual(y.shape, (4, 8, 32, 32, 64))
191
+
192
+
193
+ class TestScaledWSConv1d(unittest.TestCase):
194
+ """Test cases for 1D convolution with weight standardization."""
195
+
196
+ def test_basic(self):
197
+ """Test basic ScaledWSConv1d."""
198
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5)
199
+ x = jnp.ones((4, 100, 16))
200
+ y = conv(x)
201
+
202
+ self.assertEqual(y.shape, (4, 100, 32))
203
+ self.assertIsNotNone(conv.eps)
204
+
205
+ def test_with_gain(self):
206
+ """Test ScaledWSConv1d with gain parameter."""
207
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=True)
208
+ x = jnp.ones((4, 100, 16))
209
+ y = conv(x)
210
+
211
+ self.assertEqual(y.shape, (4, 100, 32))
212
+ self.assertIn('gain', conv.weight.value)
213
+
214
+ def test_without_gain(self):
215
+ """Test ScaledWSConv1d without gain parameter."""
216
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, ws_gain=False)
217
+ x = jnp.ones((4, 100, 16))
218
+ y = conv(x)
219
+
220
+ self.assertEqual(y.shape, (4, 100, 32))
221
+ self.assertNotIn('gain', conv.weight.value)
222
+
223
+ def test_custom_eps(self):
224
+ """Test ScaledWSConv1d with custom epsilon."""
225
+ conv = brainstate.nn.ScaledWSConv1d(in_size=(100, 16), out_channels=32, kernel_size=5, eps=1e-5)
226
+ self.assertEqual(conv.eps, 1e-5)
227
+
228
+
229
+ class TestScaledWSConv2d(unittest.TestCase):
230
+ """Test cases for 2D convolution with weight standardization."""
231
+
232
+ def test_basic_channels_last(self):
233
+ """Test basic ScaledWSConv2d with channels-last format."""
234
+ conv = brainstate.nn.ScaledWSConv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3)
235
+ x = jnp.ones((8, 64, 64, 3))
236
+ y = conv(x)
237
+
238
+ self.assertEqual(y.shape, (8, 64, 64, 32))
239
+
240
+ def test_basic_channels_first(self):
241
+ """Test basic ScaledWSConv2d with channels-first format."""
242
+ conv = brainstate.nn.ScaledWSConv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
243
+ x = jnp.ones((8, 3, 64, 64))
244
+ y = conv(x)
245
+
246
+ self.assertEqual(y.shape, (8, 32, 64, 64))
247
+
248
+ def test_with_group_norm_style(self):
249
+ """Test ScaledWSConv2d for use with group normalization."""
250
+ conv = brainstate.nn.ScaledWSConv2d(
251
+ in_size=(32, 32, 16),
252
+ out_channels=32,
253
+ kernel_size=3,
254
+ ws_gain=True,
255
+ groups=1
256
+ )
257
+ x = jnp.ones((4, 32, 32, 16))
258
+ y = conv(x)
259
+
260
+ self.assertEqual(y.shape, (4, 32, 32, 32))
261
+
262
+
263
+ class TestScaledWSConv3d(unittest.TestCase):
264
+ """Test cases for 3D convolution with weight standardization."""
265
+
266
+ def test_basic(self):
267
+ """Test basic ScaledWSConv3d."""
268
+ conv = brainstate.nn.ScaledWSConv3d(in_size=(8, 16, 16, 3), out_channels=32, kernel_size=3)
269
+ x = jnp.ones((2, 8, 16, 16, 3))
270
+ y = conv(x)
271
+
272
+ self.assertEqual(y.shape, (2, 8, 16, 16, 32))
273
+
274
+ def test_channels_first(self):
275
+ """Test ScaledWSConv3d with channels-first format."""
276
+ conv = brainstate.nn.ScaledWSConv3d(in_size=(3, 8, 16, 16), out_channels=32, kernel_size=3, channel_first=True)
277
+ x = jnp.ones((2, 3, 8, 16, 16))
278
+ y = conv(x)
279
+
280
+ self.assertEqual(y.shape, (2, 32, 8, 16, 16))
281
+
282
+
283
+ class TestErrorHandling(unittest.TestCase):
284
+ """Test error handling and edge cases."""
285
+
286
+ def test_invalid_input_shape(self):
287
+ """Test that invalid input shapes raise appropriate errors."""
288
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
289
+ x_wrong = jnp.ones((8, 32, 32, 16)) # Wrong number of channels
290
+
291
+ with self.assertRaises(ValueError):
292
+ conv(x_wrong)
293
+
294
+ def test_invalid_groups(self):
295
+ """Test that invalid group configurations raise errors."""
296
+ with self.assertRaises(AssertionError):
297
+ # out_channels not divisible by groups
298
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 16), out_channels=30, kernel_size=3, groups=4)
299
+
300
+ def test_dimension_mismatch(self):
301
+ """Test dimension mismatch detection."""
302
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
303
+ x_1d = jnp.ones((8, 32, 3)) # 1D instead of 2D
304
+
305
+ with self.assertRaises(ValueError):
306
+ conv(x_1d)
307
+
308
+
309
+ class TestOutputShapes(unittest.TestCase):
310
+ """Test output shape calculations."""
311
+
312
+ def test_same_padding_preserves_size(self):
313
+ """Test that SAME padding preserves spatial dimensions when stride=1."""
314
+ for kernel_size in [3, 5, 7]:
315
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=kernel_size, padding='SAME')
316
+ x = jnp.ones((4, 32, 32, 3))
317
+ y = conv(x)
318
+ self.assertEqual(y.shape, (4, 32, 32, 16), f"Failed for kernel_size={kernel_size}")
319
+
320
+ def test_valid_padding_reduces_size(self):
321
+ """Test that VALID padding reduces spatial dimensions."""
322
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=5, padding='VALID')
323
+ x = jnp.ones((4, 32, 32, 3))
324
+ y = conv(x)
325
+ # 32 - 5 + 1 = 28
326
+ self.assertEqual(y.shape, (4, 28, 28, 16))
327
+
328
+ def test_output_size_attribute(self):
329
+ """Test that out_size attribute is correctly computed."""
330
+ conv_cl = brainstate.nn.Conv2d(in_size=(64, 64, 3), out_channels=32, kernel_size=3, channel_first=False)
331
+ conv_cf = brainstate.nn.Conv2d(in_size=(3, 64, 64), out_channels=32, kernel_size=3, channel_first=True)
332
+
333
+ self.assertEqual(conv_cl.out_size, (64, 64, 32))
334
+ self.assertEqual(conv_cf.out_size, (32, 64, 64))
335
+
336
+
337
+ class TestChannelFormatConsistency(unittest.TestCase):
338
+ """Test consistency between channels-first and channels-last formats."""
339
+
340
+ def test_conv1d_output_channels(self):
341
+ """Test that output channels are in correct position for both formats."""
342
+ conv_cl = brainstate.nn.Conv1d(in_size=(100, 16), out_channels=32, kernel_size=3)
343
+ conv_cf = brainstate.nn.Conv1d(in_size=(16, 100), out_channels=32, kernel_size=3, channel_first=True)
344
+
345
+ x_cl = jnp.ones((4, 100, 16))
346
+ x_cf = jnp.ones((4, 16, 100))
347
+
348
+ y_cl = conv_cl(x_cl)
349
+ y_cf = conv_cf(x_cf)
350
+
351
+ # Channels-last: channels in last dimension
352
+ self.assertEqual(y_cl.shape[-1], 32)
353
+ # Channels-first: channels in first dimension (after batch)
354
+ self.assertEqual(y_cf.shape[1], 32)
355
+
356
+ def test_conv2d_output_channels(self):
357
+ """Test 2D output channel positions."""
358
+ conv_cl = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
359
+ conv_cf = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
360
+
361
+ x_cl = jnp.ones((4, 32, 32, 3))
362
+ x_cf = jnp.ones((4, 3, 32, 32))
363
+
364
+ y_cl = conv_cl(x_cl)
365
+ y_cf = conv_cf(x_cf)
366
+
367
+ self.assertEqual(y_cl.shape[-1], 64)
368
+ self.assertEqual(y_cf.shape[1], 64)
369
+
370
+
371
+ class TestReproducibility(unittest.TestCase):
372
+ """Test reproducibility with fixed seeds."""
373
+
374
+ def test_deterministic_output(self):
375
+ """Test that same seed produces same output."""
376
+ key = jax.random.PRNGKey(42)
377
+
378
+ conv1 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
379
+ conv2 = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=16, kernel_size=3)
380
+
381
+ # Use same random key for input
382
+ x = jax.random.normal(key, (4, 32, 32, 3))
383
+
384
+ # Note: outputs will differ due to different weight initialization
385
+ # This test just ensures no crashes with random inputs
386
+ y1 = conv1(x)
387
+ y2 = conv2(x)
388
+
389
+ self.assertEqual(y1.shape, y2.shape)
390
+
391
+
392
+ class TestRepr(unittest.TestCase):
393
+ """Test string representations."""
394
+
395
+ def test_conv_repr_channels_last(self):
396
+ """Test __repr__ for channels-last format."""
397
+ conv = brainstate.nn.Conv2d(in_size=(32, 32, 3), out_channels=64, kernel_size=3)
398
+ repr_str = repr(conv)
399
+
400
+ self.assertIn('Conv2d', repr_str)
401
+ self.assertIn('channel_first=False', repr_str)
402
+ self.assertIn('in_channels=3', repr_str)
403
+ self.assertIn('out_channels=64', repr_str)
404
+
405
+ def test_conv_repr_channels_first(self):
406
+ """Test __repr__ for channels-first format."""
407
+ conv = brainstate.nn.Conv2d(in_size=(3, 32, 32), out_channels=64, kernel_size=3, channel_first=True)
408
+ repr_str = repr(conv)
409
+
410
+ self.assertIn('Conv2d', repr_str)
411
+ self.assertIn('channel_first=True', repr_str)
412
+
413
+
414
+ class TestConvTranspose1d(unittest.TestCase):
415
+ """Test cases for ConvTranspose1d layer."""
416
+
417
+ def setUp(self):
418
+ """Set up test fixtures."""
419
+ self.in_size = (28, 16)
420
+ self.out_channels = 8
421
+ self.kernel_size = 4
422
+
423
+ def test_basic_channels_last(self):
424
+ """Test basic ConvTranspose1d with channels-last format."""
425
+ conv_t = brainstate.nn.ConvTranspose1d(
426
+ in_size=self.in_size,
427
+ out_channels=self.out_channels,
428
+ kernel_size=self.kernel_size,
429
+ stride=1
430
+ )
431
+ x = jnp.ones((2, 28, 16))
432
+ y = conv_t(x)
433
+
434
+ self.assertEqual(len(y.shape), 3)
435
+ self.assertEqual(y.shape[0], 2) # batch size
436
+ self.assertEqual(y.shape[-1], self.out_channels)
437
+ self.assertEqual(conv_t.in_channels, 16)
438
+ self.assertEqual(conv_t.out_channels, 8)
439
+ self.assertFalse(conv_t.channel_first)
440
+
441
+ def test_basic_channels_first(self):
442
+ """Test basic ConvTranspose1d with channels-first format."""
443
+ conv_t = brainstate.nn.ConvTranspose1d(
444
+ in_size=(16, 28), # (C, L) for channels-first
445
+ out_channels=self.out_channels,
446
+ kernel_size=self.kernel_size,
447
+ stride=1,
448
+ channel_first=True
449
+ )
450
+ x = jnp.ones((2, 16, 28))
451
+ y = conv_t(x)
452
+
453
+ self.assertEqual(len(y.shape), 3)
454
+ self.assertEqual(y.shape[0], 2) # batch size
455
+ self.assertEqual(y.shape[1], self.out_channels) # channels first
456
+ self.assertEqual(conv_t.in_channels, 16)
457
+ self.assertTrue(conv_t.channel_first)
458
+
459
+ def test_stride_upsampling(self):
460
+ """Test transposed convolution with stride for upsampling."""
461
+ conv_t = brainstate.nn.ConvTranspose1d(
462
+ in_size=(28, 16),
463
+ out_channels=8,
464
+ kernel_size=4,
465
+ stride=2,
466
+ padding='SAME'
467
+ )
468
+ x = jnp.ones((2, 28, 16))
469
+ y = conv_t(x)
470
+
471
+ # With stride=2, output should be approximately 2x larger
472
+ self.assertGreater(y.shape[1], x.shape[1])
473
+
474
+ def test_with_bias(self):
475
+ """Test ConvTranspose1d with bias."""
476
+ conv_t = brainstate.nn.ConvTranspose1d(
477
+ in_size=(50, 8),
478
+ out_channels=16,
479
+ kernel_size=3,
480
+ b_init=brainstate.init.Constant(0.0)
481
+ )
482
+ x = jnp.ones((4, 50, 8))
483
+ y = conv_t(x)
484
+
485
+ self.assertTrue('bias' in conv_t.weight.value)
486
+ self.assertEqual(y.shape[-1], 16)
487
+
488
+ def test_without_batch(self):
489
+ """Test ConvTranspose1d without batch dimension."""
490
+ conv_t = brainstate.nn.ConvTranspose1d(
491
+ in_size=(28, 16),
492
+ out_channels=8,
493
+ kernel_size=4
494
+ )
495
+ x = jnp.ones((28, 16))
496
+ y = conv_t(x)
497
+
498
+ self.assertEqual(len(y.shape), 2)
499
+ self.assertEqual(y.shape[-1], 8)
500
+
501
+ def test_groups(self):
502
+ """Test grouped transposed convolution."""
503
+ conv_t = brainstate.nn.ConvTranspose1d(
504
+ in_size=(28, 16),
505
+ out_channels=16,
506
+ kernel_size=3,
507
+ groups=4
508
+ )
509
+ x = jnp.ones((2, 28, 16))
510
+ y = conv_t(x)
511
+
512
+ self.assertEqual(y.shape[-1], 16)
513
+
514
+
515
+ class TestConvTranspose2d(unittest.TestCase):
516
+ """Test cases for ConvTranspose2d layer."""
517
+
518
+ def setUp(self):
519
+ """Set up test fixtures."""
520
+ self.in_size = (16, 16, 32)
521
+ self.out_channels = 16
522
+ self.kernel_size = 4
523
+
524
+ def test_basic_channels_last(self):
525
+ """Test basic ConvTranspose2d with channels-last format."""
526
+ conv_t = brainstate.nn.ConvTranspose2d(
527
+ in_size=self.in_size,
528
+ out_channels=self.out_channels,
529
+ kernel_size=self.kernel_size
530
+ )
531
+ x = jnp.ones((4, 16, 16, 32))
532
+ y = conv_t(x)
533
+
534
+ self.assertEqual(len(y.shape), 4)
535
+ self.assertEqual(y.shape[0], 4) # batch size
536
+ self.assertEqual(y.shape[-1], self.out_channels)
537
+ self.assertEqual(conv_t.in_channels, 32)
538
+ self.assertFalse(conv_t.channel_first)
539
+
540
+ def test_basic_channels_first(self):
541
+ """Test basic ConvTranspose2d with channels-first format."""
542
+ conv_t = brainstate.nn.ConvTranspose2d(
543
+ in_size=(32, 16, 16), # (C, H, W) for channels-first
544
+ out_channels=self.out_channels,
545
+ kernel_size=self.kernel_size,
546
+ channel_first=True
547
+ )
548
+ x = jnp.ones((4, 32, 16, 16))
549
+ y = conv_t(x)
550
+
551
+ self.assertEqual(len(y.shape), 4)
552
+ self.assertEqual(y.shape[1], self.out_channels) # channels first
553
+ self.assertTrue(conv_t.channel_first)
554
+
555
+ def test_stride_upsampling(self):
556
+ """Test 2x upsampling with stride=2."""
557
+ conv_t = brainstate.nn.ConvTranspose2d(
558
+ in_size=(16, 16, 32),
559
+ out_channels=16,
560
+ kernel_size=4,
561
+ stride=2,
562
+ padding='SAME'
563
+ )
564
+ x = jnp.ones((4, 16, 16, 32))
565
+ y = conv_t(x)
566
+
567
+ # With stride=2, output should be approximately 2x larger in each spatial dimension
568
+ self.assertGreater(y.shape[1], x.shape[1])
569
+ self.assertGreater(y.shape[2], x.shape[2])
570
+
571
+ def test_rectangular_kernel(self):
572
+ """Test ConvTranspose2d with rectangular kernel."""
573
+ conv_t = brainstate.nn.ConvTranspose2d(
574
+ in_size=(16, 16, 32),
575
+ out_channels=16,
576
+ kernel_size=(3, 5),
577
+ stride=1
578
+ )
579
+ x = jnp.ones((2, 16, 16, 32))
580
+ y = conv_t(x)
581
+
582
+ self.assertEqual(conv_t.kernel_size, (3, 5))
583
+ self.assertEqual(y.shape[-1], 16)
584
+
585
+ def test_padding_valid(self):
586
+ """Test ConvTranspose2d with VALID padding."""
587
+ conv_t = brainstate.nn.ConvTranspose2d(
588
+ in_size=(16, 16, 32),
589
+ out_channels=16,
590
+ kernel_size=4,
591
+ stride=2,
592
+ padding='VALID'
593
+ )
594
+ x = jnp.ones((2, 16, 16, 32))
595
+ y = conv_t(x)
596
+
597
+ # VALID padding means no padding, output computed by formula:
598
+ # out = (in - 1) * stride + kernel
599
+ # out = (16 - 1) * 2 + 4 = 34 (but JAX may compute it slightly differently)
600
+ # At minimum, it should upsample
601
+ self.assertGreater(y.shape[1], 16)
602
+
603
+ def test_groups(self):
604
+ """Test grouped transposed convolution."""
605
+ conv_t = brainstate.nn.ConvTranspose2d(
606
+ in_size=(16, 16, 32),
607
+ out_channels=32,
608
+ kernel_size=3,
609
+ groups=4
610
+ )
611
+ x = jnp.ones((2, 16, 16, 32))
612
+ y = conv_t(x)
613
+
614
+ self.assertEqual(y.shape[-1], 32)
615
+
616
+
617
+ class TestConvTranspose3d(unittest.TestCase):
618
+ """Test cases for ConvTranspose3d layer."""
619
+
620
+ def setUp(self):
621
+ """Set up test fixtures."""
622
+ self.in_size = (8, 8, 8, 16)
623
+ self.out_channels = 8
624
+ self.kernel_size = 4
625
+
626
+ def test_basic_channels_last(self):
627
+ """Test basic ConvTranspose3d with channels-last format."""
628
+ conv_t = brainstate.nn.ConvTranspose3d(
629
+ in_size=self.in_size,
630
+ out_channels=self.out_channels,
631
+ kernel_size=self.kernel_size
632
+ )
633
+ x = jnp.ones((2, 8, 8, 8, 16))
634
+ y = conv_t(x)
635
+
636
+ self.assertEqual(len(y.shape), 5)
637
+ self.assertEqual(y.shape[0], 2) # batch size
638
+ self.assertEqual(y.shape[-1], self.out_channels)
639
+ self.assertEqual(conv_t.in_channels, 16)
640
+
641
+ def test_basic_channels_first(self):
642
+ """Test basic ConvTranspose3d with channels-first format."""
643
+ conv_t = brainstate.nn.ConvTranspose3d(
644
+ in_size=(16, 8, 8, 8), # (C, H, W, D) for channels-first
645
+ out_channels=self.out_channels,
646
+ kernel_size=self.kernel_size,
647
+ channel_first=True
648
+ )
649
+ x = jnp.ones((2, 16, 8, 8, 8))
650
+ y = conv_t(x)
651
+
652
+ self.assertEqual(len(y.shape), 5)
653
+ self.assertEqual(y.shape[1], self.out_channels) # channels first
654
+ self.assertTrue(conv_t.channel_first)
655
+
656
+ def test_stride_upsampling(self):
657
+ """Test 3D upsampling with stride=2."""
658
+ conv_t = brainstate.nn.ConvTranspose3d(
659
+ in_size=(8, 8, 8, 16),
660
+ out_channels=8,
661
+ kernel_size=4,
662
+ stride=2,
663
+ padding='SAME'
664
+ )
665
+ x = jnp.ones((2, 8, 8, 8, 16))
666
+ y = conv_t(x)
667
+
668
+ # With stride=2, output should be approximately 2x larger
669
+ self.assertGreater(y.shape[1], x.shape[1])
670
+ self.assertGreater(y.shape[2], x.shape[2])
671
+ self.assertGreater(y.shape[3], x.shape[3])
672
+
673
+
674
+ class TestErrorHandlingConvTranspose(unittest.TestCase):
675
+ """Test error handling for transposed convolutions."""
676
+
677
+ def test_invalid_groups(self):
678
+ """Test that invalid groups raises assertion error."""
679
+ with self.assertRaises(AssertionError):
680
+ brainstate.nn.ConvTranspose2d(
681
+ in_size=(16, 16, 32),
682
+ out_channels=15, # Not divisible by groups
683
+ kernel_size=3,
684
+ groups=4
685
+ )
686
+
687
+ def test_dimension_mismatch(self):
688
+ """Test that wrong input dimensions raise error."""
689
+ conv_t = brainstate.nn.ConvTranspose2d(
690
+ in_size=(16, 16, 32),
691
+ out_channels=16,
692
+ kernel_size=3
693
+ )
694
+ x = jnp.ones((2, 16, 16, 16)) # Wrong number of channels
695
+
696
+ with self.assertRaises(ValueError):
697
+ conv_t(x)
698
+
699
+ def test_invalid_input_shape(self):
700
+ """Test that invalid input shape raises error."""
701
+ conv_t = brainstate.nn.ConvTranspose1d(
702
+ in_size=(28, 16),
703
+ out_channels=8,
704
+ kernel_size=3
705
+ )
706
+ x = jnp.ones((2, 2, 28, 16)) # Too many dimensions
707
+
708
+ with self.assertRaises(ValueError):
709
+ conv_t(x)
710
+
711
+
712
+ class TestOutputShapesConvTranspose(unittest.TestCase):
713
+ """Test output shape computation for transposed convolutions."""
714
+
715
+ def test_out_size_attribute_1d(self):
716
+ """Test that out_size attribute is correctly computed for 1D."""
717
+ conv_t = brainstate.nn.ConvTranspose1d(
718
+ in_size=(28, 16),
719
+ out_channels=8,
720
+ kernel_size=4,
721
+ stride=2
722
+ )
723
+
724
+ self.assertIsNotNone(conv_t.out_size)
725
+ self.assertEqual(len(conv_t.out_size), 2)
726
+
727
+ def test_out_size_attribute_2d(self):
728
+ """Test that out_size attribute is correctly computed for 2D."""
729
+ conv_t = brainstate.nn.ConvTranspose2d(
730
+ in_size=(16, 16, 32),
731
+ out_channels=16,
732
+ kernel_size=4,
733
+ stride=2
734
+ )
735
+
736
+ self.assertIsNotNone(conv_t.out_size)
737
+ self.assertEqual(len(conv_t.out_size), 3)
738
+
739
+ def test_upsampling_factor(self):
740
+ """Test that stride=2 approximately doubles spatial dimensions."""
741
+ conv_t = brainstate.nn.ConvTranspose2d(
742
+ in_size=(16, 16, 32),
743
+ out_channels=16,
744
+ kernel_size=4,
745
+ stride=2,
746
+ padding='SAME'
747
+ )
748
+ x = jnp.ones((2, 16, 16, 32))
749
+ y = conv_t(x)
750
+
751
+ # For SAME padding and stride=2, output should be approximately 2x input
752
+ self.assertGreaterEqual(y.shape[1], 28)
753
+ self.assertGreaterEqual(y.shape[2], 28)
754
+
755
+
756
+ class TestChannelFormatConsistencyConvTranspose(unittest.TestCase):
757
+ """Test consistency between different channel formats."""
758
+
759
+ def test_conv_transpose_1d_output_channels(self):
760
+ """Test that output channels are in correct position for both formats."""
761
+ # Channels-last
762
+ conv_t_last = brainstate.nn.ConvTranspose1d(
763
+ in_size=(28, 16),
764
+ out_channels=8,
765
+ kernel_size=3
766
+ )
767
+ x_last = jnp.ones((2, 28, 16))
768
+ y_last = conv_t_last(x_last)
769
+ self.assertEqual(y_last.shape[-1], 8)
770
+
771
+ # Channels-first
772
+ conv_t_first = brainstate.nn.ConvTranspose1d(
773
+ in_size=(16, 28),
774
+ out_channels=8,
775
+ kernel_size=3,
776
+ channel_first=True
777
+ )
778
+ x_first = jnp.ones((2, 16, 28))
779
+ y_first = conv_t_first(x_first)
780
+ self.assertEqual(y_first.shape[1], 8)
781
+
782
+ def test_conv_transpose_2d_output_channels(self):
783
+ """Test that output channels are in correct position for both formats."""
784
+ # Channels-last
785
+ conv_t_last = brainstate.nn.ConvTranspose2d(
786
+ in_size=(16, 16, 32),
787
+ out_channels=16,
788
+ kernel_size=3
789
+ )
790
+ x_last = jnp.ones((2, 16, 16, 32))
791
+ y_last = conv_t_last(x_last)
792
+ self.assertEqual(y_last.shape[-1], 16)
793
+
794
+ # Channels-first
795
+ conv_t_first = brainstate.nn.ConvTranspose2d(
796
+ in_size=(32, 16, 16),
797
+ out_channels=16,
798
+ kernel_size=3,
799
+ channel_first=True
800
+ )
801
+ x_first = jnp.ones((2, 32, 16, 16))
802
+ y_first = conv_t_first(x_first)
803
+ self.assertEqual(y_first.shape[1], 16)
804
+
805
+
806
+ class TestReproducibilityConvTranspose(unittest.TestCase):
807
+ """Test deterministic behavior of transposed convolutions."""
808
+
809
+ def test_deterministic_output(self):
810
+ """Test that same input produces same output."""
811
+ conv_t = brainstate.nn.ConvTranspose2d(
812
+ in_size=(16, 16, 32),
813
+ out_channels=16,
814
+ kernel_size=3
815
+ )
816
+ x = jnp.ones((2, 16, 16, 32))
817
+
818
+ y1 = conv_t(x)
819
+ y2 = conv_t(x)
820
+
821
+ self.assertTrue(jnp.allclose(y1, y2))
822
+
823
+
824
+ class TestKernelShapeConvTranspose(unittest.TestCase):
825
+ """Test kernel shape computation for transposed convolutions."""
826
+
827
+ def test_kernel_shape_1d(self):
828
+ """Test that kernel shape is correct for transposed conv 1D."""
829
+ conv_t = brainstate.nn.ConvTranspose1d(
830
+ in_size=(28, 16),
831
+ out_channels=8,
832
+ kernel_size=4,
833
+ groups=2
834
+ )
835
+ # For transpose conv: (kernel_size, out_channels, in_channels // groups)
836
+ expected_shape = (4, 8, 16 // 2)
837
+ self.assertEqual(conv_t.kernel_shape, expected_shape)
838
+
839
+ def test_kernel_shape_2d(self):
840
+ """Test that kernel shape is correct for transposed conv 2D."""
841
+ conv_t = brainstate.nn.ConvTranspose2d(
842
+ in_size=(16, 16, 32),
843
+ out_channels=16,
844
+ kernel_size=4,
845
+ groups=4
846
+ )
847
+ # For transpose conv: (kernel_h, kernel_w, out_channels, in_channels // groups)
848
+ expected_shape = (4, 4, 16, 32 // 4)
849
+ self.assertEqual(conv_t.kernel_shape, expected_shape)