brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +3 -0
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/nn/_collective_ops_test.py +8 -8
  20. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  21. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  22. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  23. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  24. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  25. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  26. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  27. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  28. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  29. brainstate/nn/_event/_linear_mv_test.py +0 -1
  30. brainstate/nn/_exp_euler_test.py +5 -6
  31. brainstate/nn/_interaction/_conv_test.py +31 -33
  32. brainstate/nn/_interaction/_linear_test.py +15 -17
  33. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  34. brainstate/nn/_interaction/_poolings_test.py +19 -21
  35. brainstate/nn/_module_test.py +34 -37
  36. brainstate/optim/_lr_scheduler_test.py +3 -3
  37. brainstate/optim/_optax_optimizer_test.py +8 -9
  38. brainstate/random/_rand_funs_test.py +183 -184
  39. brainstate/random/_rand_seed_test.py +10 -12
  40. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
  41. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
  42. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  43. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  44. {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -14,30 +14,29 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  import unittest
20
19
 
21
- import brainstate as bst
20
+ import brainstate
22
21
 
23
22
 
24
23
  class TestNormalInit(unittest.TestCase):
25
24
 
26
25
  def test_normal_init1(self):
27
- init = bst.init.Normal()
26
+ init = brainstate.init.Normal()
28
27
  for size in [(100,), (10, 20), (10, 20, 30)]:
29
28
  weights = init(size)
30
29
  assert weights.shape == size
31
30
 
32
31
  def test_normal_init2(self):
33
- init = bst.init.Normal(scale=0.5)
32
+ init = brainstate.init.Normal(scale=0.5)
34
33
  for size in [(100,), (10, 20)]:
35
34
  weights = init(size)
36
35
  assert weights.shape == size
37
36
 
38
37
  def test_normal_init3(self):
39
- init1 = bst.init.Normal(scale=0.5, seed=10)
40
- init2 = bst.init.Normal(scale=0.5, seed=10)
38
+ init1 = brainstate.init.Normal(scale=0.5, seed=10)
39
+ init2 = brainstate.init.Normal(scale=0.5, seed=10)
41
40
  size = (10,)
42
41
  weights1 = init1(size)
43
42
  weights2 = init2(size)
@@ -47,13 +46,13 @@ class TestNormalInit(unittest.TestCase):
47
46
 
48
47
  class TestUniformInit(unittest.TestCase):
49
48
  def test_uniform_init1(self):
50
- init = bst.init.Normal()
49
+ init = brainstate.init.Normal()
51
50
  for size in [(100,), (10, 20), (10, 20, 30)]:
52
51
  weights = init(size)
53
52
  assert weights.shape == size
54
53
 
55
54
  def test_uniform_init2(self):
56
- init = bst.init.Uniform(min_val=10, max_val=20)
55
+ init = brainstate.init.Uniform(min_val=10, max_val=20)
57
56
  for size in [(100,), (10, 20)]:
58
57
  weights = init(size)
59
58
  assert weights.shape == size
@@ -61,20 +60,20 @@ class TestUniformInit(unittest.TestCase):
61
60
 
62
61
  class TestVarianceScaling(unittest.TestCase):
63
62
  def test_var_scaling1(self):
64
- init = bst.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
63
+ init = brainstate.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal')
65
64
  for size in [(10, 20), (10, 20, 30)]:
66
65
  weights = init(size)
67
66
  assert weights.shape == size
68
67
 
69
68
  def test_var_scaling2(self):
70
- init = bst.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
69
+ init = brainstate.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal')
71
70
  for size in [(10, 20), (10, 20, 30)]:
72
71
  weights = init(size)
73
72
  assert weights.shape == size
74
73
 
75
74
  def test_var_scaling3(self):
76
- init = bst.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
77
- distribution='uniform')
75
+ init = brainstate.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1,
76
+ distribution='uniform')
78
77
  for size in [(10, 20), (10, 20, 30)]:
79
78
  weights = init(size)
80
79
  assert weights.shape == size
@@ -82,7 +81,7 @@ class TestVarianceScaling(unittest.TestCase):
82
81
 
83
82
  class TestKaimingUniformUnit(unittest.TestCase):
84
83
  def test_kaiming_uniform_init(self):
85
- init = bst.init.KaimingUniform()
84
+ init = brainstate.init.KaimingUniform()
86
85
  for size in [(10, 20), (10, 20, 30)]:
87
86
  weights = init(size)
88
87
  assert weights.shape == size
@@ -90,7 +89,7 @@ class TestKaimingUniformUnit(unittest.TestCase):
90
89
 
91
90
  class TestKaimingNormalUnit(unittest.TestCase):
92
91
  def test_kaiming_normal_init(self):
93
- init = bst.init.KaimingNormal()
92
+ init = brainstate.init.KaimingNormal()
94
93
  for size in [(10, 20), (10, 20, 30)]:
95
94
  weights = init(size)
96
95
  assert weights.shape == size
@@ -98,7 +97,7 @@ class TestKaimingNormalUnit(unittest.TestCase):
98
97
 
99
98
  class TestXavierUniformUnit(unittest.TestCase):
100
99
  def test_xavier_uniform_init(self):
101
- init = bst.init.XavierUniform()
100
+ init = brainstate.init.XavierUniform()
102
101
  for size in [(10, 20), (10, 20, 30)]:
103
102
  weights = init(size)
104
103
  assert weights.shape == size
@@ -106,7 +105,7 @@ class TestXavierUniformUnit(unittest.TestCase):
106
105
 
107
106
  class TestXavierNormalUnit(unittest.TestCase):
108
107
  def test_xavier_normal_init(self):
109
- init = bst.init.XavierNormal()
108
+ init = brainstate.init.XavierNormal()
110
109
  for size in [(10, 20), (10, 20, 30)]:
111
110
  weights = init(size)
112
111
  assert weights.shape == size
@@ -114,7 +113,7 @@ class TestXavierNormalUnit(unittest.TestCase):
114
113
 
115
114
  class TestLecunUniformUnit(unittest.TestCase):
116
115
  def test_lecun_uniform_init(self):
117
- init = bst.init.LecunUniform()
116
+ init = brainstate.init.LecunUniform()
118
117
  for size in [(10, 20), (10, 20, 30)]:
119
118
  weights = init(size)
120
119
  assert weights.shape == size
@@ -122,7 +121,7 @@ class TestLecunUniformUnit(unittest.TestCase):
122
121
 
123
122
  class TestLecunNormalUnit(unittest.TestCase):
124
123
  def test_lecun_normal_init(self):
125
- init = bst.init.LecunNormal()
124
+ init = brainstate.init.LecunNormal()
126
125
  for size in [(10, 20), (10, 20, 30)]:
127
126
  weights = init(size)
128
127
  assert weights.shape == size
@@ -130,13 +129,13 @@ class TestLecunNormalUnit(unittest.TestCase):
130
129
 
131
130
  class TestOrthogonalUnit(unittest.TestCase):
132
131
  def test_orthogonal_init1(self):
133
- init = bst.init.Orthogonal()
132
+ init = brainstate.init.Orthogonal()
134
133
  for size in [(20, 20), (10, 20, 30)]:
135
134
  weights = init(size)
136
135
  assert weights.shape == size
137
136
 
138
137
  def test_orthogonal_init2(self):
139
- init = bst.init.Orthogonal(scale=2., axis=0)
138
+ init = brainstate.init.Orthogonal(scale=2., axis=0)
140
139
  for size in [(10, 20), (10, 20, 30)]:
141
140
  weights = init(size)
142
141
  assert weights.shape == size
@@ -144,7 +143,7 @@ class TestOrthogonalUnit(unittest.TestCase):
144
143
 
145
144
  class TestDeltaOrthogonalUnit(unittest.TestCase):
146
145
  def test_delta_orthogonal_init1(self):
147
- init = bst.init.DeltaOrthogonal()
146
+ init = brainstate.init.DeltaOrthogonal()
148
147
  for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]:
149
148
  weights = init(size)
150
149
  assert weights.shape == size
@@ -14,16 +14,15 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  import unittest
20
19
 
21
- import brainstate as bst
20
+ import brainstate
22
21
 
23
22
 
24
23
  class TestZeroInit(unittest.TestCase):
25
24
  def test_zero_init(self):
26
- init = bst.init.ZeroInit()
25
+ init = brainstate.init.ZeroInit()
27
26
  for size in [(100,), (10, 20), (10, 20, 30)]:
28
27
  weights = init(size)
29
28
  assert weights.shape == size
@@ -33,7 +32,7 @@ class TestOneInit(unittest.TestCase):
33
32
  def test_one_init(self):
34
33
  for size in [(100,), (10, 20), (10, 20, 30)]:
35
34
  for value in [0., 1., -1.]:
36
- init = bst.init.Constant(value=value)
35
+ init = brainstate.init.Constant(value=value)
37
36
  weights = init(size)
38
37
  assert weights.shape == size
39
38
  assert (weights == value).all()
@@ -43,7 +42,7 @@ class TestIdentityInit(unittest.TestCase):
43
42
  def test_identity_init(self):
44
43
  for size in [(100,), (10, 20)]:
45
44
  for value in [0., 1., -1.]:
46
- init = bst.init.Identity(value=value)
45
+ init = brainstate.init.Identity(value=value)
47
46
  weights = init(size)
48
47
  if len(size) == 1:
49
48
  assert weights.shape == (size[0], size[0])
@@ -16,21 +16,21 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- import brainstate as bst
19
+ import brainstate
20
20
 
21
21
 
22
22
  class Test_vmap_init_all_states:
23
23
 
24
24
  def test_vmap_init_all_states(self):
25
- gru = bst.nn.GRUCell(1, 2)
26
- bst.nn.vmap_init_all_states(gru, axis_size=10)
25
+ gru = brainstate.nn.GRUCell(1, 2)
26
+ brainstate.nn.vmap_init_all_states(gru, axis_size=10)
27
27
  print(gru)
28
28
 
29
29
  def test_vmap_init_all_states_v2(self):
30
- @bst.compile.jit
30
+ @brainstate.compile.jit
31
31
  def init():
32
- gru = bst.nn.GRUCell(1, 2)
33
- bst.nn.vmap_init_all_states(gru, axis_size=10)
32
+ gru = brainstate.nn.GRUCell(1, 2)
33
+ brainstate.nn.vmap_init_all_states(gru, axis_size=10)
34
34
  print(gru)
35
35
 
36
36
  init()
@@ -38,6 +38,6 @@ class Test_vmap_init_all_states:
38
38
 
39
39
  class Test_init_all_states:
40
40
  def test_init_all_states(self):
41
- gru = bst.nn.GRUCell(1, 2)
42
- bst.nn.init_all_states(gru, batch_size=10)
41
+ gru = brainstate.nn.GRUCell(1, 2)
42
+ brainstate.nn.init_all_states(gru, batch_size=10)
43
43
  print(gru)
@@ -15,7 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
18
 
20
19
  import unittest
21
20
 
@@ -23,7 +22,7 @@ import brainunit as u
23
22
  import jax
24
23
  import jax.numpy as jnp
25
24
 
26
- import brainstate as bst
25
+ import brainstate
27
26
  from brainstate.nn import IF, LIF, ALIF
28
27
 
29
28
 
@@ -35,13 +34,13 @@ class TestNeuron(unittest.TestCase):
35
34
 
36
35
  def test_neuron_base_class(self):
37
36
  with self.assertRaises(NotImplementedError):
38
- bst.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
37
+ brainstate.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
39
38
 
40
39
  def generate_input(self):
41
- return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
40
+ return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
42
41
 
43
42
  def test_if_neuron(self):
44
- with bst.environ.context(dt=0.1 * u.ms):
43
+ with brainstate.environ.context(dt=0.1 * u.ms):
45
44
  neuron = IF(self.in_size)
46
45
  inputs = self.generate_input()
47
46
 
@@ -62,7 +61,7 @@ class TestNeuron(unittest.TestCase):
62
61
  self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
63
62
 
64
63
  def test_lif_neuron(self):
65
- with bst.environ.context(dt=0.1 * u.ms):
64
+ with brainstate.environ.context(dt=0.1 * u.ms):
66
65
  tau = 20.0 * u.ms
67
66
  neuron = LIF(self.in_size, tau=tau)
68
67
  inputs = self.generate_input()
@@ -74,7 +73,7 @@ class TestNeuron(unittest.TestCase):
74
73
 
75
74
  # Test forward pass
76
75
  state = neuron.init_state(self.batch_size)
77
- call = bst.compile.jit(neuron)
76
+ call = brainstate.compile.jit(neuron)
78
77
 
79
78
  for t in range(self.time_steps):
80
79
  out = call(inputs[t])
@@ -94,8 +93,8 @@ class TestNeuron(unittest.TestCase):
94
93
 
95
94
  # Test forward pass
96
95
  neuron.init_state(self.batch_size)
97
- call = bst.compile.jit(neuron)
98
- with bst.environ.context(dt=0.1 * u.ms):
96
+ call = brainstate.compile.jit(neuron)
97
+ with brainstate.environ.context(dt=0.1 * u.ms):
99
98
  for t in range(self.time_steps):
100
99
  out = call(inputs[t])
101
100
  self.assertEqual(out.shape, (self.batch_size, self.in_size))
@@ -113,8 +112,8 @@ class TestNeuron(unittest.TestCase):
113
112
  neuron = NeuronClass(self.in_size, spk_reset='soft')
114
113
  inputs = self.generate_input()
115
114
  state = neuron.init_state(self.batch_size)
116
- call = bst.compile.jit(neuron)
117
- with bst.environ.context(dt=0.1 * u.ms):
115
+ call = brainstate.compile.jit(neuron)
116
+ with brainstate.environ.context(dt=0.1 * u.ms):
118
117
  for t in range(self.time_steps):
119
118
  out = call(inputs[t])
120
119
  self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th))
@@ -124,8 +123,8 @@ class TestNeuron(unittest.TestCase):
124
123
  neuron = NeuronClass(self.in_size, spk_reset='hard')
125
124
  inputs = self.generate_input()
126
125
  state = neuron.init_state(self.batch_size)
127
- call = bst.compile.jit(neuron)
128
- with bst.environ.context(dt=0.1 * u.ms):
126
+ call = brainstate.compile.jit(neuron)
127
+ with brainstate.environ.context(dt=0.1 * u.ms):
129
128
  for t in range(self.time_steps):
130
129
  out = call(inputs[t])
131
130
  self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV)))
@@ -135,8 +134,8 @@ class TestNeuron(unittest.TestCase):
135
134
  neuron = NeuronClass(self.in_size)
136
135
  inputs = self.generate_input()
137
136
  state = neuron.init_state(self.batch_size)
138
- call = bst.compile.jit(neuron)
139
- with bst.environ.context(dt=0.1 * u.ms):
137
+ call = brainstate.compile.jit(neuron)
138
+ with brainstate.environ.context(dt=0.1 * u.ms):
140
139
  for t in range(self.time_steps):
141
140
  out = call(inputs[t])
142
141
  self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type)
@@ -148,15 +147,15 @@ class TestNeuron(unittest.TestCase):
148
147
  self.assertEqual(neuron.in_size, in_size)
149
148
  self.assertEqual(neuron.out_size, in_size)
150
149
 
151
- inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
150
+ inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
152
151
  state = neuron.init_state(self.batch_size)
153
- call = bst.compile.jit(neuron)
154
- with bst.environ.context(dt=0.1 * u.ms):
152
+ call = brainstate.compile.jit(neuron)
153
+ with brainstate.environ.context(dt=0.1 * u.ms):
155
154
  for t in range(self.time_steps):
156
155
  out = call(inputs[t])
157
156
  self.assertEqual(out.shape, (self.batch_size, *in_size))
158
157
 
159
158
 
160
159
  if __name__ == '__main__':
161
- with bst.environ.context(dt=0.1):
160
+ with brainstate.environ.context(dt=0.1):
162
161
  unittest.main()
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -21,7 +20,7 @@ import brainunit as u
21
20
  import jax.numpy as jnp
22
21
  import pytest
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
  from brainstate.nn import Expon, STP, STD
26
25
 
27
26
 
@@ -32,7 +31,7 @@ class TestSynapse(unittest.TestCase):
32
31
  self.time_steps = 100
33
32
 
34
33
  def generate_input(self):
35
- return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
34
+ return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
36
35
 
37
36
  def test_expon_synapse(self):
38
37
  tau = 20.0 * u.ms
@@ -46,8 +45,8 @@ class TestSynapse(unittest.TestCase):
46
45
 
47
46
  # Test forward pass
48
47
  state = synapse.init_state(self.batch_size)
49
- call = bst.compile.jit(synapse)
50
- with bst.environ.context(dt=0.1 * u.ms):
48
+ call = brainstate.compile.jit(synapse)
49
+ with brainstate.environ.context(dt=0.1 * u.ms):
51
50
  for t in range(self.time_steps):
52
51
  out = call(inputs[t])
53
52
  self.assertEqual(out.shape, (self.batch_size, self.in_size))
@@ -75,7 +74,7 @@ class TestSynapse(unittest.TestCase):
75
74
 
76
75
  # Test forward pass
77
76
  state = synapse.init_state(self.batch_size)
78
- call = bst.compile.jit(synapse)
77
+ call = brainstate.compile.jit(synapse)
79
78
  for t in range(self.time_steps):
80
79
  out = call(inputs[t])
81
80
  self.assertEqual(out.shape, (self.batch_size, self.in_size))
@@ -118,15 +117,15 @@ class TestSynapse(unittest.TestCase):
118
117
  self.assertEqual(synapse.in_size, in_size)
119
118
  self.assertEqual(synapse.out_size, in_size)
120
119
 
121
- inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
120
+ inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
122
121
  state = synapse.init_state(self.batch_size)
123
- call = bst.compile.jit(synapse)
124
- with bst.environ.context(dt=0.1 * u.ms):
122
+ call = brainstate.compile.jit(synapse)
123
+ with brainstate.environ.context(dt=0.1 * u.ms):
125
124
  for t in range(self.time_steps):
126
125
  out = call(inputs[t])
127
126
  self.assertEqual(out.shape, (self.batch_size, *in_size))
128
127
 
129
128
 
130
129
  if __name__ == '__main__':
131
- with bst.environ.context(dt=0.1):
130
+ with brainstate.environ.context(dt=0.1):
132
131
  unittest.main()
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax.numpy as jnp
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestRateRNNModels(unittest.TestCase):
@@ -30,31 +29,31 @@ class TestRateRNNModels(unittest.TestCase):
30
29
  self.x = jnp.ones((self.batch_size, self.num_in))
31
30
 
32
31
  def test_ValinaRNNCell(self):
33
- model = bst.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
32
+ model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
34
33
  model.init_state(batch_size=self.batch_size)
35
34
  output = model.update(self.x)
36
35
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
37
36
 
38
37
  def test_GRUCell(self):
39
- model = bst.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
38
+ model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
40
39
  model.init_state(batch_size=self.batch_size)
41
40
  output = model.update(self.x)
42
41
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
43
42
 
44
43
  def test_MGUCell(self):
45
- model = bst.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
44
+ model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
46
45
  model.init_state(batch_size=self.batch_size)
47
46
  output = model.update(self.x)
48
47
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
49
48
 
50
49
  def test_LSTMCell(self):
51
- model = bst.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
50
+ model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
52
51
  model.init_state(batch_size=self.batch_size)
53
52
  output = model.update(self.x)
54
53
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
55
54
 
56
55
  def test_URLSTMCell(self):
57
- model = bst.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
56
+ model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
58
57
  model.init_state(batch_size=self.batch_size)
59
58
  output = model.update(self.x)
60
59
  self.assertEqual(output.shape, (self.batch_size, self.num_out))
@@ -13,13 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax.numpy as jnp
21
20
 
22
- import brainstate as bst
21
+ import brainstate
23
22
 
24
23
 
25
24
  class TestReadoutModels(unittest.TestCase):
@@ -32,23 +31,23 @@ class TestReadoutModels(unittest.TestCase):
32
31
  self.x = jnp.ones((self.batch_size, self.in_size))
33
32
 
34
33
  def test_LeakyRateReadout(self):
35
- with bst.environ.context(dt=0.1):
36
- model = bst.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
34
+ with brainstate.environ.context(dt=0.1):
35
+ model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
37
36
  model.init_state(batch_size=self.batch_size)
38
37
  output = model.update(self.x)
39
38
  self.assertEqual(output.shape, (self.batch_size, self.out_size))
40
39
 
41
40
  def test_LeakySpikeReadout(self):
42
- with bst.environ.context(dt=0.1):
43
- model = bst.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
44
- V_initializer=bst.init.ZeroInit(),
45
- w_init=bst.init.KaimingNormal())
41
+ with brainstate.environ.context(dt=0.1):
42
+ model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
43
+ V_initializer=brainstate.init.ZeroInit(),
44
+ w_init=brainstate.init.KaimingNormal())
46
45
  model.init_state(batch_size=self.batch_size)
47
- with bst.environ.context(t=0.):
46
+ with brainstate.environ.context(t=0.):
48
47
  output = model.update(self.x)
49
48
  self.assertEqual(output.shape, (self.batch_size, self.out_size))
50
49
 
51
50
 
52
51
  if __name__ == '__main__':
53
- with bst.environ.context(dt=0.1):
52
+ with brainstate.environ.context(dt=0.1):
54
53
  unittest.main()
@@ -15,47 +15,46 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
18
 
20
19
  import unittest
21
20
 
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
 
26
25
 
27
26
  class TestModuleGroup(unittest.TestCase):
28
27
  def test_initialization(self):
29
- group = bst.nn.DynamicsGroup()
30
- self.assertIsInstance(group, bst.nn.DynamicsGroup)
28
+ group = brainstate.nn.DynamicsGroup()
29
+ self.assertIsInstance(group, brainstate.nn.DynamicsGroup)
31
30
 
32
31
 
33
32
  class TestProjection(unittest.TestCase):
34
33
  def test_initialization(self):
35
- proj = bst.nn.Projection()
36
- self.assertIsInstance(proj, bst.nn.Projection)
34
+ proj = brainstate.nn.Projection()
35
+ self.assertIsInstance(proj, brainstate.nn.Projection)
37
36
 
38
37
  def test_update_not_implemented(self):
39
- proj = bst.nn.Projection()
38
+ proj = brainstate.nn.Projection()
40
39
  with self.assertRaises(ValueError):
41
40
  proj.update()
42
41
 
43
42
 
44
43
  class TestDynamics(unittest.TestCase):
45
44
  def test_initialization(self):
46
- dyn = bst.nn.Dynamics(in_size=10)
47
- self.assertIsInstance(dyn, bst.nn.Dynamics)
45
+ dyn = brainstate.nn.Dynamics(in_size=10)
46
+ self.assertIsInstance(dyn, brainstate.nn.Dynamics)
48
47
  self.assertEqual(dyn.in_size, (10,))
49
48
  self.assertEqual(dyn.out_size, (10,))
50
49
 
51
50
  def test_size_validation(self):
52
51
  with self.assertRaises(ValueError):
53
- bst.nn.Dynamics(in_size=[])
52
+ brainstate.nn.Dynamics(in_size=[])
54
53
  with self.assertRaises(ValueError):
55
- bst.nn.Dynamics(in_size="invalid")
54
+ brainstate.nn.Dynamics(in_size="invalid")
56
55
 
57
56
  def test_input_handling(self):
58
- dyn = bst.nn.Dynamics(in_size=10)
57
+ dyn = brainstate.nn.Dynamics(in_size=10)
59
58
  dyn.add_current_input("test_current", lambda: np.random.rand(10))
60
59
  dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
61
60
 
@@ -63,15 +62,15 @@ class TestDynamics(unittest.TestCase):
63
62
  self.assertIn("test_delta", dyn.delta_inputs)
64
63
 
65
64
  def test_duplicate_input_key(self):
66
- dyn = bst.nn.Dynamics(in_size=10)
65
+ dyn = brainstate.nn.Dynamics(in_size=10)
67
66
  dyn.add_current_input("test", lambda: np.random.rand(10))
68
67
  with self.assertRaises(ValueError):
69
68
  dyn.add_current_input("test", lambda: np.random.rand(10))
70
69
 
71
70
  def test_varshape(self):
72
- dyn = bst.nn.Dynamics(in_size=(2, 3))
71
+ dyn = brainstate.nn.Dynamics(in_size=(2, 3))
73
72
  self.assertEqual(dyn.varshape, (2, 3))
74
- dyn = bst.nn.Dynamics(in_size=(2, 3))
73
+ dyn = brainstate.nn.Dynamics(in_size=(2, 3))
75
74
  self.assertEqual(dyn.varshape, (2, 3))
76
75
 
77
76
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -21,7 +20,7 @@ import brainunit as u
21
20
  import jax.numpy as jnp
22
21
  import numpy as np
23
22
 
24
- import brainstate as bst
23
+ import brainstate
25
24
 
26
25
 
27
26
  class TestSynOutModels(unittest.TestCase):
@@ -35,19 +34,19 @@ class TestSynOutModels(unittest.TestCase):
35
34
  self.V_offset = jnp.array([0.0])
36
35
 
37
36
  def test_COBA(self):
38
- model = bst.nn.COBA(E=self.E)
37
+ model = brainstate.nn.COBA(E=self.E)
39
38
  output = model.update(self.conductance, self.potential)
40
39
  expected_output = self.conductance * (self.E - self.potential)
41
40
  np.testing.assert_array_almost_equal(output, expected_output)
42
41
 
43
42
  def test_CUBA(self):
44
- model = bst.nn.CUBA()
43
+ model = brainstate.nn.CUBA()
45
44
  output = model.update(self.conductance)
46
45
  expected_output = self.conductance * model.scale
47
46
  self.assertTrue(u.math.allclose(output, expected_output))
48
47
 
49
48
  def test_MgBlock(self):
50
- model = bst.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
49
+ model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
51
50
  output = model.update(self.conductance, self.potential)
52
51
  norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
53
52
  expected_output = self.conductance * (self.E - self.potential) / norm