brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,589 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import brainstate
19
+ from brainstate._state import State
20
+ from brainstate.graph._node import (
21
+ Node,
22
+ _node_flatten,
23
+ _node_set_key,
24
+ _node_pop_key,
25
+ _node_create_empty,
26
+ _node_clear
27
+ )
28
+
29
+
30
+ class TestNode(unittest.TestCase):
31
+ """Test suite for the Node class."""
32
+
33
+ def test_node_creation(self):
34
+ """Test basic node creation."""
35
+
36
+ class SimpleNode(Node):
37
+ def __init__(self, value):
38
+ self.value = value
39
+
40
+ node = SimpleNode(10)
41
+ self.assertEqual(node.value, 10)
42
+ self.assertIsInstance(node, Node)
43
+
44
+ def test_node_subclass_registration(self):
45
+ """Test that Node subclasses are automatically registered."""
46
+
47
+ class TestNode(Node):
48
+ pass
49
+
50
+ # The subclass should be registered automatically via __init_subclass__
51
+ node = TestNode()
52
+ self.assertIsInstance(node, Node)
53
+
54
+ def test_graph_invisible_attrs(self):
55
+ """Test that graph_invisible_attrs works correctly."""
56
+
57
+ class NodeWithInvisible(Node):
58
+ graph_invisible_attrs = ('_private', '_internal')
59
+
60
+ def __init__(self):
61
+ self.public = 1
62
+ self._private = 2
63
+ self._internal = 3
64
+
65
+ node = NodeWithInvisible()
66
+ flattened, static = _node_flatten(node)
67
+
68
+ # Check that only public attribute is in flattened
69
+ keys = [k for k, v in flattened]
70
+ self.assertIn('public', keys)
71
+ self.assertNotIn('_private', keys)
72
+ self.assertNotIn('_internal', keys)
73
+
74
+ def test_deepcopy_using_treefy(self):
75
+ """Test deep copying of nodes using treefy_split/merge."""
76
+
77
+ class NodeWithData(Node):
78
+ def __init__(self, data=None):
79
+ if data is not None:
80
+ self.data = data
81
+ self.nested = {'a': 1, 'b': [2, 3]}
82
+
83
+ original = NodeWithData([1, 2, 3])
84
+
85
+ # Use treefy_split and treefy_merge to copy
86
+ graphdef, state = brainstate.graph.treefy_split(original)
87
+ # Create a new instance using treefy_merge
88
+ copied = brainstate.graph.treefy_merge(graphdef, state)
89
+
90
+ # Check that it's a different object
91
+ self.assertIsNot(original, copied)
92
+
93
+ # Check that data is present
94
+ self.assertEqual(original.data, copied.data)
95
+
96
+ # Modify copied data shouldn't affect original
97
+ copied.data.append(4)
98
+ self.assertEqual(len(original.data), 3)
99
+ self.assertEqual(len(copied.data), 4)
100
+
101
+ def test_node_with_state(self):
102
+ """Test nodes containing State objects."""
103
+
104
+ class NodeWithState(Node):
105
+ def __init__(self):
106
+ self.value = State(10)
107
+ self.normal = 20
108
+
109
+ node = NodeWithState()
110
+ self.assertIsInstance(node.value, State)
111
+ self.assertEqual(node.value.value, 10)
112
+ self.assertEqual(node.normal, 20)
113
+
114
+ def test_complex_nested_structure(self):
115
+ """Test nodes with complex nested structures using treefy."""
116
+
117
+ class ComplexNode(Node):
118
+ def __init__(self):
119
+ self.list_data = [1, 2, [3, 4]]
120
+ self.dict_data = {'a': 1, 'b': {'c': 2}}
121
+ self.tuple_data = (1, 2, (3, 4))
122
+
123
+ node = ComplexNode()
124
+
125
+ # Test using treefy_split and merge
126
+ graphdef, state = brainstate.graph.treefy_split(node)
127
+ copied = brainstate.graph.treefy_merge(graphdef, state)
128
+
129
+ self.assertEqual(node.list_data, copied.list_data)
130
+ self.assertEqual(node.dict_data, copied.dict_data)
131
+ self.assertEqual(node.tuple_data, copied.tuple_data)
132
+
133
+
134
+ class TestNodeHelperFunctions(unittest.TestCase):
135
+ """Test suite for node helper functions."""
136
+
137
+ def test_node_flatten(self):
138
+ """Test _node_flatten function."""
139
+
140
+ class TestNode(Node):
141
+ def __init__(self):
142
+ self.b = 2
143
+ self.a = 1
144
+ self.c = 3
145
+
146
+ node = TestNode()
147
+ flattened, static = _node_flatten(node)
148
+
149
+ # Check that attributes are sorted
150
+ keys = [k for k, v in flattened]
151
+ self.assertEqual(keys, ['a', 'b', 'c'])
152
+
153
+ # Check values
154
+ values = [v for k, v in flattened]
155
+ self.assertEqual(values, [1, 2, 3])
156
+
157
+ # Check static contains type
158
+ self.assertEqual(static, (TestNode,))
159
+
160
+ def test_node_flatten_with_invisible(self):
161
+ """Test _node_flatten with invisible attributes."""
162
+
163
+ class TestNode(Node):
164
+ graph_invisible_attrs = ('hidden',)
165
+
166
+ def __init__(self):
167
+ self.visible = 1
168
+ self.hidden = 2
169
+
170
+ node = TestNode()
171
+ flattened, static = _node_flatten(node)
172
+
173
+ keys = [k for k, v in flattened]
174
+ self.assertIn('visible', keys)
175
+ self.assertNotIn('hidden', keys)
176
+
177
+ def test_node_set_key_simple(self):
178
+ """Test _node_set_key with simple values."""
179
+
180
+ class TestNode(Node):
181
+ pass
182
+
183
+ node = TestNode()
184
+ _node_set_key(node, 'attr', 10)
185
+ self.assertEqual(node.attr, 10)
186
+
187
+ _node_set_key(node, 'attr', 20)
188
+ self.assertEqual(node.attr, 20)
189
+
190
+ def test_node_set_key_with_state(self):
191
+ """Test _node_set_key with State objects."""
192
+
193
+ class TestNode(Node):
194
+ def __init__(self):
195
+ self.state_attr = State(10)
196
+
197
+ node = TestNode()
198
+
199
+ # Test setting a regular value
200
+ _node_set_key(node, 'regular_attr', 30)
201
+ self.assertEqual(node.regular_attr, 30)
202
+
203
+ # Test updating with a TreefyState
204
+ # We'll use the real TreefyState from graph operations
205
+ graphdef, states = brainstate.graph.treefy_split(node)
206
+
207
+ # The states should contain our State object wrapped as TreefyState
208
+ # When setting with TreefyState, it should update the existing State
209
+ initial_state = node.state_attr
210
+
211
+ # Create a new node and try to set the TreefyState
212
+ new_node = TestNode()
213
+ for key, value in states.to_flat().items():
214
+ if 'state_attr' in key:
215
+ _node_set_key(new_node, 'state_attr', value)
216
+ # The State object should be updated via update_from_ref
217
+ self.assertIsInstance(new_node.state_attr, State)
218
+
219
+ def test_node_set_key_invalid_key(self):
220
+ """Test _node_set_key with invalid key."""
221
+
222
+ class TestNode(Node):
223
+ pass
224
+
225
+ node = TestNode()
226
+
227
+ with self.assertRaises(KeyError) as context:
228
+ _node_set_key(node, 123, 'value')
229
+ self.assertIn('Invalid key', str(context.exception))
230
+
231
+ def test_node_pop_key(self):
232
+ """Test _node_pop_key function."""
233
+
234
+ class TestNode(Node):
235
+ def __init__(self):
236
+ self.attr1 = 10
237
+ self.attr2 = 20
238
+
239
+ node = TestNode()
240
+
241
+ # Pop existing attribute
242
+ value = _node_pop_key(node, 'attr1')
243
+ self.assertEqual(value, 10)
244
+ self.assertFalse(hasattr(node, 'attr1'))
245
+ self.assertTrue(hasattr(node, 'attr2'))
246
+
247
+ def test_node_pop_key_invalid(self):
248
+ """Test _node_pop_key with invalid key."""
249
+
250
+ class TestNode(Node):
251
+ pass
252
+
253
+ node = TestNode()
254
+
255
+ # Invalid key type
256
+ with self.assertRaises(KeyError) as context:
257
+ _node_pop_key(node, 123)
258
+ self.assertIn('Invalid key', str(context.exception))
259
+
260
+ # Non-existent key
261
+ with self.assertRaises(KeyError):
262
+ _node_pop_key(node, 'nonexistent')
263
+
264
+ def test_node_create_empty(self):
265
+ """Test _node_create_empty function."""
266
+
267
+ class TestNode(Node):
268
+ def __init__(self, value=None):
269
+ self.value = value
270
+ self.initialized = True
271
+
272
+ # Create empty node
273
+ node = _node_create_empty((TestNode,))
274
+
275
+ # Check it's the right type
276
+ self.assertIsInstance(node, TestNode)
277
+
278
+ # Check __init__ was not called
279
+ self.assertFalse(hasattr(node, 'value'))
280
+ self.assertFalse(hasattr(node, 'initialized'))
281
+
282
+ def test_node_clear(self):
283
+ """Test _node_clear function."""
284
+
285
+ class TestNode(Node):
286
+ def __init__(self):
287
+ self.attr1 = 10
288
+ self.attr2 = 20
289
+ self.attr3 = [1, 2, 3]
290
+
291
+ node = TestNode()
292
+
293
+ # Verify attributes exist
294
+ self.assertTrue(hasattr(node, 'attr1'))
295
+ self.assertTrue(hasattr(node, 'attr2'))
296
+ self.assertTrue(hasattr(node, 'attr3'))
297
+
298
+ # Clear the node
299
+ _node_clear(node)
300
+
301
+ # Verify attributes are gone
302
+ self.assertFalse(hasattr(node, 'attr1'))
303
+ self.assertFalse(hasattr(node, 'attr2'))
304
+ self.assertFalse(hasattr(node, 'attr3'))
305
+
306
+ # Verify node still exists and is valid
307
+ self.assertIsInstance(node, TestNode)
308
+
309
+
310
+ class TestNodeIntegration(unittest.TestCase):
311
+ """Integration tests for Node with the graph system."""
312
+
313
+ def test_node_with_nested_nodes(self):
314
+ """Test nodes containing other nodes."""
315
+
316
+ class ChildNode(Node):
317
+ def __init__(self, value=None):
318
+ if value is not None:
319
+ self.value = value
320
+
321
+ class ParentNode(Node):
322
+ def __init__(self):
323
+ self.child1 = ChildNode(10)
324
+ self.child2 = ChildNode(20)
325
+ self.data = [1, 2, 3]
326
+
327
+ parent = ParentNode()
328
+
329
+ # Test using treefy_split and merge
330
+ graphdef, state = brainstate.graph.treefy_split(parent)
331
+ copied = brainstate.graph.treefy_merge(graphdef, state)
332
+
333
+ self.assertIsNot(parent.child1, copied.child1)
334
+ self.assertEqual(parent.child1.value, copied.child1.value)
335
+
336
+ def test_node_with_list_of_nodes(self):
337
+ """Test nodes containing lists of other nodes."""
338
+
339
+ class ItemNode(Node):
340
+ def __init__(self, id=None):
341
+ if id is not None:
342
+ self.id = id
343
+
344
+ class ContainerNode(Node):
345
+ def __init__(self):
346
+ self.items = [ItemNode(i) for i in range(3)]
347
+
348
+ container = ContainerNode()
349
+ graphdef, state = brainstate.graph.treefy_split(container)
350
+ copied = brainstate.graph.treefy_merge(graphdef, state)
351
+
352
+ self.assertEqual(len(container.items), len(copied.items))
353
+ for orig, cp in zip(container.items, copied.items):
354
+ self.assertIsNot(orig, cp)
355
+ self.assertEqual(orig.id, cp.id)
356
+
357
+ def test_node_with_dict_of_nodes(self):
358
+ """Test nodes containing dictionaries of other nodes."""
359
+
360
+ class ValueNode(Node):
361
+ def __init__(self, value=None):
362
+ if value is not None:
363
+ self.value = value
364
+
365
+ class DictNode(Node):
366
+ def __init__(self):
367
+ self.mapping = {
368
+ 'a': ValueNode(1),
369
+ 'b': ValueNode(2),
370
+ 'c': ValueNode(3)
371
+ }
372
+
373
+ node = DictNode()
374
+ graphdef, state = brainstate.graph.treefy_split(node)
375
+ copied = brainstate.graph.treefy_merge(graphdef, state)
376
+
377
+ self.assertEqual(set(node.mapping.keys()), set(copied.mapping.keys()))
378
+ for key in node.mapping:
379
+ self.assertIsNot(node.mapping[key], copied.mapping[key])
380
+ self.assertEqual(node.mapping[key].value, copied.mapping[key].value)
381
+
382
+
383
+ class TestStateRetrieve(unittest.TestCase):
384
+ """Tests for state retrieval from nodes."""
385
+
386
+ def test_list_of_states_1(self):
387
+ """Test retrieving states from a list."""
388
+
389
+ class Model(brainstate.graph.Node):
390
+ def __init__(self):
391
+ self.a = [1, 2, 3]
392
+ self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
393
+
394
+ m = Model()
395
+ graphdef, states = brainstate.graph.treefy_split(m)
396
+ print(states.to_flat())
397
+ self.assertTrue(len(states.to_flat()) == 3)
398
+
399
+ def test_list_of_states_2(self):
400
+ """Test retrieving states from nested lists."""
401
+
402
+ class Model(brainstate.graph.Node):
403
+ def __init__(self):
404
+ self.a = [1, 2, 3]
405
+ self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
406
+
407
+ m = Model()
408
+ graphdef, states = brainstate.graph.treefy_split(m)
409
+ print(states.to_flat())
410
+ self.assertTrue(len(states.to_flat()) == 3)
411
+
412
+ def test_list_of_node_1(self):
413
+ """Test retrieving states from a list of nodes."""
414
+
415
+ class Model(brainstate.graph.Node):
416
+ def __init__(self):
417
+ self.a = [1, 2, 3]
418
+ self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
419
+
420
+ m = Model()
421
+ graphdef, states = brainstate.graph.treefy_split(m)
422
+ print(states.to_flat())
423
+ self.assertTrue(len(states.to_flat()) == 2)
424
+
425
+ def test_list_of_node_2(self):
426
+ """Test retrieving states from nested structures of nodes."""
427
+
428
+ class Model(brainstate.graph.Node):
429
+ def __init__(self):
430
+ self.a = [1, 2, 3]
431
+ self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)],
432
+ (brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
433
+
434
+ m = Model()
435
+ graphdef, states = brainstate.graph.treefy_split(m)
436
+ print(states.to_flat())
437
+ self.assertTrue(len(states.to_flat()) == 4)
438
+
439
+ def test_mixed_states_and_nodes(self):
440
+ """Test nodes with mixed states and sub-nodes."""
441
+
442
+ class Model(brainstate.graph.Node):
443
+ def __init__(self):
444
+ self.state1 = brainstate.State(1.0)
445
+ self.state2 = brainstate.State(2.0)
446
+ self.linear = brainstate.nn.Linear(5, 10)
447
+ self.data = [1, 2, 3]
448
+
449
+ m = Model()
450
+ graphdef, states = brainstate.graph.treefy_split(m)
451
+
452
+ # Should have states from both direct State objects and Linear layer
453
+ flat_states = states.to_flat()
454
+ self.assertGreaterEqual(len(flat_states), 2) # At least the two direct states
455
+
456
+
457
+ class TestEdgeCases(unittest.TestCase):
458
+ """Test edge cases and error conditions."""
459
+
460
+ def test_empty_node(self):
461
+ """Test node with no attributes."""
462
+
463
+ class EmptyNode(Node):
464
+ pass
465
+
466
+ node = EmptyNode()
467
+ flattened, static = _node_flatten(node)
468
+
469
+ self.assertEqual(len(flattened), 0)
470
+ self.assertEqual(static, (EmptyNode,))
471
+
472
+ def test_node_with_none_values(self):
473
+ """Test node with None values."""
474
+
475
+ class NoneNode(Node):
476
+ def __init__(self):
477
+ self.none_val = None
478
+ self.real_val = 10
479
+
480
+ node = NoneNode()
481
+ flattened, static = _node_flatten(node)
482
+
483
+ values_dict = dict(flattened)
484
+ self.assertIsNone(values_dict['none_val'])
485
+ self.assertEqual(values_dict['real_val'], 10)
486
+
487
+ def test_node_with_special_attributes(self):
488
+ """Test node with special Python attributes."""
489
+
490
+ class SpecialNode(Node):
491
+ def __init__(self):
492
+ self.__dict__['special'] = 'value'
493
+ self.normal = 'normal'
494
+
495
+ node = SpecialNode()
496
+ self.assertEqual(node.special, 'value')
497
+ self.assertEqual(node.normal, 'normal')
498
+
499
+ def test_circular_reference(self):
500
+ """Test handling of circular references."""
501
+
502
+ class CircularNode(Node):
503
+ pass
504
+
505
+ node1 = CircularNode()
506
+ node2 = CircularNode()
507
+ node1.ref = node2
508
+ node2.ref = node1
509
+
510
+ # This should not cause infinite recursion with treefy
511
+ try:
512
+ graphdef, state = brainstate.graph.treefy_split(node1)
513
+ copied = brainstate.graph.treefy_merge(graphdef, state)
514
+ # Check that circular reference is preserved
515
+ self.assertIs(copied.ref.ref, copied)
516
+ except RecursionError:
517
+ self.fail("Treefy failed with circular reference")
518
+
519
+ def test_node_inheritance(self):
520
+ """Test node inheritance hierarchy."""
521
+
522
+ class BaseNode(Node):
523
+ def __init__(self):
524
+ self.base_attr = 'base'
525
+
526
+ class DerivedNode(BaseNode):
527
+ def __init__(self):
528
+ super().__init__()
529
+ self.derived_attr = 'derived'
530
+
531
+ node = DerivedNode()
532
+ self.assertEqual(node.base_attr, 'base')
533
+ self.assertEqual(node.derived_attr, 'derived')
534
+
535
+ flattened, static = _node_flatten(node)
536
+ keys = [k for k, v in flattened]
537
+ self.assertIn('base_attr', keys)
538
+ self.assertIn('derived_attr', keys)
539
+
540
+ def test_node_with_property(self):
541
+ """Test node with property decorators."""
542
+
543
+ class PropertyNode(Node):
544
+ def __init__(self):
545
+ self._value = 10
546
+
547
+ @property
548
+ def value(self):
549
+ return self._value
550
+
551
+ @value.setter
552
+ def value(self, val):
553
+ self._value = val
554
+
555
+ node = PropertyNode()
556
+ self.assertEqual(node.value, 10)
557
+
558
+ node.value = 20
559
+ self.assertEqual(node.value, 20)
560
+
561
+ # Only _value should appear in flattened
562
+ flattened, static = _node_flatten(node)
563
+ keys = [k for k, v in flattened]
564
+ self.assertIn('_value', keys)
565
+
566
+ def test_multiple_inheritance(self):
567
+ """Test node with multiple inheritance."""
568
+
569
+ class Mixin:
570
+ def mixin_method(self):
571
+ return 'mixin'
572
+
573
+ class MultiNode(Node, Mixin):
574
+ def __init__(self):
575
+ self.data = 'data'
576
+
577
+ node = MultiNode()
578
+ self.assertEqual(node.mixin_method(), 'mixin')
579
+ self.assertEqual(node.data, 'data')
580
+
581
+ # Test that it still works as a Node with treefy
582
+ graphdef, state = brainstate.graph.treefy_split(node)
583
+ copied = brainstate.graph.treefy_merge(graphdef, state)
584
+ self.assertIsNot(node, copied)
585
+ self.assertEqual(copied.data, 'data')
586
+
587
+
588
+ if __name__ == '__main__':
589
+ unittest.main()