brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +12 -9
  3. brainstate/_state.py +1 -1
  4. brainstate/augment/_autograd_test.py +132 -133
  5. brainstate/augment/_eval_shape_test.py +7 -9
  6. brainstate/augment/_mapping_test.py +75 -76
  7. brainstate/compile/_ad_checkpoint_test.py +6 -8
  8. brainstate/compile/_conditions_test.py +35 -36
  9. brainstate/compile/_error_if_test.py +10 -13
  10. brainstate/compile/_loop_collect_return_test.py +7 -9
  11. brainstate/compile/_loop_no_collection_test.py +7 -8
  12. brainstate/compile/_make_jaxpr.py +29 -14
  13. brainstate/compile/_make_jaxpr_test.py +20 -20
  14. brainstate/functional/_activations_test.py +61 -61
  15. brainstate/graph/_graph_node_test.py +16 -18
  16. brainstate/graph/_graph_operation_test.py +154 -156
  17. brainstate/init/_random_inits_test.py +20 -21
  18. brainstate/init/_regular_inits_test.py +4 -5
  19. brainstate/mixin.py +1 -14
  20. brainstate/nn/__init__.py +81 -17
  21. brainstate/nn/_collective_ops_test.py +8 -8
  22. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  23. brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
  24. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
  25. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
  26. brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
  27. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
  28. brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
  29. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  30. brainstate/nn/_elementwise_test.py +169 -0
  31. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  32. brainstate/nn/_exp_euler_test.py +5 -6
  33. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
  34. brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
  35. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  36. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  37. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
  38. brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
  39. brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
  40. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  41. brainstate/nn/_module_test.py +34 -37
  42. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  43. brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
  44. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  45. brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
  46. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  47. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
  48. brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
  49. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  50. brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
  51. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  52. brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
  53. brainstate/nn/_stp.py +236 -0
  54. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
  55. brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
  56. brainstate/nn/_synaptic_projection.py +133 -0
  57. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  58. brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
  59. brainstate/optim/_lr_scheduler_test.py +3 -3
  60. brainstate/optim/_optax_optimizer_test.py +8 -9
  61. brainstate/random/_rand_funs_test.py +183 -184
  62. brainstate/random/_rand_seed_test.py +10 -12
  63. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
  64. brainstate-0.1.3.dist-info/RECORD +131 -0
  65. brainstate/nn/_dyn_impl/__init__.py +0 -42
  66. brainstate/nn/_dynamics/__init__.py +0 -37
  67. brainstate/nn/_elementwise/__init__.py +0 -22
  68. brainstate/nn/_elementwise/_elementwise_test.py +0 -171
  69. brainstate/nn/_interaction/__init__.py +0 -41
  70. brainstate-0.1.1.dist-info/RECORD +0 -133
  71. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
  72. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
  73. {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import unittest
19
17
  from collections.abc import Callable
20
18
  from threading import Thread
@@ -23,67 +21,67 @@ import jax
23
21
  import jax.numpy as jnp
24
22
  from absl.testing import absltest, parameterized
25
23
 
26
- import brainstate as bst
24
+ import brainstate
27
25
 
28
26
 
29
27
  class TestIter(unittest.TestCase):
30
28
  def test1(self):
31
- class Model(bst.nn.Module):
29
+ class Model(brainstate.nn.Module):
32
30
  def __init__(self):
33
31
  super().__init__()
34
- self.a = bst.nn.Linear(1, 2)
35
- self.b = bst.nn.Linear(2, 3)
36
- self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
37
- self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
38
- self.b.a = bst.nn.LIF(2)
32
+ self.a = brainstate.nn.Linear(1, 2)
33
+ self.b = brainstate.nn.Linear(2, 3)
34
+ self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
35
+ self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
36
+ self.b.a = brainstate.nn.LIF(2)
39
37
 
40
- for path, node in bst.graph.iter_leaf(Model()):
38
+ for path, node in brainstate.graph.iter_leaf(Model()):
41
39
  print(path, node)
42
- for path, node in bst.graph.iter_node(Model()):
40
+ for path, node in brainstate.graph.iter_node(Model()):
43
41
  print(path, node)
44
- for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
42
+ for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
45
43
  print(path, node)
46
- for path, node in bst.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
44
+ for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
47
45
  print(path, node)
48
46
 
49
47
  def test_iter_leaf_v1(self):
50
- class Linear(bst.nn.Module):
48
+ class Linear(brainstate.nn.Module):
51
49
  def __init__(self, din, dout):
52
50
  super().__init__()
53
- self.weight = bst.ParamState(bst.random.randn(din, dout))
54
- self.bias = bst.ParamState(bst.random.randn(dout))
51
+ self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
52
+ self.bias = brainstate.ParamState(brainstate.random.randn(dout))
55
53
  self.a = 1
56
54
 
57
55
  module = Linear(3, 4)
58
56
  graph = [module, module]
59
57
 
60
58
  num = 0
61
- for path, value in bst.graph.iter_leaf(graph):
59
+ for path, value in brainstate.graph.iter_leaf(graph):
62
60
  print(path, type(value).__name__)
63
61
  num += 1
64
62
 
65
63
  assert num == 3
66
64
 
67
65
  def test_iter_node_v1(self):
68
- class Model(bst.nn.Module):
66
+ class Model(brainstate.nn.Module):
69
67
  def __init__(self):
70
68
  super().__init__()
71
- self.a = bst.nn.Linear(1, 2)
72
- self.b = bst.nn.Linear(2, 3)
73
- self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
74
- self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
75
- self.b.a = bst.nn.LIF(2)
69
+ self.a = brainstate.nn.Linear(1, 2)
70
+ self.b = brainstate.nn.Linear(2, 3)
71
+ self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
72
+ self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
73
+ self.b.a = brainstate.nn.LIF(2)
76
74
 
77
75
  model = Model()
78
76
 
79
77
  num = 0
80
- for path, node in bst.graph.iter_node([model, model]):
78
+ for path, node in brainstate.graph.iter_node([model, model]):
81
79
  print(path, node.__class__.__name__)
82
80
  num += 1
83
81
  assert num == 8
84
82
 
85
83
 
86
- class List(bst.nn.Module):
84
+ class List(brainstate.nn.Module):
87
85
  def __init__(self, items):
88
86
  super().__init__()
89
87
  self.items = list(items)
@@ -95,7 +93,7 @@ class List(bst.nn.Module):
95
93
  self.items[idx] = value
96
94
 
97
95
 
98
- class Dict(bst.nn.Module):
96
+ class Dict(brainstate.nn.Module):
99
97
  def __init__(self, *args, **kwargs):
100
98
  super().__init__()
101
99
  self.items = dict(*args, **kwargs)
@@ -107,12 +105,12 @@ class Dict(bst.nn.Module):
107
105
  self.items[key] = value
108
106
 
109
107
 
110
- class StatefulLinear(bst.nn.Module):
108
+ class StatefulLinear(brainstate.nn.Module):
111
109
  def __init__(self, din, dout):
112
110
  super().__init__()
113
- self.w = bst.ParamState(bst.random.rand(din, dout))
114
- self.b = bst.ParamState(jnp.zeros((dout,)))
115
- self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
111
+ self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
112
+ self.b = brainstate.ParamState(jnp.zeros((dout,)))
113
+ self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
116
114
 
117
115
  def increment(self):
118
116
  self.count.value += 1
@@ -124,44 +122,44 @@ class StatefulLinear(bst.nn.Module):
124
122
 
125
123
  class TestGraphUtils(absltest.TestCase):
126
124
  def test_flatten_treey_state(self):
127
- a = {'a': 1, 'b': bst.ParamState(2)}
128
- g = [a, 3, a, bst.ParamState(4)]
125
+ a = {'a': 1, 'b': brainstate.ParamState(2)}
126
+ g = [a, 3, a, brainstate.ParamState(4)]
129
127
 
130
- refmap = bst.graph.RefMap()
131
- graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=True)
128
+ refmap = brainstate.graph.RefMap()
129
+ graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
132
130
 
133
131
  states[0]['b'].value = 2
134
132
  states[3].value = 4
135
133
 
136
- assert isinstance(states[0]['b'], bst.TreefyState)
137
- assert isinstance(states[3], bst.TreefyState)
138
- assert isinstance(states, bst.util.NestedDict)
134
+ assert isinstance(states[0]['b'], brainstate.TreefyState)
135
+ assert isinstance(states[3], brainstate.TreefyState)
136
+ assert isinstance(states, brainstate.util.NestedDict)
139
137
  assert len(refmap) == 2
140
138
  assert a['b'] in refmap
141
139
  assert g[3] in refmap
142
140
 
143
141
  def test_flatten(self):
144
- a = {'a': 1, 'b': bst.ParamState(2)}
145
- g = [a, 3, a, bst.ParamState(4)]
142
+ a = {'a': 1, 'b': brainstate.ParamState(2)}
143
+ g = [a, 3, a, brainstate.ParamState(4)]
146
144
 
147
- refmap = bst.graph.RefMap()
148
- graphdef, states = bst.graph.flatten(g, ref_index=refmap, treefy_state=False)
145
+ refmap = brainstate.graph.RefMap()
146
+ graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
149
147
 
150
148
  states[0]['b'].value = 2
151
149
  states[3].value = 4
152
150
 
153
- assert isinstance(states[0]['b'], bst.State)
154
- assert isinstance(states[3], bst.State)
151
+ assert isinstance(states[0]['b'], brainstate.State)
152
+ assert isinstance(states[3], brainstate.State)
155
153
  assert len(refmap) == 2
156
154
  assert a['b'] in refmap
157
155
  assert g[3] in refmap
158
156
 
159
157
  def test_unflatten_treey_state(self):
160
- a = bst.graph.Dict(a=1, b=bst.ParamState(2))
161
- g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
158
+ a = brainstate.graph.Dict(a=1, b=brainstate.ParamState(2))
159
+ g1 = brainstate.graph.List([a, 3, a, brainstate.ParamState(4)])
162
160
 
163
- graphdef, references = bst.graph.flatten(g1, treefy_state=True)
164
- g = bst.graph.unflatten(graphdef, references)
161
+ graphdef, references = brainstate.graph.flatten(g1, treefy_state=True)
162
+ g = brainstate.graph.unflatten(graphdef, references)
165
163
 
166
164
  print(graphdef)
167
165
  print(references)
@@ -170,11 +168,11 @@ class TestGraphUtils(absltest.TestCase):
170
168
  assert g1[0]['b'] is not g[0]['b']
171
169
 
172
170
  def test_unflatten(self):
173
- a = bst.graph.Dict(a=1, b=bst.ParamState(2))
174
- g1 = bst.graph.List([a, 3, a, bst.ParamState(4)])
171
+ a = brainstate.graph.Dict(a=1, b=brainstate.ParamState(2))
172
+ g1 = brainstate.graph.List([a, 3, a, brainstate.ParamState(4)])
175
173
 
176
- graphdef, references = bst.graph.flatten(g1, treefy_state=False)
177
- g = bst.graph.unflatten(graphdef, references)
174
+ graphdef, references = brainstate.graph.flatten(g1, treefy_state=False)
175
+ g = brainstate.graph.unflatten(graphdef, references)
178
176
 
179
177
  print(graphdef)
180
178
  print(references)
@@ -183,29 +181,29 @@ class TestGraphUtils(absltest.TestCase):
183
181
  assert g1[0]['b'] is g[0]['b']
184
182
 
185
183
  def test_unflatten_pytree(self):
186
- a = {'a': 1, 'b': bst.ParamState(2)}
187
- g = [a, 3, a, bst.ParamState(4)]
184
+ a = {'a': 1, 'b': brainstate.ParamState(2)}
185
+ g = [a, 3, a, brainstate.ParamState(4)]
188
186
 
189
- graphdef, references = bst.graph.treefy_split(g)
190
- g = bst.graph.treefy_merge(graphdef, references)
187
+ graphdef, references = brainstate.graph.treefy_split(g)
188
+ g = brainstate.graph.treefy_merge(graphdef, references)
191
189
 
192
190
  assert g[0] is not g[2]
193
191
 
194
192
  def test_unflatten_empty(self):
195
- a = Dict({'a': 1, 'b': bst.ParamState(2)})
196
- g = List([a, 3, a, bst.ParamState(4)])
193
+ a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
194
+ g = List([a, 3, a, brainstate.ParamState(4)])
197
195
 
198
- graphdef, references = bst.graph.treefy_split(g)
196
+ graphdef, references = brainstate.graph.treefy_split(g)
199
197
 
200
198
  with self.assertRaisesRegex(ValueError, 'Expected key'):
201
- bst.graph.unflatten(graphdef, bst.util.NestedDict({}))
199
+ brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
202
200
 
203
201
  def test_module_list(self):
204
202
  ls = [
205
- bst.nn.Linear(2, 2),
206
- bst.nn.BatchNorm1d([10, 2]),
203
+ brainstate.nn.Linear(2, 2),
204
+ brainstate.nn.BatchNorm1d([10, 2]),
207
205
  ]
208
- graphdef, statetree = bst.graph.treefy_split(ls)
206
+ graphdef, statetree = brainstate.graph.treefy_split(ls)
209
207
 
210
208
  assert statetree[0]['weight'].value['weight'].shape == (2, 2)
211
209
  assert statetree[0]['weight'].value['bias'].shape == (2,)
@@ -215,47 +213,47 @@ class TestGraphUtils(absltest.TestCase):
215
213
  assert statetree[1]['running_var'].value.shape == (1, 2)
216
214
 
217
215
  def test_shared_variables(self):
218
- v = bst.ParamState(1)
216
+ v = brainstate.ParamState(1)
219
217
  g = [v, v]
220
218
 
221
- graphdef, statetree = bst.graph.treefy_split(g)
219
+ graphdef, statetree = brainstate.graph.treefy_split(g)
222
220
  assert len(statetree.to_flat()) == 1
223
221
 
224
- g2 = bst.graph.treefy_merge(graphdef, statetree)
222
+ g2 = brainstate.graph.treefy_merge(graphdef, statetree)
225
223
  assert g2[0] is g2[1]
226
224
 
227
225
  def test_tied_weights(self):
228
- class Foo(bst.nn.Module):
226
+ class Foo(brainstate.nn.Module):
229
227
  def __init__(self) -> None:
230
228
  super().__init__()
231
- self.bar = bst.nn.Linear(2, 2)
232
- self.baz = bst.nn.Linear(2, 2)
229
+ self.bar = brainstate.nn.Linear(2, 2)
230
+ self.baz = brainstate.nn.Linear(2, 2)
233
231
 
234
232
  # tie the weights
235
233
  self.baz.weight = self.bar.weight
236
234
 
237
235
  node = Foo()
238
- graphdef, state = bst.graph.treefy_split(node)
236
+ graphdef, state = brainstate.graph.treefy_split(node)
239
237
 
240
238
  assert len(state.to_flat()) == 1
241
239
 
242
- node2 = bst.graph.treefy_merge(graphdef, state)
240
+ node2 = brainstate.graph.treefy_merge(graphdef, state)
243
241
 
244
242
  assert node2.bar.weight is node2.baz.weight
245
243
 
246
244
  def test_tied_weights_example(self):
247
- class LinearTranspose(bst.nn.Module):
245
+ class LinearTranspose(brainstate.nn.Module):
248
246
  def __init__(self, dout: int, din: int, ) -> None:
249
247
  super().__init__()
250
- self.kernel = bst.ParamState(bst.init.LecunNormal()((dout, din)))
248
+ self.kernel = brainstate.ParamState(brainstate.init.LecunNormal()((dout, din)))
251
249
 
252
250
  def __call__(self, x):
253
251
  return x @ self.kernel.value.T
254
252
 
255
- class Encoder(bst.nn.Module):
253
+ class Encoder(brainstate.nn.Module):
256
254
  def __init__(self, ) -> None:
257
255
  super().__init__()
258
- self.embed = bst.nn.Embedding(10, 2)
256
+ self.embed = brainstate.nn.Embedding(10, 2)
259
257
  self.linear_out = LinearTranspose(10, 2)
260
258
 
261
259
  # tie the weights
@@ -266,7 +264,7 @@ class TestGraphUtils(absltest.TestCase):
266
264
  return self.linear_out(x)
267
265
 
268
266
  model = Encoder()
269
- graphdef, state = bst.graph.treefy_split(model)
267
+ graphdef, state = brainstate.graph.treefy_split(model)
270
268
 
271
269
  assert len(state.to_flat()) == 1
272
270
 
@@ -276,49 +274,49 @@ class TestGraphUtils(absltest.TestCase):
276
274
  assert y.shape == (2, 10)
277
275
 
278
276
  def test_state_variables_not_shared_with_graph(self):
279
- class Foo(bst.graph.Node):
277
+ class Foo(brainstate.graph.Node):
280
278
  def __init__(self):
281
- self.a = bst.ParamState(1)
279
+ self.a = brainstate.ParamState(1)
282
280
 
283
281
  m = Foo()
284
- graphdef, statetree = bst.graph.treefy_split(m)
282
+ graphdef, statetree = brainstate.graph.treefy_split(m)
285
283
 
286
- assert isinstance(m.a, bst.ParamState)
287
- assert issubclass(statetree.a.type, bst.ParamState)
284
+ assert isinstance(m.a, brainstate.ParamState)
285
+ assert issubclass(statetree.a.type, brainstate.ParamState)
288
286
  assert m.a is not statetree.a
289
287
  assert m.a.value == statetree.a.value
290
288
 
291
- m2 = bst.graph.treefy_merge(graphdef, statetree)
289
+ m2 = brainstate.graph.treefy_merge(graphdef, statetree)
292
290
 
293
- assert isinstance(m2.a, bst.ParamState)
294
- assert issubclass(statetree.a.type, bst.ParamState)
291
+ assert isinstance(m2.a, brainstate.ParamState)
292
+ assert issubclass(statetree.a.type, brainstate.ParamState)
295
293
  assert m2.a is not statetree.a
296
294
  assert m2.a.value == statetree.a.value
297
295
 
298
296
  def test_shared_state_variables_not_shared_with_graph(self):
299
- class Foo(bst.graph.Node):
297
+ class Foo(brainstate.graph.Node):
300
298
  def __init__(self):
301
- p = bst.ParamState(1)
299
+ p = brainstate.ParamState(1)
302
300
  self.a = p
303
301
  self.b = p
304
302
 
305
303
  m = Foo()
306
- graphdef, state = bst.graph.treefy_split(m)
304
+ graphdef, state = brainstate.graph.treefy_split(m)
307
305
 
308
- assert isinstance(m.a, bst.ParamState)
309
- assert isinstance(m.b, bst.ParamState)
310
- assert issubclass(state.a.type, bst.ParamState)
306
+ assert isinstance(m.a, brainstate.ParamState)
307
+ assert isinstance(m.b, brainstate.ParamState)
308
+ assert issubclass(state.a.type, brainstate.ParamState)
311
309
  assert 'b' not in state
312
310
  assert m.a is not state.a
313
311
  assert m.b is not state.a
314
312
  assert m.a.value == state.a.value
315
313
  assert m.b.value == state.a.value
316
314
 
317
- m2 = bst.graph.treefy_merge(graphdef, state)
315
+ m2 = brainstate.graph.treefy_merge(graphdef, state)
318
316
 
319
- assert isinstance(m2.a, bst.ParamState)
320
- assert isinstance(m2.b, bst.ParamState)
321
- assert issubclass(state.a.type, bst.ParamState)
317
+ assert isinstance(m2.a, brainstate.ParamState)
318
+ assert isinstance(m2.b, brainstate.ParamState)
319
+ assert issubclass(state.a.type, brainstate.ParamState)
322
320
  assert m2.a is not state.a
323
321
  assert m2.b is not state.a
324
322
  assert m2.a.value == state.a.value
@@ -326,24 +324,24 @@ class TestGraphUtils(absltest.TestCase):
326
324
  assert m2.a is m2.b
327
325
 
328
326
  def test_pytree_node(self):
329
- @bst.util.dataclass
327
+ @brainstate.util.dataclass
330
328
  class Tree:
331
- a: bst.ParamState
332
- b: str = bst.util.field(pytree_node=False)
329
+ a: brainstate.ParamState
330
+ b: str = brainstate.util.field(pytree_node=False)
333
331
 
334
- class Foo(bst.graph.Node):
332
+ class Foo(brainstate.graph.Node):
335
333
  def __init__(self):
336
- self.tree = Tree(bst.ParamState(1), 'a')
334
+ self.tree = Tree(brainstate.ParamState(1), 'a')
337
335
 
338
336
  m = Foo()
339
337
 
340
- graphdef, state = bst.graph.treefy_split(m)
338
+ graphdef, state = brainstate.graph.treefy_split(m)
341
339
 
342
340
  assert 'tree' in state
343
341
  assert 'a' in state.tree
344
342
  assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
345
343
 
346
- m2 = bst.graph.treefy_merge(graphdef, state)
344
+ m2 = brainstate.graph.treefy_merge(graphdef, state)
347
345
 
348
346
  assert isinstance(m2.tree, Tree)
349
347
  assert m2.tree.a.value == 1
@@ -352,36 +350,36 @@ class TestGraphUtils(absltest.TestCase):
352
350
  assert m2.tree is not m.tree
353
351
 
354
352
  def test_call_jit_update(self):
355
- class Counter(bst.graph.Node):
353
+ class Counter(brainstate.graph.Node):
356
354
  def __init__(self):
357
- self.count = bst.ParamState(jnp.zeros(()))
355
+ self.count = brainstate.ParamState(jnp.zeros(()))
358
356
 
359
357
  def inc(self):
360
358
  self.count.value += 1
361
359
  return 1
362
360
 
363
- graph_state = bst.graph.treefy_split(Counter())
361
+ graph_state = brainstate.graph.treefy_split(Counter())
364
362
 
365
363
  @jax.jit
366
364
  def update(graph_state):
367
- out, graph_state = bst.graph.call(graph_state).inc()
365
+ out, graph_state = brainstate.graph.call(graph_state).inc()
368
366
  self.assertEqual(out, 1)
369
367
  return graph_state
370
368
 
371
369
  graph_state = update(graph_state)
372
370
  graph_state = update(graph_state)
373
371
 
374
- counter = bst.graph.treefy_merge(*graph_state)
372
+ counter = brainstate.graph.treefy_merge(*graph_state)
375
373
 
376
374
  self.assertEqual(counter.count.value, 2)
377
375
 
378
376
  def test_stateful_linear(self):
379
377
  linear = StatefulLinear(3, 2)
380
- linear_state = bst.graph.treefy_split(linear)
378
+ linear_state = brainstate.graph.treefy_split(linear)
381
379
 
382
380
  @jax.jit
383
381
  def forward(x, pure_linear):
384
- y, pure_linear = bst.graph.call(pure_linear)(x)
382
+ y, pure_linear = brainstate.graph.call(pure_linear)(x)
385
383
  return y, pure_linear
386
384
 
387
385
  x = jnp.ones((1, 3))
@@ -389,7 +387,7 @@ class TestGraphUtils(absltest.TestCase):
389
387
  y, linear_state = forward(x, linear_state)
390
388
 
391
389
  self.assertEqual(linear.count.value, 0)
392
- new_linear = bst.graph.treefy_merge(*linear_state)
390
+ new_linear = brainstate.graph.treefy_merge(*linear_state)
393
391
  self.assertEqual(new_linear.count.value, 2)
394
392
 
395
393
  def test_getitem(self):
@@ -397,20 +395,20 @@ class TestGraphUtils(absltest.TestCase):
397
395
  a=StatefulLinear(3, 2),
398
396
  b=StatefulLinear(2, 1),
399
397
  )
400
- node_state = bst.graph.treefy_split(nodes)
401
- _, node_state = bst.graph.call(node_state)['b'].increment()
398
+ node_state = brainstate.graph.treefy_split(nodes)
399
+ _, node_state = brainstate.graph.call(node_state)['b'].increment()
402
400
 
403
- nodes = bst.graph.treefy_merge(*node_state)
401
+ nodes = brainstate.graph.treefy_merge(*node_state)
404
402
 
405
403
  self.assertEqual(nodes['a'].count.value, 0)
406
404
  self.assertEqual(nodes['b'].count.value, 1)
407
405
 
408
406
 
409
- class SimpleModule(bst.nn.Module):
407
+ class SimpleModule(brainstate.nn.Module):
410
408
  pass
411
409
 
412
410
 
413
- class SimplePyTreeModule(bst.nn.Module):
411
+ class SimplePyTreeModule(brainstate.nn.Module):
414
412
  pass
415
413
 
416
414
 
@@ -420,13 +418,13 @@ class TestThreading(parameterized.TestCase):
420
418
  (SimpleModule,),
421
419
  (SimplePyTreeModule,),
422
420
  )
423
- def test_threading(self, module_fn: Callable[[], bst.nn.Module]):
421
+ def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
424
422
  x = module_fn()
425
423
 
426
424
  class MyThread(Thread):
427
425
 
428
426
  def run(self) -> None:
429
- bst.graph.treefy_split(x)
427
+ brainstate.graph.treefy_split(x)
430
428
 
431
429
  thread = MyThread()
432
430
  thread.start()
@@ -435,26 +433,26 @@ class TestThreading(parameterized.TestCase):
435
433
 
436
434
  class TestGraphOperation(unittest.TestCase):
437
435
  def test1(self):
438
- class MyNode(bst.graph.Node):
436
+ class MyNode(brainstate.graph.Node):
439
437
  def __init__(self):
440
- self.a = bst.nn.Linear(2, 3)
441
- self.b = bst.nn.Linear(3, 2)
442
- self.c = [bst.nn.Linear(1, 2), bst.nn.Linear(1, 3)]
443
- self.d = {'x': bst.nn.Linear(1, 3), 'y': bst.nn.Linear(1, 4)}
438
+ self.a = brainstate.nn.Linear(2, 3)
439
+ self.b = brainstate.nn.Linear(3, 2)
440
+ self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
441
+ self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
444
442
 
445
- graphdef, statetree = bst.graph.flatten(MyNode())
443
+ graphdef, statetree = brainstate.graph.flatten(MyNode())
446
444
  # print(graphdef)
447
445
  print(statetree)
448
446
  # print(bst.graph.unflatten(graphdef, statetree))
449
447
 
450
448
  def test_split(self):
451
- class Foo(bst.graph.Node):
449
+ class Foo(brainstate.graph.Node):
452
450
  def __init__(self):
453
- self.a = bst.nn.Linear(2, 2)
454
- self.b = bst.nn.BatchNorm1d([10, 2])
451
+ self.a = brainstate.nn.Linear(2, 2)
452
+ self.b = brainstate.nn.BatchNorm1d([10, 2])
455
453
 
456
454
  node = Foo()
457
- graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
455
+ graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
458
456
 
459
457
  print(params)
460
458
  print(jax.tree.map(jnp.shape, params))
@@ -462,24 +460,24 @@ class TestGraphOperation(unittest.TestCase):
462
460
  print(jax.tree.map(jnp.shape, others))
463
461
 
464
462
  def test_merge(self):
465
- class Foo(bst.graph.Node):
463
+ class Foo(brainstate.graph.Node):
466
464
  def __init__(self):
467
- self.a = bst.nn.Linear(2, 2)
468
- self.b = bst.nn.BatchNorm1d([10, 2])
465
+ self.a = brainstate.nn.Linear(2, 2)
466
+ self.b = brainstate.nn.BatchNorm1d([10, 2])
469
467
 
470
468
  node = Foo()
471
- graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
469
+ graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
472
470
 
473
- new_node = bst.graph.treefy_merge(graphdef, params, others)
471
+ new_node = brainstate.graph.treefy_merge(graphdef, params, others)
474
472
 
475
473
  assert isinstance(new_node, Foo)
476
- assert isinstance(new_node.b, bst.nn.BatchNorm1d)
477
- assert isinstance(new_node.a, bst.nn.Linear)
474
+ assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
475
+ assert isinstance(new_node.a, brainstate.nn.Linear)
478
476
 
479
477
  def test_update_states(self):
480
478
  x = jnp.ones((1, 2))
481
479
  y = jnp.ones((1, 3))
482
- model = bst.nn.Linear(2, 3)
480
+ model = brainstate.nn.Linear(2, 3)
483
481
 
484
482
  def loss_fn(x, y):
485
483
  return jnp.mean((y - model(x)) ** 2)
@@ -490,44 +488,44 @@ class TestGraphOperation(unittest.TestCase):
490
488
 
491
489
  prev_loss = loss_fn(x, y)
492
490
  weights = model.states()
493
- grads = bst.augment.grad(loss_fn, weights)(x, y)
491
+ grads = brainstate.augment.grad(loss_fn, weights)(x, y)
494
492
  for key, val in grads.items():
495
493
  sgd(weights[key], val)
496
494
  assert loss_fn(x, y) < prev_loss
497
495
 
498
496
  def test_pop_states(self):
499
- class Model(bst.nn.Module):
497
+ class Model(brainstate.nn.Module):
500
498
  def __init__(self):
501
499
  super().__init__()
502
- self.a = bst.nn.Linear(2, 3)
503
- self.b = bst.nn.LIF([10, 2])
500
+ self.a = brainstate.nn.Linear(2, 3)
501
+ self.b = brainstate.nn.LIF([10, 2])
504
502
 
505
503
  model = Model()
506
- with bst.catch_new_states('new'):
507
- bst.nn.init_all_states(model)
504
+ with brainstate.catch_new_states('new'):
505
+ brainstate.nn.init_all_states(model)
508
506
  # print(model.states())
509
507
  self.assertTrue(len(model.states()) == 2)
510
- model_states = bst.graph.pop_states(model, 'new')
508
+ model_states = brainstate.graph.pop_states(model, 'new')
511
509
  print(model_states)
512
510
  self.assertTrue(len(model.states()) == 1)
513
511
  assert not hasattr(model.b, 'V')
514
512
  # print(model.states())
515
513
 
516
514
  def test_treefy_split(self):
517
- class MLP(bst.graph.Node):
515
+ class MLP(brainstate.graph.Node):
518
516
  def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
519
- self.input = bst.nn.Linear(din, dmid)
520
- self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
521
- self.output = bst.nn.Linear(dmid, dout)
517
+ self.input = brainstate.nn.Linear(din, dmid)
518
+ self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
519
+ self.output = brainstate.nn.Linear(dmid, dout)
522
520
 
523
521
  def __call__(self, x):
524
- x = bst.functional.relu(self.input(x))
522
+ x = brainstate.functional.relu(self.input(x))
525
523
  for layer in self.layers:
526
- x = bst.functional.relu(layer(x))
524
+ x = brainstate.functional.relu(layer(x))
527
525
  return self.output(x)
528
526
 
529
527
  model = MLP(2, 1, 3)
530
- graph_def, treefy_states = bst.graph.treefy_split(model)
528
+ graph_def, treefy_states = brainstate.graph.treefy_split(model)
531
529
 
532
530
  print(graph_def)
533
531
  print(treefy_states)
@@ -538,25 +536,25 @@ class TestGraphOperation(unittest.TestCase):
538
536
  # print(nest_states)
539
537
 
540
538
  def test_states(self):
541
- class MLP(bst.graph.Node):
539
+ class MLP(brainstate.graph.Node):
542
540
  def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
543
- self.input = bst.nn.Linear(din, dmid)
544
- self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
545
- self.output = bst.nn.LIF(dout)
541
+ self.input = brainstate.nn.Linear(din, dmid)
542
+ self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
543
+ self.output = brainstate.nn.LIF(dout)
546
544
 
547
545
  def __call__(self, x):
548
- x = bst.functional.relu(self.input(x))
546
+ x = brainstate.functional.relu(self.input(x))
549
547
  for layer in self.layers:
550
- x = bst.functional.relu(layer(x))
548
+ x = brainstate.functional.relu(layer(x))
551
549
  return self.output(x)
552
550
 
553
- model = bst.nn.init_all_states(MLP(2, 1, 3))
554
- states = bst.graph.states(model)
551
+ model = brainstate.nn.init_all_states(MLP(2, 1, 3))
552
+ states = brainstate.graph.states(model)
555
553
  print(states)
556
554
  nest_states = states.to_nest()
557
555
  print(nest_states)
558
556
 
559
- params, others = bst.graph.states(model, bst.ParamState, bst.ShortTermState)
557
+ params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
560
558
  print(params)
561
559
  print(others)
562
560