brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1301 @@
|
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
"""Pretty printing utilities for PyTree-like structures.
|
19
|
+
|
20
|
+
This module provides classes and functions for pretty printing and manipulating
|
21
|
+
tree-like data structures in BrainState. The main components are:
|
22
|
+
|
23
|
+
Classes:
|
24
|
+
- :class:`PrettyObject`: Base class for objects with pretty representation.
|
25
|
+
- :class:`PrettyDict`: Dictionary with pretty printing and tree utilities.
|
26
|
+
- :class:`NestedDict`: Nested mapping structure for hierarchical data.
|
27
|
+
- :class:`FlattedDict`: Flattened mapping with tuple keys for paths.
|
28
|
+
- :class:`PrettyList`: List with pretty printing capabilities.
|
29
|
+
|
30
|
+
Functions:
|
31
|
+
- :func:`flat_mapping`: Flatten a nested mapping to tuple keys.
|
32
|
+
- :func:`nest_mapping`: Unflatten a mapping back to nested structure.
|
33
|
+
|
34
|
+
All dictionary classes are registered as JAX pytrees and can be used with JAX
|
35
|
+
transformations. They support splitting, filtering, and merging operations for
|
36
|
+
organizing state in neural network models.
|
37
|
+
|
38
|
+
Example:
|
39
|
+
>>> from brainstate.util import NestedDict, flat_mapping
|
40
|
+
>>> state = NestedDict({'layer1': {'weight': 1, 'bias': 2}})
|
41
|
+
>>> flat = state.to_flat()
|
42
|
+
>>> print(flat)
|
43
|
+
FlattedDict({('layer1', 'weight'): 1, ('layer1', 'bias'): 2})
|
44
|
+
"""
|
45
|
+
|
46
|
+
from collections import abc
|
47
|
+
from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict, Callable, Generator
|
48
|
+
|
49
|
+
import jax
|
50
|
+
|
51
|
+
from brainstate.typing import Filter, PathParts
|
52
|
+
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
|
53
|
+
from .filter import to_predicate
|
54
|
+
from .struct import dataclass
|
55
|
+
|
56
|
+
__all__ = [
|
57
|
+
'PrettyDict',
|
58
|
+
'NestedDict',
|
59
|
+
'FlattedDict',
|
60
|
+
'flat_mapping',
|
61
|
+
'nest_mapping',
|
62
|
+
'PrettyList',
|
63
|
+
'PrettyObject',
|
64
|
+
]
|
65
|
+
|
66
|
+
A = TypeVar('A')
|
67
|
+
K = TypeVar('K', bound=Hashable)
|
68
|
+
V = TypeVar('V')
|
69
|
+
|
70
|
+
FlattedStateMapping = Dict[PathParts, V]
|
71
|
+
ExtractValueFn = Callable[[Any], Any]
|
72
|
+
SetValueFn = Callable[[V, Any], V]
|
73
|
+
|
74
|
+
|
75
|
+
class PrettyObject(PrettyRepr):
|
76
|
+
"""Base class for generating pretty representations of tree-like structures.
|
77
|
+
|
78
|
+
This class extends :class:`PrettyRepr` to provide a mechanism for generating
|
79
|
+
human-readable, pretty representations of tree-like data structures. It utilizes
|
80
|
+
custom functions to represent the object and its attributes in a structured and
|
81
|
+
visually appealing format.
|
82
|
+
|
83
|
+
The pretty representation is generated through the ``__pretty_repr__`` method,
|
84
|
+
which yields a sequence of pretty representation items using the general-purpose
|
85
|
+
representation functions.
|
86
|
+
|
87
|
+
Example:
|
88
|
+
>>> class MyTree(PrettyObject):
|
89
|
+
... def __init__(self, value):
|
90
|
+
... self.value = value
|
91
|
+
>>> tree = MyTree(42)
|
92
|
+
>>> print(tree) # Uses __pretty_repr__ for display
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
|
96
|
+
"""Generate a pretty representation of the object.
|
97
|
+
|
98
|
+
This method yields a sequence of pretty representation items for the object,
|
99
|
+
using specialized functions to represent the object and its attributes.
|
100
|
+
|
101
|
+
Yields:
|
102
|
+
Union[PrettyType, PrettyAttr]: Pretty representation items generated by
|
103
|
+
``yield_unique_pretty_repr_items``.
|
104
|
+
"""
|
105
|
+
yield from yield_unique_pretty_repr_items(
|
106
|
+
self,
|
107
|
+
repr_object=_repr_object_general,
|
108
|
+
repr_attr=_repr_attribute_general,
|
109
|
+
)
|
110
|
+
|
111
|
+
def __pretty_repr_item__(self, k: Any, v: Any) -> Optional[Tuple[Any, Any]]:
|
112
|
+
"""Transform a key-value pair for pretty representation.
|
113
|
+
|
114
|
+
This method is used to generate a pretty representation of an item
|
115
|
+
in a data structure, typically for debugging or logging purposes. Subclasses
|
116
|
+
can override this method to customize how individual items are displayed.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
k: The key of the item.
|
120
|
+
v: The value of the item.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Optional[Tuple[Any, Any]]: A tuple containing the (key, value), or None
|
124
|
+
to skip this item in the representation.
|
125
|
+
"""
|
126
|
+
return k, v
|
127
|
+
|
128
|
+
|
129
|
+
PrettyReprTree = PrettyObject
|
130
|
+
|
131
|
+
|
132
|
+
# The empty node is a struct.dataclass to be compatible with JAX.
|
133
|
+
@dataclass
|
134
|
+
class _EmptyNode:
|
135
|
+
"""Sentinel class representing an empty node in tree structures."""
|
136
|
+
pass
|
137
|
+
|
138
|
+
|
139
|
+
IsLeafCallable = Callable[[Tuple[Any, ...], abc.Mapping[Any, Any]], bool]
|
140
|
+
_default_leaf: IsLeafCallable = lambda *args: False
|
141
|
+
empty_node = _EmptyNode()
|
142
|
+
|
143
|
+
|
144
|
+
def flat_mapping(
|
145
|
+
xs: abc.Mapping[Any, Any],
|
146
|
+
/,
|
147
|
+
*,
|
148
|
+
keep_empty_nodes: bool = False,
|
149
|
+
is_leaf: Optional[IsLeafCallable] = _default_leaf,
|
150
|
+
sep: Optional[str] = None
|
151
|
+
) -> 'FlattedDict':
|
152
|
+
"""Flatten a nested mapping into a flat mapping with tuple or string keys.
|
153
|
+
|
154
|
+
The nested keys are flattened to a tuple path. For example, ``{'a': {'b': 1}}``
|
155
|
+
becomes ``{('a', 'b'): 1}``. See :func:`nest_mapping` on how to restore the
|
156
|
+
nested structure.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
xs: A nested mapping to flatten.
|
160
|
+
keep_empty_nodes: If True, replaces empty mappings with ``empty_node`` sentinel.
|
161
|
+
Otherwise, empty mappings are omitted from the result.
|
162
|
+
is_leaf: Optional function that takes ``(path, mapping)`` and returns True if
|
163
|
+
the mapping should be treated as a leaf (i.e., not flattened further).
|
164
|
+
Defaults to treating all mappings as non-leaves.
|
165
|
+
sep: If specified, keys in the returned mapping will be ``sep``-joined strings
|
166
|
+
instead of tuples. For example, with ``sep='/'``, ``('a', 'b')`` becomes ``'a/b'``.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
FlattedDict: A flattened mapping where nested keys are converted to tuples or strings.
|
170
|
+
|
171
|
+
Example:
|
172
|
+
>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
|
173
|
+
>>> flat_xs = flat_mapping(xs)
|
174
|
+
>>> flat_xs
|
175
|
+
FlattedDict({('foo',): 1, ('bar', 'a'): 2})
|
176
|
+
|
177
|
+
>>> # With separator
|
178
|
+
>>> flat_mapping(xs, sep='/')
|
179
|
+
FlattedDict({'foo': 1, 'bar/a': 2})
|
180
|
+
|
181
|
+
>>> # Keep empty nodes
|
182
|
+
>>> flat_mapping(xs, keep_empty_nodes=True)
|
183
|
+
FlattedDict({('foo',): 1, ('bar', 'a'): 2, ('bar', 'b'): _EmptyNode()})
|
184
|
+
|
185
|
+
Note:
|
186
|
+
Empty mappings are ignored by default and will not be restored by
|
187
|
+
:func:`nest_mapping` unless ``keep_empty_nodes=True``.
|
188
|
+
"""
|
189
|
+
assert isinstance(xs, abc.Mapping), f'expected Mapping; got {type(xs).__qualname__}'
|
190
|
+
|
191
|
+
if sep is None:
|
192
|
+
def _key(path: Tuple[Any, ...]) -> Union[Tuple[Any, ...], str]:
|
193
|
+
return path
|
194
|
+
else:
|
195
|
+
def _key(path: Tuple[Any, ...]) -> Union[Tuple[Any, ...], str]:
|
196
|
+
return sep.join(path)
|
197
|
+
|
198
|
+
def _flatten(xs: Any, prefix: Tuple[Any, ...]) -> Dict[Any, Any]:
|
199
|
+
if not isinstance(xs, abc.Mapping) or is_leaf(prefix, xs):
|
200
|
+
return {_key(prefix): xs}
|
201
|
+
|
202
|
+
result = {}
|
203
|
+
is_empty = True
|
204
|
+
for key, value in xs.items():
|
205
|
+
is_empty = False
|
206
|
+
result.update(_flatten(value, prefix + (key,)))
|
207
|
+
if keep_empty_nodes and is_empty:
|
208
|
+
if prefix == (): # when the whole input is empty
|
209
|
+
return {}
|
210
|
+
return {_key(prefix): empty_node}
|
211
|
+
return result
|
212
|
+
|
213
|
+
return FlattedDict(_flatten(xs, ()))
|
214
|
+
|
215
|
+
|
216
|
+
def nest_mapping(
|
217
|
+
xs: Any,
|
218
|
+
/,
|
219
|
+
*,
|
220
|
+
sep: Optional[str] = None
|
221
|
+
) -> 'NestedDict':
|
222
|
+
"""Unflatten a mapping by converting tuple/string keys back to nested structure.
|
223
|
+
|
224
|
+
This is the inverse operation of :func:`flat_mapping`. It reconstructs a nested
|
225
|
+
mapping from a flattened one by interpreting tuple keys as paths in the nested
|
226
|
+
structure.
|
227
|
+
|
228
|
+
Args:
|
229
|
+
xs: A flattened mapping with tuple or string keys.
|
230
|
+
sep: Separator used to split string keys into paths. Must match the separator
|
231
|
+
used in :func:`flat_mapping()`. If None, keys are assumed to be tuples.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
NestedDict: A nested mapping reconstructed from the flattened structure.
|
235
|
+
|
236
|
+
Example:
|
237
|
+
>>> flat_xs = {
|
238
|
+
... ('foo',): 1,
|
239
|
+
... ('bar', 'a'): 2,
|
240
|
+
... }
|
241
|
+
>>> xs = nest_mapping(flat_xs)
|
242
|
+
>>> xs
|
243
|
+
NestedDict({'foo': 1, 'bar': {'a': 2}})
|
244
|
+
|
245
|
+
>>> # With separator
|
246
|
+
>>> flat_xs_str = {'foo': 1, 'bar/a': 2}
|
247
|
+
>>> nest_mapping(flat_xs_str, sep='/')
|
248
|
+
NestedDict({'foo': 1, 'bar': {'a': 2}})
|
249
|
+
|
250
|
+
See Also:
|
251
|
+
:func:`flat_mapping`: The inverse operation that flattens a nested mapping.
|
252
|
+
"""
|
253
|
+
assert isinstance(xs, abc.Mapping), f'expected Mapping; got {type(xs).__qualname__}'
|
254
|
+
result: Dict[Any, Any] = {}
|
255
|
+
for path, value in xs.items():
|
256
|
+
if sep is not None:
|
257
|
+
path = path.split(sep)
|
258
|
+
if value is empty_node:
|
259
|
+
value = {}
|
260
|
+
cursor = result
|
261
|
+
for key in path[:-1]:
|
262
|
+
if key not in cursor:
|
263
|
+
cursor[key] = {}
|
264
|
+
cursor = cursor[key]
|
265
|
+
cursor[path[-1]] = value
|
266
|
+
return NestedDict(result)
|
267
|
+
|
268
|
+
|
269
|
+
def _default_compare(x: Any, values: set) -> bool:
|
270
|
+
"""Check if an object's identity is in a set of values.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
x: The object to check.
|
274
|
+
values: A set of object identities to compare against.
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
bool: True if the object's id is in the values set.
|
278
|
+
"""
|
279
|
+
return id(x) in values
|
280
|
+
|
281
|
+
|
282
|
+
def _default_process(x: Any) -> int:
|
283
|
+
"""Get the identity of an object.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
x: The object to process.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
int: The object's identity (id).
|
290
|
+
"""
|
291
|
+
return id(x)
|
292
|
+
|
293
|
+
|
294
|
+
class PrettyDict(dict, PrettyRepr):
|
295
|
+
"""Base dictionary class with pretty representation and tree utilities.
|
296
|
+
|
297
|
+
This class extends the built-in dict with pretty printing capabilities and
|
298
|
+
provides base methods for tree operations. It serves as the parent class for
|
299
|
+
:class:`NestedDict` and :class:`FlattedDict`.
|
300
|
+
|
301
|
+
Attributes:
|
302
|
+
__module__ (str): Module identifier set to 'brainstate.util'.
|
303
|
+
"""
|
304
|
+
__module__ = 'brainstate.util'
|
305
|
+
|
306
|
+
def __getattr__(self, key: K): # type: ignore[misc]
|
307
|
+
"""Access dictionary items as attributes.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
key: The dictionary key to access.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
The value associated with the key.
|
314
|
+
|
315
|
+
Raises:
|
316
|
+
KeyError: If the key is not found in the dictionary.
|
317
|
+
"""
|
318
|
+
return self[key]
|
319
|
+
|
320
|
+
def treefy_state(self) -> Any:
|
321
|
+
"""Convert :class:`State` objects to a reference tree of the state.
|
322
|
+
|
323
|
+
This method traverses the tree structure and converts any :class:`State` objects
|
324
|
+
to their reference form using ``to_state_ref()``.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
Any: A tree structure where State objects are replaced with their references.
|
328
|
+
|
329
|
+
Example:
|
330
|
+
>>> from brainstate._state import State
|
331
|
+
>>> d = PrettyDict({'a': State(1), 'b': 2})
|
332
|
+
>>> ref_tree = d.treefy_state()
|
333
|
+
"""
|
334
|
+
from brainstate._state import State
|
335
|
+
leaves, treedef = jax.tree.flatten(self)
|
336
|
+
leaves = jax.tree.map(lambda x: x.to_state_ref() if isinstance(x, State) else x, leaves)
|
337
|
+
return treedef.unflatten(leaves)
|
338
|
+
|
339
|
+
def to_dict(self) -> Dict[K, Union[Dict[K, Any], V]]:
|
340
|
+
"""Convert the :class:`PrettyDict` to a standard Python dictionary.
|
341
|
+
|
342
|
+
Returns:
|
343
|
+
Dict[K, Union[Dict[K, Any], V]]: A standard dictionary representation.
|
344
|
+
"""
|
345
|
+
return dict(self) # type: ignore
|
346
|
+
|
347
|
+
def __repr__(self) -> str:
|
348
|
+
"""Generate a pretty string representation of the dictionary.
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
str: A formatted string representation using pretty printing.
|
352
|
+
"""
|
353
|
+
return pretty_repr_object(self)
|
354
|
+
|
355
|
+
def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
|
356
|
+
"""Generate pretty representation items for this dictionary.
|
357
|
+
|
358
|
+
Yields:
|
359
|
+
Union[PrettyType, PrettyAttr]: Pretty representation items.
|
360
|
+
"""
|
361
|
+
yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
|
362
|
+
|
363
|
+
def split(self, *filters: Filter) -> Union['PrettyDict', Tuple['PrettyDict', ...]]:
|
364
|
+
"""Split the dictionary based on filters (abstract method).
|
365
|
+
|
366
|
+
This method must be implemented by subclasses.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
*filters: Filter specifications to split the dictionary.
|
370
|
+
|
371
|
+
Raises:
|
372
|
+
NotImplementedError: This is an abstract method.
|
373
|
+
"""
|
374
|
+
raise NotImplementedError
|
375
|
+
|
376
|
+
def filter(self, *filters: Filter) -> Union['PrettyDict', Tuple['PrettyDict', ...]]:
|
377
|
+
"""Filter the dictionary based on filters (abstract method).
|
378
|
+
|
379
|
+
This method must be implemented by subclasses.
|
380
|
+
|
381
|
+
Args:
|
382
|
+
*filters: Filter specifications to apply.
|
383
|
+
|
384
|
+
Raises:
|
385
|
+
NotImplementedError: This is an abstract method.
|
386
|
+
"""
|
387
|
+
raise NotImplementedError
|
388
|
+
|
389
|
+
def merge(self, *states: 'PrettyDict') -> 'PrettyDict':
|
390
|
+
"""Merge multiple dictionaries (abstract method).
|
391
|
+
|
392
|
+
This method must be implemented by subclasses.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
*states: Additional PrettyDict objects to merge.
|
396
|
+
|
397
|
+
Raises:
|
398
|
+
NotImplementedError: This is an abstract method.
|
399
|
+
"""
|
400
|
+
raise NotImplementedError
|
401
|
+
|
402
|
+
def subset(self, *filters: Filter) -> Union['PrettyDict', Tuple['PrettyDict', ...]]:
|
403
|
+
"""Subset a :class:`PrettyDict` into one or more :class:`PrettyDict` instances.
|
404
|
+
|
405
|
+
The user must pass at least one :class:`Filter` (e.g., :class:`State`), and the
|
406
|
+
filters must be exhaustive (i.e., they must cover all :class:`State` types in
|
407
|
+
the :class:`PrettyDict`).
|
408
|
+
|
409
|
+
Args:
|
410
|
+
*filters: Filter specifications for subsetting.
|
411
|
+
|
412
|
+
Returns:
|
413
|
+
Union[PrettyDict, Tuple[PrettyDict, ...]]: One or more subsetted
|
414
|
+
dictionaries.
|
415
|
+
"""
|
416
|
+
return self.filter(*filters)
|
417
|
+
|
418
|
+
|
419
|
+
class NestedStateRepr(PrettyRepr):
|
420
|
+
"""Pretty representation wrapper for nested state dictionaries.
|
421
|
+
|
422
|
+
This class wraps a :class:`PrettyDict` to provide specialized pretty printing
|
423
|
+
that displays the dictionary in a nested, compact format with curly braces.
|
424
|
+
|
425
|
+
Args:
|
426
|
+
state: The PrettyDict to wrap for pretty representation.
|
427
|
+
|
428
|
+
Attributes:
|
429
|
+
state (PrettyDict): The wrapped dictionary state.
|
430
|
+
"""
|
431
|
+
|
432
|
+
def __init__(self, state: PrettyDict) -> None:
|
433
|
+
self.state = state
|
434
|
+
|
435
|
+
def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
|
436
|
+
"""Generate a compact pretty representation of the nested state.
|
437
|
+
|
438
|
+
Yields:
|
439
|
+
Union[PrettyType, PrettyAttr]: Pretty representation items, skipping the
|
440
|
+
outer PrettyType from the wrapped state.
|
441
|
+
"""
|
442
|
+
yield PrettyType('', value_sep=': ', start='{', end='}')
|
443
|
+
|
444
|
+
for r in self.state.__pretty_repr__():
|
445
|
+
if isinstance(r, PrettyType):
|
446
|
+
continue
|
447
|
+
yield r
|
448
|
+
|
449
|
+
def __treescope_repr__(self, path: Any, subtree_renderer: Callable) -> Any:
|
450
|
+
"""Generate a treescope representation for debugging tools.
|
451
|
+
|
452
|
+
Args:
|
453
|
+
path: The current path in the tree structure.
|
454
|
+
subtree_renderer: Callable to render subtrees.
|
455
|
+
|
456
|
+
Returns:
|
457
|
+
Any: Rendered representation of the nested state.
|
458
|
+
"""
|
459
|
+
children = {}
|
460
|
+
for k, v in self.state.items():
|
461
|
+
if isinstance(v, PrettyDict):
|
462
|
+
v = NestedStateRepr(v)
|
463
|
+
children[k] = v
|
464
|
+
# Render as the dictionary itself at the same path.
|
465
|
+
return subtree_renderer(children, path=path)
|
466
|
+
|
467
|
+
|
468
|
+
def _default_repr_object(node: PrettyDict) -> Generator[PrettyType, None, None]:
|
469
|
+
"""Generate the default object representation for PrettyDict.
|
470
|
+
|
471
|
+
Args:
|
472
|
+
node: The PrettyDict node to represent.
|
473
|
+
|
474
|
+
Yields:
|
475
|
+
PrettyType: A type representation with dict-like formatting.
|
476
|
+
"""
|
477
|
+
yield PrettyType('', value_sep=': ', start='{', end='}')
|
478
|
+
|
479
|
+
|
480
|
+
def _default_repr_attr(node: PrettyDict) -> Generator[PrettyAttr, None, None]:
|
481
|
+
"""Generate the default attribute representations for PrettyDict items.
|
482
|
+
|
483
|
+
This function converts list and dict values to their pretty equivalents
|
484
|
+
and wraps PrettyDict values in NestedStateRepr for compact display.
|
485
|
+
|
486
|
+
Args:
|
487
|
+
node: The PrettyDict node whose attributes to represent.
|
488
|
+
|
489
|
+
Yields:
|
490
|
+
PrettyAttr: Pretty attribute representations for each item.
|
491
|
+
"""
|
492
|
+
for k, v in node.items():
|
493
|
+
if isinstance(v, list):
|
494
|
+
v = PrettyList(v)
|
495
|
+
|
496
|
+
if isinstance(v, dict):
|
497
|
+
v = PrettyDict(v)
|
498
|
+
|
499
|
+
if isinstance(v, PrettyDict):
|
500
|
+
v = NestedStateRepr(v)
|
501
|
+
|
502
|
+
yield PrettyAttr(repr(k), v)
|
503
|
+
|
504
|
+
|
505
|
+
class NestedDict(PrettyDict):
|
506
|
+
"""A pytree-like nested mapping structure for organizing hierarchical data.
|
507
|
+
|
508
|
+
This class represents a nested mapping from strings or integers to leaves, where
|
509
|
+
valid leaf types include :class:`State`, ``jax.Array``, ``numpy.ndarray``, or
|
510
|
+
nested :class:`NestedDict` and :class:`FlattedDict` structures.
|
511
|
+
|
512
|
+
:class:`NestedDict` is a JAX pytree and can be used with JAX transformations.
|
513
|
+
It provides methods for flattening to :class:`FlattedDict`, splitting/filtering
|
514
|
+
based on predicates, and merging multiple nested structures.
|
515
|
+
|
516
|
+
Attributes:
|
517
|
+
__module__ (str): Module identifier set to 'brainstate.util'.
|
518
|
+
|
519
|
+
Example:
|
520
|
+
>>> from brainstate.util import NestedDict
|
521
|
+
>>> state = NestedDict({
|
522
|
+
... 'layer1': {'weight': jnp.ones((3, 3)), 'bias': jnp.zeros(3)},
|
523
|
+
... 'layer2': {'weight': jnp.ones((3, 1))}
|
524
|
+
... })
|
525
|
+
>>> flat = state.to_flat()
|
526
|
+
>>> print(flat)
|
527
|
+
FlattedDict({('layer1', 'weight'): ..., ('layer1', 'bias'): ..., ...})
|
528
|
+
|
529
|
+
See Also:
|
530
|
+
:class:`FlattedDict`: The flattened counterpart with tuple keys.
|
531
|
+
:func:`flat_mapping`: Function to flatten a nested mapping.
|
532
|
+
:func:`nest_mapping`: Function to unflatten a flat mapping.
|
533
|
+
"""
|
534
|
+
__module__ = 'brainstate.util'
|
535
|
+
|
536
|
+
def __or__(self, other: 'NestedDict') -> 'NestedDict':
|
537
|
+
if not other:
|
538
|
+
return self
|
539
|
+
assert isinstance(other, NestedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
540
|
+
return NestedDict.merge(self, other)
|
541
|
+
|
542
|
+
def __sub__(self, other: 'NestedDict') -> 'NestedDict':
|
543
|
+
if not other:
|
544
|
+
return self
|
545
|
+
|
546
|
+
assert isinstance(other, NestedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
547
|
+
self_flat = self.to_flat()
|
548
|
+
other_flat = other.to_flat()
|
549
|
+
diff = {k: v for k, v in self_flat.items() if k not in other_flat}
|
550
|
+
return NestedDict.from_flat(diff)
|
551
|
+
|
552
|
+
def to_flat(self) -> 'FlattedDict':
|
553
|
+
"""
|
554
|
+
Flatten the nested mapping into a flat mapping.
|
555
|
+
|
556
|
+
Returns:
|
557
|
+
The flattened mapping.
|
558
|
+
"""
|
559
|
+
return flat_mapping(self)
|
560
|
+
|
561
|
+
@classmethod
|
562
|
+
def from_flat(cls, flat_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]]) -> 'NestedDict':
|
563
|
+
"""
|
564
|
+
Create a :class:`NestedDict` from a flat mapping.
|
565
|
+
|
566
|
+
Args:
|
567
|
+
flat_dict: The flat mapping.
|
568
|
+
|
569
|
+
Returns:
|
570
|
+
The :class:`NestedDict`.
|
571
|
+
"""
|
572
|
+
nested_state = nest_mapping(dict(flat_dict))
|
573
|
+
return cls(nested_state)
|
574
|
+
|
575
|
+
def split(self, *filters: Filter) -> Union['NestedDict', Tuple['NestedDict', ...]]:
|
576
|
+
"""
|
577
|
+
Split a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
|
578
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`),
|
579
|
+
and the filters must be exhaustive (i.e. they must cover all
|
580
|
+
:class:`State` types in the :class:`NestedDict`).
|
581
|
+
|
582
|
+
Example usage::
|
583
|
+
|
584
|
+
>>> import brainstate as brainstate
|
585
|
+
|
586
|
+
>>> class Model(brainstate.nn.Module):
|
587
|
+
... def __init__(self):
|
588
|
+
... super().__init__()
|
589
|
+
... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
590
|
+
... self.linear = brainstate.nn.Linear(2, 3)
|
591
|
+
... def __call__(self, x):
|
592
|
+
... return self.linear(self.batchnorm(x))
|
593
|
+
|
594
|
+
>>> model = Model()
|
595
|
+
>>> state_map = brainstate.graph.treefy_states(model)
|
596
|
+
>>> param, others = state_map.treefy_split(brainstate.ParamState, ...)
|
597
|
+
|
598
|
+
Arguments:
|
599
|
+
first: The first filter
|
600
|
+
*filters: The optional, additional filters to group the state into mutually exclusive substates.
|
601
|
+
|
602
|
+
Returns:
|
603
|
+
One or more ``States`` equal to the number of filters passed.
|
604
|
+
"""
|
605
|
+
*states_, rest = _split_nested_mapping(self, *filters)
|
606
|
+
if rest:
|
607
|
+
raise ValueError(f'Non-exhaustive filters, got a non-empty remainder: {rest}.\n'
|
608
|
+
f'Use `...` to match all remaining elements.')
|
609
|
+
|
610
|
+
states: NestedDict | Tuple[NestedDict, ...]
|
611
|
+
if len(states_) == 1:
|
612
|
+
states = states_[0]
|
613
|
+
else:
|
614
|
+
states = tuple(states_)
|
615
|
+
return states # type: ignore[bad-return-type]
|
616
|
+
|
617
|
+
def filter(self, *filters: Filter) -> Union['NestedDict', Tuple['NestedDict', ...]]:
|
618
|
+
"""
|
619
|
+
Filter a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
|
620
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`).
|
621
|
+
This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
|
622
|
+
except the filters can be non-exhaustive.
|
623
|
+
|
624
|
+
Arguments:
|
625
|
+
first: The first filter
|
626
|
+
*filters: The optional, additional filters to group the state into mutually exclusive substates.
|
627
|
+
|
628
|
+
Returns:
|
629
|
+
One or more ``States`` equal to the number of filters passed.
|
630
|
+
"""
|
631
|
+
states_ = _split_nested_mapping(self, *filters)
|
632
|
+
assert len(states_) == len(filters) + 1, f'Expected {len(filters) + 1} states, got {len(states_)}'
|
633
|
+
if len(filters) == 1:
|
634
|
+
return states_[0]
|
635
|
+
else:
|
636
|
+
return tuple(states_[:-1])
|
637
|
+
|
638
|
+
@staticmethod
|
639
|
+
def merge(*states) -> 'NestedDict':
|
640
|
+
"""
|
641
|
+
The inverse of :meth:`split()`.
|
642
|
+
|
643
|
+
``merge`` takes one or more :class:`PrettyDict`'s and creates a new :class:`PrettyDict`.
|
644
|
+
|
645
|
+
Args:
|
646
|
+
*states: Additional :class:`PrettyDict` objects.
|
647
|
+
|
648
|
+
Returns:
|
649
|
+
The merged :class:`PrettyDict`.
|
650
|
+
"""
|
651
|
+
new_state: FlattedDict = FlattedDict()
|
652
|
+
for state in states:
|
653
|
+
if isinstance(state, NestedDict):
|
654
|
+
new_state.update(state.to_flat()) # type: ignore[attribute-error] # pytype is wrong here
|
655
|
+
elif isinstance(state, FlattedDict):
|
656
|
+
new_state.update(state)
|
657
|
+
else:
|
658
|
+
raise TypeError(f'Expected Nested or Flatted Mapping, got {type(state)} instead.')
|
659
|
+
return NestedDict.from_flat(new_state)
|
660
|
+
|
661
|
+
def to_pure_dict(self) -> Dict[str, Any]:
|
662
|
+
"""Convert to a pure nested dictionary structure.
|
663
|
+
|
664
|
+
This method creates a standard Python dictionary with the same nested structure
|
665
|
+
as this NestedDict, without any special class wrappers.
|
666
|
+
|
667
|
+
Returns:
|
668
|
+
Dict[str, Any]: A pure nested dictionary representation.
|
669
|
+
|
670
|
+
Example:
|
671
|
+
>>> nested = NestedDict({'a': {'b': 1, 'c': 2}})
|
672
|
+
>>> pure = nested.to_pure_dict()
|
673
|
+
>>> type(pure)
|
674
|
+
<class 'dict'>
|
675
|
+
"""
|
676
|
+
flat_values = {k: x for k, x in self.to_flat().items()}
|
677
|
+
return nest_mapping(flat_values).to_dict()
|
678
|
+
|
679
|
+
def replace_by_pure_dict(
|
680
|
+
self,
|
681
|
+
pure_dict: Dict[str, Any],
|
682
|
+
replace_fn: Optional[SetValueFn] = None
|
683
|
+
) -> None:
|
684
|
+
"""Replace values in this NestedDict using a pure dictionary.
|
685
|
+
|
686
|
+
This method updates the values in this NestedDict with values from a standard
|
687
|
+
Python dictionary. For :class:`State` objects with a ``replace`` method, the
|
688
|
+
replace method is called; otherwise, values are directly assigned.
|
689
|
+
|
690
|
+
Args:
|
691
|
+
pure_dict: A pure dictionary with matching structure containing new values.
|
692
|
+
replace_fn: Optional custom function to replace values. Takes ``(old_value, new_value)``
|
693
|
+
and returns the updated value. Defaults to calling ``replace()`` method if
|
694
|
+
available, otherwise direct assignment.
|
695
|
+
|
696
|
+
Raises:
|
697
|
+
ValueError: If a key in ``pure_dict`` is not found in this NestedDict.
|
698
|
+
|
699
|
+
Example:
|
700
|
+
>>> from brainstate._state import State
|
701
|
+
>>> nested = NestedDict({'a': State(1), 'b': 2})
|
702
|
+
>>> nested.replace_by_pure_dict({'a': 10, 'b': 20})
|
703
|
+
>>> nested['a'].value
|
704
|
+
10
|
705
|
+
"""
|
706
|
+
if replace_fn is None:
|
707
|
+
replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v
|
708
|
+
current_flat = self.to_flat()
|
709
|
+
for kp, v in flat_mapping(pure_dict).items():
|
710
|
+
if kp not in current_flat:
|
711
|
+
raise ValueError(f'key in pure_dict not available in state: {kp}')
|
712
|
+
current_flat[kp] = replace_fn(current_flat[kp], v)
|
713
|
+
self.update(nest_mapping(current_flat))
|
714
|
+
|
715
|
+
|
716
|
+
class FlattedDict(PrettyDict):
|
717
|
+
"""
|
718
|
+
A pytree-like structure that contains a :class:`Mapping` from strings or integers to leaves.
|
719
|
+
|
720
|
+
A valid leaf type is either :class:`State`, ``jax.Array``, ``numpy.ndarray`` or Python variables.
|
721
|
+
|
722
|
+
A :class:`NestedDict` can be generated by either calling :func:`states()` or
|
723
|
+
:func:`nodes()` on the :class:`Module`.
|
724
|
+
|
725
|
+
Example usage::
|
726
|
+
|
727
|
+
>>> import brainstate as brainstate
|
728
|
+
>>> import jax.numpy as jnp
|
729
|
+
>>>
|
730
|
+
>>> class Model(brainstate.nn.Module):
|
731
|
+
... def __init__(self):
|
732
|
+
... super().__init__()
|
733
|
+
... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
|
734
|
+
... self.linear = brainstate.nn.Linear(2, 3)
|
735
|
+
... def __call__(self, x):
|
736
|
+
... return self.linear(self.batchnorm(x))
|
737
|
+
>>>
|
738
|
+
>>> model = Model()
|
739
|
+
|
740
|
+
>>> # retrieve the states of the model
|
741
|
+
>>> model.states() # with the same to the function of ``brainstate.graph.states()``
|
742
|
+
FlattedDict({
|
743
|
+
('batchnorm', 'running_mean'): LongTermState(
|
744
|
+
value=Array([[0., 0., 0.]], dtype=float32)
|
745
|
+
),
|
746
|
+
('batchnorm', 'running_var'): LongTermState(
|
747
|
+
value=Array([[1., 1., 1.]], dtype=float32)
|
748
|
+
),
|
749
|
+
('batchnorm', 'weight'): ParamState(
|
750
|
+
value={'bias': Array([[0., 0., 0.]], dtype=float32), 'scale': Array([[1., 1., 1.]], dtype=float32)}
|
751
|
+
),
|
752
|
+
('linear', 'weight'): ParamState(
|
753
|
+
value={'weight': Array([[-0.21467684, 0.7621282 , -0.50756454, -0.49047297],
|
754
|
+
[-0.90413696, 0.6711 , -0.1254792 , 0.50412565],
|
755
|
+
[ 0.23975602, 0.47905368, 1.4851435 , 0.16745673]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
|
756
|
+
)
|
757
|
+
})
|
758
|
+
|
759
|
+
>>> # retrieve the nodes of the model
|
760
|
+
>>> model.nodes() # with the same to the function of ``brainstate.graph.nodes()``
|
761
|
+
FlattedDict({
|
762
|
+
('batchnorm',): BatchNorm1d(
|
763
|
+
in_size=(10, 3),
|
764
|
+
out_size=(10, 3),
|
765
|
+
affine=True,
|
766
|
+
bias_initializer=Constant(value=0.0, dtype=<class 'numpy.float32'>),
|
767
|
+
scale_initializer=Constant(value=1.0, dtype=<class 'numpy.float32'>),
|
768
|
+
dtype=<class 'numpy.float32'>,
|
769
|
+
track_running_stats=True,
|
770
|
+
momentum=Array(shape=(), dtype=float32),
|
771
|
+
epsilon=Array(shape=(), dtype=float32),
|
772
|
+
feature_axis=(1,),
|
773
|
+
axis_name=None,
|
774
|
+
axis_index_groups=None,
|
775
|
+
running_mean=LongTermState(
|
776
|
+
value=Array(shape=(1, 3), dtype=float32)
|
777
|
+
),
|
778
|
+
running_var=LongTermState(
|
779
|
+
value=Array(shape=(1, 3), dtype=float32)
|
780
|
+
),
|
781
|
+
weight=ParamState(
|
782
|
+
value={'bias': Array(shape=(1, 3), dtype=float32), 'scale': Array(shape=(1, 3), dtype=float32)}
|
783
|
+
)
|
784
|
+
),
|
785
|
+
('linear',): Linear(
|
786
|
+
in_size=(10, 3),
|
787
|
+
out_size=(10, 4),
|
788
|
+
w_mask=None,
|
789
|
+
weight=ParamState(
|
790
|
+
value={'bias': Array(shape=(4,), dtype=float32), 'weight': Array(shape=(3, 4), dtype=float32)}
|
791
|
+
)
|
792
|
+
),
|
793
|
+
(): Model(
|
794
|
+
batchnorm=BatchNorm1d(...),
|
795
|
+
linear=Linear(...)
|
796
|
+
)
|
797
|
+
})
|
798
|
+
"""
|
799
|
+
__module__ = 'brainstate.util'
|
800
|
+
|
801
|
+
def __or__(self, other: 'FlattedDict') -> 'FlattedDict':
|
802
|
+
if not other:
|
803
|
+
return self
|
804
|
+
assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
805
|
+
return FlattedDict.merge(self, other)
|
806
|
+
|
807
|
+
def __sub__(self, other: 'FlattedDict') -> 'FlattedDict':
|
808
|
+
if not other:
|
809
|
+
return self
|
810
|
+
assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
|
811
|
+
diff = {k: v for k, v in self.items() if k not in other}
|
812
|
+
return FlattedDict(diff)
|
813
|
+
|
814
|
+
def to_nest(self) -> NestedDict:
|
815
|
+
"""
|
816
|
+
Unflatten the flat mapping into a nested mapping.
|
817
|
+
|
818
|
+
Returns:
|
819
|
+
The nested mapping.
|
820
|
+
"""
|
821
|
+
return nest_mapping(self)
|
822
|
+
|
823
|
+
@classmethod
|
824
|
+
def from_nest(
|
825
|
+
cls, nested_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]],
|
826
|
+
) -> 'FlattedDict':
|
827
|
+
"""
|
828
|
+
Create a :class:`NestedDict` from a flat mapping.
|
829
|
+
|
830
|
+
Args:
|
831
|
+
nested_dict: The flat mapping.
|
832
|
+
|
833
|
+
Returns:
|
834
|
+
The :class:`NestedDict`.
|
835
|
+
"""
|
836
|
+
return flat_mapping(nested_dict)
|
837
|
+
|
838
|
+
def split( # type: ignore[misc]
|
839
|
+
self,
|
840
|
+
first: Filter,
|
841
|
+
/,
|
842
|
+
*filters: Filter
|
843
|
+
) -> Union['FlattedDict', tuple['FlattedDict', ...]]:
|
844
|
+
"""
|
845
|
+
Split a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
|
846
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`),
|
847
|
+
and the filters must be exhaustive (i.e. they must cover all
|
848
|
+
:class:`State` types in the :class:`NestedDict`).
|
849
|
+
|
850
|
+
Arguments:
|
851
|
+
first: The first filter
|
852
|
+
*filters: The optional, additional filters to group the state into mutually exclusive substates.
|
853
|
+
|
854
|
+
Returns:
|
855
|
+
One or more ``States`` equal to the number of filters passed.
|
856
|
+
"""
|
857
|
+
filters = (first, *filters)
|
858
|
+
*states_, rest = _split_flatted_mapping(self, *filters)
|
859
|
+
if rest:
|
860
|
+
raise ValueError(f'Non-exhaustive filters, got a non-empty remainder: {rest}.\n'
|
861
|
+
f'Use `...` to match all remaining elements.')
|
862
|
+
|
863
|
+
states: FlattedDict | Tuple[FlattedDict, ...]
|
864
|
+
if len(states_) == 1:
|
865
|
+
states = states_[0]
|
866
|
+
else:
|
867
|
+
states = tuple(states_)
|
868
|
+
return states # type: ignore[bad-return-type]
|
869
|
+
|
870
|
+
def filter(
|
871
|
+
self,
|
872
|
+
first: Filter,
|
873
|
+
/,
|
874
|
+
*filters: Filter,
|
875
|
+
) -> Union['FlattedDict', Tuple['FlattedDict', ...]]:
|
876
|
+
"""
|
877
|
+
Filter a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
|
878
|
+
user must pass at least one `:class:`Filter` (i.e. :class:`State`).
|
879
|
+
This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
|
880
|
+
except the filters can be non-exhaustive.
|
881
|
+
|
882
|
+
Arguments:
|
883
|
+
first: The first filter
|
884
|
+
*filters: The optional, additional filters to group the state into mutually exclusive substates.
|
885
|
+
|
886
|
+
Returns:
|
887
|
+
One or more ``States`` equal to the number of filters passed.
|
888
|
+
"""
|
889
|
+
*states_, _rest = _split_flatted_mapping(self, first, *filters)
|
890
|
+
assert len(states_) == len(filters) + 1, f'Expected {len(filters) + 1} states, got {len(states_)}'
|
891
|
+
if len(states_) == 1:
|
892
|
+
states = states_[0]
|
893
|
+
else:
|
894
|
+
states = tuple(states_)
|
895
|
+
return states # type: ignore[bad-return-type]
|
896
|
+
|
897
|
+
@staticmethod
|
898
|
+
def merge(*states: Union['FlattedDict', 'NestedDict']) -> 'FlattedDict':
|
899
|
+
"""
|
900
|
+
The inverse of :meth:`split()`.
|
901
|
+
|
902
|
+
``merge`` takes one or more :class:`FlattedDict`'s and creates a new :class:`FlattedDict`.
|
903
|
+
|
904
|
+
Args:
|
905
|
+
state: A :class:`PrettyDict` object.
|
906
|
+
*states: Additional :class:`PrettyDict` objects.
|
907
|
+
|
908
|
+
Returns:
|
909
|
+
The merged :class:`PrettyDict`.
|
910
|
+
"""
|
911
|
+
new_state: FlattedStateMapping[V] = {}
|
912
|
+
for state in states:
|
913
|
+
if isinstance(state, NestedDict):
|
914
|
+
new_state.update(state.to_flat()) # type: ignore[attribute-error] # pytype is wrong here
|
915
|
+
elif isinstance(state, FlattedDict):
|
916
|
+
new_state.update(state)
|
917
|
+
else:
|
918
|
+
raise TypeError(f'Expected Nested or Flatted Mapping, got {type(state)} instead.')
|
919
|
+
return FlattedDict(new_state)
|
920
|
+
|
921
|
+
def to_dict_values(self) -> Dict[PathParts, Any]:
|
922
|
+
"""Convert a FlattedDict containing State objects to a dictionary of raw values.
|
923
|
+
|
924
|
+
This method extracts the underlying values from any :class:`State` objects in the
|
925
|
+
FlattedDict, creating a new dictionary with the same keys but where each State
|
926
|
+
object is replaced by its ``value`` attribute. Non-State objects are kept as is.
|
927
|
+
|
928
|
+
Returns:
|
929
|
+
Dict[PathParts, Any]: A dictionary with the same keys as the FlattedDict, but
|
930
|
+
where each State object is replaced by its value. Non-State objects remain
|
931
|
+
unchanged.
|
932
|
+
|
933
|
+
Example:
|
934
|
+
>>> from brainstate._state import ParamState
|
935
|
+
>>> flat_dict = FlattedDict({
|
936
|
+
... ('model', 'layer1', 'weight'): ParamState(value=jnp.ones((10, 5)))
|
937
|
+
... })
|
938
|
+
>>> values = flat_dict.to_dict_values()
|
939
|
+
>>> values[('model', 'layer1', 'weight')]
|
940
|
+
Array([[1., 1., ...]], dtype=float32)
|
941
|
+
"""
|
942
|
+
from brainstate._state import State
|
943
|
+
return {
|
944
|
+
k: v.value if isinstance(v, State) else v
|
945
|
+
for k, v in self.items()
|
946
|
+
}
|
947
|
+
|
948
|
+
def assign_dict_values(self, data: Dict[PathParts, Any]) -> None:
|
949
|
+
"""Assign values from a dictionary to this FlattedDict.
|
950
|
+
|
951
|
+
This method updates the values in the FlattedDict with values from the provided
|
952
|
+
dictionary. For keys that correspond to :class:`State` objects, the ``value``
|
953
|
+
attribute of the State is updated. For other keys, the value in the FlattedDict
|
954
|
+
is directly replaced with the new value.
|
955
|
+
|
956
|
+
Args:
|
957
|
+
data: A dictionary containing the values to assign, where keys must match
|
958
|
+
those in the FlattedDict.
|
959
|
+
|
960
|
+
Raises:
|
961
|
+
KeyError: If a key in the FlattedDict is not present in the provided dictionary.
|
962
|
+
|
963
|
+
Example:
|
964
|
+
>>> from brainstate._state import ParamState
|
965
|
+
>>> flat_dict = FlattedDict({
|
966
|
+
... ('model', 'weight'): ParamState(value=jnp.zeros((5, 5)))
|
967
|
+
... })
|
968
|
+
>>> flat_dict.assign_dict_values({('model', 'weight'): jnp.ones((5, 5))})
|
969
|
+
# The ParamState's value is now an array of ones
|
970
|
+
"""
|
971
|
+
from brainstate._state import State
|
972
|
+
for k in self.keys():
|
973
|
+
if k not in data:
|
974
|
+
raise KeyError(f'Invalid key: {k!r}')
|
975
|
+
val = self[k]
|
976
|
+
if isinstance(val, State):
|
977
|
+
val.value = data[k]
|
978
|
+
else:
|
979
|
+
self[k] = data[k]
|
980
|
+
|
981
|
+
|
982
|
+
def _split_nested_mapping(
|
983
|
+
mapping: 'NestedDict',
|
984
|
+
*filters: Filter,
|
985
|
+
) -> Tuple['NestedDict', ...]:
|
986
|
+
"""Split a nested mapping into multiple nested mappings based on filters.
|
987
|
+
|
988
|
+
This internal function partitions a NestedDict into multiple NestedDicts based on
|
989
|
+
filter predicates. Items that match each filter are placed in separate mappings,
|
990
|
+
with unmatched items going to the final mapping.
|
991
|
+
|
992
|
+
Args:
|
993
|
+
mapping: The NestedDict to split.
|
994
|
+
*filters: Filter specifications. The catch-all filters ``...`` or ``True`` can
|
995
|
+
only be used as the last filter.
|
996
|
+
|
997
|
+
Returns:
|
998
|
+
Tuple[NestedDict, ...]: A tuple of n+1 NestedDicts, where n is the number
|
999
|
+
of filters. The last mapping contains items that didn't match any filter.
|
1000
|
+
|
1001
|
+
Raises:
|
1002
|
+
ValueError: If ``...`` or ``True`` is used before the last filter position.
|
1003
|
+
AssertionError: If mapping is not a NestedDict.
|
1004
|
+
"""
|
1005
|
+
# Check if the filters are exhaustive
|
1006
|
+
for i, filter_ in enumerate(filters):
|
1007
|
+
if filter_ in (..., True) and i != len(filters) - 1:
|
1008
|
+
remaining_filters = filters[i + 1:]
|
1009
|
+
if not all(f in (..., True) for f in remaining_filters):
|
1010
|
+
raise ValueError('`...` or `True` can only be used as the last filters, '
|
1011
|
+
f'got {filter_} it at index {i}.')
|
1012
|
+
|
1013
|
+
# Change the filters to predicates
|
1014
|
+
predicates = tuple(map(to_predicate, filters))
|
1015
|
+
|
1016
|
+
# We have n + 1 state mappings, where n is the number of predicates
|
1017
|
+
# The last state mapping is for values that don't match any predicate
|
1018
|
+
flat_states: Tuple[FlattedStateMapping[V], ...] = tuple({} for _ in range(len(predicates) + 1))
|
1019
|
+
|
1020
|
+
assert isinstance(mapping, NestedDict), f'expected NestedDict; got {type(mapping).__qualname__}'
|
1021
|
+
flat_state = mapping.to_flat()
|
1022
|
+
for path, value in flat_state.items():
|
1023
|
+
for i, predicate in enumerate(predicates):
|
1024
|
+
if predicate(path, value):
|
1025
|
+
flat_states[i][path] = value # type: ignore[index]
|
1026
|
+
break
|
1027
|
+
else:
|
1028
|
+
# If we didn't break, set leaf to last state
|
1029
|
+
flat_states[-1][path] = value # type: ignore[index]
|
1030
|
+
|
1031
|
+
return tuple(NestedDict.from_flat(flat_state) for flat_state in flat_states)
|
1032
|
+
|
1033
|
+
|
1034
|
+
def _split_flatted_mapping(
|
1035
|
+
mapping: FlattedDict,
|
1036
|
+
*filters: Filter,
|
1037
|
+
) -> Tuple[FlattedDict, ...]:
|
1038
|
+
"""Split a flattened mapping into multiple flattened mappings based on filters.
|
1039
|
+
|
1040
|
+
This internal function partitions a FlattedDict into multiple FlattedDicts based on
|
1041
|
+
filter predicates. Items that match each filter are placed in separate mappings,
|
1042
|
+
with unmatched items going to the final mapping.
|
1043
|
+
|
1044
|
+
Args:
|
1045
|
+
mapping: The FlattedDict to split.
|
1046
|
+
*filters: Filter specifications. The catch-all filters ``...`` or ``True`` can
|
1047
|
+
only be used as the last filter.
|
1048
|
+
|
1049
|
+
Returns:
|
1050
|
+
Tuple[FlattedDict, ...]: A tuple of n+1 FlattedDicts, where n is the number
|
1051
|
+
of filters. The last mapping contains items that didn't match any filter.
|
1052
|
+
|
1053
|
+
Raises:
|
1054
|
+
ValueError: If ``...`` or ``True`` is used before the last filter position.
|
1055
|
+
AssertionError: If mapping is not a FlattedDict.
|
1056
|
+
"""
|
1057
|
+
# Check if the filters are exhaustive
|
1058
|
+
for i, filter_ in enumerate(filters):
|
1059
|
+
if filter_ in (..., True) and i != len(filters) - 1:
|
1060
|
+
remaining_filters = filters[i + 1:]
|
1061
|
+
if not all(f in (..., True) for f in remaining_filters):
|
1062
|
+
raise ValueError('`...` or `True` can only be used as the last filters, '
|
1063
|
+
f'got {filter_} it at index {i}.')
|
1064
|
+
|
1065
|
+
# Change the filters to predicates
|
1066
|
+
predicates = tuple(map(to_predicate, filters))
|
1067
|
+
|
1068
|
+
# We have n + 1 state mappings, where n is the number of predicates
|
1069
|
+
# The last state mapping is for values that don't match any predicate
|
1070
|
+
flat_states: Tuple[FlattedStateMapping[V], ...] = tuple({} for _ in range(len(predicates) + 1))
|
1071
|
+
|
1072
|
+
assert isinstance(mapping, FlattedDict), f'expected FlattedDict; got {type(mapping).__qualname__}'
|
1073
|
+
for path, value in mapping.items():
|
1074
|
+
for i, predicate in enumerate(predicates):
|
1075
|
+
if predicate(path, value):
|
1076
|
+
flat_states[i][path] = value # type: ignore[index]
|
1077
|
+
break
|
1078
|
+
else:
|
1079
|
+
# If we didn't break, set leaf to last state
|
1080
|
+
flat_states[-1][path] = value # type: ignore[index]
|
1081
|
+
|
1082
|
+
return tuple(FlattedDict(flat_state) for flat_state in flat_states)
|
1083
|
+
|
1084
|
+
|
1085
|
+
# Register :class:`NestedDict` as a pytree
|
1086
|
+
def _nest_flatten_with_keys(x: NestedDict) -> Tuple[Tuple[Tuple[jax.tree_util.DictKey, Any], ...], Tuple[K, ...]]:
|
1087
|
+
"""Flatten a NestedDict for JAX pytree registration with keys.
|
1088
|
+
|
1089
|
+
Args:
|
1090
|
+
x: The NestedDict to flatten.
|
1091
|
+
|
1092
|
+
Returns:
|
1093
|
+
Tuple containing:
|
1094
|
+
|
1095
|
+
- Tuple of (key, value) pairs where keys are wrapped in DictKey
|
1096
|
+
- Tuple of static keys for reconstruction
|
1097
|
+
"""
|
1098
|
+
items = sorted(x.items())
|
1099
|
+
children = tuple((jax.tree_util.DictKey(key), value) for key, value in items)
|
1100
|
+
return children, tuple(key for key, _ in items)
|
1101
|
+
|
1102
|
+
|
1103
|
+
def _nest_unflatten(
|
1104
|
+
static: Tuple[K, ...],
|
1105
|
+
leaves: Union[Tuple[V, ...], Tuple[Dict]],
|
1106
|
+
) -> NestedDict:
|
1107
|
+
"""Unflatten a NestedDict from pytree components.
|
1108
|
+
|
1109
|
+
Args:
|
1110
|
+
static: Tuple of keys for reconstruction.
|
1111
|
+
leaves: Tuple of leaf values.
|
1112
|
+
|
1113
|
+
Returns:
|
1114
|
+
NestedDict: Reconstructed NestedDict.
|
1115
|
+
"""
|
1116
|
+
return NestedDict(zip(static, leaves))
|
1117
|
+
|
1118
|
+
|
1119
|
+
jax.tree_util.register_pytree_with_keys(
|
1120
|
+
NestedDict,
|
1121
|
+
_nest_flatten_with_keys,
|
1122
|
+
_nest_unflatten
|
1123
|
+
) # type: ignore[arg-type]
|
1124
|
+
|
1125
|
+
|
1126
|
+
# Register :class:`FlattedDict` as a pytree
|
1127
|
+
|
1128
|
+
def _flat_unflatten(
|
1129
|
+
static: Tuple[K, ...],
|
1130
|
+
leaves: Union[Tuple[V, ...], Tuple[Dict]],
|
1131
|
+
) -> FlattedDict:
|
1132
|
+
"""Unflatten a FlattedDict from pytree components.
|
1133
|
+
|
1134
|
+
Args:
|
1135
|
+
static: Tuple of keys for reconstruction.
|
1136
|
+
leaves: Tuple of leaf values.
|
1137
|
+
|
1138
|
+
Returns:
|
1139
|
+
FlattedDict: Reconstructed FlattedDict.
|
1140
|
+
"""
|
1141
|
+
return FlattedDict(zip(static, leaves))
|
1142
|
+
|
1143
|
+
|
1144
|
+
jax.tree_util.register_pytree_with_keys(
|
1145
|
+
FlattedDict,
|
1146
|
+
_nest_flatten_with_keys,
|
1147
|
+
_flat_unflatten
|
1148
|
+
) # type: ignore[arg-type]
|
1149
|
+
|
1150
|
+
|
1151
|
+
@jax.tree_util.register_pytree_node_class
|
1152
|
+
class PrettyList(list, PrettyRepr):
|
1153
|
+
"""A list subclass with pretty representation and JAX pytree compatibility.
|
1154
|
+
|
1155
|
+
This class extends the built-in list with pretty printing capabilities and
|
1156
|
+
registers itself as a JAX pytree for use in JAX transformations.
|
1157
|
+
|
1158
|
+
Attributes:
|
1159
|
+
__module__ (str): Module identifier set to 'brainstate.util'.
|
1160
|
+
|
1161
|
+
Example:
|
1162
|
+
>>> from brainstate.util import PrettyList
|
1163
|
+
>>> lst = PrettyList([1, 2, {'a': 3}])
|
1164
|
+
>>> print(lst) # Pretty formatted output
|
1165
|
+
[1, 2, {'a': 3}]
|
1166
|
+
"""
|
1167
|
+
__module__ = 'brainstate.util'
|
1168
|
+
|
1169
|
+
def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
|
1170
|
+
"""Generate pretty representation items for this list.
|
1171
|
+
|
1172
|
+
Yields:
|
1173
|
+
Union[PrettyType, PrettyAttr]: Pretty representation items.
|
1174
|
+
"""
|
1175
|
+
yield from yield_unique_pretty_repr_items(self, _list_repr_object, _list_repr_attr)
|
1176
|
+
|
1177
|
+
def __repr__(self) -> str:
|
1178
|
+
"""Generate a pretty string representation of the list.
|
1179
|
+
|
1180
|
+
Returns:
|
1181
|
+
str: A formatted string representation using pretty printing.
|
1182
|
+
"""
|
1183
|
+
return pretty_repr_object(self)
|
1184
|
+
|
1185
|
+
def tree_flatten(self) -> Tuple[list, Tuple]:
|
1186
|
+
"""Flatten the list for JAX pytree operations.
|
1187
|
+
|
1188
|
+
Returns:
|
1189
|
+
Tuple containing:
|
1190
|
+
- The list items as children
|
1191
|
+
- Empty tuple as auxiliary data
|
1192
|
+
"""
|
1193
|
+
return list(self), ()
|
1194
|
+
|
1195
|
+
@classmethod
|
1196
|
+
def tree_unflatten(cls, aux_data: Tuple, children: list) -> 'PrettyList':
|
1197
|
+
"""Reconstruct a PrettyList from pytree components.
|
1198
|
+
|
1199
|
+
Args:
|
1200
|
+
aux_data: Auxiliary data (unused).
|
1201
|
+
children: List items to reconstruct from.
|
1202
|
+
|
1203
|
+
Returns:
|
1204
|
+
PrettyList: Reconstructed PrettyList.
|
1205
|
+
"""
|
1206
|
+
return cls(children)
|
1207
|
+
|
1208
|
+
|
1209
|
+
def _list_repr_attr(node: PrettyList) -> Generator[PrettyAttr, None, None]:
|
1210
|
+
"""Generate attribute representations for PrettyList items.
|
1211
|
+
|
1212
|
+
This function converts list and dict values to their pretty equivalents
|
1213
|
+
and wraps PrettyDict values in NestedStateRepr for compact display.
|
1214
|
+
|
1215
|
+
Args:
|
1216
|
+
node: The PrettyList whose items to represent.
|
1217
|
+
|
1218
|
+
Yields:
|
1219
|
+
PrettyAttr: Pretty attribute representations for each item with empty keys.
|
1220
|
+
"""
|
1221
|
+
for v in node:
|
1222
|
+
if isinstance(v, list):
|
1223
|
+
v = PrettyList(v)
|
1224
|
+
if isinstance(v, dict):
|
1225
|
+
v = PrettyDict(v)
|
1226
|
+
if isinstance(v, PrettyDict):
|
1227
|
+
v = NestedStateRepr(v)
|
1228
|
+
yield PrettyAttr('', v)
|
1229
|
+
|
1230
|
+
|
1231
|
+
def _list_repr_object(node: PrettyList) -> Generator[PrettyType, None, None]:
|
1232
|
+
"""Generate the object representation for PrettyList.
|
1233
|
+
|
1234
|
+
Args:
|
1235
|
+
node: The PrettyList to represent.
|
1236
|
+
|
1237
|
+
Yields:
|
1238
|
+
PrettyType: A type representation with list-like formatting.
|
1239
|
+
"""
|
1240
|
+
yield PrettyType('', value_sep='', start='[', end=']')
|
1241
|
+
|
1242
|
+
|
1243
|
+
def _repr_object_general(node: Any) -> Generator[PrettyType, None, None]:
|
1244
|
+
"""Generate a general object representation for any PrettyObject.
|
1245
|
+
|
1246
|
+
This function creates a pretty representation of an object that includes
|
1247
|
+
the type of the object with parentheses formatting.
|
1248
|
+
|
1249
|
+
Args:
|
1250
|
+
node: The object to be represented.
|
1251
|
+
|
1252
|
+
Yields:
|
1253
|
+
PrettyType: A PrettyType object representing the type of the node,
|
1254
|
+
with specified value separator, start, and end characters.
|
1255
|
+
"""
|
1256
|
+
yield PrettyType(type(node), value_sep='=', start='(', end=')')
|
1257
|
+
|
1258
|
+
|
1259
|
+
def _repr_attribute_general(node: Any) -> Generator[PrettyAttr, None, None]:
|
1260
|
+
"""Generate a pretty representation of the attributes of a general object.
|
1261
|
+
|
1262
|
+
This function iterates over the attributes of a given node and generates
|
1263
|
+
a pretty representation for each attribute. It handles conversion of lists
|
1264
|
+
and dictionaries to their pretty representation counterparts and yields a
|
1265
|
+
PrettyAttr object for each visible attribute.
|
1266
|
+
|
1267
|
+
The function respects the ``__pretty_repr_item__`` method if available on
|
1268
|
+
the node, allowing custom filtering or transformation of attribute items.
|
1269
|
+
|
1270
|
+
Args:
|
1271
|
+
node: The object whose attributes are to be represented.
|
1272
|
+
|
1273
|
+
Yields:
|
1274
|
+
PrettyAttr: A PrettyAttr object representing the key and value of
|
1275
|
+
each attribute in a pretty format.
|
1276
|
+
"""
|
1277
|
+
for k, v in vars(node).items():
|
1278
|
+
try:
|
1279
|
+
res = node.__pretty_repr_item__(k, v)
|
1280
|
+
if res is None:
|
1281
|
+
continue
|
1282
|
+
k, v = res
|
1283
|
+
except AttributeError:
|
1284
|
+
pass
|
1285
|
+
|
1286
|
+
if k is None:
|
1287
|
+
continue
|
1288
|
+
|
1289
|
+
# Convert list to PrettyList
|
1290
|
+
if isinstance(v, list):
|
1291
|
+
v = PrettyList(v)
|
1292
|
+
|
1293
|
+
# Convert dict to PrettyDict
|
1294
|
+
if isinstance(v, dict):
|
1295
|
+
v = PrettyDict(v)
|
1296
|
+
|
1297
|
+
# Convert PrettyDict to NestedStateRepr
|
1298
|
+
if isinstance(v, PrettyDict):
|
1299
|
+
v = NestedStateRepr(v)
|
1300
|
+
|
1301
|
+
yield PrettyAttr(k, v)
|