brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
16
|
import unittest
|
19
17
|
from collections.abc import Callable
|
20
18
|
from threading import Thread
|
@@ -23,67 +21,67 @@ import jax
|
|
23
21
|
import jax.numpy as jnp
|
24
22
|
from absl.testing import absltest, parameterized
|
25
23
|
|
26
|
-
import brainstate
|
24
|
+
import brainstate
|
27
25
|
|
28
26
|
|
29
27
|
class TestIter(unittest.TestCase):
|
30
28
|
def test1(self):
|
31
|
-
class Model(
|
29
|
+
class Model(brainstate.nn.Module):
|
32
30
|
def __init__(self):
|
33
31
|
super().__init__()
|
34
|
-
self.a =
|
35
|
-
self.b =
|
36
|
-
self.c = [
|
37
|
-
self.d = {'x':
|
38
|
-
self.b.a =
|
32
|
+
self.a = brainstate.nn.Linear(1, 2)
|
33
|
+
self.b = brainstate.nn.Linear(2, 3)
|
34
|
+
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
35
|
+
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
36
|
+
self.b.a = brainstate.nn.LIF(2)
|
39
37
|
|
40
|
-
for path, node in
|
38
|
+
for path, node in brainstate.graph.iter_leaf(Model()):
|
41
39
|
print(path, node)
|
42
|
-
for path, node in
|
40
|
+
for path, node in brainstate.graph.iter_node(Model()):
|
43
41
|
print(path, node)
|
44
|
-
for path, node in
|
42
|
+
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
|
45
43
|
print(path, node)
|
46
|
-
for path, node in
|
44
|
+
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
|
47
45
|
print(path, node)
|
48
46
|
|
49
47
|
def test_iter_leaf_v1(self):
|
50
|
-
class Linear(
|
48
|
+
class Linear(brainstate.nn.Module):
|
51
49
|
def __init__(self, din, dout):
|
52
50
|
super().__init__()
|
53
|
-
self.weight =
|
54
|
-
self.bias =
|
51
|
+
self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
52
|
+
self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
55
53
|
self.a = 1
|
56
54
|
|
57
55
|
module = Linear(3, 4)
|
58
56
|
graph = [module, module]
|
59
57
|
|
60
58
|
num = 0
|
61
|
-
for path, value in
|
59
|
+
for path, value in brainstate.graph.iter_leaf(graph):
|
62
60
|
print(path, type(value).__name__)
|
63
61
|
num += 1
|
64
62
|
|
65
63
|
assert num == 3
|
66
64
|
|
67
65
|
def test_iter_node_v1(self):
|
68
|
-
class Model(
|
66
|
+
class Model(brainstate.nn.Module):
|
69
67
|
def __init__(self):
|
70
68
|
super().__init__()
|
71
|
-
self.a =
|
72
|
-
self.b =
|
73
|
-
self.c = [
|
74
|
-
self.d = {'x':
|
75
|
-
self.b.a =
|
69
|
+
self.a = brainstate.nn.Linear(1, 2)
|
70
|
+
self.b = brainstate.nn.Linear(2, 3)
|
71
|
+
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
72
|
+
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
73
|
+
self.b.a = brainstate.nn.LIF(2)
|
76
74
|
|
77
75
|
model = Model()
|
78
76
|
|
79
77
|
num = 0
|
80
|
-
for path, node in
|
78
|
+
for path, node in brainstate.graph.iter_node([model, model]):
|
81
79
|
print(path, node.__class__.__name__)
|
82
80
|
num += 1
|
83
81
|
assert num == 8
|
84
82
|
|
85
83
|
|
86
|
-
class List(
|
84
|
+
class List(brainstate.nn.Module):
|
87
85
|
def __init__(self, items):
|
88
86
|
super().__init__()
|
89
87
|
self.items = list(items)
|
@@ -95,7 +93,7 @@ class List(bst.nn.Module):
|
|
95
93
|
self.items[idx] = value
|
96
94
|
|
97
95
|
|
98
|
-
class Dict(
|
96
|
+
class Dict(brainstate.nn.Module):
|
99
97
|
def __init__(self, *args, **kwargs):
|
100
98
|
super().__init__()
|
101
99
|
self.items = dict(*args, **kwargs)
|
@@ -107,12 +105,12 @@ class Dict(bst.nn.Module):
|
|
107
105
|
self.items[key] = value
|
108
106
|
|
109
107
|
|
110
|
-
class StatefulLinear(
|
108
|
+
class StatefulLinear(brainstate.nn.Module):
|
111
109
|
def __init__(self, din, dout):
|
112
110
|
super().__init__()
|
113
|
-
self.w =
|
114
|
-
self.b =
|
115
|
-
self.count =
|
111
|
+
self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
112
|
+
self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
113
|
+
self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
116
114
|
|
117
115
|
def increment(self):
|
118
116
|
self.count.value += 1
|
@@ -124,44 +122,44 @@ class StatefulLinear(bst.nn.Module):
|
|
124
122
|
|
125
123
|
class TestGraphUtils(absltest.TestCase):
|
126
124
|
def test_flatten_treey_state(self):
|
127
|
-
a = {'a': 1, 'b':
|
128
|
-
g = [a, 3, a,
|
125
|
+
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
126
|
+
g = [a, 3, a, brainstate.ParamState(4)]
|
129
127
|
|
130
|
-
refmap =
|
131
|
-
graphdef, states =
|
128
|
+
refmap = brainstate.graph.RefMap()
|
129
|
+
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
|
132
130
|
|
133
131
|
states[0]['b'].value = 2
|
134
132
|
states[3].value = 4
|
135
133
|
|
136
|
-
assert isinstance(states[0]['b'],
|
137
|
-
assert isinstance(states[3],
|
138
|
-
assert isinstance(states,
|
134
|
+
assert isinstance(states[0]['b'], brainstate.TreefyState)
|
135
|
+
assert isinstance(states[3], brainstate.TreefyState)
|
136
|
+
assert isinstance(states, brainstate.util.NestedDict)
|
139
137
|
assert len(refmap) == 2
|
140
138
|
assert a['b'] in refmap
|
141
139
|
assert g[3] in refmap
|
142
140
|
|
143
141
|
def test_flatten(self):
|
144
|
-
a = {'a': 1, 'b':
|
145
|
-
g = [a, 3, a,
|
142
|
+
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
143
|
+
g = [a, 3, a, brainstate.ParamState(4)]
|
146
144
|
|
147
|
-
refmap =
|
148
|
-
graphdef, states =
|
145
|
+
refmap = brainstate.graph.RefMap()
|
146
|
+
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
|
149
147
|
|
150
148
|
states[0]['b'].value = 2
|
151
149
|
states[3].value = 4
|
152
150
|
|
153
|
-
assert isinstance(states[0]['b'],
|
154
|
-
assert isinstance(states[3],
|
151
|
+
assert isinstance(states[0]['b'], brainstate.State)
|
152
|
+
assert isinstance(states[3], brainstate.State)
|
155
153
|
assert len(refmap) == 2
|
156
154
|
assert a['b'] in refmap
|
157
155
|
assert g[3] in refmap
|
158
156
|
|
159
157
|
def test_unflatten_treey_state(self):
|
160
|
-
a =
|
161
|
-
g1 =
|
158
|
+
a = brainstate.graph.Dict(a=1, b=brainstate.ParamState(2))
|
159
|
+
g1 = brainstate.graph.List([a, 3, a, brainstate.ParamState(4)])
|
162
160
|
|
163
|
-
graphdef, references =
|
164
|
-
g =
|
161
|
+
graphdef, references = brainstate.graph.flatten(g1, treefy_state=True)
|
162
|
+
g = brainstate.graph.unflatten(graphdef, references)
|
165
163
|
|
166
164
|
print(graphdef)
|
167
165
|
print(references)
|
@@ -170,11 +168,11 @@ class TestGraphUtils(absltest.TestCase):
|
|
170
168
|
assert g1[0]['b'] is not g[0]['b']
|
171
169
|
|
172
170
|
def test_unflatten(self):
|
173
|
-
a =
|
174
|
-
g1 =
|
171
|
+
a = brainstate.graph.Dict(a=1, b=brainstate.ParamState(2))
|
172
|
+
g1 = brainstate.graph.List([a, 3, a, brainstate.ParamState(4)])
|
175
173
|
|
176
|
-
graphdef, references =
|
177
|
-
g =
|
174
|
+
graphdef, references = brainstate.graph.flatten(g1, treefy_state=False)
|
175
|
+
g = brainstate.graph.unflatten(graphdef, references)
|
178
176
|
|
179
177
|
print(graphdef)
|
180
178
|
print(references)
|
@@ -183,29 +181,29 @@ class TestGraphUtils(absltest.TestCase):
|
|
183
181
|
assert g1[0]['b'] is g[0]['b']
|
184
182
|
|
185
183
|
def test_unflatten_pytree(self):
|
186
|
-
a = {'a': 1, 'b':
|
187
|
-
g = [a, 3, a,
|
184
|
+
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
185
|
+
g = [a, 3, a, brainstate.ParamState(4)]
|
188
186
|
|
189
|
-
graphdef, references =
|
190
|
-
g =
|
187
|
+
graphdef, references = brainstate.graph.treefy_split(g)
|
188
|
+
g = brainstate.graph.treefy_merge(graphdef, references)
|
191
189
|
|
192
190
|
assert g[0] is not g[2]
|
193
191
|
|
194
192
|
def test_unflatten_empty(self):
|
195
|
-
a = Dict({'a': 1, 'b':
|
196
|
-
g = List([a, 3, a,
|
193
|
+
a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
|
194
|
+
g = List([a, 3, a, brainstate.ParamState(4)])
|
197
195
|
|
198
|
-
graphdef, references =
|
196
|
+
graphdef, references = brainstate.graph.treefy_split(g)
|
199
197
|
|
200
198
|
with self.assertRaisesRegex(ValueError, 'Expected key'):
|
201
|
-
|
199
|
+
brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
|
202
200
|
|
203
201
|
def test_module_list(self):
|
204
202
|
ls = [
|
205
|
-
|
206
|
-
|
203
|
+
brainstate.nn.Linear(2, 2),
|
204
|
+
brainstate.nn.BatchNorm1d([10, 2]),
|
207
205
|
]
|
208
|
-
graphdef, statetree =
|
206
|
+
graphdef, statetree = brainstate.graph.treefy_split(ls)
|
209
207
|
|
210
208
|
assert statetree[0]['weight'].value['weight'].shape == (2, 2)
|
211
209
|
assert statetree[0]['weight'].value['bias'].shape == (2,)
|
@@ -215,47 +213,47 @@ class TestGraphUtils(absltest.TestCase):
|
|
215
213
|
assert statetree[1]['running_var'].value.shape == (1, 2)
|
216
214
|
|
217
215
|
def test_shared_variables(self):
|
218
|
-
v =
|
216
|
+
v = brainstate.ParamState(1)
|
219
217
|
g = [v, v]
|
220
218
|
|
221
|
-
graphdef, statetree =
|
219
|
+
graphdef, statetree = brainstate.graph.treefy_split(g)
|
222
220
|
assert len(statetree.to_flat()) == 1
|
223
221
|
|
224
|
-
g2 =
|
222
|
+
g2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
225
223
|
assert g2[0] is g2[1]
|
226
224
|
|
227
225
|
def test_tied_weights(self):
|
228
|
-
class Foo(
|
226
|
+
class Foo(brainstate.nn.Module):
|
229
227
|
def __init__(self) -> None:
|
230
228
|
super().__init__()
|
231
|
-
self.bar =
|
232
|
-
self.baz =
|
229
|
+
self.bar = brainstate.nn.Linear(2, 2)
|
230
|
+
self.baz = brainstate.nn.Linear(2, 2)
|
233
231
|
|
234
232
|
# tie the weights
|
235
233
|
self.baz.weight = self.bar.weight
|
236
234
|
|
237
235
|
node = Foo()
|
238
|
-
graphdef, state =
|
236
|
+
graphdef, state = brainstate.graph.treefy_split(node)
|
239
237
|
|
240
238
|
assert len(state.to_flat()) == 1
|
241
239
|
|
242
|
-
node2 =
|
240
|
+
node2 = brainstate.graph.treefy_merge(graphdef, state)
|
243
241
|
|
244
242
|
assert node2.bar.weight is node2.baz.weight
|
245
243
|
|
246
244
|
def test_tied_weights_example(self):
|
247
|
-
class LinearTranspose(
|
245
|
+
class LinearTranspose(brainstate.nn.Module):
|
248
246
|
def __init__(self, dout: int, din: int, ) -> None:
|
249
247
|
super().__init__()
|
250
|
-
self.kernel =
|
248
|
+
self.kernel = brainstate.ParamState(brainstate.init.LecunNormal()((dout, din)))
|
251
249
|
|
252
250
|
def __call__(self, x):
|
253
251
|
return x @ self.kernel.value.T
|
254
252
|
|
255
|
-
class Encoder(
|
253
|
+
class Encoder(brainstate.nn.Module):
|
256
254
|
def __init__(self, ) -> None:
|
257
255
|
super().__init__()
|
258
|
-
self.embed =
|
256
|
+
self.embed = brainstate.nn.Embedding(10, 2)
|
259
257
|
self.linear_out = LinearTranspose(10, 2)
|
260
258
|
|
261
259
|
# tie the weights
|
@@ -266,7 +264,7 @@ class TestGraphUtils(absltest.TestCase):
|
|
266
264
|
return self.linear_out(x)
|
267
265
|
|
268
266
|
model = Encoder()
|
269
|
-
graphdef, state =
|
267
|
+
graphdef, state = brainstate.graph.treefy_split(model)
|
270
268
|
|
271
269
|
assert len(state.to_flat()) == 1
|
272
270
|
|
@@ -276,49 +274,49 @@ class TestGraphUtils(absltest.TestCase):
|
|
276
274
|
assert y.shape == (2, 10)
|
277
275
|
|
278
276
|
def test_state_variables_not_shared_with_graph(self):
|
279
|
-
class Foo(
|
277
|
+
class Foo(brainstate.graph.Node):
|
280
278
|
def __init__(self):
|
281
|
-
self.a =
|
279
|
+
self.a = brainstate.ParamState(1)
|
282
280
|
|
283
281
|
m = Foo()
|
284
|
-
graphdef, statetree =
|
282
|
+
graphdef, statetree = brainstate.graph.treefy_split(m)
|
285
283
|
|
286
|
-
assert isinstance(m.a,
|
287
|
-
assert issubclass(statetree.a.type,
|
284
|
+
assert isinstance(m.a, brainstate.ParamState)
|
285
|
+
assert issubclass(statetree.a.type, brainstate.ParamState)
|
288
286
|
assert m.a is not statetree.a
|
289
287
|
assert m.a.value == statetree.a.value
|
290
288
|
|
291
|
-
m2 =
|
289
|
+
m2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
292
290
|
|
293
|
-
assert isinstance(m2.a,
|
294
|
-
assert issubclass(statetree.a.type,
|
291
|
+
assert isinstance(m2.a, brainstate.ParamState)
|
292
|
+
assert issubclass(statetree.a.type, brainstate.ParamState)
|
295
293
|
assert m2.a is not statetree.a
|
296
294
|
assert m2.a.value == statetree.a.value
|
297
295
|
|
298
296
|
def test_shared_state_variables_not_shared_with_graph(self):
|
299
|
-
class Foo(
|
297
|
+
class Foo(brainstate.graph.Node):
|
300
298
|
def __init__(self):
|
301
|
-
p =
|
299
|
+
p = brainstate.ParamState(1)
|
302
300
|
self.a = p
|
303
301
|
self.b = p
|
304
302
|
|
305
303
|
m = Foo()
|
306
|
-
graphdef, state =
|
304
|
+
graphdef, state = brainstate.graph.treefy_split(m)
|
307
305
|
|
308
|
-
assert isinstance(m.a,
|
309
|
-
assert isinstance(m.b,
|
310
|
-
assert issubclass(state.a.type,
|
306
|
+
assert isinstance(m.a, brainstate.ParamState)
|
307
|
+
assert isinstance(m.b, brainstate.ParamState)
|
308
|
+
assert issubclass(state.a.type, brainstate.ParamState)
|
311
309
|
assert 'b' not in state
|
312
310
|
assert m.a is not state.a
|
313
311
|
assert m.b is not state.a
|
314
312
|
assert m.a.value == state.a.value
|
315
313
|
assert m.b.value == state.a.value
|
316
314
|
|
317
|
-
m2 =
|
315
|
+
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
318
316
|
|
319
|
-
assert isinstance(m2.a,
|
320
|
-
assert isinstance(m2.b,
|
321
|
-
assert issubclass(state.a.type,
|
317
|
+
assert isinstance(m2.a, brainstate.ParamState)
|
318
|
+
assert isinstance(m2.b, brainstate.ParamState)
|
319
|
+
assert issubclass(state.a.type, brainstate.ParamState)
|
322
320
|
assert m2.a is not state.a
|
323
321
|
assert m2.b is not state.a
|
324
322
|
assert m2.a.value == state.a.value
|
@@ -326,24 +324,24 @@ class TestGraphUtils(absltest.TestCase):
|
|
326
324
|
assert m2.a is m2.b
|
327
325
|
|
328
326
|
def test_pytree_node(self):
|
329
|
-
@
|
327
|
+
@brainstate.util.dataclass
|
330
328
|
class Tree:
|
331
|
-
a:
|
332
|
-
b: str =
|
329
|
+
a: brainstate.ParamState
|
330
|
+
b: str = brainstate.util.field(pytree_node=False)
|
333
331
|
|
334
|
-
class Foo(
|
332
|
+
class Foo(brainstate.graph.Node):
|
335
333
|
def __init__(self):
|
336
|
-
self.tree = Tree(
|
334
|
+
self.tree = Tree(brainstate.ParamState(1), 'a')
|
337
335
|
|
338
336
|
m = Foo()
|
339
337
|
|
340
|
-
graphdef, state =
|
338
|
+
graphdef, state = brainstate.graph.treefy_split(m)
|
341
339
|
|
342
340
|
assert 'tree' in state
|
343
341
|
assert 'a' in state.tree
|
344
342
|
assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
|
345
343
|
|
346
|
-
m2 =
|
344
|
+
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
347
345
|
|
348
346
|
assert isinstance(m2.tree, Tree)
|
349
347
|
assert m2.tree.a.value == 1
|
@@ -352,36 +350,36 @@ class TestGraphUtils(absltest.TestCase):
|
|
352
350
|
assert m2.tree is not m.tree
|
353
351
|
|
354
352
|
def test_call_jit_update(self):
|
355
|
-
class Counter(
|
353
|
+
class Counter(brainstate.graph.Node):
|
356
354
|
def __init__(self):
|
357
|
-
self.count =
|
355
|
+
self.count = brainstate.ParamState(jnp.zeros(()))
|
358
356
|
|
359
357
|
def inc(self):
|
360
358
|
self.count.value += 1
|
361
359
|
return 1
|
362
360
|
|
363
|
-
graph_state =
|
361
|
+
graph_state = brainstate.graph.treefy_split(Counter())
|
364
362
|
|
365
363
|
@jax.jit
|
366
364
|
def update(graph_state):
|
367
|
-
out, graph_state =
|
365
|
+
out, graph_state = brainstate.graph.call(graph_state).inc()
|
368
366
|
self.assertEqual(out, 1)
|
369
367
|
return graph_state
|
370
368
|
|
371
369
|
graph_state = update(graph_state)
|
372
370
|
graph_state = update(graph_state)
|
373
371
|
|
374
|
-
counter =
|
372
|
+
counter = brainstate.graph.treefy_merge(*graph_state)
|
375
373
|
|
376
374
|
self.assertEqual(counter.count.value, 2)
|
377
375
|
|
378
376
|
def test_stateful_linear(self):
|
379
377
|
linear = StatefulLinear(3, 2)
|
380
|
-
linear_state =
|
378
|
+
linear_state = brainstate.graph.treefy_split(linear)
|
381
379
|
|
382
380
|
@jax.jit
|
383
381
|
def forward(x, pure_linear):
|
384
|
-
y, pure_linear =
|
382
|
+
y, pure_linear = brainstate.graph.call(pure_linear)(x)
|
385
383
|
return y, pure_linear
|
386
384
|
|
387
385
|
x = jnp.ones((1, 3))
|
@@ -389,7 +387,7 @@ class TestGraphUtils(absltest.TestCase):
|
|
389
387
|
y, linear_state = forward(x, linear_state)
|
390
388
|
|
391
389
|
self.assertEqual(linear.count.value, 0)
|
392
|
-
new_linear =
|
390
|
+
new_linear = brainstate.graph.treefy_merge(*linear_state)
|
393
391
|
self.assertEqual(new_linear.count.value, 2)
|
394
392
|
|
395
393
|
def test_getitem(self):
|
@@ -397,20 +395,20 @@ class TestGraphUtils(absltest.TestCase):
|
|
397
395
|
a=StatefulLinear(3, 2),
|
398
396
|
b=StatefulLinear(2, 1),
|
399
397
|
)
|
400
|
-
node_state =
|
401
|
-
_, node_state =
|
398
|
+
node_state = brainstate.graph.treefy_split(nodes)
|
399
|
+
_, node_state = brainstate.graph.call(node_state)['b'].increment()
|
402
400
|
|
403
|
-
nodes =
|
401
|
+
nodes = brainstate.graph.treefy_merge(*node_state)
|
404
402
|
|
405
403
|
self.assertEqual(nodes['a'].count.value, 0)
|
406
404
|
self.assertEqual(nodes['b'].count.value, 1)
|
407
405
|
|
408
406
|
|
409
|
-
class SimpleModule(
|
407
|
+
class SimpleModule(brainstate.nn.Module):
|
410
408
|
pass
|
411
409
|
|
412
410
|
|
413
|
-
class SimplePyTreeModule(
|
411
|
+
class SimplePyTreeModule(brainstate.nn.Module):
|
414
412
|
pass
|
415
413
|
|
416
414
|
|
@@ -420,13 +418,13 @@ class TestThreading(parameterized.TestCase):
|
|
420
418
|
(SimpleModule,),
|
421
419
|
(SimplePyTreeModule,),
|
422
420
|
)
|
423
|
-
def test_threading(self, module_fn: Callable[[],
|
421
|
+
def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
|
424
422
|
x = module_fn()
|
425
423
|
|
426
424
|
class MyThread(Thread):
|
427
425
|
|
428
426
|
def run(self) -> None:
|
429
|
-
|
427
|
+
brainstate.graph.treefy_split(x)
|
430
428
|
|
431
429
|
thread = MyThread()
|
432
430
|
thread.start()
|
@@ -435,26 +433,26 @@ class TestThreading(parameterized.TestCase):
|
|
435
433
|
|
436
434
|
class TestGraphOperation(unittest.TestCase):
|
437
435
|
def test1(self):
|
438
|
-
class MyNode(
|
436
|
+
class MyNode(brainstate.graph.Node):
|
439
437
|
def __init__(self):
|
440
|
-
self.a =
|
441
|
-
self.b =
|
442
|
-
self.c = [
|
443
|
-
self.d = {'x':
|
438
|
+
self.a = brainstate.nn.Linear(2, 3)
|
439
|
+
self.b = brainstate.nn.Linear(3, 2)
|
440
|
+
self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
|
441
|
+
self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
|
444
442
|
|
445
|
-
graphdef, statetree =
|
443
|
+
graphdef, statetree = brainstate.graph.flatten(MyNode())
|
446
444
|
# print(graphdef)
|
447
445
|
print(statetree)
|
448
446
|
# print(bst.graph.unflatten(graphdef, statetree))
|
449
447
|
|
450
448
|
def test_split(self):
|
451
|
-
class Foo(
|
449
|
+
class Foo(brainstate.graph.Node):
|
452
450
|
def __init__(self):
|
453
|
-
self.a =
|
454
|
-
self.b =
|
451
|
+
self.a = brainstate.nn.Linear(2, 2)
|
452
|
+
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
455
453
|
|
456
454
|
node = Foo()
|
457
|
-
graphdef, params, others =
|
455
|
+
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
458
456
|
|
459
457
|
print(params)
|
460
458
|
print(jax.tree.map(jnp.shape, params))
|
@@ -462,24 +460,24 @@ class TestGraphOperation(unittest.TestCase):
|
|
462
460
|
print(jax.tree.map(jnp.shape, others))
|
463
461
|
|
464
462
|
def test_merge(self):
|
465
|
-
class Foo(
|
463
|
+
class Foo(brainstate.graph.Node):
|
466
464
|
def __init__(self):
|
467
|
-
self.a =
|
468
|
-
self.b =
|
465
|
+
self.a = brainstate.nn.Linear(2, 2)
|
466
|
+
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
469
467
|
|
470
468
|
node = Foo()
|
471
|
-
graphdef, params, others =
|
469
|
+
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
472
470
|
|
473
|
-
new_node =
|
471
|
+
new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
474
472
|
|
475
473
|
assert isinstance(new_node, Foo)
|
476
|
-
assert isinstance(new_node.b,
|
477
|
-
assert isinstance(new_node.a,
|
474
|
+
assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
475
|
+
assert isinstance(new_node.a, brainstate.nn.Linear)
|
478
476
|
|
479
477
|
def test_update_states(self):
|
480
478
|
x = jnp.ones((1, 2))
|
481
479
|
y = jnp.ones((1, 3))
|
482
|
-
model =
|
480
|
+
model = brainstate.nn.Linear(2, 3)
|
483
481
|
|
484
482
|
def loss_fn(x, y):
|
485
483
|
return jnp.mean((y - model(x)) ** 2)
|
@@ -490,44 +488,44 @@ class TestGraphOperation(unittest.TestCase):
|
|
490
488
|
|
491
489
|
prev_loss = loss_fn(x, y)
|
492
490
|
weights = model.states()
|
493
|
-
grads =
|
491
|
+
grads = brainstate.augment.grad(loss_fn, weights)(x, y)
|
494
492
|
for key, val in grads.items():
|
495
493
|
sgd(weights[key], val)
|
496
494
|
assert loss_fn(x, y) < prev_loss
|
497
495
|
|
498
496
|
def test_pop_states(self):
|
499
|
-
class Model(
|
497
|
+
class Model(brainstate.nn.Module):
|
500
498
|
def __init__(self):
|
501
499
|
super().__init__()
|
502
|
-
self.a =
|
503
|
-
self.b =
|
500
|
+
self.a = brainstate.nn.Linear(2, 3)
|
501
|
+
self.b = brainstate.nn.LIF([10, 2])
|
504
502
|
|
505
503
|
model = Model()
|
506
|
-
with
|
507
|
-
|
504
|
+
with brainstate.catch_new_states('new'):
|
505
|
+
brainstate.nn.init_all_states(model)
|
508
506
|
# print(model.states())
|
509
507
|
self.assertTrue(len(model.states()) == 2)
|
510
|
-
model_states =
|
508
|
+
model_states = brainstate.graph.pop_states(model, 'new')
|
511
509
|
print(model_states)
|
512
510
|
self.assertTrue(len(model.states()) == 1)
|
513
511
|
assert not hasattr(model.b, 'V')
|
514
512
|
# print(model.states())
|
515
513
|
|
516
514
|
def test_treefy_split(self):
|
517
|
-
class MLP(
|
515
|
+
class MLP(brainstate.graph.Node):
|
518
516
|
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
519
|
-
self.input =
|
520
|
-
self.layers = [
|
521
|
-
self.output =
|
517
|
+
self.input = brainstate.nn.Linear(din, dmid)
|
518
|
+
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
519
|
+
self.output = brainstate.nn.Linear(dmid, dout)
|
522
520
|
|
523
521
|
def __call__(self, x):
|
524
|
-
x =
|
522
|
+
x = brainstate.functional.relu(self.input(x))
|
525
523
|
for layer in self.layers:
|
526
|
-
x =
|
524
|
+
x = brainstate.functional.relu(layer(x))
|
527
525
|
return self.output(x)
|
528
526
|
|
529
527
|
model = MLP(2, 1, 3)
|
530
|
-
graph_def, treefy_states =
|
528
|
+
graph_def, treefy_states = brainstate.graph.treefy_split(model)
|
531
529
|
|
532
530
|
print(graph_def)
|
533
531
|
print(treefy_states)
|
@@ -538,25 +536,25 @@ class TestGraphOperation(unittest.TestCase):
|
|
538
536
|
# print(nest_states)
|
539
537
|
|
540
538
|
def test_states(self):
|
541
|
-
class MLP(
|
539
|
+
class MLP(brainstate.graph.Node):
|
542
540
|
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
543
|
-
self.input =
|
544
|
-
self.layers = [
|
545
|
-
self.output =
|
541
|
+
self.input = brainstate.nn.Linear(din, dmid)
|
542
|
+
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
543
|
+
self.output = brainstate.nn.LIF(dout)
|
546
544
|
|
547
545
|
def __call__(self, x):
|
548
|
-
x =
|
546
|
+
x = brainstate.functional.relu(self.input(x))
|
549
547
|
for layer in self.layers:
|
550
|
-
x =
|
548
|
+
x = brainstate.functional.relu(layer(x))
|
551
549
|
return self.output(x)
|
552
550
|
|
553
|
-
model =
|
554
|
-
states =
|
551
|
+
model = brainstate.nn.init_all_states(MLP(2, 1, 3))
|
552
|
+
states = brainstate.graph.states(model)
|
555
553
|
print(states)
|
556
554
|
nest_states = states.to_nest()
|
557
555
|
print(nest_states)
|
558
556
|
|
559
|
-
params, others =
|
557
|
+
params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
|
560
558
|
print(params)
|
561
559
|
print(others)
|
562
560
|
|