brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,589 +1,589 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import brainstate
19
- from brainstate._state import State
20
- from brainstate.graph._node import (
21
- Node,
22
- _node_flatten,
23
- _node_set_key,
24
- _node_pop_key,
25
- _node_create_empty,
26
- _node_clear
27
- )
28
-
29
-
30
- class TestNode(unittest.TestCase):
31
- """Test suite for the Node class."""
32
-
33
- def test_node_creation(self):
34
- """Test basic node creation."""
35
-
36
- class SimpleNode(Node):
37
- def __init__(self, value):
38
- self.value = value
39
-
40
- node = SimpleNode(10)
41
- self.assertEqual(node.value, 10)
42
- self.assertIsInstance(node, Node)
43
-
44
- def test_node_subclass_registration(self):
45
- """Test that Node subclasses are automatically registered."""
46
-
47
- class TestNode(Node):
48
- pass
49
-
50
- # The subclass should be registered automatically via __init_subclass__
51
- node = TestNode()
52
- self.assertIsInstance(node, Node)
53
-
54
- def test_graph_invisible_attrs(self):
55
- """Test that graph_invisible_attrs works correctly."""
56
-
57
- class NodeWithInvisible(Node):
58
- graph_invisible_attrs = ('_private', '_internal')
59
-
60
- def __init__(self):
61
- self.public = 1
62
- self._private = 2
63
- self._internal = 3
64
-
65
- node = NodeWithInvisible()
66
- flattened, static = _node_flatten(node)
67
-
68
- # Check that only public attribute is in flattened
69
- keys = [k for k, v in flattened]
70
- self.assertIn('public', keys)
71
- self.assertNotIn('_private', keys)
72
- self.assertNotIn('_internal', keys)
73
-
74
- def test_deepcopy_using_treefy(self):
75
- """Test deep copying of nodes using treefy_split/merge."""
76
-
77
- class NodeWithData(Node):
78
- def __init__(self, data=None):
79
- if data is not None:
80
- self.data = data
81
- self.nested = {'a': 1, 'b': [2, 3]}
82
-
83
- original = NodeWithData([1, 2, 3])
84
-
85
- # Use treefy_split and treefy_merge to copy
86
- graphdef, state = brainstate.graph.treefy_split(original)
87
- # Create a new instance using treefy_merge
88
- copied = brainstate.graph.treefy_merge(graphdef, state)
89
-
90
- # Check that it's a different object
91
- self.assertIsNot(original, copied)
92
-
93
- # Check that data is present
94
- self.assertEqual(original.data, copied.data)
95
-
96
- # Modify copied data shouldn't affect original
97
- copied.data.append(4)
98
- self.assertEqual(len(original.data), 3)
99
- self.assertEqual(len(copied.data), 4)
100
-
101
- def test_node_with_state(self):
102
- """Test nodes containing State objects."""
103
-
104
- class NodeWithState(Node):
105
- def __init__(self):
106
- self.value = State(10)
107
- self.normal = 20
108
-
109
- node = NodeWithState()
110
- self.assertIsInstance(node.value, State)
111
- self.assertEqual(node.value.value, 10)
112
- self.assertEqual(node.normal, 20)
113
-
114
- def test_complex_nested_structure(self):
115
- """Test nodes with complex nested structures using treefy."""
116
-
117
- class ComplexNode(Node):
118
- def __init__(self):
119
- self.list_data = [1, 2, [3, 4]]
120
- self.dict_data = {'a': 1, 'b': {'c': 2}}
121
- self.tuple_data = (1, 2, (3, 4))
122
-
123
- node = ComplexNode()
124
-
125
- # Test using treefy_split and merge
126
- graphdef, state = brainstate.graph.treefy_split(node)
127
- copied = brainstate.graph.treefy_merge(graphdef, state)
128
-
129
- self.assertEqual(node.list_data, copied.list_data)
130
- self.assertEqual(node.dict_data, copied.dict_data)
131
- self.assertEqual(node.tuple_data, copied.tuple_data)
132
-
133
-
134
- class TestNodeHelperFunctions(unittest.TestCase):
135
- """Test suite for node helper functions."""
136
-
137
- def test_node_flatten(self):
138
- """Test _node_flatten function."""
139
-
140
- class TestNode(Node):
141
- def __init__(self):
142
- self.b = 2
143
- self.a = 1
144
- self.c = 3
145
-
146
- node = TestNode()
147
- flattened, static = _node_flatten(node)
148
-
149
- # Check that attributes are sorted
150
- keys = [k for k, v in flattened]
151
- self.assertEqual(keys, ['a', 'b', 'c'])
152
-
153
- # Check values
154
- values = [v for k, v in flattened]
155
- self.assertEqual(values, [1, 2, 3])
156
-
157
- # Check static contains type
158
- self.assertEqual(static, (TestNode,))
159
-
160
- def test_node_flatten_with_invisible(self):
161
- """Test _node_flatten with invisible attributes."""
162
-
163
- class TestNode(Node):
164
- graph_invisible_attrs = ('hidden',)
165
-
166
- def __init__(self):
167
- self.visible = 1
168
- self.hidden = 2
169
-
170
- node = TestNode()
171
- flattened, static = _node_flatten(node)
172
-
173
- keys = [k for k, v in flattened]
174
- self.assertIn('visible', keys)
175
- self.assertNotIn('hidden', keys)
176
-
177
- def test_node_set_key_simple(self):
178
- """Test _node_set_key with simple values."""
179
-
180
- class TestNode(Node):
181
- pass
182
-
183
- node = TestNode()
184
- _node_set_key(node, 'attr', 10)
185
- self.assertEqual(node.attr, 10)
186
-
187
- _node_set_key(node, 'attr', 20)
188
- self.assertEqual(node.attr, 20)
189
-
190
- def test_node_set_key_with_state(self):
191
- """Test _node_set_key with State objects."""
192
-
193
- class TestNode(Node):
194
- def __init__(self):
195
- self.state_attr = State(10)
196
-
197
- node = TestNode()
198
-
199
- # Test setting a regular value
200
- _node_set_key(node, 'regular_attr', 30)
201
- self.assertEqual(node.regular_attr, 30)
202
-
203
- # Test updating with a TreefyState
204
- # We'll use the real TreefyState from graph operations
205
- graphdef, states = brainstate.graph.treefy_split(node)
206
-
207
- # The states should contain our State object wrapped as TreefyState
208
- # When setting with TreefyState, it should update the existing State
209
- initial_state = node.state_attr
210
-
211
- # Create a new node and try to set the TreefyState
212
- new_node = TestNode()
213
- for key, value in states.to_flat().items():
214
- if 'state_attr' in key:
215
- _node_set_key(new_node, 'state_attr', value)
216
- # The State object should be updated via update_from_ref
217
- self.assertIsInstance(new_node.state_attr, State)
218
-
219
- def test_node_set_key_invalid_key(self):
220
- """Test _node_set_key with invalid key."""
221
-
222
- class TestNode(Node):
223
- pass
224
-
225
- node = TestNode()
226
-
227
- with self.assertRaises(KeyError) as context:
228
- _node_set_key(node, 123, 'value')
229
- self.assertIn('Invalid key', str(context.exception))
230
-
231
- def test_node_pop_key(self):
232
- """Test _node_pop_key function."""
233
-
234
- class TestNode(Node):
235
- def __init__(self):
236
- self.attr1 = 10
237
- self.attr2 = 20
238
-
239
- node = TestNode()
240
-
241
- # Pop existing attribute
242
- value = _node_pop_key(node, 'attr1')
243
- self.assertEqual(value, 10)
244
- self.assertFalse(hasattr(node, 'attr1'))
245
- self.assertTrue(hasattr(node, 'attr2'))
246
-
247
- def test_node_pop_key_invalid(self):
248
- """Test _node_pop_key with invalid key."""
249
-
250
- class TestNode(Node):
251
- pass
252
-
253
- node = TestNode()
254
-
255
- # Invalid key type
256
- with self.assertRaises(KeyError) as context:
257
- _node_pop_key(node, 123)
258
- self.assertIn('Invalid key', str(context.exception))
259
-
260
- # Non-existent key
261
- with self.assertRaises(KeyError):
262
- _node_pop_key(node, 'nonexistent')
263
-
264
- def test_node_create_empty(self):
265
- """Test _node_create_empty function."""
266
-
267
- class TestNode(Node):
268
- def __init__(self, value=None):
269
- self.value = value
270
- self.initialized = True
271
-
272
- # Create empty node
273
- node = _node_create_empty((TestNode,))
274
-
275
- # Check it's the right type
276
- self.assertIsInstance(node, TestNode)
277
-
278
- # Check __init__ was not called
279
- self.assertFalse(hasattr(node, 'value'))
280
- self.assertFalse(hasattr(node, 'initialized'))
281
-
282
- def test_node_clear(self):
283
- """Test _node_clear function."""
284
-
285
- class TestNode(Node):
286
- def __init__(self):
287
- self.attr1 = 10
288
- self.attr2 = 20
289
- self.attr3 = [1, 2, 3]
290
-
291
- node = TestNode()
292
-
293
- # Verify attributes exist
294
- self.assertTrue(hasattr(node, 'attr1'))
295
- self.assertTrue(hasattr(node, 'attr2'))
296
- self.assertTrue(hasattr(node, 'attr3'))
297
-
298
- # Clear the node
299
- _node_clear(node)
300
-
301
- # Verify attributes are gone
302
- self.assertFalse(hasattr(node, 'attr1'))
303
- self.assertFalse(hasattr(node, 'attr2'))
304
- self.assertFalse(hasattr(node, 'attr3'))
305
-
306
- # Verify node still exists and is valid
307
- self.assertIsInstance(node, TestNode)
308
-
309
-
310
- class TestNodeIntegration(unittest.TestCase):
311
- """Integration tests for Node with the graph system."""
312
-
313
- def test_node_with_nested_nodes(self):
314
- """Test nodes containing other nodes."""
315
-
316
- class ChildNode(Node):
317
- def __init__(self, value=None):
318
- if value is not None:
319
- self.value = value
320
-
321
- class ParentNode(Node):
322
- def __init__(self):
323
- self.child1 = ChildNode(10)
324
- self.child2 = ChildNode(20)
325
- self.data = [1, 2, 3]
326
-
327
- parent = ParentNode()
328
-
329
- # Test using treefy_split and merge
330
- graphdef, state = brainstate.graph.treefy_split(parent)
331
- copied = brainstate.graph.treefy_merge(graphdef, state)
332
-
333
- self.assertIsNot(parent.child1, copied.child1)
334
- self.assertEqual(parent.child1.value, copied.child1.value)
335
-
336
- def test_node_with_list_of_nodes(self):
337
- """Test nodes containing lists of other nodes."""
338
-
339
- class ItemNode(Node):
340
- def __init__(self, id=None):
341
- if id is not None:
342
- self.id = id
343
-
344
- class ContainerNode(Node):
345
- def __init__(self):
346
- self.items = [ItemNode(i) for i in range(3)]
347
-
348
- container = ContainerNode()
349
- graphdef, state = brainstate.graph.treefy_split(container)
350
- copied = brainstate.graph.treefy_merge(graphdef, state)
351
-
352
- self.assertEqual(len(container.items), len(copied.items))
353
- for orig, cp in zip(container.items, copied.items):
354
- self.assertIsNot(orig, cp)
355
- self.assertEqual(orig.id, cp.id)
356
-
357
- def test_node_with_dict_of_nodes(self):
358
- """Test nodes containing dictionaries of other nodes."""
359
-
360
- class ValueNode(Node):
361
- def __init__(self, value=None):
362
- if value is not None:
363
- self.value = value
364
-
365
- class DictNode(Node):
366
- def __init__(self):
367
- self.mapping = {
368
- 'a': ValueNode(1),
369
- 'b': ValueNode(2),
370
- 'c': ValueNode(3)
371
- }
372
-
373
- node = DictNode()
374
- graphdef, state = brainstate.graph.treefy_split(node)
375
- copied = brainstate.graph.treefy_merge(graphdef, state)
376
-
377
- self.assertEqual(set(node.mapping.keys()), set(copied.mapping.keys()))
378
- for key in node.mapping:
379
- self.assertIsNot(node.mapping[key], copied.mapping[key])
380
- self.assertEqual(node.mapping[key].value, copied.mapping[key].value)
381
-
382
-
383
- class TestStateRetrieve(unittest.TestCase):
384
- """Tests for state retrieval from nodes."""
385
-
386
- def test_list_of_states_1(self):
387
- """Test retrieving states from a list."""
388
-
389
- class Model(brainstate.graph.Node):
390
- def __init__(self):
391
- self.a = [1, 2, 3]
392
- self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
393
-
394
- m = Model()
395
- graphdef, states = brainstate.graph.treefy_split(m)
396
- print(states.to_flat())
397
- self.assertTrue(len(states.to_flat()) == 3)
398
-
399
- def test_list_of_states_2(self):
400
- """Test retrieving states from nested lists."""
401
-
402
- class Model(brainstate.graph.Node):
403
- def __init__(self):
404
- self.a = [1, 2, 3]
405
- self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
406
-
407
- m = Model()
408
- graphdef, states = brainstate.graph.treefy_split(m)
409
- print(states.to_flat())
410
- self.assertTrue(len(states.to_flat()) == 3)
411
-
412
- def test_list_of_node_1(self):
413
- """Test retrieving states from a list of nodes."""
414
-
415
- class Model(brainstate.graph.Node):
416
- def __init__(self):
417
- self.a = [1, 2, 3]
418
- self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
419
-
420
- m = Model()
421
- graphdef, states = brainstate.graph.treefy_split(m)
422
- print(states.to_flat())
423
- self.assertTrue(len(states.to_flat()) == 2)
424
-
425
- def test_list_of_node_2(self):
426
- """Test retrieving states from nested structures of nodes."""
427
-
428
- class Model(brainstate.graph.Node):
429
- def __init__(self):
430
- self.a = [1, 2, 3]
431
- self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)],
432
- (brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
433
-
434
- m = Model()
435
- graphdef, states = brainstate.graph.treefy_split(m)
436
- print(states.to_flat())
437
- self.assertTrue(len(states.to_flat()) == 4)
438
-
439
- def test_mixed_states_and_nodes(self):
440
- """Test nodes with mixed states and sub-nodes."""
441
-
442
- class Model(brainstate.graph.Node):
443
- def __init__(self):
444
- self.state1 = brainstate.State(1.0)
445
- self.state2 = brainstate.State(2.0)
446
- self.linear = brainstate.nn.Linear(5, 10)
447
- self.data = [1, 2, 3]
448
-
449
- m = Model()
450
- graphdef, states = brainstate.graph.treefy_split(m)
451
-
452
- # Should have states from both direct State objects and Linear layer
453
- flat_states = states.to_flat()
454
- self.assertGreaterEqual(len(flat_states), 2) # At least the two direct states
455
-
456
-
457
- class TestEdgeCases(unittest.TestCase):
458
- """Test edge cases and error conditions."""
459
-
460
- def test_empty_node(self):
461
- """Test node with no attributes."""
462
-
463
- class EmptyNode(Node):
464
- pass
465
-
466
- node = EmptyNode()
467
- flattened, static = _node_flatten(node)
468
-
469
- self.assertEqual(len(flattened), 0)
470
- self.assertEqual(static, (EmptyNode,))
471
-
472
- def test_node_with_none_values(self):
473
- """Test node with None values."""
474
-
475
- class NoneNode(Node):
476
- def __init__(self):
477
- self.none_val = None
478
- self.real_val = 10
479
-
480
- node = NoneNode()
481
- flattened, static = _node_flatten(node)
482
-
483
- values_dict = dict(flattened)
484
- self.assertIsNone(values_dict['none_val'])
485
- self.assertEqual(values_dict['real_val'], 10)
486
-
487
- def test_node_with_special_attributes(self):
488
- """Test node with special Python attributes."""
489
-
490
- class SpecialNode(Node):
491
- def __init__(self):
492
- self.__dict__['special'] = 'value'
493
- self.normal = 'normal'
494
-
495
- node = SpecialNode()
496
- self.assertEqual(node.special, 'value')
497
- self.assertEqual(node.normal, 'normal')
498
-
499
- def test_circular_reference(self):
500
- """Test handling of circular references."""
501
-
502
- class CircularNode(Node):
503
- pass
504
-
505
- node1 = CircularNode()
506
- node2 = CircularNode()
507
- node1.ref = node2
508
- node2.ref = node1
509
-
510
- # This should not cause infinite recursion with treefy
511
- try:
512
- graphdef, state = brainstate.graph.treefy_split(node1)
513
- copied = brainstate.graph.treefy_merge(graphdef, state)
514
- # Check that circular reference is preserved
515
- self.assertIs(copied.ref.ref, copied)
516
- except RecursionError:
517
- self.fail("Treefy failed with circular reference")
518
-
519
- def test_node_inheritance(self):
520
- """Test node inheritance hierarchy."""
521
-
522
- class BaseNode(Node):
523
- def __init__(self):
524
- self.base_attr = 'base'
525
-
526
- class DerivedNode(BaseNode):
527
- def __init__(self):
528
- super().__init__()
529
- self.derived_attr = 'derived'
530
-
531
- node = DerivedNode()
532
- self.assertEqual(node.base_attr, 'base')
533
- self.assertEqual(node.derived_attr, 'derived')
534
-
535
- flattened, static = _node_flatten(node)
536
- keys = [k for k, v in flattened]
537
- self.assertIn('base_attr', keys)
538
- self.assertIn('derived_attr', keys)
539
-
540
- def test_node_with_property(self):
541
- """Test node with property decorators."""
542
-
543
- class PropertyNode(Node):
544
- def __init__(self):
545
- self._value = 10
546
-
547
- @property
548
- def value(self):
549
- return self._value
550
-
551
- @value.setter
552
- def value(self, val):
553
- self._value = val
554
-
555
- node = PropertyNode()
556
- self.assertEqual(node.value, 10)
557
-
558
- node.value = 20
559
- self.assertEqual(node.value, 20)
560
-
561
- # Only _value should appear in flattened
562
- flattened, static = _node_flatten(node)
563
- keys = [k for k, v in flattened]
564
- self.assertIn('_value', keys)
565
-
566
- def test_multiple_inheritance(self):
567
- """Test node with multiple inheritance."""
568
-
569
- class Mixin:
570
- def mixin_method(self):
571
- return 'mixin'
572
-
573
- class MultiNode(Node, Mixin):
574
- def __init__(self):
575
- self.data = 'data'
576
-
577
- node = MultiNode()
578
- self.assertEqual(node.mixin_method(), 'mixin')
579
- self.assertEqual(node.data, 'data')
580
-
581
- # Test that it still works as a Node with treefy
582
- graphdef, state = brainstate.graph.treefy_split(node)
583
- copied = brainstate.graph.treefy_merge(graphdef, state)
584
- self.assertIsNot(node, copied)
585
- self.assertEqual(copied.data, 'data')
586
-
587
-
588
- if __name__ == '__main__':
589
- unittest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import brainstate
19
+ from brainstate._state import State
20
+ from brainstate.graph._node import (
21
+ Node,
22
+ _node_flatten,
23
+ _node_set_key,
24
+ _node_pop_key,
25
+ _node_create_empty,
26
+ _node_clear
27
+ )
28
+
29
+
30
+ class TestNode(unittest.TestCase):
31
+ """Test suite for the Node class."""
32
+
33
+ def test_node_creation(self):
34
+ """Test basic node creation."""
35
+
36
+ class SimpleNode(Node):
37
+ def __init__(self, value):
38
+ self.value = value
39
+
40
+ node = SimpleNode(10)
41
+ self.assertEqual(node.value, 10)
42
+ self.assertIsInstance(node, Node)
43
+
44
+ def test_node_subclass_registration(self):
45
+ """Test that Node subclasses are automatically registered."""
46
+
47
+ class TestNode(Node):
48
+ pass
49
+
50
+ # The subclass should be registered automatically via __init_subclass__
51
+ node = TestNode()
52
+ self.assertIsInstance(node, Node)
53
+
54
+ def test_graph_invisible_attrs(self):
55
+ """Test that graph_invisible_attrs works correctly."""
56
+
57
+ class NodeWithInvisible(Node):
58
+ graph_invisible_attrs = ('_private', '_internal')
59
+
60
+ def __init__(self):
61
+ self.public = 1
62
+ self._private = 2
63
+ self._internal = 3
64
+
65
+ node = NodeWithInvisible()
66
+ flattened, static = _node_flatten(node)
67
+
68
+ # Check that only public attribute is in flattened
69
+ keys = [k for k, v in flattened]
70
+ self.assertIn('public', keys)
71
+ self.assertNotIn('_private', keys)
72
+ self.assertNotIn('_internal', keys)
73
+
74
+ def test_deepcopy_using_treefy(self):
75
+ """Test deep copying of nodes using treefy_split/merge."""
76
+
77
+ class NodeWithData(Node):
78
+ def __init__(self, data=None):
79
+ if data is not None:
80
+ self.data = data
81
+ self.nested = {'a': 1, 'b': [2, 3]}
82
+
83
+ original = NodeWithData([1, 2, 3])
84
+
85
+ # Use treefy_split and treefy_merge to copy
86
+ graphdef, state = brainstate.graph.treefy_split(original)
87
+ # Create a new instance using treefy_merge
88
+ copied = brainstate.graph.treefy_merge(graphdef, state)
89
+
90
+ # Check that it's a different object
91
+ self.assertIsNot(original, copied)
92
+
93
+ # Check that data is present
94
+ self.assertEqual(original.data, copied.data)
95
+
96
+ # Modify copied data shouldn't affect original
97
+ copied.data.append(4)
98
+ self.assertEqual(len(original.data), 3)
99
+ self.assertEqual(len(copied.data), 4)
100
+
101
+ def test_node_with_state(self):
102
+ """Test nodes containing State objects."""
103
+
104
+ class NodeWithState(Node):
105
+ def __init__(self):
106
+ self.value = State(10)
107
+ self.normal = 20
108
+
109
+ node = NodeWithState()
110
+ self.assertIsInstance(node.value, State)
111
+ self.assertEqual(node.value.value, 10)
112
+ self.assertEqual(node.normal, 20)
113
+
114
+ def test_complex_nested_structure(self):
115
+ """Test nodes with complex nested structures using treefy."""
116
+
117
+ class ComplexNode(Node):
118
+ def __init__(self):
119
+ self.list_data = [1, 2, [3, 4]]
120
+ self.dict_data = {'a': 1, 'b': {'c': 2}}
121
+ self.tuple_data = (1, 2, (3, 4))
122
+
123
+ node = ComplexNode()
124
+
125
+ # Test using treefy_split and merge
126
+ graphdef, state = brainstate.graph.treefy_split(node)
127
+ copied = brainstate.graph.treefy_merge(graphdef, state)
128
+
129
+ self.assertEqual(node.list_data, copied.list_data)
130
+ self.assertEqual(node.dict_data, copied.dict_data)
131
+ self.assertEqual(node.tuple_data, copied.tuple_data)
132
+
133
+
134
+ class TestNodeHelperFunctions(unittest.TestCase):
135
+ """Test suite for node helper functions."""
136
+
137
+ def test_node_flatten(self):
138
+ """Test _node_flatten function."""
139
+
140
+ class TestNode(Node):
141
+ def __init__(self):
142
+ self.b = 2
143
+ self.a = 1
144
+ self.c = 3
145
+
146
+ node = TestNode()
147
+ flattened, static = _node_flatten(node)
148
+
149
+ # Check that attributes are sorted
150
+ keys = [k for k, v in flattened]
151
+ self.assertEqual(keys, ['a', 'b', 'c'])
152
+
153
+ # Check values
154
+ values = [v for k, v in flattened]
155
+ self.assertEqual(values, [1, 2, 3])
156
+
157
+ # Check static contains type
158
+ self.assertEqual(static, (TestNode,))
159
+
160
+ def test_node_flatten_with_invisible(self):
161
+ """Test _node_flatten with invisible attributes."""
162
+
163
+ class TestNode(Node):
164
+ graph_invisible_attrs = ('hidden',)
165
+
166
+ def __init__(self):
167
+ self.visible = 1
168
+ self.hidden = 2
169
+
170
+ node = TestNode()
171
+ flattened, static = _node_flatten(node)
172
+
173
+ keys = [k for k, v in flattened]
174
+ self.assertIn('visible', keys)
175
+ self.assertNotIn('hidden', keys)
176
+
177
+ def test_node_set_key_simple(self):
178
+ """Test _node_set_key with simple values."""
179
+
180
+ class TestNode(Node):
181
+ pass
182
+
183
+ node = TestNode()
184
+ _node_set_key(node, 'attr', 10)
185
+ self.assertEqual(node.attr, 10)
186
+
187
+ _node_set_key(node, 'attr', 20)
188
+ self.assertEqual(node.attr, 20)
189
+
190
+ def test_node_set_key_with_state(self):
191
+ """Test _node_set_key with State objects."""
192
+
193
+ class TestNode(Node):
194
+ def __init__(self):
195
+ self.state_attr = State(10)
196
+
197
+ node = TestNode()
198
+
199
+ # Test setting a regular value
200
+ _node_set_key(node, 'regular_attr', 30)
201
+ self.assertEqual(node.regular_attr, 30)
202
+
203
+ # Test updating with a TreefyState
204
+ # We'll use the real TreefyState from graph operations
205
+ graphdef, states = brainstate.graph.treefy_split(node)
206
+
207
+ # The states should contain our State object wrapped as TreefyState
208
+ # When setting with TreefyState, it should update the existing State
209
+ initial_state = node.state_attr
210
+
211
+ # Create a new node and try to set the TreefyState
212
+ new_node = TestNode()
213
+ for key, value in states.to_flat().items():
214
+ if 'state_attr' in key:
215
+ _node_set_key(new_node, 'state_attr', value)
216
+ # The State object should be updated via update_from_ref
217
+ self.assertIsInstance(new_node.state_attr, State)
218
+
219
+ def test_node_set_key_invalid_key(self):
220
+ """Test _node_set_key with invalid key."""
221
+
222
+ class TestNode(Node):
223
+ pass
224
+
225
+ node = TestNode()
226
+
227
+ with self.assertRaises(KeyError) as context:
228
+ _node_set_key(node, 123, 'value')
229
+ self.assertIn('Invalid key', str(context.exception))
230
+
231
+ def test_node_pop_key(self):
232
+ """Test _node_pop_key function."""
233
+
234
+ class TestNode(Node):
235
+ def __init__(self):
236
+ self.attr1 = 10
237
+ self.attr2 = 20
238
+
239
+ node = TestNode()
240
+
241
+ # Pop existing attribute
242
+ value = _node_pop_key(node, 'attr1')
243
+ self.assertEqual(value, 10)
244
+ self.assertFalse(hasattr(node, 'attr1'))
245
+ self.assertTrue(hasattr(node, 'attr2'))
246
+
247
+ def test_node_pop_key_invalid(self):
248
+ """Test _node_pop_key with invalid key."""
249
+
250
+ class TestNode(Node):
251
+ pass
252
+
253
+ node = TestNode()
254
+
255
+ # Invalid key type
256
+ with self.assertRaises(KeyError) as context:
257
+ _node_pop_key(node, 123)
258
+ self.assertIn('Invalid key', str(context.exception))
259
+
260
+ # Non-existent key
261
+ with self.assertRaises(KeyError):
262
+ _node_pop_key(node, 'nonexistent')
263
+
264
+ def test_node_create_empty(self):
265
+ """Test _node_create_empty function."""
266
+
267
+ class TestNode(Node):
268
+ def __init__(self, value=None):
269
+ self.value = value
270
+ self.initialized = True
271
+
272
+ # Create empty node
273
+ node = _node_create_empty((TestNode,))
274
+
275
+ # Check it's the right type
276
+ self.assertIsInstance(node, TestNode)
277
+
278
+ # Check __init__ was not called
279
+ self.assertFalse(hasattr(node, 'value'))
280
+ self.assertFalse(hasattr(node, 'initialized'))
281
+
282
+ def test_node_clear(self):
283
+ """Test _node_clear function."""
284
+
285
+ class TestNode(Node):
286
+ def __init__(self):
287
+ self.attr1 = 10
288
+ self.attr2 = 20
289
+ self.attr3 = [1, 2, 3]
290
+
291
+ node = TestNode()
292
+
293
+ # Verify attributes exist
294
+ self.assertTrue(hasattr(node, 'attr1'))
295
+ self.assertTrue(hasattr(node, 'attr2'))
296
+ self.assertTrue(hasattr(node, 'attr3'))
297
+
298
+ # Clear the node
299
+ _node_clear(node)
300
+
301
+ # Verify attributes are gone
302
+ self.assertFalse(hasattr(node, 'attr1'))
303
+ self.assertFalse(hasattr(node, 'attr2'))
304
+ self.assertFalse(hasattr(node, 'attr3'))
305
+
306
+ # Verify node still exists and is valid
307
+ self.assertIsInstance(node, TestNode)
308
+
309
+
310
+ class TestNodeIntegration(unittest.TestCase):
311
+ """Integration tests for Node with the graph system."""
312
+
313
+ def test_node_with_nested_nodes(self):
314
+ """Test nodes containing other nodes."""
315
+
316
+ class ChildNode(Node):
317
+ def __init__(self, value=None):
318
+ if value is not None:
319
+ self.value = value
320
+
321
+ class ParentNode(Node):
322
+ def __init__(self):
323
+ self.child1 = ChildNode(10)
324
+ self.child2 = ChildNode(20)
325
+ self.data = [1, 2, 3]
326
+
327
+ parent = ParentNode()
328
+
329
+ # Test using treefy_split and merge
330
+ graphdef, state = brainstate.graph.treefy_split(parent)
331
+ copied = brainstate.graph.treefy_merge(graphdef, state)
332
+
333
+ self.assertIsNot(parent.child1, copied.child1)
334
+ self.assertEqual(parent.child1.value, copied.child1.value)
335
+
336
+ def test_node_with_list_of_nodes(self):
337
+ """Test nodes containing lists of other nodes."""
338
+
339
+ class ItemNode(Node):
340
+ def __init__(self, id=None):
341
+ if id is not None:
342
+ self.id = id
343
+
344
+ class ContainerNode(Node):
345
+ def __init__(self):
346
+ self.items = [ItemNode(i) for i in range(3)]
347
+
348
+ container = ContainerNode()
349
+ graphdef, state = brainstate.graph.treefy_split(container)
350
+ copied = brainstate.graph.treefy_merge(graphdef, state)
351
+
352
+ self.assertEqual(len(container.items), len(copied.items))
353
+ for orig, cp in zip(container.items, copied.items):
354
+ self.assertIsNot(orig, cp)
355
+ self.assertEqual(orig.id, cp.id)
356
+
357
+ def test_node_with_dict_of_nodes(self):
358
+ """Test nodes containing dictionaries of other nodes."""
359
+
360
+ class ValueNode(Node):
361
+ def __init__(self, value=None):
362
+ if value is not None:
363
+ self.value = value
364
+
365
+ class DictNode(Node):
366
+ def __init__(self):
367
+ self.mapping = {
368
+ 'a': ValueNode(1),
369
+ 'b': ValueNode(2),
370
+ 'c': ValueNode(3)
371
+ }
372
+
373
+ node = DictNode()
374
+ graphdef, state = brainstate.graph.treefy_split(node)
375
+ copied = brainstate.graph.treefy_merge(graphdef, state)
376
+
377
+ self.assertEqual(set(node.mapping.keys()), set(copied.mapping.keys()))
378
+ for key in node.mapping:
379
+ self.assertIsNot(node.mapping[key], copied.mapping[key])
380
+ self.assertEqual(node.mapping[key].value, copied.mapping[key].value)
381
+
382
+
383
+ class TestStateRetrieve(unittest.TestCase):
384
+ """Tests for state retrieval from nodes."""
385
+
386
+ def test_list_of_states_1(self):
387
+ """Test retrieving states from a list."""
388
+
389
+ class Model(brainstate.graph.Node):
390
+ def __init__(self):
391
+ self.a = [1, 2, 3]
392
+ self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
393
+
394
+ m = Model()
395
+ graphdef, states = brainstate.graph.treefy_split(m)
396
+ print(states.to_flat())
397
+ self.assertTrue(len(states.to_flat()) == 3)
398
+
399
+ def test_list_of_states_2(self):
400
+ """Test retrieving states from nested lists."""
401
+
402
+ class Model(brainstate.graph.Node):
403
+ def __init__(self):
404
+ self.a = [1, 2, 3]
405
+ self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
406
+
407
+ m = Model()
408
+ graphdef, states = brainstate.graph.treefy_split(m)
409
+ print(states.to_flat())
410
+ self.assertTrue(len(states.to_flat()) == 3)
411
+
412
+ def test_list_of_node_1(self):
413
+ """Test retrieving states from a list of nodes."""
414
+
415
+ class Model(brainstate.graph.Node):
416
+ def __init__(self):
417
+ self.a = [1, 2, 3]
418
+ self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
419
+
420
+ m = Model()
421
+ graphdef, states = brainstate.graph.treefy_split(m)
422
+ print(states.to_flat())
423
+ self.assertTrue(len(states.to_flat()) == 2)
424
+
425
+ def test_list_of_node_2(self):
426
+ """Test retrieving states from nested structures of nodes."""
427
+
428
+ class Model(brainstate.graph.Node):
429
+ def __init__(self):
430
+ self.a = [1, 2, 3]
431
+ self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)],
432
+ (brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
433
+
434
+ m = Model()
435
+ graphdef, states = brainstate.graph.treefy_split(m)
436
+ print(states.to_flat())
437
+ self.assertTrue(len(states.to_flat()) == 4)
438
+
439
+ def test_mixed_states_and_nodes(self):
440
+ """Test nodes with mixed states and sub-nodes."""
441
+
442
+ class Model(brainstate.graph.Node):
443
+ def __init__(self):
444
+ self.state1 = brainstate.State(1.0)
445
+ self.state2 = brainstate.State(2.0)
446
+ self.linear = brainstate.nn.Linear(5, 10)
447
+ self.data = [1, 2, 3]
448
+
449
+ m = Model()
450
+ graphdef, states = brainstate.graph.treefy_split(m)
451
+
452
+ # Should have states from both direct State objects and Linear layer
453
+ flat_states = states.to_flat()
454
+ self.assertGreaterEqual(len(flat_states), 2) # At least the two direct states
455
+
456
+
457
+ class TestEdgeCases(unittest.TestCase):
458
+ """Test edge cases and error conditions."""
459
+
460
+ def test_empty_node(self):
461
+ """Test node with no attributes."""
462
+
463
+ class EmptyNode(Node):
464
+ pass
465
+
466
+ node = EmptyNode()
467
+ flattened, static = _node_flatten(node)
468
+
469
+ self.assertEqual(len(flattened), 0)
470
+ self.assertEqual(static, (EmptyNode,))
471
+
472
+ def test_node_with_none_values(self):
473
+ """Test node with None values."""
474
+
475
+ class NoneNode(Node):
476
+ def __init__(self):
477
+ self.none_val = None
478
+ self.real_val = 10
479
+
480
+ node = NoneNode()
481
+ flattened, static = _node_flatten(node)
482
+
483
+ values_dict = dict(flattened)
484
+ self.assertIsNone(values_dict['none_val'])
485
+ self.assertEqual(values_dict['real_val'], 10)
486
+
487
+ def test_node_with_special_attributes(self):
488
+ """Test node with special Python attributes."""
489
+
490
+ class SpecialNode(Node):
491
+ def __init__(self):
492
+ self.__dict__['special'] = 'value'
493
+ self.normal = 'normal'
494
+
495
+ node = SpecialNode()
496
+ self.assertEqual(node.special, 'value')
497
+ self.assertEqual(node.normal, 'normal')
498
+
499
+ def test_circular_reference(self):
500
+ """Test handling of circular references."""
501
+
502
+ class CircularNode(Node):
503
+ pass
504
+
505
+ node1 = CircularNode()
506
+ node2 = CircularNode()
507
+ node1.ref = node2
508
+ node2.ref = node1
509
+
510
+ # This should not cause infinite recursion with treefy
511
+ try:
512
+ graphdef, state = brainstate.graph.treefy_split(node1)
513
+ copied = brainstate.graph.treefy_merge(graphdef, state)
514
+ # Check that circular reference is preserved
515
+ self.assertIs(copied.ref.ref, copied)
516
+ except RecursionError:
517
+ self.fail("Treefy failed with circular reference")
518
+
519
+ def test_node_inheritance(self):
520
+ """Test node inheritance hierarchy."""
521
+
522
+ class BaseNode(Node):
523
+ def __init__(self):
524
+ self.base_attr = 'base'
525
+
526
+ class DerivedNode(BaseNode):
527
+ def __init__(self):
528
+ super().__init__()
529
+ self.derived_attr = 'derived'
530
+
531
+ node = DerivedNode()
532
+ self.assertEqual(node.base_attr, 'base')
533
+ self.assertEqual(node.derived_attr, 'derived')
534
+
535
+ flattened, static = _node_flatten(node)
536
+ keys = [k for k, v in flattened]
537
+ self.assertIn('base_attr', keys)
538
+ self.assertIn('derived_attr', keys)
539
+
540
+ def test_node_with_property(self):
541
+ """Test node with property decorators."""
542
+
543
+ class PropertyNode(Node):
544
+ def __init__(self):
545
+ self._value = 10
546
+
547
+ @property
548
+ def value(self):
549
+ return self._value
550
+
551
+ @value.setter
552
+ def value(self, val):
553
+ self._value = val
554
+
555
+ node = PropertyNode()
556
+ self.assertEqual(node.value, 10)
557
+
558
+ node.value = 20
559
+ self.assertEqual(node.value, 20)
560
+
561
+ # Only _value should appear in flattened
562
+ flattened, static = _node_flatten(node)
563
+ keys = [k for k, v in flattened]
564
+ self.assertIn('_value', keys)
565
+
566
+ def test_multiple_inheritance(self):
567
+ """Test node with multiple inheritance."""
568
+
569
+ class Mixin:
570
+ def mixin_method(self):
571
+ return 'mixin'
572
+
573
+ class MultiNode(Node, Mixin):
574
+ def __init__(self):
575
+ self.data = 'data'
576
+
577
+ node = MultiNode()
578
+ self.assertEqual(node.mixin_method(), 'mixin')
579
+ self.assertEqual(node.data, 'data')
580
+
581
+ # Test that it still works as a Node with treefy
582
+ graphdef, state = brainstate.graph.treefy_split(node)
583
+ copied = brainstate.graph.treefy_merge(graphdef, state)
584
+ self.assertIsNot(node, copied)
585
+ self.assertEqual(copied.data, 'data')
586
+
587
+
588
+ if __name__ == '__main__':
589
+ unittest.main()