brainstate 0.1.0.post20250503__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +10 -3
  3. brainstate/_state.py +178 -178
  4. brainstate/_utils.py +0 -1
  5. brainstate/augment/_autograd.py +0 -2
  6. brainstate/augment/_autograd_test.py +132 -133
  7. brainstate/augment/_eval_shape.py +0 -2
  8. brainstate/augment/_eval_shape_test.py +7 -9
  9. brainstate/augment/_mapping.py +2 -3
  10. brainstate/augment/_mapping_test.py +75 -76
  11. brainstate/augment/_random.py +0 -2
  12. brainstate/compile/_ad_checkpoint.py +0 -2
  13. brainstate/compile/_ad_checkpoint_test.py +6 -8
  14. brainstate/compile/_conditions.py +0 -2
  15. brainstate/compile/_conditions_test.py +35 -36
  16. brainstate/compile/_error_if.py +0 -2
  17. brainstate/compile/_error_if_test.py +10 -13
  18. brainstate/compile/_jit.py +9 -8
  19. brainstate/compile/_loop_collect_return.py +0 -2
  20. brainstate/compile/_loop_collect_return_test.py +7 -9
  21. brainstate/compile/_loop_no_collection.py +0 -2
  22. brainstate/compile/_loop_no_collection_test.py +7 -8
  23. brainstate/compile/_make_jaxpr.py +30 -17
  24. brainstate/compile/_make_jaxpr_test.py +20 -20
  25. brainstate/compile/_progress_bar.py +0 -1
  26. brainstate/compile/_unvmap.py +0 -1
  27. brainstate/compile/_util.py +0 -2
  28. brainstate/environ.py +0 -2
  29. brainstate/functional/_activations.py +0 -2
  30. brainstate/functional/_activations_test.py +61 -61
  31. brainstate/functional/_normalization.py +0 -2
  32. brainstate/functional/_others.py +0 -2
  33. brainstate/functional/_spikes.py +0 -1
  34. brainstate/graph/_graph_node.py +1 -3
  35. brainstate/graph/_graph_node_test.py +16 -18
  36. brainstate/graph/_graph_operation.py +4 -2
  37. brainstate/graph/_graph_operation_test.py +154 -156
  38. brainstate/init/_base.py +0 -2
  39. brainstate/init/_generic.py +0 -1
  40. brainstate/init/_random_inits.py +0 -1
  41. brainstate/init/_random_inits_test.py +20 -21
  42. brainstate/init/_regular_inits.py +0 -2
  43. brainstate/init/_regular_inits_test.py +4 -5
  44. brainstate/mixin.py +0 -2
  45. brainstate/nn/_collective_ops.py +0 -3
  46. brainstate/nn/_collective_ops_test.py +8 -8
  47. brainstate/nn/_common.py +0 -2
  48. brainstate/nn/_dyn_impl/_dynamics_neuron.py +0 -2
  49. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
  50. brainstate/nn/_dyn_impl/_dynamics_synapse.py +0 -1
  51. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
  52. brainstate/nn/_dyn_impl/_inputs.py +0 -1
  53. brainstate/nn/_dyn_impl/_rate_rnns.py +0 -1
  54. brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
  55. brainstate/nn/_dyn_impl/_readout.py +0 -1
  56. brainstate/nn/_dyn_impl/_readout_test.py +9 -10
  57. brainstate/nn/_dynamics/_dynamics_base.py +0 -1
  58. brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
  59. brainstate/nn/_dynamics/_projection_base.py +0 -1
  60. brainstate/nn/_dynamics/_state_delay.py +0 -2
  61. brainstate/nn/_dynamics/_synouts.py +0 -2
  62. brainstate/nn/_dynamics/_synouts_test.py +4 -5
  63. brainstate/nn/_elementwise/_dropout.py +0 -2
  64. brainstate/nn/_elementwise/_dropout_test.py +9 -9
  65. brainstate/nn/_elementwise/_elementwise.py +0 -2
  66. brainstate/nn/_elementwise/_elementwise_test.py +57 -59
  67. brainstate/nn/_event/_fixedprob_mv.py +0 -1
  68. brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
  69. brainstate/nn/_event/_linear_mv.py +0 -2
  70. brainstate/nn/_event/_linear_mv_test.py +0 -1
  71. brainstate/nn/_exp_euler.py +0 -2
  72. brainstate/nn/_exp_euler_test.py +5 -6
  73. brainstate/nn/_interaction/_conv.py +0 -2
  74. brainstate/nn/_interaction/_conv_test.py +31 -33
  75. brainstate/nn/_interaction/_embedding.py +0 -1
  76. brainstate/nn/_interaction/_linear.py +0 -2
  77. brainstate/nn/_interaction/_linear_test.py +15 -17
  78. brainstate/nn/_interaction/_normalizations.py +0 -2
  79. brainstate/nn/_interaction/_normalizations_test.py +10 -12
  80. brainstate/nn/_interaction/_poolings.py +0 -2
  81. brainstate/nn/_interaction/_poolings_test.py +19 -21
  82. brainstate/nn/_module.py +0 -1
  83. brainstate/nn/_module_test.py +34 -37
  84. brainstate/nn/metrics.py +0 -2
  85. brainstate/optim/_base.py +0 -2
  86. brainstate/optim/_lr_scheduler.py +0 -1
  87. brainstate/optim/_lr_scheduler_test.py +3 -3
  88. brainstate/optim/_optax_optimizer.py +0 -2
  89. brainstate/optim/_optax_optimizer_test.py +8 -9
  90. brainstate/optim/_sgd_optimizer.py +0 -1
  91. brainstate/random/_rand_funs.py +0 -1
  92. brainstate/random/_rand_funs_test.py +183 -184
  93. brainstate/random/_rand_seed.py +0 -1
  94. brainstate/random/_rand_seed_test.py +10 -12
  95. brainstate/random/_rand_state.py +0 -1
  96. brainstate/surrogate.py +0 -1
  97. brainstate/typing.py +0 -2
  98. brainstate/util/_caller.py +4 -6
  99. brainstate/util/_others.py +0 -2
  100. brainstate/util/_pretty_pytree.py +201 -150
  101. brainstate/util/_pretty_repr.py +0 -2
  102. brainstate/util/_pretty_table.py +57 -3
  103. brainstate/util/_scaling.py +0 -2
  104. brainstate/util/_struct.py +0 -2
  105. brainstate/util/filter.py +0 -2
  106. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/METADATA +11 -5
  107. brainstate-0.1.2.dist-info/RECORD +133 -0
  108. brainstate-0.1.0.post20250503.dist-info/RECORD +0 -133
  109. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
  110. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
  111. {brainstate-0.1.0.post20250503.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
- from __future__ import annotations
18
17
 
19
18
  import unittest
20
19
  from pprint import pprint
@@ -24,7 +23,7 @@ import jax
24
23
  import jax.numpy as jnp
25
24
  import pytest
26
25
 
27
- import brainstate as bst
26
+ import brainstate
28
27
  from brainstate.augment._autograd import _jacfwd
29
28
 
30
29
 
@@ -32,11 +31,11 @@ class TestPureFuncGrad(unittest.TestCase):
32
31
  def test_grad_pure_func_1(self):
33
32
  def call(a, b, c): return jnp.sum(a + b + c)
34
33
 
35
- bst.random.seed(1)
34
+ brainstate.random.seed(1)
36
35
  a = jnp.ones(10)
37
- b = bst.random.randn(10)
38
- c = bst.random.uniform(size=10)
39
- f_grad = bst.augment.grad(call, argnums=[0, 1, 2])
36
+ b = brainstate.random.randn(10)
37
+ c = brainstate.random.uniform(size=10)
38
+ f_grad = brainstate.augment.grad(call, argnums=[0, 1, 2])
40
39
  grads = f_grad(a, b, c)
41
40
 
42
41
  for g in grads: assert (g == 1.).all()
@@ -44,29 +43,29 @@ class TestPureFuncGrad(unittest.TestCase):
44
43
  def test_grad_pure_func_2(self):
45
44
  def call(a, b, c): return jnp.sum(a + b + c)
46
45
 
47
- bst.random.seed(1)
46
+ brainstate.random.seed(1)
48
47
  a = jnp.ones(10)
49
- b = bst.random.randn(10)
50
- c = bst.random.uniform(size=10)
51
- f_grad = bst.augment.grad(call)
48
+ b = brainstate.random.randn(10)
49
+ c = brainstate.random.uniform(size=10)
50
+ f_grad = brainstate.augment.grad(call)
52
51
  assert (f_grad(a, b, c) == 1.).all()
53
52
 
54
53
  def test_grad_pure_func_aux1(self):
55
54
  def call(a, b, c):
56
55
  return jnp.sum(a + b + c), (jnp.sin(100), jnp.exp(0.1))
57
56
 
58
- bst.random.seed(1)
59
- f_grad = bst.augment.grad(call, argnums=[0, 1, 2])
57
+ brainstate.random.seed(1)
58
+ f_grad = brainstate.augment.grad(call, argnums=[0, 1, 2])
60
59
  with pytest.raises(TypeError):
61
- f_grad(jnp.ones(10), bst.random.randn(10), bst.random.uniform(size=10))
60
+ f_grad(jnp.ones(10), brainstate.random.randn(10), brainstate.random.uniform(size=10))
62
61
 
63
62
  def test_grad_pure_func_aux2(self):
64
63
  def call(a, b, c):
65
64
  return jnp.sum(a + b + c), (jnp.sin(100), jnp.exp(0.1))
66
65
 
67
- bst.random.seed(1)
68
- f_grad = bst.augment.grad(call, argnums=[0, 1, 2], has_aux=True)
69
- grads, aux = f_grad(jnp.ones(10), bst.random.randn(10), bst.random.uniform(size=10))
66
+ brainstate.random.seed(1)
67
+ f_grad = brainstate.augment.grad(call, argnums=[0, 1, 2], has_aux=True)
68
+ grads, aux = f_grad(jnp.ones(10), brainstate.random.randn(10), brainstate.random.uniform(size=10))
70
69
  for g in grads: assert (g == 1.).all()
71
70
  assert aux[0] == jnp.sin(100)
72
71
  assert aux[1] == jnp.exp(0.1)
@@ -74,11 +73,11 @@ class TestPureFuncGrad(unittest.TestCase):
74
73
  def test_grad_pure_func_return1(self):
75
74
  def call(a, b, c): return jnp.sum(a + b + c)
76
75
 
77
- bst.random.seed(1)
76
+ brainstate.random.seed(1)
78
77
  a = jnp.ones(10)
79
- b = bst.random.randn(10)
80
- c = bst.random.uniform(size=10)
81
- f_grad = bst.augment.grad(call, return_value=True)
78
+ b = brainstate.random.randn(10)
79
+ c = brainstate.random.uniform(size=10)
80
+ f_grad = brainstate.augment.grad(call, return_value=True)
82
81
  grads, returns = f_grad(a, b, c)
83
82
  assert (grads == 1.).all()
84
83
  assert returns == jnp.sum(a + b + c)
@@ -87,11 +86,11 @@ class TestPureFuncGrad(unittest.TestCase):
87
86
  def call(a, b, c):
88
87
  return jnp.sum(a + b + c), (jnp.sin(100), jnp.exp(0.1))
89
88
 
90
- bst.random.seed(1)
89
+ brainstate.random.seed(1)
91
90
  a = jnp.ones(10)
92
- b = bst.random.randn(10)
93
- c = bst.random.uniform(size=10)
94
- f_grad = bst.augment.grad(call, return_value=True, has_aux=True)
91
+ b = brainstate.random.randn(10)
92
+ c = brainstate.random.uniform(size=10)
93
+ f_grad = brainstate.augment.grad(call, return_value=True, has_aux=True)
95
94
  grads, returns, aux = f_grad(a, b, c)
96
95
  assert (grads == 1.).all()
97
96
  assert returns == jnp.sum(a + b + c)
@@ -101,110 +100,110 @@ class TestPureFuncGrad(unittest.TestCase):
101
100
 
102
101
  class TestObjectFuncGrad(unittest.TestCase):
103
102
  def test_grad_ob1(self):
104
- class Test(bst.nn.Module):
103
+ class Test(brainstate.nn.Module):
105
104
  def __init__(self):
106
105
  super(Test, self).__init__()
107
106
 
108
- self.a = bst.ParamState(jnp.ones(10))
109
- self.b = bst.ParamState(bst.random.randn(10))
110
- self.c = bst.ParamState(bst.random.uniform(size=10))
107
+ self.a = brainstate.ParamState(jnp.ones(10))
108
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
109
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
111
110
 
112
111
  def __call__(self):
113
112
  return jnp.sum(self.a.value + self.b.value + self.c.value)
114
113
 
115
- bst.random.seed(0)
114
+ brainstate.random.seed(0)
116
115
 
117
116
  t = Test()
118
- f_grad = bst.augment.grad(t, grad_states={'a': t.a, 'b': t.b, 'c': t.c})
117
+ f_grad = brainstate.augment.grad(t, grad_states={'a': t.a, 'b': t.b, 'c': t.c})
119
118
  grads = f_grad()
120
119
  for g in grads.values():
121
120
  assert (g == 1.).all()
122
121
 
123
122
  t = Test()
124
- f_grad = bst.augment.grad(t, grad_states=[t.a, t.b])
123
+ f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b])
125
124
  grads = f_grad()
126
125
  for g in grads: assert (g == 1.).all()
127
126
 
128
127
  t = Test()
129
- f_grad = bst.augment.grad(t, grad_states=t.a)
128
+ f_grad = brainstate.augment.grad(t, grad_states=t.a)
130
129
  grads = f_grad()
131
130
  assert (grads == 1.).all()
132
131
 
133
132
  t = Test()
134
- f_grad = bst.augment.grad(t, grad_states=t.states())
133
+ f_grad = brainstate.augment.grad(t, grad_states=t.states())
135
134
  grads = f_grad()
136
135
  for g in grads.values():
137
136
  assert (g == 1.).all()
138
137
 
139
138
  def test_grad_ob_aux(self):
140
- class Test(bst.nn.Module):
139
+ class Test(brainstate.nn.Module):
141
140
  def __init__(self):
142
141
  super(Test, self).__init__()
143
- self.a = bst.ParamState(jnp.ones(10))
144
- self.b = bst.ParamState(bst.random.randn(10))
145
- self.c = bst.ParamState(bst.random.uniform(size=10))
142
+ self.a = brainstate.ParamState(jnp.ones(10))
143
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
144
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
146
145
 
147
146
  def __call__(self):
148
147
  return jnp.sum(self.a.value + self.b.value + self.c.value), (jnp.sin(100), jnp.exp(0.1))
149
148
 
150
- bst.random.seed(0)
149
+ brainstate.random.seed(0)
151
150
  t = Test()
152
- f_grad = bst.augment.grad(t, grad_states=[t.a, t.b], has_aux=True)
151
+ f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b], has_aux=True)
153
152
  grads, aux = f_grad()
154
153
  for g in grads: assert (g == 1.).all()
155
154
  assert aux[0] == jnp.sin(100)
156
155
  assert aux[1] == jnp.exp(0.1)
157
156
 
158
157
  t = Test()
159
- f_grad = bst.augment.grad(t, grad_states=t.a, has_aux=True)
158
+ f_grad = brainstate.augment.grad(t, grad_states=t.a, has_aux=True)
160
159
  grads, aux = f_grad()
161
160
  assert (grads == 1.).all()
162
161
  assert aux[0] == jnp.sin(100)
163
162
  assert aux[1] == jnp.exp(0.1)
164
163
 
165
164
  t = Test()
166
- f_grad = bst.augment.grad(t, grad_states=t.states(), has_aux=True)
165
+ f_grad = brainstate.augment.grad(t, grad_states=t.states(), has_aux=True)
167
166
  grads, aux = f_grad()
168
167
  self.assertTrue(len(grads) == len(t.states()))
169
168
 
170
169
  def test_grad_ob_return(self):
171
- class Test(bst.nn.Module):
170
+ class Test(brainstate.nn.Module):
172
171
  def __init__(self):
173
172
  super(Test, self).__init__()
174
- self.a = bst.ParamState(jnp.ones(10))
175
- self.b = bst.ParamState(bst.random.randn(10))
176
- self.c = bst.ParamState(bst.random.uniform(size=10))
173
+ self.a = brainstate.ParamState(jnp.ones(10))
174
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
175
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
177
176
 
178
177
  def __call__(self):
179
178
  return jnp.sum(self.a.value + self.b.value + self.c.value)
180
179
 
181
- bst.random.seed(0)
180
+ brainstate.random.seed(0)
182
181
  t = Test()
183
- f_grad = bst.augment.grad(t, grad_states=[t.a, t.b], return_value=True)
182
+ f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b], return_value=True)
184
183
  grads, returns = f_grad()
185
184
  for g in grads: assert (g == 1.).all()
186
185
  assert returns == t()
187
186
 
188
187
  t = Test()
189
- f_grad = bst.augment.grad(t, grad_states=t.a, return_value=True)
188
+ f_grad = brainstate.augment.grad(t, grad_states=t.a, return_value=True)
190
189
  grads, returns = f_grad()
191
190
  assert (grads == 1.).all()
192
191
  assert returns == t()
193
192
 
194
193
  def test_grad_ob_aux_return(self):
195
- class Test(bst.nn.Module):
194
+ class Test(brainstate.nn.Module):
196
195
  def __init__(self):
197
196
  super(Test, self).__init__()
198
- self.a = bst.ParamState(jnp.ones(10))
199
- self.b = bst.ParamState(bst.random.randn(10))
200
- self.c = bst.ParamState(bst.random.uniform(size=10))
197
+ self.a = brainstate.ParamState(jnp.ones(10))
198
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
199
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
201
200
 
202
201
  def __call__(self):
203
202
  return jnp.sum(self.a.value + self.b.value + self.c.value), (jnp.sin(100), jnp.exp(0.1))
204
203
 
205
- bst.random.seed(0)
204
+ brainstate.random.seed(0)
206
205
  t = Test()
207
- f_grad = bst.augment.grad(t, grad_states=[t.a, t.b], has_aux=True, return_value=True)
206
+ f_grad = brainstate.augment.grad(t, grad_states=[t.a, t.b], has_aux=True, return_value=True)
208
207
  grads, returns, aux = f_grad()
209
208
  for g in grads: assert (g == 1.).all()
210
209
  assert returns == jnp.sum(t.a.value + t.b.value + t.c.value)
@@ -212,7 +211,7 @@ class TestObjectFuncGrad(unittest.TestCase):
212
211
  assert aux[1] == jnp.exp(0.1)
213
212
 
214
213
  t = Test()
215
- f_grad = bst.augment.grad(t, grad_states=t.a, has_aux=True, return_value=True)
214
+ f_grad = brainstate.augment.grad(t, grad_states=t.a, has_aux=True, return_value=True)
216
215
  grads, returns, aux = f_grad()
217
216
  assert (grads == 1.).all()
218
217
  assert returns == jnp.sum(t.a.value + t.b.value + t.c.value)
@@ -220,101 +219,101 @@ class TestObjectFuncGrad(unittest.TestCase):
220
219
  assert aux[1] == jnp.exp(0.1)
221
220
 
222
221
  def test_grad_ob_argnums(self):
223
- class Test(bst.nn.Module):
222
+ class Test(brainstate.nn.Module):
224
223
  def __init__(self):
225
224
  super(Test, self).__init__()
226
- bst.random.seed()
227
- self.a = bst.ParamState(jnp.ones(10))
228
- self.b = bst.ParamState(bst.random.randn(10))
229
- self.c = bst.ParamState(bst.random.uniform(size=10))
225
+ brainstate.random.seed()
226
+ self.a = brainstate.ParamState(jnp.ones(10))
227
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
228
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
230
229
 
231
230
  def __call__(self, d):
232
231
  return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d)
233
232
 
234
- bst.random.seed(0)
233
+ brainstate.random.seed(0)
235
234
 
236
235
  t = Test()
237
- f_grad = bst.augment.grad(t, t.states(), argnums=0)
238
- var_grads, arg_grads = f_grad(bst.random.random(10))
236
+ f_grad = brainstate.augment.grad(t, t.states(), argnums=0)
237
+ var_grads, arg_grads = f_grad(brainstate.random.random(10))
239
238
  for g in var_grads.values(): assert (g == 1.).all()
240
239
  assert (arg_grads == 2.).all()
241
240
 
242
241
  t = Test()
243
- f_grad = bst.augment.grad(t, t.states(), argnums=[0])
244
- var_grads, arg_grads = f_grad(bst.random.random(10))
242
+ f_grad = brainstate.augment.grad(t, t.states(), argnums=[0])
243
+ var_grads, arg_grads = f_grad(brainstate.random.random(10))
245
244
  for g in var_grads.values(): assert (g == 1.).all()
246
245
  assert (arg_grads[0] == 2.).all()
247
246
 
248
247
  t = Test()
249
- f_grad = bst.augment.grad(t, argnums=0)
250
- arg_grads = f_grad(bst.random.random(10))
248
+ f_grad = brainstate.augment.grad(t, argnums=0)
249
+ arg_grads = f_grad(brainstate.random.random(10))
251
250
  assert (arg_grads == 2.).all()
252
251
 
253
252
  t = Test()
254
- f_grad = bst.augment.grad(t, argnums=[0])
255
- arg_grads = f_grad(bst.random.random(10))
253
+ f_grad = brainstate.augment.grad(t, argnums=[0])
254
+ arg_grads = f_grad(brainstate.random.random(10))
256
255
  assert (arg_grads[0] == 2.).all()
257
256
 
258
257
  def test_grad_ob_argnums_aux(self):
259
- class Test(bst.nn.Module):
258
+ class Test(brainstate.nn.Module):
260
259
  def __init__(self):
261
260
  super(Test, self).__init__()
262
- self.a = bst.ParamState(jnp.ones(10))
263
- self.b = bst.ParamState(bst.random.randn(10))
264
- self.c = bst.ParamState(bst.random.uniform(size=10))
261
+ self.a = brainstate.ParamState(jnp.ones(10))
262
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
263
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
265
264
 
266
265
  def __call__(self, d):
267
266
  return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d), (jnp.sin(100), jnp.exp(0.1))
268
267
 
269
- bst.random.seed(0)
268
+ brainstate.random.seed(0)
270
269
 
271
270
  t = Test()
272
- f_grad = bst.augment.grad(t, grad_states=t.states(), argnums=0, has_aux=True)
273
- (var_grads, arg_grads), aux = f_grad(bst.random.random(10))
271
+ f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=0, has_aux=True)
272
+ (var_grads, arg_grads), aux = f_grad(brainstate.random.random(10))
274
273
  for g in var_grads.values(): assert (g == 1.).all()
275
274
  assert (arg_grads == 2.).all()
276
275
  assert aux[0] == jnp.sin(100)
277
276
  assert aux[1] == jnp.exp(0.1)
278
277
 
279
278
  t = Test()
280
- f_grad = bst.augment.grad(t, grad_states=t.states(), argnums=[0], has_aux=True)
281
- (var_grads, arg_grads), aux = f_grad(bst.random.random(10))
279
+ f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=[0], has_aux=True)
280
+ (var_grads, arg_grads), aux = f_grad(brainstate.random.random(10))
282
281
  for g in var_grads.values(): assert (g == 1.).all()
283
282
  assert (arg_grads[0] == 2.).all()
284
283
  assert aux[0] == jnp.sin(100)
285
284
  assert aux[1] == jnp.exp(0.1)
286
285
 
287
286
  t = Test()
288
- f_grad = bst.augment.grad(t, argnums=0, has_aux=True)
289
- arg_grads, aux = f_grad(bst.random.random(10))
287
+ f_grad = brainstate.augment.grad(t, argnums=0, has_aux=True)
288
+ arg_grads, aux = f_grad(brainstate.random.random(10))
290
289
  assert (arg_grads == 2.).all()
291
290
  assert aux[0] == jnp.sin(100)
292
291
  assert aux[1] == jnp.exp(0.1)
293
292
 
294
293
  t = Test()
295
- f_grad = bst.augment.grad(t, argnums=[0], has_aux=True)
296
- arg_grads, aux = f_grad(bst.random.random(10))
294
+ f_grad = brainstate.augment.grad(t, argnums=[0], has_aux=True)
295
+ arg_grads, aux = f_grad(brainstate.random.random(10))
297
296
  assert (arg_grads[0] == 2.).all()
298
297
  assert aux[0] == jnp.sin(100)
299
298
  assert aux[1] == jnp.exp(0.1)
300
299
 
301
300
  def test_grad_ob_argnums_return(self):
302
- class Test(bst.nn.Module):
301
+ class Test(brainstate.nn.Module):
303
302
  def __init__(self):
304
303
  super(Test, self).__init__()
305
304
 
306
- self.a = bst.ParamState(jnp.ones(10))
307
- self.b = bst.ParamState(bst.random.randn(10))
308
- self.c = bst.ParamState(bst.random.uniform(size=10))
305
+ self.a = brainstate.ParamState(jnp.ones(10))
306
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
307
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
309
308
 
310
309
  def __call__(self, d):
311
310
  return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d)
312
311
 
313
- bst.random.seed(0)
312
+ brainstate.random.seed(0)
314
313
 
315
314
  t = Test()
316
- f_grad = bst.augment.grad(t, t.states(), argnums=0, return_value=True)
317
- d = bst.random.random(10)
315
+ f_grad = brainstate.augment.grad(t, t.states(), argnums=0, return_value=True)
316
+ d = brainstate.random.random(10)
318
317
  (var_grads, arg_grads), loss = f_grad(d)
319
318
  for g in var_grads.values():
320
319
  assert (g == 1.).all()
@@ -322,8 +321,8 @@ class TestObjectFuncGrad(unittest.TestCase):
322
321
  assert loss == t(d)
323
322
 
324
323
  t = Test()
325
- f_grad = bst.augment.grad(t, t.states(), argnums=[0], return_value=True)
326
- d = bst.random.random(10)
324
+ f_grad = brainstate.augment.grad(t, t.states(), argnums=[0], return_value=True)
325
+ d = brainstate.random.random(10)
327
326
  (var_grads, arg_grads), loss = f_grad(d)
328
327
  for g in var_grads.values():
329
328
  assert (g == 1.).all()
@@ -331,35 +330,35 @@ class TestObjectFuncGrad(unittest.TestCase):
331
330
  assert loss == t(d)
332
331
 
333
332
  t = Test()
334
- f_grad = bst.augment.grad(t, argnums=0, return_value=True)
335
- d = bst.random.random(10)
333
+ f_grad = brainstate.augment.grad(t, argnums=0, return_value=True)
334
+ d = brainstate.random.random(10)
336
335
  arg_grads, loss = f_grad(d)
337
336
  assert (arg_grads == 2.).all()
338
337
  assert loss == t(d)
339
338
 
340
339
  t = Test()
341
- f_grad = bst.augment.grad(t, argnums=[0], return_value=True)
342
- d = bst.random.random(10)
340
+ f_grad = brainstate.augment.grad(t, argnums=[0], return_value=True)
341
+ d = brainstate.random.random(10)
343
342
  arg_grads, loss = f_grad(d)
344
343
  assert (arg_grads[0] == 2.).all()
345
344
  assert loss == t(d)
346
345
 
347
346
  def test_grad_ob_argnums_aux_return(self):
348
- class Test(bst.nn.Module):
347
+ class Test(brainstate.nn.Module):
349
348
  def __init__(self):
350
349
  super(Test, self).__init__()
351
- self.a = bst.ParamState(jnp.ones(10))
352
- self.b = bst.ParamState(bst.random.randn(10))
353
- self.c = bst.ParamState(bst.random.uniform(size=10))
350
+ self.a = brainstate.ParamState(jnp.ones(10))
351
+ self.b = brainstate.ParamState(brainstate.random.randn(10))
352
+ self.c = brainstate.ParamState(brainstate.random.uniform(size=10))
354
353
 
355
354
  def __call__(self, d):
356
355
  return jnp.sum(self.a.value + self.b.value + self.c.value + 2 * d), (jnp.sin(100), jnp.exp(0.1))
357
356
 
358
- bst.random.seed(0)
357
+ brainstate.random.seed(0)
359
358
 
360
359
  t = Test()
361
- f_grad = bst.augment.grad(t, grad_states=t.states(), argnums=0, has_aux=True, return_value=True)
362
- d = bst.random.random(10)
360
+ f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=0, has_aux=True, return_value=True)
361
+ d = brainstate.random.random(10)
363
362
  (var_grads, arg_grads), loss, aux = f_grad(d)
364
363
  for g in var_grads.values(): assert (g == 1.).all()
365
364
  assert (arg_grads == 2.).all()
@@ -368,8 +367,8 @@ class TestObjectFuncGrad(unittest.TestCase):
368
367
  assert loss == t(d)[0]
369
368
 
370
369
  t = Test()
371
- f_grad = bst.augment.grad(t, grad_states=t.states(), argnums=[0], has_aux=True, return_value=True)
372
- d = bst.random.random(10)
370
+ f_grad = brainstate.augment.grad(t, grad_states=t.states(), argnums=[0], has_aux=True, return_value=True)
371
+ d = brainstate.random.random(10)
373
372
  (var_grads, arg_grads), loss, aux = f_grad(d)
374
373
  for g in var_grads.values(): assert (g == 1.).all()
375
374
  assert (arg_grads[0] == 2.).all()
@@ -378,8 +377,8 @@ class TestObjectFuncGrad(unittest.TestCase):
378
377
  assert loss == t(d)[0]
379
378
 
380
379
  t = Test()
381
- f_grad = bst.augment.grad(t, argnums=0, has_aux=True, return_value=True)
382
- d = bst.random.random(10)
380
+ f_grad = brainstate.augment.grad(t, argnums=0, has_aux=True, return_value=True)
381
+ d = brainstate.random.random(10)
383
382
  arg_grads, loss, aux = f_grad(d)
384
383
  assert (arg_grads == 2.).all()
385
384
  assert aux[0] == jnp.sin(100)
@@ -387,8 +386,8 @@ class TestObjectFuncGrad(unittest.TestCase):
387
386
  assert loss == t(d)[0]
388
387
 
389
388
  t = Test()
390
- f_grad = bst.augment.grad(t, argnums=[0], has_aux=True, return_value=True)
391
- d = bst.random.random(10)
389
+ f_grad = brainstate.augment.grad(t, argnums=[0], has_aux=True, return_value=True)
390
+ d = brainstate.random.random(10)
392
391
  arg_grads, loss, aux = f_grad(d)
393
392
  assert (arg_grads[0] == 2.).all()
394
393
  assert aux[0] == jnp.sin(100)
@@ -436,11 +435,11 @@ class TestPureFuncJacobian(unittest.TestCase):
436
435
  r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
437
436
  return r
438
437
 
439
- br = bst.augment.jacrev(f1)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
438
+ br = brainstate.augment.jacrev(f1)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
440
439
  jr = jax.jacrev(f1)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
441
440
  assert (br == jr).all()
442
441
 
443
- br = bst.augment.jacrev(f1, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
442
+ br = brainstate.augment.jacrev(f1, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
444
443
  jr = jax.jacrev(f1, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
445
444
  assert (br[0] == jr[0]).all()
446
445
  assert (br[1] == jr[1]).all()
@@ -456,12 +455,12 @@ class TestPureFuncJacobian(unittest.TestCase):
456
455
  jr = jax.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
457
456
  pprint(jr)
458
457
 
459
- br = bst.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
458
+ br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
460
459
  pprint(br)
461
460
  assert jnp.array_equal(br[0], jr[0])
462
461
  assert jnp.array_equal(br[1], jr[1])
463
462
 
464
- br = bst.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
463
+ br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
465
464
  pprint(br)
466
465
  assert jnp.array_equal(br[0], jr[0])
467
466
  assert jnp.array_equal(br[1], jr[1])
@@ -471,12 +470,12 @@ class TestPureFuncJacobian(unittest.TestCase):
471
470
  r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
472
471
  return r1, r2
473
472
 
474
- br = bst.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
473
+ br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
475
474
  pprint(br)
476
475
  assert jnp.array_equal(br[0], jr[0])
477
476
  assert jnp.array_equal(br[1], jr[1])
478
477
 
479
- br = bst.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
478
+ br = brainstate.augment.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
480
479
  pprint(br)
481
480
  assert jnp.array_equal(br[0], jr[0])
482
481
  assert jnp.array_equal(br[1], jr[1])
@@ -492,14 +491,14 @@ class TestPureFuncJacobian(unittest.TestCase):
492
491
  jr = jax.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
493
492
  pprint(jr)
494
493
 
495
- br = bst.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
494
+ br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
496
495
  pprint(br)
497
496
  assert jnp.array_equal(br[0][0], jr[0][0])
498
497
  assert jnp.array_equal(br[0][1], jr[0][1])
499
498
  assert jnp.array_equal(br[1][0], jr[1][0])
500
499
  assert jnp.array_equal(br[1][1], jr[1][1])
501
500
 
502
- br = bst.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
501
+ br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
503
502
  pprint(br)
504
503
  assert jnp.array_equal(br[0][0], jr[0][0])
505
504
  assert jnp.array_equal(br[0][1], jr[0][1])
@@ -511,14 +510,14 @@ class TestPureFuncJacobian(unittest.TestCase):
511
510
  r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
512
511
  return r1, r2
513
512
 
514
- br = bst.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
513
+ br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
515
514
  pprint(br)
516
515
  assert jnp.array_equal(br[0][0], jr[0][0])
517
516
  assert jnp.array_equal(br[0][1], jr[0][1])
518
517
  assert jnp.array_equal(br[1][0], jr[1][0])
519
518
  assert jnp.array_equal(br[1][1], jr[1][1])
520
519
 
521
- br = bst.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
520
+ br = brainstate.augment.jacrev(f3, argnums=(0, 1))(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
522
521
  pprint(br)
523
522
  assert jnp.array_equal(br[0][0], jr[0][0])
524
523
  assert jnp.array_equal(br[0][1], jr[0][1])
@@ -537,19 +536,19 @@ class TestPureFuncJacobian(unittest.TestCase):
537
536
  f2 = lambda *args: f1(*args)[0]
538
537
  jr = jax.jacrev(f2)(x, y) # jax jacobian
539
538
  pprint(jr)
540
- grads, aux = bst.augment.jacrev(f1, has_aux=True)(x, y)
539
+ grads, aux = brainstate.augment.jacrev(f1, has_aux=True)(x, y)
541
540
  assert (grads == jr).all()
542
541
  assert aux == (4 * x[1] ** 2 - 2 * x[2])
543
542
 
544
543
  jr = jax.jacrev(f2, argnums=(0, 1))(x, y) # jax jacobian
545
544
  pprint(jr)
546
- grads, aux = bst.augment.jacrev(f1, argnums=(0, 1), has_aux=True)(x, y)
545
+ grads, aux = brainstate.augment.jacrev(f1, argnums=(0, 1), has_aux=True)(x, y)
547
546
  assert (grads[0] == jr[0]).all()
548
547
  assert (grads[1] == jr[1]).all()
549
548
  assert aux == (4 * x[1] ** 2 - 2 * x[2])
550
549
 
551
550
  def test_jacrev_return_aux1(self):
552
- with bst.environ.context(precision=64):
551
+ with brainstate.environ.context(precision=64):
553
552
  def f1(x, y):
554
553
  a = 4 * x[1] ** 2 - 2 * x[2]
555
554
  r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
@@ -564,12 +563,12 @@ class TestPureFuncJacobian(unittest.TestCase):
564
563
  _g2 = jax.jacrev(f2, argnums=(0, 1))(_x, _y) # jax jacobian
565
564
  pprint(_g2)
566
565
 
567
- grads, vec, aux = bst.augment.jacrev(f1, return_value=True, has_aux=True)(_x, _y)
566
+ grads, vec, aux = brainstate.augment.jacrev(f1, return_value=True, has_aux=True)(_x, _y)
568
567
  assert (grads == _g1).all()
569
568
  assert aux == _a
570
569
  assert (vec == _r).all()
571
570
 
572
- grads, vec, aux = bst.augment.jacrev(f1, return_value=True, argnums=(0, 1), has_aux=True)(_x, _y)
571
+ grads, vec, aux = brainstate.augment.jacrev(f1, return_value=True, argnums=(0, 1), has_aux=True)(_x, _y)
573
572
  assert (grads[0] == _g2[0]).all()
574
573
  assert (grads[1] == _g2[1]).all()
575
574
  assert aux == _a
@@ -585,11 +584,11 @@ class TestClassFuncJacobian(unittest.TestCase):
585
584
  _x = jnp.array([1., 2., 3.])
586
585
  _y = jnp.array([10., 5.])
587
586
 
588
- class Test(bst.nn.Module):
587
+ class Test(brainstate.nn.Module):
589
588
  def __init__(self):
590
589
  super(Test, self).__init__()
591
- self.x = bst.State(jnp.array([1., 2., 3.]))
592
- self.y = bst.State(jnp.array([10., 5.]))
590
+ self.x = brainstate.State(jnp.array([1., 2., 3.]))
591
+ self.y = brainstate.State(jnp.array([10., 5.]))
593
592
 
594
593
  def __call__(self, ):
595
594
  a = self.x.value[0] * self.y.value[0]
@@ -601,12 +600,12 @@ class TestClassFuncJacobian(unittest.TestCase):
601
600
 
602
601
  _jr = jax.jacrev(f1)(_x, _y)
603
602
  t = Test()
604
- br = bst.augment.jacrev(t, grad_states=t.x)()
603
+ br = brainstate.augment.jacrev(t, grad_states=t.x)()
605
604
  self.assertTrue((br == _jr).all())
606
605
 
607
606
  _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y)
608
607
  t = Test()
609
- br = bst.augment.jacrev(t, grad_states=[t.x, t.y])()
608
+ br = brainstate.augment.jacrev(t, grad_states=[t.x, t.y])()
610
609
  self.assertTrue((br[0] == _jr[0]).all())
611
610
  self.assertTrue((br[1] == _jr[1]).all())
612
611
 
@@ -1202,7 +1201,7 @@ class TestUnitAwareGrad(unittest.TestCase):
1202
1201
  return u.math.sum(x ** 2)
1203
1202
 
1204
1203
  x = jnp.array([1., 2., 3.]) * u.ms
1205
- g = bst.augment.grad(f, unit_aware=True)(x)
1204
+ g = brainstate.augment.grad(f, unit_aware=True)(x)
1206
1205
  self.assertTrue(u.math.allclose(g, 2 * x))
1207
1206
 
1208
1207
  def test_vector_grad1(self):
@@ -1210,7 +1209,7 @@ class TestUnitAwareGrad(unittest.TestCase):
1210
1209
  return x ** 3
1211
1210
 
1212
1211
  x = jnp.array([1., 2., 3.]) * u.ms
1213
- g = bst.augment.vector_grad(f, unit_aware=True)(x)
1212
+ g = brainstate.augment.vector_grad(f, unit_aware=True)(x)
1214
1213
  self.assertTrue(u.math.allclose(g, 3 * x ** 2))
1215
1214
 
1216
1215
  def test_jacrev1(self):
@@ -1222,7 +1221,7 @@ class TestUnitAwareGrad(unittest.TestCase):
1222
1221
  _x = jnp.array([1., 2., 3.]) * u.ms
1223
1222
  _y = jnp.array([10., 5.]) * u.ms
1224
1223
 
1225
- g = bst.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
1224
+ g = brainstate.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
1226
1225
  self.assertTrue(
1227
1226
  u.math.allclose(
1228
1227
  g[0],
@@ -1254,7 +1253,7 @@ class TestUnitAwareGrad(unittest.TestCase):
1254
1253
  _x = jnp.array([1., 2., 3.]) * u.ms
1255
1254
  _y = jnp.array([10., 5.]) * u.ms
1256
1255
 
1257
- g = bst.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
1256
+ g = brainstate.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
1258
1257
  self.assertTrue(
1259
1258
  u.math.allclose(
1260
1259
  g[0],
@@ -1283,7 +1282,7 @@ class TestUnitAwareGrad(unittest.TestCase):
1283
1282
  def scalar_function(x):
1284
1283
  return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit
1285
1284
 
1286
- hess = bst.augment.hessian(scalar_function, unit_aware=True)
1285
+ hess = brainstate.augment.hessian(scalar_function, unit_aware=True)
1287
1286
  x = jnp.array(1.0) * unit
1288
1287
  res = hess(x)
1289
1288
  expected_hessian = jnp.array([[6.0]]) * unit