brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/graph/_node_test.py
CHANGED
@@ -1,589 +1,589 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import brainstate
|
19
|
-
from brainstate._state import State
|
20
|
-
from brainstate.graph._node import (
|
21
|
-
Node,
|
22
|
-
_node_flatten,
|
23
|
-
_node_set_key,
|
24
|
-
_node_pop_key,
|
25
|
-
_node_create_empty,
|
26
|
-
_node_clear
|
27
|
-
)
|
28
|
-
|
29
|
-
|
30
|
-
class TestNode(unittest.TestCase):
|
31
|
-
"""Test suite for the Node class."""
|
32
|
-
|
33
|
-
def test_node_creation(self):
|
34
|
-
"""Test basic node creation."""
|
35
|
-
|
36
|
-
class SimpleNode(Node):
|
37
|
-
def __init__(self, value):
|
38
|
-
self.value = value
|
39
|
-
|
40
|
-
node = SimpleNode(10)
|
41
|
-
self.assertEqual(node.value, 10)
|
42
|
-
self.assertIsInstance(node, Node)
|
43
|
-
|
44
|
-
def test_node_subclass_registration(self):
|
45
|
-
"""Test that Node subclasses are automatically registered."""
|
46
|
-
|
47
|
-
class TestNode(Node):
|
48
|
-
pass
|
49
|
-
|
50
|
-
# The subclass should be registered automatically via __init_subclass__
|
51
|
-
node = TestNode()
|
52
|
-
self.assertIsInstance(node, Node)
|
53
|
-
|
54
|
-
def test_graph_invisible_attrs(self):
|
55
|
-
"""Test that graph_invisible_attrs works correctly."""
|
56
|
-
|
57
|
-
class NodeWithInvisible(Node):
|
58
|
-
graph_invisible_attrs = ('_private', '_internal')
|
59
|
-
|
60
|
-
def __init__(self):
|
61
|
-
self.public = 1
|
62
|
-
self._private = 2
|
63
|
-
self._internal = 3
|
64
|
-
|
65
|
-
node = NodeWithInvisible()
|
66
|
-
flattened, static = _node_flatten(node)
|
67
|
-
|
68
|
-
# Check that only public attribute is in flattened
|
69
|
-
keys = [k for k, v in flattened]
|
70
|
-
self.assertIn('public', keys)
|
71
|
-
self.assertNotIn('_private', keys)
|
72
|
-
self.assertNotIn('_internal', keys)
|
73
|
-
|
74
|
-
def test_deepcopy_using_treefy(self):
|
75
|
-
"""Test deep copying of nodes using treefy_split/merge."""
|
76
|
-
|
77
|
-
class NodeWithData(Node):
|
78
|
-
def __init__(self, data=None):
|
79
|
-
if data is not None:
|
80
|
-
self.data = data
|
81
|
-
self.nested = {'a': 1, 'b': [2, 3]}
|
82
|
-
|
83
|
-
original = NodeWithData([1, 2, 3])
|
84
|
-
|
85
|
-
# Use treefy_split and treefy_merge to copy
|
86
|
-
graphdef, state = brainstate.graph.treefy_split(original)
|
87
|
-
# Create a new instance using treefy_merge
|
88
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
89
|
-
|
90
|
-
# Check that it's a different object
|
91
|
-
self.assertIsNot(original, copied)
|
92
|
-
|
93
|
-
# Check that data is present
|
94
|
-
self.assertEqual(original.data, copied.data)
|
95
|
-
|
96
|
-
# Modify copied data shouldn't affect original
|
97
|
-
copied.data.append(4)
|
98
|
-
self.assertEqual(len(original.data), 3)
|
99
|
-
self.assertEqual(len(copied.data), 4)
|
100
|
-
|
101
|
-
def test_node_with_state(self):
|
102
|
-
"""Test nodes containing State objects."""
|
103
|
-
|
104
|
-
class NodeWithState(Node):
|
105
|
-
def __init__(self):
|
106
|
-
self.value = State(10)
|
107
|
-
self.normal = 20
|
108
|
-
|
109
|
-
node = NodeWithState()
|
110
|
-
self.assertIsInstance(node.value, State)
|
111
|
-
self.assertEqual(node.value.value, 10)
|
112
|
-
self.assertEqual(node.normal, 20)
|
113
|
-
|
114
|
-
def test_complex_nested_structure(self):
|
115
|
-
"""Test nodes with complex nested structures using treefy."""
|
116
|
-
|
117
|
-
class ComplexNode(Node):
|
118
|
-
def __init__(self):
|
119
|
-
self.list_data = [1, 2, [3, 4]]
|
120
|
-
self.dict_data = {'a': 1, 'b': {'c': 2}}
|
121
|
-
self.tuple_data = (1, 2, (3, 4))
|
122
|
-
|
123
|
-
node = ComplexNode()
|
124
|
-
|
125
|
-
# Test using treefy_split and merge
|
126
|
-
graphdef, state = brainstate.graph.treefy_split(node)
|
127
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
128
|
-
|
129
|
-
self.assertEqual(node.list_data, copied.list_data)
|
130
|
-
self.assertEqual(node.dict_data, copied.dict_data)
|
131
|
-
self.assertEqual(node.tuple_data, copied.tuple_data)
|
132
|
-
|
133
|
-
|
134
|
-
class TestNodeHelperFunctions(unittest.TestCase):
|
135
|
-
"""Test suite for node helper functions."""
|
136
|
-
|
137
|
-
def test_node_flatten(self):
|
138
|
-
"""Test _node_flatten function."""
|
139
|
-
|
140
|
-
class TestNode(Node):
|
141
|
-
def __init__(self):
|
142
|
-
self.b = 2
|
143
|
-
self.a = 1
|
144
|
-
self.c = 3
|
145
|
-
|
146
|
-
node = TestNode()
|
147
|
-
flattened, static = _node_flatten(node)
|
148
|
-
|
149
|
-
# Check that attributes are sorted
|
150
|
-
keys = [k for k, v in flattened]
|
151
|
-
self.assertEqual(keys, ['a', 'b', 'c'])
|
152
|
-
|
153
|
-
# Check values
|
154
|
-
values = [v for k, v in flattened]
|
155
|
-
self.assertEqual(values, [1, 2, 3])
|
156
|
-
|
157
|
-
# Check static contains type
|
158
|
-
self.assertEqual(static, (TestNode,))
|
159
|
-
|
160
|
-
def test_node_flatten_with_invisible(self):
|
161
|
-
"""Test _node_flatten with invisible attributes."""
|
162
|
-
|
163
|
-
class TestNode(Node):
|
164
|
-
graph_invisible_attrs = ('hidden',)
|
165
|
-
|
166
|
-
def __init__(self):
|
167
|
-
self.visible = 1
|
168
|
-
self.hidden = 2
|
169
|
-
|
170
|
-
node = TestNode()
|
171
|
-
flattened, static = _node_flatten(node)
|
172
|
-
|
173
|
-
keys = [k for k, v in flattened]
|
174
|
-
self.assertIn('visible', keys)
|
175
|
-
self.assertNotIn('hidden', keys)
|
176
|
-
|
177
|
-
def test_node_set_key_simple(self):
|
178
|
-
"""Test _node_set_key with simple values."""
|
179
|
-
|
180
|
-
class TestNode(Node):
|
181
|
-
pass
|
182
|
-
|
183
|
-
node = TestNode()
|
184
|
-
_node_set_key(node, 'attr', 10)
|
185
|
-
self.assertEqual(node.attr, 10)
|
186
|
-
|
187
|
-
_node_set_key(node, 'attr', 20)
|
188
|
-
self.assertEqual(node.attr, 20)
|
189
|
-
|
190
|
-
def test_node_set_key_with_state(self):
|
191
|
-
"""Test _node_set_key with State objects."""
|
192
|
-
|
193
|
-
class TestNode(Node):
|
194
|
-
def __init__(self):
|
195
|
-
self.state_attr = State(10)
|
196
|
-
|
197
|
-
node = TestNode()
|
198
|
-
|
199
|
-
# Test setting a regular value
|
200
|
-
_node_set_key(node, 'regular_attr', 30)
|
201
|
-
self.assertEqual(node.regular_attr, 30)
|
202
|
-
|
203
|
-
# Test updating with a TreefyState
|
204
|
-
# We'll use the real TreefyState from graph operations
|
205
|
-
graphdef, states = brainstate.graph.treefy_split(node)
|
206
|
-
|
207
|
-
# The states should contain our State object wrapped as TreefyState
|
208
|
-
# When setting with TreefyState, it should update the existing State
|
209
|
-
initial_state = node.state_attr
|
210
|
-
|
211
|
-
# Create a new node and try to set the TreefyState
|
212
|
-
new_node = TestNode()
|
213
|
-
for key, value in states.to_flat().items():
|
214
|
-
if 'state_attr' in key:
|
215
|
-
_node_set_key(new_node, 'state_attr', value)
|
216
|
-
# The State object should be updated via update_from_ref
|
217
|
-
self.assertIsInstance(new_node.state_attr, State)
|
218
|
-
|
219
|
-
def test_node_set_key_invalid_key(self):
|
220
|
-
"""Test _node_set_key with invalid key."""
|
221
|
-
|
222
|
-
class TestNode(Node):
|
223
|
-
pass
|
224
|
-
|
225
|
-
node = TestNode()
|
226
|
-
|
227
|
-
with self.assertRaises(KeyError) as context:
|
228
|
-
_node_set_key(node, 123, 'value')
|
229
|
-
self.assertIn('Invalid key', str(context.exception))
|
230
|
-
|
231
|
-
def test_node_pop_key(self):
|
232
|
-
"""Test _node_pop_key function."""
|
233
|
-
|
234
|
-
class TestNode(Node):
|
235
|
-
def __init__(self):
|
236
|
-
self.attr1 = 10
|
237
|
-
self.attr2 = 20
|
238
|
-
|
239
|
-
node = TestNode()
|
240
|
-
|
241
|
-
# Pop existing attribute
|
242
|
-
value = _node_pop_key(node, 'attr1')
|
243
|
-
self.assertEqual(value, 10)
|
244
|
-
self.assertFalse(hasattr(node, 'attr1'))
|
245
|
-
self.assertTrue(hasattr(node, 'attr2'))
|
246
|
-
|
247
|
-
def test_node_pop_key_invalid(self):
|
248
|
-
"""Test _node_pop_key with invalid key."""
|
249
|
-
|
250
|
-
class TestNode(Node):
|
251
|
-
pass
|
252
|
-
|
253
|
-
node = TestNode()
|
254
|
-
|
255
|
-
# Invalid key type
|
256
|
-
with self.assertRaises(KeyError) as context:
|
257
|
-
_node_pop_key(node, 123)
|
258
|
-
self.assertIn('Invalid key', str(context.exception))
|
259
|
-
|
260
|
-
# Non-existent key
|
261
|
-
with self.assertRaises(KeyError):
|
262
|
-
_node_pop_key(node, 'nonexistent')
|
263
|
-
|
264
|
-
def test_node_create_empty(self):
|
265
|
-
"""Test _node_create_empty function."""
|
266
|
-
|
267
|
-
class TestNode(Node):
|
268
|
-
def __init__(self, value=None):
|
269
|
-
self.value = value
|
270
|
-
self.initialized = True
|
271
|
-
|
272
|
-
# Create empty node
|
273
|
-
node = _node_create_empty((TestNode,))
|
274
|
-
|
275
|
-
# Check it's the right type
|
276
|
-
self.assertIsInstance(node, TestNode)
|
277
|
-
|
278
|
-
# Check __init__ was not called
|
279
|
-
self.assertFalse(hasattr(node, 'value'))
|
280
|
-
self.assertFalse(hasattr(node, 'initialized'))
|
281
|
-
|
282
|
-
def test_node_clear(self):
|
283
|
-
"""Test _node_clear function."""
|
284
|
-
|
285
|
-
class TestNode(Node):
|
286
|
-
def __init__(self):
|
287
|
-
self.attr1 = 10
|
288
|
-
self.attr2 = 20
|
289
|
-
self.attr3 = [1, 2, 3]
|
290
|
-
|
291
|
-
node = TestNode()
|
292
|
-
|
293
|
-
# Verify attributes exist
|
294
|
-
self.assertTrue(hasattr(node, 'attr1'))
|
295
|
-
self.assertTrue(hasattr(node, 'attr2'))
|
296
|
-
self.assertTrue(hasattr(node, 'attr3'))
|
297
|
-
|
298
|
-
# Clear the node
|
299
|
-
_node_clear(node)
|
300
|
-
|
301
|
-
# Verify attributes are gone
|
302
|
-
self.assertFalse(hasattr(node, 'attr1'))
|
303
|
-
self.assertFalse(hasattr(node, 'attr2'))
|
304
|
-
self.assertFalse(hasattr(node, 'attr3'))
|
305
|
-
|
306
|
-
# Verify node still exists and is valid
|
307
|
-
self.assertIsInstance(node, TestNode)
|
308
|
-
|
309
|
-
|
310
|
-
class TestNodeIntegration(unittest.TestCase):
|
311
|
-
"""Integration tests for Node with the graph system."""
|
312
|
-
|
313
|
-
def test_node_with_nested_nodes(self):
|
314
|
-
"""Test nodes containing other nodes."""
|
315
|
-
|
316
|
-
class ChildNode(Node):
|
317
|
-
def __init__(self, value=None):
|
318
|
-
if value is not None:
|
319
|
-
self.value = value
|
320
|
-
|
321
|
-
class ParentNode(Node):
|
322
|
-
def __init__(self):
|
323
|
-
self.child1 = ChildNode(10)
|
324
|
-
self.child2 = ChildNode(20)
|
325
|
-
self.data = [1, 2, 3]
|
326
|
-
|
327
|
-
parent = ParentNode()
|
328
|
-
|
329
|
-
# Test using treefy_split and merge
|
330
|
-
graphdef, state = brainstate.graph.treefy_split(parent)
|
331
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
332
|
-
|
333
|
-
self.assertIsNot(parent.child1, copied.child1)
|
334
|
-
self.assertEqual(parent.child1.value, copied.child1.value)
|
335
|
-
|
336
|
-
def test_node_with_list_of_nodes(self):
|
337
|
-
"""Test nodes containing lists of other nodes."""
|
338
|
-
|
339
|
-
class ItemNode(Node):
|
340
|
-
def __init__(self, id=None):
|
341
|
-
if id is not None:
|
342
|
-
self.id = id
|
343
|
-
|
344
|
-
class ContainerNode(Node):
|
345
|
-
def __init__(self):
|
346
|
-
self.items = [ItemNode(i) for i in range(3)]
|
347
|
-
|
348
|
-
container = ContainerNode()
|
349
|
-
graphdef, state = brainstate.graph.treefy_split(container)
|
350
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
351
|
-
|
352
|
-
self.assertEqual(len(container.items), len(copied.items))
|
353
|
-
for orig, cp in zip(container.items, copied.items):
|
354
|
-
self.assertIsNot(orig, cp)
|
355
|
-
self.assertEqual(orig.id, cp.id)
|
356
|
-
|
357
|
-
def test_node_with_dict_of_nodes(self):
|
358
|
-
"""Test nodes containing dictionaries of other nodes."""
|
359
|
-
|
360
|
-
class ValueNode(Node):
|
361
|
-
def __init__(self, value=None):
|
362
|
-
if value is not None:
|
363
|
-
self.value = value
|
364
|
-
|
365
|
-
class DictNode(Node):
|
366
|
-
def __init__(self):
|
367
|
-
self.mapping = {
|
368
|
-
'a': ValueNode(1),
|
369
|
-
'b': ValueNode(2),
|
370
|
-
'c': ValueNode(3)
|
371
|
-
}
|
372
|
-
|
373
|
-
node = DictNode()
|
374
|
-
graphdef, state = brainstate.graph.treefy_split(node)
|
375
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
376
|
-
|
377
|
-
self.assertEqual(set(node.mapping.keys()), set(copied.mapping.keys()))
|
378
|
-
for key in node.mapping:
|
379
|
-
self.assertIsNot(node.mapping[key], copied.mapping[key])
|
380
|
-
self.assertEqual(node.mapping[key].value, copied.mapping[key].value)
|
381
|
-
|
382
|
-
|
383
|
-
class TestStateRetrieve(unittest.TestCase):
|
384
|
-
"""Tests for state retrieval from nodes."""
|
385
|
-
|
386
|
-
def test_list_of_states_1(self):
|
387
|
-
"""Test retrieving states from a list."""
|
388
|
-
|
389
|
-
class Model(brainstate.graph.Node):
|
390
|
-
def __init__(self):
|
391
|
-
self.a = [1, 2, 3]
|
392
|
-
self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
|
393
|
-
|
394
|
-
m = Model()
|
395
|
-
graphdef, states = brainstate.graph.treefy_split(m)
|
396
|
-
print(states.to_flat())
|
397
|
-
self.assertTrue(len(states.to_flat()) == 3)
|
398
|
-
|
399
|
-
def test_list_of_states_2(self):
|
400
|
-
"""Test retrieving states from nested lists."""
|
401
|
-
|
402
|
-
class Model(brainstate.graph.Node):
|
403
|
-
def __init__(self):
|
404
|
-
self.a = [1, 2, 3]
|
405
|
-
self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
|
406
|
-
|
407
|
-
m = Model()
|
408
|
-
graphdef, states = brainstate.graph.treefy_split(m)
|
409
|
-
print(states.to_flat())
|
410
|
-
self.assertTrue(len(states.to_flat()) == 3)
|
411
|
-
|
412
|
-
def test_list_of_node_1(self):
|
413
|
-
"""Test retrieving states from a list of nodes."""
|
414
|
-
|
415
|
-
class Model(brainstate.graph.Node):
|
416
|
-
def __init__(self):
|
417
|
-
self.a = [1, 2, 3]
|
418
|
-
self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
|
419
|
-
|
420
|
-
m = Model()
|
421
|
-
graphdef, states = brainstate.graph.treefy_split(m)
|
422
|
-
print(states.to_flat())
|
423
|
-
self.assertTrue(len(states.to_flat()) == 2)
|
424
|
-
|
425
|
-
def test_list_of_node_2(self):
|
426
|
-
"""Test retrieving states from nested structures of nodes."""
|
427
|
-
|
428
|
-
class Model(brainstate.graph.Node):
|
429
|
-
def __init__(self):
|
430
|
-
self.a = [1, 2, 3]
|
431
|
-
self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)],
|
432
|
-
(brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
|
433
|
-
|
434
|
-
m = Model()
|
435
|
-
graphdef, states = brainstate.graph.treefy_split(m)
|
436
|
-
print(states.to_flat())
|
437
|
-
self.assertTrue(len(states.to_flat()) == 4)
|
438
|
-
|
439
|
-
def test_mixed_states_and_nodes(self):
|
440
|
-
"""Test nodes with mixed states and sub-nodes."""
|
441
|
-
|
442
|
-
class Model(brainstate.graph.Node):
|
443
|
-
def __init__(self):
|
444
|
-
self.state1 = brainstate.State(1.0)
|
445
|
-
self.state2 = brainstate.State(2.0)
|
446
|
-
self.linear = brainstate.nn.Linear(5, 10)
|
447
|
-
self.data = [1, 2, 3]
|
448
|
-
|
449
|
-
m = Model()
|
450
|
-
graphdef, states = brainstate.graph.treefy_split(m)
|
451
|
-
|
452
|
-
# Should have states from both direct State objects and Linear layer
|
453
|
-
flat_states = states.to_flat()
|
454
|
-
self.assertGreaterEqual(len(flat_states), 2) # At least the two direct states
|
455
|
-
|
456
|
-
|
457
|
-
class TestEdgeCases(unittest.TestCase):
|
458
|
-
"""Test edge cases and error conditions."""
|
459
|
-
|
460
|
-
def test_empty_node(self):
|
461
|
-
"""Test node with no attributes."""
|
462
|
-
|
463
|
-
class EmptyNode(Node):
|
464
|
-
pass
|
465
|
-
|
466
|
-
node = EmptyNode()
|
467
|
-
flattened, static = _node_flatten(node)
|
468
|
-
|
469
|
-
self.assertEqual(len(flattened), 0)
|
470
|
-
self.assertEqual(static, (EmptyNode,))
|
471
|
-
|
472
|
-
def test_node_with_none_values(self):
|
473
|
-
"""Test node with None values."""
|
474
|
-
|
475
|
-
class NoneNode(Node):
|
476
|
-
def __init__(self):
|
477
|
-
self.none_val = None
|
478
|
-
self.real_val = 10
|
479
|
-
|
480
|
-
node = NoneNode()
|
481
|
-
flattened, static = _node_flatten(node)
|
482
|
-
|
483
|
-
values_dict = dict(flattened)
|
484
|
-
self.assertIsNone(values_dict['none_val'])
|
485
|
-
self.assertEqual(values_dict['real_val'], 10)
|
486
|
-
|
487
|
-
def test_node_with_special_attributes(self):
|
488
|
-
"""Test node with special Python attributes."""
|
489
|
-
|
490
|
-
class SpecialNode(Node):
|
491
|
-
def __init__(self):
|
492
|
-
self.__dict__['special'] = 'value'
|
493
|
-
self.normal = 'normal'
|
494
|
-
|
495
|
-
node = SpecialNode()
|
496
|
-
self.assertEqual(node.special, 'value')
|
497
|
-
self.assertEqual(node.normal, 'normal')
|
498
|
-
|
499
|
-
def test_circular_reference(self):
|
500
|
-
"""Test handling of circular references."""
|
501
|
-
|
502
|
-
class CircularNode(Node):
|
503
|
-
pass
|
504
|
-
|
505
|
-
node1 = CircularNode()
|
506
|
-
node2 = CircularNode()
|
507
|
-
node1.ref = node2
|
508
|
-
node2.ref = node1
|
509
|
-
|
510
|
-
# This should not cause infinite recursion with treefy
|
511
|
-
try:
|
512
|
-
graphdef, state = brainstate.graph.treefy_split(node1)
|
513
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
514
|
-
# Check that circular reference is preserved
|
515
|
-
self.assertIs(copied.ref.ref, copied)
|
516
|
-
except RecursionError:
|
517
|
-
self.fail("Treefy failed with circular reference")
|
518
|
-
|
519
|
-
def test_node_inheritance(self):
|
520
|
-
"""Test node inheritance hierarchy."""
|
521
|
-
|
522
|
-
class BaseNode(Node):
|
523
|
-
def __init__(self):
|
524
|
-
self.base_attr = 'base'
|
525
|
-
|
526
|
-
class DerivedNode(BaseNode):
|
527
|
-
def __init__(self):
|
528
|
-
super().__init__()
|
529
|
-
self.derived_attr = 'derived'
|
530
|
-
|
531
|
-
node = DerivedNode()
|
532
|
-
self.assertEqual(node.base_attr, 'base')
|
533
|
-
self.assertEqual(node.derived_attr, 'derived')
|
534
|
-
|
535
|
-
flattened, static = _node_flatten(node)
|
536
|
-
keys = [k for k, v in flattened]
|
537
|
-
self.assertIn('base_attr', keys)
|
538
|
-
self.assertIn('derived_attr', keys)
|
539
|
-
|
540
|
-
def test_node_with_property(self):
|
541
|
-
"""Test node with property decorators."""
|
542
|
-
|
543
|
-
class PropertyNode(Node):
|
544
|
-
def __init__(self):
|
545
|
-
self._value = 10
|
546
|
-
|
547
|
-
@property
|
548
|
-
def value(self):
|
549
|
-
return self._value
|
550
|
-
|
551
|
-
@value.setter
|
552
|
-
def value(self, val):
|
553
|
-
self._value = val
|
554
|
-
|
555
|
-
node = PropertyNode()
|
556
|
-
self.assertEqual(node.value, 10)
|
557
|
-
|
558
|
-
node.value = 20
|
559
|
-
self.assertEqual(node.value, 20)
|
560
|
-
|
561
|
-
# Only _value should appear in flattened
|
562
|
-
flattened, static = _node_flatten(node)
|
563
|
-
keys = [k for k, v in flattened]
|
564
|
-
self.assertIn('_value', keys)
|
565
|
-
|
566
|
-
def test_multiple_inheritance(self):
|
567
|
-
"""Test node with multiple inheritance."""
|
568
|
-
|
569
|
-
class Mixin:
|
570
|
-
def mixin_method(self):
|
571
|
-
return 'mixin'
|
572
|
-
|
573
|
-
class MultiNode(Node, Mixin):
|
574
|
-
def __init__(self):
|
575
|
-
self.data = 'data'
|
576
|
-
|
577
|
-
node = MultiNode()
|
578
|
-
self.assertEqual(node.mixin_method(), 'mixin')
|
579
|
-
self.assertEqual(node.data, 'data')
|
580
|
-
|
581
|
-
# Test that it still works as a Node with treefy
|
582
|
-
graphdef, state = brainstate.graph.treefy_split(node)
|
583
|
-
copied = brainstate.graph.treefy_merge(graphdef, state)
|
584
|
-
self.assertIsNot(node, copied)
|
585
|
-
self.assertEqual(copied.data, 'data')
|
586
|
-
|
587
|
-
|
588
|
-
if __name__ == '__main__':
|
589
|
-
unittest.main()
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import brainstate
|
19
|
+
from brainstate._state import State
|
20
|
+
from brainstate.graph._node import (
|
21
|
+
Node,
|
22
|
+
_node_flatten,
|
23
|
+
_node_set_key,
|
24
|
+
_node_pop_key,
|
25
|
+
_node_create_empty,
|
26
|
+
_node_clear
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
class TestNode(unittest.TestCase):
|
31
|
+
"""Test suite for the Node class."""
|
32
|
+
|
33
|
+
def test_node_creation(self):
|
34
|
+
"""Test basic node creation."""
|
35
|
+
|
36
|
+
class SimpleNode(Node):
|
37
|
+
def __init__(self, value):
|
38
|
+
self.value = value
|
39
|
+
|
40
|
+
node = SimpleNode(10)
|
41
|
+
self.assertEqual(node.value, 10)
|
42
|
+
self.assertIsInstance(node, Node)
|
43
|
+
|
44
|
+
def test_node_subclass_registration(self):
|
45
|
+
"""Test that Node subclasses are automatically registered."""
|
46
|
+
|
47
|
+
class TestNode(Node):
|
48
|
+
pass
|
49
|
+
|
50
|
+
# The subclass should be registered automatically via __init_subclass__
|
51
|
+
node = TestNode()
|
52
|
+
self.assertIsInstance(node, Node)
|
53
|
+
|
54
|
+
def test_graph_invisible_attrs(self):
|
55
|
+
"""Test that graph_invisible_attrs works correctly."""
|
56
|
+
|
57
|
+
class NodeWithInvisible(Node):
|
58
|
+
graph_invisible_attrs = ('_private', '_internal')
|
59
|
+
|
60
|
+
def __init__(self):
|
61
|
+
self.public = 1
|
62
|
+
self._private = 2
|
63
|
+
self._internal = 3
|
64
|
+
|
65
|
+
node = NodeWithInvisible()
|
66
|
+
flattened, static = _node_flatten(node)
|
67
|
+
|
68
|
+
# Check that only public attribute is in flattened
|
69
|
+
keys = [k for k, v in flattened]
|
70
|
+
self.assertIn('public', keys)
|
71
|
+
self.assertNotIn('_private', keys)
|
72
|
+
self.assertNotIn('_internal', keys)
|
73
|
+
|
74
|
+
def test_deepcopy_using_treefy(self):
|
75
|
+
"""Test deep copying of nodes using treefy_split/merge."""
|
76
|
+
|
77
|
+
class NodeWithData(Node):
|
78
|
+
def __init__(self, data=None):
|
79
|
+
if data is not None:
|
80
|
+
self.data = data
|
81
|
+
self.nested = {'a': 1, 'b': [2, 3]}
|
82
|
+
|
83
|
+
original = NodeWithData([1, 2, 3])
|
84
|
+
|
85
|
+
# Use treefy_split and treefy_merge to copy
|
86
|
+
graphdef, state = brainstate.graph.treefy_split(original)
|
87
|
+
# Create a new instance using treefy_merge
|
88
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
89
|
+
|
90
|
+
# Check that it's a different object
|
91
|
+
self.assertIsNot(original, copied)
|
92
|
+
|
93
|
+
# Check that data is present
|
94
|
+
self.assertEqual(original.data, copied.data)
|
95
|
+
|
96
|
+
# Modify copied data shouldn't affect original
|
97
|
+
copied.data.append(4)
|
98
|
+
self.assertEqual(len(original.data), 3)
|
99
|
+
self.assertEqual(len(copied.data), 4)
|
100
|
+
|
101
|
+
def test_node_with_state(self):
|
102
|
+
"""Test nodes containing State objects."""
|
103
|
+
|
104
|
+
class NodeWithState(Node):
|
105
|
+
def __init__(self):
|
106
|
+
self.value = State(10)
|
107
|
+
self.normal = 20
|
108
|
+
|
109
|
+
node = NodeWithState()
|
110
|
+
self.assertIsInstance(node.value, State)
|
111
|
+
self.assertEqual(node.value.value, 10)
|
112
|
+
self.assertEqual(node.normal, 20)
|
113
|
+
|
114
|
+
def test_complex_nested_structure(self):
|
115
|
+
"""Test nodes with complex nested structures using treefy."""
|
116
|
+
|
117
|
+
class ComplexNode(Node):
|
118
|
+
def __init__(self):
|
119
|
+
self.list_data = [1, 2, [3, 4]]
|
120
|
+
self.dict_data = {'a': 1, 'b': {'c': 2}}
|
121
|
+
self.tuple_data = (1, 2, (3, 4))
|
122
|
+
|
123
|
+
node = ComplexNode()
|
124
|
+
|
125
|
+
# Test using treefy_split and merge
|
126
|
+
graphdef, state = brainstate.graph.treefy_split(node)
|
127
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
128
|
+
|
129
|
+
self.assertEqual(node.list_data, copied.list_data)
|
130
|
+
self.assertEqual(node.dict_data, copied.dict_data)
|
131
|
+
self.assertEqual(node.tuple_data, copied.tuple_data)
|
132
|
+
|
133
|
+
|
134
|
+
class TestNodeHelperFunctions(unittest.TestCase):
|
135
|
+
"""Test suite for node helper functions."""
|
136
|
+
|
137
|
+
def test_node_flatten(self):
|
138
|
+
"""Test _node_flatten function."""
|
139
|
+
|
140
|
+
class TestNode(Node):
|
141
|
+
def __init__(self):
|
142
|
+
self.b = 2
|
143
|
+
self.a = 1
|
144
|
+
self.c = 3
|
145
|
+
|
146
|
+
node = TestNode()
|
147
|
+
flattened, static = _node_flatten(node)
|
148
|
+
|
149
|
+
# Check that attributes are sorted
|
150
|
+
keys = [k for k, v in flattened]
|
151
|
+
self.assertEqual(keys, ['a', 'b', 'c'])
|
152
|
+
|
153
|
+
# Check values
|
154
|
+
values = [v for k, v in flattened]
|
155
|
+
self.assertEqual(values, [1, 2, 3])
|
156
|
+
|
157
|
+
# Check static contains type
|
158
|
+
self.assertEqual(static, (TestNode,))
|
159
|
+
|
160
|
+
def test_node_flatten_with_invisible(self):
|
161
|
+
"""Test _node_flatten with invisible attributes."""
|
162
|
+
|
163
|
+
class TestNode(Node):
|
164
|
+
graph_invisible_attrs = ('hidden',)
|
165
|
+
|
166
|
+
def __init__(self):
|
167
|
+
self.visible = 1
|
168
|
+
self.hidden = 2
|
169
|
+
|
170
|
+
node = TestNode()
|
171
|
+
flattened, static = _node_flatten(node)
|
172
|
+
|
173
|
+
keys = [k for k, v in flattened]
|
174
|
+
self.assertIn('visible', keys)
|
175
|
+
self.assertNotIn('hidden', keys)
|
176
|
+
|
177
|
+
def test_node_set_key_simple(self):
|
178
|
+
"""Test _node_set_key with simple values."""
|
179
|
+
|
180
|
+
class TestNode(Node):
|
181
|
+
pass
|
182
|
+
|
183
|
+
node = TestNode()
|
184
|
+
_node_set_key(node, 'attr', 10)
|
185
|
+
self.assertEqual(node.attr, 10)
|
186
|
+
|
187
|
+
_node_set_key(node, 'attr', 20)
|
188
|
+
self.assertEqual(node.attr, 20)
|
189
|
+
|
190
|
+
def test_node_set_key_with_state(self):
|
191
|
+
"""Test _node_set_key with State objects."""
|
192
|
+
|
193
|
+
class TestNode(Node):
|
194
|
+
def __init__(self):
|
195
|
+
self.state_attr = State(10)
|
196
|
+
|
197
|
+
node = TestNode()
|
198
|
+
|
199
|
+
# Test setting a regular value
|
200
|
+
_node_set_key(node, 'regular_attr', 30)
|
201
|
+
self.assertEqual(node.regular_attr, 30)
|
202
|
+
|
203
|
+
# Test updating with a TreefyState
|
204
|
+
# We'll use the real TreefyState from graph operations
|
205
|
+
graphdef, states = brainstate.graph.treefy_split(node)
|
206
|
+
|
207
|
+
# The states should contain our State object wrapped as TreefyState
|
208
|
+
# When setting with TreefyState, it should update the existing State
|
209
|
+
initial_state = node.state_attr
|
210
|
+
|
211
|
+
# Create a new node and try to set the TreefyState
|
212
|
+
new_node = TestNode()
|
213
|
+
for key, value in states.to_flat().items():
|
214
|
+
if 'state_attr' in key:
|
215
|
+
_node_set_key(new_node, 'state_attr', value)
|
216
|
+
# The State object should be updated via update_from_ref
|
217
|
+
self.assertIsInstance(new_node.state_attr, State)
|
218
|
+
|
219
|
+
def test_node_set_key_invalid_key(self):
|
220
|
+
"""Test _node_set_key with invalid key."""
|
221
|
+
|
222
|
+
class TestNode(Node):
|
223
|
+
pass
|
224
|
+
|
225
|
+
node = TestNode()
|
226
|
+
|
227
|
+
with self.assertRaises(KeyError) as context:
|
228
|
+
_node_set_key(node, 123, 'value')
|
229
|
+
self.assertIn('Invalid key', str(context.exception))
|
230
|
+
|
231
|
+
def test_node_pop_key(self):
|
232
|
+
"""Test _node_pop_key function."""
|
233
|
+
|
234
|
+
class TestNode(Node):
|
235
|
+
def __init__(self):
|
236
|
+
self.attr1 = 10
|
237
|
+
self.attr2 = 20
|
238
|
+
|
239
|
+
node = TestNode()
|
240
|
+
|
241
|
+
# Pop existing attribute
|
242
|
+
value = _node_pop_key(node, 'attr1')
|
243
|
+
self.assertEqual(value, 10)
|
244
|
+
self.assertFalse(hasattr(node, 'attr1'))
|
245
|
+
self.assertTrue(hasattr(node, 'attr2'))
|
246
|
+
|
247
|
+
def test_node_pop_key_invalid(self):
|
248
|
+
"""Test _node_pop_key with invalid key."""
|
249
|
+
|
250
|
+
class TestNode(Node):
|
251
|
+
pass
|
252
|
+
|
253
|
+
node = TestNode()
|
254
|
+
|
255
|
+
# Invalid key type
|
256
|
+
with self.assertRaises(KeyError) as context:
|
257
|
+
_node_pop_key(node, 123)
|
258
|
+
self.assertIn('Invalid key', str(context.exception))
|
259
|
+
|
260
|
+
# Non-existent key
|
261
|
+
with self.assertRaises(KeyError):
|
262
|
+
_node_pop_key(node, 'nonexistent')
|
263
|
+
|
264
|
+
def test_node_create_empty(self):
|
265
|
+
"""Test _node_create_empty function."""
|
266
|
+
|
267
|
+
class TestNode(Node):
|
268
|
+
def __init__(self, value=None):
|
269
|
+
self.value = value
|
270
|
+
self.initialized = True
|
271
|
+
|
272
|
+
# Create empty node
|
273
|
+
node = _node_create_empty((TestNode,))
|
274
|
+
|
275
|
+
# Check it's the right type
|
276
|
+
self.assertIsInstance(node, TestNode)
|
277
|
+
|
278
|
+
# Check __init__ was not called
|
279
|
+
self.assertFalse(hasattr(node, 'value'))
|
280
|
+
self.assertFalse(hasattr(node, 'initialized'))
|
281
|
+
|
282
|
+
def test_node_clear(self):
|
283
|
+
"""Test _node_clear function."""
|
284
|
+
|
285
|
+
class TestNode(Node):
|
286
|
+
def __init__(self):
|
287
|
+
self.attr1 = 10
|
288
|
+
self.attr2 = 20
|
289
|
+
self.attr3 = [1, 2, 3]
|
290
|
+
|
291
|
+
node = TestNode()
|
292
|
+
|
293
|
+
# Verify attributes exist
|
294
|
+
self.assertTrue(hasattr(node, 'attr1'))
|
295
|
+
self.assertTrue(hasattr(node, 'attr2'))
|
296
|
+
self.assertTrue(hasattr(node, 'attr3'))
|
297
|
+
|
298
|
+
# Clear the node
|
299
|
+
_node_clear(node)
|
300
|
+
|
301
|
+
# Verify attributes are gone
|
302
|
+
self.assertFalse(hasattr(node, 'attr1'))
|
303
|
+
self.assertFalse(hasattr(node, 'attr2'))
|
304
|
+
self.assertFalse(hasattr(node, 'attr3'))
|
305
|
+
|
306
|
+
# Verify node still exists and is valid
|
307
|
+
self.assertIsInstance(node, TestNode)
|
308
|
+
|
309
|
+
|
310
|
+
class TestNodeIntegration(unittest.TestCase):
|
311
|
+
"""Integration tests for Node with the graph system."""
|
312
|
+
|
313
|
+
def test_node_with_nested_nodes(self):
|
314
|
+
"""Test nodes containing other nodes."""
|
315
|
+
|
316
|
+
class ChildNode(Node):
|
317
|
+
def __init__(self, value=None):
|
318
|
+
if value is not None:
|
319
|
+
self.value = value
|
320
|
+
|
321
|
+
class ParentNode(Node):
|
322
|
+
def __init__(self):
|
323
|
+
self.child1 = ChildNode(10)
|
324
|
+
self.child2 = ChildNode(20)
|
325
|
+
self.data = [1, 2, 3]
|
326
|
+
|
327
|
+
parent = ParentNode()
|
328
|
+
|
329
|
+
# Test using treefy_split and merge
|
330
|
+
graphdef, state = brainstate.graph.treefy_split(parent)
|
331
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
332
|
+
|
333
|
+
self.assertIsNot(parent.child1, copied.child1)
|
334
|
+
self.assertEqual(parent.child1.value, copied.child1.value)
|
335
|
+
|
336
|
+
def test_node_with_list_of_nodes(self):
|
337
|
+
"""Test nodes containing lists of other nodes."""
|
338
|
+
|
339
|
+
class ItemNode(Node):
|
340
|
+
def __init__(self, id=None):
|
341
|
+
if id is not None:
|
342
|
+
self.id = id
|
343
|
+
|
344
|
+
class ContainerNode(Node):
|
345
|
+
def __init__(self):
|
346
|
+
self.items = [ItemNode(i) for i in range(3)]
|
347
|
+
|
348
|
+
container = ContainerNode()
|
349
|
+
graphdef, state = brainstate.graph.treefy_split(container)
|
350
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
351
|
+
|
352
|
+
self.assertEqual(len(container.items), len(copied.items))
|
353
|
+
for orig, cp in zip(container.items, copied.items):
|
354
|
+
self.assertIsNot(orig, cp)
|
355
|
+
self.assertEqual(orig.id, cp.id)
|
356
|
+
|
357
|
+
def test_node_with_dict_of_nodes(self):
|
358
|
+
"""Test nodes containing dictionaries of other nodes."""
|
359
|
+
|
360
|
+
class ValueNode(Node):
|
361
|
+
def __init__(self, value=None):
|
362
|
+
if value is not None:
|
363
|
+
self.value = value
|
364
|
+
|
365
|
+
class DictNode(Node):
|
366
|
+
def __init__(self):
|
367
|
+
self.mapping = {
|
368
|
+
'a': ValueNode(1),
|
369
|
+
'b': ValueNode(2),
|
370
|
+
'c': ValueNode(3)
|
371
|
+
}
|
372
|
+
|
373
|
+
node = DictNode()
|
374
|
+
graphdef, state = brainstate.graph.treefy_split(node)
|
375
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
376
|
+
|
377
|
+
self.assertEqual(set(node.mapping.keys()), set(copied.mapping.keys()))
|
378
|
+
for key in node.mapping:
|
379
|
+
self.assertIsNot(node.mapping[key], copied.mapping[key])
|
380
|
+
self.assertEqual(node.mapping[key].value, copied.mapping[key].value)
|
381
|
+
|
382
|
+
|
383
|
+
class TestStateRetrieve(unittest.TestCase):
|
384
|
+
"""Tests for state retrieval from nodes."""
|
385
|
+
|
386
|
+
def test_list_of_states_1(self):
|
387
|
+
"""Test retrieving states from a list."""
|
388
|
+
|
389
|
+
class Model(brainstate.graph.Node):
|
390
|
+
def __init__(self):
|
391
|
+
self.a = [1, 2, 3]
|
392
|
+
self.b = [brainstate.State(1), brainstate.State(2), brainstate.State(3)]
|
393
|
+
|
394
|
+
m = Model()
|
395
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
396
|
+
print(states.to_flat())
|
397
|
+
self.assertTrue(len(states.to_flat()) == 3)
|
398
|
+
|
399
|
+
def test_list_of_states_2(self):
|
400
|
+
"""Test retrieving states from nested lists."""
|
401
|
+
|
402
|
+
class Model(brainstate.graph.Node):
|
403
|
+
def __init__(self):
|
404
|
+
self.a = [1, 2, 3]
|
405
|
+
self.b = [brainstate.State(1), [brainstate.State(2), brainstate.State(3)]]
|
406
|
+
|
407
|
+
m = Model()
|
408
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
409
|
+
print(states.to_flat())
|
410
|
+
self.assertTrue(len(states.to_flat()) == 3)
|
411
|
+
|
412
|
+
def test_list_of_node_1(self):
|
413
|
+
"""Test retrieving states from a list of nodes."""
|
414
|
+
|
415
|
+
class Model(brainstate.graph.Node):
|
416
|
+
def __init__(self):
|
417
|
+
self.a = [1, 2, 3]
|
418
|
+
self.b = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
|
419
|
+
|
420
|
+
m = Model()
|
421
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
422
|
+
print(states.to_flat())
|
423
|
+
self.assertTrue(len(states.to_flat()) == 2)
|
424
|
+
|
425
|
+
def test_list_of_node_2(self):
|
426
|
+
"""Test retrieving states from nested structures of nodes."""
|
427
|
+
|
428
|
+
class Model(brainstate.graph.Node):
|
429
|
+
def __init__(self):
|
430
|
+
self.a = [1, 2, 3]
|
431
|
+
self.b = [brainstate.nn.Linear(1, 2), [brainstate.nn.Linear(2, 3)],
|
432
|
+
(brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5))]
|
433
|
+
|
434
|
+
m = Model()
|
435
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
436
|
+
print(states.to_flat())
|
437
|
+
self.assertTrue(len(states.to_flat()) == 4)
|
438
|
+
|
439
|
+
def test_mixed_states_and_nodes(self):
|
440
|
+
"""Test nodes with mixed states and sub-nodes."""
|
441
|
+
|
442
|
+
class Model(brainstate.graph.Node):
|
443
|
+
def __init__(self):
|
444
|
+
self.state1 = brainstate.State(1.0)
|
445
|
+
self.state2 = brainstate.State(2.0)
|
446
|
+
self.linear = brainstate.nn.Linear(5, 10)
|
447
|
+
self.data = [1, 2, 3]
|
448
|
+
|
449
|
+
m = Model()
|
450
|
+
graphdef, states = brainstate.graph.treefy_split(m)
|
451
|
+
|
452
|
+
# Should have states from both direct State objects and Linear layer
|
453
|
+
flat_states = states.to_flat()
|
454
|
+
self.assertGreaterEqual(len(flat_states), 2) # At least the two direct states
|
455
|
+
|
456
|
+
|
457
|
+
class TestEdgeCases(unittest.TestCase):
|
458
|
+
"""Test edge cases and error conditions."""
|
459
|
+
|
460
|
+
def test_empty_node(self):
|
461
|
+
"""Test node with no attributes."""
|
462
|
+
|
463
|
+
class EmptyNode(Node):
|
464
|
+
pass
|
465
|
+
|
466
|
+
node = EmptyNode()
|
467
|
+
flattened, static = _node_flatten(node)
|
468
|
+
|
469
|
+
self.assertEqual(len(flattened), 0)
|
470
|
+
self.assertEqual(static, (EmptyNode,))
|
471
|
+
|
472
|
+
def test_node_with_none_values(self):
|
473
|
+
"""Test node with None values."""
|
474
|
+
|
475
|
+
class NoneNode(Node):
|
476
|
+
def __init__(self):
|
477
|
+
self.none_val = None
|
478
|
+
self.real_val = 10
|
479
|
+
|
480
|
+
node = NoneNode()
|
481
|
+
flattened, static = _node_flatten(node)
|
482
|
+
|
483
|
+
values_dict = dict(flattened)
|
484
|
+
self.assertIsNone(values_dict['none_val'])
|
485
|
+
self.assertEqual(values_dict['real_val'], 10)
|
486
|
+
|
487
|
+
def test_node_with_special_attributes(self):
|
488
|
+
"""Test node with special Python attributes."""
|
489
|
+
|
490
|
+
class SpecialNode(Node):
|
491
|
+
def __init__(self):
|
492
|
+
self.__dict__['special'] = 'value'
|
493
|
+
self.normal = 'normal'
|
494
|
+
|
495
|
+
node = SpecialNode()
|
496
|
+
self.assertEqual(node.special, 'value')
|
497
|
+
self.assertEqual(node.normal, 'normal')
|
498
|
+
|
499
|
+
def test_circular_reference(self):
|
500
|
+
"""Test handling of circular references."""
|
501
|
+
|
502
|
+
class CircularNode(Node):
|
503
|
+
pass
|
504
|
+
|
505
|
+
node1 = CircularNode()
|
506
|
+
node2 = CircularNode()
|
507
|
+
node1.ref = node2
|
508
|
+
node2.ref = node1
|
509
|
+
|
510
|
+
# This should not cause infinite recursion with treefy
|
511
|
+
try:
|
512
|
+
graphdef, state = brainstate.graph.treefy_split(node1)
|
513
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
514
|
+
# Check that circular reference is preserved
|
515
|
+
self.assertIs(copied.ref.ref, copied)
|
516
|
+
except RecursionError:
|
517
|
+
self.fail("Treefy failed with circular reference")
|
518
|
+
|
519
|
+
def test_node_inheritance(self):
|
520
|
+
"""Test node inheritance hierarchy."""
|
521
|
+
|
522
|
+
class BaseNode(Node):
|
523
|
+
def __init__(self):
|
524
|
+
self.base_attr = 'base'
|
525
|
+
|
526
|
+
class DerivedNode(BaseNode):
|
527
|
+
def __init__(self):
|
528
|
+
super().__init__()
|
529
|
+
self.derived_attr = 'derived'
|
530
|
+
|
531
|
+
node = DerivedNode()
|
532
|
+
self.assertEqual(node.base_attr, 'base')
|
533
|
+
self.assertEqual(node.derived_attr, 'derived')
|
534
|
+
|
535
|
+
flattened, static = _node_flatten(node)
|
536
|
+
keys = [k for k, v in flattened]
|
537
|
+
self.assertIn('base_attr', keys)
|
538
|
+
self.assertIn('derived_attr', keys)
|
539
|
+
|
540
|
+
def test_node_with_property(self):
|
541
|
+
"""Test node with property decorators."""
|
542
|
+
|
543
|
+
class PropertyNode(Node):
|
544
|
+
def __init__(self):
|
545
|
+
self._value = 10
|
546
|
+
|
547
|
+
@property
|
548
|
+
def value(self):
|
549
|
+
return self._value
|
550
|
+
|
551
|
+
@value.setter
|
552
|
+
def value(self, val):
|
553
|
+
self._value = val
|
554
|
+
|
555
|
+
node = PropertyNode()
|
556
|
+
self.assertEqual(node.value, 10)
|
557
|
+
|
558
|
+
node.value = 20
|
559
|
+
self.assertEqual(node.value, 20)
|
560
|
+
|
561
|
+
# Only _value should appear in flattened
|
562
|
+
flattened, static = _node_flatten(node)
|
563
|
+
keys = [k for k, v in flattened]
|
564
|
+
self.assertIn('_value', keys)
|
565
|
+
|
566
|
+
def test_multiple_inheritance(self):
|
567
|
+
"""Test node with multiple inheritance."""
|
568
|
+
|
569
|
+
class Mixin:
|
570
|
+
def mixin_method(self):
|
571
|
+
return 'mixin'
|
572
|
+
|
573
|
+
class MultiNode(Node, Mixin):
|
574
|
+
def __init__(self):
|
575
|
+
self.data = 'data'
|
576
|
+
|
577
|
+
node = MultiNode()
|
578
|
+
self.assertEqual(node.mixin_method(), 'mixin')
|
579
|
+
self.assertEqual(node.data, 'data')
|
580
|
+
|
581
|
+
# Test that it still works as a Node with treefy
|
582
|
+
graphdef, state = brainstate.graph.treefy_split(node)
|
583
|
+
copied = brainstate.graph.treefy_merge(graphdef, state)
|
584
|
+
self.assertIsNot(node, copied)
|
585
|
+
self.assertEqual(copied.data, 'data')
|
586
|
+
|
587
|
+
|
588
|
+
if __name__ == '__main__':
|
589
|
+
unittest.main()
|