brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,675 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the pretty_pytree module.
18
+
19
+ This test module provides extensive coverage of the pretty printing and tree
20
+ manipulation functionality, including:
21
+ - PrettyObject and pretty representation
22
+ - Nested and flattened dictionary structures
23
+ - Mapping flattening and unflattening
24
+ - Split, filter, and merge operations
25
+ - JAX pytree integration
26
+ - State management utilities
27
+ """
28
+
29
+ import unittest
30
+
31
+ import jax
32
+ import jax.numpy as jnp
33
+ import numpy as np
34
+ from absl.testing import absltest
35
+
36
+ import brainstate
37
+ from brainstate.util._pretty_pytree import (
38
+ PrettyObject,
39
+ PrettyDict,
40
+ NestedDict,
41
+ FlattedDict,
42
+ PrettyList,
43
+ flat_mapping,
44
+ nest_mapping,
45
+ empty_node,
46
+ _EmptyNode,
47
+ )
48
+
49
+
50
+ class TestNestedMapping(absltest.TestCase):
51
+ def test_create_state(self):
52
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
53
+
54
+ assert state['a'].value == 1
55
+ assert state['b']['c'].value == 2
56
+
57
+ def test_get_attr(self):
58
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
59
+
60
+ assert state.a.value == 1
61
+ assert state.b['c'].value == 2
62
+
63
+ def test_set_attr(self):
64
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
65
+
66
+ state.a.value = 3
67
+ state.b['c'].value = 4
68
+
69
+ assert state['a'].value == 3
70
+ assert state['b']['c'].value == 4
71
+
72
+ def test_set_attr_variables(self):
73
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
74
+
75
+ state.a.value = 3
76
+ state.b['c'].value = 4
77
+
78
+ assert isinstance(state.a, brainstate.ParamState)
79
+ assert state.a.value == 3
80
+ assert isinstance(state.b['c'], brainstate.ParamState)
81
+ assert state.b['c'].value == 4
82
+
83
+ def test_add_nested_attr(self):
84
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
85
+ state.b['d'] = brainstate.ParamState(5)
86
+
87
+ assert state['b']['d'].value == 5
88
+
89
+ def test_delete_nested_attr(self):
90
+ state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
91
+ del state['b']['c']
92
+
93
+ assert 'c' not in state['b']
94
+
95
+ def test_integer_access(self):
96
+ class Foo(brainstate.nn.Module):
97
+ def __init__(self):
98
+ super().__init__()
99
+ self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
100
+
101
+ module = Foo()
102
+ state_refs = brainstate.graph.treefy_states(module)
103
+
104
+ assert module.layers[0].weight.value['weight'].shape == (1, 2)
105
+ assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
106
+ assert module.layers[1].weight.value['weight'].shape == (2, 3)
107
+ assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
108
+
109
+ def test_pure_dict(self):
110
+ module = brainstate.nn.Linear(4, 5)
111
+ state_map = brainstate.graph.treefy_states(module)
112
+ pure_dict = state_map.to_pure_dict()
113
+ assert isinstance(pure_dict, dict)
114
+ assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
115
+ assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
116
+
117
+
118
+ class TestSplit(unittest.TestCase):
119
+ def test_split(self):
120
+ class Model(brainstate.nn.Module):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
124
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
125
+
126
+ def __call__(self, x):
127
+ return self.linear(self.batchnorm(x))
128
+
129
+ with brainstate.environ.context(fit=True):
130
+ model = Model()
131
+ x = brainstate.random.randn(1, 10, 3)
132
+ y = model(x)
133
+ self.assertEqual(y.shape, (1, 10, 4))
134
+
135
+ state_map = brainstate.graph.treefy_states(model)
136
+
137
+ with self.assertRaises(ValueError):
138
+ params, others = state_map.split(brainstate.ParamState)
139
+
140
+ params, others = state_map.split(brainstate.ParamState, ...)
141
+ print()
142
+ print(params)
143
+ print(others)
144
+
145
+ self.assertTrue(len(params.to_flat()) == 2)
146
+ self.assertTrue(len(others.to_flat()) == 2)
147
+
148
+
149
+ class TestStateMap2(unittest.TestCase):
150
+ def test1(self):
151
+ class Model(brainstate.nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
155
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
156
+
157
+ def __call__(self, x):
158
+ return self.linear(self.batchnorm(x))
159
+
160
+ with brainstate.environ.context(fit=True):
161
+ model = Model()
162
+ state_map = brainstate.graph.treefy_states(model).to_flat()
163
+ state_map = brainstate.util.NestedDict(state_map)
164
+
165
+
166
+ class TestFlattedMapping(unittest.TestCase):
167
+ def test1(self):
168
+ class Model(brainstate.nn.Module):
169
+ def __init__(self):
170
+ super().__init__()
171
+ self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
172
+ self.linear = brainstate.nn.Linear([10, 3], [10, 4])
173
+
174
+ def __call__(self, x):
175
+ return self.linear(self.batchnorm(x))
176
+
177
+ model = Model()
178
+ # print(model.states())
179
+ # print(brainstate.graph.states(model))
180
+ self.assertTrue(model.states() == brainstate.graph.states(model))
181
+
182
+ print(model.nodes())
183
+ # print(brainstate.graph.nodes(model))
184
+ self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
185
+
186
+
187
+ class TestPrettyObject(unittest.TestCase):
188
+ """Test PrettyObject functionality."""
189
+
190
+ def test_pretty_object_basic(self):
191
+ """Test basic PrettyObject creation and representation."""
192
+ class MyObject(PrettyObject):
193
+ def __init__(self, value):
194
+ self.value = value
195
+ self.name = "test"
196
+
197
+ obj = MyObject(42)
198
+ repr_str = repr(obj)
199
+ self.assertIsInstance(repr_str, str)
200
+ self.assertIn("MyObject", repr_str)
201
+ self.assertIn("value", repr_str)
202
+ self.assertIn("42", repr_str)
203
+
204
+ def test_pretty_repr_item_filtering(self):
205
+ """Test __pretty_repr_item__ filtering."""
206
+ class FilteredObject(PrettyObject):
207
+ def __init__(self):
208
+ self.visible = "show"
209
+ self.hidden = "hide"
210
+
211
+ def __pretty_repr_item__(self, k, v):
212
+ if k == "hidden":
213
+ return None
214
+ return k, v
215
+
216
+ obj = FilteredObject()
217
+ repr_str = repr(obj)
218
+ self.assertIn("visible", repr_str)
219
+ self.assertNotIn("hidden", repr_str)
220
+
221
+ def test_pretty_repr_item_transformation(self):
222
+ """Test __pretty_repr_item__ value transformation."""
223
+ class TransformObject(PrettyObject):
224
+ def __init__(self):
225
+ self.value = 100
226
+
227
+ def __pretty_repr_item__(self, k, v):
228
+ if k == "value":
229
+ return k, v * 2
230
+ return k, v
231
+
232
+ obj = TransformObject()
233
+ repr_str = repr(obj)
234
+ self.assertIn("200", repr_str)
235
+
236
+
237
+ class TestFlatAndNestMapping(unittest.TestCase):
238
+ """Test flat_mapping and nest_mapping functions."""
239
+
240
+ def test_flat_mapping_basic(self):
241
+ """Test basic flattening of nested dict."""
242
+ nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
243
+ flat = flat_mapping(nested)
244
+
245
+ self.assertIsInstance(flat, FlattedDict)
246
+ self.assertEqual(flat[('a',)], 1)
247
+ self.assertEqual(flat[('b', 'c')], 2)
248
+ self.assertEqual(flat[('b', 'd')], 3)
249
+
250
+ def test_flat_mapping_with_separator(self):
251
+ """Test flattening with string separator."""
252
+ nested = {'a': 1, 'b': {'c': 2}}
253
+ flat = flat_mapping(nested, sep='/')
254
+
255
+ self.assertEqual(flat['a'], 1)
256
+ self.assertEqual(flat['b/c'], 2)
257
+
258
+ def test_flat_mapping_empty_nodes(self):
259
+ """Test flattening with keep_empty_nodes."""
260
+ nested = {'a': 1, 'b': {}}
261
+ flat = flat_mapping(nested, keep_empty_nodes=True)
262
+
263
+ self.assertEqual(flat[('a',)], 1)
264
+ self.assertIsInstance(flat[('b',)], _EmptyNode)
265
+
266
+ def test_flat_mapping_without_empty_nodes(self):
267
+ """Test flattening without keeping empty nodes."""
268
+ nested = {'a': 1, 'b': {}}
269
+ flat = flat_mapping(nested, keep_empty_nodes=False)
270
+
271
+ self.assertIn(('a',), flat)
272
+ self.assertNotIn(('b',), flat)
273
+
274
+ def test_flat_mapping_is_leaf(self):
275
+ """Test flattening with custom is_leaf function."""
276
+ nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
277
+
278
+ def is_leaf(path, value):
279
+ return len(path) >= 1
280
+
281
+ flat = flat_mapping(nested, is_leaf=is_leaf)
282
+ self.assertEqual(flat[('a',)], 1)
283
+ self.assertEqual(flat[('b',)], {'c': 2, 'd': 3})
284
+
285
+ def test_nest_mapping_basic(self):
286
+ """Test basic unflattening."""
287
+ flat = {('a',): 1, ('b', 'c'): 2, ('b', 'd'): 3}
288
+ nested = nest_mapping(flat)
289
+
290
+ self.assertIsInstance(nested, NestedDict)
291
+ self.assertEqual(nested['a'], 1)
292
+ self.assertEqual(nested['b']['c'], 2)
293
+ self.assertEqual(nested['b']['d'], 3)
294
+
295
+ def test_nest_mapping_with_separator(self):
296
+ """Test unflattening with string separator."""
297
+ flat = {'a': 1, 'b/c': 2, 'b/d': 3}
298
+ nested = nest_mapping(flat, sep='/')
299
+
300
+ self.assertEqual(nested['a'], 1)
301
+ self.assertEqual(nested['b']['c'], 2)
302
+ self.assertEqual(nested['b']['d'], 3)
303
+
304
+ def test_nest_mapping_with_empty_node(self):
305
+ """Test unflattening with empty nodes."""
306
+ flat = {('a',): 1, ('b',): empty_node}
307
+ nested = nest_mapping(flat)
308
+
309
+ self.assertEqual(nested['a'], 1)
310
+ self.assertEqual(nested['b'], {})
311
+
312
+ def test_round_trip(self):
313
+ """Test flatten -> unflatten round trip."""
314
+ original = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
315
+ flat = flat_mapping(original)
316
+ restored = nest_mapping(flat)
317
+
318
+ self.assertEqual(restored.to_dict(), original)
319
+
320
+
321
+ class TestPrettyDict(unittest.TestCase):
322
+ """Test PrettyDict functionality."""
323
+
324
+ def test_pretty_dict_creation(self):
325
+ """Test PrettyDict creation."""
326
+ d = PrettyDict({'a': 1, 'b': 2})
327
+ self.assertEqual(d['a'], 1)
328
+ self.assertEqual(d['b'], 2)
329
+
330
+ def test_pretty_dict_attribute_access(self):
331
+ """Test accessing dict items as attributes."""
332
+ d = PrettyDict({'a': 1, 'b': 2})
333
+ self.assertEqual(d.a, 1)
334
+ self.assertEqual(d.b, 2)
335
+
336
+ def test_pretty_dict_repr(self):
337
+ """Test PrettyDict representation."""
338
+ d = PrettyDict({'a': 1, 'b': 2})
339
+ repr_str = repr(d)
340
+ self.assertIsInstance(repr_str, str)
341
+ self.assertIn('a', repr_str)
342
+
343
+ def test_to_dict(self):
344
+ """Test conversion to regular dict."""
345
+ d = PrettyDict({'a': 1, 'b': 2})
346
+ regular = d.to_dict()
347
+ self.assertIsInstance(regular, dict)
348
+ self.assertEqual(regular, {'a': 1, 'b': 2})
349
+
350
+
351
+ class TestNestedDictOperations(unittest.TestCase):
352
+ """Test NestedDict additional operations."""
353
+
354
+ def test_or_operator(self):
355
+ """Test | operator for merging."""
356
+ d1 = NestedDict({'a': 1})
357
+ d2 = NestedDict({'b': 2})
358
+ merged = d1 | d2
359
+
360
+ self.assertIsInstance(merged, NestedDict)
361
+ self.assertEqual(merged['a'], 1)
362
+ self.assertEqual(merged['b'], 2)
363
+
364
+ def test_sub_operator(self):
365
+ """Test - operator for difference."""
366
+ d1 = NestedDict({'a': 1, 'b': 2, 'c': 3})
367
+ d2 = NestedDict({'b': 2})
368
+ diff = d1 - d2
369
+
370
+ flat_diff = diff.to_flat()
371
+ self.assertIn(('a',), flat_diff.keys())
372
+ self.assertIn(('c',), flat_diff.keys())
373
+ # b should not be in diff
374
+ has_b = any('b' in key for key in flat_diff.keys())
375
+ self.assertFalse(has_b)
376
+
377
+ def test_merge_static_method(self):
378
+ """Test static merge method."""
379
+ d1 = NestedDict({'a': 1})
380
+ d2 = NestedDict({'b': 2})
381
+ d3 = NestedDict({'c': 3})
382
+ merged = NestedDict.merge(d1, d2, d3)
383
+
384
+ self.assertEqual(merged['a'], 1)
385
+ self.assertEqual(merged['b'], 2)
386
+ self.assertEqual(merged['c'], 3)
387
+
388
+ def test_to_pure_dict(self):
389
+ """Test conversion to pure dict."""
390
+ nested = NestedDict({'a': 1, 'b': {'c': 2}})
391
+ pure = nested.to_pure_dict()
392
+
393
+ self.assertIsInstance(pure, dict)
394
+ self.assertNotIsInstance(pure, NestedDict)
395
+ self.assertEqual(pure['a'], 1)
396
+ self.assertEqual(pure['b']['c'], 2)
397
+
398
+
399
+ class TestFlattedDictOperations(unittest.TestCase):
400
+ """Test FlattedDict additional operations."""
401
+
402
+ def test_or_operator(self):
403
+ """Test | operator for merging."""
404
+ d1 = FlattedDict({('a',): 1})
405
+ d2 = FlattedDict({('b',): 2})
406
+ merged = d1 | d2
407
+
408
+ self.assertIsInstance(merged, FlattedDict)
409
+ self.assertEqual(merged[('a',)], 1)
410
+ self.assertEqual(merged[('b',)], 2)
411
+
412
+ def test_sub_operator(self):
413
+ """Test - operator for difference."""
414
+ d1 = FlattedDict({('a',): 1, ('b',): 2, ('c',): 3})
415
+ d2 = FlattedDict({('b',): 2})
416
+ diff = d1 - d2
417
+
418
+ self.assertIn(('a',), diff)
419
+ self.assertIn(('c',), diff)
420
+ self.assertNotIn(('b',), diff)
421
+
422
+ def test_merge_static_method(self):
423
+ """Test static merge method."""
424
+ d1 = FlattedDict({('a',): 1})
425
+ d2 = FlattedDict({('b',): 2})
426
+ merged = FlattedDict.merge(d1, d2)
427
+
428
+ self.assertEqual(merged[('a',)], 1)
429
+ self.assertEqual(merged[('b',)], 2)
430
+
431
+ def test_to_dict_values(self):
432
+ """Test conversion to dictionary of values."""
433
+ flat = FlattedDict({
434
+ ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
435
+ ('b',): 42
436
+ })
437
+ values = flat.to_dict_values()
438
+
439
+ self.assertIsInstance(values[('a',)], jnp.ndarray)
440
+ np.testing.assert_array_equal(values[('a',)], jnp.array([1, 2, 3]))
441
+ self.assertEqual(values[('b',)], 42)
442
+
443
+ def test_assign_dict_values(self):
444
+ """Test assigning dictionary values."""
445
+ flat = FlattedDict({
446
+ ('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
447
+ ('b',): 42
448
+ })
449
+
450
+ new_values = {
451
+ ('a',): jnp.array([4, 5, 6]),
452
+ ('b',): 100
453
+ }
454
+
455
+ flat.assign_dict_values(new_values)
456
+
457
+ np.testing.assert_array_equal(flat[('a',)].value, jnp.array([4, 5, 6]))
458
+ self.assertEqual(flat[('b',)], 100)
459
+
460
+ def test_assign_dict_values_missing_key(self):
461
+ """Test assigning with missing key raises error."""
462
+ flat = FlattedDict({('a',): 1})
463
+
464
+ with self.assertRaises(KeyError):
465
+ flat.assign_dict_values({('b',): 2})
466
+
467
+
468
+ class TestPrettyList(unittest.TestCase):
469
+ """Test PrettyList functionality."""
470
+
471
+ def test_pretty_list_creation(self):
472
+ """Test PrettyList creation."""
473
+ lst = PrettyList([1, 2, 3])
474
+ self.assertEqual(lst[0], 1)
475
+ self.assertEqual(lst[1], 2)
476
+ self.assertEqual(lst[2], 3)
477
+
478
+ def test_pretty_list_repr(self):
479
+ """Test PrettyList representation."""
480
+ lst = PrettyList([1, 2, {'a': 3}])
481
+ repr_str = repr(lst)
482
+ self.assertIsInstance(repr_str, str)
483
+ self.assertIn('1', repr_str)
484
+
485
+ def test_tree_flatten(self):
486
+ """Test JAX tree flattening."""
487
+ lst = PrettyList([1, 2, 3])
488
+ leaves, aux = lst.tree_flatten()
489
+ self.assertEqual(leaves, [1, 2, 3])
490
+ self.assertEqual(aux, ())
491
+
492
+ def test_tree_unflatten(self):
493
+ """Test JAX tree unflattening."""
494
+ children = [1, 2, 3]
495
+ lst = PrettyList.tree_unflatten((), children)
496
+ self.assertIsInstance(lst, PrettyList)
497
+ self.assertEqual(list(lst), [1, 2, 3])
498
+
499
+
500
+ class TestFilterOperations(unittest.TestCase):
501
+ """Test filter operations."""
502
+
503
+ def test_nested_dict_filter(self):
504
+ """Test filtering NestedDict."""
505
+ nested = NestedDict({
506
+ 'a': 1,
507
+ 'b': 2,
508
+ 'c': 3
509
+ })
510
+
511
+ filtered = nested.filter(lambda path, val: val >= 2)
512
+
513
+ flat = filtered.to_flat()
514
+ # Check that filtered values are present
515
+ values = [v for v in flat.values()]
516
+ self.assertIn(2, values)
517
+ self.assertIn(3, values)
518
+
519
+ def test_flatted_dict_filter(self):
520
+ """Test filtering FlattedDict."""
521
+ flat = FlattedDict({
522
+ ('a',): 1,
523
+ ('b',): 2,
524
+ ('c',): 3
525
+ })
526
+
527
+ filtered = flat.filter(lambda path, val: val % 2 == 0)
528
+ self.assertIn(('b',), filtered)
529
+ self.assertNotIn(('a',), filtered)
530
+
531
+ def test_ellipsis_filter_position(self):
532
+ """Test that ... can only be used as last filter."""
533
+ nested = NestedDict({'a': 1, 'b': 2, 'c': 3})
534
+
535
+ with self.assertRaises(ValueError):
536
+ # ... in middle should raise error
537
+ nested.split(..., lambda path, val: val > 1)
538
+
539
+
540
+ class TestJAXPytreeIntegration(unittest.TestCase):
541
+ """Test JAX pytree integration."""
542
+
543
+ def test_nested_dict_pytree_flatten(self):
544
+ """Test NestedDict can be flattened as pytree."""
545
+ nested = NestedDict({'a': 1, 'b': 2})
546
+ leaves, treedef = jax.tree.flatten(nested)
547
+
548
+ self.assertEqual(sorted(leaves), [1, 2])
549
+
550
+ def test_nested_dict_pytree_unflatten(self):
551
+ """Test NestedDict can be unflattened as pytree."""
552
+ nested = NestedDict({'a': 1, 'b': 2})
553
+ leaves, treedef = jax.tree.flatten(nested)
554
+ restored = jax.tree.unflatten(treedef, leaves)
555
+
556
+ self.assertIsInstance(restored, NestedDict)
557
+ self.assertEqual(restored['a'], 1)
558
+ self.assertEqual(restored['b'], 2)
559
+
560
+ def test_flatted_dict_pytree_flatten(self):
561
+ """Test FlattedDict can be flattened as pytree."""
562
+ flat = FlattedDict({('a',): 1, ('b',): 2})
563
+ leaves, treedef = jax.tree.flatten(flat)
564
+
565
+ self.assertEqual(sorted(leaves), [1, 2])
566
+
567
+ def test_flatted_dict_pytree_unflatten(self):
568
+ """Test FlattedDict can be unflattened as pytree."""
569
+ flat = FlattedDict({('a',): 1, ('b',): 2})
570
+ leaves, treedef = jax.tree.flatten(flat)
571
+ restored = jax.tree.unflatten(treedef, leaves)
572
+
573
+ self.assertIsInstance(restored, FlattedDict)
574
+ self.assertEqual(restored[('a',)], 1)
575
+
576
+ def test_pretty_list_pytree(self):
577
+ """Test PrettyList pytree operations."""
578
+ lst = PrettyList([1, 2, 3])
579
+ leaves, treedef = jax.tree.flatten(lst)
580
+ restored = jax.tree.unflatten(treedef, leaves)
581
+
582
+ self.assertIsInstance(restored, PrettyList)
583
+ self.assertEqual(list(restored), [1, 2, 3])
584
+
585
+ def test_jax_tree_map_nested_dict(self):
586
+ """Test jax.tree.map with NestedDict."""
587
+ nested = NestedDict({'a': 1, 'b': {'c': 2}})
588
+ doubled = jax.tree.map(lambda x: x * 2, nested)
589
+
590
+ self.assertEqual(doubled['a'], 2)
591
+ self.assertEqual(doubled['b']['c'], 4)
592
+
593
+ def test_jax_tree_map_flatted_dict(self):
594
+ """Test jax.tree.map with FlattedDict."""
595
+ flat = FlattedDict({('a',): 1, ('b', 'c'): 2})
596
+ doubled = jax.tree.map(lambda x: x * 2, flat)
597
+
598
+ self.assertEqual(doubled[('a',)], 2)
599
+ self.assertEqual(doubled[('b', 'c')], 4)
600
+
601
+ def test_jax_tree_map_pretty_list(self):
602
+ """Test jax.tree.map with PrettyList."""
603
+ lst = PrettyList([1, 2, 3])
604
+ doubled = jax.tree.map(lambda x: x * 2, lst)
605
+
606
+ self.assertEqual(list(doubled), [2, 4, 6])
607
+
608
+
609
+ class TestEdgeCases(unittest.TestCase):
610
+ """Test edge cases and error handling."""
611
+
612
+ def test_empty_nested_dict(self):
613
+ """Test empty NestedDict."""
614
+ nested = NestedDict({})
615
+ flat = nested.to_flat()
616
+ self.assertEqual(len(flat), 0)
617
+
618
+ def test_empty_flatted_dict(self):
619
+ """Test empty FlattedDict."""
620
+ flat = FlattedDict({})
621
+ nested = flat.to_nest()
622
+ self.assertEqual(len(nested), 0)
623
+
624
+ def test_deeply_nested_structure(self):
625
+ """Test deeply nested structure."""
626
+ nested = NestedDict({
627
+ 'a': {
628
+ 'b': {
629
+ 'c': {
630
+ 'd': {
631
+ 'e': 42
632
+ }
633
+ }
634
+ }
635
+ }
636
+ })
637
+ flat = nested.to_flat()
638
+ self.assertEqual(flat[('a', 'b', 'c', 'd', 'e')], 42)
639
+
640
+ def test_mixed_types_in_nested(self):
641
+ """Test nested dict with mixed types."""
642
+ nested = NestedDict({
643
+ 'int': 1,
644
+ 'float': 2.5,
645
+ 'str': 'hello',
646
+ 'list': [1, 2, 3],
647
+ 'dict': {'nested': True}
648
+ })
649
+ flat = nested.to_flat()
650
+
651
+ self.assertEqual(flat[('int',)], 1)
652
+ self.assertEqual(flat[('float',)], 2.5)
653
+ self.assertEqual(flat[('str',)], 'hello')
654
+
655
+ def test_numeric_keys(self):
656
+ """Test handling of numeric keys."""
657
+ nested = NestedDict({
658
+ 1: 'one',
659
+ 2: {'a': 'two-a'}
660
+ })
661
+ flat = nested.to_flat()
662
+
663
+ self.assertEqual(flat[(1,)], 'one')
664
+ self.assertEqual(flat[(2, 'a')], 'two-a')
665
+
666
+ def test_merge_with_overlapping_keys(self):
667
+ """Test merging with overlapping keys."""
668
+ d1 = NestedDict({'a': 1, 'b': 2})
669
+ d2 = NestedDict({'b': 3, 'c': 4})
670
+ merged = NestedDict.merge(d1, d2)
671
+
672
+ # Later values should override
673
+ self.assertEqual(merged['b'], 3)
674
+ self.assertEqual(merged['a'], 1)
675
+ self.assertEqual(merged['c'], 4)