brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -13,20 +13,17 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
 
20
18
  import jax.numpy as jnp
21
- import jaxlib.xla_extension
22
19
 
23
- import brainstate as bst
20
+ import brainstate
24
21
 
25
22
 
26
23
  class TestDelay(unittest.TestCase):
27
24
  def test_delay1(self):
28
- a = bst.State(bst.random.random(10, 20))
29
- delay = bst.nn.Delay(a.value)
25
+ a = brainstate.State(brainstate.random.random(10, 20))
26
+ delay = brainstate.nn.Delay(a.value)
30
27
  delay.register_entry('a', 1.)
31
28
  delay.register_entry('b', 2.)
32
29
  delay.register_entry('c', None)
@@ -36,7 +33,7 @@ class TestDelay(unittest.TestCase):
36
33
  delay.register_entry('c', 10.)
37
34
 
38
35
  def test_rotation_delay(self):
39
- rotation_delay = bst.nn.Delay(jnp.ones((1,)))
36
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
40
37
  t0 = 0.
41
38
  t1, n1 = 1., 10
42
39
  t2, n2 = 2., 20
@@ -53,7 +50,7 @@ class TestDelay(unittest.TestCase):
53
50
  # print(rotation_delay.max_length)
54
51
 
55
52
  for i in range(100):
56
- bst.environ.set(i=i)
53
+ brainstate.environ.set(i=i)
57
54
  rotation_delay.update(jnp.ones((1,)) * i)
58
55
  # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
59
56
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
@@ -61,7 +58,7 @@ class TestDelay(unittest.TestCase):
61
58
  self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
62
59
 
63
60
  def test_concat_delay(self):
64
- rotation_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
61
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
65
62
  t0 = 0.
66
63
  t1, n1 = 1., 10
67
64
  t2, n2 = 2., 20
@@ -74,7 +71,7 @@ class TestDelay(unittest.TestCase):
74
71
 
75
72
  print()
76
73
  for i in range(100):
77
- bst.environ.set(i=i)
74
+ brainstate.environ.set(i=i)
78
75
  rotation_delay.update(jnp.ones((1,)) * i)
79
76
  print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
80
77
  self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
@@ -83,40 +80,40 @@ class TestDelay(unittest.TestCase):
83
80
  # bst.util.clear_buffer_memory()
84
81
 
85
82
  def test_jit_erro(self):
86
- rotation_delay = bst.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
83
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
84
  rotation_delay.init_state()
88
85
 
89
- with bst.environ.context(i=0, t=0, jit_error_check=True):
86
+ with brainstate.environ.context(i=0, t=0, jit_error_check=True):
90
87
  rotation_delay.retrieve_at_time(-2.0)
91
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
88
+ with self.assertRaises(Exception):
92
89
  rotation_delay.retrieve_at_time(-2.1)
93
90
  rotation_delay.retrieve_at_time(-2.01)
94
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
91
+ with self.assertRaises(Exception):
95
92
  rotation_delay.retrieve_at_time(-2.09)
96
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
93
+ with self.assertRaises(Exception):
97
94
  rotation_delay.retrieve_at_time(0.1)
98
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
95
+ with self.assertRaises(Exception):
99
96
  rotation_delay.retrieve_at_time(0.01)
100
97
 
101
98
  def test_round_interp(self):
102
99
  for shape in [(1,), (1, 1), (1, 1, 1)]:
103
100
  for delay_method in ['rotation', 'concat']:
104
- rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
105
- interp_method='round')
101
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
102
+ interp_method='round')
106
103
  t0, n1 = 0.01, 0
107
104
  t1, n1 = 1.04, 10
108
105
  t2, n2 = 1.06, 11
109
106
  rotation_delay.init_state()
110
107
 
111
- @bst.compile.jit
108
+ @brainstate.compile.jit
112
109
  def retrieve(td, i):
113
- with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
110
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
114
111
  return rotation_delay.retrieve_at_time(td)
115
112
 
116
113
  print()
117
114
  for i in range(100):
118
- t = i * bst.environ.get_dt()
119
- with bst.environ.context(i=i, t=t):
115
+ t = i * brainstate.environ.get_dt()
116
+ with brainstate.environ.context(i=i, t=t):
120
117
  rotation_delay.update(jnp.ones(shape) * i)
121
118
  print(i,
122
119
  retrieve(t - t0, i),
@@ -131,22 +128,22 @@ class TestDelay(unittest.TestCase):
131
128
  for delay_method in ['rotation', 'concat']:
132
129
  print(shape, delay_method)
133
130
 
134
- rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
135
- interp_method='linear_interp')
131
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
132
+ interp_method='linear_interp')
136
133
  t0, n0 = 0.01, 0.1
137
134
  t1, n1 = 1.04, 10.4
138
135
  t2, n2 = 1.06, 10.6
139
136
  rotation_delay.init_state()
140
137
 
141
- @bst.compile.jit
138
+ @brainstate.compile.jit
142
139
  def retrieve(td, i):
143
- with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
140
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
144
141
  return rotation_delay.retrieve_at_time(td)
145
142
 
146
143
  print()
147
144
  for i in range(100):
148
- t = i * bst.environ.get_dt()
149
- with bst.environ.context(i=i, t=t):
145
+ t = i * brainstate.environ.get_dt()
146
+ with brainstate.environ.context(i=i, t=t):
150
147
  rotation_delay.update(jnp.ones(shape) * i)
151
148
  print(i,
152
149
  retrieve(t - t0, i),
@@ -157,8 +154,8 @@ class TestDelay(unittest.TestCase):
157
154
  self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
158
155
 
159
156
  def test_rotation_and_concat_delay(self):
160
- rotation_delay = bst.nn.Delay(jnp.ones((1,)))
161
- concat_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
157
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
158
+ concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
162
159
  t0 = 0.
163
160
  t1, n1 = 1., 10
164
161
  t2, n2 = 2., 20
@@ -175,7 +172,7 @@ class TestDelay(unittest.TestCase):
175
172
 
176
173
  print()
177
174
  for i in range(100):
178
- bst.environ.set(i=i)
175
+ brainstate.environ.set(i=i)
179
176
  new = jnp.ones((1,)) * i
180
177
  rotation_delay.update(new)
181
178
  concat_delay.update(new)
@@ -186,17 +183,17 @@ class TestDelay(unittest.TestCase):
186
183
 
187
184
  class TestModule(unittest.TestCase):
188
185
  def test_states(self):
189
- class A(bst.nn.Module):
186
+ class A(brainstate.nn.Module):
190
187
  def __init__(self):
191
188
  super().__init__()
192
- self.a = bst.State(bst.random.random(10, 20))
193
- self.b = bst.State(bst.random.random(10, 20))
189
+ self.a = brainstate.State(brainstate.random.random(10, 20))
190
+ self.b = brainstate.State(brainstate.random.random(10, 20))
194
191
 
195
- class B(bst.nn.Module):
192
+ class B(brainstate.nn.Module):
196
193
  def __init__(self):
197
194
  super().__init__()
198
195
  self.a = A()
199
- self.b = bst.State(bst.random.random(10, 20))
196
+ self.b = brainstate.State(brainstate.random.random(10, 20))
200
197
 
201
198
  b = B()
202
199
  print()
@@ -207,5 +204,5 @@ class TestModule(unittest.TestCase):
207
204
 
208
205
 
209
206
  if __name__ == '__main__':
210
- with bst.environ.context(dt=0.1):
207
+ with brainstate.environ.context(dt=0.1):
211
208
  unittest.main()
brainstate/nn/metrics.py CHANGED
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  import typing as tp
20
18
  from dataclasses import dataclass
21
19
  from functools import partial
brainstate/optim/_base.py CHANGED
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  from typing import Dict, Hashable
20
18
 
21
19
  from brainstate._state import State, StateDictManager
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  from typing import Sequence, Union
20
19
 
@@ -19,12 +19,12 @@ import unittest
19
19
 
20
20
  import jax.numpy as jnp
21
21
 
22
- import brainstate as bst
22
+ import brainstate
23
23
 
24
24
 
25
25
  class TestMultiStepLR(unittest.TestCase):
26
26
  def test1(self):
27
- lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
27
+ lr = brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
28
28
  for i in range(40):
29
29
  r = lr(i)
30
30
  if i < 10:
@@ -37,7 +37,7 @@ class TestMultiStepLR(unittest.TestCase):
37
37
  self.assertTrue(jnp.allclose(r, 0.0001))
38
38
 
39
39
  def test2(self):
40
- lr = bst.compile.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
+ lr = brainstate.compile.jit(brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
41
  for i in range(40):
42
42
  r = lr(i)
43
43
  if i < 10:
@@ -14,8 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from __future__ import annotations
18
-
19
17
  import importlib.util
20
18
  from typing import Hashable, Dict, Optional
21
19
 
@@ -13,39 +13,38 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
20
19
  import jax
21
20
  import optax
22
21
 
23
- import brainstate as bst
22
+ import brainstate
24
23
 
25
24
 
26
25
  class TestOptaxOptimizer(unittest.TestCase):
27
26
  def test1(self):
28
- class Model(bst.nn.Module):
27
+ class Model(brainstate.nn.Module):
29
28
  def __init__(self):
30
29
  super().__init__()
31
- self.linear1 = bst.nn.Linear(2, 3)
32
- self.linear2 = bst.nn.Linear(3, 4)
30
+ self.linear1 = brainstate.nn.Linear(2, 3)
31
+ self.linear2 = brainstate.nn.Linear(3, 4)
33
32
 
34
33
  def __call__(self, x):
35
34
  return self.linear2(self.linear1(x))
36
35
 
37
- x = bst.random.randn(1, 2)
36
+ x = brainstate.random.randn(1, 2)
38
37
  y = jax.numpy.ones((1, 4))
39
38
 
40
39
  model = Model()
41
40
  tx = optax.adam(1e-3)
42
- optimizer = bst.optim.OptaxOptimizer(tx)
43
- optimizer.register_trainable_weights(model.states(bst.ParamState))
41
+ optimizer = brainstate.optim.OptaxOptimizer(tx)
42
+ optimizer.register_trainable_weights(model.states(brainstate.ParamState))
44
43
 
45
44
  loss_fn = lambda: ((model(x) - y) ** 2).mean()
46
45
  prev_loss = loss_fn()
47
46
 
48
- grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
47
+ grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
49
48
  optimizer.update(grads)
50
49
 
51
50
  new_loss = loss_fn()
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  import functools
20
19
  from typing import Union, Dict, Optional, Tuple, Any, TypeVar
@@ -15,7 +15,6 @@
15
15
 
16
16
 
17
17
  # -*- coding: utf-8 -*-
18
- from __future__ import annotations
19
18
 
20
19
  from typing import Optional
21
20