brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,675 +1,675 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive test suite for the pretty_pytree module.
18
-
19
- This test module provides extensive coverage of the pretty printing and tree
20
- manipulation functionality, including:
21
- - PrettyObject and pretty representation
22
- - Nested and flattened dictionary structures
23
- - Mapping flattening and unflattening
24
- - Split, filter, and merge operations
25
- - JAX pytree integration
26
- - State management utilities
27
- """
28
-
29
- import unittest
30
-
31
- import jax
32
- import jax.numpy as jnp
33
- import numpy as np
34
- from absl.testing import absltest
35
-
36
- import brainstate
37
- from brainstate.util._pretty_pytree import (
38
- PrettyObject,
39
- PrettyDict,
40
- NestedDict,
41
- FlattedDict,
42
- PrettyList,
43
- flat_mapping,
44
- nest_mapping,
45
- empty_node,
46
- _EmptyNode,
47
- )
48
-
49
-
50
- class TestNestedMapping(absltest.TestCase):
51
- def test_create_state(self):
52
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
53
-
54
- assert state['a'].value == 1
55
- assert state['b']['c'].value == 2
56
-
57
- def test_get_attr(self):
58
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
59
-
60
- assert state.a.value == 1
61
- assert state.b['c'].value == 2
62
-
63
- def test_set_attr(self):
64
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
65
-
66
- state.a.value = 3
67
- state.b['c'].value = 4
68
-
69
- assert state['a'].value == 3
70
- assert state['b']['c'].value == 4
71
-
72
- def test_set_attr_variables(self):
73
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
74
-
75
- state.a.value = 3
76
- state.b['c'].value = 4
77
-
78
- assert isinstance(state.a, brainstate.ParamState)
79
- assert state.a.value == 3
80
- assert isinstance(state.b['c'], brainstate.ParamState)
81
- assert state.b['c'].value == 4
82
-
83
- def test_add_nested_attr(self):
84
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
85
- state.b['d'] = brainstate.ParamState(5)
86
-
87
- assert state['b']['d'].value == 5
88
-
89
- def test_delete_nested_attr(self):
90
- state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
91
- del state['b']['c']
92
-
93
- assert 'c' not in state['b']
94
-
95
- def test_integer_access(self):
96
- class Foo(brainstate.nn.Module):
97
- def __init__(self):
98
- super().__init__()
99
- self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
100
-
101
- module = Foo()
102
- state_refs = brainstate.graph.treefy_states(module)
103
-
104
- assert module.layers[0].weight.value['weight'].shape == (1, 2)
105
- assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
106
- assert module.layers[1].weight.value['weight'].shape == (2, 3)
107
- assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
108
-
109
- def test_pure_dict(self):
110
- module = brainstate.nn.Linear(4, 5)
111
- state_map = brainstate.graph.treefy_states(module)
112
- pure_dict = state_map.to_pure_dict()
113
- assert isinstance(pure_dict, dict)
114
- assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
115
- assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
116
-
117
-
118
- class TestSplit(unittest.TestCase):
119
- def test_split(self):
120
- class Model(brainstate.nn.Module):
121
- def __init__(self):
122
- super().__init__()
123
- self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
124
- self.linear = brainstate.nn.Linear([10, 3], [10, 4])
125
-
126
- def __call__(self, x):
127
- return self.linear(self.batchnorm(x))
128
-
129
- with brainstate.environ.context(fit=True):
130
- model = Model()
131
- x = brainstate.random.randn(1, 10, 3)
132
- y = model(x)
133
- self.assertEqual(y.shape, (1, 10, 4))
134
-
135
- state_map = brainstate.graph.treefy_states(model)
136
-
137
- with self.assertRaises(ValueError):
138
- params, others = state_map.split(brainstate.ParamState)
139
-
140
- params, others = state_map.split(brainstate.ParamState, ...)
141
- print()
142
- print(params)
143
- print(others)
144
-
145
- self.assertTrue(len(params.to_flat()) == 2)
146
- self.assertTrue(len(others.to_flat()) == 2)
147
-
148
-
149
- class TestStateMap2(unittest.TestCase):
150
- def test1(self):
151
- class Model(brainstate.nn.Module):
152
- def __init__(self):
153
- super().__init__()
154
- self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
155
- self.linear = brainstate.nn.Linear([10, 3], [10, 4])
156
-
157
- def __call__(self, x):
158
- return self.linear(self.batchnorm(x))
159
-
160
- with brainstate.environ.context(fit=True):
161
- model = Model()
162
- state_map = brainstate.graph.treefy_states(model).to_flat()
163
- state_map = brainstate.util.NestedDict(state_map)
164
-
165
-
166
- class TestFlattedMapping(unittest.TestCase):
167
- def test1(self):
168
- class Model(brainstate.nn.Module):
169
- def __init__(self):
170
- super().__init__()
171
- self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
172
- self.linear = brainstate.nn.Linear([10, 3], [10, 4])
173
-
174
- def __call__(self, x):
175
- return self.linear(self.batchnorm(x))
176
-
177
- model = Model()
178
- # print(model.states())
179
- # print(brainstate.graph.states(model))
180
- self.assertTrue(model.states() == brainstate.graph.states(model))
181
-
182
- print(model.nodes())
183
- # print(brainstate.graph.nodes(model))
184
- self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
185
-
186
-
187
- class TestPrettyObject(unittest.TestCase):
188
- """Test PrettyObject functionality."""
189
-
190
- def test_pretty_object_basic(self):
191
- """Test basic PrettyObject creation and representation."""
192
- class MyObject(PrettyObject):
193
- def __init__(self, value):
194
- self.value = value
195
- self.name = "test"
196
-
197
- obj = MyObject(42)
198
- repr_str = repr(obj)
199
- self.assertIsInstance(repr_str, str)
200
- self.assertIn("MyObject", repr_str)
201
- self.assertIn("value", repr_str)
202
- self.assertIn("42", repr_str)
203
-
204
- def test_pretty_repr_item_filtering(self):
205
- """Test __pretty_repr_item__ filtering."""
206
- class FilteredObject(PrettyObject):
207
- def __init__(self):
208
- self.visible = "show"
209
- self.hidden = "hide"
210
-
211
- def __pretty_repr_item__(self, k, v):
212
- if k == "hidden":
213
- return None
214
- return k, v
215
-
216
- obj = FilteredObject()
217
- repr_str = repr(obj)
218
- self.assertIn("visible", repr_str)
219
- self.assertNotIn("hidden", repr_str)
220
-
221
- def test_pretty_repr_item_transformation(self):
222
- """Test __pretty_repr_item__ value transformation."""
223
- class TransformObject(PrettyObject):
224
- def __init__(self):
225
- self.value = 100
226
-
227
- def __pretty_repr_item__(self, k, v):
228
- if k == "value":
229
- return k, v * 2
230
- return k, v
231
-
232
- obj = TransformObject()
233
- repr_str = repr(obj)
234
- self.assertIn("200", repr_str)
235
-
236
-
237
- class TestFlatAndNestMapping(unittest.TestCase):
238
- """Test flat_mapping and nest_mapping functions."""
239
-
240
- def test_flat_mapping_basic(self):
241
- """Test basic flattening of nested dict."""
242
- nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
243
- flat = flat_mapping(nested)
244
-
245
- self.assertIsInstance(flat, FlattedDict)
246
- self.assertEqual(flat[('a',)], 1)
247
- self.assertEqual(flat[('b', 'c')], 2)
248
- self.assertEqual(flat[('b', 'd')], 3)
249
-
250
- def test_flat_mapping_with_separator(self):
251
- """Test flattening with string separator."""
252
- nested = {'a': 1, 'b': {'c': 2}}
253
- flat = flat_mapping(nested, sep='/')
254
-
255
- self.assertEqual(flat['a'], 1)
256
- self.assertEqual(flat['b/c'], 2)
257
-
258
- def test_flat_mapping_empty_nodes(self):
259
- """Test flattening with keep_empty_nodes."""
260
- nested = {'a': 1, 'b': {}}
261
- flat = flat_mapping(nested, keep_empty_nodes=True)
262
-
263
- self.assertEqual(flat[('a',)], 1)
264
- self.assertIsInstance(flat[('b',)], _EmptyNode)
265
-
266
- def test_flat_mapping_without_empty_nodes(self):
267
- """Test flattening without keeping empty nodes."""
268
- nested = {'a': 1, 'b': {}}
269
- flat = flat_mapping(nested, keep_empty_nodes=False)
270
-
271
- self.assertIn(('a',), flat)
272
- self.assertNotIn(('b',), flat)
273
-
274
- def test_flat_mapping_is_leaf(self):
275
- """Test flattening with custom is_leaf function."""
276
- nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
277
-
278
- def is_leaf(path, value):
279
- return len(path) >= 1
280
-
281
- flat = flat_mapping(nested, is_leaf=is_leaf)
282
- self.assertEqual(flat[('a',)], 1)
283
- self.assertEqual(flat[('b',)], {'c': 2, 'd': 3})
284
-
285
- def test_nest_mapping_basic(self):
286
- """Test basic unflattening."""
287
- flat = {('a',): 1, ('b', 'c'): 2, ('b', 'd'): 3}
288
- nested = nest_mapping(flat)
289
-
290
- self.assertIsInstance(nested, NestedDict)
291
- self.assertEqual(nested['a'], 1)
292
- self.assertEqual(nested['b']['c'], 2)
293
- self.assertEqual(nested['b']['d'], 3)
294
-
295
- def test_nest_mapping_with_separator(self):
296
- """Test unflattening with string separator."""
297
- flat = {'a': 1, 'b/c': 2, 'b/d': 3}
298
- nested = nest_mapping(flat, sep='/')
299
-
300
- self.assertEqual(nested['a'], 1)
301
- self.assertEqual(nested['b']['c'], 2)
302
- self.assertEqual(nested['b']['d'], 3)
303
-
304
- def test_nest_mapping_with_empty_node(self):
305
- """Test unflattening with empty nodes."""
306
- flat = {('a',): 1, ('b',): empty_node}
307
- nested = nest_mapping(flat)
308
-
309
- self.assertEqual(nested['a'], 1)
310
- self.assertEqual(nested['b'], {})
311
-
312
- def test_round_trip(self):
313
- """Test flatten -> unflatten round trip."""
314
- original = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
315
- flat = flat_mapping(original)
316
- restored = nest_mapping(flat)
317
-
318
- self.assertEqual(restored.to_dict(), original)
319
-
320
-
321
- class TestPrettyDict(unittest.TestCase):
322
- """Test PrettyDict functionality."""
323
-
324
- def test_pretty_dict_creation(self):
325
- """Test PrettyDict creation."""
326
- d = PrettyDict({'a': 1, 'b': 2})
327
- self.assertEqual(d['a'], 1)
328
- self.assertEqual(d['b'], 2)
329
-
330
- def test_pretty_dict_attribute_access(self):
331
- """Test accessing dict items as attributes."""
332
- d = PrettyDict({'a': 1, 'b': 2})
333
- self.assertEqual(d.a, 1)
334
- self.assertEqual(d.b, 2)
335
-
336
- def test_pretty_dict_repr(self):
337
- """Test PrettyDict representation."""
338
- d = PrettyDict({'a': 1, 'b': 2})
339
- repr_str = repr(d)
340
- self.assertIsInstance(repr_str, str)
341
- self.assertIn('a', repr_str)
342
-
343
- def test_to_dict(self):
344
- """Test conversion to regular dict."""
345
- d = PrettyDict({'a': 1, 'b': 2})
346
- regular = d.to_dict()
347
- self.assertIsInstance(regular, dict)
348
- self.assertEqual(regular, {'a': 1, 'b': 2})
349
-
350
-
351
- class TestNestedDictOperations(unittest.TestCase):
352
- """Test NestedDict additional operations."""
353
-
354
- def test_or_operator(self):
355
- """Test | operator for merging."""
356
- d1 = NestedDict({'a': 1})
357
- d2 = NestedDict({'b': 2})
358
- merged = d1 | d2
359
-
360
- self.assertIsInstance(merged, NestedDict)
361
- self.assertEqual(merged['a'], 1)
362
- self.assertEqual(merged['b'], 2)
363
-
364
- def test_sub_operator(self):
365
- """Test - operator for difference."""
366
- d1 = NestedDict({'a': 1, 'b': 2, 'c': 3})
367
- d2 = NestedDict({'b': 2})
368
- diff = d1 - d2
369
-
370
- flat_diff = diff.to_flat()
371
- self.assertIn(('a',), flat_diff.keys())
372
- self.assertIn(('c',), flat_diff.keys())
373
- # b should not be in diff
374
- has_b = any('b' in key for key in flat_diff.keys())
375
- self.assertFalse(has_b)
376
-
377
- def test_merge_static_method(self):
378
- """Test static merge method."""
379
- d1 = NestedDict({'a': 1})
380
- d2 = NestedDict({'b': 2})
381
- d3 = NestedDict({'c': 3})
382
- merged = NestedDict.merge(d1, d2, d3)
383
-
384
- self.assertEqual(merged['a'], 1)
385
- self.assertEqual(merged['b'], 2)
386
- self.assertEqual(merged['c'], 3)
387
-
388
- def test_to_pure_dict(self):
389
- """Test conversion to pure dict."""
390
- nested = NestedDict({'a': 1, 'b': {'c': 2}})
391
- pure = nested.to_pure_dict()
392
-
393
- self.assertIsInstance(pure, dict)
394
- self.assertNotIsInstance(pure, NestedDict)
395
- self.assertEqual(pure['a'], 1)
396
- self.assertEqual(pure['b']['c'], 2)
397
-
398
-
399
- class TestFlattedDictOperations(unittest.TestCase):
400
- """Test FlattedDict additional operations."""
401
-
402
- def test_or_operator(self):
403
- """Test | operator for merging."""
404
- d1 = FlattedDict({('a',): 1})
405
- d2 = FlattedDict({('b',): 2})
406
- merged = d1 | d2
407
-
408
- self.assertIsInstance(merged, FlattedDict)
409
- self.assertEqual(merged[('a',)], 1)
410
- self.assertEqual(merged[('b',)], 2)
411
-
412
- def test_sub_operator(self):
413
- """Test - operator for difference."""
414
- d1 = FlattedDict({('a',): 1, ('b',): 2, ('c',): 3})
415
- d2 = FlattedDict({('b',): 2})
416
- diff = d1 - d2
417
-
418
- self.assertIn(('a',), diff)
419
- self.assertIn(('c',), diff)
420
- self.assertNotIn(('b',), diff)
421
-
422
- def test_merge_static_method(self):
423
- """Test static merge method."""
424
- d1 = FlattedDict({('a',): 1})
425
- d2 = FlattedDict({('b',): 2})
426
- merged = FlattedDict.merge(d1, d2)
427
-
428
- self.assertEqual(merged[('a',)], 1)
429
- self.assertEqual(merged[('b',)], 2)
430
-
431
- def test_to_dict_values(self):
432
- """Test conversion to dictionary of values."""
433
- flat = FlattedDict({
434
- ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
435
- ('b',): 42
436
- })
437
- values = flat.to_dict_values()
438
-
439
- self.assertIsInstance(values[('a',)], jnp.ndarray)
440
- np.testing.assert_array_equal(values[('a',)], jnp.array([1, 2, 3]))
441
- self.assertEqual(values[('b',)], 42)
442
-
443
- def test_assign_dict_values(self):
444
- """Test assigning dictionary values."""
445
- flat = FlattedDict({
446
- ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
447
- ('b',): 42
448
- })
449
-
450
- new_values = {
451
- ('a',): jnp.array([4, 5, 6]),
452
- ('b',): 100
453
- }
454
-
455
- flat.assign_dict_values(new_values)
456
-
457
- np.testing.assert_array_equal(flat[('a',)].value, jnp.array([4, 5, 6]))
458
- self.assertEqual(flat[('b',)], 100)
459
-
460
- def test_assign_dict_values_missing_key(self):
461
- """Test assigning with missing key raises error."""
462
- flat = FlattedDict({('a',): 1})
463
-
464
- with self.assertRaises(KeyError):
465
- flat.assign_dict_values({('b',): 2})
466
-
467
-
468
- class TestPrettyList(unittest.TestCase):
469
- """Test PrettyList functionality."""
470
-
471
- def test_pretty_list_creation(self):
472
- """Test PrettyList creation."""
473
- lst = PrettyList([1, 2, 3])
474
- self.assertEqual(lst[0], 1)
475
- self.assertEqual(lst[1], 2)
476
- self.assertEqual(lst[2], 3)
477
-
478
- def test_pretty_list_repr(self):
479
- """Test PrettyList representation."""
480
- lst = PrettyList([1, 2, {'a': 3}])
481
- repr_str = repr(lst)
482
- self.assertIsInstance(repr_str, str)
483
- self.assertIn('1', repr_str)
484
-
485
- def test_tree_flatten(self):
486
- """Test JAX tree flattening."""
487
- lst = PrettyList([1, 2, 3])
488
- leaves, aux = lst.tree_flatten()
489
- self.assertEqual(leaves, [1, 2, 3])
490
- self.assertEqual(aux, ())
491
-
492
- def test_tree_unflatten(self):
493
- """Test JAX tree unflattening."""
494
- children = [1, 2, 3]
495
- lst = PrettyList.tree_unflatten((), children)
496
- self.assertIsInstance(lst, PrettyList)
497
- self.assertEqual(list(lst), [1, 2, 3])
498
-
499
-
500
- class TestFilterOperations(unittest.TestCase):
501
- """Test filter operations."""
502
-
503
- def test_nested_dict_filter(self):
504
- """Test filtering NestedDict."""
505
- nested = NestedDict({
506
- 'a': 1,
507
- 'b': 2,
508
- 'c': 3
509
- })
510
-
511
- filtered = nested.filter(lambda path, val: val >= 2)
512
-
513
- flat = filtered.to_flat()
514
- # Check that filtered values are present
515
- values = [v for v in flat.values()]
516
- self.assertIn(2, values)
517
- self.assertIn(3, values)
518
-
519
- def test_flatted_dict_filter(self):
520
- """Test filtering FlattedDict."""
521
- flat = FlattedDict({
522
- ('a',): 1,
523
- ('b',): 2,
524
- ('c',): 3
525
- })
526
-
527
- filtered = flat.filter(lambda path, val: val % 2 == 0)
528
- self.assertIn(('b',), filtered)
529
- self.assertNotIn(('a',), filtered)
530
-
531
- def test_ellipsis_filter_position(self):
532
- """Test that ... can only be used as last filter."""
533
- nested = NestedDict({'a': 1, 'b': 2, 'c': 3})
534
-
535
- with self.assertRaises(ValueError):
536
- # ... in middle should raise error
537
- nested.split(..., lambda path, val: val > 1)
538
-
539
-
540
- class TestJAXPytreeIntegration(unittest.TestCase):
541
- """Test JAX pytree integration."""
542
-
543
- def test_nested_dict_pytree_flatten(self):
544
- """Test NestedDict can be flattened as pytree."""
545
- nested = NestedDict({'a': 1, 'b': 2})
546
- leaves, treedef = jax.tree.flatten(nested)
547
-
548
- self.assertEqual(sorted(leaves), [1, 2])
549
-
550
- def test_nested_dict_pytree_unflatten(self):
551
- """Test NestedDict can be unflattened as pytree."""
552
- nested = NestedDict({'a': 1, 'b': 2})
553
- leaves, treedef = jax.tree.flatten(nested)
554
- restored = jax.tree.unflatten(treedef, leaves)
555
-
556
- self.assertIsInstance(restored, NestedDict)
557
- self.assertEqual(restored['a'], 1)
558
- self.assertEqual(restored['b'], 2)
559
-
560
- def test_flatted_dict_pytree_flatten(self):
561
- """Test FlattedDict can be flattened as pytree."""
562
- flat = FlattedDict({('a',): 1, ('b',): 2})
563
- leaves, treedef = jax.tree.flatten(flat)
564
-
565
- self.assertEqual(sorted(leaves), [1, 2])
566
-
567
- def test_flatted_dict_pytree_unflatten(self):
568
- """Test FlattedDict can be unflattened as pytree."""
569
- flat = FlattedDict({('a',): 1, ('b',): 2})
570
- leaves, treedef = jax.tree.flatten(flat)
571
- restored = jax.tree.unflatten(treedef, leaves)
572
-
573
- self.assertIsInstance(restored, FlattedDict)
574
- self.assertEqual(restored[('a',)], 1)
575
-
576
- def test_pretty_list_pytree(self):
577
- """Test PrettyList pytree operations."""
578
- lst = PrettyList([1, 2, 3])
579
- leaves, treedef = jax.tree.flatten(lst)
580
- restored = jax.tree.unflatten(treedef, leaves)
581
-
582
- self.assertIsInstance(restored, PrettyList)
583
- self.assertEqual(list(restored), [1, 2, 3])
584
-
585
- def test_jax_tree_map_nested_dict(self):
586
- """Test jax.tree.map with NestedDict."""
587
- nested = NestedDict({'a': 1, 'b': {'c': 2}})
588
- doubled = jax.tree.map(lambda x: x * 2, nested)
589
-
590
- self.assertEqual(doubled['a'], 2)
591
- self.assertEqual(doubled['b']['c'], 4)
592
-
593
- def test_jax_tree_map_flatted_dict(self):
594
- """Test jax.tree.map with FlattedDict."""
595
- flat = FlattedDict({('a',): 1, ('b', 'c'): 2})
596
- doubled = jax.tree.map(lambda x: x * 2, flat)
597
-
598
- self.assertEqual(doubled[('a',)], 2)
599
- self.assertEqual(doubled[('b', 'c')], 4)
600
-
601
- def test_jax_tree_map_pretty_list(self):
602
- """Test jax.tree.map with PrettyList."""
603
- lst = PrettyList([1, 2, 3])
604
- doubled = jax.tree.map(lambda x: x * 2, lst)
605
-
606
- self.assertEqual(list(doubled), [2, 4, 6])
607
-
608
-
609
- class TestEdgeCases(unittest.TestCase):
610
- """Test edge cases and error handling."""
611
-
612
- def test_empty_nested_dict(self):
613
- """Test empty NestedDict."""
614
- nested = NestedDict({})
615
- flat = nested.to_flat()
616
- self.assertEqual(len(flat), 0)
617
-
618
- def test_empty_flatted_dict(self):
619
- """Test empty FlattedDict."""
620
- flat = FlattedDict({})
621
- nested = flat.to_nest()
622
- self.assertEqual(len(nested), 0)
623
-
624
- def test_deeply_nested_structure(self):
625
- """Test deeply nested structure."""
626
- nested = NestedDict({
627
- 'a': {
628
- 'b': {
629
- 'c': {
630
- 'd': {
631
- 'e': 42
632
- }
633
- }
634
- }
635
- }
636
- })
637
- flat = nested.to_flat()
638
- self.assertEqual(flat[('a', 'b', 'c', 'd', 'e')], 42)
639
-
640
- def test_mixed_types_in_nested(self):
641
- """Test nested dict with mixed types."""
642
- nested = NestedDict({
643
- 'int': 1,
644
- 'float': 2.5,
645
- 'str': 'hello',
646
- 'list': [1, 2, 3],
647
- 'dict': {'nested': True}
648
- })
649
- flat = nested.to_flat()
650
-
651
- self.assertEqual(flat[('int',)], 1)
652
- self.assertEqual(flat[('float',)], 2.5)
653
- self.assertEqual(flat[('str',)], 'hello')
654
-
655
- def test_numeric_keys(self):
656
- """Test handling of numeric keys."""
657
- nested = NestedDict({
658
- 1: 'one',
659
- 2: {'a': 'two-a'}
660
- })
661
- flat = nested.to_flat()
662
-
663
- self.assertEqual(flat[(1,)], 'one')
664
- self.assertEqual(flat[(2, 'a')], 'two-a')
665
-
666
- def test_merge_with_overlapping_keys(self):
667
- """Test merging with overlapping keys."""
668
- d1 = NestedDict({'a': 1, 'b': 2})
669
- d2 = NestedDict({'b': 3, 'c': 4})
670
- merged = NestedDict.merge(d1, d2)
671
-
672
- # Later values should override
673
- self.assertEqual(merged['b'], 3)
674
- self.assertEqual(merged['a'], 1)
675
- self.assertEqual(merged['c'], 4)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the pretty_pytree module.
18
+
19
+ This test module provides extensive coverage of the pretty printing and tree
20
+ manipulation functionality, including:
21
+ - PrettyObject and pretty representation
22
+ - Nested and flattened dictionary structures
23
+ - Mapping flattening and unflattening
24
+ - Split, filter, and merge operations
25
+ - JAX pytree integration
26
+ - State management utilities
27
+ """
28
+
29
+ import unittest
30
+
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import numpy as np
34
+ from absl.testing import absltest
35
+
36
+ import brainstate
37
+ from brainstate.util._pretty_pytree import (
38
+ PrettyObject,
39
+ PrettyDict,
40
+ NestedDict,
41
+ FlattedDict,
42
+ PrettyList,
43
+ flat_mapping,
44
+ nest_mapping,
45
+ empty_node,
46
+ _EmptyNode,
47
+ )
48
+
49
+
50
+ class TestNestedMapping(absltest.TestCase):
51
+ def test_create_state(self):
52
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
53
+
54
+ assert state['a'].value == 1
55
+ assert state['b']['c'].value == 2
56
+
57
+ def test_get_attr(self):
58
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
59
+
60
+ assert state.a.value == 1
61
+ assert state.b['c'].value == 2
62
+
63
+ def test_set_attr(self):
64
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
65
+
66
+ state.a.value = 3
67
+ state.b['c'].value = 4
68
+
69
+ assert state['a'].value == 3
70
+ assert state['b']['c'].value == 4
71
+
72
+ def test_set_attr_variables(self):
73
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
74
+
75
+ state.a.value = 3
76
+ state.b['c'].value = 4
77
+
78
+ assert isinstance(state.a, brainstate.ParamState)
79
+ assert state.a.value == 3
80
+ assert isinstance(state.b['c'], brainstate.ParamState)
81
+ assert state.b['c'].value == 4
82
+
83
+ def test_add_nested_attr(self):
84
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
85
+ state.b['d'] = brainstate.ParamState(5)
86
+
87
+ assert state['b']['d'].value == 5
88
+
89
+ def test_delete_nested_attr(self):
90
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
91
+ del state['b']['c']
92
+
93
+ assert 'c' not in state['b']
94
+
95
+ def test_integer_access(self):
96
+ class Foo(brainstate.nn.Module):
97
+ def __init__(self):
98
+ super().__init__()
99
+ self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
100
+
101
+ module = Foo()
102
+ state_refs = brainstate.graph.treefy_states(module)
103
+
104
+ assert module.layers[0].weight.value['weight'].shape == (1, 2)
105
+ assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
106
+ assert module.layers[1].weight.value['weight'].shape == (2, 3)
107
+ assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
108
+
109
+ def test_pure_dict(self):
110
+ module = brainstate.nn.Linear(4, 5)
111
+ state_map = brainstate.graph.treefy_states(module)
112
+ pure_dict = state_map.to_pure_dict()
113
+ assert isinstance(pure_dict, dict)
114
+ assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
115
+ assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
116
+
117
+
118
+ class TestSplit(unittest.TestCase):
119
+ def test_split(self):
120
+ class Model(brainstate.nn.Module):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
124
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
125
+
126
+ def __call__(self, x):
127
+ return self.linear(self.batchnorm(x))
128
+
129
+ with brainstate.environ.context(fit=True):
130
+ model = Model()
131
+ x = brainstate.random.randn(1, 10, 3)
132
+ y = model(x)
133
+ self.assertEqual(y.shape, (1, 10, 4))
134
+
135
+ state_map = brainstate.graph.treefy_states(model)
136
+
137
+ with self.assertRaises(ValueError):
138
+ params, others = state_map.split(brainstate.ParamState)
139
+
140
+ params, others = state_map.split(brainstate.ParamState, ...)
141
+ print()
142
+ print(params)
143
+ print(others)
144
+
145
+ self.assertTrue(len(params.to_flat()) == 2)
146
+ self.assertTrue(len(others.to_flat()) == 2)
147
+
148
+
149
+ class TestStateMap2(unittest.TestCase):
150
+ def test1(self):
151
+ class Model(brainstate.nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
155
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
156
+
157
+ def __call__(self, x):
158
+ return self.linear(self.batchnorm(x))
159
+
160
+ with brainstate.environ.context(fit=True):
161
+ model = Model()
162
+ state_map = brainstate.graph.treefy_states(model).to_flat()
163
+ state_map = brainstate.util.NestedDict(state_map)
164
+
165
+
166
+ class TestFlattedMapping(unittest.TestCase):
167
+ def test1(self):
168
+ class Model(brainstate.nn.Module):
169
+ def __init__(self):
170
+ super().__init__()
171
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
172
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
173
+
174
+ def __call__(self, x):
175
+ return self.linear(self.batchnorm(x))
176
+
177
+ model = Model()
178
+ # print(model.states())
179
+ # print(brainstate.graph.states(model))
180
+ self.assertTrue(model.states() == brainstate.graph.states(model))
181
+
182
+ print(model.nodes())
183
+ # print(brainstate.graph.nodes(model))
184
+ self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
185
+
186
+
187
+ class TestPrettyObject(unittest.TestCase):
188
+ """Test PrettyObject functionality."""
189
+
190
+ def test_pretty_object_basic(self):
191
+ """Test basic PrettyObject creation and representation."""
192
+ class MyObject(PrettyObject):
193
+ def __init__(self, value):
194
+ self.value = value
195
+ self.name = "test"
196
+
197
+ obj = MyObject(42)
198
+ repr_str = repr(obj)
199
+ self.assertIsInstance(repr_str, str)
200
+ self.assertIn("MyObject", repr_str)
201
+ self.assertIn("value", repr_str)
202
+ self.assertIn("42", repr_str)
203
+
204
+ def test_pretty_repr_item_filtering(self):
205
+ """Test __pretty_repr_item__ filtering."""
206
+ class FilteredObject(PrettyObject):
207
+ def __init__(self):
208
+ self.visible = "show"
209
+ self.hidden = "hide"
210
+
211
+ def __pretty_repr_item__(self, k, v):
212
+ if k == "hidden":
213
+ return None
214
+ return k, v
215
+
216
+ obj = FilteredObject()
217
+ repr_str = repr(obj)
218
+ self.assertIn("visible", repr_str)
219
+ self.assertNotIn("hidden", repr_str)
220
+
221
+ def test_pretty_repr_item_transformation(self):
222
+ """Test __pretty_repr_item__ value transformation."""
223
+ class TransformObject(PrettyObject):
224
+ def __init__(self):
225
+ self.value = 100
226
+
227
+ def __pretty_repr_item__(self, k, v):
228
+ if k == "value":
229
+ return k, v * 2
230
+ return k, v
231
+
232
+ obj = TransformObject()
233
+ repr_str = repr(obj)
234
+ self.assertIn("200", repr_str)
235
+
236
+
237
+ class TestFlatAndNestMapping(unittest.TestCase):
238
+ """Test flat_mapping and nest_mapping functions."""
239
+
240
+ def test_flat_mapping_basic(self):
241
+ """Test basic flattening of nested dict."""
242
+ nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
243
+ flat = flat_mapping(nested)
244
+
245
+ self.assertIsInstance(flat, FlattedDict)
246
+ self.assertEqual(flat[('a',)], 1)
247
+ self.assertEqual(flat[('b', 'c')], 2)
248
+ self.assertEqual(flat[('b', 'd')], 3)
249
+
250
+ def test_flat_mapping_with_separator(self):
251
+ """Test flattening with string separator."""
252
+ nested = {'a': 1, 'b': {'c': 2}}
253
+ flat = flat_mapping(nested, sep='/')
254
+
255
+ self.assertEqual(flat['a'], 1)
256
+ self.assertEqual(flat['b/c'], 2)
257
+
258
+ def test_flat_mapping_empty_nodes(self):
259
+ """Test flattening with keep_empty_nodes."""
260
+ nested = {'a': 1, 'b': {}}
261
+ flat = flat_mapping(nested, keep_empty_nodes=True)
262
+
263
+ self.assertEqual(flat[('a',)], 1)
264
+ self.assertIsInstance(flat[('b',)], _EmptyNode)
265
+
266
+ def test_flat_mapping_without_empty_nodes(self):
267
+ """Test flattening without keeping empty nodes."""
268
+ nested = {'a': 1, 'b': {}}
269
+ flat = flat_mapping(nested, keep_empty_nodes=False)
270
+
271
+ self.assertIn(('a',), flat)
272
+ self.assertNotIn(('b',), flat)
273
+
274
+ def test_flat_mapping_is_leaf(self):
275
+ """Test flattening with custom is_leaf function."""
276
+ nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
277
+
278
+ def is_leaf(path, value):
279
+ return len(path) >= 1
280
+
281
+ flat = flat_mapping(nested, is_leaf=is_leaf)
282
+ self.assertEqual(flat[('a',)], 1)
283
+ self.assertEqual(flat[('b',)], {'c': 2, 'd': 3})
284
+
285
+ def test_nest_mapping_basic(self):
286
+ """Test basic unflattening."""
287
+ flat = {('a',): 1, ('b', 'c'): 2, ('b', 'd'): 3}
288
+ nested = nest_mapping(flat)
289
+
290
+ self.assertIsInstance(nested, NestedDict)
291
+ self.assertEqual(nested['a'], 1)
292
+ self.assertEqual(nested['b']['c'], 2)
293
+ self.assertEqual(nested['b']['d'], 3)
294
+
295
+ def test_nest_mapping_with_separator(self):
296
+ """Test unflattening with string separator."""
297
+ flat = {'a': 1, 'b/c': 2, 'b/d': 3}
298
+ nested = nest_mapping(flat, sep='/')
299
+
300
+ self.assertEqual(nested['a'], 1)
301
+ self.assertEqual(nested['b']['c'], 2)
302
+ self.assertEqual(nested['b']['d'], 3)
303
+
304
+ def test_nest_mapping_with_empty_node(self):
305
+ """Test unflattening with empty nodes."""
306
+ flat = {('a',): 1, ('b',): empty_node}
307
+ nested = nest_mapping(flat)
308
+
309
+ self.assertEqual(nested['a'], 1)
310
+ self.assertEqual(nested['b'], {})
311
+
312
+ def test_round_trip(self):
313
+ """Test flatten -> unflatten round trip."""
314
+ original = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
315
+ flat = flat_mapping(original)
316
+ restored = nest_mapping(flat)
317
+
318
+ self.assertEqual(restored.to_dict(), original)
319
+
320
+
321
+ class TestPrettyDict(unittest.TestCase):
322
+ """Test PrettyDict functionality."""
323
+
324
+ def test_pretty_dict_creation(self):
325
+ """Test PrettyDict creation."""
326
+ d = PrettyDict({'a': 1, 'b': 2})
327
+ self.assertEqual(d['a'], 1)
328
+ self.assertEqual(d['b'], 2)
329
+
330
+ def test_pretty_dict_attribute_access(self):
331
+ """Test accessing dict items as attributes."""
332
+ d = PrettyDict({'a': 1, 'b': 2})
333
+ self.assertEqual(d.a, 1)
334
+ self.assertEqual(d.b, 2)
335
+
336
+ def test_pretty_dict_repr(self):
337
+ """Test PrettyDict representation."""
338
+ d = PrettyDict({'a': 1, 'b': 2})
339
+ repr_str = repr(d)
340
+ self.assertIsInstance(repr_str, str)
341
+ self.assertIn('a', repr_str)
342
+
343
+ def test_to_dict(self):
344
+ """Test conversion to regular dict."""
345
+ d = PrettyDict({'a': 1, 'b': 2})
346
+ regular = d.to_dict()
347
+ self.assertIsInstance(regular, dict)
348
+ self.assertEqual(regular, {'a': 1, 'b': 2})
349
+
350
+
351
+ class TestNestedDictOperations(unittest.TestCase):
352
+ """Test NestedDict additional operations."""
353
+
354
+ def test_or_operator(self):
355
+ """Test | operator for merging."""
356
+ d1 = NestedDict({'a': 1})
357
+ d2 = NestedDict({'b': 2})
358
+ merged = d1 | d2
359
+
360
+ self.assertIsInstance(merged, NestedDict)
361
+ self.assertEqual(merged['a'], 1)
362
+ self.assertEqual(merged['b'], 2)
363
+
364
+ def test_sub_operator(self):
365
+ """Test - operator for difference."""
366
+ d1 = NestedDict({'a': 1, 'b': 2, 'c': 3})
367
+ d2 = NestedDict({'b': 2})
368
+ diff = d1 - d2
369
+
370
+ flat_diff = diff.to_flat()
371
+ self.assertIn(('a',), flat_diff.keys())
372
+ self.assertIn(('c',), flat_diff.keys())
373
+ # b should not be in diff
374
+ has_b = any('b' in key for key in flat_diff.keys())
375
+ self.assertFalse(has_b)
376
+
377
+ def test_merge_static_method(self):
378
+ """Test static merge method."""
379
+ d1 = NestedDict({'a': 1})
380
+ d2 = NestedDict({'b': 2})
381
+ d3 = NestedDict({'c': 3})
382
+ merged = NestedDict.merge(d1, d2, d3)
383
+
384
+ self.assertEqual(merged['a'], 1)
385
+ self.assertEqual(merged['b'], 2)
386
+ self.assertEqual(merged['c'], 3)
387
+
388
+ def test_to_pure_dict(self):
389
+ """Test conversion to pure dict."""
390
+ nested = NestedDict({'a': 1, 'b': {'c': 2}})
391
+ pure = nested.to_pure_dict()
392
+
393
+ self.assertIsInstance(pure, dict)
394
+ self.assertNotIsInstance(pure, NestedDict)
395
+ self.assertEqual(pure['a'], 1)
396
+ self.assertEqual(pure['b']['c'], 2)
397
+
398
+
399
+ class TestFlattedDictOperations(unittest.TestCase):
400
+ """Test FlattedDict additional operations."""
401
+
402
+ def test_or_operator(self):
403
+ """Test | operator for merging."""
404
+ d1 = FlattedDict({('a',): 1})
405
+ d2 = FlattedDict({('b',): 2})
406
+ merged = d1 | d2
407
+
408
+ self.assertIsInstance(merged, FlattedDict)
409
+ self.assertEqual(merged[('a',)], 1)
410
+ self.assertEqual(merged[('b',)], 2)
411
+
412
+ def test_sub_operator(self):
413
+ """Test - operator for difference."""
414
+ d1 = FlattedDict({('a',): 1, ('b',): 2, ('c',): 3})
415
+ d2 = FlattedDict({('b',): 2})
416
+ diff = d1 - d2
417
+
418
+ self.assertIn(('a',), diff)
419
+ self.assertIn(('c',), diff)
420
+ self.assertNotIn(('b',), diff)
421
+
422
+ def test_merge_static_method(self):
423
+ """Test static merge method."""
424
+ d1 = FlattedDict({('a',): 1})
425
+ d2 = FlattedDict({('b',): 2})
426
+ merged = FlattedDict.merge(d1, d2)
427
+
428
+ self.assertEqual(merged[('a',)], 1)
429
+ self.assertEqual(merged[('b',)], 2)
430
+
431
+ def test_to_dict_values(self):
432
+ """Test conversion to dictionary of values."""
433
+ flat = FlattedDict({
434
+ ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
435
+ ('b',): 42
436
+ })
437
+ values = flat.to_dict_values()
438
+
439
+ self.assertIsInstance(values[('a',)], jnp.ndarray)
440
+ np.testing.assert_array_equal(values[('a',)], jnp.array([1, 2, 3]))
441
+ self.assertEqual(values[('b',)], 42)
442
+
443
+ def test_assign_dict_values(self):
444
+ """Test assigning dictionary values."""
445
+ flat = FlattedDict({
446
+ ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
447
+ ('b',): 42
448
+ })
449
+
450
+ new_values = {
451
+ ('a',): jnp.array([4, 5, 6]),
452
+ ('b',): 100
453
+ }
454
+
455
+ flat.assign_dict_values(new_values)
456
+
457
+ np.testing.assert_array_equal(flat[('a',)].value, jnp.array([4, 5, 6]))
458
+ self.assertEqual(flat[('b',)], 100)
459
+
460
+ def test_assign_dict_values_missing_key(self):
461
+ """Test assigning with missing key raises error."""
462
+ flat = FlattedDict({('a',): 1})
463
+
464
+ with self.assertRaises(KeyError):
465
+ flat.assign_dict_values({('b',): 2})
466
+
467
+
468
+ class TestPrettyList(unittest.TestCase):
469
+ """Test PrettyList functionality."""
470
+
471
+ def test_pretty_list_creation(self):
472
+ """Test PrettyList creation."""
473
+ lst = PrettyList([1, 2, 3])
474
+ self.assertEqual(lst[0], 1)
475
+ self.assertEqual(lst[1], 2)
476
+ self.assertEqual(lst[2], 3)
477
+
478
+ def test_pretty_list_repr(self):
479
+ """Test PrettyList representation."""
480
+ lst = PrettyList([1, 2, {'a': 3}])
481
+ repr_str = repr(lst)
482
+ self.assertIsInstance(repr_str, str)
483
+ self.assertIn('1', repr_str)
484
+
485
+ def test_tree_flatten(self):
486
+ """Test JAX tree flattening."""
487
+ lst = PrettyList([1, 2, 3])
488
+ leaves, aux = lst.tree_flatten()
489
+ self.assertEqual(leaves, [1, 2, 3])
490
+ self.assertEqual(aux, ())
491
+
492
+ def test_tree_unflatten(self):
493
+ """Test JAX tree unflattening."""
494
+ children = [1, 2, 3]
495
+ lst = PrettyList.tree_unflatten((), children)
496
+ self.assertIsInstance(lst, PrettyList)
497
+ self.assertEqual(list(lst), [1, 2, 3])
498
+
499
+
500
+ class TestFilterOperations(unittest.TestCase):
501
+ """Test filter operations."""
502
+
503
+ def test_nested_dict_filter(self):
504
+ """Test filtering NestedDict."""
505
+ nested = NestedDict({
506
+ 'a': 1,
507
+ 'b': 2,
508
+ 'c': 3
509
+ })
510
+
511
+ filtered = nested.filter(lambda path, val: val >= 2)
512
+
513
+ flat = filtered.to_flat()
514
+ # Check that filtered values are present
515
+ values = [v for v in flat.values()]
516
+ self.assertIn(2, values)
517
+ self.assertIn(3, values)
518
+
519
+ def test_flatted_dict_filter(self):
520
+ """Test filtering FlattedDict."""
521
+ flat = FlattedDict({
522
+ ('a',): 1,
523
+ ('b',): 2,
524
+ ('c',): 3
525
+ })
526
+
527
+ filtered = flat.filter(lambda path, val: val % 2 == 0)
528
+ self.assertIn(('b',), filtered)
529
+ self.assertNotIn(('a',), filtered)
530
+
531
+ def test_ellipsis_filter_position(self):
532
+ """Test that ... can only be used as last filter."""
533
+ nested = NestedDict({'a': 1, 'b': 2, 'c': 3})
534
+
535
+ with self.assertRaises(ValueError):
536
+ # ... in middle should raise error
537
+ nested.split(..., lambda path, val: val > 1)
538
+
539
+
540
+ class TestJAXPytreeIntegration(unittest.TestCase):
541
+ """Test JAX pytree integration."""
542
+
543
+ def test_nested_dict_pytree_flatten(self):
544
+ """Test NestedDict can be flattened as pytree."""
545
+ nested = NestedDict({'a': 1, 'b': 2})
546
+ leaves, treedef = jax.tree.flatten(nested)
547
+
548
+ self.assertEqual(sorted(leaves), [1, 2])
549
+
550
+ def test_nested_dict_pytree_unflatten(self):
551
+ """Test NestedDict can be unflattened as pytree."""
552
+ nested = NestedDict({'a': 1, 'b': 2})
553
+ leaves, treedef = jax.tree.flatten(nested)
554
+ restored = jax.tree.unflatten(treedef, leaves)
555
+
556
+ self.assertIsInstance(restored, NestedDict)
557
+ self.assertEqual(restored['a'], 1)
558
+ self.assertEqual(restored['b'], 2)
559
+
560
+ def test_flatted_dict_pytree_flatten(self):
561
+ """Test FlattedDict can be flattened as pytree."""
562
+ flat = FlattedDict({('a',): 1, ('b',): 2})
563
+ leaves, treedef = jax.tree.flatten(flat)
564
+
565
+ self.assertEqual(sorted(leaves), [1, 2])
566
+
567
+ def test_flatted_dict_pytree_unflatten(self):
568
+ """Test FlattedDict can be unflattened as pytree."""
569
+ flat = FlattedDict({('a',): 1, ('b',): 2})
570
+ leaves, treedef = jax.tree.flatten(flat)
571
+ restored = jax.tree.unflatten(treedef, leaves)
572
+
573
+ self.assertIsInstance(restored, FlattedDict)
574
+ self.assertEqual(restored[('a',)], 1)
575
+
576
+ def test_pretty_list_pytree(self):
577
+ """Test PrettyList pytree operations."""
578
+ lst = PrettyList([1, 2, 3])
579
+ leaves, treedef = jax.tree.flatten(lst)
580
+ restored = jax.tree.unflatten(treedef, leaves)
581
+
582
+ self.assertIsInstance(restored, PrettyList)
583
+ self.assertEqual(list(restored), [1, 2, 3])
584
+
585
+ def test_jax_tree_map_nested_dict(self):
586
+ """Test jax.tree.map with NestedDict."""
587
+ nested = NestedDict({'a': 1, 'b': {'c': 2}})
588
+ doubled = jax.tree.map(lambda x: x * 2, nested)
589
+
590
+ self.assertEqual(doubled['a'], 2)
591
+ self.assertEqual(doubled['b']['c'], 4)
592
+
593
+ def test_jax_tree_map_flatted_dict(self):
594
+ """Test jax.tree.map with FlattedDict."""
595
+ flat = FlattedDict({('a',): 1, ('b', 'c'): 2})
596
+ doubled = jax.tree.map(lambda x: x * 2, flat)
597
+
598
+ self.assertEqual(doubled[('a',)], 2)
599
+ self.assertEqual(doubled[('b', 'c')], 4)
600
+
601
+ def test_jax_tree_map_pretty_list(self):
602
+ """Test jax.tree.map with PrettyList."""
603
+ lst = PrettyList([1, 2, 3])
604
+ doubled = jax.tree.map(lambda x: x * 2, lst)
605
+
606
+ self.assertEqual(list(doubled), [2, 4, 6])
607
+
608
+
609
+ class TestEdgeCases(unittest.TestCase):
610
+ """Test edge cases and error handling."""
611
+
612
+ def test_empty_nested_dict(self):
613
+ """Test empty NestedDict."""
614
+ nested = NestedDict({})
615
+ flat = nested.to_flat()
616
+ self.assertEqual(len(flat), 0)
617
+
618
+ def test_empty_flatted_dict(self):
619
+ """Test empty FlattedDict."""
620
+ flat = FlattedDict({})
621
+ nested = flat.to_nest()
622
+ self.assertEqual(len(nested), 0)
623
+
624
+ def test_deeply_nested_structure(self):
625
+ """Test deeply nested structure."""
626
+ nested = NestedDict({
627
+ 'a': {
628
+ 'b': {
629
+ 'c': {
630
+ 'd': {
631
+ 'e': 42
632
+ }
633
+ }
634
+ }
635
+ }
636
+ })
637
+ flat = nested.to_flat()
638
+ self.assertEqual(flat[('a', 'b', 'c', 'd', 'e')], 42)
639
+
640
+ def test_mixed_types_in_nested(self):
641
+ """Test nested dict with mixed types."""
642
+ nested = NestedDict({
643
+ 'int': 1,
644
+ 'float': 2.5,
645
+ 'str': 'hello',
646
+ 'list': [1, 2, 3],
647
+ 'dict': {'nested': True}
648
+ })
649
+ flat = nested.to_flat()
650
+
651
+ self.assertEqual(flat[('int',)], 1)
652
+ self.assertEqual(flat[('float',)], 2.5)
653
+ self.assertEqual(flat[('str',)], 'hello')
654
+
655
+ def test_numeric_keys(self):
656
+ """Test handling of numeric keys."""
657
+ nested = NestedDict({
658
+ 1: 'one',
659
+ 2: {'a': 'two-a'}
660
+ })
661
+ flat = nested.to_flat()
662
+
663
+ self.assertEqual(flat[(1,)], 'one')
664
+ self.assertEqual(flat[(2, 'a')], 'two-a')
665
+
666
+ def test_merge_with_overlapping_keys(self):
667
+ """Test merging with overlapping keys."""
668
+ d1 = NestedDict({'a': 1, 'b': 2})
669
+ d2 = NestedDict({'b': 3, 'c': 4})
670
+ merged = NestedDict.merge(d1, d2)
671
+
672
+ # Later values should override
673
+ self.assertEqual(merged['b'], 3)
674
+ self.assertEqual(merged['a'], 1)
675
+ self.assertEqual(merged['c'], 4)