brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- from __future__ import annotations
4
-
5
3
  import jax.numpy as jnp
6
4
  import pytest
7
5
  from absl.testing import absltest
8
6
  from absl.testing import parameterized
9
7
 
10
- import brainstate as bst
8
+ import brainstate
11
9
 
12
10
 
13
11
  class TestConv(parameterized.TestCase):
@@ -19,8 +17,8 @@ class TestConv(parameterized.TestCase):
19
17
  img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
20
18
  img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
21
19
 
22
- net = bst.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
23
- stride=(2, 1), padding='VALID', groups=4)
20
+ net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
21
+ stride=(2, 1), padding='VALID', groups=4)
24
22
  out = net(img)
25
23
  print("out shape: ", out.shape)
26
24
  self.assertEqual(out.shape, (2, 99, 196, 32))
@@ -30,7 +28,7 @@ class TestConv(parameterized.TestCase):
30
28
  # plt.show()
31
29
 
32
30
  def test_conv1D(self):
33
- model = bst.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
31
+ model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
34
32
  input = jnp.ones((2, 5, 3))
35
33
  out = model(input)
36
34
  print("out shape: ", out.shape)
@@ -41,7 +39,7 @@ class TestConv(parameterized.TestCase):
41
39
  # plt.show()
42
40
 
43
41
  def test_conv2D(self):
44
- model = bst.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
42
+ model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
45
43
  input = jnp.ones((2, 5, 5, 3))
46
44
 
47
45
  out = model(input)
@@ -49,7 +47,7 @@ class TestConv(parameterized.TestCase):
49
47
  self.assertEqual(out.shape, (2, 5, 5, 32))
50
48
 
51
49
  def test_conv3D(self):
52
- model = bst.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
50
+ model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
53
51
  input = jnp.ones((2, 5, 5, 5, 3))
54
52
  out = model(input)
55
53
  print("out shape: ", out.shape)
@@ -62,13 +60,13 @@ class TestConvTranspose1d(parameterized.TestCase):
62
60
 
63
61
  x = jnp.ones((1, 8, 3))
64
62
  for use_bias in [True, False]:
65
- conv_transpose_module = bst.nn.ConvTranspose1d(
63
+ conv_transpose_module = brainstate.nn.ConvTranspose1d(
66
64
  in_channels=3,
67
65
  out_channels=4,
68
66
  kernel_size=(3,),
69
67
  padding='VALID',
70
- w_initializer=bst.init.Constant(1.),
71
- b_initializer=bst.init.Constant(1.) if use_bias else None,
68
+ w_initializer=brainstate.init.Constant(1.),
69
+ b_initializer=brainstate.init.Constant(1.) if use_bias else None,
72
70
  )
73
71
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
74
72
  y = conv_transpose_module(x)
@@ -91,14 +89,14 @@ class TestConvTranspose1d(parameterized.TestCase):
91
89
 
92
90
  x = jnp.ones((1, 8, 3))
93
91
  m = jnp.tril(jnp.ones((3, 3, 4)))
94
- conv_transpose_module = bst.nn.ConvTranspose1d(
92
+ conv_transpose_module = brainstate.nn.ConvTranspose1d(
95
93
  in_channels=3,
96
94
  out_channels=4,
97
95
  kernel_size=(3,),
98
96
  padding='VALID',
99
97
  mask=m,
100
- w_initializer=bst.init.Constant(),
101
- b_initializer=bst.init.Constant(),
98
+ w_initializer=brainstate.init.Constant(),
99
+ b_initializer=brainstate.init.Constant(),
102
100
  )
103
101
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
104
102
  y = conv_transpose_module(x)
@@ -119,14 +117,14 @@ class TestConvTranspose1d(parameterized.TestCase):
119
117
 
120
118
  data = jnp.ones([1, 3, 1])
121
119
  for use_bias in [True, False]:
122
- net = bst.nn.ConvTranspose1d(
120
+ net = brainstate.nn.ConvTranspose1d(
123
121
  in_channels=1,
124
122
  out_channels=1,
125
123
  kernel_size=3,
126
124
  stride=1,
127
125
  padding="SAME",
128
- w_initializer=bst.init.Constant(),
129
- b_initializer=bst.init.Constant() if use_bias else None,
126
+ w_initializer=brainstate.init.Constant(),
127
+ b_initializer=brainstate.init.Constant() if use_bias else None,
130
128
  )
131
129
  out = net(data)
132
130
  self.assertEqual(out.shape, (1, 3, 1))
@@ -143,13 +141,13 @@ class TestConvTranspose2d(parameterized.TestCase):
143
141
 
144
142
  x = jnp.ones((1, 8, 8, 3))
145
143
  for use_bias in [True, False]:
146
- conv_transpose_module = bst.nn.ConvTranspose2d(
144
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
147
145
  in_channels=3,
148
146
  out_channels=4,
149
147
  kernel_size=(3, 3),
150
148
  padding='VALID',
151
- w_initializer=bst.init.Constant(),
152
- b_initializer=bst.init.Constant() if use_bias else None,
149
+ w_initializer=brainstate.init.Constant(),
150
+ b_initializer=brainstate.init.Constant() if use_bias else None,
153
151
  )
154
152
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
155
153
  y = conv_transpose_module(x)
@@ -159,13 +157,13 @@ class TestConvTranspose2d(parameterized.TestCase):
159
157
 
160
158
  x = jnp.ones((1, 8, 8, 3))
161
159
  m = jnp.tril(jnp.ones((3, 3, 3, 4)))
162
- conv_transpose_module = bst.nn.ConvTranspose2d(
160
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
163
161
  in_channels=3,
164
162
  out_channels=4,
165
163
  kernel_size=(3, 3),
166
164
  padding='VALID',
167
165
  mask=m,
168
- w_initializer=bst.init.Constant(),
166
+ w_initializer=brainstate.init.Constant(),
169
167
  )
170
168
  y = conv_transpose_module(x)
171
169
  print(y.shape)
@@ -174,14 +172,14 @@ class TestConvTranspose2d(parameterized.TestCase):
174
172
 
175
173
  x = jnp.ones((1, 8, 8, 3))
176
174
  for use_bias in [True, False]:
177
- conv_transpose_module = bst.nn.ConvTranspose2d(
175
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
178
176
  in_channels=3,
179
177
  out_channels=4,
180
178
  kernel_size=(3, 3),
181
179
  stride=1,
182
180
  padding='SAME',
183
- w_initializer=bst.init.Constant(),
184
- b_initializer=bst.init.Constant() if use_bias else None,
181
+ w_initializer=brainstate.init.Constant(),
182
+ b_initializer=brainstate.init.Constant() if use_bias else None,
185
183
  )
186
184
  y = conv_transpose_module(x)
187
185
  print(y.shape)
@@ -193,13 +191,13 @@ class TestConvTranspose3d(parameterized.TestCase):
193
191
 
194
192
  x = jnp.ones((1, 8, 8, 8, 3))
195
193
  for use_bias in [True, False]:
196
- conv_transpose_module = bst.nn.ConvTranspose3d(
194
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
197
195
  in_channels=3,
198
196
  out_channels=4,
199
197
  kernel_size=(3, 3, 3),
200
198
  padding='VALID',
201
- w_initializer=bst.init.Constant(),
202
- b_initializer=bst.init.Constant() if use_bias else None,
199
+ w_initializer=brainstate.init.Constant(),
200
+ b_initializer=brainstate.init.Constant() if use_bias else None,
203
201
  )
204
202
  y = conv_transpose_module(x)
205
203
  print(y.shape)
@@ -208,13 +206,13 @@ class TestConvTranspose3d(parameterized.TestCase):
208
206
 
209
207
  x = jnp.ones((1, 8, 8, 8, 3))
210
208
  m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
211
- conv_transpose_module = bst.nn.ConvTranspose3d(
209
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
212
210
  in_channels=3,
213
211
  out_channels=4,
214
212
  kernel_size=(3, 3, 3),
215
213
  padding='VALID',
216
214
  mask=m,
217
- w_initializer=bst.init.Constant(),
215
+ w_initializer=brainstate.init.Constant(),
218
216
  )
219
217
  y = conv_transpose_module(x)
220
218
  print(y.shape)
@@ -223,14 +221,14 @@ class TestConvTranspose3d(parameterized.TestCase):
223
221
 
224
222
  x = jnp.ones((1, 8, 8, 8, 3))
225
223
  for use_bias in [True, False]:
226
- conv_transpose_module = bst.nn.ConvTranspose3d(
224
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
227
225
  in_channels=3,
228
226
  out_channels=4,
229
227
  kernel_size=(3, 3, 3),
230
228
  stride=1,
231
229
  padding='SAME',
232
- w_initializer=bst.init.Constant(),
233
- b_initializer=bst.init.Constant() if use_bias else None,
230
+ w_initializer=brainstate.init.Constant(),
231
+ b_initializer=brainstate.init.Constant() if use_bias else None,
234
232
  )
235
233
  y = conv_transpose_module(x)
236
234
  print(y.shape)
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from __future__ import annotations
16
15
 
17
16
  from typing import Optional, Callable, Union
18
17
 
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  from typing import Callable, Union, Optional
21
19
 
22
20
  import brainunit as u
@@ -14,14 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  import unittest
20
18
 
21
19
  import brainunit as u
22
20
  from absl.testing import parameterized
23
21
 
24
- import brainstate as bst
22
+ import brainstate
25
23
 
26
24
 
27
25
  class TestDense(parameterized.TestCase):
@@ -32,19 +30,19 @@ class TestDense(parameterized.TestCase):
32
30
  num_out=[20, ]
33
31
  )
34
32
  def test_Dense1(self, size, num_out):
35
- f = bst.nn.Linear(10, num_out)
36
- x = bst.random.random(size)
33
+ f = brainstate.nn.Linear(10, num_out)
34
+ x = brainstate.random.random(size)
37
35
  y = f(x)
38
36
  self.assertTrue(y.shape == size[:-1] + (num_out,))
39
37
 
40
38
 
41
39
  class TestSparseMatrix(unittest.TestCase):
42
40
  def test_csr(self):
43
- data = bst.random.rand(10, 20)
41
+ data = brainstate.random.rand(10, 20)
44
42
  data = data * (data > 0.9)
45
- f = bst.nn.SparseLinear(u.sparse.CSR.fromdense(data))
43
+ f = brainstate.nn.SparseLinear(u.sparse.CSR.fromdense(data))
46
44
 
47
- x = bst.random.rand(10)
45
+ x = brainstate.random.rand(10)
48
46
  y = f(x)
49
47
  self.assertTrue(
50
48
  u.math.allclose(
@@ -53,7 +51,7 @@ class TestSparseMatrix(unittest.TestCase):
53
51
  )
54
52
  )
55
53
 
56
- x = bst.random.rand(5, 10)
54
+ x = brainstate.random.rand(5, 10)
57
55
  y = f(x)
58
56
  self.assertTrue(
59
57
  u.math.allclose(
@@ -63,11 +61,11 @@ class TestSparseMatrix(unittest.TestCase):
63
61
  )
64
62
 
65
63
  def test_csc(self):
66
- data = bst.random.rand(10, 20)
64
+ data = brainstate.random.rand(10, 20)
67
65
  data = data * (data > 0.9)
68
- f = bst.nn.SparseLinear(u.sparse.CSC.fromdense(data))
66
+ f = brainstate.nn.SparseLinear(u.sparse.CSC.fromdense(data))
69
67
 
70
- x = bst.random.rand(10)
68
+ x = brainstate.random.rand(10)
71
69
  y = f(x)
72
70
  self.assertTrue(
73
71
  u.math.allclose(
@@ -76,7 +74,7 @@ class TestSparseMatrix(unittest.TestCase):
76
74
  )
77
75
  )
78
76
 
79
- x = bst.random.rand(5, 10)
77
+ x = brainstate.random.rand(5, 10)
80
78
  y = f(x)
81
79
  self.assertTrue(
82
80
  u.math.allclose(
@@ -86,11 +84,11 @@ class TestSparseMatrix(unittest.TestCase):
86
84
  )
87
85
 
88
86
  def test_coo(self):
89
- data = bst.random.rand(10, 20)
87
+ data = brainstate.random.rand(10, 20)
90
88
  data = data * (data > 0.9)
91
- f = bst.nn.SparseLinear(u.sparse.COO.fromdense(data))
89
+ f = brainstate.nn.SparseLinear(u.sparse.COO.fromdense(data))
92
90
 
93
- x = bst.random.rand(10)
91
+ x = brainstate.random.rand(10)
94
92
  y = f(x)
95
93
  self.assertTrue(
96
94
  u.math.allclose(
@@ -99,7 +97,7 @@ class TestSparseMatrix(unittest.TestCase):
99
97
  )
100
98
  )
101
99
 
102
- x = bst.random.rand(5, 10)
100
+ x = brainstate.random.rand(5, 10)
103
101
  y = f(x)
104
102
  self.assertTrue(
105
103
  u.math.allclose(
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  from typing import Callable, Union, Sequence, Optional, Any
21
19
 
22
20
  import jax
@@ -13,12 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from absl.testing import absltest
19
17
  from absl.testing import parameterized
20
18
 
21
- import brainstate as bst
19
+ import brainstate
22
20
 
23
21
 
24
22
  class Test_Normalization(parameterized.TestCase):
@@ -26,27 +24,27 @@ class Test_Normalization(parameterized.TestCase):
26
24
  fit=[True, False],
27
25
  )
28
26
  def test_BatchNorm1d(self, fit):
29
- net = bst.nn.BatchNorm1d((3, 10))
30
- bst.environ.set(fit=fit)
31
- input = bst.random.randn(1, 3, 10)
27
+ net = brainstate.nn.BatchNorm1d((3, 10))
28
+ brainstate.environ.set(fit=fit)
29
+ input = brainstate.random.randn(1, 3, 10)
32
30
  output = net(input)
33
31
 
34
32
  @parameterized.product(
35
33
  fit=[True, False]
36
34
  )
37
35
  def test_BatchNorm2d(self, fit):
38
- net = bst.nn.BatchNorm2d([3, 4, 10])
39
- bst.environ.set(fit=fit)
40
- input = bst.random.randn(1, 3, 4, 10)
36
+ net = brainstate.nn.BatchNorm2d([3, 4, 10])
37
+ brainstate.environ.set(fit=fit)
38
+ input = brainstate.random.randn(1, 3, 4, 10)
41
39
  output = net(input)
42
40
 
43
41
  @parameterized.product(
44
42
  fit=[True, False]
45
43
  )
46
44
  def test_BatchNorm3d(self, fit):
47
- net = bst.nn.BatchNorm3d([3, 4, 5, 10])
48
- bst.environ.set(fit=fit)
49
- input = bst.random.randn(1, 3, 4, 5, 10)
45
+ net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
46
+ brainstate.environ.set(fit=fit)
47
+ input = brainstate.random.randn(1, 3, 4, 5, 10)
50
48
  output = net(input)
51
49
 
52
50
  # @parameterized.product(
@@ -15,8 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
-
20
18
  import functools
21
19
  from typing import Sequence, Optional
22
20
  from typing import Union, Tuple, Callable, List
@@ -1,13 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- from __future__ import annotations
4
-
5
3
  import jax
6
4
  import numpy as np
7
5
  from absl.testing import absltest
8
6
  from absl.testing import parameterized
9
7
 
10
- import brainstate as bst
8
+ import brainstate
11
9
  import brainstate.nn as nn
12
10
 
13
11
 
@@ -18,7 +16,7 @@ class TestFlatten(parameterized.TestCase):
18
16
  (32, 8),
19
17
  (10, 20, 30),
20
18
  ]:
21
- arr = bst.random.rand(*size)
19
+ arr = brainstate.random.rand(*size)
22
20
  f = nn.Flatten(start_axis=0)
23
21
  out = f(arr)
24
22
  self.assertTrue(out.shape == (np.prod(size),))
@@ -29,21 +27,21 @@ class TestFlatten(parameterized.TestCase):
29
27
  (32, 8),
30
28
  (10, 20, 30),
31
29
  ]:
32
- arr = bst.random.rand(*size)
30
+ arr = brainstate.random.rand(*size)
33
31
  f = nn.Flatten(start_axis=1)
34
32
  out = f(arr)
35
33
  self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
36
34
 
37
35
  def test_flatten3(self):
38
36
  size = (16, 32, 32, 8)
39
- arr = bst.random.rand(*size)
37
+ arr = brainstate.random.rand(*size)
40
38
  f = nn.Flatten(start_axis=0, in_size=(32, 8))
41
39
  out = f(arr)
42
40
  self.assertTrue(out.shape == (16, 32, 32 * 8))
43
41
 
44
42
  def test_flatten4(self):
45
43
  size = (16, 32, 32, 8)
46
- arr = bst.random.rand(*size)
44
+ arr = brainstate.random.rand(*size)
47
45
  f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
48
46
  out = f(arr)
49
47
  self.assertTrue(out.shape == (16, 32, 32 * 8))
@@ -58,7 +56,7 @@ class TestPool(parameterized.TestCase):
58
56
  super().__init__(*args, **kwargs)
59
57
 
60
58
  def test_MaxPool2d_v1(self):
61
- arr = bst.random.rand(16, 32, 32, 8)
59
+ arr = brainstate.random.rand(16, 32, 32, 8)
62
60
 
63
61
  out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
64
62
  self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -79,7 +77,7 @@ class TestPool(parameterized.TestCase):
79
77
  self.assertTrue(out.shape == (16, 17, 32, 5))
80
78
 
81
79
  def test_AvgPool2d_v1(self):
82
- arr = bst.random.rand(16, 32, 32, 8)
80
+ arr = brainstate.random.rand(16, 32, 32, 8)
83
81
 
84
82
  out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
85
83
  self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -107,7 +105,7 @@ class TestPool(parameterized.TestCase):
107
105
  def test_adaptive_pool1d(self, target_size):
108
106
  from brainstate.nn._interaction._poolings import _adaptive_pool1d
109
107
 
110
- arr = bst.random.rand(100)
108
+ arr = brainstate.random.rand(100)
111
109
  op = jax.numpy.mean
112
110
 
113
111
  out = _adaptive_pool1d(arr, target_size, op)
@@ -119,7 +117,7 @@ class TestPool(parameterized.TestCase):
119
117
  self.assertTrue(out.shape == (target_size,))
120
118
 
121
119
  def test_AdaptiveAvgPool2d_v1(self):
122
- input = bst.random.randn(64, 8, 9)
120
+ input = brainstate.random.randn(64, 8, 9)
123
121
 
124
122
  output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
125
123
  self.assertTrue(output.shape == (64, 5, 7))
@@ -137,8 +135,8 @@ class TestPool(parameterized.TestCase):
137
135
  self.assertTrue(output.shape == (64, 2, 3))
138
136
 
139
137
  def test_AdaptiveAvgPool2d_v2(self):
140
- bst.random.seed()
141
- input = bst.random.randn(128, 64, 32, 16)
138
+ brainstate.random.seed()
139
+ input = brainstate.random.randn(128, 64, 32, 16)
142
140
 
143
141
  output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
144
142
  self.assertTrue(output.shape == (128, 64, 5, 7))
@@ -154,13 +152,13 @@ class TestPool(parameterized.TestCase):
154
152
  print()
155
153
 
156
154
  def test_AdaptiveAvgPool3d_v1(self):
157
- input = bst.random.randn(10, 128, 64, 32)
155
+ input = brainstate.random.randn(10, 128, 64, 32)
158
156
  net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
159
157
  output = net(input)
160
158
  self.assertTrue(output.shape == (10, 6, 5, 3))
161
159
 
162
160
  def test_AdaptiveAvgPool3d_v2(self):
163
- input = bst.random.randn(10, 20, 128, 64, 32)
161
+ input = brainstate.random.randn(10, 20, 128, 64, 32)
164
162
  net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
165
163
  output = net(input)
166
164
  self.assertTrue(output.shape == (10, 6, 5, 3, 32))
@@ -169,7 +167,7 @@ class TestPool(parameterized.TestCase):
169
167
  axis=(-1, 0, 1)
170
168
  )
171
169
  def test_AdaptiveMaxPool1d_v1(self, axis):
172
- input = bst.random.randn(32, 16)
170
+ input = brainstate.random.randn(32, 16)
173
171
  net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
174
172
  output = net(input)
175
173
 
@@ -177,7 +175,7 @@ class TestPool(parameterized.TestCase):
177
175
  axis=(-1, 0, 1, 2)
178
176
  )
179
177
  def test_AdaptiveMaxPool1d_v2(self, axis):
180
- input = bst.random.randn(2, 32, 16)
178
+ input = brainstate.random.randn(2, 32, 16)
181
179
  net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
182
180
  output = net(input)
183
181
 
@@ -185,7 +183,7 @@ class TestPool(parameterized.TestCase):
185
183
  axis=(-1, 0, 1, 2)
186
184
  )
187
185
  def test_AdaptiveMaxPool2d_v1(self, axis):
188
- input = bst.random.randn(32, 16, 12)
186
+ input = brainstate.random.randn(32, 16, 12)
189
187
  net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
190
188
  output = net(input)
191
189
 
@@ -193,7 +191,7 @@ class TestPool(parameterized.TestCase):
193
191
  axis=(-1, 0, 1, 2, 3)
194
192
  )
195
193
  def test_AdaptiveMaxPool2d_v2(self, axis):
196
- input = bst.random.randn(2, 32, 16, 12)
194
+ input = brainstate.random.randn(2, 32, 16, 12)
197
195
  net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
198
196
  output = net(input)
199
197
 
@@ -201,7 +199,7 @@ class TestPool(parameterized.TestCase):
201
199
  axis=(-1, 0, 1, 2, 3)
202
200
  )
203
201
  def test_AdaptiveMaxPool3d_v1(self, axis):
204
- input = bst.random.randn(2, 128, 64, 32)
202
+ input = brainstate.random.randn(2, 128, 64, 32)
205
203
  net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
206
204
  output = net(input)
207
205
  print()
@@ -210,7 +208,7 @@ class TestPool(parameterized.TestCase):
210
208
  axis=(-1, 0, 1, 2, 3, 4)
211
209
  )
212
210
  def test_AdaptiveMaxPool3d_v1(self, axis):
213
- input = bst.random.randn(2, 128, 64, 32, 16)
211
+ input = brainstate.random.randn(2, 128, 64, 32, 16)
214
212
  net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
215
213
  output = net(input)
216
214
 
brainstate/nn/_module.py CHANGED
@@ -25,7 +25,6 @@ The basic classes include:
25
25
  - ``Sequential``: The class for a sequential of modules, which update the modules sequentially.
26
26
 
27
27
  """
28
- from __future__ import annotations
29
28
 
30
29
  import warnings
31
30
  from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable