brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,675 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the pretty_pytree module.
18
+
19
+ This test module provides extensive coverage of the pretty printing and tree
20
+ manipulation functionality, including:
21
+ - PrettyObject and pretty representation
22
+ - Nested and flattened dictionary structures
23
+ - Mapping flattening and unflattening
24
+ - Split, filter, and merge operations
25
+ - JAX pytree integration
26
+ - State management utilities
27
+ """
28
+
29
+ import unittest
30
+
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import numpy as np
34
+ from absl.testing import absltest
35
+
36
+ import brainstate
37
+ from brainstate.util._pretty_pytree import (
38
+ PrettyObject,
39
+ PrettyDict,
40
+ NestedDict,
41
+ FlattedDict,
42
+ PrettyList,
43
+ flat_mapping,
44
+ nest_mapping,
45
+ empty_node,
46
+ _EmptyNode,
47
+ )
48
+
49
+
50
+ class TestNestedMapping(absltest.TestCase):
51
+ def test_create_state(self):
52
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
53
+
54
+ assert state['a'].value == 1
55
+ assert state['b']['c'].value == 2
56
+
57
+ def test_get_attr(self):
58
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
59
+
60
+ assert state.a.value == 1
61
+ assert state.b['c'].value == 2
62
+
63
+ def test_set_attr(self):
64
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
65
+
66
+ state.a.value = 3
67
+ state.b['c'].value = 4
68
+
69
+ assert state['a'].value == 3
70
+ assert state['b']['c'].value == 4
71
+
72
+ def test_set_attr_variables(self):
73
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
74
+
75
+ state.a.value = 3
76
+ state.b['c'].value = 4
77
+
78
+ assert isinstance(state.a, brainstate.ParamState)
79
+ assert state.a.value == 3
80
+ assert isinstance(state.b['c'], brainstate.ParamState)
81
+ assert state.b['c'].value == 4
82
+
83
+ def test_add_nested_attr(self):
84
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
85
+ state.b['d'] = brainstate.ParamState(5)
86
+
87
+ assert state['b']['d'].value == 5
88
+
89
+ def test_delete_nested_attr(self):
90
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
91
+ del state['b']['c']
92
+
93
+ assert 'c' not in state['b']
94
+
95
+ def test_integer_access(self):
96
+ class Foo(brainstate.nn.Module):
97
+ def __init__(self):
98
+ super().__init__()
99
+ self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
100
+
101
+ module = Foo()
102
+ state_refs = brainstate.graph.treefy_states(module)
103
+
104
+ assert module.layers[0].weight.value['weight'].shape == (1, 2)
105
+ assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
106
+ assert module.layers[1].weight.value['weight'].shape == (2, 3)
107
+ assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
108
+
109
+ def test_pure_dict(self):
110
+ module = brainstate.nn.Linear(4, 5)
111
+ state_map = brainstate.graph.treefy_states(module)
112
+ pure_dict = state_map.to_pure_dict()
113
+ assert isinstance(pure_dict, dict)
114
+ assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
115
+ assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
116
+
117
+
118
+ class TestSplit(unittest.TestCase):
119
+ def test_split(self):
120
+ class Model(brainstate.nn.Module):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
124
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
125
+
126
+ def __call__(self, x):
127
+ return self.linear(self.batchnorm(x))
128
+
129
+ with brainstate.environ.context(fit=True):
130
+ model = Model()
131
+ x = brainstate.random.randn(1, 10, 3)
132
+ y = model(x)
133
+ self.assertEqual(y.shape, (1, 10, 4))
134
+
135
+ state_map = brainstate.graph.treefy_states(model)
136
+
137
+ with self.assertRaises(ValueError):
138
+ params, others = state_map.split(brainstate.ParamState)
139
+
140
+ params, others = state_map.split(brainstate.ParamState, ...)
141
+ print()
142
+ print(params)
143
+ print(others)
144
+
145
+ self.assertTrue(len(params.to_flat()) == 2)
146
+ self.assertTrue(len(others.to_flat()) == 2)
147
+
148
+
149
+ class TestStateMap2(unittest.TestCase):
150
+ def test1(self):
151
+ class Model(brainstate.nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
155
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
156
+
157
+ def __call__(self, x):
158
+ return self.linear(self.batchnorm(x))
159
+
160
+ with brainstate.environ.context(fit=True):
161
+ model = Model()
162
+ state_map = brainstate.graph.treefy_states(model).to_flat()
163
+ state_map = brainstate.util.NestedDict(state_map)
164
+
165
+
166
+ class TestFlattedMapping(unittest.TestCase):
167
+ def test1(self):
168
+ class Model(brainstate.nn.Module):
169
+ def __init__(self):
170
+ super().__init__()
171
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
172
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
173
+
174
+ def __call__(self, x):
175
+ return self.linear(self.batchnorm(x))
176
+
177
+ model = Model()
178
+ # print(model.states())
179
+ # print(brainstate.graph.states(model))
180
+ self.assertTrue(model.states() == brainstate.graph.states(model))
181
+
182
+ print(model.nodes())
183
+ # print(brainstate.graph.nodes(model))
184
+ self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
185
+
186
+
187
+ class TestPrettyObject(unittest.TestCase):
188
+ """Test PrettyObject functionality."""
189
+
190
+ def test_pretty_object_basic(self):
191
+ """Test basic PrettyObject creation and representation."""
192
+ class MyObject(PrettyObject):
193
+ def __init__(self, value):
194
+ self.value = value
195
+ self.name = "test"
196
+
197
+ obj = MyObject(42)
198
+ repr_str = repr(obj)
199
+ self.assertIsInstance(repr_str, str)
200
+ self.assertIn("MyObject", repr_str)
201
+ self.assertIn("value", repr_str)
202
+ self.assertIn("42", repr_str)
203
+
204
+ def test_pretty_repr_item_filtering(self):
205
+ """Test __pretty_repr_item__ filtering."""
206
+ class FilteredObject(PrettyObject):
207
+ def __init__(self):
208
+ self.visible = "show"
209
+ self.hidden = "hide"
210
+
211
+ def __pretty_repr_item__(self, k, v):
212
+ if k == "hidden":
213
+ return None
214
+ return k, v
215
+
216
+ obj = FilteredObject()
217
+ repr_str = repr(obj)
218
+ self.assertIn("visible", repr_str)
219
+ self.assertNotIn("hidden", repr_str)
220
+
221
+ def test_pretty_repr_item_transformation(self):
222
+ """Test __pretty_repr_item__ value transformation."""
223
+ class TransformObject(PrettyObject):
224
+ def __init__(self):
225
+ self.value = 100
226
+
227
+ def __pretty_repr_item__(self, k, v):
228
+ if k == "value":
229
+ return k, v * 2
230
+ return k, v
231
+
232
+ obj = TransformObject()
233
+ repr_str = repr(obj)
234
+ self.assertIn("200", repr_str)
235
+
236
+
237
+ class TestFlatAndNestMapping(unittest.TestCase):
238
+ """Test flat_mapping and nest_mapping functions."""
239
+
240
+ def test_flat_mapping_basic(self):
241
+ """Test basic flattening of nested dict."""
242
+ nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
243
+ flat = flat_mapping(nested)
244
+
245
+ self.assertIsInstance(flat, FlattedDict)
246
+ self.assertEqual(flat[('a',)], 1)
247
+ self.assertEqual(flat[('b', 'c')], 2)
248
+ self.assertEqual(flat[('b', 'd')], 3)
249
+
250
+ def test_flat_mapping_with_separator(self):
251
+ """Test flattening with string separator."""
252
+ nested = {'a': 1, 'b': {'c': 2}}
253
+ flat = flat_mapping(nested, sep='/')
254
+
255
+ self.assertEqual(flat['a'], 1)
256
+ self.assertEqual(flat['b/c'], 2)
257
+
258
+ def test_flat_mapping_empty_nodes(self):
259
+ """Test flattening with keep_empty_nodes."""
260
+ nested = {'a': 1, 'b': {}}
261
+ flat = flat_mapping(nested, keep_empty_nodes=True)
262
+
263
+ self.assertEqual(flat[('a',)], 1)
264
+ self.assertIsInstance(flat[('b',)], _EmptyNode)
265
+
266
+ def test_flat_mapping_without_empty_nodes(self):
267
+ """Test flattening without keeping empty nodes."""
268
+ nested = {'a': 1, 'b': {}}
269
+ flat = flat_mapping(nested, keep_empty_nodes=False)
270
+
271
+ self.assertIn(('a',), flat)
272
+ self.assertNotIn(('b',), flat)
273
+
274
+ def test_flat_mapping_is_leaf(self):
275
+ """Test flattening with custom is_leaf function."""
276
+ nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
277
+
278
+ def is_leaf(path, value):
279
+ return len(path) >= 1
280
+
281
+ flat = flat_mapping(nested, is_leaf=is_leaf)
282
+ self.assertEqual(flat[('a',)], 1)
283
+ self.assertEqual(flat[('b',)], {'c': 2, 'd': 3})
284
+
285
+ def test_nest_mapping_basic(self):
286
+ """Test basic unflattening."""
287
+ flat = {('a',): 1, ('b', 'c'): 2, ('b', 'd'): 3}
288
+ nested = nest_mapping(flat)
289
+
290
+ self.assertIsInstance(nested, NestedDict)
291
+ self.assertEqual(nested['a'], 1)
292
+ self.assertEqual(nested['b']['c'], 2)
293
+ self.assertEqual(nested['b']['d'], 3)
294
+
295
+ def test_nest_mapping_with_separator(self):
296
+ """Test unflattening with string separator."""
297
+ flat = {'a': 1, 'b/c': 2, 'b/d': 3}
298
+ nested = nest_mapping(flat, sep='/')
299
+
300
+ self.assertEqual(nested['a'], 1)
301
+ self.assertEqual(nested['b']['c'], 2)
302
+ self.assertEqual(nested['b']['d'], 3)
303
+
304
+ def test_nest_mapping_with_empty_node(self):
305
+ """Test unflattening with empty nodes."""
306
+ flat = {('a',): 1, ('b',): empty_node}
307
+ nested = nest_mapping(flat)
308
+
309
+ self.assertEqual(nested['a'], 1)
310
+ self.assertEqual(nested['b'], {})
311
+
312
+ def test_round_trip(self):
313
+ """Test flatten -> unflatten round trip."""
314
+ original = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
315
+ flat = flat_mapping(original)
316
+ restored = nest_mapping(flat)
317
+
318
+ self.assertEqual(restored.to_dict(), original)
319
+
320
+
321
+ class TestPrettyDict(unittest.TestCase):
322
+ """Test PrettyDict functionality."""
323
+
324
+ def test_pretty_dict_creation(self):
325
+ """Test PrettyDict creation."""
326
+ d = PrettyDict({'a': 1, 'b': 2})
327
+ self.assertEqual(d['a'], 1)
328
+ self.assertEqual(d['b'], 2)
329
+
330
+ def test_pretty_dict_attribute_access(self):
331
+ """Test accessing dict items as attributes."""
332
+ d = PrettyDict({'a': 1, 'b': 2})
333
+ self.assertEqual(d.a, 1)
334
+ self.assertEqual(d.b, 2)
335
+
336
+ def test_pretty_dict_repr(self):
337
+ """Test PrettyDict representation."""
338
+ d = PrettyDict({'a': 1, 'b': 2})
339
+ repr_str = repr(d)
340
+ self.assertIsInstance(repr_str, str)
341
+ self.assertIn('a', repr_str)
342
+
343
+ def test_to_dict(self):
344
+ """Test conversion to regular dict."""
345
+ d = PrettyDict({'a': 1, 'b': 2})
346
+ regular = d.to_dict()
347
+ self.assertIsInstance(regular, dict)
348
+ self.assertEqual(regular, {'a': 1, 'b': 2})
349
+
350
+
351
+ class TestNestedDictOperations(unittest.TestCase):
352
+ """Test NestedDict additional operations."""
353
+
354
+ def test_or_operator(self):
355
+ """Test | operator for merging."""
356
+ d1 = NestedDict({'a': 1})
357
+ d2 = NestedDict({'b': 2})
358
+ merged = d1 | d2
359
+
360
+ self.assertIsInstance(merged, NestedDict)
361
+ self.assertEqual(merged['a'], 1)
362
+ self.assertEqual(merged['b'], 2)
363
+
364
+ def test_sub_operator(self):
365
+ """Test - operator for difference."""
366
+ d1 = NestedDict({'a': 1, 'b': 2, 'c': 3})
367
+ d2 = NestedDict({'b': 2})
368
+ diff = d1 - d2
369
+
370
+ flat_diff = diff.to_flat()
371
+ self.assertIn(('a',), flat_diff.keys())
372
+ self.assertIn(('c',), flat_diff.keys())
373
+ # b should not be in diff
374
+ has_b = any('b' in key for key in flat_diff.keys())
375
+ self.assertFalse(has_b)
376
+
377
+ def test_merge_static_method(self):
378
+ """Test static merge method."""
379
+ d1 = NestedDict({'a': 1})
380
+ d2 = NestedDict({'b': 2})
381
+ d3 = NestedDict({'c': 3})
382
+ merged = NestedDict.merge(d1, d2, d3)
383
+
384
+ self.assertEqual(merged['a'], 1)
385
+ self.assertEqual(merged['b'], 2)
386
+ self.assertEqual(merged['c'], 3)
387
+
388
+ def test_to_pure_dict(self):
389
+ """Test conversion to pure dict."""
390
+ nested = NestedDict({'a': 1, 'b': {'c': 2}})
391
+ pure = nested.to_pure_dict()
392
+
393
+ self.assertIsInstance(pure, dict)
394
+ self.assertNotIsInstance(pure, NestedDict)
395
+ self.assertEqual(pure['a'], 1)
396
+ self.assertEqual(pure['b']['c'], 2)
397
+
398
+
399
+ class TestFlattedDictOperations(unittest.TestCase):
400
+ """Test FlattedDict additional operations."""
401
+
402
+ def test_or_operator(self):
403
+ """Test | operator for merging."""
404
+ d1 = FlattedDict({('a',): 1})
405
+ d2 = FlattedDict({('b',): 2})
406
+ merged = d1 | d2
407
+
408
+ self.assertIsInstance(merged, FlattedDict)
409
+ self.assertEqual(merged[('a',)], 1)
410
+ self.assertEqual(merged[('b',)], 2)
411
+
412
+ def test_sub_operator(self):
413
+ """Test - operator for difference."""
414
+ d1 = FlattedDict({('a',): 1, ('b',): 2, ('c',): 3})
415
+ d2 = FlattedDict({('b',): 2})
416
+ diff = d1 - d2
417
+
418
+ self.assertIn(('a',), diff)
419
+ self.assertIn(('c',), diff)
420
+ self.assertNotIn(('b',), diff)
421
+
422
+ def test_merge_static_method(self):
423
+ """Test static merge method."""
424
+ d1 = FlattedDict({('a',): 1})
425
+ d2 = FlattedDict({('b',): 2})
426
+ merged = FlattedDict.merge(d1, d2)
427
+
428
+ self.assertEqual(merged[('a',)], 1)
429
+ self.assertEqual(merged[('b',)], 2)
430
+
431
+ def test_to_dict_values(self):
432
+ """Test conversion to dictionary of values."""
433
+ flat = FlattedDict({
434
+ ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
435
+ ('b',): 42
436
+ })
437
+ values = flat.to_dict_values()
438
+
439
+ self.assertIsInstance(values[('a',)], jnp.ndarray)
440
+ np.testing.assert_array_equal(values[('a',)], jnp.array([1, 2, 3]))
441
+ self.assertEqual(values[('b',)], 42)
442
+
443
+ def test_assign_dict_values(self):
444
+ """Test assigning dictionary values."""
445
+ flat = FlattedDict({
446
+ ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
447
+ ('b',): 42
448
+ })
449
+
450
+ new_values = {
451
+ ('a',): jnp.array([4, 5, 6]),
452
+ ('b',): 100
453
+ }
454
+
455
+ flat.assign_dict_values(new_values)
456
+
457
+ np.testing.assert_array_equal(flat[('a',)].value, jnp.array([4, 5, 6]))
458
+ self.assertEqual(flat[('b',)], 100)
459
+
460
+ def test_assign_dict_values_missing_key(self):
461
+ """Test assigning with missing key raises error."""
462
+ flat = FlattedDict({('a',): 1})
463
+
464
+ with self.assertRaises(KeyError):
465
+ flat.assign_dict_values({('b',): 2})
466
+
467
+
468
+ class TestPrettyList(unittest.TestCase):
469
+ """Test PrettyList functionality."""
470
+
471
+ def test_pretty_list_creation(self):
472
+ """Test PrettyList creation."""
473
+ lst = PrettyList([1, 2, 3])
474
+ self.assertEqual(lst[0], 1)
475
+ self.assertEqual(lst[1], 2)
476
+ self.assertEqual(lst[2], 3)
477
+
478
+ def test_pretty_list_repr(self):
479
+ """Test PrettyList representation."""
480
+ lst = PrettyList([1, 2, {'a': 3}])
481
+ repr_str = repr(lst)
482
+ self.assertIsInstance(repr_str, str)
483
+ self.assertIn('1', repr_str)
484
+
485
+ def test_tree_flatten(self):
486
+ """Test JAX tree flattening."""
487
+ lst = PrettyList([1, 2, 3])
488
+ leaves, aux = lst.tree_flatten()
489
+ self.assertEqual(leaves, [1, 2, 3])
490
+ self.assertEqual(aux, ())
491
+
492
+ def test_tree_unflatten(self):
493
+ """Test JAX tree unflattening."""
494
+ children = [1, 2, 3]
495
+ lst = PrettyList.tree_unflatten((), children)
496
+ self.assertIsInstance(lst, PrettyList)
497
+ self.assertEqual(list(lst), [1, 2, 3])
498
+
499
+
500
+ class TestFilterOperations(unittest.TestCase):
501
+ """Test filter operations."""
502
+
503
+ def test_nested_dict_filter(self):
504
+ """Test filtering NestedDict."""
505
+ nested = NestedDict({
506
+ 'a': 1,
507
+ 'b': 2,
508
+ 'c': 3
509
+ })
510
+
511
+ filtered = nested.filter(lambda path, val: val >= 2)
512
+
513
+ flat = filtered.to_flat()
514
+ # Check that filtered values are present
515
+ values = [v for v in flat.values()]
516
+ self.assertIn(2, values)
517
+ self.assertIn(3, values)
518
+
519
+ def test_flatted_dict_filter(self):
520
+ """Test filtering FlattedDict."""
521
+ flat = FlattedDict({
522
+ ('a',): 1,
523
+ ('b',): 2,
524
+ ('c',): 3
525
+ })
526
+
527
+ filtered = flat.filter(lambda path, val: val % 2 == 0)
528
+ self.assertIn(('b',), filtered)
529
+ self.assertNotIn(('a',), filtered)
530
+
531
+ def test_ellipsis_filter_position(self):
532
+ """Test that ... can only be used as last filter."""
533
+ nested = NestedDict({'a': 1, 'b': 2, 'c': 3})
534
+
535
+ with self.assertRaises(ValueError):
536
+ # ... in middle should raise error
537
+ nested.split(..., lambda path, val: val > 1)
538
+
539
+
540
+ class TestJAXPytreeIntegration(unittest.TestCase):
541
+ """Test JAX pytree integration."""
542
+
543
+ def test_nested_dict_pytree_flatten(self):
544
+ """Test NestedDict can be flattened as pytree."""
545
+ nested = NestedDict({'a': 1, 'b': 2})
546
+ leaves, treedef = jax.tree.flatten(nested)
547
+
548
+ self.assertEqual(sorted(leaves), [1, 2])
549
+
550
+ def test_nested_dict_pytree_unflatten(self):
551
+ """Test NestedDict can be unflattened as pytree."""
552
+ nested = NestedDict({'a': 1, 'b': 2})
553
+ leaves, treedef = jax.tree.flatten(nested)
554
+ restored = jax.tree.unflatten(treedef, leaves)
555
+
556
+ self.assertIsInstance(restored, NestedDict)
557
+ self.assertEqual(restored['a'], 1)
558
+ self.assertEqual(restored['b'], 2)
559
+
560
+ def test_flatted_dict_pytree_flatten(self):
561
+ """Test FlattedDict can be flattened as pytree."""
562
+ flat = FlattedDict({('a',): 1, ('b',): 2})
563
+ leaves, treedef = jax.tree.flatten(flat)
564
+
565
+ self.assertEqual(sorted(leaves), [1, 2])
566
+
567
+ def test_flatted_dict_pytree_unflatten(self):
568
+ """Test FlattedDict can be unflattened as pytree."""
569
+ flat = FlattedDict({('a',): 1, ('b',): 2})
570
+ leaves, treedef = jax.tree.flatten(flat)
571
+ restored = jax.tree.unflatten(treedef, leaves)
572
+
573
+ self.assertIsInstance(restored, FlattedDict)
574
+ self.assertEqual(restored[('a',)], 1)
575
+
576
+ def test_pretty_list_pytree(self):
577
+ """Test PrettyList pytree operations."""
578
+ lst = PrettyList([1, 2, 3])
579
+ leaves, treedef = jax.tree.flatten(lst)
580
+ restored = jax.tree.unflatten(treedef, leaves)
581
+
582
+ self.assertIsInstance(restored, PrettyList)
583
+ self.assertEqual(list(restored), [1, 2, 3])
584
+
585
+ def test_jax_tree_map_nested_dict(self):
586
+ """Test jax.tree.map with NestedDict."""
587
+ nested = NestedDict({'a': 1, 'b': {'c': 2}})
588
+ doubled = jax.tree.map(lambda x: x * 2, nested)
589
+
590
+ self.assertEqual(doubled['a'], 2)
591
+ self.assertEqual(doubled['b']['c'], 4)
592
+
593
+ def test_jax_tree_map_flatted_dict(self):
594
+ """Test jax.tree.map with FlattedDict."""
595
+ flat = FlattedDict({('a',): 1, ('b', 'c'): 2})
596
+ doubled = jax.tree.map(lambda x: x * 2, flat)
597
+
598
+ self.assertEqual(doubled[('a',)], 2)
599
+ self.assertEqual(doubled[('b', 'c')], 4)
600
+
601
+ def test_jax_tree_map_pretty_list(self):
602
+ """Test jax.tree.map with PrettyList."""
603
+ lst = PrettyList([1, 2, 3])
604
+ doubled = jax.tree.map(lambda x: x * 2, lst)
605
+
606
+ self.assertEqual(list(doubled), [2, 4, 6])
607
+
608
+
609
+ class TestEdgeCases(unittest.TestCase):
610
+ """Test edge cases and error handling."""
611
+
612
+ def test_empty_nested_dict(self):
613
+ """Test empty NestedDict."""
614
+ nested = NestedDict({})
615
+ flat = nested.to_flat()
616
+ self.assertEqual(len(flat), 0)
617
+
618
+ def test_empty_flatted_dict(self):
619
+ """Test empty FlattedDict."""
620
+ flat = FlattedDict({})
621
+ nested = flat.to_nest()
622
+ self.assertEqual(len(nested), 0)
623
+
624
+ def test_deeply_nested_structure(self):
625
+ """Test deeply nested structure."""
626
+ nested = NestedDict({
627
+ 'a': {
628
+ 'b': {
629
+ 'c': {
630
+ 'd': {
631
+ 'e': 42
632
+ }
633
+ }
634
+ }
635
+ }
636
+ })
637
+ flat = nested.to_flat()
638
+ self.assertEqual(flat[('a', 'b', 'c', 'd', 'e')], 42)
639
+
640
+ def test_mixed_types_in_nested(self):
641
+ """Test nested dict with mixed types."""
642
+ nested = NestedDict({
643
+ 'int': 1,
644
+ 'float': 2.5,
645
+ 'str': 'hello',
646
+ 'list': [1, 2, 3],
647
+ 'dict': {'nested': True}
648
+ })
649
+ flat = nested.to_flat()
650
+
651
+ self.assertEqual(flat[('int',)], 1)
652
+ self.assertEqual(flat[('float',)], 2.5)
653
+ self.assertEqual(flat[('str',)], 'hello')
654
+
655
+ def test_numeric_keys(self):
656
+ """Test handling of numeric keys."""
657
+ nested = NestedDict({
658
+ 1: 'one',
659
+ 2: {'a': 'two-a'}
660
+ })
661
+ flat = nested.to_flat()
662
+
663
+ self.assertEqual(flat[(1,)], 'one')
664
+ self.assertEqual(flat[(2, 'a')], 'two-a')
665
+
666
+ def test_merge_with_overlapping_keys(self):
667
+ """Test merging with overlapping keys."""
668
+ d1 = NestedDict({'a': 1, 'b': 2})
669
+ d2 = NestedDict({'b': 3, 'c': 4})
670
+ merged = NestedDict.merge(d1, d2)
671
+
672
+ # Later values should override
673
+ self.assertEqual(merged['b'], 3)
674
+ self.assertEqual(merged['a'], 1)
675
+ self.assertEqual(merged['c'], 4)