brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +12 -9
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/mixin.py +1 -14
  20. brainstate/nn/__init__.py +81 -17
  21. brainstate/nn/_collective_ops_test.py +8 -8
  22. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  23. brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
  24. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  25. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  26. brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
  27. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  28. brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
  29. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  30. brainstate/nn/_elementwise_test.py +169 -0
  31. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  32. brainstate/nn/_exp_euler_test.py +5 -6
  33. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  34. brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
  35. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  36. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  37. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  38. brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
  39. brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
  40. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  41. brainstate/nn/_module_test.py +34 -37
  42. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  43. brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
  44. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  45. brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
  46. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  47. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
  48. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  49. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  50. brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
  51. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  52. brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
  53. brainstate/nn/_stp.py +236 -0
  54. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  55. brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
  56. brainstate/nn/_synaptic_projection.py +133 -0
  57. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  58. brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
  59. brainstate/optim/_lr_scheduler_test.py +3 -3
  60. brainstate/optim/_optax_optimizer_test.py +8 -9
  61. brainstate/random/_rand_funs_test.py +183 -184
  62. brainstate/random/_rand_seed_test.py +10 -12
  63. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  64. brainstate-0.1.3.dist-info/RECORD +131 -0
  65. brainstate/nn/_dyn_impl/__init__.py +0 -42
  66. brainstate/nn/_dynamics/__init__.py +0 -37
  67. brainstate/nn/_elementwise/__init__.py +0 -22
  68. brainstate/nn/_elementwise/_elementwise_test.py +0 -171
  69. brainstate/nn/_interaction/__init__.py +0 -41
  70. brainstate-0.1.1.dist-info/RECORD +0 -133
  71. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  72. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  73. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -14,14 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  import unittest
20
18
 
21
19
  import brainunit as u
22
20
  from absl.testing import parameterized
23
21
 
24
- import brainstate as bst
22
+ import brainstate
25
23
 
26
24
 
27
25
  class TestDense(parameterized.TestCase):
@@ -32,19 +30,19 @@ class TestDense(parameterized.TestCase):
32
30
  num_out=[20, ]
33
31
  )
34
32
  def test_Dense1(self, size, num_out):
35
- f = bst.nn.Linear(10, num_out)
36
- x = bst.random.random(size)
33
+ f = brainstate.nn.Linear(10, num_out)
34
+ x = brainstate.random.random(size)
37
35
  y = f(x)
38
36
  self.assertTrue(y.shape == size[:-1] + (num_out,))
39
37
 
40
38
 
41
39
  class TestSparseMatrix(unittest.TestCase):
42
40
  def test_csr(self):
43
- data = bst.random.rand(10, 20)
41
+ data = brainstate.random.rand(10, 20)
44
42
  data = data * (data > 0.9)
45
- f = bst.nn.SparseLinear(u.sparse.CSR.fromdense(data))
43
+ f = brainstate.nn.SparseLinear(u.sparse.CSR.fromdense(data))
46
44
 
47
- x = bst.random.rand(10)
45
+ x = brainstate.random.rand(10)
48
46
  y = f(x)
49
47
  self.assertTrue(
50
48
  u.math.allclose(
@@ -53,7 +51,7 @@ class TestSparseMatrix(unittest.TestCase):
53
51
  )
54
52
  )
55
53
 
56
- x = bst.random.rand(5, 10)
54
+ x = brainstate.random.rand(5, 10)
57
55
  y = f(x)
58
56
  self.assertTrue(
59
57
  u.math.allclose(
@@ -63,11 +61,11 @@ class TestSparseMatrix(unittest.TestCase):
63
61
  )
64
62
 
65
63
  def test_csc(self):
66
- data = bst.random.rand(10, 20)
64
+ data = brainstate.random.rand(10, 20)
67
65
  data = data * (data > 0.9)
68
- f = bst.nn.SparseLinear(u.sparse.CSC.fromdense(data))
66
+ f = brainstate.nn.SparseLinear(u.sparse.CSC.fromdense(data))
69
67
 
70
- x = bst.random.rand(10)
68
+ x = brainstate.random.rand(10)
71
69
  y = f(x)
72
70
  self.assertTrue(
73
71
  u.math.allclose(
@@ -76,7 +74,7 @@ class TestSparseMatrix(unittest.TestCase):
76
74
  )
77
75
  )
78
76
 
79
- x = bst.random.rand(5, 10)
77
+ x = brainstate.random.rand(5, 10)
80
78
  y = f(x)
81
79
  self.assertTrue(
82
80
  u.math.allclose(
@@ -86,11 +84,11 @@ class TestSparseMatrix(unittest.TestCase):
86
84
  )
87
85
 
88
86
  def test_coo(self):
89
- data = bst.random.rand(10, 20)
87
+ data = brainstate.random.rand(10, 20)
90
88
  data = data * (data > 0.9)
91
- f = bst.nn.SparseLinear(u.sparse.COO.fromdense(data))
89
+ f = brainstate.nn.SparseLinear(u.sparse.COO.fromdense(data))
92
90
 
93
- x = bst.random.rand(10)
91
+ x = brainstate.random.rand(10)
94
92
  y = f(x)
95
93
  self.assertTrue(
96
94
  u.math.allclose(
@@ -99,7 +97,7 @@ class TestSparseMatrix(unittest.TestCase):
99
97
  )
100
98
  )
101
99
 
102
- x = bst.random.rand(5, 10)
100
+ x = brainstate.random.rand(5, 10)
103
101
  y = f(x)
104
102
  self.assertTrue(
105
103
  u.math.allclose(
@@ -16,11 +16,13 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
20
- from ._linear_mv import EventLinear
19
+ from ._synapse import Synapse
21
20
 
22
21
  __all__ = [
23
- 'EventLinear',
24
- 'EventFixedProb',
25
- 'EventFixedNumConn',
22
+ 'LongTermPlasticity',
26
23
  ]
24
+
25
+
26
+ class LongTermPlasticity(Synapse):
27
+ pass
28
+
@@ -13,20 +13,17 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax.numpy as jnp
21
- import jaxlib.xla_extension
22
19
 
23
- import brainstate as bst
20
+ import brainstate
24
21
 
25
22
 
26
23
  class TestDelay(unittest.TestCase):
27
24
  def test_delay1(self):
28
- a = bst.State(bst.random.random(10, 20))
29
- delay = bst.nn.Delay(a.value)
25
+ a = brainstate.State(brainstate.random.random(10, 20))
26
+ delay = brainstate.nn.Delay(a.value)
30
27
  delay.register_entry('a', 1.)
31
28
  delay.register_entry('b', 2.)
32
29
  delay.register_entry('c', None)
@@ -36,7 +33,7 @@ class TestDelay(unittest.TestCase):
36
33
  delay.register_entry('c', 10.)
37
34
 
38
35
  def test_rotation_delay(self):
39
- rotation_delay = bst.nn.Delay(jnp.ones((1,)))
36
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
40
37
  t0 = 0.
41
38
  t1, n1 = 1., 10
42
39
  t2, n2 = 2., 20
@@ -53,7 +50,7 @@ class TestDelay(unittest.TestCase):
53
50
  # print(rotation_delay.max_length)
54
51
 
55
52
  for i in range(100):
56
- bst.environ.set(i=i)
53
+ brainstate.environ.set(i=i)
57
54
  rotation_delay.update(jnp.ones((1,)) * i)
58
55
  # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
59
56
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
@@ -61,7 +58,7 @@ class TestDelay(unittest.TestCase):
61
58
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
62
59
 
63
60
  def test_concat_delay(self):
64
- rotation_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
61
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
65
62
  t0 = 0.
66
63
  t1, n1 = 1., 10
67
64
  t2, n2 = 2., 20
@@ -74,7 +71,7 @@ class TestDelay(unittest.TestCase):
74
71
 
75
72
  print()
76
73
  for i in range(100):
77
- bst.environ.set(i=i)
74
+ brainstate.environ.set(i=i)
78
75
  rotation_delay.update(jnp.ones((1,)) * i)
79
76
  print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
80
77
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
@@ -83,40 +80,40 @@ class TestDelay(unittest.TestCase):
83
80
  # bst.util.clear_buffer_memory()
84
81
 
85
82
  def test_jit_erro(self):
86
- rotation_delay = bst.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
83
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
84
  rotation_delay.init_state()
88
85
 
89
- with bst.environ.context(i=0, t=0, jit_error_check=True):
86
+ with brainstate.environ.context(i=0, t=0, jit_error_check=True):
90
87
  rotation_delay.retrieve_at_time(-2.0)
91
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
88
+ with self.assertRaises(Exception):
92
89
  rotation_delay.retrieve_at_time(-2.1)
93
90
  rotation_delay.retrieve_at_time(-2.01)
94
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
91
+ with self.assertRaises(Exception):
95
92
  rotation_delay.retrieve_at_time(-2.09)
96
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
93
+ with self.assertRaises(Exception):
97
94
  rotation_delay.retrieve_at_time(0.1)
98
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
95
+ with self.assertRaises(Exception):
99
96
  rotation_delay.retrieve_at_time(0.01)
100
97
 
101
98
  def test_round_interp(self):
102
99
  for shape in [(1,), (1, 1), (1, 1, 1)]:
103
100
  for delay_method in ['rotation', 'concat']:
104
- rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
105
- interp_method='round')
101
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
102
+ interp_method='round')
106
103
  t0, n1 = 0.01, 0
107
104
  t1, n1 = 1.04, 10
108
105
  t2, n2 = 1.06, 11
109
106
  rotation_delay.init_state()
110
107
 
111
- @bst.compile.jit
108
+ @brainstate.compile.jit
112
109
  def retrieve(td, i):
113
- with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
110
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
114
111
  return rotation_delay.retrieve_at_time(td)
115
112
 
116
113
  print()
117
114
  for i in range(100):
118
- t = i * bst.environ.get_dt()
119
- with bst.environ.context(i=i, t=t):
115
+ t = i * brainstate.environ.get_dt()
116
+ with brainstate.environ.context(i=i, t=t):
120
117
  rotation_delay.update(jnp.ones(shape) * i)
121
118
  print(i,
122
119
  retrieve(t - t0, i),
@@ -131,22 +128,22 @@ class TestDelay(unittest.TestCase):
131
128
  for delay_method in ['rotation', 'concat']:
132
129
  print(shape, delay_method)
133
130
 
134
- rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
135
- interp_method='linear_interp')
131
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
132
+ interp_method='linear_interp')
136
133
  t0, n0 = 0.01, 0.1
137
134
  t1, n1 = 1.04, 10.4
138
135
  t2, n2 = 1.06, 10.6
139
136
  rotation_delay.init_state()
140
137
 
141
- @bst.compile.jit
138
+ @brainstate.compile.jit
142
139
  def retrieve(td, i):
143
- with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
140
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
144
141
  return rotation_delay.retrieve_at_time(td)
145
142
 
146
143
  print()
147
144
  for i in range(100):
148
- t = i * bst.environ.get_dt()
149
- with bst.environ.context(i=i, t=t):
145
+ t = i * brainstate.environ.get_dt()
146
+ with brainstate.environ.context(i=i, t=t):
150
147
  rotation_delay.update(jnp.ones(shape) * i)
151
148
  print(i,
152
149
  retrieve(t - t0, i),
@@ -157,8 +154,8 @@ class TestDelay(unittest.TestCase):
157
154
  self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
158
155
 
159
156
  def test_rotation_and_concat_delay(self):
160
- rotation_delay = bst.nn.Delay(jnp.ones((1,)))
161
- concat_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
157
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
158
+ concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
162
159
  t0 = 0.
163
160
  t1, n1 = 1., 10
164
161
  t2, n2 = 2., 20
@@ -175,7 +172,7 @@ class TestDelay(unittest.TestCase):
175
172
 
176
173
  print()
177
174
  for i in range(100):
178
- bst.environ.set(i=i)
175
+ brainstate.environ.set(i=i)
179
176
  new = jnp.ones((1,)) * i
180
177
  rotation_delay.update(new)
181
178
  concat_delay.update(new)
@@ -186,17 +183,17 @@ class TestDelay(unittest.TestCase):
186
183
 
187
184
  class TestModule(unittest.TestCase):
188
185
  def test_states(self):
189
- class A(bst.nn.Module):
186
+ class A(brainstate.nn.Module):
190
187
  def __init__(self):
191
188
  super().__init__()
192
- self.a = bst.State(bst.random.random(10, 20))
193
- self.b = bst.State(bst.random.random(10, 20))
189
+ self.a = brainstate.State(brainstate.random.random(10, 20))
190
+ self.b = brainstate.State(brainstate.random.random(10, 20))
194
191
 
195
- class B(bst.nn.Module):
192
+ class B(brainstate.nn.Module):
196
193
  def __init__(self):
197
194
  super().__init__()
198
195
  self.a = A()
199
- self.b = bst.State(bst.random.random(10, 20))
196
+ self.b = brainstate.State(brainstate.random.random(10, 20))
200
197
 
201
198
  b = B()
202
199
  print()
@@ -207,5 +204,5 @@ class TestModule(unittest.TestCase):
207
204
 
208
205
 
209
206
  if __name__ == '__main__':
210
- with bst.environ.context(dt=0.1):
207
+ with brainstate.environ.context(dt=0.1):
211
208
  unittest.main()
@@ -22,9 +22,9 @@ import jax
22
22
 
23
23
  from brainstate import init, surrogate, environ
24
24
  from brainstate._state import HiddenState, ShortTermState
25
- from brainstate.nn._dynamics._dynamics_base import Dynamics
26
- from brainstate.nn._exp_euler import exp_euler_step
27
25
  from brainstate.typing import ArrayLike, Size
26
+ from ._dynamics import Dynamics
27
+ from ._exp_euler import exp_euler_step
28
28
 
29
29
  __all__ = [
30
30
  'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
@@ -15,7 +15,6 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from __future__ import annotations
19
18
 
20
19
  import unittest
21
20
 
@@ -23,7 +22,7 @@ import brainunit as u
23
22
  import jax
24
23
  import jax.numpy as jnp
25
24
 
26
- import brainstate as bst
25
+ import brainstate
27
26
  from brainstate.nn import IF, LIF, ALIF
28
27
 
29
28
 
@@ -35,13 +34,13 @@ class TestNeuron(unittest.TestCase):
35
34
 
36
35
  def test_neuron_base_class(self):
37
36
  with self.assertRaises(NotImplementedError):
38
- bst.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
37
+ brainstate.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
39
38
 
40
39
  def generate_input(self):
41
- return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
40
+ return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
42
41
 
43
42
  def test_if_neuron(self):
44
- with bst.environ.context(dt=0.1 * u.ms):
43
+ with brainstate.environ.context(dt=0.1 * u.ms):
45
44
  neuron = IF(self.in_size)
46
45
  inputs = self.generate_input()
47
46
 
@@ -62,7 +61,7 @@ class TestNeuron(unittest.TestCase):
62
61
  self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
63
62
 
64
63
  def test_lif_neuron(self):
65
- with bst.environ.context(dt=0.1 * u.ms):
64
+ with brainstate.environ.context(dt=0.1 * u.ms):
66
65
  tau = 20.0 * u.ms
67
66
  neuron = LIF(self.in_size, tau=tau)
68
67
  inputs = self.generate_input()
@@ -74,7 +73,7 @@ class TestNeuron(unittest.TestCase):
74
73
 
75
74
  # Test forward pass
76
75
  state = neuron.init_state(self.batch_size)
77
- call = bst.compile.jit(neuron)
76
+ call = brainstate.compile.jit(neuron)
78
77
 
79
78
  for t in range(self.time_steps):
80
79
  out = call(inputs[t])
@@ -94,8 +93,8 @@ class TestNeuron(unittest.TestCase):
94
93
 
95
94
  # Test forward pass
96
95
  neuron.init_state(self.batch_size)
97
- call = bst.compile.jit(neuron)
98
- with bst.environ.context(dt=0.1 * u.ms):
96
+ call = brainstate.compile.jit(neuron)
97
+ with brainstate.environ.context(dt=0.1 * u.ms):
99
98
  for t in range(self.time_steps):
100
99
  out = call(inputs[t])
101
100
  self.assertEqual(out.shape, (self.batch_size, self.in_size))
@@ -113,8 +112,8 @@ class TestNeuron(unittest.TestCase):
113
112
  neuron = NeuronClass(self.in_size, spk_reset='soft')
114
113
  inputs = self.generate_input()
115
114
  state = neuron.init_state(self.batch_size)
116
- call = bst.compile.jit(neuron)
117
- with bst.environ.context(dt=0.1 * u.ms):
115
+ call = brainstate.compile.jit(neuron)
116
+ with brainstate.environ.context(dt=0.1 * u.ms):
118
117
  for t in range(self.time_steps):
119
118
  out = call(inputs[t])
120
119
  self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th))
@@ -124,8 +123,8 @@ class TestNeuron(unittest.TestCase):
124
123
  neuron = NeuronClass(self.in_size, spk_reset='hard')
125
124
  inputs = self.generate_input()
126
125
  state = neuron.init_state(self.batch_size)
127
- call = bst.compile.jit(neuron)
128
- with bst.environ.context(dt=0.1 * u.ms):
126
+ call = brainstate.compile.jit(neuron)
127
+ with brainstate.environ.context(dt=0.1 * u.ms):
129
128
  for t in range(self.time_steps):
130
129
  out = call(inputs[t])
131
130
  self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV)))
@@ -135,8 +134,8 @@ class TestNeuron(unittest.TestCase):
135
134
  neuron = NeuronClass(self.in_size)
136
135
  inputs = self.generate_input()
137
136
  state = neuron.init_state(self.batch_size)
138
- call = bst.compile.jit(neuron)
139
- with bst.environ.context(dt=0.1 * u.ms):
137
+ call = brainstate.compile.jit(neuron)
138
+ with brainstate.environ.context(dt=0.1 * u.ms):
140
139
  for t in range(self.time_steps):
141
140
  out = call(inputs[t])
142
141
  self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type)
@@ -148,15 +147,15 @@ class TestNeuron(unittest.TestCase):
148
147
  self.assertEqual(neuron.in_size, in_size)
149
148
  self.assertEqual(neuron.out_size, in_size)
150
149
 
151
- inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
150
+ inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
152
151
  state = neuron.init_state(self.batch_size)
153
- call = bst.compile.jit(neuron)
154
- with bst.environ.context(dt=0.1 * u.ms):
152
+ call = brainstate.compile.jit(neuron)
153
+ with brainstate.environ.context(dt=0.1 * u.ms):
155
154
  for t in range(self.time_steps):
156
155
  out = call(inputs[t])
157
156
  self.assertEqual(out.shape, (self.batch_size, *in_size))
158
157
 
159
158
 
160
159
  if __name__ == '__main__':
161
- with bst.environ.context(dt=0.1):
160
+ with brainstate.environ.context(dt=0.1):
162
161
  unittest.main()
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import environ, init
24
24
  from brainstate._state import ParamState, BatchState
25
- from brainstate.nn._module import Module
26
25
  from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
26
+ from ._module import Module
27
27
 
28
28
  __all__ = [
29
29
  'BatchNorm0d',
@@ -13,12 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  from absl.testing import absltest
19
17
  from absl.testing import parameterized
20
18
 
21
- import brainstate as bst
19
+ import brainstate
22
20
 
23
21
 
24
22
  class Test_Normalization(parameterized.TestCase):
@@ -26,27 +24,27 @@ class Test_Normalization(parameterized.TestCase):
26
24
  fit=[True, False],
27
25
  )
28
26
  def test_BatchNorm1d(self, fit):
29
- net = bst.nn.BatchNorm1d((3, 10))
30
- bst.environ.set(fit=fit)
31
- input = bst.random.randn(1, 3, 10)
27
+ net = brainstate.nn.BatchNorm1d((3, 10))
28
+ brainstate.environ.set(fit=fit)
29
+ input = brainstate.random.randn(1, 3, 10)
32
30
  output = net(input)
33
31
 
34
32
  @parameterized.product(
35
33
  fit=[True, False]
36
34
  )
37
35
  def test_BatchNorm2d(self, fit):
38
- net = bst.nn.BatchNorm2d([3, 4, 10])
39
- bst.environ.set(fit=fit)
40
- input = bst.random.randn(1, 3, 4, 10)
36
+ net = brainstate.nn.BatchNorm2d([3, 4, 10])
37
+ brainstate.environ.set(fit=fit)
38
+ input = brainstate.random.randn(1, 3, 4, 10)
41
39
  output = net(input)
42
40
 
43
41
  @parameterized.product(
44
42
  fit=[True, False]
45
43
  )
46
44
  def test_BatchNorm3d(self, fit):
47
- net = bst.nn.BatchNorm3d([3, 4, 5, 10])
48
- bst.environ.set(fit=fit)
49
- input = bst.random.randn(1, 3, 4, 5, 10)
45
+ net = brainstate.nn.BatchNorm3d([3, 4, 5, 10])
46
+ brainstate.environ.set(fit=fit)
47
+ input = brainstate.random.randn(1, 3, 4, 5, 10)
50
48
  output = net(input)
51
49
 
52
50
  # @parameterized.product(
@@ -25,8 +25,8 @@ import jax.numpy as jnp
25
25
  import numpy as np
26
26
 
27
27
  from brainstate import environ
28
- from brainstate.nn._module import Module
29
28
  from brainstate.typing import Size
29
+ from ._module import Module
30
30
 
31
31
  __all__ = [
32
32
  'Flatten', 'Unflatten',
@@ -1,13 +1,11 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- from __future__ import annotations
4
-
5
3
  import jax
6
4
  import numpy as np
7
5
  from absl.testing import absltest
8
6
  from absl.testing import parameterized
9
7
 
10
- import brainstate as bst
8
+ import brainstate
11
9
  import brainstate.nn as nn
12
10
 
13
11
 
@@ -18,7 +16,7 @@ class TestFlatten(parameterized.TestCase):
18
16
  (32, 8),
19
17
  (10, 20, 30),
20
18
  ]:
21
- arr = bst.random.rand(*size)
19
+ arr = brainstate.random.rand(*size)
22
20
  f = nn.Flatten(start_axis=0)
23
21
  out = f(arr)
24
22
  self.assertTrue(out.shape == (np.prod(size),))
@@ -29,21 +27,21 @@ class TestFlatten(parameterized.TestCase):
29
27
  (32, 8),
30
28
  (10, 20, 30),
31
29
  ]:
32
- arr = bst.random.rand(*size)
30
+ arr = brainstate.random.rand(*size)
33
31
  f = nn.Flatten(start_axis=1)
34
32
  out = f(arr)
35
33
  self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
36
34
 
37
35
  def test_flatten3(self):
38
36
  size = (16, 32, 32, 8)
39
- arr = bst.random.rand(*size)
37
+ arr = brainstate.random.rand(*size)
40
38
  f = nn.Flatten(start_axis=0, in_size=(32, 8))
41
39
  out = f(arr)
42
40
  self.assertTrue(out.shape == (16, 32, 32 * 8))
43
41
 
44
42
  def test_flatten4(self):
45
43
  size = (16, 32, 32, 8)
46
- arr = bst.random.rand(*size)
44
+ arr = brainstate.random.rand(*size)
47
45
  f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
48
46
  out = f(arr)
49
47
  self.assertTrue(out.shape == (16, 32, 32 * 8))
@@ -58,7 +56,7 @@ class TestPool(parameterized.TestCase):
58
56
  super().__init__(*args, **kwargs)
59
57
 
60
58
  def test_MaxPool2d_v1(self):
61
- arr = bst.random.rand(16, 32, 32, 8)
59
+ arr = brainstate.random.rand(16, 32, 32, 8)
62
60
 
63
61
  out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
64
62
  self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -79,7 +77,7 @@ class TestPool(parameterized.TestCase):
79
77
  self.assertTrue(out.shape == (16, 17, 32, 5))
80
78
 
81
79
  def test_AvgPool2d_v1(self):
82
- arr = bst.random.rand(16, 32, 32, 8)
80
+ arr = brainstate.random.rand(16, 32, 32, 8)
83
81
 
84
82
  out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
85
83
  self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -105,9 +103,9 @@ class TestPool(parameterized.TestCase):
105
103
  for target_size in [10, 9, 8, 7, 6]
106
104
  )
107
105
  def test_adaptive_pool1d(self, target_size):
108
- from brainstate.nn._interaction._poolings import _adaptive_pool1d
106
+ from brainstate.nn._poolings import _adaptive_pool1d
109
107
 
110
- arr = bst.random.rand(100)
108
+ arr = brainstate.random.rand(100)
111
109
  op = jax.numpy.mean
112
110
 
113
111
  out = _adaptive_pool1d(arr, target_size, op)
@@ -119,7 +117,7 @@ class TestPool(parameterized.TestCase):
119
117
  self.assertTrue(out.shape == (target_size,))
120
118
 
121
119
  def test_AdaptiveAvgPool2d_v1(self):
122
- input = bst.random.randn(64, 8, 9)
120
+ input = brainstate.random.randn(64, 8, 9)
123
121
 
124
122
  output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
125
123
  self.assertTrue(output.shape == (64, 5, 7))
@@ -137,8 +135,8 @@ class TestPool(parameterized.TestCase):
137
135
  self.assertTrue(output.shape == (64, 2, 3))
138
136
 
139
137
  def test_AdaptiveAvgPool2d_v2(self):
140
- bst.random.seed()
141
- input = bst.random.randn(128, 64, 32, 16)
138
+ brainstate.random.seed()
139
+ input = brainstate.random.randn(128, 64, 32, 16)
142
140
 
143
141
  output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
144
142
  self.assertTrue(output.shape == (128, 64, 5, 7))
@@ -154,13 +152,13 @@ class TestPool(parameterized.TestCase):
154
152
  print()
155
153
 
156
154
  def test_AdaptiveAvgPool3d_v1(self):
157
- input = bst.random.randn(10, 128, 64, 32)
155
+ input = brainstate.random.randn(10, 128, 64, 32)
158
156
  net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
159
157
  output = net(input)
160
158
  self.assertTrue(output.shape == (10, 6, 5, 3))
161
159
 
162
160
  def test_AdaptiveAvgPool3d_v2(self):
163
- input = bst.random.randn(10, 20, 128, 64, 32)
161
+ input = brainstate.random.randn(10, 20, 128, 64, 32)
164
162
  net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
165
163
  output = net(input)
166
164
  self.assertTrue(output.shape == (10, 6, 5, 3, 32))
@@ -169,7 +167,7 @@ class TestPool(parameterized.TestCase):
169
167
  axis=(-1, 0, 1)
170
168
  )
171
169
  def test_AdaptiveMaxPool1d_v1(self, axis):
172
- input = bst.random.randn(32, 16)
170
+ input = brainstate.random.randn(32, 16)
173
171
  net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
174
172
  output = net(input)
175
173
 
@@ -177,7 +175,7 @@ class TestPool(parameterized.TestCase):
177
175
  axis=(-1, 0, 1, 2)
178
176
  )
179
177
  def test_AdaptiveMaxPool1d_v2(self, axis):
180
- input = bst.random.randn(2, 32, 16)
178
+ input = brainstate.random.randn(2, 32, 16)
181
179
  net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
182
180
  output = net(input)
183
181
 
@@ -185,7 +183,7 @@ class TestPool(parameterized.TestCase):
185
183
  axis=(-1, 0, 1, 2)
186
184
  )
187
185
  def test_AdaptiveMaxPool2d_v1(self, axis):
188
- input = bst.random.randn(32, 16, 12)
186
+ input = brainstate.random.randn(32, 16, 12)
189
187
  net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
190
188
  output = net(input)
191
189
 
@@ -193,7 +191,7 @@ class TestPool(parameterized.TestCase):
193
191
  axis=(-1, 0, 1, 2, 3)
194
192
  )
195
193
  def test_AdaptiveMaxPool2d_v2(self, axis):
196
- input = bst.random.randn(2, 32, 16, 12)
194
+ input = brainstate.random.randn(2, 32, 16, 12)
197
195
  net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
198
196
  output = net(input)
199
197
 
@@ -201,7 +199,7 @@ class TestPool(parameterized.TestCase):
201
199
  axis=(-1, 0, 1, 2, 3)
202
200
  )
203
201
  def test_AdaptiveMaxPool3d_v1(self, axis):
204
- input = bst.random.randn(2, 128, 64, 32)
202
+ input = brainstate.random.randn(2, 128, 64, 32)
205
203
  net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
206
204
  output = net(input)
207
205
  print()
@@ -210,7 +208,7 @@ class TestPool(parameterized.TestCase):
210
208
  axis=(-1, 0, 1, 2, 3, 4)
211
209
  )
212
210
  def test_AdaptiveMaxPool3d_v1(self, axis):
213
- input = bst.random.randn(2, 128, 64, 32, 16)
211
+ input = brainstate.random.randn(2, 128, 64, 32, 16)
214
212
  net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
215
213
  output = net(input)
216
214