brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1147 +1,1147 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
- from collections.abc import Callable
18
- from threading import Thread
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- from absl.testing import absltest, parameterized
23
-
24
- import pytest
25
- pytest.skip("skipping tests", allow_module_level=True)
26
-
27
- import brainstate
28
- import braintools
29
- import brainpy
30
-
31
-
32
- class TestIter(unittest.TestCase):
33
- def test1(self):
34
- class Model(brainstate.nn.Module):
35
- def __init__(self):
36
- super().__init__()
37
- self.a = brainstate.nn.Linear(1, 2)
38
- self.b = brainstate.nn.Linear(2, 3)
39
- self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
40
- self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
41
- self.b.a = brainpy.LIF(2)
42
-
43
- for path, node in brainstate.graph.iter_leaf(Model()):
44
- print(path, node)
45
- for path, node in brainstate.graph.iter_node(Model()):
46
- print(path, node)
47
- for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
48
- print(path, node)
49
- for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
50
- print(path, node)
51
-
52
- def test_iter_leaf_v1(self):
53
- class Linear(brainstate.nn.Module):
54
- def __init__(self, din, dout):
55
- super().__init__()
56
- self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
57
- self.bias = brainstate.ParamState(brainstate.random.randn(dout))
58
- self.a = 1
59
-
60
- module = Linear(3, 4)
61
- graph = [module, module]
62
-
63
- num = 0
64
- for path, value in brainstate.graph.iter_leaf(graph):
65
- print(path, type(value).__name__)
66
- num += 1
67
-
68
- assert num == 3
69
-
70
- def test_iter_node_v1(self):
71
- class Model(brainstate.nn.Module):
72
- def __init__(self):
73
- super().__init__()
74
- self.a = brainstate.nn.Linear(1, 2)
75
- self.b = brainstate.nn.Linear(2, 3)
76
- self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
77
- self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
78
- self.b.a = brainpy.LIF(2)
79
-
80
- model = Model()
81
-
82
- num = 0
83
- for path, node in brainstate.graph.iter_node([model, model]):
84
- print(path, node.__class__.__name__)
85
- num += 1
86
- assert num == 8
87
-
88
-
89
- class List(brainstate.nn.Module):
90
- def __init__(self, items):
91
- super().__init__()
92
- self.items = list(items)
93
-
94
- def __getitem__(self, idx):
95
- return self.items[idx]
96
-
97
- def __setitem__(self, idx, value):
98
- self.items[idx] = value
99
-
100
-
101
- class Dict(brainstate.nn.Module):
102
- def __init__(self, *args, **kwargs):
103
- super().__init__()
104
- self.items = dict(*args, **kwargs)
105
-
106
- def __getitem__(self, key):
107
- return self.items[key]
108
-
109
- def __setitem__(self, key, value):
110
- self.items[key] = value
111
-
112
-
113
- class StatefulLinear(brainstate.nn.Module):
114
- def __init__(self, din, dout):
115
- super().__init__()
116
- self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
117
- self.b = brainstate.ParamState(jnp.zeros((dout,)))
118
- self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
119
-
120
- def increment(self):
121
- self.count.value += 1
122
-
123
- def __call__(self, x):
124
- self.count.value += 1
125
- return x @ self.w.value + self.b.value
126
-
127
-
128
- class TestGraphUtils(absltest.TestCase):
129
- def test_flatten_treey_state(self):
130
- a = {'a': 1, 'b': brainstate.ParamState(2)}
131
- g = [a, 3, a, brainstate.ParamState(4)]
132
-
133
- refmap = brainstate.graph.RefMap()
134
- graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
135
-
136
- states[0]['b'].value = 2
137
- states[3].value = 4
138
-
139
- assert isinstance(states[0]['b'], brainstate.TreefyState)
140
- assert isinstance(states[3], brainstate.TreefyState)
141
- assert isinstance(states, brainstate.util.NestedDict)
142
- assert len(refmap) == 2
143
- assert a['b'] in refmap
144
- assert g[3] in refmap
145
-
146
- def test_flatten(self):
147
- a = {'a': 1, 'b': brainstate.ParamState(2)}
148
- g = [a, 3, a, brainstate.ParamState(4)]
149
-
150
- refmap = brainstate.graph.RefMap()
151
- graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
152
-
153
- states[0]['b'].value = 2
154
- states[3].value = 4
155
-
156
- assert isinstance(states[0]['b'], brainstate.State)
157
- assert isinstance(states[3], brainstate.State)
158
- assert len(refmap) == 2
159
- assert a['b'] in refmap
160
- assert g[3] in refmap
161
-
162
- def test_unflatten_pytree(self):
163
- a = {'a': 1, 'b': brainstate.ParamState(2)}
164
- g = [a, 3, a, brainstate.ParamState(4)]
165
-
166
- graphdef, references = brainstate.graph.treefy_split(g)
167
- g = brainstate.graph.treefy_merge(graphdef, references)
168
-
169
- assert g[0] is not g[2]
170
-
171
- def test_unflatten_empty(self):
172
- a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
173
- g = List([a, 3, a, brainstate.ParamState(4)])
174
-
175
- graphdef, references = brainstate.graph.treefy_split(g)
176
-
177
- with self.assertRaisesRegex(ValueError, 'Expected key'):
178
- brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
179
-
180
- def test_module_list(self):
181
- ls = [
182
- brainstate.nn.Linear(2, 2),
183
- brainstate.nn.BatchNorm1d([10, 2]),
184
- ]
185
- graphdef, statetree = brainstate.graph.treefy_split(ls)
186
-
187
- assert statetree[0]['weight'].value['weight'].shape == (2, 2)
188
- assert statetree[0]['weight'].value['bias'].shape == (2,)
189
- assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
190
- assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
191
- assert statetree[1]['running_mean'].value.shape == (1, 2,)
192
- assert statetree[1]['running_var'].value.shape == (1, 2)
193
-
194
- def test_shared_variables(self):
195
- v = brainstate.ParamState(1)
196
- g = [v, v]
197
-
198
- graphdef, statetree = brainstate.graph.treefy_split(g)
199
- assert len(statetree.to_flat()) == 1
200
-
201
- g2 = brainstate.graph.treefy_merge(graphdef, statetree)
202
- assert g2[0] is g2[1]
203
-
204
- def test_tied_weights(self):
205
- class Foo(brainstate.nn.Module):
206
- def __init__(self) -> None:
207
- super().__init__()
208
- self.bar = brainstate.nn.Linear(2, 2)
209
- self.baz = brainstate.nn.Linear(2, 2)
210
-
211
- # tie the weights
212
- self.baz.weight = self.bar.weight
213
-
214
- node = Foo()
215
- graphdef, state = brainstate.graph.treefy_split(node)
216
-
217
- assert len(state.to_flat()) == 1
218
-
219
- node2 = brainstate.graph.treefy_merge(graphdef, state)
220
-
221
- assert node2.bar.weight is node2.baz.weight
222
-
223
- def test_tied_weights_example(self):
224
- class LinearTranspose(brainstate.nn.Module):
225
- def __init__(self, dout: int, din: int, ) -> None:
226
- super().__init__()
227
- self.kernel = brainstate.ParamState(braintools.init.LecunNormal()((dout, din)))
228
-
229
- def __call__(self, x):
230
- return x @ self.kernel.value.T
231
-
232
- class Encoder(brainstate.nn.Module):
233
- def __init__(self, ) -> None:
234
- super().__init__()
235
- self.embed = brainstate.nn.Embedding(10, 2)
236
- self.linear_out = LinearTranspose(10, 2)
237
-
238
- # tie the weights
239
- self.linear_out.kernel = self.embed.weight
240
-
241
- def __call__(self, x):
242
- x = self.embed(x)
243
- return self.linear_out(x)
244
-
245
- model = Encoder()
246
- graphdef, state = brainstate.graph.treefy_split(model)
247
-
248
- assert len(state.to_flat()) == 1
249
-
250
- x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
251
- y = model(x)
252
-
253
- assert y.shape == (2, 10)
254
-
255
- def test_state_variables_not_shared_with_graph(self):
256
- class Foo(brainstate.graph.Node):
257
- def __init__(self):
258
- self.a = brainstate.ParamState(1)
259
-
260
- m = Foo()
261
- graphdef, statetree = brainstate.graph.treefy_split(m)
262
-
263
- assert isinstance(m.a, brainstate.ParamState)
264
- assert issubclass(statetree.a.type, brainstate.ParamState)
265
- assert m.a is not statetree.a
266
- assert m.a.value == statetree.a.value
267
-
268
- m2 = brainstate.graph.treefy_merge(graphdef, statetree)
269
-
270
- assert isinstance(m2.a, brainstate.ParamState)
271
- assert issubclass(statetree.a.type, brainstate.ParamState)
272
- assert m2.a is not statetree.a
273
- assert m2.a.value == statetree.a.value
274
-
275
- def test_shared_state_variables_not_shared_with_graph(self):
276
- class Foo(brainstate.graph.Node):
277
- def __init__(self):
278
- p = brainstate.ParamState(1)
279
- self.a = p
280
- self.b = p
281
-
282
- m = Foo()
283
- graphdef, state = brainstate.graph.treefy_split(m)
284
-
285
- assert isinstance(m.a, brainstate.ParamState)
286
- assert isinstance(m.b, brainstate.ParamState)
287
- assert issubclass(state.a.type, brainstate.ParamState)
288
- assert 'b' not in state
289
- assert m.a is not state.a
290
- assert m.b is not state.a
291
- assert m.a.value == state.a.value
292
- assert m.b.value == state.a.value
293
-
294
- m2 = brainstate.graph.treefy_merge(graphdef, state)
295
-
296
- assert isinstance(m2.a, brainstate.ParamState)
297
- assert isinstance(m2.b, brainstate.ParamState)
298
- assert issubclass(state.a.type, brainstate.ParamState)
299
- assert m2.a is not state.a
300
- assert m2.b is not state.a
301
- assert m2.a.value == state.a.value
302
- assert m2.b.value == state.a.value
303
- assert m2.a is m2.b
304
-
305
- def test_pytree_node(self):
306
- @brainstate.util.dataclass
307
- class Tree:
308
- a: brainstate.ParamState
309
- b: str = brainstate.util.field(pytree_node=False)
310
-
311
- class Foo(brainstate.graph.Node):
312
- def __init__(self):
313
- self.tree = Tree(brainstate.ParamState(1), 'a')
314
-
315
- m = Foo()
316
-
317
- graphdef, state = brainstate.graph.treefy_split(m)
318
-
319
- assert 'tree' in state
320
- assert 'a' in state.tree
321
- assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
322
-
323
- m2 = brainstate.graph.treefy_merge(graphdef, state)
324
-
325
- assert isinstance(m2.tree, Tree)
326
- assert m2.tree.a.value == 1
327
- assert m2.tree.b == 'a'
328
- assert m2.tree.a is not m.tree.a
329
- assert m2.tree is not m.tree
330
-
331
-
332
- class SimpleModule(brainstate.nn.Module):
333
- pass
334
-
335
-
336
- class SimplePyTreeModule(brainstate.nn.Module):
337
- pass
338
-
339
-
340
- class TestThreading(parameterized.TestCase):
341
-
342
- @parameterized.parameters(
343
- (SimpleModule,),
344
- (SimplePyTreeModule,),
345
- )
346
- def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
347
- x = module_fn()
348
-
349
- class MyThread(Thread):
350
-
351
- def run(self) -> None:
352
- brainstate.graph.treefy_split(x)
353
-
354
- thread = MyThread()
355
- thread.start()
356
- thread.join()
357
-
358
-
359
- class TestGraphOperation(unittest.TestCase):
360
- def test1(self):
361
- class MyNode(brainstate.graph.Node):
362
- def __init__(self):
363
- self.a = brainstate.nn.Linear(2, 3)
364
- self.b = brainstate.nn.Linear(3, 2)
365
- self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
366
- self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
367
-
368
- graphdef, statetree = brainstate.graph.flatten(MyNode())
369
- # print(graphdef)
370
- print(statetree)
371
- # print(brainstate.graph.unflatten(graphdef, statetree))
372
-
373
- def test_split(self):
374
- class Foo(brainstate.graph.Node):
375
- def __init__(self):
376
- self.a = brainstate.nn.Linear(2, 2)
377
- self.b = brainstate.nn.BatchNorm1d([10, 2])
378
-
379
- node = Foo()
380
- graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
381
-
382
- print(params)
383
- print(jax.tree.map(jnp.shape, params))
384
-
385
- print(jax.tree.map(jnp.shape, others))
386
-
387
- def test_merge(self):
388
- class Foo(brainstate.graph.Node):
389
- def __init__(self):
390
- self.a = brainstate.nn.Linear(2, 2)
391
- self.b = brainstate.nn.BatchNorm1d([10, 2])
392
-
393
- node = Foo()
394
- graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
395
-
396
- new_node = brainstate.graph.treefy_merge(graphdef, params, others)
397
-
398
- assert isinstance(new_node, Foo)
399
- assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
400
- assert isinstance(new_node.a, brainstate.nn.Linear)
401
-
402
- def test_update_states(self):
403
- x = jnp.ones((1, 2))
404
- y = jnp.ones((1, 3))
405
- model = brainstate.nn.Linear(2, 3)
406
-
407
- def loss_fn(x, y):
408
- return jnp.mean((y - model(x)) ** 2)
409
-
410
- def sgd(ps, gs):
411
- updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
412
- ps.value = updates
413
-
414
- prev_loss = loss_fn(x, y)
415
- weights = model.states()
416
- grads = brainstate.augment.grad(loss_fn, weights)(x, y)
417
- for key, val in grads.items():
418
- sgd(weights[key], val)
419
- assert loss_fn(x, y) < prev_loss
420
-
421
- def test_pop_states(self):
422
- class Model(brainstate.nn.Module):
423
- def __init__(self):
424
- super().__init__()
425
- self.a = brainstate.nn.Linear(2, 3)
426
- self.b = brainpy.LIF([10, 2])
427
-
428
- model = Model()
429
- with brainstate.catch_new_states('new'):
430
- brainstate.nn.init_all_states(model)
431
- # print(model.states())
432
- self.assertTrue(len(model.states()) == 2)
433
- model_states = brainstate.graph.pop_states(model, 'new')
434
- print(model_states)
435
- self.assertTrue(len(model.states()) == 1)
436
- assert not hasattr(model.b, 'V')
437
- # print(model.states())
438
-
439
- def test_treefy_split(self):
440
- class MLP(brainstate.graph.Node):
441
- def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
442
- self.input = brainstate.nn.Linear(din, dmid)
443
- self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
444
- self.output = brainstate.nn.Linear(dmid, dout)
445
-
446
- def __call__(self, x):
447
- x = brainstate.functional.relu(self.input(x))
448
- for layer in self.layers:
449
- x = brainstate.functional.relu(layer(x))
450
- return self.output(x)
451
-
452
- model = MLP(2, 1, 3)
453
- graph_def, treefy_states = brainstate.graph.treefy_split(model)
454
-
455
- print(graph_def)
456
- print(treefy_states)
457
-
458
- # states = brainstate.graph.states(model)
459
- # print(states)
460
- # nest_states = states.to_nest()
461
- # print(nest_states)
462
-
463
- def test_states(self):
464
- class MLP(brainstate.graph.Node):
465
- def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
466
- self.input = brainstate.nn.Linear(din, dmid)
467
- self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
468
- self.output = brainpy.LIF(dout)
469
-
470
- def __call__(self, x):
471
- x = brainstate.functional.relu(self.input(x))
472
- for layer in self.layers:
473
- x = brainstate.functional.relu(layer(x))
474
- return self.output(x)
475
-
476
- model = brainstate.nn.init_all_states(MLP(2, 1, 3))
477
- states = brainstate.graph.states(model)
478
- print(states)
479
- nest_states = states.to_nest()
480
- print(nest_states)
481
-
482
- params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
483
- print(params)
484
- print(others)
485
-
486
-
487
- class TestRefMap(unittest.TestCase):
488
- """Test RefMap class functionality."""
489
-
490
- def test_refmap_basic_operations(self):
491
- """Test basic RefMap operations."""
492
- ref_map = brainstate.graph.RefMap()
493
-
494
- # Test empty RefMap
495
- self.assertEqual(len(ref_map), 0)
496
- self.assertFalse(object() in ref_map)
497
-
498
- # Test adding items
499
- obj1 = object()
500
- obj2 = object()
501
- ref_map[obj1] = 'value1'
502
- ref_map[obj2] = 'value2'
503
-
504
- self.assertEqual(len(ref_map), 2)
505
- self.assertTrue(obj1 in ref_map)
506
- self.assertTrue(obj2 in ref_map)
507
- self.assertEqual(ref_map[obj1], 'value1')
508
- self.assertEqual(ref_map[obj2], 'value2')
509
-
510
- # Test iteration
511
- keys = list(ref_map)
512
- self.assertIn(obj1, keys)
513
- self.assertIn(obj2, keys)
514
-
515
- # Test deletion
516
- del ref_map[obj1]
517
- self.assertEqual(len(ref_map), 1)
518
- self.assertFalse(obj1 in ref_map)
519
- self.assertTrue(obj2 in ref_map)
520
-
521
- def test_refmap_initialization_with_mapping(self):
522
- """Test RefMap initialization with a mapping."""
523
- obj1, obj2 = object(), object()
524
- mapping = {obj1: 'value1', obj2: 'value2'}
525
- ref_map = brainstate.graph.RefMap(mapping)
526
-
527
- self.assertEqual(len(ref_map), 2)
528
- self.assertEqual(ref_map[obj1], 'value1')
529
- self.assertEqual(ref_map[obj2], 'value2')
530
-
531
- def test_refmap_initialization_with_iterable(self):
532
- """Test RefMap initialization with an iterable."""
533
- obj1, obj2 = object(), object()
534
- pairs = [(obj1, 'value1'), (obj2, 'value2')]
535
- ref_map = brainstate.graph.RefMap(pairs)
536
-
537
- self.assertEqual(len(ref_map), 2)
538
- self.assertEqual(ref_map[obj1], 'value1')
539
- self.assertEqual(ref_map[obj2], 'value2')
540
-
541
- def test_refmap_same_object_different_instances(self):
542
- """Test RefMap handles same content objects with different ids."""
543
- # Create two lists with same content but different ids
544
- list1 = [1, 2, 3]
545
- list2 = [1, 2, 3]
546
-
547
- ref_map = brainstate.graph.RefMap()
548
- ref_map[list1] = 'list1'
549
- ref_map[list2] = 'list2'
550
-
551
- # Should have 2 entries since they have different ids
552
- self.assertEqual(len(ref_map), 2)
553
- self.assertEqual(ref_map[list1], 'list1')
554
- self.assertEqual(ref_map[list2], 'list2')
555
-
556
- def test_refmap_update(self):
557
- """Test RefMap update method."""
558
- obj1, obj2, obj3 = object(), object(), object()
559
- ref_map = brainstate.graph.RefMap()
560
- ref_map[obj1] = 'value1'
561
-
562
- # Update with mapping
563
- ref_map.update({obj2: 'value2', obj3: 'value3'})
564
- self.assertEqual(len(ref_map), 3)
565
-
566
- # Update existing key
567
- ref_map[obj1] = 'new_value1'
568
- self.assertEqual(ref_map[obj1], 'new_value1')
569
-
570
- def test_refmap_str_repr(self):
571
- """Test RefMap string representation."""
572
- ref_map = brainstate.graph.RefMap()
573
- obj = object()
574
- ref_map[obj] = 'value'
575
-
576
- str_repr = str(ref_map)
577
- self.assertIsInstance(str_repr, str)
578
- # Check that __str__ calls __repr__
579
- self.assertEqual(str_repr, repr(ref_map))
580
-
581
-
582
- class TestHelperFunctions(unittest.TestCase):
583
- """Test helper functions in the _operation module."""
584
-
585
- def test_is_state_leaf(self):
586
- """Test _is_state_leaf function."""
587
- from brainstate.graph._operation import _is_state_leaf
588
-
589
- # Create TreefyState instance
590
- state = brainstate.ParamState(1)
591
- treefy_state = state.to_state_ref()
592
-
593
- self.assertTrue(_is_state_leaf(treefy_state))
594
- self.assertFalse(_is_state_leaf(state))
595
- self.assertFalse(_is_state_leaf(1))
596
- self.assertFalse(_is_state_leaf("string"))
597
- self.assertFalse(_is_state_leaf(None))
598
-
599
- def test_is_node_leaf(self):
600
- """Test _is_node_leaf function."""
601
- from brainstate.graph._operation import _is_node_leaf
602
-
603
- state = brainstate.ParamState(1)
604
-
605
- self.assertTrue(_is_node_leaf(state))
606
- self.assertFalse(_is_node_leaf(1))
607
- self.assertFalse(_is_node_leaf("string"))
608
- self.assertFalse(_is_node_leaf(None))
609
-
610
- def test_is_node(self):
611
- """Test _is_node function."""
612
- from brainstate.graph._operation import _is_node
613
-
614
- # Test with graph nodes
615
- node = brainstate.nn.Module()
616
- self.assertTrue(_is_node(node))
617
-
618
- # Test with pytree nodes
619
- self.assertTrue(_is_node([1, 2, 3]))
620
- self.assertTrue(_is_node({'a': 1}))
621
-
622
- # Test with non-nodes
623
- self.assertFalse(_is_node(1))
624
- self.assertFalse(_is_node("string"))
625
-
626
- def test_is_pytree_node(self):
627
- """Test _is_pytree_node function."""
628
- from brainstate.graph._operation import _is_pytree_node
629
-
630
- self.assertTrue(_is_pytree_node([1, 2, 3]))
631
- self.assertTrue(_is_pytree_node({'a': 1}))
632
- self.assertTrue(_is_pytree_node((1, 2)))
633
-
634
- self.assertFalse(_is_pytree_node(1))
635
- self.assertFalse(_is_pytree_node("string"))
636
- self.assertFalse(_is_pytree_node(jnp.array([1, 2])))
637
-
638
- def test_is_graph_node(self):
639
- """Test _is_graph_node function."""
640
- from brainstate.graph._operation import _is_graph_node
641
-
642
- # Register a custom type for testing
643
- class CustomNode:
644
- pass
645
-
646
- # Graph nodes are those registered with register_graph_node_type
647
- node = brainstate.nn.Module()
648
- self.assertTrue(_is_graph_node(node))
649
-
650
- # Non-registered types
651
- self.assertFalse(_is_graph_node([1, 2, 3]))
652
- self.assertFalse(_is_graph_node({'a': 1}))
653
- self.assertFalse(_is_graph_node(CustomNode()))
654
-
655
-
656
- class TestRegisterGraphNodeType(unittest.TestCase):
657
- """Test register_graph_node_type functionality."""
658
-
659
- def test_register_custom_node_type(self):
660
- """Test registering a custom graph node type."""
661
- from brainstate.graph._operation import _is_graph_node, _get_node_impl
662
-
663
- class CustomNode:
664
- def __init__(self):
665
- self.data = {}
666
-
667
- def flatten_custom(node):
668
- return list(node.data.items()), None
669
-
670
- def set_key_custom(node, key, value):
671
- node.data[key] = value
672
-
673
- def pop_key_custom(node, key):
674
- return node.data.pop(key)
675
-
676
- def create_empty_custom(metadata):
677
- return CustomNode()
678
-
679
- def clear_custom(node):
680
- node.data.clear()
681
-
682
- # Register the custom node type
683
- brainstate.graph.register_graph_node_type(
684
- CustomNode,
685
- flatten_custom,
686
- set_key_custom,
687
- pop_key_custom,
688
- create_empty_custom,
689
- clear_custom
690
- )
691
-
692
- # Test that the node is recognized
693
- node = CustomNode()
694
- self.assertTrue(_is_graph_node(node))
695
-
696
- # Test node operations
697
- node.data['key1'] = 'value1'
698
- node_impl = _get_node_impl(node)
699
-
700
- # Test flatten
701
- items, metadata = node_impl.flatten(node)
702
- self.assertEqual(list(items), [('key1', 'value1')])
703
-
704
- # Test set_key
705
- node_impl.set_key(node, 'key2', 'value2')
706
- self.assertEqual(node.data['key2'], 'value2')
707
-
708
- # Test pop_key
709
- value = node_impl.pop_key(node, 'key1')
710
- self.assertEqual(value, 'value1')
711
- self.assertNotIn('key1', node.data)
712
-
713
- # Test create_empty
714
- new_node = node_impl.create_empty(None)
715
- self.assertIsInstance(new_node, CustomNode)
716
- self.assertEqual(new_node.data, {})
717
-
718
- # Test clear
719
- node_impl.clear(node)
720
- self.assertEqual(node.data, {})
721
-
722
-
723
- class TestHashableMapping(unittest.TestCase):
724
- """Test HashableMapping class."""
725
-
726
- def test_hashable_mapping_basic(self):
727
- """Test basic HashableMapping operations."""
728
- from brainstate.graph._operation import HashableMapping
729
-
730
- mapping = {'a': 1, 'b': 2}
731
- hm = HashableMapping(mapping)
732
-
733
- # Test basic operations
734
- self.assertEqual(len(hm), 2)
735
- self.assertTrue('a' in hm)
736
- self.assertFalse('c' in hm)
737
- self.assertEqual(hm['a'], 1)
738
- self.assertEqual(hm['b'], 2)
739
-
740
- # Test iteration
741
- keys = list(hm)
742
- self.assertEqual(set(keys), {'a', 'b'})
743
-
744
- def test_hashable_mapping_hash(self):
745
- """Test HashableMapping hashing."""
746
- from brainstate.graph._operation import HashableMapping
747
-
748
- hm1 = HashableMapping({'a': 1, 'b': 2})
749
- hm2 = HashableMapping({'a': 1, 'b': 2})
750
- hm3 = HashableMapping({'a': 1, 'b': 3})
751
-
752
- # Equal mappings should have same hash
753
- self.assertEqual(hash(hm1), hash(hm2))
754
- self.assertEqual(hm1, hm2)
755
-
756
- # Different mappings should not be equal
757
- self.assertNotEqual(hm1, hm3)
758
-
759
- # Can be used in sets
760
- s = {hm1, hm2, hm3}
761
- self.assertEqual(len(s), 2) # hm1 and hm2 are the same
762
-
763
- def test_hashable_mapping_from_iterable(self):
764
- """Test HashableMapping creation from iterable."""
765
- from brainstate.graph._operation import HashableMapping
766
-
767
- pairs = [('a', 1), ('b', 2)]
768
- hm = HashableMapping(pairs)
769
-
770
- self.assertEqual(len(hm), 2)
771
- self.assertEqual(hm['a'], 1)
772
- self.assertEqual(hm['b'], 2)
773
-
774
-
775
- class TestNodeDefAndNodeRef(unittest.TestCase):
776
- """Test NodeDef and NodeRef classes."""
777
-
778
- def test_noderef_creation(self):
779
- """Test NodeRef creation and attributes."""
780
- node_ref = brainstate.graph.NodeRef(
781
- type=brainstate.nn.Module,
782
- index=42
783
- )
784
-
785
- self.assertEqual(node_ref.type, brainstate.nn.Module)
786
- self.assertEqual(node_ref.index, 42)
787
-
788
- def test_nodedef_creation(self):
789
- """Test NodeDef creation and attributes."""
790
- from brainstate.graph._operation import HashableMapping
791
-
792
- nodedef = brainstate.graph.NodeDef.create(
793
- type=brainstate.nn.Module,
794
- index=1,
795
- attributes=('a', 'b'),
796
- subgraphs=[],
797
- static_fields=[('static', 'value')],
798
- leaves=[],
799
- metadata=None,
800
- index_mapping=None
801
- )
802
-
803
- self.assertEqual(nodedef.type, brainstate.nn.Module)
804
- self.assertEqual(nodedef.index, 1)
805
- self.assertEqual(nodedef.attributes, ('a', 'b'))
806
- self.assertIsInstance(nodedef.subgraphs, HashableMapping)
807
- self.assertIsInstance(nodedef.static_fields, HashableMapping)
808
- self.assertEqual(nodedef.static_fields['static'], 'value')
809
- self.assertIsNone(nodedef.metadata)
810
- self.assertIsNone(nodedef.index_mapping)
811
-
812
- def test_nodedef_with_index_mapping(self):
813
- """Test NodeDef with index_mapping."""
814
- nodedef = brainstate.graph.NodeDef.create(
815
- type=brainstate.nn.Module,
816
- index=1,
817
- attributes=(),
818
- subgraphs=[],
819
- static_fields=[],
820
- leaves=[],
821
- metadata=None,
822
- index_mapping={1: 2, 3: 4}
823
- )
824
-
825
- self.assertIsNotNone(nodedef.index_mapping)
826
- self.assertEqual(nodedef.index_mapping[1], 2)
827
- self.assertEqual(nodedef.index_mapping[3], 4)
828
-
829
-
830
- class TestGraphDefAndClone(unittest.TestCase):
831
- """Test graphdef and clone functions."""
832
-
833
- def test_graphdef_function(self):
834
- """Test graphdef function returns correct GraphDef."""
835
- model = brainstate.nn.Linear(2, 3)
836
- graphdef = brainstate.graph.graphdef(model)
837
-
838
- self.assertIsInstance(graphdef, brainstate.graph.NodeDef)
839
- self.assertEqual(graphdef.type, brainstate.nn.Linear)
840
-
841
- # Compare with flatten result
842
- graphdef2, _ = brainstate.graph.flatten(model)
843
- self.assertEqual(graphdef, graphdef2)
844
-
845
- def test_clone_function(self):
846
- """Test clone creates a deep copy."""
847
- model = brainstate.nn.Linear(2, 3)
848
- cloned = brainstate.graph.clone(model)
849
-
850
- # Check types
851
- self.assertIsInstance(cloned, brainstate.nn.Linear)
852
- self.assertIsNot(model, cloned)
853
-
854
- # Check that states are not shared
855
- self.assertIsNot(model.weight, cloned.weight)
856
-
857
- # Modify original and check clone is unaffected
858
- original_weight = cloned.weight.value['weight'].copy()
859
- model.weight.value = jax.tree.map(lambda x: x + 1, model.weight.value)
860
-
861
- # Clone should be unchanged
862
- self.assertTrue(jnp.allclose(cloned.weight.value['weight'], original_weight))
863
-
864
- def test_clone_with_shared_variables(self):
865
- """Test cloning preserves shared variable structure."""
866
-
867
- class SharedModel(brainstate.nn.Module):
868
- def __init__(self):
869
- super().__init__()
870
- self.shared_weight = brainstate.ParamState(jnp.ones((2, 2)))
871
- self.layer1 = brainstate.nn.Linear(2, 2)
872
- self.layer2 = brainstate.nn.Linear(2, 2)
873
- # Share weights
874
- self.layer2.weight = self.layer1.weight
875
-
876
- model = SharedModel()
877
- cloned = brainstate.graph.clone(model)
878
-
879
- # Check that sharing is preserved
880
- self.assertIs(cloned.layer1.weight, cloned.layer2.weight)
881
- # But not shared with original
882
- self.assertIsNot(cloned.layer1.weight, model.layer1.weight)
883
-
884
-
885
- class TestNodesFunction(unittest.TestCase):
886
- """Test nodes function for filtering graph nodes."""
887
-
888
- def test_nodes_without_filters(self):
889
- """Test nodes function without filters."""
890
-
891
- class Model(brainstate.nn.Module):
892
- def __init__(self):
893
- super().__init__()
894
- self.a = brainstate.nn.Linear(2, 3)
895
- self.b = brainstate.nn.Linear(3, 4)
896
-
897
- model = Model()
898
- all_nodes = brainstate.graph.nodes(model)
899
-
900
- # Should return all nodes as FlattedDict
901
- self.assertIsInstance(all_nodes, brainstate.util.FlattedDict)
902
-
903
- # Check that nodes are present
904
- paths = [path for path, _ in all_nodes.items()]
905
- self.assertIn(('a',), paths)
906
- self.assertIn(('b',), paths)
907
- self.assertIn((), paths) # The model itself
908
-
909
- def test_nodes_with_filter(self):
910
- """Test nodes function with a single filter."""
911
-
912
- class CustomModule(brainstate.nn.Module):
913
- pass
914
-
915
- class Model(brainstate.nn.Module):
916
- def __init__(self):
917
- super().__init__()
918
- self.linear = brainstate.nn.Linear(2, 3)
919
- self.custom = CustomModule()
920
-
921
- model = Model()
922
-
923
- # Filter for Linear modules
924
- linear_nodes = brainstate.graph.nodes(
925
- model,
926
- lambda path, node: isinstance(node, brainstate.nn.Linear)
927
- )
928
-
929
- self.assertIsInstance(linear_nodes, brainstate.util.FlattedDict)
930
- # Should only contain the Linear module
931
- nodes_list = list(linear_nodes.values())
932
- self.assertEqual(len(nodes_list), 1)
933
- self.assertIsInstance(nodes_list[0], brainstate.nn.Linear)
934
-
935
- def test_nodes_with_hierarchy(self):
936
- """Test nodes function with hierarchy limits."""
937
-
938
- class Model(brainstate.nn.Module):
939
- def __init__(self):
940
- super().__init__()
941
- self.layer1 = brainstate.nn.Linear(2, 3)
942
- self.layer1.sublayer = brainstate.nn.Linear(3, 3)
943
-
944
- model = Model()
945
-
946
- # Get only level 1 nodes
947
- level1_nodes = brainstate.graph.nodes(model, allowed_hierarchy=(1, 1))
948
- paths = [path for path, _ in level1_nodes.items()]
949
-
950
- self.assertIn(('layer1',), paths)
951
- # Sublayer should not be included at level 1
952
- self.assertNotIn(('layer1', 'sublayer'), paths)
953
-
954
-
955
- class TestStatic(unittest.TestCase):
956
- """Test Static class functionality."""
957
-
958
- def test_static_basic(self):
959
- """Test basic Static wrapper."""
960
- from brainstate.graph._operation import Static
961
-
962
- value = {'key': 'value'}
963
- static = Static(value)
964
-
965
- self.assertEqual(static.value, value)
966
- self.assertIs(static.value, value)
967
-
968
- def test_static_is_pytree_leaf(self):
969
- """Test that Static is treated as a pytree leaf."""
970
- from brainstate.graph._operation import Static
971
-
972
- static = Static({'key': 'value'})
973
-
974
- # Should be treated as a leaf in pytree operations
975
- leaves, treedef = jax.tree_util.tree_flatten(static)
976
- self.assertEqual(len(leaves), 0) # Static has no leaves
977
-
978
- # Test in a structure
979
- tree = {'a': 1, 'b': static, 'c': [2, 3]}
980
- leaves, treedef = jax.tree_util.tree_flatten(tree)
981
-
982
- # static should not be in leaves since it's registered as static
983
- self.assertNotIn(static, leaves)
984
-
985
- def test_static_equality_and_hash(self):
986
- """Test Static equality and hashing."""
987
- from brainstate.graph._operation import Static
988
-
989
- static1 = Static(42)
990
- static2 = Static(42)
991
- static3 = Static(43)
992
-
993
- # Dataclass frozen=True provides equality
994
- self.assertEqual(static1, static2)
995
- self.assertNotEqual(static1, static3)
996
-
997
- # Can be hashed due to frozen=True
998
- self.assertEqual(hash(static1), hash(static2))
999
- self.assertNotEqual(hash(static1), hash(static3))
1000
-
1001
-
1002
- class TestErrorHandling(unittest.TestCase):
1003
- """Test error handling and edge cases."""
1004
-
1005
- def test_flatten_with_invalid_ref_index(self):
1006
- """Test flatten with invalid ref_index."""
1007
- model = brainstate.nn.Linear(2, 3)
1008
-
1009
- # Should raise assertion error with non-RefMap
1010
- with self.assertRaises(AssertionError):
1011
- brainstate.graph.flatten(model, ref_index={})
1012
-
1013
- def test_unflatten_with_invalid_graphdef(self):
1014
- """Test unflatten with invalid graphdef."""
1015
- state = brainstate.util.NestedDict({})
1016
-
1017
- # Should raise assertion error with non-GraphDef
1018
- with self.assertRaises(AssertionError):
1019
- brainstate.graph.unflatten("not_a_graphdef", state)
1020
-
1021
- def test_pop_states_without_filters(self):
1022
- """Test pop_states raises error without filters."""
1023
- model = brainstate.nn.Linear(2, 3)
1024
-
1025
- with self.assertRaises(ValueError) as context:
1026
- brainstate.graph.pop_states(model)
1027
-
1028
- self.assertIn('Expected at least one filter', str(context.exception))
1029
-
1030
- def test_update_states_immutable_node(self):
1031
- """Test update_states on immutable pytree node."""
1032
- # Create a pytree node (tuple is immutable)
1033
- node = (1, 2, brainstate.ParamState(3))
1034
- state = brainstate.util.NestedDict({0: brainstate.TreefyState(int, 10)})
1035
-
1036
- # Should raise ValueError when trying to update immutable node
1037
- with self.assertRaises(ValueError):
1038
- brainstate.graph.update_states(node, state)
1039
-
1040
- def test_get_node_impl_with_state(self):
1041
- """Test _get_node_impl raises error for State objects."""
1042
- from brainstate.graph._operation import _get_node_impl
1043
-
1044
- state = brainstate.ParamState(1)
1045
-
1046
- with self.assertRaises(ValueError) as context:
1047
- _get_node_impl(state)
1048
-
1049
- self.assertIn('State is not a node', str(context.exception))
1050
-
1051
- def test_split_with_non_exhaustive_filters(self):
1052
- """Test split with non-exhaustive filters."""
1053
- from brainstate.graph._operation import _split_flatted
1054
-
1055
- flatted = [(('a',), 1), (('b',), 2)]
1056
- filters = (lambda path, value: value == 1,) # Only matches first item
1057
-
1058
- # Should raise ValueError for non-exhaustive filters
1059
- with self.assertRaises(ValueError) as context:
1060
- _split_flatted(flatted, filters)
1061
-
1062
- self.assertIn('Non-exhaustive filters', str(context.exception))
1063
-
1064
- def test_invalid_filter_order(self):
1065
- """Test filters with ... not at the end."""
1066
- from brainstate.graph._operation import _filters_to_predicates
1067
-
1068
- # ... must be the last filter
1069
- filters = (..., lambda p, v: True)
1070
-
1071
- with self.assertRaises(ValueError) as context:
1072
- _filters_to_predicates(filters)
1073
-
1074
- self.assertIn('can only be used as the last filters', str(context.exception))
1075
-
1076
-
1077
- class TestIntegration(unittest.TestCase):
1078
- """Integration tests for complex scenarios."""
1079
-
1080
- def test_complex_graph_operations(self):
1081
- """Test complex graph with multiple levels and shared references."""
1082
-
1083
- class SubModule(brainstate.nn.Module):
1084
- def __init__(self):
1085
- super().__init__()
1086
- self.weight = brainstate.ParamState(jnp.ones((2, 2)))
1087
-
1088
- class ComplexModel(brainstate.nn.Module):
1089
- def __init__(self):
1090
- super().__init__()
1091
- self.shared = SubModule()
1092
- self.layer1 = brainstate.nn.Linear(2, 3)
1093
- self.layer2 = brainstate.nn.Linear(3, 4)
1094
- self.layer2.shared_ref = self.shared # Create a reference
1095
- self.nested = {
1096
- 'a': brainstate.nn.Linear(4, 5),
1097
- 'b': [brainstate.nn.Linear(5, 6), self.shared] # Another reference
1098
- }
1099
-
1100
- model = ComplexModel()
1101
-
1102
- # Test flatten/unflatten preserves structure
1103
- graphdef, state = brainstate.graph.treefy_split(model)
1104
- reconstructed = brainstate.graph.treefy_merge(graphdef, state)
1105
-
1106
- # Check shared references are preserved
1107
- self.assertIs(reconstructed.shared, reconstructed.layer2.shared_ref)
1108
- self.assertIs(reconstructed.shared, reconstructed.nested['b'][1])
1109
-
1110
- # Test state updates
1111
- new_state = jax.tree.map(lambda x: x * 2, state)
1112
- brainstate.graph.update_states(model, new_state)
1113
-
1114
- # Verify updates applied
1115
- self.assertTrue(jnp.allclose(
1116
- model.shared.weight.value,
1117
- jnp.ones((2, 2)) * 2
1118
- ))
1119
-
1120
- def test_recursive_structure(self):
1121
- """Test handling of recursive/circular references."""
1122
-
1123
- class RecursiveModule(brainstate.nn.Module):
1124
- def __init__(self):
1125
- super().__init__()
1126
- self.weight = brainstate.ParamState(1)
1127
- self.child = None
1128
-
1129
- # Create circular reference
1130
- parent = RecursiveModule()
1131
- child = RecursiveModule()
1132
- parent.child = child
1133
- child.child = parent # Circular reference
1134
-
1135
- # Should handle circular references without infinite recursion
1136
- graphdef, state = brainstate.graph.treefy_split(parent)
1137
-
1138
- # Should be able to reconstruct
1139
- reconstructed = brainstate.graph.treefy_merge(graphdef, state)
1140
-
1141
- # Check structure is preserved
1142
- self.assertIsNotNone(reconstructed.child)
1143
- self.assertIs(reconstructed.child.child, reconstructed)
1144
-
1145
-
1146
- if __name__ == '__main__':
1147
- absltest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+ from collections.abc import Callable
18
+ from threading import Thread
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ from absl.testing import absltest, parameterized
23
+
24
+ import pytest
25
+ pytest.skip("skipping tests", allow_module_level=True)
26
+
27
+ import brainstate
28
+ import braintools
29
+ import brainpy
30
+
31
+
32
+ class TestIter(unittest.TestCase):
33
+ def test1(self):
34
+ class Model(brainstate.nn.Module):
35
+ def __init__(self):
36
+ super().__init__()
37
+ self.a = brainstate.nn.Linear(1, 2)
38
+ self.b = brainstate.nn.Linear(2, 3)
39
+ self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
40
+ self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
41
+ self.b.a = brainpy.LIF(2)
42
+
43
+ for path, node in brainstate.graph.iter_leaf(Model()):
44
+ print(path, node)
45
+ for path, node in brainstate.graph.iter_node(Model()):
46
+ print(path, node)
47
+ for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(1, 1)):
48
+ print(path, node)
49
+ for path, node in brainstate.graph.iter_node(Model(), allowed_hierarchy=(2, 2)):
50
+ print(path, node)
51
+
52
+ def test_iter_leaf_v1(self):
53
+ class Linear(brainstate.nn.Module):
54
+ def __init__(self, din, dout):
55
+ super().__init__()
56
+ self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
57
+ self.bias = brainstate.ParamState(brainstate.random.randn(dout))
58
+ self.a = 1
59
+
60
+ module = Linear(3, 4)
61
+ graph = [module, module]
62
+
63
+ num = 0
64
+ for path, value in brainstate.graph.iter_leaf(graph):
65
+ print(path, type(value).__name__)
66
+ num += 1
67
+
68
+ assert num == 3
69
+
70
+ def test_iter_node_v1(self):
71
+ class Model(brainstate.nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.a = brainstate.nn.Linear(1, 2)
75
+ self.b = brainstate.nn.Linear(2, 3)
76
+ self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
77
+ self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
78
+ self.b.a = brainpy.LIF(2)
79
+
80
+ model = Model()
81
+
82
+ num = 0
83
+ for path, node in brainstate.graph.iter_node([model, model]):
84
+ print(path, node.__class__.__name__)
85
+ num += 1
86
+ assert num == 8
87
+
88
+
89
+ class List(brainstate.nn.Module):
90
+ def __init__(self, items):
91
+ super().__init__()
92
+ self.items = list(items)
93
+
94
+ def __getitem__(self, idx):
95
+ return self.items[idx]
96
+
97
+ def __setitem__(self, idx, value):
98
+ self.items[idx] = value
99
+
100
+
101
+ class Dict(brainstate.nn.Module):
102
+ def __init__(self, *args, **kwargs):
103
+ super().__init__()
104
+ self.items = dict(*args, **kwargs)
105
+
106
+ def __getitem__(self, key):
107
+ return self.items[key]
108
+
109
+ def __setitem__(self, key, value):
110
+ self.items[key] = value
111
+
112
+
113
+ class StatefulLinear(brainstate.nn.Module):
114
+ def __init__(self, din, dout):
115
+ super().__init__()
116
+ self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
117
+ self.b = brainstate.ParamState(jnp.zeros((dout,)))
118
+ self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
119
+
120
+ def increment(self):
121
+ self.count.value += 1
122
+
123
+ def __call__(self, x):
124
+ self.count.value += 1
125
+ return x @ self.w.value + self.b.value
126
+
127
+
128
+ class TestGraphUtils(absltest.TestCase):
129
+ def test_flatten_treey_state(self):
130
+ a = {'a': 1, 'b': brainstate.ParamState(2)}
131
+ g = [a, 3, a, brainstate.ParamState(4)]
132
+
133
+ refmap = brainstate.graph.RefMap()
134
+ graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=True)
135
+
136
+ states[0]['b'].value = 2
137
+ states[3].value = 4
138
+
139
+ assert isinstance(states[0]['b'], brainstate.TreefyState)
140
+ assert isinstance(states[3], brainstate.TreefyState)
141
+ assert isinstance(states, brainstate.util.NestedDict)
142
+ assert len(refmap) == 2
143
+ assert a['b'] in refmap
144
+ assert g[3] in refmap
145
+
146
+ def test_flatten(self):
147
+ a = {'a': 1, 'b': brainstate.ParamState(2)}
148
+ g = [a, 3, a, brainstate.ParamState(4)]
149
+
150
+ refmap = brainstate.graph.RefMap()
151
+ graphdef, states = brainstate.graph.flatten(g, ref_index=refmap, treefy_state=False)
152
+
153
+ states[0]['b'].value = 2
154
+ states[3].value = 4
155
+
156
+ assert isinstance(states[0]['b'], brainstate.State)
157
+ assert isinstance(states[3], brainstate.State)
158
+ assert len(refmap) == 2
159
+ assert a['b'] in refmap
160
+ assert g[3] in refmap
161
+
162
+ def test_unflatten_pytree(self):
163
+ a = {'a': 1, 'b': brainstate.ParamState(2)}
164
+ g = [a, 3, a, brainstate.ParamState(4)]
165
+
166
+ graphdef, references = brainstate.graph.treefy_split(g)
167
+ g = brainstate.graph.treefy_merge(graphdef, references)
168
+
169
+ assert g[0] is not g[2]
170
+
171
+ def test_unflatten_empty(self):
172
+ a = Dict({'a': 1, 'b': brainstate.ParamState(2)})
173
+ g = List([a, 3, a, brainstate.ParamState(4)])
174
+
175
+ graphdef, references = brainstate.graph.treefy_split(g)
176
+
177
+ with self.assertRaisesRegex(ValueError, 'Expected key'):
178
+ brainstate.graph.unflatten(graphdef, brainstate.util.NestedDict({}))
179
+
180
+ def test_module_list(self):
181
+ ls = [
182
+ brainstate.nn.Linear(2, 2),
183
+ brainstate.nn.BatchNorm1d([10, 2]),
184
+ ]
185
+ graphdef, statetree = brainstate.graph.treefy_split(ls)
186
+
187
+ assert statetree[0]['weight'].value['weight'].shape == (2, 2)
188
+ assert statetree[0]['weight'].value['bias'].shape == (2,)
189
+ assert statetree[1]['weight'].value['scale'].shape == (1, 2,)
190
+ assert statetree[1]['weight'].value['bias'].shape == (1, 2,)
191
+ assert statetree[1]['running_mean'].value.shape == (1, 2,)
192
+ assert statetree[1]['running_var'].value.shape == (1, 2)
193
+
194
+ def test_shared_variables(self):
195
+ v = brainstate.ParamState(1)
196
+ g = [v, v]
197
+
198
+ graphdef, statetree = brainstate.graph.treefy_split(g)
199
+ assert len(statetree.to_flat()) == 1
200
+
201
+ g2 = brainstate.graph.treefy_merge(graphdef, statetree)
202
+ assert g2[0] is g2[1]
203
+
204
+ def test_tied_weights(self):
205
+ class Foo(brainstate.nn.Module):
206
+ def __init__(self) -> None:
207
+ super().__init__()
208
+ self.bar = brainstate.nn.Linear(2, 2)
209
+ self.baz = brainstate.nn.Linear(2, 2)
210
+
211
+ # tie the weights
212
+ self.baz.weight = self.bar.weight
213
+
214
+ node = Foo()
215
+ graphdef, state = brainstate.graph.treefy_split(node)
216
+
217
+ assert len(state.to_flat()) == 1
218
+
219
+ node2 = brainstate.graph.treefy_merge(graphdef, state)
220
+
221
+ assert node2.bar.weight is node2.baz.weight
222
+
223
+ def test_tied_weights_example(self):
224
+ class LinearTranspose(brainstate.nn.Module):
225
+ def __init__(self, dout: int, din: int, ) -> None:
226
+ super().__init__()
227
+ self.kernel = brainstate.ParamState(braintools.init.LecunNormal()((dout, din)))
228
+
229
+ def __call__(self, x):
230
+ return x @ self.kernel.value.T
231
+
232
+ class Encoder(brainstate.nn.Module):
233
+ def __init__(self, ) -> None:
234
+ super().__init__()
235
+ self.embed = brainstate.nn.Embedding(10, 2)
236
+ self.linear_out = LinearTranspose(10, 2)
237
+
238
+ # tie the weights
239
+ self.linear_out.kernel = self.embed.weight
240
+
241
+ def __call__(self, x):
242
+ x = self.embed(x)
243
+ return self.linear_out(x)
244
+
245
+ model = Encoder()
246
+ graphdef, state = brainstate.graph.treefy_split(model)
247
+
248
+ assert len(state.to_flat()) == 1
249
+
250
+ x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
251
+ y = model(x)
252
+
253
+ assert y.shape == (2, 10)
254
+
255
+ def test_state_variables_not_shared_with_graph(self):
256
+ class Foo(brainstate.graph.Node):
257
+ def __init__(self):
258
+ self.a = brainstate.ParamState(1)
259
+
260
+ m = Foo()
261
+ graphdef, statetree = brainstate.graph.treefy_split(m)
262
+
263
+ assert isinstance(m.a, brainstate.ParamState)
264
+ assert issubclass(statetree.a.type, brainstate.ParamState)
265
+ assert m.a is not statetree.a
266
+ assert m.a.value == statetree.a.value
267
+
268
+ m2 = brainstate.graph.treefy_merge(graphdef, statetree)
269
+
270
+ assert isinstance(m2.a, brainstate.ParamState)
271
+ assert issubclass(statetree.a.type, brainstate.ParamState)
272
+ assert m2.a is not statetree.a
273
+ assert m2.a.value == statetree.a.value
274
+
275
+ def test_shared_state_variables_not_shared_with_graph(self):
276
+ class Foo(brainstate.graph.Node):
277
+ def __init__(self):
278
+ p = brainstate.ParamState(1)
279
+ self.a = p
280
+ self.b = p
281
+
282
+ m = Foo()
283
+ graphdef, state = brainstate.graph.treefy_split(m)
284
+
285
+ assert isinstance(m.a, brainstate.ParamState)
286
+ assert isinstance(m.b, brainstate.ParamState)
287
+ assert issubclass(state.a.type, brainstate.ParamState)
288
+ assert 'b' not in state
289
+ assert m.a is not state.a
290
+ assert m.b is not state.a
291
+ assert m.a.value == state.a.value
292
+ assert m.b.value == state.a.value
293
+
294
+ m2 = brainstate.graph.treefy_merge(graphdef, state)
295
+
296
+ assert isinstance(m2.a, brainstate.ParamState)
297
+ assert isinstance(m2.b, brainstate.ParamState)
298
+ assert issubclass(state.a.type, brainstate.ParamState)
299
+ assert m2.a is not state.a
300
+ assert m2.b is not state.a
301
+ assert m2.a.value == state.a.value
302
+ assert m2.b.value == state.a.value
303
+ assert m2.a is m2.b
304
+
305
+ def test_pytree_node(self):
306
+ @brainstate.util.dataclass
307
+ class Tree:
308
+ a: brainstate.ParamState
309
+ b: str = brainstate.util.field(pytree_node=False)
310
+
311
+ class Foo(brainstate.graph.Node):
312
+ def __init__(self):
313
+ self.tree = Tree(brainstate.ParamState(1), 'a')
314
+
315
+ m = Foo()
316
+
317
+ graphdef, state = brainstate.graph.treefy_split(m)
318
+
319
+ assert 'tree' in state
320
+ assert 'a' in state.tree
321
+ assert graphdef.subgraphs['tree'].type.__name__ == 'PytreeType'
322
+
323
+ m2 = brainstate.graph.treefy_merge(graphdef, state)
324
+
325
+ assert isinstance(m2.tree, Tree)
326
+ assert m2.tree.a.value == 1
327
+ assert m2.tree.b == 'a'
328
+ assert m2.tree.a is not m.tree.a
329
+ assert m2.tree is not m.tree
330
+
331
+
332
+ class SimpleModule(brainstate.nn.Module):
333
+ pass
334
+
335
+
336
+ class SimplePyTreeModule(brainstate.nn.Module):
337
+ pass
338
+
339
+
340
+ class TestThreading(parameterized.TestCase):
341
+
342
+ @parameterized.parameters(
343
+ (SimpleModule,),
344
+ (SimplePyTreeModule,),
345
+ )
346
+ def test_threading(self, module_fn: Callable[[], brainstate.nn.Module]):
347
+ x = module_fn()
348
+
349
+ class MyThread(Thread):
350
+
351
+ def run(self) -> None:
352
+ brainstate.graph.treefy_split(x)
353
+
354
+ thread = MyThread()
355
+ thread.start()
356
+ thread.join()
357
+
358
+
359
+ class TestGraphOperation(unittest.TestCase):
360
+ def test1(self):
361
+ class MyNode(brainstate.graph.Node):
362
+ def __init__(self):
363
+ self.a = brainstate.nn.Linear(2, 3)
364
+ self.b = brainstate.nn.Linear(3, 2)
365
+ self.c = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(1, 3)]
366
+ self.d = {'x': brainstate.nn.Linear(1, 3), 'y': brainstate.nn.Linear(1, 4)}
367
+
368
+ graphdef, statetree = brainstate.graph.flatten(MyNode())
369
+ # print(graphdef)
370
+ print(statetree)
371
+ # print(brainstate.graph.unflatten(graphdef, statetree))
372
+
373
+ def test_split(self):
374
+ class Foo(brainstate.graph.Node):
375
+ def __init__(self):
376
+ self.a = brainstate.nn.Linear(2, 2)
377
+ self.b = brainstate.nn.BatchNorm1d([10, 2])
378
+
379
+ node = Foo()
380
+ graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
381
+
382
+ print(params)
383
+ print(jax.tree.map(jnp.shape, params))
384
+
385
+ print(jax.tree.map(jnp.shape, others))
386
+
387
+ def test_merge(self):
388
+ class Foo(brainstate.graph.Node):
389
+ def __init__(self):
390
+ self.a = brainstate.nn.Linear(2, 2)
391
+ self.b = brainstate.nn.BatchNorm1d([10, 2])
392
+
393
+ node = Foo()
394
+ graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
395
+
396
+ new_node = brainstate.graph.treefy_merge(graphdef, params, others)
397
+
398
+ assert isinstance(new_node, Foo)
399
+ assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
400
+ assert isinstance(new_node.a, brainstate.nn.Linear)
401
+
402
+ def test_update_states(self):
403
+ x = jnp.ones((1, 2))
404
+ y = jnp.ones((1, 3))
405
+ model = brainstate.nn.Linear(2, 3)
406
+
407
+ def loss_fn(x, y):
408
+ return jnp.mean((y - model(x)) ** 2)
409
+
410
+ def sgd(ps, gs):
411
+ updates = jax.tree.map(lambda p, g: p - 0.1 * g, ps.value, gs)
412
+ ps.value = updates
413
+
414
+ prev_loss = loss_fn(x, y)
415
+ weights = model.states()
416
+ grads = brainstate.augment.grad(loss_fn, weights)(x, y)
417
+ for key, val in grads.items():
418
+ sgd(weights[key], val)
419
+ assert loss_fn(x, y) < prev_loss
420
+
421
+ def test_pop_states(self):
422
+ class Model(brainstate.nn.Module):
423
+ def __init__(self):
424
+ super().__init__()
425
+ self.a = brainstate.nn.Linear(2, 3)
426
+ self.b = brainpy.LIF([10, 2])
427
+
428
+ model = Model()
429
+ with brainstate.catch_new_states('new'):
430
+ brainstate.nn.init_all_states(model)
431
+ # print(model.states())
432
+ self.assertTrue(len(model.states()) == 2)
433
+ model_states = brainstate.graph.pop_states(model, 'new')
434
+ print(model_states)
435
+ self.assertTrue(len(model.states()) == 1)
436
+ assert not hasattr(model.b, 'V')
437
+ # print(model.states())
438
+
439
+ def test_treefy_split(self):
440
+ class MLP(brainstate.graph.Node):
441
+ def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
442
+ self.input = brainstate.nn.Linear(din, dmid)
443
+ self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
444
+ self.output = brainstate.nn.Linear(dmid, dout)
445
+
446
+ def __call__(self, x):
447
+ x = brainstate.functional.relu(self.input(x))
448
+ for layer in self.layers:
449
+ x = brainstate.functional.relu(layer(x))
450
+ return self.output(x)
451
+
452
+ model = MLP(2, 1, 3)
453
+ graph_def, treefy_states = brainstate.graph.treefy_split(model)
454
+
455
+ print(graph_def)
456
+ print(treefy_states)
457
+
458
+ # states = brainstate.graph.states(model)
459
+ # print(states)
460
+ # nest_states = states.to_nest()
461
+ # print(nest_states)
462
+
463
+ def test_states(self):
464
+ class MLP(brainstate.graph.Node):
465
+ def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
466
+ self.input = brainstate.nn.Linear(din, dmid)
467
+ self.layers = [brainstate.nn.Linear(dmid, dmid) for _ in range(n_layer)]
468
+ self.output = brainpy.LIF(dout)
469
+
470
+ def __call__(self, x):
471
+ x = brainstate.functional.relu(self.input(x))
472
+ for layer in self.layers:
473
+ x = brainstate.functional.relu(layer(x))
474
+ return self.output(x)
475
+
476
+ model = brainstate.nn.init_all_states(MLP(2, 1, 3))
477
+ states = brainstate.graph.states(model)
478
+ print(states)
479
+ nest_states = states.to_nest()
480
+ print(nest_states)
481
+
482
+ params, others = brainstate.graph.states(model, brainstate.ParamState, brainstate.ShortTermState)
483
+ print(params)
484
+ print(others)
485
+
486
+
487
+ class TestRefMap(unittest.TestCase):
488
+ """Test RefMap class functionality."""
489
+
490
+ def test_refmap_basic_operations(self):
491
+ """Test basic RefMap operations."""
492
+ ref_map = brainstate.graph.RefMap()
493
+
494
+ # Test empty RefMap
495
+ self.assertEqual(len(ref_map), 0)
496
+ self.assertFalse(object() in ref_map)
497
+
498
+ # Test adding items
499
+ obj1 = object()
500
+ obj2 = object()
501
+ ref_map[obj1] = 'value1'
502
+ ref_map[obj2] = 'value2'
503
+
504
+ self.assertEqual(len(ref_map), 2)
505
+ self.assertTrue(obj1 in ref_map)
506
+ self.assertTrue(obj2 in ref_map)
507
+ self.assertEqual(ref_map[obj1], 'value1')
508
+ self.assertEqual(ref_map[obj2], 'value2')
509
+
510
+ # Test iteration
511
+ keys = list(ref_map)
512
+ self.assertIn(obj1, keys)
513
+ self.assertIn(obj2, keys)
514
+
515
+ # Test deletion
516
+ del ref_map[obj1]
517
+ self.assertEqual(len(ref_map), 1)
518
+ self.assertFalse(obj1 in ref_map)
519
+ self.assertTrue(obj2 in ref_map)
520
+
521
+ def test_refmap_initialization_with_mapping(self):
522
+ """Test RefMap initialization with a mapping."""
523
+ obj1, obj2 = object(), object()
524
+ mapping = {obj1: 'value1', obj2: 'value2'}
525
+ ref_map = brainstate.graph.RefMap(mapping)
526
+
527
+ self.assertEqual(len(ref_map), 2)
528
+ self.assertEqual(ref_map[obj1], 'value1')
529
+ self.assertEqual(ref_map[obj2], 'value2')
530
+
531
+ def test_refmap_initialization_with_iterable(self):
532
+ """Test RefMap initialization with an iterable."""
533
+ obj1, obj2 = object(), object()
534
+ pairs = [(obj1, 'value1'), (obj2, 'value2')]
535
+ ref_map = brainstate.graph.RefMap(pairs)
536
+
537
+ self.assertEqual(len(ref_map), 2)
538
+ self.assertEqual(ref_map[obj1], 'value1')
539
+ self.assertEqual(ref_map[obj2], 'value2')
540
+
541
+ def test_refmap_same_object_different_instances(self):
542
+ """Test RefMap handles same content objects with different ids."""
543
+ # Create two lists with same content but different ids
544
+ list1 = [1, 2, 3]
545
+ list2 = [1, 2, 3]
546
+
547
+ ref_map = brainstate.graph.RefMap()
548
+ ref_map[list1] = 'list1'
549
+ ref_map[list2] = 'list2'
550
+
551
+ # Should have 2 entries since they have different ids
552
+ self.assertEqual(len(ref_map), 2)
553
+ self.assertEqual(ref_map[list1], 'list1')
554
+ self.assertEqual(ref_map[list2], 'list2')
555
+
556
+ def test_refmap_update(self):
557
+ """Test RefMap update method."""
558
+ obj1, obj2, obj3 = object(), object(), object()
559
+ ref_map = brainstate.graph.RefMap()
560
+ ref_map[obj1] = 'value1'
561
+
562
+ # Update with mapping
563
+ ref_map.update({obj2: 'value2', obj3: 'value3'})
564
+ self.assertEqual(len(ref_map), 3)
565
+
566
+ # Update existing key
567
+ ref_map[obj1] = 'new_value1'
568
+ self.assertEqual(ref_map[obj1], 'new_value1')
569
+
570
+ def test_refmap_str_repr(self):
571
+ """Test RefMap string representation."""
572
+ ref_map = brainstate.graph.RefMap()
573
+ obj = object()
574
+ ref_map[obj] = 'value'
575
+
576
+ str_repr = str(ref_map)
577
+ self.assertIsInstance(str_repr, str)
578
+ # Check that __str__ calls __repr__
579
+ self.assertEqual(str_repr, repr(ref_map))
580
+
581
+
582
+ class TestHelperFunctions(unittest.TestCase):
583
+ """Test helper functions in the _operation module."""
584
+
585
+ def test_is_state_leaf(self):
586
+ """Test _is_state_leaf function."""
587
+ from brainstate.graph._operation import _is_state_leaf
588
+
589
+ # Create TreefyState instance
590
+ state = brainstate.ParamState(1)
591
+ treefy_state = state.to_state_ref()
592
+
593
+ self.assertTrue(_is_state_leaf(treefy_state))
594
+ self.assertFalse(_is_state_leaf(state))
595
+ self.assertFalse(_is_state_leaf(1))
596
+ self.assertFalse(_is_state_leaf("string"))
597
+ self.assertFalse(_is_state_leaf(None))
598
+
599
+ def test_is_node_leaf(self):
600
+ """Test _is_node_leaf function."""
601
+ from brainstate.graph._operation import _is_node_leaf
602
+
603
+ state = brainstate.ParamState(1)
604
+
605
+ self.assertTrue(_is_node_leaf(state))
606
+ self.assertFalse(_is_node_leaf(1))
607
+ self.assertFalse(_is_node_leaf("string"))
608
+ self.assertFalse(_is_node_leaf(None))
609
+
610
+ def test_is_node(self):
611
+ """Test _is_node function."""
612
+ from brainstate.graph._operation import _is_node
613
+
614
+ # Test with graph nodes
615
+ node = brainstate.nn.Module()
616
+ self.assertTrue(_is_node(node))
617
+
618
+ # Test with pytree nodes
619
+ self.assertTrue(_is_node([1, 2, 3]))
620
+ self.assertTrue(_is_node({'a': 1}))
621
+
622
+ # Test with non-nodes
623
+ self.assertFalse(_is_node(1))
624
+ self.assertFalse(_is_node("string"))
625
+
626
+ def test_is_pytree_node(self):
627
+ """Test _is_pytree_node function."""
628
+ from brainstate.graph._operation import _is_pytree_node
629
+
630
+ self.assertTrue(_is_pytree_node([1, 2, 3]))
631
+ self.assertTrue(_is_pytree_node({'a': 1}))
632
+ self.assertTrue(_is_pytree_node((1, 2)))
633
+
634
+ self.assertFalse(_is_pytree_node(1))
635
+ self.assertFalse(_is_pytree_node("string"))
636
+ self.assertFalse(_is_pytree_node(jnp.array([1, 2])))
637
+
638
+ def test_is_graph_node(self):
639
+ """Test _is_graph_node function."""
640
+ from brainstate.graph._operation import _is_graph_node
641
+
642
+ # Register a custom type for testing
643
+ class CustomNode:
644
+ pass
645
+
646
+ # Graph nodes are those registered with register_graph_node_type
647
+ node = brainstate.nn.Module()
648
+ self.assertTrue(_is_graph_node(node))
649
+
650
+ # Non-registered types
651
+ self.assertFalse(_is_graph_node([1, 2, 3]))
652
+ self.assertFalse(_is_graph_node({'a': 1}))
653
+ self.assertFalse(_is_graph_node(CustomNode()))
654
+
655
+
656
+ class TestRegisterGraphNodeType(unittest.TestCase):
657
+ """Test register_graph_node_type functionality."""
658
+
659
+ def test_register_custom_node_type(self):
660
+ """Test registering a custom graph node type."""
661
+ from brainstate.graph._operation import _is_graph_node, _get_node_impl
662
+
663
+ class CustomNode:
664
+ def __init__(self):
665
+ self.data = {}
666
+
667
+ def flatten_custom(node):
668
+ return list(node.data.items()), None
669
+
670
+ def set_key_custom(node, key, value):
671
+ node.data[key] = value
672
+
673
+ def pop_key_custom(node, key):
674
+ return node.data.pop(key)
675
+
676
+ def create_empty_custom(metadata):
677
+ return CustomNode()
678
+
679
+ def clear_custom(node):
680
+ node.data.clear()
681
+
682
+ # Register the custom node type
683
+ brainstate.graph.register_graph_node_type(
684
+ CustomNode,
685
+ flatten_custom,
686
+ set_key_custom,
687
+ pop_key_custom,
688
+ create_empty_custom,
689
+ clear_custom
690
+ )
691
+
692
+ # Test that the node is recognized
693
+ node = CustomNode()
694
+ self.assertTrue(_is_graph_node(node))
695
+
696
+ # Test node operations
697
+ node.data['key1'] = 'value1'
698
+ node_impl = _get_node_impl(node)
699
+
700
+ # Test flatten
701
+ items, metadata = node_impl.flatten(node)
702
+ self.assertEqual(list(items), [('key1', 'value1')])
703
+
704
+ # Test set_key
705
+ node_impl.set_key(node, 'key2', 'value2')
706
+ self.assertEqual(node.data['key2'], 'value2')
707
+
708
+ # Test pop_key
709
+ value = node_impl.pop_key(node, 'key1')
710
+ self.assertEqual(value, 'value1')
711
+ self.assertNotIn('key1', node.data)
712
+
713
+ # Test create_empty
714
+ new_node = node_impl.create_empty(None)
715
+ self.assertIsInstance(new_node, CustomNode)
716
+ self.assertEqual(new_node.data, {})
717
+
718
+ # Test clear
719
+ node_impl.clear(node)
720
+ self.assertEqual(node.data, {})
721
+
722
+
723
+ class TestHashableMapping(unittest.TestCase):
724
+ """Test HashableMapping class."""
725
+
726
+ def test_hashable_mapping_basic(self):
727
+ """Test basic HashableMapping operations."""
728
+ from brainstate.graph._operation import HashableMapping
729
+
730
+ mapping = {'a': 1, 'b': 2}
731
+ hm = HashableMapping(mapping)
732
+
733
+ # Test basic operations
734
+ self.assertEqual(len(hm), 2)
735
+ self.assertTrue('a' in hm)
736
+ self.assertFalse('c' in hm)
737
+ self.assertEqual(hm['a'], 1)
738
+ self.assertEqual(hm['b'], 2)
739
+
740
+ # Test iteration
741
+ keys = list(hm)
742
+ self.assertEqual(set(keys), {'a', 'b'})
743
+
744
+ def test_hashable_mapping_hash(self):
745
+ """Test HashableMapping hashing."""
746
+ from brainstate.graph._operation import HashableMapping
747
+
748
+ hm1 = HashableMapping({'a': 1, 'b': 2})
749
+ hm2 = HashableMapping({'a': 1, 'b': 2})
750
+ hm3 = HashableMapping({'a': 1, 'b': 3})
751
+
752
+ # Equal mappings should have same hash
753
+ self.assertEqual(hash(hm1), hash(hm2))
754
+ self.assertEqual(hm1, hm2)
755
+
756
+ # Different mappings should not be equal
757
+ self.assertNotEqual(hm1, hm3)
758
+
759
+ # Can be used in sets
760
+ s = {hm1, hm2, hm3}
761
+ self.assertEqual(len(s), 2) # hm1 and hm2 are the same
762
+
763
+ def test_hashable_mapping_from_iterable(self):
764
+ """Test HashableMapping creation from iterable."""
765
+ from brainstate.graph._operation import HashableMapping
766
+
767
+ pairs = [('a', 1), ('b', 2)]
768
+ hm = HashableMapping(pairs)
769
+
770
+ self.assertEqual(len(hm), 2)
771
+ self.assertEqual(hm['a'], 1)
772
+ self.assertEqual(hm['b'], 2)
773
+
774
+
775
+ class TestNodeDefAndNodeRef(unittest.TestCase):
776
+ """Test NodeDef and NodeRef classes."""
777
+
778
+ def test_noderef_creation(self):
779
+ """Test NodeRef creation and attributes."""
780
+ node_ref = brainstate.graph.NodeRef(
781
+ type=brainstate.nn.Module,
782
+ index=42
783
+ )
784
+
785
+ self.assertEqual(node_ref.type, brainstate.nn.Module)
786
+ self.assertEqual(node_ref.index, 42)
787
+
788
+ def test_nodedef_creation(self):
789
+ """Test NodeDef creation and attributes."""
790
+ from brainstate.graph._operation import HashableMapping
791
+
792
+ nodedef = brainstate.graph.NodeDef.create(
793
+ type=brainstate.nn.Module,
794
+ index=1,
795
+ attributes=('a', 'b'),
796
+ subgraphs=[],
797
+ static_fields=[('static', 'value')],
798
+ leaves=[],
799
+ metadata=None,
800
+ index_mapping=None
801
+ )
802
+
803
+ self.assertEqual(nodedef.type, brainstate.nn.Module)
804
+ self.assertEqual(nodedef.index, 1)
805
+ self.assertEqual(nodedef.attributes, ('a', 'b'))
806
+ self.assertIsInstance(nodedef.subgraphs, HashableMapping)
807
+ self.assertIsInstance(nodedef.static_fields, HashableMapping)
808
+ self.assertEqual(nodedef.static_fields['static'], 'value')
809
+ self.assertIsNone(nodedef.metadata)
810
+ self.assertIsNone(nodedef.index_mapping)
811
+
812
+ def test_nodedef_with_index_mapping(self):
813
+ """Test NodeDef with index_mapping."""
814
+ nodedef = brainstate.graph.NodeDef.create(
815
+ type=brainstate.nn.Module,
816
+ index=1,
817
+ attributes=(),
818
+ subgraphs=[],
819
+ static_fields=[],
820
+ leaves=[],
821
+ metadata=None,
822
+ index_mapping={1: 2, 3: 4}
823
+ )
824
+
825
+ self.assertIsNotNone(nodedef.index_mapping)
826
+ self.assertEqual(nodedef.index_mapping[1], 2)
827
+ self.assertEqual(nodedef.index_mapping[3], 4)
828
+
829
+
830
+ class TestGraphDefAndClone(unittest.TestCase):
831
+ """Test graphdef and clone functions."""
832
+
833
+ def test_graphdef_function(self):
834
+ """Test graphdef function returns correct GraphDef."""
835
+ model = brainstate.nn.Linear(2, 3)
836
+ graphdef = brainstate.graph.graphdef(model)
837
+
838
+ self.assertIsInstance(graphdef, brainstate.graph.NodeDef)
839
+ self.assertEqual(graphdef.type, brainstate.nn.Linear)
840
+
841
+ # Compare with flatten result
842
+ graphdef2, _ = brainstate.graph.flatten(model)
843
+ self.assertEqual(graphdef, graphdef2)
844
+
845
+ def test_clone_function(self):
846
+ """Test clone creates a deep copy."""
847
+ model = brainstate.nn.Linear(2, 3)
848
+ cloned = brainstate.graph.clone(model)
849
+
850
+ # Check types
851
+ self.assertIsInstance(cloned, brainstate.nn.Linear)
852
+ self.assertIsNot(model, cloned)
853
+
854
+ # Check that states are not shared
855
+ self.assertIsNot(model.weight, cloned.weight)
856
+
857
+ # Modify original and check clone is unaffected
858
+ original_weight = cloned.weight.value['weight'].copy()
859
+ model.weight.value = jax.tree.map(lambda x: x + 1, model.weight.value)
860
+
861
+ # Clone should be unchanged
862
+ self.assertTrue(jnp.allclose(cloned.weight.value['weight'], original_weight))
863
+
864
+ def test_clone_with_shared_variables(self):
865
+ """Test cloning preserves shared variable structure."""
866
+
867
+ class SharedModel(brainstate.nn.Module):
868
+ def __init__(self):
869
+ super().__init__()
870
+ self.shared_weight = brainstate.ParamState(jnp.ones((2, 2)))
871
+ self.layer1 = brainstate.nn.Linear(2, 2)
872
+ self.layer2 = brainstate.nn.Linear(2, 2)
873
+ # Share weights
874
+ self.layer2.weight = self.layer1.weight
875
+
876
+ model = SharedModel()
877
+ cloned = brainstate.graph.clone(model)
878
+
879
+ # Check that sharing is preserved
880
+ self.assertIs(cloned.layer1.weight, cloned.layer2.weight)
881
+ # But not shared with original
882
+ self.assertIsNot(cloned.layer1.weight, model.layer1.weight)
883
+
884
+
885
+ class TestNodesFunction(unittest.TestCase):
886
+ """Test nodes function for filtering graph nodes."""
887
+
888
+ def test_nodes_without_filters(self):
889
+ """Test nodes function without filters."""
890
+
891
+ class Model(brainstate.nn.Module):
892
+ def __init__(self):
893
+ super().__init__()
894
+ self.a = brainstate.nn.Linear(2, 3)
895
+ self.b = brainstate.nn.Linear(3, 4)
896
+
897
+ model = Model()
898
+ all_nodes = brainstate.graph.nodes(model)
899
+
900
+ # Should return all nodes as FlattedDict
901
+ self.assertIsInstance(all_nodes, brainstate.util.FlattedDict)
902
+
903
+ # Check that nodes are present
904
+ paths = [path for path, _ in all_nodes.items()]
905
+ self.assertIn(('a',), paths)
906
+ self.assertIn(('b',), paths)
907
+ self.assertIn((), paths) # The model itself
908
+
909
+ def test_nodes_with_filter(self):
910
+ """Test nodes function with a single filter."""
911
+
912
+ class CustomModule(brainstate.nn.Module):
913
+ pass
914
+
915
+ class Model(brainstate.nn.Module):
916
+ def __init__(self):
917
+ super().__init__()
918
+ self.linear = brainstate.nn.Linear(2, 3)
919
+ self.custom = CustomModule()
920
+
921
+ model = Model()
922
+
923
+ # Filter for Linear modules
924
+ linear_nodes = brainstate.graph.nodes(
925
+ model,
926
+ lambda path, node: isinstance(node, brainstate.nn.Linear)
927
+ )
928
+
929
+ self.assertIsInstance(linear_nodes, brainstate.util.FlattedDict)
930
+ # Should only contain the Linear module
931
+ nodes_list = list(linear_nodes.values())
932
+ self.assertEqual(len(nodes_list), 1)
933
+ self.assertIsInstance(nodes_list[0], brainstate.nn.Linear)
934
+
935
+ def test_nodes_with_hierarchy(self):
936
+ """Test nodes function with hierarchy limits."""
937
+
938
+ class Model(brainstate.nn.Module):
939
+ def __init__(self):
940
+ super().__init__()
941
+ self.layer1 = brainstate.nn.Linear(2, 3)
942
+ self.layer1.sublayer = brainstate.nn.Linear(3, 3)
943
+
944
+ model = Model()
945
+
946
+ # Get only level 1 nodes
947
+ level1_nodes = brainstate.graph.nodes(model, allowed_hierarchy=(1, 1))
948
+ paths = [path for path, _ in level1_nodes.items()]
949
+
950
+ self.assertIn(('layer1',), paths)
951
+ # Sublayer should not be included at level 1
952
+ self.assertNotIn(('layer1', 'sublayer'), paths)
953
+
954
+
955
+ class TestStatic(unittest.TestCase):
956
+ """Test Static class functionality."""
957
+
958
+ def test_static_basic(self):
959
+ """Test basic Static wrapper."""
960
+ from brainstate.graph._operation import Static
961
+
962
+ value = {'key': 'value'}
963
+ static = Static(value)
964
+
965
+ self.assertEqual(static.value, value)
966
+ self.assertIs(static.value, value)
967
+
968
+ def test_static_is_pytree_leaf(self):
969
+ """Test that Static is treated as a pytree leaf."""
970
+ from brainstate.graph._operation import Static
971
+
972
+ static = Static({'key': 'value'})
973
+
974
+ # Should be treated as a leaf in pytree operations
975
+ leaves, treedef = jax.tree_util.tree_flatten(static)
976
+ self.assertEqual(len(leaves), 0) # Static has no leaves
977
+
978
+ # Test in a structure
979
+ tree = {'a': 1, 'b': static, 'c': [2, 3]}
980
+ leaves, treedef = jax.tree_util.tree_flatten(tree)
981
+
982
+ # static should not be in leaves since it's registered as static
983
+ self.assertNotIn(static, leaves)
984
+
985
+ def test_static_equality_and_hash(self):
986
+ """Test Static equality and hashing."""
987
+ from brainstate.graph._operation import Static
988
+
989
+ static1 = Static(42)
990
+ static2 = Static(42)
991
+ static3 = Static(43)
992
+
993
+ # Dataclass frozen=True provides equality
994
+ self.assertEqual(static1, static2)
995
+ self.assertNotEqual(static1, static3)
996
+
997
+ # Can be hashed due to frozen=True
998
+ self.assertEqual(hash(static1), hash(static2))
999
+ self.assertNotEqual(hash(static1), hash(static3))
1000
+
1001
+
1002
+ class TestErrorHandling(unittest.TestCase):
1003
+ """Test error handling and edge cases."""
1004
+
1005
+ def test_flatten_with_invalid_ref_index(self):
1006
+ """Test flatten with invalid ref_index."""
1007
+ model = brainstate.nn.Linear(2, 3)
1008
+
1009
+ # Should raise assertion error with non-RefMap
1010
+ with self.assertRaises(AssertionError):
1011
+ brainstate.graph.flatten(model, ref_index={})
1012
+
1013
+ def test_unflatten_with_invalid_graphdef(self):
1014
+ """Test unflatten with invalid graphdef."""
1015
+ state = brainstate.util.NestedDict({})
1016
+
1017
+ # Should raise assertion error with non-GraphDef
1018
+ with self.assertRaises(AssertionError):
1019
+ brainstate.graph.unflatten("not_a_graphdef", state)
1020
+
1021
+ def test_pop_states_without_filters(self):
1022
+ """Test pop_states raises error without filters."""
1023
+ model = brainstate.nn.Linear(2, 3)
1024
+
1025
+ with self.assertRaises(ValueError) as context:
1026
+ brainstate.graph.pop_states(model)
1027
+
1028
+ self.assertIn('Expected at least one filter', str(context.exception))
1029
+
1030
+ def test_update_states_immutable_node(self):
1031
+ """Test update_states on immutable pytree node."""
1032
+ # Create a pytree node (tuple is immutable)
1033
+ node = (1, 2, brainstate.ParamState(3))
1034
+ state = brainstate.util.NestedDict({0: brainstate.TreefyState(int, 10)})
1035
+
1036
+ # Should raise ValueError when trying to update immutable node
1037
+ with self.assertRaises(ValueError):
1038
+ brainstate.graph.update_states(node, state)
1039
+
1040
+ def test_get_node_impl_with_state(self):
1041
+ """Test _get_node_impl raises error for State objects."""
1042
+ from brainstate.graph._operation import _get_node_impl
1043
+
1044
+ state = brainstate.ParamState(1)
1045
+
1046
+ with self.assertRaises(ValueError) as context:
1047
+ _get_node_impl(state)
1048
+
1049
+ self.assertIn('State is not a node', str(context.exception))
1050
+
1051
+ def test_split_with_non_exhaustive_filters(self):
1052
+ """Test split with non-exhaustive filters."""
1053
+ from brainstate.graph._operation import _split_flatted
1054
+
1055
+ flatted = [(('a',), 1), (('b',), 2)]
1056
+ filters = (lambda path, value: value == 1,) # Only matches first item
1057
+
1058
+ # Should raise ValueError for non-exhaustive filters
1059
+ with self.assertRaises(ValueError) as context:
1060
+ _split_flatted(flatted, filters)
1061
+
1062
+ self.assertIn('Non-exhaustive filters', str(context.exception))
1063
+
1064
+ def test_invalid_filter_order(self):
1065
+ """Test filters with ... not at the end."""
1066
+ from brainstate.graph._operation import _filters_to_predicates
1067
+
1068
+ # ... must be the last filter
1069
+ filters = (..., lambda p, v: True)
1070
+
1071
+ with self.assertRaises(ValueError) as context:
1072
+ _filters_to_predicates(filters)
1073
+
1074
+ self.assertIn('can only be used as the last filters', str(context.exception))
1075
+
1076
+
1077
+ class TestIntegration(unittest.TestCase):
1078
+ """Integration tests for complex scenarios."""
1079
+
1080
+ def test_complex_graph_operations(self):
1081
+ """Test complex graph with multiple levels and shared references."""
1082
+
1083
+ class SubModule(brainstate.nn.Module):
1084
+ def __init__(self):
1085
+ super().__init__()
1086
+ self.weight = brainstate.ParamState(jnp.ones((2, 2)))
1087
+
1088
+ class ComplexModel(brainstate.nn.Module):
1089
+ def __init__(self):
1090
+ super().__init__()
1091
+ self.shared = SubModule()
1092
+ self.layer1 = brainstate.nn.Linear(2, 3)
1093
+ self.layer2 = brainstate.nn.Linear(3, 4)
1094
+ self.layer2.shared_ref = self.shared # Create a reference
1095
+ self.nested = {
1096
+ 'a': brainstate.nn.Linear(4, 5),
1097
+ 'b': [brainstate.nn.Linear(5, 6), self.shared] # Another reference
1098
+ }
1099
+
1100
+ model = ComplexModel()
1101
+
1102
+ # Test flatten/unflatten preserves structure
1103
+ graphdef, state = brainstate.graph.treefy_split(model)
1104
+ reconstructed = brainstate.graph.treefy_merge(graphdef, state)
1105
+
1106
+ # Check shared references are preserved
1107
+ self.assertIs(reconstructed.shared, reconstructed.layer2.shared_ref)
1108
+ self.assertIs(reconstructed.shared, reconstructed.nested['b'][1])
1109
+
1110
+ # Test state updates
1111
+ new_state = jax.tree.map(lambda x: x * 2, state)
1112
+ brainstate.graph.update_states(model, new_state)
1113
+
1114
+ # Verify updates applied
1115
+ self.assertTrue(jnp.allclose(
1116
+ model.shared.weight.value,
1117
+ jnp.ones((2, 2)) * 2
1118
+ ))
1119
+
1120
+ def test_recursive_structure(self):
1121
+ """Test handling of recursive/circular references."""
1122
+
1123
+ class RecursiveModule(brainstate.nn.Module):
1124
+ def __init__(self):
1125
+ super().__init__()
1126
+ self.weight = brainstate.ParamState(1)
1127
+ self.child = None
1128
+
1129
+ # Create circular reference
1130
+ parent = RecursiveModule()
1131
+ child = RecursiveModule()
1132
+ parent.child = child
1133
+ child.child = parent # Circular reference
1134
+
1135
+ # Should handle circular references without infinite recursion
1136
+ graphdef, state = brainstate.graph.treefy_split(parent)
1137
+
1138
+ # Should be able to reconstruct
1139
+ reconstructed = brainstate.graph.treefy_merge(graphdef, state)
1140
+
1141
+ # Check structure is preserved
1142
+ self.assertIsNotNone(reconstructed.child)
1143
+ self.assertIs(reconstructed.child.child, reconstructed)
1144
+
1145
+
1146
+ if __name__ == '__main__':
1147
+ absltest.main()