brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,675 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Comprehensive test suite for the pretty_pytree module.
|
18
|
+
|
19
|
+
This test module provides extensive coverage of the pretty printing and tree
|
20
|
+
manipulation functionality, including:
|
21
|
+
- PrettyObject and pretty representation
|
22
|
+
- Nested and flattened dictionary structures
|
23
|
+
- Mapping flattening and unflattening
|
24
|
+
- Split, filter, and merge operations
|
25
|
+
- JAX pytree integration
|
26
|
+
- State management utilities
|
27
|
+
"""
|
28
|
+
|
29
|
+
import unittest
|
30
|
+
|
31
|
+
import jax
|
32
|
+
import jax.numpy as jnp
|
33
|
+
import numpy as np
|
34
|
+
from absl.testing import absltest
|
35
|
+
|
36
|
+
import brainstate
|
37
|
+
from brainstate.util._pretty_pytree import (
|
38
|
+
PrettyObject,
|
39
|
+
PrettyDict,
|
40
|
+
NestedDict,
|
41
|
+
FlattedDict,
|
42
|
+
PrettyList,
|
43
|
+
flat_mapping,
|
44
|
+
nest_mapping,
|
45
|
+
empty_node,
|
46
|
+
_EmptyNode,
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class TestNestedMapping(absltest.TestCase):
|
51
|
+
def test_create_state(self):
|
52
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
53
|
+
|
54
|
+
assert state['a'].value == 1
|
55
|
+
assert state['b']['c'].value == 2
|
56
|
+
|
57
|
+
def test_get_attr(self):
|
58
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
59
|
+
|
60
|
+
assert state.a.value == 1
|
61
|
+
assert state.b['c'].value == 2
|
62
|
+
|
63
|
+
def test_set_attr(self):
|
64
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
65
|
+
|
66
|
+
state.a.value = 3
|
67
|
+
state.b['c'].value = 4
|
68
|
+
|
69
|
+
assert state['a'].value == 3
|
70
|
+
assert state['b']['c'].value == 4
|
71
|
+
|
72
|
+
def test_set_attr_variables(self):
|
73
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
74
|
+
|
75
|
+
state.a.value = 3
|
76
|
+
state.b['c'].value = 4
|
77
|
+
|
78
|
+
assert isinstance(state.a, brainstate.ParamState)
|
79
|
+
assert state.a.value == 3
|
80
|
+
assert isinstance(state.b['c'], brainstate.ParamState)
|
81
|
+
assert state.b['c'].value == 4
|
82
|
+
|
83
|
+
def test_add_nested_attr(self):
|
84
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
85
|
+
state.b['d'] = brainstate.ParamState(5)
|
86
|
+
|
87
|
+
assert state['b']['d'].value == 5
|
88
|
+
|
89
|
+
def test_delete_nested_attr(self):
|
90
|
+
state = brainstate.util.NestedDict({'a': brainstate.ParamState(1), 'b': {'c': brainstate.ParamState(2)}})
|
91
|
+
del state['b']['c']
|
92
|
+
|
93
|
+
assert 'c' not in state['b']
|
94
|
+
|
95
|
+
def test_integer_access(self):
|
96
|
+
class Foo(brainstate.nn.Module):
|
97
|
+
def __init__(self):
|
98
|
+
super().__init__()
|
99
|
+
self.layers = [brainstate.nn.Linear(1, 2), brainstate.nn.Linear(2, 3)]
|
100
|
+
|
101
|
+
module = Foo()
|
102
|
+
state_refs = brainstate.graph.treefy_states(module)
|
103
|
+
|
104
|
+
assert module.layers[0].weight.value['weight'].shape == (1, 2)
|
105
|
+
assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
|
106
|
+
assert module.layers[1].weight.value['weight'].shape == (2, 3)
|
107
|
+
assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
|
108
|
+
|
109
|
+
def test_pure_dict(self):
|
110
|
+
module = brainstate.nn.Linear(4, 5)
|
111
|
+
state_map = brainstate.graph.treefy_states(module)
|
112
|
+
pure_dict = state_map.to_pure_dict()
|
113
|
+
assert isinstance(pure_dict, dict)
|
114
|
+
assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
|
115
|
+
assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
|
116
|
+
|
117
|
+
|
118
|
+
class TestSplit(unittest.TestCase):
|
119
|
+
def test_split(self):
|
120
|
+
class Model(brainstate.nn.Module):
|
121
|
+
def __init__(self):
|
122
|
+
super().__init__()
|
123
|
+
self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
124
|
+
self.linear = brainstate.nn.Linear([10, 3], [10, 4])
|
125
|
+
|
126
|
+
def __call__(self, x):
|
127
|
+
return self.linear(self.batchnorm(x))
|
128
|
+
|
129
|
+
with brainstate.environ.context(fit=True):
|
130
|
+
model = Model()
|
131
|
+
x = brainstate.random.randn(1, 10, 3)
|
132
|
+
y = model(x)
|
133
|
+
self.assertEqual(y.shape, (1, 10, 4))
|
134
|
+
|
135
|
+
state_map = brainstate.graph.treefy_states(model)
|
136
|
+
|
137
|
+
with self.assertRaises(ValueError):
|
138
|
+
params, others = state_map.split(brainstate.ParamState)
|
139
|
+
|
140
|
+
params, others = state_map.split(brainstate.ParamState, ...)
|
141
|
+
print()
|
142
|
+
print(params)
|
143
|
+
print(others)
|
144
|
+
|
145
|
+
self.assertTrue(len(params.to_flat()) == 2)
|
146
|
+
self.assertTrue(len(others.to_flat()) == 2)
|
147
|
+
|
148
|
+
|
149
|
+
class TestStateMap2(unittest.TestCase):
|
150
|
+
def test1(self):
|
151
|
+
class Model(brainstate.nn.Module):
|
152
|
+
def __init__(self):
|
153
|
+
super().__init__()
|
154
|
+
self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
155
|
+
self.linear = brainstate.nn.Linear([10, 3], [10, 4])
|
156
|
+
|
157
|
+
def __call__(self, x):
|
158
|
+
return self.linear(self.batchnorm(x))
|
159
|
+
|
160
|
+
with brainstate.environ.context(fit=True):
|
161
|
+
model = Model()
|
162
|
+
state_map = brainstate.graph.treefy_states(model).to_flat()
|
163
|
+
state_map = brainstate.util.NestedDict(state_map)
|
164
|
+
|
165
|
+
|
166
|
+
class TestFlattedMapping(unittest.TestCase):
|
167
|
+
def test1(self):
|
168
|
+
class Model(brainstate.nn.Module):
|
169
|
+
def __init__(self):
|
170
|
+
super().__init__()
|
171
|
+
self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
172
|
+
self.linear = brainstate.nn.Linear([10, 3], [10, 4])
|
173
|
+
|
174
|
+
def __call__(self, x):
|
175
|
+
return self.linear(self.batchnorm(x))
|
176
|
+
|
177
|
+
model = Model()
|
178
|
+
# print(model.states())
|
179
|
+
# print(brainstate.graph.states(model))
|
180
|
+
self.assertTrue(model.states() == brainstate.graph.states(model))
|
181
|
+
|
182
|
+
print(model.nodes())
|
183
|
+
# print(brainstate.graph.nodes(model))
|
184
|
+
self.assertTrue(model.nodes() == brainstate.graph.nodes(model))
|
185
|
+
|
186
|
+
|
187
|
+
class TestPrettyObject(unittest.TestCase):
|
188
|
+
"""Test PrettyObject functionality."""
|
189
|
+
|
190
|
+
def test_pretty_object_basic(self):
|
191
|
+
"""Test basic PrettyObject creation and representation."""
|
192
|
+
class MyObject(PrettyObject):
|
193
|
+
def __init__(self, value):
|
194
|
+
self.value = value
|
195
|
+
self.name = "test"
|
196
|
+
|
197
|
+
obj = MyObject(42)
|
198
|
+
repr_str = repr(obj)
|
199
|
+
self.assertIsInstance(repr_str, str)
|
200
|
+
self.assertIn("MyObject", repr_str)
|
201
|
+
self.assertIn("value", repr_str)
|
202
|
+
self.assertIn("42", repr_str)
|
203
|
+
|
204
|
+
def test_pretty_repr_item_filtering(self):
|
205
|
+
"""Test __pretty_repr_item__ filtering."""
|
206
|
+
class FilteredObject(PrettyObject):
|
207
|
+
def __init__(self):
|
208
|
+
self.visible = "show"
|
209
|
+
self.hidden = "hide"
|
210
|
+
|
211
|
+
def __pretty_repr_item__(self, k, v):
|
212
|
+
if k == "hidden":
|
213
|
+
return None
|
214
|
+
return k, v
|
215
|
+
|
216
|
+
obj = FilteredObject()
|
217
|
+
repr_str = repr(obj)
|
218
|
+
self.assertIn("visible", repr_str)
|
219
|
+
self.assertNotIn("hidden", repr_str)
|
220
|
+
|
221
|
+
def test_pretty_repr_item_transformation(self):
|
222
|
+
"""Test __pretty_repr_item__ value transformation."""
|
223
|
+
class TransformObject(PrettyObject):
|
224
|
+
def __init__(self):
|
225
|
+
self.value = 100
|
226
|
+
|
227
|
+
def __pretty_repr_item__(self, k, v):
|
228
|
+
if k == "value":
|
229
|
+
return k, v * 2
|
230
|
+
return k, v
|
231
|
+
|
232
|
+
obj = TransformObject()
|
233
|
+
repr_str = repr(obj)
|
234
|
+
self.assertIn("200", repr_str)
|
235
|
+
|
236
|
+
|
237
|
+
class TestFlatAndNestMapping(unittest.TestCase):
|
238
|
+
"""Test flat_mapping and nest_mapping functions."""
|
239
|
+
|
240
|
+
def test_flat_mapping_basic(self):
|
241
|
+
"""Test basic flattening of nested dict."""
|
242
|
+
nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
243
|
+
flat = flat_mapping(nested)
|
244
|
+
|
245
|
+
self.assertIsInstance(flat, FlattedDict)
|
246
|
+
self.assertEqual(flat[('a',)], 1)
|
247
|
+
self.assertEqual(flat[('b', 'c')], 2)
|
248
|
+
self.assertEqual(flat[('b', 'd')], 3)
|
249
|
+
|
250
|
+
def test_flat_mapping_with_separator(self):
|
251
|
+
"""Test flattening with string separator."""
|
252
|
+
nested = {'a': 1, 'b': {'c': 2}}
|
253
|
+
flat = flat_mapping(nested, sep='/')
|
254
|
+
|
255
|
+
self.assertEqual(flat['a'], 1)
|
256
|
+
self.assertEqual(flat['b/c'], 2)
|
257
|
+
|
258
|
+
def test_flat_mapping_empty_nodes(self):
|
259
|
+
"""Test flattening with keep_empty_nodes."""
|
260
|
+
nested = {'a': 1, 'b': {}}
|
261
|
+
flat = flat_mapping(nested, keep_empty_nodes=True)
|
262
|
+
|
263
|
+
self.assertEqual(flat[('a',)], 1)
|
264
|
+
self.assertIsInstance(flat[('b',)], _EmptyNode)
|
265
|
+
|
266
|
+
def test_flat_mapping_without_empty_nodes(self):
|
267
|
+
"""Test flattening without keeping empty nodes."""
|
268
|
+
nested = {'a': 1, 'b': {}}
|
269
|
+
flat = flat_mapping(nested, keep_empty_nodes=False)
|
270
|
+
|
271
|
+
self.assertIn(('a',), flat)
|
272
|
+
self.assertNotIn(('b',), flat)
|
273
|
+
|
274
|
+
def test_flat_mapping_is_leaf(self):
|
275
|
+
"""Test flattening with custom is_leaf function."""
|
276
|
+
nested = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
277
|
+
|
278
|
+
def is_leaf(path, value):
|
279
|
+
return len(path) >= 1
|
280
|
+
|
281
|
+
flat = flat_mapping(nested, is_leaf=is_leaf)
|
282
|
+
self.assertEqual(flat[('a',)], 1)
|
283
|
+
self.assertEqual(flat[('b',)], {'c': 2, 'd': 3})
|
284
|
+
|
285
|
+
def test_nest_mapping_basic(self):
|
286
|
+
"""Test basic unflattening."""
|
287
|
+
flat = {('a',): 1, ('b', 'c'): 2, ('b', 'd'): 3}
|
288
|
+
nested = nest_mapping(flat)
|
289
|
+
|
290
|
+
self.assertIsInstance(nested, NestedDict)
|
291
|
+
self.assertEqual(nested['a'], 1)
|
292
|
+
self.assertEqual(nested['b']['c'], 2)
|
293
|
+
self.assertEqual(nested['b']['d'], 3)
|
294
|
+
|
295
|
+
def test_nest_mapping_with_separator(self):
|
296
|
+
"""Test unflattening with string separator."""
|
297
|
+
flat = {'a': 1, 'b/c': 2, 'b/d': 3}
|
298
|
+
nested = nest_mapping(flat, sep='/')
|
299
|
+
|
300
|
+
self.assertEqual(nested['a'], 1)
|
301
|
+
self.assertEqual(nested['b']['c'], 2)
|
302
|
+
self.assertEqual(nested['b']['d'], 3)
|
303
|
+
|
304
|
+
def test_nest_mapping_with_empty_node(self):
|
305
|
+
"""Test unflattening with empty nodes."""
|
306
|
+
flat = {('a',): 1, ('b',): empty_node}
|
307
|
+
nested = nest_mapping(flat)
|
308
|
+
|
309
|
+
self.assertEqual(nested['a'], 1)
|
310
|
+
self.assertEqual(nested['b'], {})
|
311
|
+
|
312
|
+
def test_round_trip(self):
|
313
|
+
"""Test flatten -> unflatten round trip."""
|
314
|
+
original = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
|
315
|
+
flat = flat_mapping(original)
|
316
|
+
restored = nest_mapping(flat)
|
317
|
+
|
318
|
+
self.assertEqual(restored.to_dict(), original)
|
319
|
+
|
320
|
+
|
321
|
+
class TestPrettyDict(unittest.TestCase):
|
322
|
+
"""Test PrettyDict functionality."""
|
323
|
+
|
324
|
+
def test_pretty_dict_creation(self):
|
325
|
+
"""Test PrettyDict creation."""
|
326
|
+
d = PrettyDict({'a': 1, 'b': 2})
|
327
|
+
self.assertEqual(d['a'], 1)
|
328
|
+
self.assertEqual(d['b'], 2)
|
329
|
+
|
330
|
+
def test_pretty_dict_attribute_access(self):
|
331
|
+
"""Test accessing dict items as attributes."""
|
332
|
+
d = PrettyDict({'a': 1, 'b': 2})
|
333
|
+
self.assertEqual(d.a, 1)
|
334
|
+
self.assertEqual(d.b, 2)
|
335
|
+
|
336
|
+
def test_pretty_dict_repr(self):
|
337
|
+
"""Test PrettyDict representation."""
|
338
|
+
d = PrettyDict({'a': 1, 'b': 2})
|
339
|
+
repr_str = repr(d)
|
340
|
+
self.assertIsInstance(repr_str, str)
|
341
|
+
self.assertIn('a', repr_str)
|
342
|
+
|
343
|
+
def test_to_dict(self):
|
344
|
+
"""Test conversion to regular dict."""
|
345
|
+
d = PrettyDict({'a': 1, 'b': 2})
|
346
|
+
regular = d.to_dict()
|
347
|
+
self.assertIsInstance(regular, dict)
|
348
|
+
self.assertEqual(regular, {'a': 1, 'b': 2})
|
349
|
+
|
350
|
+
|
351
|
+
class TestNestedDictOperations(unittest.TestCase):
|
352
|
+
"""Test NestedDict additional operations."""
|
353
|
+
|
354
|
+
def test_or_operator(self):
|
355
|
+
"""Test | operator for merging."""
|
356
|
+
d1 = NestedDict({'a': 1})
|
357
|
+
d2 = NestedDict({'b': 2})
|
358
|
+
merged = d1 | d2
|
359
|
+
|
360
|
+
self.assertIsInstance(merged, NestedDict)
|
361
|
+
self.assertEqual(merged['a'], 1)
|
362
|
+
self.assertEqual(merged['b'], 2)
|
363
|
+
|
364
|
+
def test_sub_operator(self):
|
365
|
+
"""Test - operator for difference."""
|
366
|
+
d1 = NestedDict({'a': 1, 'b': 2, 'c': 3})
|
367
|
+
d2 = NestedDict({'b': 2})
|
368
|
+
diff = d1 - d2
|
369
|
+
|
370
|
+
flat_diff = diff.to_flat()
|
371
|
+
self.assertIn(('a',), flat_diff.keys())
|
372
|
+
self.assertIn(('c',), flat_diff.keys())
|
373
|
+
# b should not be in diff
|
374
|
+
has_b = any('b' in key for key in flat_diff.keys())
|
375
|
+
self.assertFalse(has_b)
|
376
|
+
|
377
|
+
def test_merge_static_method(self):
|
378
|
+
"""Test static merge method."""
|
379
|
+
d1 = NestedDict({'a': 1})
|
380
|
+
d2 = NestedDict({'b': 2})
|
381
|
+
d3 = NestedDict({'c': 3})
|
382
|
+
merged = NestedDict.merge(d1, d2, d3)
|
383
|
+
|
384
|
+
self.assertEqual(merged['a'], 1)
|
385
|
+
self.assertEqual(merged['b'], 2)
|
386
|
+
self.assertEqual(merged['c'], 3)
|
387
|
+
|
388
|
+
def test_to_pure_dict(self):
|
389
|
+
"""Test conversion to pure dict."""
|
390
|
+
nested = NestedDict({'a': 1, 'b': {'c': 2}})
|
391
|
+
pure = nested.to_pure_dict()
|
392
|
+
|
393
|
+
self.assertIsInstance(pure, dict)
|
394
|
+
self.assertNotIsInstance(pure, NestedDict)
|
395
|
+
self.assertEqual(pure['a'], 1)
|
396
|
+
self.assertEqual(pure['b']['c'], 2)
|
397
|
+
|
398
|
+
|
399
|
+
class TestFlattedDictOperations(unittest.TestCase):
|
400
|
+
"""Test FlattedDict additional operations."""
|
401
|
+
|
402
|
+
def test_or_operator(self):
|
403
|
+
"""Test | operator for merging."""
|
404
|
+
d1 = FlattedDict({('a',): 1})
|
405
|
+
d2 = FlattedDict({('b',): 2})
|
406
|
+
merged = d1 | d2
|
407
|
+
|
408
|
+
self.assertIsInstance(merged, FlattedDict)
|
409
|
+
self.assertEqual(merged[('a',)], 1)
|
410
|
+
self.assertEqual(merged[('b',)], 2)
|
411
|
+
|
412
|
+
def test_sub_operator(self):
|
413
|
+
"""Test - operator for difference."""
|
414
|
+
d1 = FlattedDict({('a',): 1, ('b',): 2, ('c',): 3})
|
415
|
+
d2 = FlattedDict({('b',): 2})
|
416
|
+
diff = d1 - d2
|
417
|
+
|
418
|
+
self.assertIn(('a',), diff)
|
419
|
+
self.assertIn(('c',), diff)
|
420
|
+
self.assertNotIn(('b',), diff)
|
421
|
+
|
422
|
+
def test_merge_static_method(self):
|
423
|
+
"""Test static merge method."""
|
424
|
+
d1 = FlattedDict({('a',): 1})
|
425
|
+
d2 = FlattedDict({('b',): 2})
|
426
|
+
merged = FlattedDict.merge(d1, d2)
|
427
|
+
|
428
|
+
self.assertEqual(merged[('a',)], 1)
|
429
|
+
self.assertEqual(merged[('b',)], 2)
|
430
|
+
|
431
|
+
def test_to_dict_values(self):
|
432
|
+
"""Test conversion to dictionary of values."""
|
433
|
+
flat = FlattedDict({
|
434
|
+
('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
|
435
|
+
('b',): 42
|
436
|
+
})
|
437
|
+
values = flat.to_dict_values()
|
438
|
+
|
439
|
+
self.assertIsInstance(values[('a',)], jnp.ndarray)
|
440
|
+
np.testing.assert_array_equal(values[('a',)], jnp.array([1, 2, 3]))
|
441
|
+
self.assertEqual(values[('b',)], 42)
|
442
|
+
|
443
|
+
def test_assign_dict_values(self):
|
444
|
+
"""Test assigning dictionary values."""
|
445
|
+
flat = FlattedDict({
|
446
|
+
('a',): brainstate.ParamState(jnp.array([1, 2, 3])),
|
447
|
+
('b',): 42
|
448
|
+
})
|
449
|
+
|
450
|
+
new_values = {
|
451
|
+
('a',): jnp.array([4, 5, 6]),
|
452
|
+
('b',): 100
|
453
|
+
}
|
454
|
+
|
455
|
+
flat.assign_dict_values(new_values)
|
456
|
+
|
457
|
+
np.testing.assert_array_equal(flat[('a',)].value, jnp.array([4, 5, 6]))
|
458
|
+
self.assertEqual(flat[('b',)], 100)
|
459
|
+
|
460
|
+
def test_assign_dict_values_missing_key(self):
|
461
|
+
"""Test assigning with missing key raises error."""
|
462
|
+
flat = FlattedDict({('a',): 1})
|
463
|
+
|
464
|
+
with self.assertRaises(KeyError):
|
465
|
+
flat.assign_dict_values({('b',): 2})
|
466
|
+
|
467
|
+
|
468
|
+
class TestPrettyList(unittest.TestCase):
|
469
|
+
"""Test PrettyList functionality."""
|
470
|
+
|
471
|
+
def test_pretty_list_creation(self):
|
472
|
+
"""Test PrettyList creation."""
|
473
|
+
lst = PrettyList([1, 2, 3])
|
474
|
+
self.assertEqual(lst[0], 1)
|
475
|
+
self.assertEqual(lst[1], 2)
|
476
|
+
self.assertEqual(lst[2], 3)
|
477
|
+
|
478
|
+
def test_pretty_list_repr(self):
|
479
|
+
"""Test PrettyList representation."""
|
480
|
+
lst = PrettyList([1, 2, {'a': 3}])
|
481
|
+
repr_str = repr(lst)
|
482
|
+
self.assertIsInstance(repr_str, str)
|
483
|
+
self.assertIn('1', repr_str)
|
484
|
+
|
485
|
+
def test_tree_flatten(self):
|
486
|
+
"""Test JAX tree flattening."""
|
487
|
+
lst = PrettyList([1, 2, 3])
|
488
|
+
leaves, aux = lst.tree_flatten()
|
489
|
+
self.assertEqual(leaves, [1, 2, 3])
|
490
|
+
self.assertEqual(aux, ())
|
491
|
+
|
492
|
+
def test_tree_unflatten(self):
|
493
|
+
"""Test JAX tree unflattening."""
|
494
|
+
children = [1, 2, 3]
|
495
|
+
lst = PrettyList.tree_unflatten((), children)
|
496
|
+
self.assertIsInstance(lst, PrettyList)
|
497
|
+
self.assertEqual(list(lst), [1, 2, 3])
|
498
|
+
|
499
|
+
|
500
|
+
class TestFilterOperations(unittest.TestCase):
|
501
|
+
"""Test filter operations."""
|
502
|
+
|
503
|
+
def test_nested_dict_filter(self):
|
504
|
+
"""Test filtering NestedDict."""
|
505
|
+
nested = NestedDict({
|
506
|
+
'a': 1,
|
507
|
+
'b': 2,
|
508
|
+
'c': 3
|
509
|
+
})
|
510
|
+
|
511
|
+
filtered = nested.filter(lambda path, val: val >= 2)
|
512
|
+
|
513
|
+
flat = filtered.to_flat()
|
514
|
+
# Check that filtered values are present
|
515
|
+
values = [v for v in flat.values()]
|
516
|
+
self.assertIn(2, values)
|
517
|
+
self.assertIn(3, values)
|
518
|
+
|
519
|
+
def test_flatted_dict_filter(self):
|
520
|
+
"""Test filtering FlattedDict."""
|
521
|
+
flat = FlattedDict({
|
522
|
+
('a',): 1,
|
523
|
+
('b',): 2,
|
524
|
+
('c',): 3
|
525
|
+
})
|
526
|
+
|
527
|
+
filtered = flat.filter(lambda path, val: val % 2 == 0)
|
528
|
+
self.assertIn(('b',), filtered)
|
529
|
+
self.assertNotIn(('a',), filtered)
|
530
|
+
|
531
|
+
def test_ellipsis_filter_position(self):
|
532
|
+
"""Test that ... can only be used as last filter."""
|
533
|
+
nested = NestedDict({'a': 1, 'b': 2, 'c': 3})
|
534
|
+
|
535
|
+
with self.assertRaises(ValueError):
|
536
|
+
# ... in middle should raise error
|
537
|
+
nested.split(..., lambda path, val: val > 1)
|
538
|
+
|
539
|
+
|
540
|
+
class TestJAXPytreeIntegration(unittest.TestCase):
|
541
|
+
"""Test JAX pytree integration."""
|
542
|
+
|
543
|
+
def test_nested_dict_pytree_flatten(self):
|
544
|
+
"""Test NestedDict can be flattened as pytree."""
|
545
|
+
nested = NestedDict({'a': 1, 'b': 2})
|
546
|
+
leaves, treedef = jax.tree.flatten(nested)
|
547
|
+
|
548
|
+
self.assertEqual(sorted(leaves), [1, 2])
|
549
|
+
|
550
|
+
def test_nested_dict_pytree_unflatten(self):
|
551
|
+
"""Test NestedDict can be unflattened as pytree."""
|
552
|
+
nested = NestedDict({'a': 1, 'b': 2})
|
553
|
+
leaves, treedef = jax.tree.flatten(nested)
|
554
|
+
restored = jax.tree.unflatten(treedef, leaves)
|
555
|
+
|
556
|
+
self.assertIsInstance(restored, NestedDict)
|
557
|
+
self.assertEqual(restored['a'], 1)
|
558
|
+
self.assertEqual(restored['b'], 2)
|
559
|
+
|
560
|
+
def test_flatted_dict_pytree_flatten(self):
|
561
|
+
"""Test FlattedDict can be flattened as pytree."""
|
562
|
+
flat = FlattedDict({('a',): 1, ('b',): 2})
|
563
|
+
leaves, treedef = jax.tree.flatten(flat)
|
564
|
+
|
565
|
+
self.assertEqual(sorted(leaves), [1, 2])
|
566
|
+
|
567
|
+
def test_flatted_dict_pytree_unflatten(self):
|
568
|
+
"""Test FlattedDict can be unflattened as pytree."""
|
569
|
+
flat = FlattedDict({('a',): 1, ('b',): 2})
|
570
|
+
leaves, treedef = jax.tree.flatten(flat)
|
571
|
+
restored = jax.tree.unflatten(treedef, leaves)
|
572
|
+
|
573
|
+
self.assertIsInstance(restored, FlattedDict)
|
574
|
+
self.assertEqual(restored[('a',)], 1)
|
575
|
+
|
576
|
+
def test_pretty_list_pytree(self):
|
577
|
+
"""Test PrettyList pytree operations."""
|
578
|
+
lst = PrettyList([1, 2, 3])
|
579
|
+
leaves, treedef = jax.tree.flatten(lst)
|
580
|
+
restored = jax.tree.unflatten(treedef, leaves)
|
581
|
+
|
582
|
+
self.assertIsInstance(restored, PrettyList)
|
583
|
+
self.assertEqual(list(restored), [1, 2, 3])
|
584
|
+
|
585
|
+
def test_jax_tree_map_nested_dict(self):
|
586
|
+
"""Test jax.tree.map with NestedDict."""
|
587
|
+
nested = NestedDict({'a': 1, 'b': {'c': 2}})
|
588
|
+
doubled = jax.tree.map(lambda x: x * 2, nested)
|
589
|
+
|
590
|
+
self.assertEqual(doubled['a'], 2)
|
591
|
+
self.assertEqual(doubled['b']['c'], 4)
|
592
|
+
|
593
|
+
def test_jax_tree_map_flatted_dict(self):
|
594
|
+
"""Test jax.tree.map with FlattedDict."""
|
595
|
+
flat = FlattedDict({('a',): 1, ('b', 'c'): 2})
|
596
|
+
doubled = jax.tree.map(lambda x: x * 2, flat)
|
597
|
+
|
598
|
+
self.assertEqual(doubled[('a',)], 2)
|
599
|
+
self.assertEqual(doubled[('b', 'c')], 4)
|
600
|
+
|
601
|
+
def test_jax_tree_map_pretty_list(self):
|
602
|
+
"""Test jax.tree.map with PrettyList."""
|
603
|
+
lst = PrettyList([1, 2, 3])
|
604
|
+
doubled = jax.tree.map(lambda x: x * 2, lst)
|
605
|
+
|
606
|
+
self.assertEqual(list(doubled), [2, 4, 6])
|
607
|
+
|
608
|
+
|
609
|
+
class TestEdgeCases(unittest.TestCase):
|
610
|
+
"""Test edge cases and error handling."""
|
611
|
+
|
612
|
+
def test_empty_nested_dict(self):
|
613
|
+
"""Test empty NestedDict."""
|
614
|
+
nested = NestedDict({})
|
615
|
+
flat = nested.to_flat()
|
616
|
+
self.assertEqual(len(flat), 0)
|
617
|
+
|
618
|
+
def test_empty_flatted_dict(self):
|
619
|
+
"""Test empty FlattedDict."""
|
620
|
+
flat = FlattedDict({})
|
621
|
+
nested = flat.to_nest()
|
622
|
+
self.assertEqual(len(nested), 0)
|
623
|
+
|
624
|
+
def test_deeply_nested_structure(self):
|
625
|
+
"""Test deeply nested structure."""
|
626
|
+
nested = NestedDict({
|
627
|
+
'a': {
|
628
|
+
'b': {
|
629
|
+
'c': {
|
630
|
+
'd': {
|
631
|
+
'e': 42
|
632
|
+
}
|
633
|
+
}
|
634
|
+
}
|
635
|
+
}
|
636
|
+
})
|
637
|
+
flat = nested.to_flat()
|
638
|
+
self.assertEqual(flat[('a', 'b', 'c', 'd', 'e')], 42)
|
639
|
+
|
640
|
+
def test_mixed_types_in_nested(self):
|
641
|
+
"""Test nested dict with mixed types."""
|
642
|
+
nested = NestedDict({
|
643
|
+
'int': 1,
|
644
|
+
'float': 2.5,
|
645
|
+
'str': 'hello',
|
646
|
+
'list': [1, 2, 3],
|
647
|
+
'dict': {'nested': True}
|
648
|
+
})
|
649
|
+
flat = nested.to_flat()
|
650
|
+
|
651
|
+
self.assertEqual(flat[('int',)], 1)
|
652
|
+
self.assertEqual(flat[('float',)], 2.5)
|
653
|
+
self.assertEqual(flat[('str',)], 'hello')
|
654
|
+
|
655
|
+
def test_numeric_keys(self):
|
656
|
+
"""Test handling of numeric keys."""
|
657
|
+
nested = NestedDict({
|
658
|
+
1: 'one',
|
659
|
+
2: {'a': 'two-a'}
|
660
|
+
})
|
661
|
+
flat = nested.to_flat()
|
662
|
+
|
663
|
+
self.assertEqual(flat[(1,)], 'one')
|
664
|
+
self.assertEqual(flat[(2, 'a')], 'two-a')
|
665
|
+
|
666
|
+
def test_merge_with_overlapping_keys(self):
|
667
|
+
"""Test merging with overlapping keys."""
|
668
|
+
d1 = NestedDict({'a': 1, 'b': 2})
|
669
|
+
d2 = NestedDict({'b': 3, 'c': 4})
|
670
|
+
merged = NestedDict.merge(d1, d2)
|
671
|
+
|
672
|
+
# Later values should override
|
673
|
+
self.assertEqual(merged['b'], 3)
|
674
|
+
self.assertEqual(merged['a'], 1)
|
675
|
+
self.assertEqual(merged['c'], 4)
|