brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1147 +1,1147 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
from collections.abc import Callable
|
18
|
-
from threading import Thread
|
19
|
-
|
20
|
-
import jax
|
21
|
-
import jax.numpy as jnp
|
22
|
-
from absl.testing import absltest, parameterized
|
23
|
-
|
24
|
-
import pytest
|
25
|
-
pytest.skip("skipping tests", allow_module_level=True)
|
26
|
-
|
27
|
-
import brainstate
|
28
|
-
import braintools
|
29
|
-
import brainpy
|
30
|
-
|
31
|
-
|
32
|
-
class TestIter(unittest.TestCase):
|
33
|
-
def test1(self):
|
34
|
-
class Model(brainstate.nn.Module):
|
35
|
-
def __init__(self):
|
36
|
-
super().__init__()
|
37
|
-
self.a = brainstate.nn.Linear(1, 2)
|
38
|
-
self.b = brainstate.nn.Linear(2, 3)
|
39
|
-
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
40
|
-
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
41
|
-
self.b.a = brainpy.LIF(2)
|
42
|
-
|
43
|
-
for path, node in brainstate.graph.iter_leaf(Model()):
|
44
|
-
print(path, node)
|
45
|
-
for path, node in brainstate.graph.iter_node(Model()):
|
46
|
-
print(path, node)
|
47
|
-
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
|
48
|
-
print(path, node)
|
49
|
-
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
|
50
|
-
print(path, node)
|
51
|
-
|
52
|
-
def test_iter_leaf_v1(self):
|
53
|
-
class Linear(brainstate.nn.Module):
|
54
|
-
def __init__(self, din, dout):
|
55
|
-
super().__init__()
|
56
|
-
self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
57
|
-
self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
58
|
-
self.a = 1
|
59
|
-
|
60
|
-
module = Linear(3, 4)
|
61
|
-
graph = [module, module]
|
62
|
-
|
63
|
-
num = 0
|
64
|
-
for path, value in brainstate.graph.iter_leaf(graph):
|
65
|
-
print(path, type(value).__name__)
|
66
|
-
num += 1
|
67
|
-
|
68
|
-
assert num == 3
|
69
|
-
|
70
|
-
def test_iter_node_v1(self):
|
71
|
-
class Model(brainstate.nn.Module):
|
72
|
-
def __init__(self):
|
73
|
-
super().__init__()
|
74
|
-
self.a = brainstate.nn.Linear(1, 2)
|
75
|
-
self.b = brainstate.nn.Linear(2, 3)
|
76
|
-
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
77
|
-
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
78
|
-
self.b.a = brainpy.LIF(2)
|
79
|
-
|
80
|
-
model = Model()
|
81
|
-
|
82
|
-
num = 0
|
83
|
-
for path, node in brainstate.graph.iter_node([model, model]):
|
84
|
-
print(path, node.__class__.__name__)
|
85
|
-
num += 1
|
86
|
-
assert num == 8
|
87
|
-
|
88
|
-
|
89
|
-
class List(brainstate.nn.Module):
|
90
|
-
def __init__(self, items):
|
91
|
-
super().__init__()
|
92
|
-
self.items = list(items)
|
93
|
-
|
94
|
-
def __getitem__(self, idx):
|
95
|
-
return self.items[idx]
|
96
|
-
|
97
|
-
def __setitem__(self, idx, value):
|
98
|
-
self.items[idx] = value
|
99
|
-
|
100
|
-
|
101
|
-
class Dict(brainstate.nn.Module):
|
102
|
-
def __init__(self, *args, **kwargs):
|
103
|
-
super().__init__()
|
104
|
-
self.items = dict(*args, **kwargs)
|
105
|
-
|
106
|
-
def __getitem__(self, key):
|
107
|
-
return self.items[key]
|
108
|
-
|
109
|
-
def __setitem__(self, key, value):
|
110
|
-
self.items[key] = value
|
111
|
-
|
112
|
-
|
113
|
-
class StatefulLinear(brainstate.nn.Module):
|
114
|
-
def __init__(self, din, dout):
|
115
|
-
super().__init__()
|
116
|
-
self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
117
|
-
self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
118
|
-
self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
119
|
-
|
120
|
-
def increment(self):
|
121
|
-
self.count.value += 1
|
122
|
-
|
123
|
-
def __call__(self, x):
|
124
|
-
self.count.value += 1
|
125
|
-
return x @ self.w.value + self.b.value
|
126
|
-
|
127
|
-
|
128
|
-
class TestGraphUtils(absltest.TestCase):
|
129
|
-
def test_flatten_treey_state(self):
|
130
|
-
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
131
|
-
g = [a, 3, a, brainstate.ParamState(4)]
|
132
|
-
|
133
|
-
refmap = brainstate.graph.RefMap()
|
134
|
-
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
|
135
|
-
|
136
|
-
states[0]['b'].value = 2
|
137
|
-
states[3].value = 4
|
138
|
-
|
139
|
-
assert isinstance(states[0]['b'], brainstate.TreefyState)
|
140
|
-
assert isinstance(states[3], brainstate.TreefyState)
|
141
|
-
assert isinstance(states, brainstate.util.NestedDict)
|
142
|
-
assert len(refmap) == 2
|
143
|
-
assert a['b'] in refmap
|
144
|
-
assert g[3] in refmap
|
145
|
-
|
146
|
-
def test_flatten(self):
|
147
|
-
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
148
|
-
g = [a, 3, a, brainstate.ParamState(4)]
|
149
|
-
|
150
|
-
refmap = brainstate.graph.RefMap()
|
151
|
-
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
|
152
|
-
|
153
|
-
states[0]['b'].value = 2
|
154
|
-
states[3].value = 4
|
155
|
-
|
156
|
-
assert isinstance(states[0]['b'], brainstate.State)
|
157
|
-
assert isinstance(states[3], brainstate.State)
|
158
|
-
assert len(refmap) == 2
|
159
|
-
assert a['b'] in refmap
|
160
|
-
assert g[3] in refmap
|
161
|
-
|
162
|
-
def test_unflatten_pytree(self):
|
163
|
-
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
164
|
-
g = [a, 3, a, brainstate.ParamState(4)]
|
165
|
-
|
166
|
-
graphdef, references = brainstate.graph.treefy_split(g)
|
167
|
-
g = brainstate.graph.treefy_merge(graphdef, references)
|
168
|
-
|
169
|
-
assert g[0] is not g[2]
|
170
|
-
|
171
|
-
def test_unflatten_empty(self):
|
172
|
-
a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
|
173
|
-
g = List([a, 3, a, brainstate.ParamState(4)])
|
174
|
-
|
175
|
-
graphdef, references = brainstate.graph.treefy_split(g)
|
176
|
-
|
177
|
-
with self.assertRaisesRegex(ValueError, 'Expected key'):
|
178
|
-
brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
|
179
|
-
|
180
|
-
def test_module_list(self):
|
181
|
-
ls = [
|
182
|
-
brainstate.nn.Linear(2, 2),
|
183
|
-
brainstate.nn.BatchNorm1d([10, 2]),
|
184
|
-
]
|
185
|
-
graphdef, statetree = brainstate.graph.treefy_split(ls)
|
186
|
-
|
187
|
-
assert statetree[0]['weight'].value['weight'].shape == (2, 2)
|
188
|
-
assert statetree[0]['weight'].value['bias'].shape == (2,)
|
189
|
-
assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
|
190
|
-
assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
|
191
|
-
assert statetree[1]['running_mean'].value.shape == (1, 2,)
|
192
|
-
assert statetree[1]['running_var'].value.shape == (1, 2)
|
193
|
-
|
194
|
-
def test_shared_variables(self):
|
195
|
-
v = brainstate.ParamState(1)
|
196
|
-
g = [v, v]
|
197
|
-
|
198
|
-
graphdef, statetree = brainstate.graph.treefy_split(g)
|
199
|
-
assert len(statetree.to_flat()) == 1
|
200
|
-
|
201
|
-
g2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
202
|
-
assert g2[0] is g2[1]
|
203
|
-
|
204
|
-
def test_tied_weights(self):
|
205
|
-
class Foo(brainstate.nn.Module):
|
206
|
-
def __init__(self) -> None:
|
207
|
-
super().__init__()
|
208
|
-
self.bar = brainstate.nn.Linear(2, 2)
|
209
|
-
self.baz = brainstate.nn.Linear(2, 2)
|
210
|
-
|
211
|
-
# tie the weights
|
212
|
-
self.baz.weight = self.bar.weight
|
213
|
-
|
214
|
-
node = Foo()
|
215
|
-
graphdef, state = brainstate.graph.treefy_split(node)
|
216
|
-
|
217
|
-
assert len(state.to_flat()) == 1
|
218
|
-
|
219
|
-
node2 = brainstate.graph.treefy_merge(graphdef, state)
|
220
|
-
|
221
|
-
assert node2.bar.weight is node2.baz.weight
|
222
|
-
|
223
|
-
def test_tied_weights_example(self):
|
224
|
-
class LinearTranspose(brainstate.nn.Module):
|
225
|
-
def __init__(self, dout: int, din: int, ) -> None:
|
226
|
-
super().__init__()
|
227
|
-
self.kernel = brainstate.ParamState(braintools.init.LecunNormal()((dout, din)))
|
228
|
-
|
229
|
-
def __call__(self, x):
|
230
|
-
return x @ self.kernel.value.T
|
231
|
-
|
232
|
-
class Encoder(brainstate.nn.Module):
|
233
|
-
def __init__(self, ) -> None:
|
234
|
-
super().__init__()
|
235
|
-
self.embed = brainstate.nn.Embedding(10, 2)
|
236
|
-
self.linear_out = LinearTranspose(10, 2)
|
237
|
-
|
238
|
-
# tie the weights
|
239
|
-
self.linear_out.kernel = self.embed.weight
|
240
|
-
|
241
|
-
def __call__(self, x):
|
242
|
-
x = self.embed(x)
|
243
|
-
return self.linear_out(x)
|
244
|
-
|
245
|
-
model = Encoder()
|
246
|
-
graphdef, state = brainstate.graph.treefy_split(model)
|
247
|
-
|
248
|
-
assert len(state.to_flat()) == 1
|
249
|
-
|
250
|
-
x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
|
251
|
-
y = model(x)
|
252
|
-
|
253
|
-
assert y.shape == (2, 10)
|
254
|
-
|
255
|
-
def test_state_variables_not_shared_with_graph(self):
|
256
|
-
class Foo(brainstate.graph.Node):
|
257
|
-
def __init__(self):
|
258
|
-
self.a = brainstate.ParamState(1)
|
259
|
-
|
260
|
-
m = Foo()
|
261
|
-
graphdef, statetree = brainstate.graph.treefy_split(m)
|
262
|
-
|
263
|
-
assert isinstance(m.a, brainstate.ParamState)
|
264
|
-
assert issubclass(statetree.a.type, brainstate.ParamState)
|
265
|
-
assert m.a is not statetree.a
|
266
|
-
assert m.a.value == statetree.a.value
|
267
|
-
|
268
|
-
m2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
269
|
-
|
270
|
-
assert isinstance(m2.a, brainstate.ParamState)
|
271
|
-
assert issubclass(statetree.a.type, brainstate.ParamState)
|
272
|
-
assert m2.a is not statetree.a
|
273
|
-
assert m2.a.value == statetree.a.value
|
274
|
-
|
275
|
-
def test_shared_state_variables_not_shared_with_graph(self):
|
276
|
-
class Foo(brainstate.graph.Node):
|
277
|
-
def __init__(self):
|
278
|
-
p = brainstate.ParamState(1)
|
279
|
-
self.a = p
|
280
|
-
self.b = p
|
281
|
-
|
282
|
-
m = Foo()
|
283
|
-
graphdef, state = brainstate.graph.treefy_split(m)
|
284
|
-
|
285
|
-
assert isinstance(m.a, brainstate.ParamState)
|
286
|
-
assert isinstance(m.b, brainstate.ParamState)
|
287
|
-
assert issubclass(state.a.type, brainstate.ParamState)
|
288
|
-
assert 'b' not in state
|
289
|
-
assert m.a is not state.a
|
290
|
-
assert m.b is not state.a
|
291
|
-
assert m.a.value == state.a.value
|
292
|
-
assert m.b.value == state.a.value
|
293
|
-
|
294
|
-
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
295
|
-
|
296
|
-
assert isinstance(m2.a, brainstate.ParamState)
|
297
|
-
assert isinstance(m2.b, brainstate.ParamState)
|
298
|
-
assert issubclass(state.a.type, brainstate.ParamState)
|
299
|
-
assert m2.a is not state.a
|
300
|
-
assert m2.b is not state.a
|
301
|
-
assert m2.a.value == state.a.value
|
302
|
-
assert m2.b.value == state.a.value
|
303
|
-
assert m2.a is m2.b
|
304
|
-
|
305
|
-
def test_pytree_node(self):
|
306
|
-
@brainstate.util.dataclass
|
307
|
-
class Tree:
|
308
|
-
a: brainstate.ParamState
|
309
|
-
b: str = brainstate.util.field(pytree_node=False)
|
310
|
-
|
311
|
-
class Foo(brainstate.graph.Node):
|
312
|
-
def __init__(self):
|
313
|
-
self.tree = Tree(brainstate.ParamState(1), 'a')
|
314
|
-
|
315
|
-
m = Foo()
|
316
|
-
|
317
|
-
graphdef, state = brainstate.graph.treefy_split(m)
|
318
|
-
|
319
|
-
assert 'tree' in state
|
320
|
-
assert 'a' in state.tree
|
321
|
-
assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
|
322
|
-
|
323
|
-
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
324
|
-
|
325
|
-
assert isinstance(m2.tree, Tree)
|
326
|
-
assert m2.tree.a.value == 1
|
327
|
-
assert m2.tree.b == 'a'
|
328
|
-
assert m2.tree.a is not m.tree.a
|
329
|
-
assert m2.tree is not m.tree
|
330
|
-
|
331
|
-
|
332
|
-
class SimpleModule(brainstate.nn.Module):
|
333
|
-
pass
|
334
|
-
|
335
|
-
|
336
|
-
class SimplePyTreeModule(brainstate.nn.Module):
|
337
|
-
pass
|
338
|
-
|
339
|
-
|
340
|
-
class TestThreading(parameterized.TestCase):
|
341
|
-
|
342
|
-
@parameterized.parameters(
|
343
|
-
(SimpleModule,),
|
344
|
-
(SimplePyTreeModule,),
|
345
|
-
)
|
346
|
-
def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
|
347
|
-
x = module_fn()
|
348
|
-
|
349
|
-
class MyThread(Thread):
|
350
|
-
|
351
|
-
def run(self) -> None:
|
352
|
-
brainstate.graph.treefy_split(x)
|
353
|
-
|
354
|
-
thread = MyThread()
|
355
|
-
thread.start()
|
356
|
-
thread.join()
|
357
|
-
|
358
|
-
|
359
|
-
class TestGraphOperation(unittest.TestCase):
|
360
|
-
def test1(self):
|
361
|
-
class MyNode(brainstate.graph.Node):
|
362
|
-
def __init__(self):
|
363
|
-
self.a = brainstate.nn.Linear(2, 3)
|
364
|
-
self.b = brainstate.nn.Linear(3, 2)
|
365
|
-
self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
|
366
|
-
self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
|
367
|
-
|
368
|
-
graphdef, statetree = brainstate.graph.flatten(MyNode())
|
369
|
-
# print(graphdef)
|
370
|
-
print(statetree)
|
371
|
-
# print(brainstate.graph.unflatten(graphdef, statetree))
|
372
|
-
|
373
|
-
def test_split(self):
|
374
|
-
class Foo(brainstate.graph.Node):
|
375
|
-
def __init__(self):
|
376
|
-
self.a = brainstate.nn.Linear(2, 2)
|
377
|
-
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
378
|
-
|
379
|
-
node = Foo()
|
380
|
-
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
381
|
-
|
382
|
-
print(params)
|
383
|
-
print(jax.tree.map(jnp.shape, params))
|
384
|
-
|
385
|
-
print(jax.tree.map(jnp.shape, others))
|
386
|
-
|
387
|
-
def test_merge(self):
|
388
|
-
class Foo(brainstate.graph.Node):
|
389
|
-
def __init__(self):
|
390
|
-
self.a = brainstate.nn.Linear(2, 2)
|
391
|
-
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
392
|
-
|
393
|
-
node = Foo()
|
394
|
-
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
395
|
-
|
396
|
-
new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
397
|
-
|
398
|
-
assert isinstance(new_node, Foo)
|
399
|
-
assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
400
|
-
assert isinstance(new_node.a, brainstate.nn.Linear)
|
401
|
-
|
402
|
-
def test_update_states(self):
|
403
|
-
x = jnp.ones((1, 2))
|
404
|
-
y = jnp.ones((1, 3))
|
405
|
-
model = brainstate.nn.Linear(2, 3)
|
406
|
-
|
407
|
-
def loss_fn(x, y):
|
408
|
-
return jnp.mean((y - model(x)) ** 2)
|
409
|
-
|
410
|
-
def sgd(ps, gs):
|
411
|
-
updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
|
412
|
-
ps.value = updates
|
413
|
-
|
414
|
-
prev_loss = loss_fn(x, y)
|
415
|
-
weights = model.states()
|
416
|
-
grads = brainstate.augment.grad(loss_fn, weights)(x, y)
|
417
|
-
for key, val in grads.items():
|
418
|
-
sgd(weights[key], val)
|
419
|
-
assert loss_fn(x, y) < prev_loss
|
420
|
-
|
421
|
-
def test_pop_states(self):
|
422
|
-
class Model(brainstate.nn.Module):
|
423
|
-
def __init__(self):
|
424
|
-
super().__init__()
|
425
|
-
self.a = brainstate.nn.Linear(2, 3)
|
426
|
-
self.b = brainpy.LIF([10, 2])
|
427
|
-
|
428
|
-
model = Model()
|
429
|
-
with brainstate.catch_new_states('new'):
|
430
|
-
brainstate.nn.init_all_states(model)
|
431
|
-
# print(model.states())
|
432
|
-
self.assertTrue(len(model.states()) == 2)
|
433
|
-
model_states = brainstate.graph.pop_states(model, 'new')
|
434
|
-
print(model_states)
|
435
|
-
self.assertTrue(len(model.states()) == 1)
|
436
|
-
assert not hasattr(model.b, 'V')
|
437
|
-
# print(model.states())
|
438
|
-
|
439
|
-
def test_treefy_split(self):
|
440
|
-
class MLP(brainstate.graph.Node):
|
441
|
-
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
442
|
-
self.input = brainstate.nn.Linear(din, dmid)
|
443
|
-
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
444
|
-
self.output = brainstate.nn.Linear(dmid, dout)
|
445
|
-
|
446
|
-
def __call__(self, x):
|
447
|
-
x = brainstate.functional.relu(self.input(x))
|
448
|
-
for layer in self.layers:
|
449
|
-
x = brainstate.functional.relu(layer(x))
|
450
|
-
return self.output(x)
|
451
|
-
|
452
|
-
model = MLP(2, 1, 3)
|
453
|
-
graph_def, treefy_states = brainstate.graph.treefy_split(model)
|
454
|
-
|
455
|
-
print(graph_def)
|
456
|
-
print(treefy_states)
|
457
|
-
|
458
|
-
# states = brainstate.graph.states(model)
|
459
|
-
# print(states)
|
460
|
-
# nest_states = states.to_nest()
|
461
|
-
# print(nest_states)
|
462
|
-
|
463
|
-
def test_states(self):
|
464
|
-
class MLP(brainstate.graph.Node):
|
465
|
-
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
466
|
-
self.input = brainstate.nn.Linear(din, dmid)
|
467
|
-
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
468
|
-
self.output = brainpy.LIF(dout)
|
469
|
-
|
470
|
-
def __call__(self, x):
|
471
|
-
x = brainstate.functional.relu(self.input(x))
|
472
|
-
for layer in self.layers:
|
473
|
-
x = brainstate.functional.relu(layer(x))
|
474
|
-
return self.output(x)
|
475
|
-
|
476
|
-
model = brainstate.nn.init_all_states(MLP(2, 1, 3))
|
477
|
-
states = brainstate.graph.states(model)
|
478
|
-
print(states)
|
479
|
-
nest_states = states.to_nest()
|
480
|
-
print(nest_states)
|
481
|
-
|
482
|
-
params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
|
483
|
-
print(params)
|
484
|
-
print(others)
|
485
|
-
|
486
|
-
|
487
|
-
class TestRefMap(unittest.TestCase):
|
488
|
-
"""Test RefMap class functionality."""
|
489
|
-
|
490
|
-
def test_refmap_basic_operations(self):
|
491
|
-
"""Test basic RefMap operations."""
|
492
|
-
ref_map = brainstate.graph.RefMap()
|
493
|
-
|
494
|
-
# Test empty RefMap
|
495
|
-
self.assertEqual(len(ref_map), 0)
|
496
|
-
self.assertFalse(object() in ref_map)
|
497
|
-
|
498
|
-
# Test adding items
|
499
|
-
obj1 = object()
|
500
|
-
obj2 = object()
|
501
|
-
ref_map[obj1] = 'value1'
|
502
|
-
ref_map[obj2] = 'value2'
|
503
|
-
|
504
|
-
self.assertEqual(len(ref_map), 2)
|
505
|
-
self.assertTrue(obj1 in ref_map)
|
506
|
-
self.assertTrue(obj2 in ref_map)
|
507
|
-
self.assertEqual(ref_map[obj1], 'value1')
|
508
|
-
self.assertEqual(ref_map[obj2], 'value2')
|
509
|
-
|
510
|
-
# Test iteration
|
511
|
-
keys = list(ref_map)
|
512
|
-
self.assertIn(obj1, keys)
|
513
|
-
self.assertIn(obj2, keys)
|
514
|
-
|
515
|
-
# Test deletion
|
516
|
-
del ref_map[obj1]
|
517
|
-
self.assertEqual(len(ref_map), 1)
|
518
|
-
self.assertFalse(obj1 in ref_map)
|
519
|
-
self.assertTrue(obj2 in ref_map)
|
520
|
-
|
521
|
-
def test_refmap_initialization_with_mapping(self):
|
522
|
-
"""Test RefMap initialization with a mapping."""
|
523
|
-
obj1, obj2 = object(), object()
|
524
|
-
mapping = {obj1: 'value1', obj2: 'value2'}
|
525
|
-
ref_map = brainstate.graph.RefMap(mapping)
|
526
|
-
|
527
|
-
self.assertEqual(len(ref_map), 2)
|
528
|
-
self.assertEqual(ref_map[obj1], 'value1')
|
529
|
-
self.assertEqual(ref_map[obj2], 'value2')
|
530
|
-
|
531
|
-
def test_refmap_initialization_with_iterable(self):
|
532
|
-
"""Test RefMap initialization with an iterable."""
|
533
|
-
obj1, obj2 = object(), object()
|
534
|
-
pairs = [(obj1, 'value1'), (obj2, 'value2')]
|
535
|
-
ref_map = brainstate.graph.RefMap(pairs)
|
536
|
-
|
537
|
-
self.assertEqual(len(ref_map), 2)
|
538
|
-
self.assertEqual(ref_map[obj1], 'value1')
|
539
|
-
self.assertEqual(ref_map[obj2], 'value2')
|
540
|
-
|
541
|
-
def test_refmap_same_object_different_instances(self):
|
542
|
-
"""Test RefMap handles same content objects with different ids."""
|
543
|
-
# Create two lists with same content but different ids
|
544
|
-
list1 = [1, 2, 3]
|
545
|
-
list2 = [1, 2, 3]
|
546
|
-
|
547
|
-
ref_map = brainstate.graph.RefMap()
|
548
|
-
ref_map[list1] = 'list1'
|
549
|
-
ref_map[list2] = 'list2'
|
550
|
-
|
551
|
-
# Should have 2 entries since they have different ids
|
552
|
-
self.assertEqual(len(ref_map), 2)
|
553
|
-
self.assertEqual(ref_map[list1], 'list1')
|
554
|
-
self.assertEqual(ref_map[list2], 'list2')
|
555
|
-
|
556
|
-
def test_refmap_update(self):
|
557
|
-
"""Test RefMap update method."""
|
558
|
-
obj1, obj2, obj3 = object(), object(), object()
|
559
|
-
ref_map = brainstate.graph.RefMap()
|
560
|
-
ref_map[obj1] = 'value1'
|
561
|
-
|
562
|
-
# Update with mapping
|
563
|
-
ref_map.update({obj2: 'value2', obj3: 'value3'})
|
564
|
-
self.assertEqual(len(ref_map), 3)
|
565
|
-
|
566
|
-
# Update existing key
|
567
|
-
ref_map[obj1] = 'new_value1'
|
568
|
-
self.assertEqual(ref_map[obj1], 'new_value1')
|
569
|
-
|
570
|
-
def test_refmap_str_repr(self):
|
571
|
-
"""Test RefMap string representation."""
|
572
|
-
ref_map = brainstate.graph.RefMap()
|
573
|
-
obj = object()
|
574
|
-
ref_map[obj] = 'value'
|
575
|
-
|
576
|
-
str_repr = str(ref_map)
|
577
|
-
self.assertIsInstance(str_repr, str)
|
578
|
-
# Check that __str__ calls __repr__
|
579
|
-
self.assertEqual(str_repr, repr(ref_map))
|
580
|
-
|
581
|
-
|
582
|
-
class TestHelperFunctions(unittest.TestCase):
|
583
|
-
"""Test helper functions in the _operation module."""
|
584
|
-
|
585
|
-
def test_is_state_leaf(self):
|
586
|
-
"""Test _is_state_leaf function."""
|
587
|
-
from brainstate.graph._operation import _is_state_leaf
|
588
|
-
|
589
|
-
# Create TreefyState instance
|
590
|
-
state = brainstate.ParamState(1)
|
591
|
-
treefy_state = state.to_state_ref()
|
592
|
-
|
593
|
-
self.assertTrue(_is_state_leaf(treefy_state))
|
594
|
-
self.assertFalse(_is_state_leaf(state))
|
595
|
-
self.assertFalse(_is_state_leaf(1))
|
596
|
-
self.assertFalse(_is_state_leaf("string"))
|
597
|
-
self.assertFalse(_is_state_leaf(None))
|
598
|
-
|
599
|
-
def test_is_node_leaf(self):
|
600
|
-
"""Test _is_node_leaf function."""
|
601
|
-
from brainstate.graph._operation import _is_node_leaf
|
602
|
-
|
603
|
-
state = brainstate.ParamState(1)
|
604
|
-
|
605
|
-
self.assertTrue(_is_node_leaf(state))
|
606
|
-
self.assertFalse(_is_node_leaf(1))
|
607
|
-
self.assertFalse(_is_node_leaf("string"))
|
608
|
-
self.assertFalse(_is_node_leaf(None))
|
609
|
-
|
610
|
-
def test_is_node(self):
|
611
|
-
"""Test _is_node function."""
|
612
|
-
from brainstate.graph._operation import _is_node
|
613
|
-
|
614
|
-
# Test with graph nodes
|
615
|
-
node = brainstate.nn.Module()
|
616
|
-
self.assertTrue(_is_node(node))
|
617
|
-
|
618
|
-
# Test with pytree nodes
|
619
|
-
self.assertTrue(_is_node([1, 2, 3]))
|
620
|
-
self.assertTrue(_is_node({'a': 1}))
|
621
|
-
|
622
|
-
# Test with non-nodes
|
623
|
-
self.assertFalse(_is_node(1))
|
624
|
-
self.assertFalse(_is_node("string"))
|
625
|
-
|
626
|
-
def test_is_pytree_node(self):
|
627
|
-
"""Test _is_pytree_node function."""
|
628
|
-
from brainstate.graph._operation import _is_pytree_node
|
629
|
-
|
630
|
-
self.assertTrue(_is_pytree_node([1, 2, 3]))
|
631
|
-
self.assertTrue(_is_pytree_node({'a': 1}))
|
632
|
-
self.assertTrue(_is_pytree_node((1, 2)))
|
633
|
-
|
634
|
-
self.assertFalse(_is_pytree_node(1))
|
635
|
-
self.assertFalse(_is_pytree_node("string"))
|
636
|
-
self.assertFalse(_is_pytree_node(jnp.array([1, 2])))
|
637
|
-
|
638
|
-
def test_is_graph_node(self):
|
639
|
-
"""Test _is_graph_node function."""
|
640
|
-
from brainstate.graph._operation import _is_graph_node
|
641
|
-
|
642
|
-
# Register a custom type for testing
|
643
|
-
class CustomNode:
|
644
|
-
pass
|
645
|
-
|
646
|
-
# Graph nodes are those registered with register_graph_node_type
|
647
|
-
node = brainstate.nn.Module()
|
648
|
-
self.assertTrue(_is_graph_node(node))
|
649
|
-
|
650
|
-
# Non-registered types
|
651
|
-
self.assertFalse(_is_graph_node([1, 2, 3]))
|
652
|
-
self.assertFalse(_is_graph_node({'a': 1}))
|
653
|
-
self.assertFalse(_is_graph_node(CustomNode()))
|
654
|
-
|
655
|
-
|
656
|
-
class TestRegisterGraphNodeType(unittest.TestCase):
|
657
|
-
"""Test register_graph_node_type functionality."""
|
658
|
-
|
659
|
-
def test_register_custom_node_type(self):
|
660
|
-
"""Test registering a custom graph node type."""
|
661
|
-
from brainstate.graph._operation import _is_graph_node, _get_node_impl
|
662
|
-
|
663
|
-
class CustomNode:
|
664
|
-
def __init__(self):
|
665
|
-
self.data = {}
|
666
|
-
|
667
|
-
def flatten_custom(node):
|
668
|
-
return list(node.data.items()), None
|
669
|
-
|
670
|
-
def set_key_custom(node, key, value):
|
671
|
-
node.data[key] = value
|
672
|
-
|
673
|
-
def pop_key_custom(node, key):
|
674
|
-
return node.data.pop(key)
|
675
|
-
|
676
|
-
def create_empty_custom(metadata):
|
677
|
-
return CustomNode()
|
678
|
-
|
679
|
-
def clear_custom(node):
|
680
|
-
node.data.clear()
|
681
|
-
|
682
|
-
# Register the custom node type
|
683
|
-
brainstate.graph.register_graph_node_type(
|
684
|
-
CustomNode,
|
685
|
-
flatten_custom,
|
686
|
-
set_key_custom,
|
687
|
-
pop_key_custom,
|
688
|
-
create_empty_custom,
|
689
|
-
clear_custom
|
690
|
-
)
|
691
|
-
|
692
|
-
# Test that the node is recognized
|
693
|
-
node = CustomNode()
|
694
|
-
self.assertTrue(_is_graph_node(node))
|
695
|
-
|
696
|
-
# Test node operations
|
697
|
-
node.data['key1'] = 'value1'
|
698
|
-
node_impl = _get_node_impl(node)
|
699
|
-
|
700
|
-
# Test flatten
|
701
|
-
items, metadata = node_impl.flatten(node)
|
702
|
-
self.assertEqual(list(items), [('key1', 'value1')])
|
703
|
-
|
704
|
-
# Test set_key
|
705
|
-
node_impl.set_key(node, 'key2', 'value2')
|
706
|
-
self.assertEqual(node.data['key2'], 'value2')
|
707
|
-
|
708
|
-
# Test pop_key
|
709
|
-
value = node_impl.pop_key(node, 'key1')
|
710
|
-
self.assertEqual(value, 'value1')
|
711
|
-
self.assertNotIn('key1', node.data)
|
712
|
-
|
713
|
-
# Test create_empty
|
714
|
-
new_node = node_impl.create_empty(None)
|
715
|
-
self.assertIsInstance(new_node, CustomNode)
|
716
|
-
self.assertEqual(new_node.data, {})
|
717
|
-
|
718
|
-
# Test clear
|
719
|
-
node_impl.clear(node)
|
720
|
-
self.assertEqual(node.data, {})
|
721
|
-
|
722
|
-
|
723
|
-
class TestHashableMapping(unittest.TestCase):
|
724
|
-
"""Test HashableMapping class."""
|
725
|
-
|
726
|
-
def test_hashable_mapping_basic(self):
|
727
|
-
"""Test basic HashableMapping operations."""
|
728
|
-
from brainstate.graph._operation import HashableMapping
|
729
|
-
|
730
|
-
mapping = {'a': 1, 'b': 2}
|
731
|
-
hm = HashableMapping(mapping)
|
732
|
-
|
733
|
-
# Test basic operations
|
734
|
-
self.assertEqual(len(hm), 2)
|
735
|
-
self.assertTrue('a' in hm)
|
736
|
-
self.assertFalse('c' in hm)
|
737
|
-
self.assertEqual(hm['a'], 1)
|
738
|
-
self.assertEqual(hm['b'], 2)
|
739
|
-
|
740
|
-
# Test iteration
|
741
|
-
keys = list(hm)
|
742
|
-
self.assertEqual(set(keys), {'a', 'b'})
|
743
|
-
|
744
|
-
def test_hashable_mapping_hash(self):
|
745
|
-
"""Test HashableMapping hashing."""
|
746
|
-
from brainstate.graph._operation import HashableMapping
|
747
|
-
|
748
|
-
hm1 = HashableMapping({'a': 1, 'b': 2})
|
749
|
-
hm2 = HashableMapping({'a': 1, 'b': 2})
|
750
|
-
hm3 = HashableMapping({'a': 1, 'b': 3})
|
751
|
-
|
752
|
-
# Equal mappings should have same hash
|
753
|
-
self.assertEqual(hash(hm1), hash(hm2))
|
754
|
-
self.assertEqual(hm1, hm2)
|
755
|
-
|
756
|
-
# Different mappings should not be equal
|
757
|
-
self.assertNotEqual(hm1, hm3)
|
758
|
-
|
759
|
-
# Can be used in sets
|
760
|
-
s = {hm1, hm2, hm3}
|
761
|
-
self.assertEqual(len(s), 2) # hm1 and hm2 are the same
|
762
|
-
|
763
|
-
def test_hashable_mapping_from_iterable(self):
|
764
|
-
"""Test HashableMapping creation from iterable."""
|
765
|
-
from brainstate.graph._operation import HashableMapping
|
766
|
-
|
767
|
-
pairs = [('a', 1), ('b', 2)]
|
768
|
-
hm = HashableMapping(pairs)
|
769
|
-
|
770
|
-
self.assertEqual(len(hm), 2)
|
771
|
-
self.assertEqual(hm['a'], 1)
|
772
|
-
self.assertEqual(hm['b'], 2)
|
773
|
-
|
774
|
-
|
775
|
-
class TestNodeDefAndNodeRef(unittest.TestCase):
|
776
|
-
"""Test NodeDef and NodeRef classes."""
|
777
|
-
|
778
|
-
def test_noderef_creation(self):
|
779
|
-
"""Test NodeRef creation and attributes."""
|
780
|
-
node_ref = brainstate.graph.NodeRef(
|
781
|
-
type=brainstate.nn.Module,
|
782
|
-
index=42
|
783
|
-
)
|
784
|
-
|
785
|
-
self.assertEqual(node_ref.type, brainstate.nn.Module)
|
786
|
-
self.assertEqual(node_ref.index, 42)
|
787
|
-
|
788
|
-
def test_nodedef_creation(self):
|
789
|
-
"""Test NodeDef creation and attributes."""
|
790
|
-
from brainstate.graph._operation import HashableMapping
|
791
|
-
|
792
|
-
nodedef = brainstate.graph.NodeDef.create(
|
793
|
-
type=brainstate.nn.Module,
|
794
|
-
index=1,
|
795
|
-
attributes=('a', 'b'),
|
796
|
-
subgraphs=[],
|
797
|
-
static_fields=[('static', 'value')],
|
798
|
-
leaves=[],
|
799
|
-
metadata=None,
|
800
|
-
index_mapping=None
|
801
|
-
)
|
802
|
-
|
803
|
-
self.assertEqual(nodedef.type, brainstate.nn.Module)
|
804
|
-
self.assertEqual(nodedef.index, 1)
|
805
|
-
self.assertEqual(nodedef.attributes, ('a', 'b'))
|
806
|
-
self.assertIsInstance(nodedef.subgraphs, HashableMapping)
|
807
|
-
self.assertIsInstance(nodedef.static_fields, HashableMapping)
|
808
|
-
self.assertEqual(nodedef.static_fields['static'], 'value')
|
809
|
-
self.assertIsNone(nodedef.metadata)
|
810
|
-
self.assertIsNone(nodedef.index_mapping)
|
811
|
-
|
812
|
-
def test_nodedef_with_index_mapping(self):
|
813
|
-
"""Test NodeDef with index_mapping."""
|
814
|
-
nodedef = brainstate.graph.NodeDef.create(
|
815
|
-
type=brainstate.nn.Module,
|
816
|
-
index=1,
|
817
|
-
attributes=(),
|
818
|
-
subgraphs=[],
|
819
|
-
static_fields=[],
|
820
|
-
leaves=[],
|
821
|
-
metadata=None,
|
822
|
-
index_mapping={1: 2, 3: 4}
|
823
|
-
)
|
824
|
-
|
825
|
-
self.assertIsNotNone(nodedef.index_mapping)
|
826
|
-
self.assertEqual(nodedef.index_mapping[1], 2)
|
827
|
-
self.assertEqual(nodedef.index_mapping[3], 4)
|
828
|
-
|
829
|
-
|
830
|
-
class TestGraphDefAndClone(unittest.TestCase):
|
831
|
-
"""Test graphdef and clone functions."""
|
832
|
-
|
833
|
-
def test_graphdef_function(self):
|
834
|
-
"""Test graphdef function returns correct GraphDef."""
|
835
|
-
model = brainstate.nn.Linear(2, 3)
|
836
|
-
graphdef = brainstate.graph.graphdef(model)
|
837
|
-
|
838
|
-
self.assertIsInstance(graphdef, brainstate.graph.NodeDef)
|
839
|
-
self.assertEqual(graphdef.type, brainstate.nn.Linear)
|
840
|
-
|
841
|
-
# Compare with flatten result
|
842
|
-
graphdef2, _ = brainstate.graph.flatten(model)
|
843
|
-
self.assertEqual(graphdef, graphdef2)
|
844
|
-
|
845
|
-
def test_clone_function(self):
|
846
|
-
"""Test clone creates a deep copy."""
|
847
|
-
model = brainstate.nn.Linear(2, 3)
|
848
|
-
cloned = brainstate.graph.clone(model)
|
849
|
-
|
850
|
-
# Check types
|
851
|
-
self.assertIsInstance(cloned, brainstate.nn.Linear)
|
852
|
-
self.assertIsNot(model, cloned)
|
853
|
-
|
854
|
-
# Check that states are not shared
|
855
|
-
self.assertIsNot(model.weight, cloned.weight)
|
856
|
-
|
857
|
-
# Modify original and check clone is unaffected
|
858
|
-
original_weight = cloned.weight.value['weight'].copy()
|
859
|
-
model.weight.value = jax.tree.map(lambda x: x + 1, model.weight.value)
|
860
|
-
|
861
|
-
# Clone should be unchanged
|
862
|
-
self.assertTrue(jnp.allclose(cloned.weight.value['weight'], original_weight))
|
863
|
-
|
864
|
-
def test_clone_with_shared_variables(self):
|
865
|
-
"""Test cloning preserves shared variable structure."""
|
866
|
-
|
867
|
-
class SharedModel(brainstate.nn.Module):
|
868
|
-
def __init__(self):
|
869
|
-
super().__init__()
|
870
|
-
self.shared_weight = brainstate.ParamState(jnp.ones((2, 2)))
|
871
|
-
self.layer1 = brainstate.nn.Linear(2, 2)
|
872
|
-
self.layer2 = brainstate.nn.Linear(2, 2)
|
873
|
-
# Share weights
|
874
|
-
self.layer2.weight = self.layer1.weight
|
875
|
-
|
876
|
-
model = SharedModel()
|
877
|
-
cloned = brainstate.graph.clone(model)
|
878
|
-
|
879
|
-
# Check that sharing is preserved
|
880
|
-
self.assertIs(cloned.layer1.weight, cloned.layer2.weight)
|
881
|
-
# But not shared with original
|
882
|
-
self.assertIsNot(cloned.layer1.weight, model.layer1.weight)
|
883
|
-
|
884
|
-
|
885
|
-
class TestNodesFunction(unittest.TestCase):
|
886
|
-
"""Test nodes function for filtering graph nodes."""
|
887
|
-
|
888
|
-
def test_nodes_without_filters(self):
|
889
|
-
"""Test nodes function without filters."""
|
890
|
-
|
891
|
-
class Model(brainstate.nn.Module):
|
892
|
-
def __init__(self):
|
893
|
-
super().__init__()
|
894
|
-
self.a = brainstate.nn.Linear(2, 3)
|
895
|
-
self.b = brainstate.nn.Linear(3, 4)
|
896
|
-
|
897
|
-
model = Model()
|
898
|
-
all_nodes = brainstate.graph.nodes(model)
|
899
|
-
|
900
|
-
# Should return all nodes as FlattedDict
|
901
|
-
self.assertIsInstance(all_nodes, brainstate.util.FlattedDict)
|
902
|
-
|
903
|
-
# Check that nodes are present
|
904
|
-
paths = [path for path, _ in all_nodes.items()]
|
905
|
-
self.assertIn(('a',), paths)
|
906
|
-
self.assertIn(('b',), paths)
|
907
|
-
self.assertIn((), paths) # The model itself
|
908
|
-
|
909
|
-
def test_nodes_with_filter(self):
|
910
|
-
"""Test nodes function with a single filter."""
|
911
|
-
|
912
|
-
class CustomModule(brainstate.nn.Module):
|
913
|
-
pass
|
914
|
-
|
915
|
-
class Model(brainstate.nn.Module):
|
916
|
-
def __init__(self):
|
917
|
-
super().__init__()
|
918
|
-
self.linear = brainstate.nn.Linear(2, 3)
|
919
|
-
self.custom = CustomModule()
|
920
|
-
|
921
|
-
model = Model()
|
922
|
-
|
923
|
-
# Filter for Linear modules
|
924
|
-
linear_nodes = brainstate.graph.nodes(
|
925
|
-
model,
|
926
|
-
lambda path, node: isinstance(node, brainstate.nn.Linear)
|
927
|
-
)
|
928
|
-
|
929
|
-
self.assertIsInstance(linear_nodes, brainstate.util.FlattedDict)
|
930
|
-
# Should only contain the Linear module
|
931
|
-
nodes_list = list(linear_nodes.values())
|
932
|
-
self.assertEqual(len(nodes_list), 1)
|
933
|
-
self.assertIsInstance(nodes_list[0], brainstate.nn.Linear)
|
934
|
-
|
935
|
-
def test_nodes_with_hierarchy(self):
|
936
|
-
"""Test nodes function with hierarchy limits."""
|
937
|
-
|
938
|
-
class Model(brainstate.nn.Module):
|
939
|
-
def __init__(self):
|
940
|
-
super().__init__()
|
941
|
-
self.layer1 = brainstate.nn.Linear(2, 3)
|
942
|
-
self.layer1.sublayer = brainstate.nn.Linear(3, 3)
|
943
|
-
|
944
|
-
model = Model()
|
945
|
-
|
946
|
-
# Get only level 1 nodes
|
947
|
-
level1_nodes = brainstate.graph.nodes(model, allowed_hierarchy=(1, 1))
|
948
|
-
paths = [path for path, _ in level1_nodes.items()]
|
949
|
-
|
950
|
-
self.assertIn(('layer1',), paths)
|
951
|
-
# Sublayer should not be included at level 1
|
952
|
-
self.assertNotIn(('layer1', 'sublayer'), paths)
|
953
|
-
|
954
|
-
|
955
|
-
class TestStatic(unittest.TestCase):
|
956
|
-
"""Test Static class functionality."""
|
957
|
-
|
958
|
-
def test_static_basic(self):
|
959
|
-
"""Test basic Static wrapper."""
|
960
|
-
from brainstate.graph._operation import Static
|
961
|
-
|
962
|
-
value = {'key': 'value'}
|
963
|
-
static = Static(value)
|
964
|
-
|
965
|
-
self.assertEqual(static.value, value)
|
966
|
-
self.assertIs(static.value, value)
|
967
|
-
|
968
|
-
def test_static_is_pytree_leaf(self):
|
969
|
-
"""Test that Static is treated as a pytree leaf."""
|
970
|
-
from brainstate.graph._operation import Static
|
971
|
-
|
972
|
-
static = Static({'key': 'value'})
|
973
|
-
|
974
|
-
# Should be treated as a leaf in pytree operations
|
975
|
-
leaves, treedef = jax.tree_util.tree_flatten(static)
|
976
|
-
self.assertEqual(len(leaves), 0) # Static has no leaves
|
977
|
-
|
978
|
-
# Test in a structure
|
979
|
-
tree = {'a': 1, 'b': static, 'c': [2, 3]}
|
980
|
-
leaves, treedef = jax.tree_util.tree_flatten(tree)
|
981
|
-
|
982
|
-
# static should not be in leaves since it's registered as static
|
983
|
-
self.assertNotIn(static, leaves)
|
984
|
-
|
985
|
-
def test_static_equality_and_hash(self):
|
986
|
-
"""Test Static equality and hashing."""
|
987
|
-
from brainstate.graph._operation import Static
|
988
|
-
|
989
|
-
static1 = Static(42)
|
990
|
-
static2 = Static(42)
|
991
|
-
static3 = Static(43)
|
992
|
-
|
993
|
-
# Dataclass frozen=True provides equality
|
994
|
-
self.assertEqual(static1, static2)
|
995
|
-
self.assertNotEqual(static1, static3)
|
996
|
-
|
997
|
-
# Can be hashed due to frozen=True
|
998
|
-
self.assertEqual(hash(static1), hash(static2))
|
999
|
-
self.assertNotEqual(hash(static1), hash(static3))
|
1000
|
-
|
1001
|
-
|
1002
|
-
class TestErrorHandling(unittest.TestCase):
|
1003
|
-
"""Test error handling and edge cases."""
|
1004
|
-
|
1005
|
-
def test_flatten_with_invalid_ref_index(self):
|
1006
|
-
"""Test flatten with invalid ref_index."""
|
1007
|
-
model = brainstate.nn.Linear(2, 3)
|
1008
|
-
|
1009
|
-
# Should raise assertion error with non-RefMap
|
1010
|
-
with self.assertRaises(AssertionError):
|
1011
|
-
brainstate.graph.flatten(model, ref_index={})
|
1012
|
-
|
1013
|
-
def test_unflatten_with_invalid_graphdef(self):
|
1014
|
-
"""Test unflatten with invalid graphdef."""
|
1015
|
-
state = brainstate.util.NestedDict({})
|
1016
|
-
|
1017
|
-
# Should raise assertion error with non-GraphDef
|
1018
|
-
with self.assertRaises(AssertionError):
|
1019
|
-
brainstate.graph.unflatten("not_a_graphdef", state)
|
1020
|
-
|
1021
|
-
def test_pop_states_without_filters(self):
|
1022
|
-
"""Test pop_states raises error without filters."""
|
1023
|
-
model = brainstate.nn.Linear(2, 3)
|
1024
|
-
|
1025
|
-
with self.assertRaises(ValueError) as context:
|
1026
|
-
brainstate.graph.pop_states(model)
|
1027
|
-
|
1028
|
-
self.assertIn('Expected at least one filter', str(context.exception))
|
1029
|
-
|
1030
|
-
def test_update_states_immutable_node(self):
|
1031
|
-
"""Test update_states on immutable pytree node."""
|
1032
|
-
# Create a pytree node (tuple is immutable)
|
1033
|
-
node = (1, 2, brainstate.ParamState(3))
|
1034
|
-
state = brainstate.util.NestedDict({0: brainstate.TreefyState(int, 10)})
|
1035
|
-
|
1036
|
-
# Should raise ValueError when trying to update immutable node
|
1037
|
-
with self.assertRaises(ValueError):
|
1038
|
-
brainstate.graph.update_states(node, state)
|
1039
|
-
|
1040
|
-
def test_get_node_impl_with_state(self):
|
1041
|
-
"""Test _get_node_impl raises error for State objects."""
|
1042
|
-
from brainstate.graph._operation import _get_node_impl
|
1043
|
-
|
1044
|
-
state = brainstate.ParamState(1)
|
1045
|
-
|
1046
|
-
with self.assertRaises(ValueError) as context:
|
1047
|
-
_get_node_impl(state)
|
1048
|
-
|
1049
|
-
self.assertIn('State is not a node', str(context.exception))
|
1050
|
-
|
1051
|
-
def test_split_with_non_exhaustive_filters(self):
|
1052
|
-
"""Test split with non-exhaustive filters."""
|
1053
|
-
from brainstate.graph._operation import _split_flatted
|
1054
|
-
|
1055
|
-
flatted = [(('a',), 1), (('b',), 2)]
|
1056
|
-
filters = (lambda path, value: value == 1,) # Only matches first item
|
1057
|
-
|
1058
|
-
# Should raise ValueError for non-exhaustive filters
|
1059
|
-
with self.assertRaises(ValueError) as context:
|
1060
|
-
_split_flatted(flatted, filters)
|
1061
|
-
|
1062
|
-
self.assertIn('Non-exhaustive filters', str(context.exception))
|
1063
|
-
|
1064
|
-
def test_invalid_filter_order(self):
|
1065
|
-
"""Test filters with ... not at the end."""
|
1066
|
-
from brainstate.graph._operation import _filters_to_predicates
|
1067
|
-
|
1068
|
-
# ... must be the last filter
|
1069
|
-
filters = (..., lambda p, v: True)
|
1070
|
-
|
1071
|
-
with self.assertRaises(ValueError) as context:
|
1072
|
-
_filters_to_predicates(filters)
|
1073
|
-
|
1074
|
-
self.assertIn('can only be used as the last filters', str(context.exception))
|
1075
|
-
|
1076
|
-
|
1077
|
-
class TestIntegration(unittest.TestCase):
|
1078
|
-
"""Integration tests for complex scenarios."""
|
1079
|
-
|
1080
|
-
def test_complex_graph_operations(self):
|
1081
|
-
"""Test complex graph with multiple levels and shared references."""
|
1082
|
-
|
1083
|
-
class SubModule(brainstate.nn.Module):
|
1084
|
-
def __init__(self):
|
1085
|
-
super().__init__()
|
1086
|
-
self.weight = brainstate.ParamState(jnp.ones((2, 2)))
|
1087
|
-
|
1088
|
-
class ComplexModel(brainstate.nn.Module):
|
1089
|
-
def __init__(self):
|
1090
|
-
super().__init__()
|
1091
|
-
self.shared = SubModule()
|
1092
|
-
self.layer1 = brainstate.nn.Linear(2, 3)
|
1093
|
-
self.layer2 = brainstate.nn.Linear(3, 4)
|
1094
|
-
self.layer2.shared_ref = self.shared # Create a reference
|
1095
|
-
self.nested = {
|
1096
|
-
'a': brainstate.nn.Linear(4, 5),
|
1097
|
-
'b': [brainstate.nn.Linear(5, 6), self.shared] # Another reference
|
1098
|
-
}
|
1099
|
-
|
1100
|
-
model = ComplexModel()
|
1101
|
-
|
1102
|
-
# Test flatten/unflatten preserves structure
|
1103
|
-
graphdef, state = brainstate.graph.treefy_split(model)
|
1104
|
-
reconstructed = brainstate.graph.treefy_merge(graphdef, state)
|
1105
|
-
|
1106
|
-
# Check shared references are preserved
|
1107
|
-
self.assertIs(reconstructed.shared, reconstructed.layer2.shared_ref)
|
1108
|
-
self.assertIs(reconstructed.shared, reconstructed.nested['b'][1])
|
1109
|
-
|
1110
|
-
# Test state updates
|
1111
|
-
new_state = jax.tree.map(lambda x: x * 2, state)
|
1112
|
-
brainstate.graph.update_states(model, new_state)
|
1113
|
-
|
1114
|
-
# Verify updates applied
|
1115
|
-
self.assertTrue(jnp.allclose(
|
1116
|
-
model.shared.weight.value,
|
1117
|
-
jnp.ones((2, 2)) * 2
|
1118
|
-
))
|
1119
|
-
|
1120
|
-
def test_recursive_structure(self):
|
1121
|
-
"""Test handling of recursive/circular references."""
|
1122
|
-
|
1123
|
-
class RecursiveModule(brainstate.nn.Module):
|
1124
|
-
def __init__(self):
|
1125
|
-
super().__init__()
|
1126
|
-
self.weight = brainstate.ParamState(1)
|
1127
|
-
self.child = None
|
1128
|
-
|
1129
|
-
# Create circular reference
|
1130
|
-
parent = RecursiveModule()
|
1131
|
-
child = RecursiveModule()
|
1132
|
-
parent.child = child
|
1133
|
-
child.child = parent # Circular reference
|
1134
|
-
|
1135
|
-
# Should handle circular references without infinite recursion
|
1136
|
-
graphdef, state = brainstate.graph.treefy_split(parent)
|
1137
|
-
|
1138
|
-
# Should be able to reconstruct
|
1139
|
-
reconstructed = brainstate.graph.treefy_merge(graphdef, state)
|
1140
|
-
|
1141
|
-
# Check structure is preserved
|
1142
|
-
self.assertIsNotNone(reconstructed.child)
|
1143
|
-
self.assertIs(reconstructed.child.child, reconstructed)
|
1144
|
-
|
1145
|
-
|
1146
|
-
if __name__ == '__main__':
|
1147
|
-
absltest.main()
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
from collections.abc import Callable
|
18
|
+
from threading import Thread
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
from absl.testing import absltest, parameterized
|
23
|
+
|
24
|
+
import pytest
|
25
|
+
pytest.skip("skipping tests", allow_module_level=True)
|
26
|
+
|
27
|
+
import brainstate
|
28
|
+
import braintools
|
29
|
+
import brainpy
|
30
|
+
|
31
|
+
|
32
|
+
class TestIter(unittest.TestCase):
|
33
|
+
def test1(self):
|
34
|
+
class Model(brainstate.nn.Module):
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
self.a = brainstate.nn.Linear(1, 2)
|
38
|
+
self.b = brainstate.nn.Linear(2, 3)
|
39
|
+
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
40
|
+
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
41
|
+
self.b.a = brainpy.LIF(2)
|
42
|
+
|
43
|
+
for path, node in brainstate.graph.iter_leaf(Model()):
|
44
|
+
print(path, node)
|
45
|
+
for path, node in brainstate.graph.iter_node(Model()):
|
46
|
+
print(path, node)
|
47
|
+
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
|
48
|
+
print(path, node)
|
49
|
+
for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
|
50
|
+
print(path, node)
|
51
|
+
|
52
|
+
def test_iter_leaf_v1(self):
|
53
|
+
class Linear(brainstate.nn.Module):
|
54
|
+
def __init__(self, din, dout):
|
55
|
+
super().__init__()
|
56
|
+
self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
57
|
+
self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
58
|
+
self.a = 1
|
59
|
+
|
60
|
+
module = Linear(3, 4)
|
61
|
+
graph = [module, module]
|
62
|
+
|
63
|
+
num = 0
|
64
|
+
for path, value in brainstate.graph.iter_leaf(graph):
|
65
|
+
print(path, type(value).__name__)
|
66
|
+
num += 1
|
67
|
+
|
68
|
+
assert num == 3
|
69
|
+
|
70
|
+
def test_iter_node_v1(self):
|
71
|
+
class Model(brainstate.nn.Module):
|
72
|
+
def __init__(self):
|
73
|
+
super().__init__()
|
74
|
+
self.a = brainstate.nn.Linear(1, 2)
|
75
|
+
self.b = brainstate.nn.Linear(2, 3)
|
76
|
+
self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
77
|
+
self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
78
|
+
self.b.a = brainpy.LIF(2)
|
79
|
+
|
80
|
+
model = Model()
|
81
|
+
|
82
|
+
num = 0
|
83
|
+
for path, node in brainstate.graph.iter_node([model, model]):
|
84
|
+
print(path, node.__class__.__name__)
|
85
|
+
num += 1
|
86
|
+
assert num == 8
|
87
|
+
|
88
|
+
|
89
|
+
class List(brainstate.nn.Module):
|
90
|
+
def __init__(self, items):
|
91
|
+
super().__init__()
|
92
|
+
self.items = list(items)
|
93
|
+
|
94
|
+
def __getitem__(self, idx):
|
95
|
+
return self.items[idx]
|
96
|
+
|
97
|
+
def __setitem__(self, idx, value):
|
98
|
+
self.items[idx] = value
|
99
|
+
|
100
|
+
|
101
|
+
class Dict(brainstate.nn.Module):
|
102
|
+
def __init__(self, *args, **kwargs):
|
103
|
+
super().__init__()
|
104
|
+
self.items = dict(*args, **kwargs)
|
105
|
+
|
106
|
+
def __getitem__(self, key):
|
107
|
+
return self.items[key]
|
108
|
+
|
109
|
+
def __setitem__(self, key, value):
|
110
|
+
self.items[key] = value
|
111
|
+
|
112
|
+
|
113
|
+
class StatefulLinear(brainstate.nn.Module):
|
114
|
+
def __init__(self, din, dout):
|
115
|
+
super().__init__()
|
116
|
+
self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
117
|
+
self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
118
|
+
self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
119
|
+
|
120
|
+
def increment(self):
|
121
|
+
self.count.value += 1
|
122
|
+
|
123
|
+
def __call__(self, x):
|
124
|
+
self.count.value += 1
|
125
|
+
return x @ self.w.value + self.b.value
|
126
|
+
|
127
|
+
|
128
|
+
class TestGraphUtils(absltest.TestCase):
|
129
|
+
def test_flatten_treey_state(self):
|
130
|
+
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
131
|
+
g = [a, 3, a, brainstate.ParamState(4)]
|
132
|
+
|
133
|
+
refmap = brainstate.graph.RefMap()
|
134
|
+
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
|
135
|
+
|
136
|
+
states[0]['b'].value = 2
|
137
|
+
states[3].value = 4
|
138
|
+
|
139
|
+
assert isinstance(states[0]['b'], brainstate.TreefyState)
|
140
|
+
assert isinstance(states[3], brainstate.TreefyState)
|
141
|
+
assert isinstance(states, brainstate.util.NestedDict)
|
142
|
+
assert len(refmap) == 2
|
143
|
+
assert a['b'] in refmap
|
144
|
+
assert g[3] in refmap
|
145
|
+
|
146
|
+
def test_flatten(self):
|
147
|
+
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
148
|
+
g = [a, 3, a, brainstate.ParamState(4)]
|
149
|
+
|
150
|
+
refmap = brainstate.graph.RefMap()
|
151
|
+
graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
|
152
|
+
|
153
|
+
states[0]['b'].value = 2
|
154
|
+
states[3].value = 4
|
155
|
+
|
156
|
+
assert isinstance(states[0]['b'], brainstate.State)
|
157
|
+
assert isinstance(states[3], brainstate.State)
|
158
|
+
assert len(refmap) == 2
|
159
|
+
assert a['b'] in refmap
|
160
|
+
assert g[3] in refmap
|
161
|
+
|
162
|
+
def test_unflatten_pytree(self):
|
163
|
+
a = {'a': 1, 'b': brainstate.ParamState(2)}
|
164
|
+
g = [a, 3, a, brainstate.ParamState(4)]
|
165
|
+
|
166
|
+
graphdef, references = brainstate.graph.treefy_split(g)
|
167
|
+
g = brainstate.graph.treefy_merge(graphdef, references)
|
168
|
+
|
169
|
+
assert g[0] is not g[2]
|
170
|
+
|
171
|
+
def test_unflatten_empty(self):
|
172
|
+
a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
|
173
|
+
g = List([a, 3, a, brainstate.ParamState(4)])
|
174
|
+
|
175
|
+
graphdef, references = brainstate.graph.treefy_split(g)
|
176
|
+
|
177
|
+
with self.assertRaisesRegex(ValueError, 'Expected key'):
|
178
|
+
brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
|
179
|
+
|
180
|
+
def test_module_list(self):
|
181
|
+
ls = [
|
182
|
+
brainstate.nn.Linear(2, 2),
|
183
|
+
brainstate.nn.BatchNorm1d([10, 2]),
|
184
|
+
]
|
185
|
+
graphdef, statetree = brainstate.graph.treefy_split(ls)
|
186
|
+
|
187
|
+
assert statetree[0]['weight'].value['weight'].shape == (2, 2)
|
188
|
+
assert statetree[0]['weight'].value['bias'].shape == (2,)
|
189
|
+
assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
|
190
|
+
assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
|
191
|
+
assert statetree[1]['running_mean'].value.shape == (1, 2,)
|
192
|
+
assert statetree[1]['running_var'].value.shape == (1, 2)
|
193
|
+
|
194
|
+
def test_shared_variables(self):
|
195
|
+
v = brainstate.ParamState(1)
|
196
|
+
g = [v, v]
|
197
|
+
|
198
|
+
graphdef, statetree = brainstate.graph.treefy_split(g)
|
199
|
+
assert len(statetree.to_flat()) == 1
|
200
|
+
|
201
|
+
g2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
202
|
+
assert g2[0] is g2[1]
|
203
|
+
|
204
|
+
def test_tied_weights(self):
|
205
|
+
class Foo(brainstate.nn.Module):
|
206
|
+
def __init__(self) -> None:
|
207
|
+
super().__init__()
|
208
|
+
self.bar = brainstate.nn.Linear(2, 2)
|
209
|
+
self.baz = brainstate.nn.Linear(2, 2)
|
210
|
+
|
211
|
+
# tie the weights
|
212
|
+
self.baz.weight = self.bar.weight
|
213
|
+
|
214
|
+
node = Foo()
|
215
|
+
graphdef, state = brainstate.graph.treefy_split(node)
|
216
|
+
|
217
|
+
assert len(state.to_flat()) == 1
|
218
|
+
|
219
|
+
node2 = brainstate.graph.treefy_merge(graphdef, state)
|
220
|
+
|
221
|
+
assert node2.bar.weight is node2.baz.weight
|
222
|
+
|
223
|
+
def test_tied_weights_example(self):
|
224
|
+
class LinearTranspose(brainstate.nn.Module):
|
225
|
+
def __init__(self, dout: int, din: int, ) -> None:
|
226
|
+
super().__init__()
|
227
|
+
self.kernel = brainstate.ParamState(braintools.init.LecunNormal()((dout, din)))
|
228
|
+
|
229
|
+
def __call__(self, x):
|
230
|
+
return x @ self.kernel.value.T
|
231
|
+
|
232
|
+
class Encoder(brainstate.nn.Module):
|
233
|
+
def __init__(self, ) -> None:
|
234
|
+
super().__init__()
|
235
|
+
self.embed = brainstate.nn.Embedding(10, 2)
|
236
|
+
self.linear_out = LinearTranspose(10, 2)
|
237
|
+
|
238
|
+
# tie the weights
|
239
|
+
self.linear_out.kernel = self.embed.weight
|
240
|
+
|
241
|
+
def __call__(self, x):
|
242
|
+
x = self.embed(x)
|
243
|
+
return self.linear_out(x)
|
244
|
+
|
245
|
+
model = Encoder()
|
246
|
+
graphdef, state = brainstate.graph.treefy_split(model)
|
247
|
+
|
248
|
+
assert len(state.to_flat()) == 1
|
249
|
+
|
250
|
+
x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
|
251
|
+
y = model(x)
|
252
|
+
|
253
|
+
assert y.shape == (2, 10)
|
254
|
+
|
255
|
+
def test_state_variables_not_shared_with_graph(self):
|
256
|
+
class Foo(brainstate.graph.Node):
|
257
|
+
def __init__(self):
|
258
|
+
self.a = brainstate.ParamState(1)
|
259
|
+
|
260
|
+
m = Foo()
|
261
|
+
graphdef, statetree = brainstate.graph.treefy_split(m)
|
262
|
+
|
263
|
+
assert isinstance(m.a, brainstate.ParamState)
|
264
|
+
assert issubclass(statetree.a.type, brainstate.ParamState)
|
265
|
+
assert m.a is not statetree.a
|
266
|
+
assert m.a.value == statetree.a.value
|
267
|
+
|
268
|
+
m2 = brainstate.graph.treefy_merge(graphdef, statetree)
|
269
|
+
|
270
|
+
assert isinstance(m2.a, brainstate.ParamState)
|
271
|
+
assert issubclass(statetree.a.type, brainstate.ParamState)
|
272
|
+
assert m2.a is not statetree.a
|
273
|
+
assert m2.a.value == statetree.a.value
|
274
|
+
|
275
|
+
def test_shared_state_variables_not_shared_with_graph(self):
|
276
|
+
class Foo(brainstate.graph.Node):
|
277
|
+
def __init__(self):
|
278
|
+
p = brainstate.ParamState(1)
|
279
|
+
self.a = p
|
280
|
+
self.b = p
|
281
|
+
|
282
|
+
m = Foo()
|
283
|
+
graphdef, state = brainstate.graph.treefy_split(m)
|
284
|
+
|
285
|
+
assert isinstance(m.a, brainstate.ParamState)
|
286
|
+
assert isinstance(m.b, brainstate.ParamState)
|
287
|
+
assert issubclass(state.a.type, brainstate.ParamState)
|
288
|
+
assert 'b' not in state
|
289
|
+
assert m.a is not state.a
|
290
|
+
assert m.b is not state.a
|
291
|
+
assert m.a.value == state.a.value
|
292
|
+
assert m.b.value == state.a.value
|
293
|
+
|
294
|
+
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
295
|
+
|
296
|
+
assert isinstance(m2.a, brainstate.ParamState)
|
297
|
+
assert isinstance(m2.b, brainstate.ParamState)
|
298
|
+
assert issubclass(state.a.type, brainstate.ParamState)
|
299
|
+
assert m2.a is not state.a
|
300
|
+
assert m2.b is not state.a
|
301
|
+
assert m2.a.value == state.a.value
|
302
|
+
assert m2.b.value == state.a.value
|
303
|
+
assert m2.a is m2.b
|
304
|
+
|
305
|
+
def test_pytree_node(self):
|
306
|
+
@brainstate.util.dataclass
|
307
|
+
class Tree:
|
308
|
+
a: brainstate.ParamState
|
309
|
+
b: str = brainstate.util.field(pytree_node=False)
|
310
|
+
|
311
|
+
class Foo(brainstate.graph.Node):
|
312
|
+
def __init__(self):
|
313
|
+
self.tree = Tree(brainstate.ParamState(1), 'a')
|
314
|
+
|
315
|
+
m = Foo()
|
316
|
+
|
317
|
+
graphdef, state = brainstate.graph.treefy_split(m)
|
318
|
+
|
319
|
+
assert 'tree' in state
|
320
|
+
assert 'a' in state.tree
|
321
|
+
assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
|
322
|
+
|
323
|
+
m2 = brainstate.graph.treefy_merge(graphdef, state)
|
324
|
+
|
325
|
+
assert isinstance(m2.tree, Tree)
|
326
|
+
assert m2.tree.a.value == 1
|
327
|
+
assert m2.tree.b == 'a'
|
328
|
+
assert m2.tree.a is not m.tree.a
|
329
|
+
assert m2.tree is not m.tree
|
330
|
+
|
331
|
+
|
332
|
+
class SimpleModule(brainstate.nn.Module):
|
333
|
+
pass
|
334
|
+
|
335
|
+
|
336
|
+
class SimplePyTreeModule(brainstate.nn.Module):
|
337
|
+
pass
|
338
|
+
|
339
|
+
|
340
|
+
class TestThreading(parameterized.TestCase):
|
341
|
+
|
342
|
+
@parameterized.parameters(
|
343
|
+
(SimpleModule,),
|
344
|
+
(SimplePyTreeModule,),
|
345
|
+
)
|
346
|
+
def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
|
347
|
+
x = module_fn()
|
348
|
+
|
349
|
+
class MyThread(Thread):
|
350
|
+
|
351
|
+
def run(self) -> None:
|
352
|
+
brainstate.graph.treefy_split(x)
|
353
|
+
|
354
|
+
thread = MyThread()
|
355
|
+
thread.start()
|
356
|
+
thread.join()
|
357
|
+
|
358
|
+
|
359
|
+
class TestGraphOperation(unittest.TestCase):
|
360
|
+
def test1(self):
|
361
|
+
class MyNode(brainstate.graph.Node):
|
362
|
+
def __init__(self):
|
363
|
+
self.a = brainstate.nn.Linear(2, 3)
|
364
|
+
self.b = brainstate.nn.Linear(3, 2)
|
365
|
+
self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
|
366
|
+
self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
|
367
|
+
|
368
|
+
graphdef, statetree = brainstate.graph.flatten(MyNode())
|
369
|
+
# print(graphdef)
|
370
|
+
print(statetree)
|
371
|
+
# print(brainstate.graph.unflatten(graphdef, statetree))
|
372
|
+
|
373
|
+
def test_split(self):
|
374
|
+
class Foo(brainstate.graph.Node):
|
375
|
+
def __init__(self):
|
376
|
+
self.a = brainstate.nn.Linear(2, 2)
|
377
|
+
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
378
|
+
|
379
|
+
node = Foo()
|
380
|
+
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
381
|
+
|
382
|
+
print(params)
|
383
|
+
print(jax.tree.map(jnp.shape, params))
|
384
|
+
|
385
|
+
print(jax.tree.map(jnp.shape, others))
|
386
|
+
|
387
|
+
def test_merge(self):
|
388
|
+
class Foo(brainstate.graph.Node):
|
389
|
+
def __init__(self):
|
390
|
+
self.a = brainstate.nn.Linear(2, 2)
|
391
|
+
self.b = brainstate.nn.BatchNorm1d([10, 2])
|
392
|
+
|
393
|
+
node = Foo()
|
394
|
+
graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
395
|
+
|
396
|
+
new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
397
|
+
|
398
|
+
assert isinstance(new_node, Foo)
|
399
|
+
assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
400
|
+
assert isinstance(new_node.a, brainstate.nn.Linear)
|
401
|
+
|
402
|
+
def test_update_states(self):
|
403
|
+
x = jnp.ones((1, 2))
|
404
|
+
y = jnp.ones((1, 3))
|
405
|
+
model = brainstate.nn.Linear(2, 3)
|
406
|
+
|
407
|
+
def loss_fn(x, y):
|
408
|
+
return jnp.mean((y - model(x)) ** 2)
|
409
|
+
|
410
|
+
def sgd(ps, gs):
|
411
|
+
updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
|
412
|
+
ps.value = updates
|
413
|
+
|
414
|
+
prev_loss = loss_fn(x, y)
|
415
|
+
weights = model.states()
|
416
|
+
grads = brainstate.augment.grad(loss_fn, weights)(x, y)
|
417
|
+
for key, val in grads.items():
|
418
|
+
sgd(weights[key], val)
|
419
|
+
assert loss_fn(x, y) < prev_loss
|
420
|
+
|
421
|
+
def test_pop_states(self):
|
422
|
+
class Model(brainstate.nn.Module):
|
423
|
+
def __init__(self):
|
424
|
+
super().__init__()
|
425
|
+
self.a = brainstate.nn.Linear(2, 3)
|
426
|
+
self.b = brainpy.LIF([10, 2])
|
427
|
+
|
428
|
+
model = Model()
|
429
|
+
with brainstate.catch_new_states('new'):
|
430
|
+
brainstate.nn.init_all_states(model)
|
431
|
+
# print(model.states())
|
432
|
+
self.assertTrue(len(model.states()) == 2)
|
433
|
+
model_states = brainstate.graph.pop_states(model, 'new')
|
434
|
+
print(model_states)
|
435
|
+
self.assertTrue(len(model.states()) == 1)
|
436
|
+
assert not hasattr(model.b, 'V')
|
437
|
+
# print(model.states())
|
438
|
+
|
439
|
+
def test_treefy_split(self):
|
440
|
+
class MLP(brainstate.graph.Node):
|
441
|
+
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
442
|
+
self.input = brainstate.nn.Linear(din, dmid)
|
443
|
+
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
444
|
+
self.output = brainstate.nn.Linear(dmid, dout)
|
445
|
+
|
446
|
+
def __call__(self, x):
|
447
|
+
x = brainstate.functional.relu(self.input(x))
|
448
|
+
for layer in self.layers:
|
449
|
+
x = brainstate.functional.relu(layer(x))
|
450
|
+
return self.output(x)
|
451
|
+
|
452
|
+
model = MLP(2, 1, 3)
|
453
|
+
graph_def, treefy_states = brainstate.graph.treefy_split(model)
|
454
|
+
|
455
|
+
print(graph_def)
|
456
|
+
print(treefy_states)
|
457
|
+
|
458
|
+
# states = brainstate.graph.states(model)
|
459
|
+
# print(states)
|
460
|
+
# nest_states = states.to_nest()
|
461
|
+
# print(nest_states)
|
462
|
+
|
463
|
+
def test_states(self):
|
464
|
+
class MLP(brainstate.graph.Node):
|
465
|
+
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
|
466
|
+
self.input = brainstate.nn.Linear(din, dmid)
|
467
|
+
self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
|
468
|
+
self.output = brainpy.LIF(dout)
|
469
|
+
|
470
|
+
def __call__(self, x):
|
471
|
+
x = brainstate.functional.relu(self.input(x))
|
472
|
+
for layer in self.layers:
|
473
|
+
x = brainstate.functional.relu(layer(x))
|
474
|
+
return self.output(x)
|
475
|
+
|
476
|
+
model = brainstate.nn.init_all_states(MLP(2, 1, 3))
|
477
|
+
states = brainstate.graph.states(model)
|
478
|
+
print(states)
|
479
|
+
nest_states = states.to_nest()
|
480
|
+
print(nest_states)
|
481
|
+
|
482
|
+
params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
|
483
|
+
print(params)
|
484
|
+
print(others)
|
485
|
+
|
486
|
+
|
487
|
+
class TestRefMap(unittest.TestCase):
|
488
|
+
"""Test RefMap class functionality."""
|
489
|
+
|
490
|
+
def test_refmap_basic_operations(self):
|
491
|
+
"""Test basic RefMap operations."""
|
492
|
+
ref_map = brainstate.graph.RefMap()
|
493
|
+
|
494
|
+
# Test empty RefMap
|
495
|
+
self.assertEqual(len(ref_map), 0)
|
496
|
+
self.assertFalse(object() in ref_map)
|
497
|
+
|
498
|
+
# Test adding items
|
499
|
+
obj1 = object()
|
500
|
+
obj2 = object()
|
501
|
+
ref_map[obj1] = 'value1'
|
502
|
+
ref_map[obj2] = 'value2'
|
503
|
+
|
504
|
+
self.assertEqual(len(ref_map), 2)
|
505
|
+
self.assertTrue(obj1 in ref_map)
|
506
|
+
self.assertTrue(obj2 in ref_map)
|
507
|
+
self.assertEqual(ref_map[obj1], 'value1')
|
508
|
+
self.assertEqual(ref_map[obj2], 'value2')
|
509
|
+
|
510
|
+
# Test iteration
|
511
|
+
keys = list(ref_map)
|
512
|
+
self.assertIn(obj1, keys)
|
513
|
+
self.assertIn(obj2, keys)
|
514
|
+
|
515
|
+
# Test deletion
|
516
|
+
del ref_map[obj1]
|
517
|
+
self.assertEqual(len(ref_map), 1)
|
518
|
+
self.assertFalse(obj1 in ref_map)
|
519
|
+
self.assertTrue(obj2 in ref_map)
|
520
|
+
|
521
|
+
def test_refmap_initialization_with_mapping(self):
|
522
|
+
"""Test RefMap initialization with a mapping."""
|
523
|
+
obj1, obj2 = object(), object()
|
524
|
+
mapping = {obj1: 'value1', obj2: 'value2'}
|
525
|
+
ref_map = brainstate.graph.RefMap(mapping)
|
526
|
+
|
527
|
+
self.assertEqual(len(ref_map), 2)
|
528
|
+
self.assertEqual(ref_map[obj1], 'value1')
|
529
|
+
self.assertEqual(ref_map[obj2], 'value2')
|
530
|
+
|
531
|
+
def test_refmap_initialization_with_iterable(self):
|
532
|
+
"""Test RefMap initialization with an iterable."""
|
533
|
+
obj1, obj2 = object(), object()
|
534
|
+
pairs = [(obj1, 'value1'), (obj2, 'value2')]
|
535
|
+
ref_map = brainstate.graph.RefMap(pairs)
|
536
|
+
|
537
|
+
self.assertEqual(len(ref_map), 2)
|
538
|
+
self.assertEqual(ref_map[obj1], 'value1')
|
539
|
+
self.assertEqual(ref_map[obj2], 'value2')
|
540
|
+
|
541
|
+
def test_refmap_same_object_different_instances(self):
|
542
|
+
"""Test RefMap handles same content objects with different ids."""
|
543
|
+
# Create two lists with same content but different ids
|
544
|
+
list1 = [1, 2, 3]
|
545
|
+
list2 = [1, 2, 3]
|
546
|
+
|
547
|
+
ref_map = brainstate.graph.RefMap()
|
548
|
+
ref_map[list1] = 'list1'
|
549
|
+
ref_map[list2] = 'list2'
|
550
|
+
|
551
|
+
# Should have 2 entries since they have different ids
|
552
|
+
self.assertEqual(len(ref_map), 2)
|
553
|
+
self.assertEqual(ref_map[list1], 'list1')
|
554
|
+
self.assertEqual(ref_map[list2], 'list2')
|
555
|
+
|
556
|
+
def test_refmap_update(self):
|
557
|
+
"""Test RefMap update method."""
|
558
|
+
obj1, obj2, obj3 = object(), object(), object()
|
559
|
+
ref_map = brainstate.graph.RefMap()
|
560
|
+
ref_map[obj1] = 'value1'
|
561
|
+
|
562
|
+
# Update with mapping
|
563
|
+
ref_map.update({obj2: 'value2', obj3: 'value3'})
|
564
|
+
self.assertEqual(len(ref_map), 3)
|
565
|
+
|
566
|
+
# Update existing key
|
567
|
+
ref_map[obj1] = 'new_value1'
|
568
|
+
self.assertEqual(ref_map[obj1], 'new_value1')
|
569
|
+
|
570
|
+
def test_refmap_str_repr(self):
|
571
|
+
"""Test RefMap string representation."""
|
572
|
+
ref_map = brainstate.graph.RefMap()
|
573
|
+
obj = object()
|
574
|
+
ref_map[obj] = 'value'
|
575
|
+
|
576
|
+
str_repr = str(ref_map)
|
577
|
+
self.assertIsInstance(str_repr, str)
|
578
|
+
# Check that __str__ calls __repr__
|
579
|
+
self.assertEqual(str_repr, repr(ref_map))
|
580
|
+
|
581
|
+
|
582
|
+
class TestHelperFunctions(unittest.TestCase):
|
583
|
+
"""Test helper functions in the _operation module."""
|
584
|
+
|
585
|
+
def test_is_state_leaf(self):
|
586
|
+
"""Test _is_state_leaf function."""
|
587
|
+
from brainstate.graph._operation import _is_state_leaf
|
588
|
+
|
589
|
+
# Create TreefyState instance
|
590
|
+
state = brainstate.ParamState(1)
|
591
|
+
treefy_state = state.to_state_ref()
|
592
|
+
|
593
|
+
self.assertTrue(_is_state_leaf(treefy_state))
|
594
|
+
self.assertFalse(_is_state_leaf(state))
|
595
|
+
self.assertFalse(_is_state_leaf(1))
|
596
|
+
self.assertFalse(_is_state_leaf("string"))
|
597
|
+
self.assertFalse(_is_state_leaf(None))
|
598
|
+
|
599
|
+
def test_is_node_leaf(self):
|
600
|
+
"""Test _is_node_leaf function."""
|
601
|
+
from brainstate.graph._operation import _is_node_leaf
|
602
|
+
|
603
|
+
state = brainstate.ParamState(1)
|
604
|
+
|
605
|
+
self.assertTrue(_is_node_leaf(state))
|
606
|
+
self.assertFalse(_is_node_leaf(1))
|
607
|
+
self.assertFalse(_is_node_leaf("string"))
|
608
|
+
self.assertFalse(_is_node_leaf(None))
|
609
|
+
|
610
|
+
def test_is_node(self):
|
611
|
+
"""Test _is_node function."""
|
612
|
+
from brainstate.graph._operation import _is_node
|
613
|
+
|
614
|
+
# Test with graph nodes
|
615
|
+
node = brainstate.nn.Module()
|
616
|
+
self.assertTrue(_is_node(node))
|
617
|
+
|
618
|
+
# Test with pytree nodes
|
619
|
+
self.assertTrue(_is_node([1, 2, 3]))
|
620
|
+
self.assertTrue(_is_node({'a': 1}))
|
621
|
+
|
622
|
+
# Test with non-nodes
|
623
|
+
self.assertFalse(_is_node(1))
|
624
|
+
self.assertFalse(_is_node("string"))
|
625
|
+
|
626
|
+
def test_is_pytree_node(self):
|
627
|
+
"""Test _is_pytree_node function."""
|
628
|
+
from brainstate.graph._operation import _is_pytree_node
|
629
|
+
|
630
|
+
self.assertTrue(_is_pytree_node([1, 2, 3]))
|
631
|
+
self.assertTrue(_is_pytree_node({'a': 1}))
|
632
|
+
self.assertTrue(_is_pytree_node((1, 2)))
|
633
|
+
|
634
|
+
self.assertFalse(_is_pytree_node(1))
|
635
|
+
self.assertFalse(_is_pytree_node("string"))
|
636
|
+
self.assertFalse(_is_pytree_node(jnp.array([1, 2])))
|
637
|
+
|
638
|
+
def test_is_graph_node(self):
|
639
|
+
"""Test _is_graph_node function."""
|
640
|
+
from brainstate.graph._operation import _is_graph_node
|
641
|
+
|
642
|
+
# Register a custom type for testing
|
643
|
+
class CustomNode:
|
644
|
+
pass
|
645
|
+
|
646
|
+
# Graph nodes are those registered with register_graph_node_type
|
647
|
+
node = brainstate.nn.Module()
|
648
|
+
self.assertTrue(_is_graph_node(node))
|
649
|
+
|
650
|
+
# Non-registered types
|
651
|
+
self.assertFalse(_is_graph_node([1, 2, 3]))
|
652
|
+
self.assertFalse(_is_graph_node({'a': 1}))
|
653
|
+
self.assertFalse(_is_graph_node(CustomNode()))
|
654
|
+
|
655
|
+
|
656
|
+
class TestRegisterGraphNodeType(unittest.TestCase):
|
657
|
+
"""Test register_graph_node_type functionality."""
|
658
|
+
|
659
|
+
def test_register_custom_node_type(self):
|
660
|
+
"""Test registering a custom graph node type."""
|
661
|
+
from brainstate.graph._operation import _is_graph_node, _get_node_impl
|
662
|
+
|
663
|
+
class CustomNode:
|
664
|
+
def __init__(self):
|
665
|
+
self.data = {}
|
666
|
+
|
667
|
+
def flatten_custom(node):
|
668
|
+
return list(node.data.items()), None
|
669
|
+
|
670
|
+
def set_key_custom(node, key, value):
|
671
|
+
node.data[key] = value
|
672
|
+
|
673
|
+
def pop_key_custom(node, key):
|
674
|
+
return node.data.pop(key)
|
675
|
+
|
676
|
+
def create_empty_custom(metadata):
|
677
|
+
return CustomNode()
|
678
|
+
|
679
|
+
def clear_custom(node):
|
680
|
+
node.data.clear()
|
681
|
+
|
682
|
+
# Register the custom node type
|
683
|
+
brainstate.graph.register_graph_node_type(
|
684
|
+
CustomNode,
|
685
|
+
flatten_custom,
|
686
|
+
set_key_custom,
|
687
|
+
pop_key_custom,
|
688
|
+
create_empty_custom,
|
689
|
+
clear_custom
|
690
|
+
)
|
691
|
+
|
692
|
+
# Test that the node is recognized
|
693
|
+
node = CustomNode()
|
694
|
+
self.assertTrue(_is_graph_node(node))
|
695
|
+
|
696
|
+
# Test node operations
|
697
|
+
node.data['key1'] = 'value1'
|
698
|
+
node_impl = _get_node_impl(node)
|
699
|
+
|
700
|
+
# Test flatten
|
701
|
+
items, metadata = node_impl.flatten(node)
|
702
|
+
self.assertEqual(list(items), [('key1', 'value1')])
|
703
|
+
|
704
|
+
# Test set_key
|
705
|
+
node_impl.set_key(node, 'key2', 'value2')
|
706
|
+
self.assertEqual(node.data['key2'], 'value2')
|
707
|
+
|
708
|
+
# Test pop_key
|
709
|
+
value = node_impl.pop_key(node, 'key1')
|
710
|
+
self.assertEqual(value, 'value1')
|
711
|
+
self.assertNotIn('key1', node.data)
|
712
|
+
|
713
|
+
# Test create_empty
|
714
|
+
new_node = node_impl.create_empty(None)
|
715
|
+
self.assertIsInstance(new_node, CustomNode)
|
716
|
+
self.assertEqual(new_node.data, {})
|
717
|
+
|
718
|
+
# Test clear
|
719
|
+
node_impl.clear(node)
|
720
|
+
self.assertEqual(node.data, {})
|
721
|
+
|
722
|
+
|
723
|
+
class TestHashableMapping(unittest.TestCase):
|
724
|
+
"""Test HashableMapping class."""
|
725
|
+
|
726
|
+
def test_hashable_mapping_basic(self):
|
727
|
+
"""Test basic HashableMapping operations."""
|
728
|
+
from brainstate.graph._operation import HashableMapping
|
729
|
+
|
730
|
+
mapping = {'a': 1, 'b': 2}
|
731
|
+
hm = HashableMapping(mapping)
|
732
|
+
|
733
|
+
# Test basic operations
|
734
|
+
self.assertEqual(len(hm), 2)
|
735
|
+
self.assertTrue('a' in hm)
|
736
|
+
self.assertFalse('c' in hm)
|
737
|
+
self.assertEqual(hm['a'], 1)
|
738
|
+
self.assertEqual(hm['b'], 2)
|
739
|
+
|
740
|
+
# Test iteration
|
741
|
+
keys = list(hm)
|
742
|
+
self.assertEqual(set(keys), {'a', 'b'})
|
743
|
+
|
744
|
+
def test_hashable_mapping_hash(self):
|
745
|
+
"""Test HashableMapping hashing."""
|
746
|
+
from brainstate.graph._operation import HashableMapping
|
747
|
+
|
748
|
+
hm1 = HashableMapping({'a': 1, 'b': 2})
|
749
|
+
hm2 = HashableMapping({'a': 1, 'b': 2})
|
750
|
+
hm3 = HashableMapping({'a': 1, 'b': 3})
|
751
|
+
|
752
|
+
# Equal mappings should have same hash
|
753
|
+
self.assertEqual(hash(hm1), hash(hm2))
|
754
|
+
self.assertEqual(hm1, hm2)
|
755
|
+
|
756
|
+
# Different mappings should not be equal
|
757
|
+
self.assertNotEqual(hm1, hm3)
|
758
|
+
|
759
|
+
# Can be used in sets
|
760
|
+
s = {hm1, hm2, hm3}
|
761
|
+
self.assertEqual(len(s), 2) # hm1 and hm2 are the same
|
762
|
+
|
763
|
+
def test_hashable_mapping_from_iterable(self):
|
764
|
+
"""Test HashableMapping creation from iterable."""
|
765
|
+
from brainstate.graph._operation import HashableMapping
|
766
|
+
|
767
|
+
pairs = [('a', 1), ('b', 2)]
|
768
|
+
hm = HashableMapping(pairs)
|
769
|
+
|
770
|
+
self.assertEqual(len(hm), 2)
|
771
|
+
self.assertEqual(hm['a'], 1)
|
772
|
+
self.assertEqual(hm['b'], 2)
|
773
|
+
|
774
|
+
|
775
|
+
class TestNodeDefAndNodeRef(unittest.TestCase):
|
776
|
+
"""Test NodeDef and NodeRef classes."""
|
777
|
+
|
778
|
+
def test_noderef_creation(self):
|
779
|
+
"""Test NodeRef creation and attributes."""
|
780
|
+
node_ref = brainstate.graph.NodeRef(
|
781
|
+
type=brainstate.nn.Module,
|
782
|
+
index=42
|
783
|
+
)
|
784
|
+
|
785
|
+
self.assertEqual(node_ref.type, brainstate.nn.Module)
|
786
|
+
self.assertEqual(node_ref.index, 42)
|
787
|
+
|
788
|
+
def test_nodedef_creation(self):
|
789
|
+
"""Test NodeDef creation and attributes."""
|
790
|
+
from brainstate.graph._operation import HashableMapping
|
791
|
+
|
792
|
+
nodedef = brainstate.graph.NodeDef.create(
|
793
|
+
type=brainstate.nn.Module,
|
794
|
+
index=1,
|
795
|
+
attributes=('a', 'b'),
|
796
|
+
subgraphs=[],
|
797
|
+
static_fields=[('static', 'value')],
|
798
|
+
leaves=[],
|
799
|
+
metadata=None,
|
800
|
+
index_mapping=None
|
801
|
+
)
|
802
|
+
|
803
|
+
self.assertEqual(nodedef.type, brainstate.nn.Module)
|
804
|
+
self.assertEqual(nodedef.index, 1)
|
805
|
+
self.assertEqual(nodedef.attributes, ('a', 'b'))
|
806
|
+
self.assertIsInstance(nodedef.subgraphs, HashableMapping)
|
807
|
+
self.assertIsInstance(nodedef.static_fields, HashableMapping)
|
808
|
+
self.assertEqual(nodedef.static_fields['static'], 'value')
|
809
|
+
self.assertIsNone(nodedef.metadata)
|
810
|
+
self.assertIsNone(nodedef.index_mapping)
|
811
|
+
|
812
|
+
def test_nodedef_with_index_mapping(self):
|
813
|
+
"""Test NodeDef with index_mapping."""
|
814
|
+
nodedef = brainstate.graph.NodeDef.create(
|
815
|
+
type=brainstate.nn.Module,
|
816
|
+
index=1,
|
817
|
+
attributes=(),
|
818
|
+
subgraphs=[],
|
819
|
+
static_fields=[],
|
820
|
+
leaves=[],
|
821
|
+
metadata=None,
|
822
|
+
index_mapping={1: 2, 3: 4}
|
823
|
+
)
|
824
|
+
|
825
|
+
self.assertIsNotNone(nodedef.index_mapping)
|
826
|
+
self.assertEqual(nodedef.index_mapping[1], 2)
|
827
|
+
self.assertEqual(nodedef.index_mapping[3], 4)
|
828
|
+
|
829
|
+
|
830
|
+
class TestGraphDefAndClone(unittest.TestCase):
|
831
|
+
"""Test graphdef and clone functions."""
|
832
|
+
|
833
|
+
def test_graphdef_function(self):
|
834
|
+
"""Test graphdef function returns correct GraphDef."""
|
835
|
+
model = brainstate.nn.Linear(2, 3)
|
836
|
+
graphdef = brainstate.graph.graphdef(model)
|
837
|
+
|
838
|
+
self.assertIsInstance(graphdef, brainstate.graph.NodeDef)
|
839
|
+
self.assertEqual(graphdef.type, brainstate.nn.Linear)
|
840
|
+
|
841
|
+
# Compare with flatten result
|
842
|
+
graphdef2, _ = brainstate.graph.flatten(model)
|
843
|
+
self.assertEqual(graphdef, graphdef2)
|
844
|
+
|
845
|
+
def test_clone_function(self):
|
846
|
+
"""Test clone creates a deep copy."""
|
847
|
+
model = brainstate.nn.Linear(2, 3)
|
848
|
+
cloned = brainstate.graph.clone(model)
|
849
|
+
|
850
|
+
# Check types
|
851
|
+
self.assertIsInstance(cloned, brainstate.nn.Linear)
|
852
|
+
self.assertIsNot(model, cloned)
|
853
|
+
|
854
|
+
# Check that states are not shared
|
855
|
+
self.assertIsNot(model.weight, cloned.weight)
|
856
|
+
|
857
|
+
# Modify original and check clone is unaffected
|
858
|
+
original_weight = cloned.weight.value['weight'].copy()
|
859
|
+
model.weight.value = jax.tree.map(lambda x: x + 1, model.weight.value)
|
860
|
+
|
861
|
+
# Clone should be unchanged
|
862
|
+
self.assertTrue(jnp.allclose(cloned.weight.value['weight'], original_weight))
|
863
|
+
|
864
|
+
def test_clone_with_shared_variables(self):
|
865
|
+
"""Test cloning preserves shared variable structure."""
|
866
|
+
|
867
|
+
class SharedModel(brainstate.nn.Module):
|
868
|
+
def __init__(self):
|
869
|
+
super().__init__()
|
870
|
+
self.shared_weight = brainstate.ParamState(jnp.ones((2, 2)))
|
871
|
+
self.layer1 = brainstate.nn.Linear(2, 2)
|
872
|
+
self.layer2 = brainstate.nn.Linear(2, 2)
|
873
|
+
# Share weights
|
874
|
+
self.layer2.weight = self.layer1.weight
|
875
|
+
|
876
|
+
model = SharedModel()
|
877
|
+
cloned = brainstate.graph.clone(model)
|
878
|
+
|
879
|
+
# Check that sharing is preserved
|
880
|
+
self.assertIs(cloned.layer1.weight, cloned.layer2.weight)
|
881
|
+
# But not shared with original
|
882
|
+
self.assertIsNot(cloned.layer1.weight, model.layer1.weight)
|
883
|
+
|
884
|
+
|
885
|
+
class TestNodesFunction(unittest.TestCase):
|
886
|
+
"""Test nodes function for filtering graph nodes."""
|
887
|
+
|
888
|
+
def test_nodes_without_filters(self):
|
889
|
+
"""Test nodes function without filters."""
|
890
|
+
|
891
|
+
class Model(brainstate.nn.Module):
|
892
|
+
def __init__(self):
|
893
|
+
super().__init__()
|
894
|
+
self.a = brainstate.nn.Linear(2, 3)
|
895
|
+
self.b = brainstate.nn.Linear(3, 4)
|
896
|
+
|
897
|
+
model = Model()
|
898
|
+
all_nodes = brainstate.graph.nodes(model)
|
899
|
+
|
900
|
+
# Should return all nodes as FlattedDict
|
901
|
+
self.assertIsInstance(all_nodes, brainstate.util.FlattedDict)
|
902
|
+
|
903
|
+
# Check that nodes are present
|
904
|
+
paths = [path for path, _ in all_nodes.items()]
|
905
|
+
self.assertIn(('a',), paths)
|
906
|
+
self.assertIn(('b',), paths)
|
907
|
+
self.assertIn((), paths) # The model itself
|
908
|
+
|
909
|
+
def test_nodes_with_filter(self):
|
910
|
+
"""Test nodes function with a single filter."""
|
911
|
+
|
912
|
+
class CustomModule(brainstate.nn.Module):
|
913
|
+
pass
|
914
|
+
|
915
|
+
class Model(brainstate.nn.Module):
|
916
|
+
def __init__(self):
|
917
|
+
super().__init__()
|
918
|
+
self.linear = brainstate.nn.Linear(2, 3)
|
919
|
+
self.custom = CustomModule()
|
920
|
+
|
921
|
+
model = Model()
|
922
|
+
|
923
|
+
# Filter for Linear modules
|
924
|
+
linear_nodes = brainstate.graph.nodes(
|
925
|
+
model,
|
926
|
+
lambda path, node: isinstance(node, brainstate.nn.Linear)
|
927
|
+
)
|
928
|
+
|
929
|
+
self.assertIsInstance(linear_nodes, brainstate.util.FlattedDict)
|
930
|
+
# Should only contain the Linear module
|
931
|
+
nodes_list = list(linear_nodes.values())
|
932
|
+
self.assertEqual(len(nodes_list), 1)
|
933
|
+
self.assertIsInstance(nodes_list[0], brainstate.nn.Linear)
|
934
|
+
|
935
|
+
def test_nodes_with_hierarchy(self):
|
936
|
+
"""Test nodes function with hierarchy limits."""
|
937
|
+
|
938
|
+
class Model(brainstate.nn.Module):
|
939
|
+
def __init__(self):
|
940
|
+
super().__init__()
|
941
|
+
self.layer1 = brainstate.nn.Linear(2, 3)
|
942
|
+
self.layer1.sublayer = brainstate.nn.Linear(3, 3)
|
943
|
+
|
944
|
+
model = Model()
|
945
|
+
|
946
|
+
# Get only level 1 nodes
|
947
|
+
level1_nodes = brainstate.graph.nodes(model, allowed_hierarchy=(1, 1))
|
948
|
+
paths = [path for path, _ in level1_nodes.items()]
|
949
|
+
|
950
|
+
self.assertIn(('layer1',), paths)
|
951
|
+
# Sublayer should not be included at level 1
|
952
|
+
self.assertNotIn(('layer1', 'sublayer'), paths)
|
953
|
+
|
954
|
+
|
955
|
+
class TestStatic(unittest.TestCase):
|
956
|
+
"""Test Static class functionality."""
|
957
|
+
|
958
|
+
def test_static_basic(self):
|
959
|
+
"""Test basic Static wrapper."""
|
960
|
+
from brainstate.graph._operation import Static
|
961
|
+
|
962
|
+
value = {'key': 'value'}
|
963
|
+
static = Static(value)
|
964
|
+
|
965
|
+
self.assertEqual(static.value, value)
|
966
|
+
self.assertIs(static.value, value)
|
967
|
+
|
968
|
+
def test_static_is_pytree_leaf(self):
|
969
|
+
"""Test that Static is treated as a pytree leaf."""
|
970
|
+
from brainstate.graph._operation import Static
|
971
|
+
|
972
|
+
static = Static({'key': 'value'})
|
973
|
+
|
974
|
+
# Should be treated as a leaf in pytree operations
|
975
|
+
leaves, treedef = jax.tree_util.tree_flatten(static)
|
976
|
+
self.assertEqual(len(leaves), 0) # Static has no leaves
|
977
|
+
|
978
|
+
# Test in a structure
|
979
|
+
tree = {'a': 1, 'b': static, 'c': [2, 3]}
|
980
|
+
leaves, treedef = jax.tree_util.tree_flatten(tree)
|
981
|
+
|
982
|
+
# static should not be in leaves since it's registered as static
|
983
|
+
self.assertNotIn(static, leaves)
|
984
|
+
|
985
|
+
def test_static_equality_and_hash(self):
|
986
|
+
"""Test Static equality and hashing."""
|
987
|
+
from brainstate.graph._operation import Static
|
988
|
+
|
989
|
+
static1 = Static(42)
|
990
|
+
static2 = Static(42)
|
991
|
+
static3 = Static(43)
|
992
|
+
|
993
|
+
# Dataclass frozen=True provides equality
|
994
|
+
self.assertEqual(static1, static2)
|
995
|
+
self.assertNotEqual(static1, static3)
|
996
|
+
|
997
|
+
# Can be hashed due to frozen=True
|
998
|
+
self.assertEqual(hash(static1), hash(static2))
|
999
|
+
self.assertNotEqual(hash(static1), hash(static3))
|
1000
|
+
|
1001
|
+
|
1002
|
+
class TestErrorHandling(unittest.TestCase):
|
1003
|
+
"""Test error handling and edge cases."""
|
1004
|
+
|
1005
|
+
def test_flatten_with_invalid_ref_index(self):
|
1006
|
+
"""Test flatten with invalid ref_index."""
|
1007
|
+
model = brainstate.nn.Linear(2, 3)
|
1008
|
+
|
1009
|
+
# Should raise assertion error with non-RefMap
|
1010
|
+
with self.assertRaises(AssertionError):
|
1011
|
+
brainstate.graph.flatten(model, ref_index={})
|
1012
|
+
|
1013
|
+
def test_unflatten_with_invalid_graphdef(self):
|
1014
|
+
"""Test unflatten with invalid graphdef."""
|
1015
|
+
state = brainstate.util.NestedDict({})
|
1016
|
+
|
1017
|
+
# Should raise assertion error with non-GraphDef
|
1018
|
+
with self.assertRaises(AssertionError):
|
1019
|
+
brainstate.graph.unflatten("not_a_graphdef", state)
|
1020
|
+
|
1021
|
+
def test_pop_states_without_filters(self):
|
1022
|
+
"""Test pop_states raises error without filters."""
|
1023
|
+
model = brainstate.nn.Linear(2, 3)
|
1024
|
+
|
1025
|
+
with self.assertRaises(ValueError) as context:
|
1026
|
+
brainstate.graph.pop_states(model)
|
1027
|
+
|
1028
|
+
self.assertIn('Expected at least one filter', str(context.exception))
|
1029
|
+
|
1030
|
+
def test_update_states_immutable_node(self):
|
1031
|
+
"""Test update_states on immutable pytree node."""
|
1032
|
+
# Create a pytree node (tuple is immutable)
|
1033
|
+
node = (1, 2, brainstate.ParamState(3))
|
1034
|
+
state = brainstate.util.NestedDict({0: brainstate.TreefyState(int, 10)})
|
1035
|
+
|
1036
|
+
# Should raise ValueError when trying to update immutable node
|
1037
|
+
with self.assertRaises(ValueError):
|
1038
|
+
brainstate.graph.update_states(node, state)
|
1039
|
+
|
1040
|
+
def test_get_node_impl_with_state(self):
|
1041
|
+
"""Test _get_node_impl raises error for State objects."""
|
1042
|
+
from brainstate.graph._operation import _get_node_impl
|
1043
|
+
|
1044
|
+
state = brainstate.ParamState(1)
|
1045
|
+
|
1046
|
+
with self.assertRaises(ValueError) as context:
|
1047
|
+
_get_node_impl(state)
|
1048
|
+
|
1049
|
+
self.assertIn('State is not a node', str(context.exception))
|
1050
|
+
|
1051
|
+
def test_split_with_non_exhaustive_filters(self):
|
1052
|
+
"""Test split with non-exhaustive filters."""
|
1053
|
+
from brainstate.graph._operation import _split_flatted
|
1054
|
+
|
1055
|
+
flatted = [(('a',), 1), (('b',), 2)]
|
1056
|
+
filters = (lambda path, value: value == 1,) # Only matches first item
|
1057
|
+
|
1058
|
+
# Should raise ValueError for non-exhaustive filters
|
1059
|
+
with self.assertRaises(ValueError) as context:
|
1060
|
+
_split_flatted(flatted, filters)
|
1061
|
+
|
1062
|
+
self.assertIn('Non-exhaustive filters', str(context.exception))
|
1063
|
+
|
1064
|
+
def test_invalid_filter_order(self):
|
1065
|
+
"""Test filters with ... not at the end."""
|
1066
|
+
from brainstate.graph._operation import _filters_to_predicates
|
1067
|
+
|
1068
|
+
# ... must be the last filter
|
1069
|
+
filters = (..., lambda p, v: True)
|
1070
|
+
|
1071
|
+
with self.assertRaises(ValueError) as context:
|
1072
|
+
_filters_to_predicates(filters)
|
1073
|
+
|
1074
|
+
self.assertIn('can only be used as the last filters', str(context.exception))
|
1075
|
+
|
1076
|
+
|
1077
|
+
class TestIntegration(unittest.TestCase):
|
1078
|
+
"""Integration tests for complex scenarios."""
|
1079
|
+
|
1080
|
+
def test_complex_graph_operations(self):
|
1081
|
+
"""Test complex graph with multiple levels and shared references."""
|
1082
|
+
|
1083
|
+
class SubModule(brainstate.nn.Module):
|
1084
|
+
def __init__(self):
|
1085
|
+
super().__init__()
|
1086
|
+
self.weight = brainstate.ParamState(jnp.ones((2, 2)))
|
1087
|
+
|
1088
|
+
class ComplexModel(brainstate.nn.Module):
|
1089
|
+
def __init__(self):
|
1090
|
+
super().__init__()
|
1091
|
+
self.shared = SubModule()
|
1092
|
+
self.layer1 = brainstate.nn.Linear(2, 3)
|
1093
|
+
self.layer2 = brainstate.nn.Linear(3, 4)
|
1094
|
+
self.layer2.shared_ref = self.shared # Create a reference
|
1095
|
+
self.nested = {
|
1096
|
+
'a': brainstate.nn.Linear(4, 5),
|
1097
|
+
'b': [brainstate.nn.Linear(5, 6), self.shared] # Another reference
|
1098
|
+
}
|
1099
|
+
|
1100
|
+
model = ComplexModel()
|
1101
|
+
|
1102
|
+
# Test flatten/unflatten preserves structure
|
1103
|
+
graphdef, state = brainstate.graph.treefy_split(model)
|
1104
|
+
reconstructed = brainstate.graph.treefy_merge(graphdef, state)
|
1105
|
+
|
1106
|
+
# Check shared references are preserved
|
1107
|
+
self.assertIs(reconstructed.shared, reconstructed.layer2.shared_ref)
|
1108
|
+
self.assertIs(reconstructed.shared, reconstructed.nested['b'][1])
|
1109
|
+
|
1110
|
+
# Test state updates
|
1111
|
+
new_state = jax.tree.map(lambda x: x * 2, state)
|
1112
|
+
brainstate.graph.update_states(model, new_state)
|
1113
|
+
|
1114
|
+
# Verify updates applied
|
1115
|
+
self.assertTrue(jnp.allclose(
|
1116
|
+
model.shared.weight.value,
|
1117
|
+
jnp.ones((2, 2)) * 2
|
1118
|
+
))
|
1119
|
+
|
1120
|
+
def test_recursive_structure(self):
|
1121
|
+
"""Test handling of recursive/circular references."""
|
1122
|
+
|
1123
|
+
class RecursiveModule(brainstate.nn.Module):
|
1124
|
+
def __init__(self):
|
1125
|
+
super().__init__()
|
1126
|
+
self.weight = brainstate.ParamState(1)
|
1127
|
+
self.child = None
|
1128
|
+
|
1129
|
+
# Create circular reference
|
1130
|
+
parent = RecursiveModule()
|
1131
|
+
child = RecursiveModule()
|
1132
|
+
parent.child = child
|
1133
|
+
child.child = parent # Circular reference
|
1134
|
+
|
1135
|
+
# Should handle circular references without infinite recursion
|
1136
|
+
graphdef, state = brainstate.graph.treefy_split(parent)
|
1137
|
+
|
1138
|
+
# Should be able to reconstruct
|
1139
|
+
reconstructed = brainstate.graph.treefy_merge(graphdef, state)
|
1140
|
+
|
1141
|
+
# Check structure is preserved
|
1142
|
+
self.assertIsNotNone(reconstructed.child)
|
1143
|
+
self.assertIs(reconstructed.child.child, reconstructed)
|
1144
|
+
|
1145
|
+
|
1146
|
+
if __name__ == '__main__':
|
1147
|
+
absltest.main()
|