brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,589 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import brainstate
19
+ from brainstate._state import State
20
+ from brainstate.graph._node import (
21
+ Node,
22
+ _node_flatten,
23
+ _node_set_key,
24
+ _node_pop_key,
25
+ _node_create_empty,
26
+ _node_clear
27
+ )
28
+
29
+
30
+ class TestNode(unittest.TestCase):
31
+ """Test suite for the Node class."""
32
+
33
+ def test_node_creation(self):
34
+ """Test basic node creation."""
35
+
36
+ class SimpleNode(Node):
37
+ def __init__(self, value):
38
+ self.value = value
39
+
40
+ node = SimpleNode(10)
41
+ self.assertEqual(node.value, 10)
42
+ self.assertIsInstance(node, Node)
43
+
44
+ def test_node_subclass_registration(self):
45
+ """Test that Node subclasses are automatically registered."""
46
+
47
+ class TestNode(Node):
48
+ pass
49
+
50
+ # The subclass should be registered automatically via __init_subclass__
51
+ node = TestNode()
52
+ self.assertIsInstance(node, Node)
53
+
54
+ def test_graph_invisible_attrs(self):
55
+ """Test that graph_invisible_attrs works correctly."""
56
+
57
+ class NodeWithInvisible(Node):
58
+ graph_invisible_attrs = ('_private', '_internal')
59
+
60
+ def __init__(self):
61
+ self.public = 1
62
+ self._private = 2
63
+ self._internal = 3
64
+
65
+ node = NodeWithInvisible()
66
+ flattened, static = _node_flatten(node)
67
+
68
+ # Check that only public attribute is in flattened
69
+ keys = [k for k, v in flattened]
70
+ self.assertIn('public', keys)
71
+ self.assertNotIn('_private', keys)
72
+ self.assertNotIn('_internal', keys)
73
+
74
+ def test_deepcopy_using_treefy(self):
75
+ """Test deep copying of nodes using treefy_split/merge."""
76
+
77
+ class NodeWithData(Node):
78
+ def __init__(self, data=None):
79
+ if data is not None:
80
+ self.data = data
81
+ self.nested = {'a': 1, 'b': [2, 3]}
82
+
83
+ original = NodeWithData([1, 2, 3])
84
+
85
+ # Use treefy_split and treefy_merge to copy
86
+ graphdef, state = brainstate.graph.treefy_split(original)
87
+ # Create a new instance using treefy_merge
88
+ copied = brainstate.graph.treefy_merge(graphdef, state)
89
+
90
+ # Check that it's a different object
91
+ self.assertIsNot(original, copied)
92
+
93
+ # Check that data is present
94
+ self.assertEqual(original.data, copied.data)
95
+
96
+ # Modify copied data shouldn't affect original
97
+ copied.data.append(4)
98
+ self.assertEqual(len(original.data), 3)
99
+ self.assertEqual(len(copied.data), 4)
100
+
101
+ def test_node_with_state(self):
102
+ """Test nodes containing State objects."""
103
+
104
+ class NodeWithState(Node):
105
+ def __init__(self):
106
+ self.value = State(10)
107
+ self.normal = 20
108
+
109
+ node = NodeWithState()
110
+ self.assertIsInstance(node.value, State)
111
+ self.assertEqual(node.value.value, 10)
112
+ self.assertEqual(node.normal, 20)
113
+
114
+ def test_complex_nested_structure(self):
115
+ """Test nodes with complex nested structures using treefy."""
116
+
117
+ class ComplexNode(Node):
118
+ def __init__(self):
119
+ self.list_data = [1, 2, [3, 4]]
120
+ self.dict_data = {'a': 1, 'b': {'c': 2}}
121
+ self.tuple_data = (1, 2, (3, 4))
122
+
123
+ node = ComplexNode()
124
+
125
+ # Test using treefy_split and merge
126
+ graphdef, state = brainstate.graph.treefy_split(node)
127
+ copied = brainstate.graph.treefy_merge(graphdef, state)
128
+
129
+ self.assertEqual(node.list_data, copied.list_data)
130
+ self.assertEqual(node.dict_data, copied.dict_data)
131
+ self.assertEqual(node.tuple_data, copied.tuple_data)
132
+
133
+
134
+ class TestNodeHelperFunctions(unittest.TestCase):
135
+ """Test suite for node helper functions."""
136
+
137
+ def test_node_flatten(self):
138
+ """Test _node_flatten function."""
139
+
140
+ class TestNode(Node):
141
+ def __init__(self):
142
+ self.b = 2
143
+ self.a = 1
144
+ self.c = 3
145
+
146
+ node = TestNode()
147
+ flattened, static = _node_flatten(node)
148
+
149
+ # Check that attributes are sorted
150
+ keys = [k for k, v in flattened]
151
+ self.assertEqual(keys, ['a', 'b', 'c'])
152
+
153
+ # Check values
154
+ values = [v for k, v in flattened]
155
+ self.assertEqual(values, [1, 2, 3])
156
+
157
+ # Check static contains type
158
+ self.assertEqual(static, (TestNode,))
159
+
160
+ def test_node_flatten_with_invisible(self):
161
+ """Test _node_flatten with invisible attributes."""
162
+
163
+ class TestNode(Node):
164
+ graph_invisible_attrs = ('hidden',)
165
+
166
+ def __init__(self):
167
+ self.visible = 1
168
+ self.hidden = 2
169
+
170
+ node = TestNode()
171
+ flattened, static = _node_flatten(node)
172
+
173
+ keys = [k for k, v in flattened]
174
+ self.assertIn('visible', keys)
175
+ self.assertNotIn('hidden', keys)
176
+
177
+ def test_node_set_key_simple(self):
178
+ """Test _node_set_key with simple values."""
179
+
180
+ class TestNode(Node):
181
+ pass
182
+
183
+ node = TestNode()
184
+ _node_set_key(node, 'attr', 10)
185
+ self.assertEqual(node.attr, 10)
186
+
187
+ _node_set_key(node, 'attr', 20)
188
+ self.assertEqual(node.attr, 20)
189
+
190
+ def test_node_set_key_with_state(self):
191
+ """Test _node_set_key with State objects."""
192
+
193
+ class TestNode(Node):
194
+ def __init__(self):
195
+ self.state_attr = State(10)
196
+
197
+ node = TestNode()
198
+
199
+ # Test setting a regular value
200
+ _node_set_key(node, 'regular_attr', 30)
201
+ self.assertEqual(node.regular_attr, 30)
202
+
203
+ # Test updating with a TreefyState
204
+ # We'll use the real TreefyState from graph operations
205
+ graphdef, states = brainstate.graph.treefy_split(node)
206
+
207
+ # The states should contain our State object wrapped as TreefyState
208
+ # When setting with TreefyState, it should update the existing State
209
+ initial_state = node.state_attr
210
+
211
+ # Create a new node and try to set the TreefyState
212
+ new_node = TestNode()
213
+ for key, value in states.to_flat().items():
214
+ if 'state_attr' in key:
215
+ _node_set_key(new_node, 'state_attr', value)
216
+ # The State object should be updated via update_from_ref
217
+ self.assertIsInstance(new_node.state_attr, State)
218
+
219
+ def test_node_set_key_invalid_key(self):
220
+ """Test _node_set_key with invalid key."""
221
+
222
+ class TestNode(Node):
223
+ pass
224
+
225
+ node = TestNode()
226
+
227
+ with self.assertRaises(KeyError) as context:
228
+ _node_set_key(node, 123, 'value')
229
+ self.assertIn('Invalid key', str(context.exception))
230
+
231
+ def test_node_pop_key(self):
232
+ """Test _node_pop_key function."""
233
+
234
+ class TestNode(Node):
235
+ def __init__(self):
236
+ self.attr1 = 10
237
+ self.attr2 = 20
238
+
239
+ node = TestNode()
240
+
241
+ # Pop existing attribute
242
+ value = _node_pop_key(node, 'attr1')
243
+ self.assertEqual(value, 10)
244
+ self.assertFalse(hasattr(node, 'attr1'))
245
+ self.assertTrue(hasattr(node, 'attr2'))
246
+
247
+ def test_node_pop_key_invalid(self):
248
+ """Test _node_pop_key with invalid key."""
249
+
250
+ class TestNode(Node):
251
+ pass
252
+
253
+ node = TestNode()
254
+
255
+ # Invalid key type
256
+ with self.assertRaises(KeyError) as context:
257
+ _node_pop_key(node, 123)
258
+ self.assertIn('Invalid key', str(context.exception))
259
+
260
+ # Non-existent key
261
+ with self.assertRaises(KeyError):
262
+ _node_pop_key(node, 'nonexistent')
263
+
264
+ def test_node_create_empty(self):
265
+ """Test _node_create_empty function."""
266
+
267
+ class TestNode(Node):
268
+ def __init__(self, value=None):
269
+ self.value = value
270
+ self.initialized = True
271
+
272
+ # Create empty node
273
+ node = _node_create_empty((TestNode,))
274
+
275
+ # Check it's the right type
276
+ self.assertIsInstance(node, TestNode)
277
+
278
+ # Check __init__ was not called
279
+ self.assertFalse(hasattr(node, 'value'))
280
+ self.assertFalse(hasattr(node, 'initialized'))
281
+
282
+ def test_node_clear(self):
283
+ """Test _node_clear function."""
284
+
285
+ class TestNode(Node):
286
+ def __init__(self):
287
+ self.attr1 = 10
288
+ self.attr2 = 20
289
+ self.attr3 = [1, 2, 3]
290
+
291
+ node = TestNode()
292
+
293
+ # Verify attributes exist
294
+ self.assertTrue(hasattr(node, 'attr1'))
295
+ self.assertTrue(hasattr(node, 'attr2'))
296
+ self.assertTrue(hasattr(node, 'attr3'))
297
+
298
+ # Clear the node
299
+ _node_clear(node)
300
+
301
+ # Verify attributes are gone
302
+ self.assertFalse(hasattr(node, 'attr1'))
303
+ self.assertFalse(hasattr(node, 'attr2'))
304
+ self.assertFalse(hasattr(node, 'attr3'))
305
+
306
+ # Verify node still exists and is valid
307
+ self.assertIsInstance(node, TestNode)
308
+
309
+
310
+ class TestNodeIntegration(unittest.TestCase):
311
+ """Integration tests for Node with the graph system."""
312
+
313
+ def test_node_with_nested_nodes(self):
314
+ """Test nodes containing other nodes."""
315
+
316
+ class ChildNode(Node):
317
+ def __init__(self, value=None):
318
+ if value is not None:
319
+ self.value = value
320
+
321
+ class ParentNode(Node):
322
+ def __init__(self):
323
+ self.child1 = ChildNode(10)
324
+ self.child2 = ChildNode(20)
325
+ self.data = [1, 2, 3]
326
+
327
+ parent = ParentNode()
328
+
329
+ # Test using treefy_split and merge
330
+ graphdef, state = brainstate.graph.treefy_split(parent)
331
+ copied = brainstate.graph.treefy_merge(graphdef, state)
332
+
333
+ self.assertIsNot(parent.child1, copied.child1)
334
+ self.assertEqual(parent.child1.value, copied.child1.value)
335
+
336
+ def test_node_with_list_of_nodes(self):
337
+ """Test nodes containing lists of other nodes."""
338
+
339
+ class ItemNode(Node):
340
+ def __init__(self, id=None):
341
+ if id is not None:
342
+ self.id = id
343
+
344
+ class ContainerNode(Node):
345
+ def __init__(self):
346
+ self.items = [ItemNode(i) for i in range(3)]
347
+
348
+ container = ContainerNode()
349
+ graphdef, state = brainstate.graph.treefy_split(container)
350
+ copied = brainstate.graph.treefy_merge(graphdef, state)
351
+
352
+ self.assertEqual(len(container.items), len(copied.items))
353
+ for orig, cp in zip(container.items, copied.items):
354
+ self.assertIsNot(orig, cp)
355
+ self.assertEqual(orig.id, cp.id)
356
+
357
+ def test_node_with_dict_of_nodes(self):
358
+ """Test nodes containing dictionaries of other nodes."""
359
+
360
+ class ValueNode(Node):
361
+ def __init__(self, value=None):
362
+ if value is not None:
363
+ self.value = value
364
+
365
+ class DictNode(Node):
366
+ def __init__(self):
367
+ self.mapping = {
368
+ 'a': ValueNode(1),
369
+ 'b': ValueNode(2),
370
+ 'c': ValueNode(3)
371
+ }
372
+
373
+ node = DictNode()
374
+ graphdef, state = brainstate.graph.treefy_split(node)
375
+ copied = brainstate.graph.treefy_merge(graphdef, state)
376
+
377
+ self.assertEqual(set(node.mapping.keys()), set(copied.mapping.keys()))
378
+ for key in node.mapping:
379
+ self.assertIsNot(node.mapping[key], copied.mapping[key])
380
+ self.assertEqual(node.mapping[key].value, copied.mapping[key].value)
381
+
382
+
383
+ class TestStateRetrieve(unittest.TestCase):
384
+ """Tests for state retrieval from nodes."""
385
+
386
+ def test_list_of_states_1(self):
387
+ """Test retrieving states from a list."""
388
+
389
+ class Model(brainstate.graph.Node):
390
+ def __init__(self):
391
+ self.a = [1, 2, 3]
392
+ self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
393
+
394
+ m = Model()
395
+ graphdef, states = brainstate.graph.treefy_split(m)
396
+ print(states.to_flat())
397
+ self.assertTrue(len(states.to_flat()) == 3)
398
+
399
+ def test_list_of_states_2(self):
400
+ """Test retrieving states from nested lists."""
401
+
402
+ class Model(brainstate.graph.Node):
403
+ def __init__(self):
404
+ self.a = [1, 2, 3]
405
+ self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
406
+
407
+ m = Model()
408
+ graphdef, states = brainstate.graph.treefy_split(m)
409
+ print(states.to_flat())
410
+ self.assertTrue(len(states.to_flat()) == 3)
411
+
412
+ def test_list_of_node_1(self):
413
+ """Test retrieving states from a list of nodes."""
414
+
415
+ class Model(brainstate.graph.Node):
416
+ def __init__(self):
417
+ self.a = [1, 2, 3]
418
+ self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
419
+
420
+ m = Model()
421
+ graphdef, states = brainstate.graph.treefy_split(m)
422
+ print(states.to_flat())
423
+ self.assertTrue(len(states.to_flat()) == 2)
424
+
425
+ def test_list_of_node_2(self):
426
+ """Test retrieving states from nested structures of nodes."""
427
+
428
+ class Model(brainstate.graph.Node):
429
+ def __init__(self):
430
+ self.a = [1, 2, 3]
431
+ self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)],
432
+ (brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
433
+
434
+ m = Model()
435
+ graphdef, states = brainstate.graph.treefy_split(m)
436
+ print(states.to_flat())
437
+ self.assertTrue(len(states.to_flat()) == 4)
438
+
439
+ def test_mixed_states_and_nodes(self):
440
+ """Test nodes with mixed states and sub-nodes."""
441
+
442
+ class Model(brainstate.graph.Node):
443
+ def __init__(self):
444
+ self.state1 = brainstate.State(1.0)
445
+ self.state2 = brainstate.State(2.0)
446
+ self.linear = brainstate.nn.Linear(5, 10)
447
+ self.data = [1, 2, 3]
448
+
449
+ m = Model()
450
+ graphdef, states = brainstate.graph.treefy_split(m)
451
+
452
+ # Should have states from both direct State objects and Linear layer
453
+ flat_states = states.to_flat()
454
+ self.assertGreaterEqual(len(flat_states), 2) # At least the two direct states
455
+
456
+
457
+ class TestEdgeCases(unittest.TestCase):
458
+ """Test edge cases and error conditions."""
459
+
460
+ def test_empty_node(self):
461
+ """Test node with no attributes."""
462
+
463
+ class EmptyNode(Node):
464
+ pass
465
+
466
+ node = EmptyNode()
467
+ flattened, static = _node_flatten(node)
468
+
469
+ self.assertEqual(len(flattened), 0)
470
+ self.assertEqual(static, (EmptyNode,))
471
+
472
+ def test_node_with_none_values(self):
473
+ """Test node with None values."""
474
+
475
+ class NoneNode(Node):
476
+ def __init__(self):
477
+ self.none_val = None
478
+ self.real_val = 10
479
+
480
+ node = NoneNode()
481
+ flattened, static = _node_flatten(node)
482
+
483
+ values_dict = dict(flattened)
484
+ self.assertIsNone(values_dict['none_val'])
485
+ self.assertEqual(values_dict['real_val'], 10)
486
+
487
+ def test_node_with_special_attributes(self):
488
+ """Test node with special Python attributes."""
489
+
490
+ class SpecialNode(Node):
491
+ def __init__(self):
492
+ self.__dict__['special'] = 'value'
493
+ self.normal = 'normal'
494
+
495
+ node = SpecialNode()
496
+ self.assertEqual(node.special, 'value')
497
+ self.assertEqual(node.normal, 'normal')
498
+
499
+ def test_circular_reference(self):
500
+ """Test handling of circular references."""
501
+
502
+ class CircularNode(Node):
503
+ pass
504
+
505
+ node1 = CircularNode()
506
+ node2 = CircularNode()
507
+ node1.ref = node2
508
+ node2.ref = node1
509
+
510
+ # This should not cause infinite recursion with treefy
511
+ try:
512
+ graphdef, state = brainstate.graph.treefy_split(node1)
513
+ copied = brainstate.graph.treefy_merge(graphdef, state)
514
+ # Check that circular reference is preserved
515
+ self.assertIs(copied.ref.ref, copied)
516
+ except RecursionError:
517
+ self.fail("Treefy failed with circular reference")
518
+
519
+ def test_node_inheritance(self):
520
+ """Test node inheritance hierarchy."""
521
+
522
+ class BaseNode(Node):
523
+ def __init__(self):
524
+ self.base_attr = 'base'
525
+
526
+ class DerivedNode(BaseNode):
527
+ def __init__(self):
528
+ super().__init__()
529
+ self.derived_attr = 'derived'
530
+
531
+ node = DerivedNode()
532
+ self.assertEqual(node.base_attr, 'base')
533
+ self.assertEqual(node.derived_attr, 'derived')
534
+
535
+ flattened, static = _node_flatten(node)
536
+ keys = [k for k, v in flattened]
537
+ self.assertIn('base_attr', keys)
538
+ self.assertIn('derived_attr', keys)
539
+
540
+ def test_node_with_property(self):
541
+ """Test node with property decorators."""
542
+
543
+ class PropertyNode(Node):
544
+ def __init__(self):
545
+ self._value = 10
546
+
547
+ @property
548
+ def value(self):
549
+ return self._value
550
+
551
+ @value.setter
552
+ def value(self, val):
553
+ self._value = val
554
+
555
+ node = PropertyNode()
556
+ self.assertEqual(node.value, 10)
557
+
558
+ node.value = 20
559
+ self.assertEqual(node.value, 20)
560
+
561
+ # Only _value should appear in flattened
562
+ flattened, static = _node_flatten(node)
563
+ keys = [k for k, v in flattened]
564
+ self.assertIn('_value', keys)
565
+
566
+ def test_multiple_inheritance(self):
567
+ """Test node with multiple inheritance."""
568
+
569
+ class Mixin:
570
+ def mixin_method(self):
571
+ return 'mixin'
572
+
573
+ class MultiNode(Node, Mixin):
574
+ def __init__(self):
575
+ self.data = 'data'
576
+
577
+ node = MultiNode()
578
+ self.assertEqual(node.mixin_method(), 'mixin')
579
+ self.assertEqual(node.data, 'data')
580
+
581
+ # Test that it still works as a Node with treefy
582
+ graphdef, state = brainstate.graph.treefy_split(node)
583
+ copied = brainstate.graph.treefy_merge(graphdef, state)
584
+ self.assertIsNot(node, copied)
585
+ self.assertEqual(copied.data, 'data')
586
+
587
+
588
+ if __name__ == '__main__':
589
+ unittest.main()