brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +3 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/nn/_collective_ops_test.py +8 -8
  20. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  21. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  22. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  23. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  24. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  25. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  26. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  27. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  28. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  29. brainstate/nn/_event/_linear_mv_test.py +0 -1
  30. brainstate/nn/_exp_euler_test.py +5 -6
  31. brainstate/nn/_interaction/_conv_test.py +31 -33
  32. brainstate/nn/_interaction/_linear_test.py +15 -17
  33. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  34. brainstate/nn/_interaction/_poolings_test.py +19 -21
  35. brainstate/nn/_module_test.py +34 -37
  36. brainstate/optim/_lr_scheduler_test.py +3 -3
  37. brainstate/optim/_optax_optimizer_test.py +8 -9
  38. brainstate/random/_rand_funs_test.py +183 -184
  39. brainstate/random/_rand_seed_test.py +10 -12
  40. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
  41. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
  42. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  43. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  44. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,19 @@ import unittest
18
18
 
19
19
  import numpy as np
20
20
 
21
- import brainstate as bst
21
+ import brainstate
22
22
 
23
23
 
24
24
  class TestDropout(unittest.TestCase):
25
25
 
26
26
  def test_dropout(self):
27
27
  # Create a Dropout layer with a dropout rate of 0.5
28
- dropout_layer = bst.nn.Dropout(0.5)
28
+ dropout_layer = brainstate.nn.Dropout(0.5)
29
29
 
30
30
  # Input data
31
31
  input_data = np.arange(20)
32
32
 
33
- with bst.environ.context(fit=True):
33
+ with brainstate.environ.context(fit=True):
34
34
  # Apply dropout
35
35
  output_data = dropout_layer(input_data)
36
36
 
@@ -47,10 +47,10 @@ class TestDropout(unittest.TestCase):
47
47
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
48
48
 
49
49
  def test_DropoutFixed(self):
50
- dropout_layer = bst.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
50
+ dropout_layer = brainstate.nn.DropoutFixed(in_size=(2, 3), prob=0.5)
51
51
  dropout_layer.init_state(batch_size=2)
52
52
  input_data = np.random.randn(2, 2, 3)
53
- with bst.environ.context(fit=True):
53
+ with brainstate.environ.context(fit=True):
54
54
  output_data = dropout_layer.update(input_data)
55
55
  self.assertEqual(input_data.shape, output_data.shape)
56
56
  self.assertTrue(np.any(output_data == 0))
@@ -72,9 +72,9 @@ class TestDropout(unittest.TestCase):
72
72
  # np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
73
73
 
74
74
  def test_Dropout2d(self):
75
- dropout_layer = bst.nn.Dropout2d(prob=0.5)
75
+ dropout_layer = brainstate.nn.Dropout2d(prob=0.5)
76
76
  input_data = np.random.randn(2, 3, 4, 5)
77
- with bst.environ.context(fit=True):
77
+ with brainstate.environ.context(fit=True):
78
78
  output_data = dropout_layer(input_data)
79
79
  self.assertEqual(input_data.shape, output_data.shape)
80
80
  self.assertTrue(np.any(output_data == 0))
@@ -84,9 +84,9 @@ class TestDropout(unittest.TestCase):
84
84
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
85
85
 
86
86
  def test_Dropout3d(self):
87
- dropout_layer = bst.nn.Dropout3d(prob=0.5)
87
+ dropout_layer = brainstate.nn.Dropout3d(prob=0.5)
88
88
  input_data = np.random.randn(2, 3, 4, 5, 6)
89
- with bst.environ.context(fit=True):
89
+ with brainstate.environ.context(fit=True):
90
90
  output_data = dropout_layer(input_data)
91
91
  self.assertEqual(input_data.shape, output_data.shape)
92
92
  self.assertTrue(np.any(output_data == 0))
@@ -13,157 +13,155 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from absl.testing import absltest
19
17
  from absl.testing import parameterized
20
18
 
21
- import brainstate as bst
19
+ import brainstate
22
20
 
23
21
 
24
22
  class Test_Activation(parameterized.TestCase):
25
23
 
26
24
  def test_Threshold(self):
27
- threshold_layer = bst.nn.Threshold(5, 20)
28
- input = bst.random.randn(2)
25
+ threshold_layer = brainstate.nn.Threshold(5, 20)
26
+ input = brainstate.random.randn(2)
29
27
  output = threshold_layer(input)
30
28
 
31
29
  def test_ReLU(self):
32
- ReLU_layer = bst.nn.ReLU()
33
- input = bst.random.randn(2)
30
+ ReLU_layer = brainstate.nn.ReLU()
31
+ input = brainstate.random.randn(2)
34
32
  output = ReLU_layer(input)
35
33
 
36
34
  def test_RReLU(self):
37
- RReLU_layer = bst.nn.RReLU(lower=0, upper=1)
38
- input = bst.random.randn(2)
35
+ RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
36
+ input = brainstate.random.randn(2)
39
37
  output = RReLU_layer(input)
40
38
 
41
39
  def test_Hardtanh(self):
42
- Hardtanh_layer = bst.nn.Hardtanh(min_val=0, max_val=1, )
43
- input = bst.random.randn(2)
40
+ Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
41
+ input = brainstate.random.randn(2)
44
42
  output = Hardtanh_layer(input)
45
43
 
46
44
  def test_ReLU6(self):
47
- ReLU6_layer = bst.nn.ReLU6()
48
- input = bst.random.randn(2)
45
+ ReLU6_layer = brainstate.nn.ReLU6()
46
+ input = brainstate.random.randn(2)
49
47
  output = ReLU6_layer(input)
50
48
 
51
49
  def test_Sigmoid(self):
52
- Sigmoid_layer = bst.nn.Sigmoid()
53
- input = bst.random.randn(2)
50
+ Sigmoid_layer = brainstate.nn.Sigmoid()
51
+ input = brainstate.random.randn(2)
54
52
  output = Sigmoid_layer(input)
55
53
 
56
54
  def test_Hardsigmoid(self):
57
- Hardsigmoid_layer = bst.nn.Hardsigmoid()
58
- input = bst.random.randn(2)
55
+ Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
56
+ input = brainstate.random.randn(2)
59
57
  output = Hardsigmoid_layer(input)
60
58
 
61
59
  def test_Tanh(self):
62
- Tanh_layer = bst.nn.Tanh()
63
- input = bst.random.randn(2)
60
+ Tanh_layer = brainstate.nn.Tanh()
61
+ input = brainstate.random.randn(2)
64
62
  output = Tanh_layer(input)
65
63
 
66
64
  def test_SiLU(self):
67
- SiLU_layer = bst.nn.SiLU()
68
- input = bst.random.randn(2)
65
+ SiLU_layer = brainstate.nn.SiLU()
66
+ input = brainstate.random.randn(2)
69
67
  output = SiLU_layer(input)
70
68
 
71
69
  def test_Mish(self):
72
- Mish_layer = bst.nn.Mish()
73
- input = bst.random.randn(2)
70
+ Mish_layer = brainstate.nn.Mish()
71
+ input = brainstate.random.randn(2)
74
72
  output = Mish_layer(input)
75
73
 
76
74
  def test_Hardswish(self):
77
- Hardswish_layer = bst.nn.Hardswish()
78
- input = bst.random.randn(2)
75
+ Hardswish_layer = brainstate.nn.Hardswish()
76
+ input = brainstate.random.randn(2)
79
77
  output = Hardswish_layer(input)
80
78
 
81
79
  def test_ELU(self):
82
- ELU_layer = bst.nn.ELU(alpha=0.5, )
83
- input = bst.random.randn(2)
80
+ ELU_layer = brainstate.nn.ELU(alpha=0.5, )
81
+ input = brainstate.random.randn(2)
84
82
  output = ELU_layer(input)
85
83
 
86
84
  def test_CELU(self):
87
- CELU_layer = bst.nn.CELU(alpha=0.5, )
88
- input = bst.random.randn(2)
85
+ CELU_layer = brainstate.nn.CELU(alpha=0.5, )
86
+ input = brainstate.random.randn(2)
89
87
  output = CELU_layer(input)
90
88
 
91
89
  def test_SELU(self):
92
- SELU_layer = bst.nn.SELU()
93
- input = bst.random.randn(2)
90
+ SELU_layer = brainstate.nn.SELU()
91
+ input = brainstate.random.randn(2)
94
92
  output = SELU_layer(input)
95
93
 
96
94
  def test_GLU(self):
97
- GLU_layer = bst.nn.GLU()
98
- input = bst.random.randn(4, 2)
95
+ GLU_layer = brainstate.nn.GLU()
96
+ input = brainstate.random.randn(4, 2)
99
97
  output = GLU_layer(input)
100
98
 
101
99
  @parameterized.product(
102
100
  approximate=['tanh', 'none']
103
101
  )
104
102
  def test_GELU(self, approximate):
105
- GELU_layer = bst.nn.GELU()
106
- input = bst.random.randn(2)
103
+ GELU_layer = brainstate.nn.GELU()
104
+ input = brainstate.random.randn(2)
107
105
  output = GELU_layer(input)
108
106
 
109
107
  def test_Hardshrink(self):
110
- Hardshrink_layer = bst.nn.Hardshrink(lambd=1)
111
- input = bst.random.randn(2)
108
+ Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
109
+ input = brainstate.random.randn(2)
112
110
  output = Hardshrink_layer(input)
113
111
 
114
112
  def test_LeakyReLU(self):
115
- LeakyReLU_layer = bst.nn.LeakyReLU()
116
- input = bst.random.randn(2)
113
+ LeakyReLU_layer = brainstate.nn.LeakyReLU()
114
+ input = brainstate.random.randn(2)
117
115
  output = LeakyReLU_layer(input)
118
116
 
119
117
  def test_LogSigmoid(self):
120
- LogSigmoid_layer = bst.nn.LogSigmoid()
121
- input = bst.random.randn(2)
118
+ LogSigmoid_layer = brainstate.nn.LogSigmoid()
119
+ input = brainstate.random.randn(2)
122
120
  output = LogSigmoid_layer(input)
123
121
 
124
122
  def test_Softplus(self):
125
- Softplus_layer = bst.nn.Softplus()
126
- input = bst.random.randn(2)
123
+ Softplus_layer = brainstate.nn.Softplus()
124
+ input = brainstate.random.randn(2)
127
125
  output = Softplus_layer(input)
128
126
 
129
127
  def test_Softshrink(self):
130
- Softshrink_layer = bst.nn.Softshrink(lambd=1)
131
- input = bst.random.randn(2)
128
+ Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
129
+ input = brainstate.random.randn(2)
132
130
  output = Softshrink_layer(input)
133
131
 
134
132
  def test_PReLU(self):
135
- PReLU_layer = bst.nn.PReLU(num_parameters=2, init=0.5)
136
- input = bst.random.randn(2)
133
+ PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
134
+ input = brainstate.random.randn(2)
137
135
  output = PReLU_layer(input)
138
136
 
139
137
  def test_Softsign(self):
140
- Softsign_layer = bst.nn.Softsign()
141
- input = bst.random.randn(2)
138
+ Softsign_layer = brainstate.nn.Softsign()
139
+ input = brainstate.random.randn(2)
142
140
  output = Softsign_layer(input)
143
141
 
144
142
  def test_Tanhshrink(self):
145
- Tanhshrink_layer = bst.nn.Tanhshrink()
146
- input = bst.random.randn(2)
143
+ Tanhshrink_layer = brainstate.nn.Tanhshrink()
144
+ input = brainstate.random.randn(2)
147
145
  output = Tanhshrink_layer(input)
148
146
 
149
147
  def test_Softmin(self):
150
- Softmin_layer = bst.nn.Softmin(dim=2)
151
- input = bst.random.randn(2, 3, 4)
148
+ Softmin_layer = brainstate.nn.Softmin(dim=2)
149
+ input = brainstate.random.randn(2, 3, 4)
152
150
  output = Softmin_layer(input)
153
151
 
154
152
  def test_Softmax(self):
155
- Softmax_layer = bst.nn.Softmax(dim=2)
156
- input = bst.random.randn(2, 3, 4)
153
+ Softmax_layer = brainstate.nn.Softmax(dim=2)
154
+ input = brainstate.random.randn(2, 3, 4)
157
155
  output = Softmax_layer(input)
158
156
 
159
157
  def test_Softmax2d(self):
160
- Softmax2d_layer = bst.nn.Softmax2d()
161
- input = bst.random.randn(2, 3, 12, 13)
158
+ Softmax2d_layer = brainstate.nn.Softmax2d()
159
+ input = brainstate.random.randn(2, 3, 12, 13)
162
160
  output = Softmax2d_layer(input)
163
161
 
164
162
  def test_LogSoftmax(self):
165
- LogSoftmax_layer = bst.nn.LogSoftmax(dim=2)
166
- input = bst.random.randn(2, 3, 4)
163
+ LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
164
+ input = brainstate.random.randn(2, 3, 4)
167
165
  output = LogSoftmax_layer(input)
168
166
 
169
167
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import jax.numpy
19
18
  import jax.numpy as jnp
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import jax
19
18
  import jax.numpy as jnp
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import brainunit as u
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestExpEuler(unittest.TestCase):
@@ -27,10 +26,10 @@ class TestExpEuler(unittest.TestCase):
27
26
  def fun(x, tau):
28
27
  return -x / tau
29
28
 
30
- with bst.environ.context(dt=0.1):
29
+ with brainstate.environ.context(dt=0.1):
31
30
  with self.assertRaises(AssertionError):
32
- r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
31
+ r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
33
32
 
34
- with bst.environ.context(dt=1. * u.ms):
35
- r = bst.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
33
+ with brainstate.environ.context(dt=1. * u.ms):
34
+ r = brainstate.nn.exp_euler_step(fun, 1.0 * u.mV, 1. * u.ms)
36
35
  print(r)
@@ -1,13 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- from __future__ import annotations
4
-
5
3
  import jax.numpy as jnp
6
4
  import pytest
7
5
  from absl.testing import absltest
8
6
  from absl.testing import parameterized
9
7
 
10
- import brainstate as bst
8
+ import brainstate
11
9
 
12
10
 
13
11
  class TestConv(parameterized.TestCase):
@@ -19,8 +17,8 @@ class TestConv(parameterized.TestCase):
19
17
  img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
20
18
  img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
21
19
 
22
- net = bst.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
23
- stride=(2, 1), padding='VALID', groups=4)
20
+ net = brainstate.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
21
+ stride=(2, 1), padding='VALID', groups=4)
24
22
  out = net(img)
25
23
  print("out shape: ", out.shape)
26
24
  self.assertEqual(out.shape, (2, 99, 196, 32))
@@ -30,7 +28,7 @@ class TestConv(parameterized.TestCase):
30
28
  # plt.show()
31
29
 
32
30
  def test_conv1D(self):
33
- model = bst.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
31
+ model = brainstate.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
34
32
  input = jnp.ones((2, 5, 3))
35
33
  out = model(input)
36
34
  print("out shape: ", out.shape)
@@ -41,7 +39,7 @@ class TestConv(parameterized.TestCase):
41
39
  # plt.show()
42
40
 
43
41
  def test_conv2D(self):
44
- model = bst.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
42
+ model = brainstate.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
45
43
  input = jnp.ones((2, 5, 5, 3))
46
44
 
47
45
  out = model(input)
@@ -49,7 +47,7 @@ class TestConv(parameterized.TestCase):
49
47
  self.assertEqual(out.shape, (2, 5, 5, 32))
50
48
 
51
49
  def test_conv3D(self):
52
- model = bst.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
50
+ model = brainstate.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
53
51
  input = jnp.ones((2, 5, 5, 5, 3))
54
52
  out = model(input)
55
53
  print("out shape: ", out.shape)
@@ -62,13 +60,13 @@ class TestConvTranspose1d(parameterized.TestCase):
62
60
 
63
61
  x = jnp.ones((1, 8, 3))
64
62
  for use_bias in [True, False]:
65
- conv_transpose_module = bst.nn.ConvTranspose1d(
63
+ conv_transpose_module = brainstate.nn.ConvTranspose1d(
66
64
  in_channels=3,
67
65
  out_channels=4,
68
66
  kernel_size=(3,),
69
67
  padding='VALID',
70
- w_initializer=bst.init.Constant(1.),
71
- b_initializer=bst.init.Constant(1.) if use_bias else None,
68
+ w_initializer=brainstate.init.Constant(1.),
69
+ b_initializer=brainstate.init.Constant(1.) if use_bias else None,
72
70
  )
73
71
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
74
72
  y = conv_transpose_module(x)
@@ -91,14 +89,14 @@ class TestConvTranspose1d(parameterized.TestCase):
91
89
 
92
90
  x = jnp.ones((1, 8, 3))
93
91
  m = jnp.tril(jnp.ones((3, 3, 4)))
94
- conv_transpose_module = bst.nn.ConvTranspose1d(
92
+ conv_transpose_module = brainstate.nn.ConvTranspose1d(
95
93
  in_channels=3,
96
94
  out_channels=4,
97
95
  kernel_size=(3,),
98
96
  padding='VALID',
99
97
  mask=m,
100
- w_initializer=bst.init.Constant(),
101
- b_initializer=bst.init.Constant(),
98
+ w_initializer=brainstate.init.Constant(),
99
+ b_initializer=brainstate.init.Constant(),
102
100
  )
103
101
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
104
102
  y = conv_transpose_module(x)
@@ -119,14 +117,14 @@ class TestConvTranspose1d(parameterized.TestCase):
119
117
 
120
118
  data = jnp.ones([1, 3, 1])
121
119
  for use_bias in [True, False]:
122
- net = bst.nn.ConvTranspose1d(
120
+ net = brainstate.nn.ConvTranspose1d(
123
121
  in_channels=1,
124
122
  out_channels=1,
125
123
  kernel_size=3,
126
124
  stride=1,
127
125
  padding="SAME",
128
- w_initializer=bst.init.Constant(),
129
- b_initializer=bst.init.Constant() if use_bias else None,
126
+ w_initializer=brainstate.init.Constant(),
127
+ b_initializer=brainstate.init.Constant() if use_bias else None,
130
128
  )
131
129
  out = net(data)
132
130
  self.assertEqual(out.shape, (1, 3, 1))
@@ -143,13 +141,13 @@ class TestConvTranspose2d(parameterized.TestCase):
143
141
 
144
142
  x = jnp.ones((1, 8, 8, 3))
145
143
  for use_bias in [True, False]:
146
- conv_transpose_module = bst.nn.ConvTranspose2d(
144
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
147
145
  in_channels=3,
148
146
  out_channels=4,
149
147
  kernel_size=(3, 3),
150
148
  padding='VALID',
151
- w_initializer=bst.init.Constant(),
152
- b_initializer=bst.init.Constant() if use_bias else None,
149
+ w_initializer=brainstate.init.Constant(),
150
+ b_initializer=brainstate.init.Constant() if use_bias else None,
153
151
  )
154
152
  self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
155
153
  y = conv_transpose_module(x)
@@ -159,13 +157,13 @@ class TestConvTranspose2d(parameterized.TestCase):
159
157
 
160
158
  x = jnp.ones((1, 8, 8, 3))
161
159
  m = jnp.tril(jnp.ones((3, 3, 3, 4)))
162
- conv_transpose_module = bst.nn.ConvTranspose2d(
160
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
163
161
  in_channels=3,
164
162
  out_channels=4,
165
163
  kernel_size=(3, 3),
166
164
  padding='VALID',
167
165
  mask=m,
168
- w_initializer=bst.init.Constant(),
166
+ w_initializer=brainstate.init.Constant(),
169
167
  )
170
168
  y = conv_transpose_module(x)
171
169
  print(y.shape)
@@ -174,14 +172,14 @@ class TestConvTranspose2d(parameterized.TestCase):
174
172
 
175
173
  x = jnp.ones((1, 8, 8, 3))
176
174
  for use_bias in [True, False]:
177
- conv_transpose_module = bst.nn.ConvTranspose2d(
175
+ conv_transpose_module = brainstate.nn.ConvTranspose2d(
178
176
  in_channels=3,
179
177
  out_channels=4,
180
178
  kernel_size=(3, 3),
181
179
  stride=1,
182
180
  padding='SAME',
183
- w_initializer=bst.init.Constant(),
184
- b_initializer=bst.init.Constant() if use_bias else None,
181
+ w_initializer=brainstate.init.Constant(),
182
+ b_initializer=brainstate.init.Constant() if use_bias else None,
185
183
  )
186
184
  y = conv_transpose_module(x)
187
185
  print(y.shape)
@@ -193,13 +191,13 @@ class TestConvTranspose3d(parameterized.TestCase):
193
191
 
194
192
  x = jnp.ones((1, 8, 8, 8, 3))
195
193
  for use_bias in [True, False]:
196
- conv_transpose_module = bst.nn.ConvTranspose3d(
194
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
197
195
  in_channels=3,
198
196
  out_channels=4,
199
197
  kernel_size=(3, 3, 3),
200
198
  padding='VALID',
201
- w_initializer=bst.init.Constant(),
202
- b_initializer=bst.init.Constant() if use_bias else None,
199
+ w_initializer=brainstate.init.Constant(),
200
+ b_initializer=brainstate.init.Constant() if use_bias else None,
203
201
  )
204
202
  y = conv_transpose_module(x)
205
203
  print(y.shape)
@@ -208,13 +206,13 @@ class TestConvTranspose3d(parameterized.TestCase):
208
206
 
209
207
  x = jnp.ones((1, 8, 8, 8, 3))
210
208
  m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
211
- conv_transpose_module = bst.nn.ConvTranspose3d(
209
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
212
210
  in_channels=3,
213
211
  out_channels=4,
214
212
  kernel_size=(3, 3, 3),
215
213
  padding='VALID',
216
214
  mask=m,
217
- w_initializer=bst.init.Constant(),
215
+ w_initializer=brainstate.init.Constant(),
218
216
  )
219
217
  y = conv_transpose_module(x)
220
218
  print(y.shape)
@@ -223,14 +221,14 @@ class TestConvTranspose3d(parameterized.TestCase):
223
221
 
224
222
  x = jnp.ones((1, 8, 8, 8, 3))
225
223
  for use_bias in [True, False]:
226
- conv_transpose_module = bst.nn.ConvTranspose3d(
224
+ conv_transpose_module = brainstate.nn.ConvTranspose3d(
227
225
  in_channels=3,
228
226
  out_channels=4,
229
227
  kernel_size=(3, 3, 3),
230
228
  stride=1,
231
229
  padding='SAME',
232
- w_initializer=bst.init.Constant(),
233
- b_initializer=bst.init.Constant() if use_bias else None,
230
+ w_initializer=brainstate.init.Constant(),
231
+ b_initializer=brainstate.init.Constant() if use_bias else None,
234
232
  )
235
233
  y = conv_transpose_module(x)
236
234
  print(y.shape)
@@ -14,14 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  import unittest
20
18
 
21
19
  import brainunit as u
22
20
  from absl.testing import parameterized
23
21
 
24
- import brainstate as bst
22
+ import brainstate
25
23
 
26
24
 
27
25
  class TestDense(parameterized.TestCase):
@@ -32,19 +30,19 @@ class TestDense(parameterized.TestCase):
32
30
  num_out=[20, ]
33
31
  )
34
32
  def test_Dense1(self, size, num_out):
35
- f = bst.nn.Linear(10, num_out)
36
- x = bst.random.random(size)
33
+ f = brainstate.nn.Linear(10, num_out)
34
+ x = brainstate.random.random(size)
37
35
  y = f(x)
38
36
  self.assertTrue(y.shape == size[:-1] + (num_out,))
39
37
 
40
38
 
41
39
  class TestSparseMatrix(unittest.TestCase):
42
40
  def test_csr(self):
43
- data = bst.random.rand(10, 20)
41
+ data = brainstate.random.rand(10, 20)
44
42
  data = data * (data > 0.9)
45
- f = bst.nn.SparseLinear(u.sparse.CSR.fromdense(data))
43
+ f = brainstate.nn.SparseLinear(u.sparse.CSR.fromdense(data))
46
44
 
47
- x = bst.random.rand(10)
45
+ x = brainstate.random.rand(10)
48
46
  y = f(x)
49
47
  self.assertTrue(
50
48
  u.math.allclose(
@@ -53,7 +51,7 @@ class TestSparseMatrix(unittest.TestCase):
53
51
  )
54
52
  )
55
53
 
56
- x = bst.random.rand(5, 10)
54
+ x = brainstate.random.rand(5, 10)
57
55
  y = f(x)
58
56
  self.assertTrue(
59
57
  u.math.allclose(
@@ -63,11 +61,11 @@ class TestSparseMatrix(unittest.TestCase):
63
61
  )
64
62
 
65
63
  def test_csc(self):
66
- data = bst.random.rand(10, 20)
64
+ data = brainstate.random.rand(10, 20)
67
65
  data = data * (data > 0.9)
68
- f = bst.nn.SparseLinear(u.sparse.CSC.fromdense(data))
66
+ f = brainstate.nn.SparseLinear(u.sparse.CSC.fromdense(data))
69
67
 
70
- x = bst.random.rand(10)
68
+ x = brainstate.random.rand(10)
71
69
  y = f(x)
72
70
  self.assertTrue(
73
71
  u.math.allclose(
@@ -76,7 +74,7 @@ class TestSparseMatrix(unittest.TestCase):
76
74
  )
77
75
  )
78
76
 
79
- x = bst.random.rand(5, 10)
77
+ x = brainstate.random.rand(5, 10)
80
78
  y = f(x)
81
79
  self.assertTrue(
82
80
  u.math.allclose(
@@ -86,11 +84,11 @@ class TestSparseMatrix(unittest.TestCase):
86
84
  )
87
85
 
88
86
  def test_coo(self):
89
- data = bst.random.rand(10, 20)
87
+ data = brainstate.random.rand(10, 20)
90
88
  data = data * (data > 0.9)
91
- f = bst.nn.SparseLinear(u.sparse.COO.fromdense(data))
89
+ f = brainstate.nn.SparseLinear(u.sparse.COO.fromdense(data))
92
90
 
93
- x = bst.random.rand(10)
91
+ x = brainstate.random.rand(10)
94
92
  y = f(x)
95
93
  self.assertTrue(
96
94
  u.math.allclose(
@@ -99,7 +97,7 @@ class TestSparseMatrix(unittest.TestCase):
99
97
  )
100
98
  )
101
99
 
102
- x = bst.random.rand(5, 10)
100
+ x = brainstate.random.rand(5, 10)
103
101
  y = f(x)
104
102
  self.assertTrue(
105
103
  u.math.allclose(
@@ -13,12 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from absl.testing import absltest
19
17
  from absl.testing import parameterized
20
18
 
21
- import brainstate as bst
19
+ import brainstate
22
20
 
23
21
 
24
22
  class Test_Normalization(parameterized.TestCase):
@@ -26,27 +24,27 @@ class Test_Normalization(parameterized.TestCase):
26
24
  fit=[True, False],
27
25
  )
28
26
  def test_BatchNorm1d(self, fit):
29
- net = bst.nn.BatchNorm1d((3, 10))
30
- bst.environ.set(fit=fit)
31
- input = bst.random.randn(1, 3, 10)
27
+ net = brainstate.nn.BatchNorm1d((3, 10))
28
+ brainstate.environ.set(fit=fit)
29
+ input = brainstate.random.randn(1, 3, 10)
32
30
  output = net(input)
33
31
 
34
32
  @parameterized.product(
35
33
  fit=[True, False]
36
34
  )
37
35
  def test_BatchNorm2d(self, fit):
38
- net = bst.nn.BatchNorm2d([3, 4, 10])
39
- bst.environ.set(fit=fit)
40
- input = bst.random.randn(1, 3, 4, 10)
36
+ net = brainstate.nn.BatchNorm2d([3, 4, 10])
37
+ brainstate.environ.set(fit=fit)
38
+ input = brainstate.random.randn(1, 3, 4, 10)
41
39
  output = net(input)
42
40
 
43
41
  @parameterized.product(
44
42
  fit=[True, False]
45
43
  )
46
44
  def test_BatchNorm3d(self, fit):
47
- net = bst.nn.BatchNorm3d([3, 4, 5, 10])
48
- bst.environ.set(fit=fit)
49
- input = bst.random.randn(1, 3, 4, 5, 10)
45
+ net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
46
+ brainstate.environ.set(fit=fit)
47
+ input = brainstate.random.randn(1, 3, 4, 5, 10)
50
48
  output = net(input)
51
49
 
52
50
  # @parameterized.product(