brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1301 @@
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Pretty printing utilities for PyTree-like structures.
19
+
20
+ This module provides classes and functions for pretty printing and manipulating
21
+ tree-like data structures in BrainState. The main components are:
22
+
23
+ Classes:
24
+ - :class:`PrettyObject`: Base class for objects with pretty representation.
25
+ - :class:`PrettyDict`: Dictionary with pretty printing and tree utilities.
26
+ - :class:`NestedDict`: Nested mapping structure for hierarchical data.
27
+ - :class:`FlattedDict`: Flattened mapping with tuple keys for paths.
28
+ - :class:`PrettyList`: List with pretty printing capabilities.
29
+
30
+ Functions:
31
+ - :func:`flat_mapping`: Flatten a nested mapping to tuple keys.
32
+ - :func:`nest_mapping`: Unflatten a mapping back to nested structure.
33
+
34
+ All dictionary classes are registered as JAX pytrees and can be used with JAX
35
+ transformations. They support splitting, filtering, and merging operations for
36
+ organizing state in neural network models.
37
+
38
+ Example:
39
+ >>> from brainstate.util import NestedDict, flat_mapping
40
+ >>> state = NestedDict({'layer1': {'weight': 1, 'bias': 2}})
41
+ >>> flat = state.to_flat()
42
+ >>> print(flat)
43
+ FlattedDict({('layer1', 'weight'): 1, ('layer1', 'bias'): 2})
44
+ """
45
+
46
+ from collections import abc
47
+ from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dict, Callable, Generator
48
+
49
+ import jax
50
+
51
+ from brainstate.typing import Filter, PathParts
52
+ from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
53
+ from .filter import to_predicate
54
+ from .struct import dataclass
55
+
56
+ __all__ = [
57
+ 'PrettyDict',
58
+ 'NestedDict',
59
+ 'FlattedDict',
60
+ 'flat_mapping',
61
+ 'nest_mapping',
62
+ 'PrettyList',
63
+ 'PrettyObject',
64
+ ]
65
+
66
+ A = TypeVar('A')
67
+ K = TypeVar('K', bound=Hashable)
68
+ V = TypeVar('V')
69
+
70
+ FlattedStateMapping = Dict[PathParts, V]
71
+ ExtractValueFn = Callable[[Any], Any]
72
+ SetValueFn = Callable[[V, Any], V]
73
+
74
+
75
+ class PrettyObject(PrettyRepr):
76
+ """Base class for generating pretty representations of tree-like structures.
77
+
78
+ This class extends :class:`PrettyRepr` to provide a mechanism for generating
79
+ human-readable, pretty representations of tree-like data structures. It utilizes
80
+ custom functions to represent the object and its attributes in a structured and
81
+ visually appealing format.
82
+
83
+ The pretty representation is generated through the ``__pretty_repr__`` method,
84
+ which yields a sequence of pretty representation items using the general-purpose
85
+ representation functions.
86
+
87
+ Example:
88
+ >>> class MyTree(PrettyObject):
89
+ ... def __init__(self, value):
90
+ ... self.value = value
91
+ >>> tree = MyTree(42)
92
+ >>> print(tree) # Uses __pretty_repr__ for display
93
+ """
94
+
95
+ def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
96
+ """Generate a pretty representation of the object.
97
+
98
+ This method yields a sequence of pretty representation items for the object,
99
+ using specialized functions to represent the object and its attributes.
100
+
101
+ Yields:
102
+ Union[PrettyType, PrettyAttr]: Pretty representation items generated by
103
+ ``yield_unique_pretty_repr_items``.
104
+ """
105
+ yield from yield_unique_pretty_repr_items(
106
+ self,
107
+ repr_object=_repr_object_general,
108
+ repr_attr=_repr_attribute_general,
109
+ )
110
+
111
+ def __pretty_repr_item__(self, k: Any, v: Any) -> Optional[Tuple[Any, Any]]:
112
+ """Transform a key-value pair for pretty representation.
113
+
114
+ This method is used to generate a pretty representation of an item
115
+ in a data structure, typically for debugging or logging purposes. Subclasses
116
+ can override this method to customize how individual items are displayed.
117
+
118
+ Args:
119
+ k: The key of the item.
120
+ v: The value of the item.
121
+
122
+ Returns:
123
+ Optional[Tuple[Any, Any]]: A tuple containing the (key, value), or None
124
+ to skip this item in the representation.
125
+ """
126
+ return k, v
127
+
128
+
129
+ PrettyReprTree = PrettyObject
130
+
131
+
132
+ # The empty node is a struct.dataclass to be compatible with JAX.
133
+ @dataclass
134
+ class _EmptyNode:
135
+ """Sentinel class representing an empty node in tree structures."""
136
+ pass
137
+
138
+
139
+ IsLeafCallable = Callable[[Tuple[Any, ...], abc.Mapping[Any, Any]], bool]
140
+ _default_leaf: IsLeafCallable = lambda *args: False
141
+ empty_node = _EmptyNode()
142
+
143
+
144
+ def flat_mapping(
145
+ xs: abc.Mapping[Any, Any],
146
+ /,
147
+ *,
148
+ keep_empty_nodes: bool = False,
149
+ is_leaf: Optional[IsLeafCallable] = _default_leaf,
150
+ sep: Optional[str] = None
151
+ ) -> 'FlattedDict':
152
+ """Flatten a nested mapping into a flat mapping with tuple or string keys.
153
+
154
+ The nested keys are flattened to a tuple path. For example, ``{'a': {'b': 1}}``
155
+ becomes ``{('a', 'b'): 1}``. See :func:`nest_mapping` on how to restore the
156
+ nested structure.
157
+
158
+ Args:
159
+ xs: A nested mapping to flatten.
160
+ keep_empty_nodes: If True, replaces empty mappings with ``empty_node`` sentinel.
161
+ Otherwise, empty mappings are omitted from the result.
162
+ is_leaf: Optional function that takes ``(path, mapping)`` and returns True if
163
+ the mapping should be treated as a leaf (i.e., not flattened further).
164
+ Defaults to treating all mappings as non-leaves.
165
+ sep: If specified, keys in the returned mapping will be ``sep``-joined strings
166
+ instead of tuples. For example, with ``sep='/'``, ``('a', 'b')`` becomes ``'a/b'``.
167
+
168
+ Returns:
169
+ FlattedDict: A flattened mapping where nested keys are converted to tuples or strings.
170
+
171
+ Example:
172
+ >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
173
+ >>> flat_xs = flat_mapping(xs)
174
+ >>> flat_xs
175
+ FlattedDict({('foo',): 1, ('bar', 'a'): 2})
176
+
177
+ >>> # With separator
178
+ >>> flat_mapping(xs, sep='/')
179
+ FlattedDict({'foo': 1, 'bar/a': 2})
180
+
181
+ >>> # Keep empty nodes
182
+ >>> flat_mapping(xs, keep_empty_nodes=True)
183
+ FlattedDict({('foo',): 1, ('bar', 'a'): 2, ('bar', 'b'): _EmptyNode()})
184
+
185
+ Note:
186
+ Empty mappings are ignored by default and will not be restored by
187
+ :func:`nest_mapping` unless ``keep_empty_nodes=True``.
188
+ """
189
+ assert isinstance(xs, abc.Mapping), f'expected Mapping; got {type(xs).__qualname__}'
190
+
191
+ if sep is None:
192
+ def _key(path: Tuple[Any, ...]) -> Union[Tuple[Any, ...], str]:
193
+ return path
194
+ else:
195
+ def _key(path: Tuple[Any, ...]) -> Union[Tuple[Any, ...], str]:
196
+ return sep.join(path)
197
+
198
+ def _flatten(xs: Any, prefix: Tuple[Any, ...]) -> Dict[Any, Any]:
199
+ if not isinstance(xs, abc.Mapping) or is_leaf(prefix, xs):
200
+ return {_key(prefix): xs}
201
+
202
+ result = {}
203
+ is_empty = True
204
+ for key, value in xs.items():
205
+ is_empty = False
206
+ result.update(_flatten(value, prefix + (key,)))
207
+ if keep_empty_nodes and is_empty:
208
+ if prefix == (): # when the whole input is empty
209
+ return {}
210
+ return {_key(prefix): empty_node}
211
+ return result
212
+
213
+ return FlattedDict(_flatten(xs, ()))
214
+
215
+
216
+ def nest_mapping(
217
+ xs: Any,
218
+ /,
219
+ *,
220
+ sep: Optional[str] = None
221
+ ) -> 'NestedDict':
222
+ """Unflatten a mapping by converting tuple/string keys back to nested structure.
223
+
224
+ This is the inverse operation of :func:`flat_mapping`. It reconstructs a nested
225
+ mapping from a flattened one by interpreting tuple keys as paths in the nested
226
+ structure.
227
+
228
+ Args:
229
+ xs: A flattened mapping with tuple or string keys.
230
+ sep: Separator used to split string keys into paths. Must match the separator
231
+ used in :func:`flat_mapping()`. If None, keys are assumed to be tuples.
232
+
233
+ Returns:
234
+ NestedDict: A nested mapping reconstructed from the flattened structure.
235
+
236
+ Example:
237
+ >>> flat_xs = {
238
+ ... ('foo',): 1,
239
+ ... ('bar', 'a'): 2,
240
+ ... }
241
+ >>> xs = nest_mapping(flat_xs)
242
+ >>> xs
243
+ NestedDict({'foo': 1, 'bar': {'a': 2}})
244
+
245
+ >>> # With separator
246
+ >>> flat_xs_str = {'foo': 1, 'bar/a': 2}
247
+ >>> nest_mapping(flat_xs_str, sep='/')
248
+ NestedDict({'foo': 1, 'bar': {'a': 2}})
249
+
250
+ See Also:
251
+ :func:`flat_mapping`: The inverse operation that flattens a nested mapping.
252
+ """
253
+ assert isinstance(xs, abc.Mapping), f'expected Mapping; got {type(xs).__qualname__}'
254
+ result: Dict[Any, Any] = {}
255
+ for path, value in xs.items():
256
+ if sep is not None:
257
+ path = path.split(sep)
258
+ if value is empty_node:
259
+ value = {}
260
+ cursor = result
261
+ for key in path[:-1]:
262
+ if key not in cursor:
263
+ cursor[key] = {}
264
+ cursor = cursor[key]
265
+ cursor[path[-1]] = value
266
+ return NestedDict(result)
267
+
268
+
269
+ def _default_compare(x: Any, values: set) -> bool:
270
+ """Check if an object's identity is in a set of values.
271
+
272
+ Args:
273
+ x: The object to check.
274
+ values: A set of object identities to compare against.
275
+
276
+ Returns:
277
+ bool: True if the object's id is in the values set.
278
+ """
279
+ return id(x) in values
280
+
281
+
282
+ def _default_process(x: Any) -> int:
283
+ """Get the identity of an object.
284
+
285
+ Args:
286
+ x: The object to process.
287
+
288
+ Returns:
289
+ int: The object's identity (id).
290
+ """
291
+ return id(x)
292
+
293
+
294
+ class PrettyDict(dict, PrettyRepr):
295
+ """Base dictionary class with pretty representation and tree utilities.
296
+
297
+ This class extends the built-in dict with pretty printing capabilities and
298
+ provides base methods for tree operations. It serves as the parent class for
299
+ :class:`NestedDict` and :class:`FlattedDict`.
300
+
301
+ Attributes:
302
+ __module__ (str): Module identifier set to 'brainstate.util'.
303
+ """
304
+ __module__ = 'brainstate.util'
305
+
306
+ def __getattr__(self, key: K): # type: ignore[misc]
307
+ """Access dictionary items as attributes.
308
+
309
+ Args:
310
+ key: The dictionary key to access.
311
+
312
+ Returns:
313
+ The value associated with the key.
314
+
315
+ Raises:
316
+ KeyError: If the key is not found in the dictionary.
317
+ """
318
+ return self[key]
319
+
320
+ def treefy_state(self) -> Any:
321
+ """Convert :class:`State` objects to a reference tree of the state.
322
+
323
+ This method traverses the tree structure and converts any :class:`State` objects
324
+ to their reference form using ``to_state_ref()``.
325
+
326
+ Returns:
327
+ Any: A tree structure where State objects are replaced with their references.
328
+
329
+ Example:
330
+ >>> from brainstate._state import State
331
+ >>> d = PrettyDict({'a': State(1), 'b': 2})
332
+ >>> ref_tree = d.treefy_state()
333
+ """
334
+ from brainstate._state import State
335
+ leaves, treedef = jax.tree.flatten(self)
336
+ leaves = jax.tree.map(lambda x: x.to_state_ref() if isinstance(x, State) else x, leaves)
337
+ return treedef.unflatten(leaves)
338
+
339
+ def to_dict(self) -> Dict[K, Union[Dict[K, Any], V]]:
340
+ """Convert the :class:`PrettyDict` to a standard Python dictionary.
341
+
342
+ Returns:
343
+ Dict[K, Union[Dict[K, Any], V]]: A standard dictionary representation.
344
+ """
345
+ return dict(self) # type: ignore
346
+
347
+ def __repr__(self) -> str:
348
+ """Generate a pretty string representation of the dictionary.
349
+
350
+ Returns:
351
+ str: A formatted string representation using pretty printing.
352
+ """
353
+ return pretty_repr_object(self)
354
+
355
+ def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
356
+ """Generate pretty representation items for this dictionary.
357
+
358
+ Yields:
359
+ Union[PrettyType, PrettyAttr]: Pretty representation items.
360
+ """
361
+ yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
362
+
363
+ def split(self, *filters: Filter) -> Union['PrettyDict', Tuple['PrettyDict', ...]]:
364
+ """Split the dictionary based on filters (abstract method).
365
+
366
+ This method must be implemented by subclasses.
367
+
368
+ Args:
369
+ *filters: Filter specifications to split the dictionary.
370
+
371
+ Raises:
372
+ NotImplementedError: This is an abstract method.
373
+ """
374
+ raise NotImplementedError
375
+
376
+ def filter(self, *filters: Filter) -> Union['PrettyDict', Tuple['PrettyDict', ...]]:
377
+ """Filter the dictionary based on filters (abstract method).
378
+
379
+ This method must be implemented by subclasses.
380
+
381
+ Args:
382
+ *filters: Filter specifications to apply.
383
+
384
+ Raises:
385
+ NotImplementedError: This is an abstract method.
386
+ """
387
+ raise NotImplementedError
388
+
389
+ def merge(self, *states: 'PrettyDict') -> 'PrettyDict':
390
+ """Merge multiple dictionaries (abstract method).
391
+
392
+ This method must be implemented by subclasses.
393
+
394
+ Args:
395
+ *states: Additional PrettyDict objects to merge.
396
+
397
+ Raises:
398
+ NotImplementedError: This is an abstract method.
399
+ """
400
+ raise NotImplementedError
401
+
402
+ def subset(self, *filters: Filter) -> Union['PrettyDict', Tuple['PrettyDict', ...]]:
403
+ """Subset a :class:`PrettyDict` into one or more :class:`PrettyDict` instances.
404
+
405
+ The user must pass at least one :class:`Filter` (e.g., :class:`State`), and the
406
+ filters must be exhaustive (i.e., they must cover all :class:`State` types in
407
+ the :class:`PrettyDict`).
408
+
409
+ Args:
410
+ *filters: Filter specifications for subsetting.
411
+
412
+ Returns:
413
+ Union[PrettyDict, Tuple[PrettyDict, ...]]: One or more subsetted
414
+ dictionaries.
415
+ """
416
+ return self.filter(*filters)
417
+
418
+
419
+ class NestedStateRepr(PrettyRepr):
420
+ """Pretty representation wrapper for nested state dictionaries.
421
+
422
+ This class wraps a :class:`PrettyDict` to provide specialized pretty printing
423
+ that displays the dictionary in a nested, compact format with curly braces.
424
+
425
+ Args:
426
+ state: The PrettyDict to wrap for pretty representation.
427
+
428
+ Attributes:
429
+ state (PrettyDict): The wrapped dictionary state.
430
+ """
431
+
432
+ def __init__(self, state: PrettyDict) -> None:
433
+ self.state = state
434
+
435
+ def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
436
+ """Generate a compact pretty representation of the nested state.
437
+
438
+ Yields:
439
+ Union[PrettyType, PrettyAttr]: Pretty representation items, skipping the
440
+ outer PrettyType from the wrapped state.
441
+ """
442
+ yield PrettyType('', value_sep=': ', start='{', end='}')
443
+
444
+ for r in self.state.__pretty_repr__():
445
+ if isinstance(r, PrettyType):
446
+ continue
447
+ yield r
448
+
449
+ def __treescope_repr__(self, path: Any, subtree_renderer: Callable) -> Any:
450
+ """Generate a treescope representation for debugging tools.
451
+
452
+ Args:
453
+ path: The current path in the tree structure.
454
+ subtree_renderer: Callable to render subtrees.
455
+
456
+ Returns:
457
+ Any: Rendered representation of the nested state.
458
+ """
459
+ children = {}
460
+ for k, v in self.state.items():
461
+ if isinstance(v, PrettyDict):
462
+ v = NestedStateRepr(v)
463
+ children[k] = v
464
+ # Render as the dictionary itself at the same path.
465
+ return subtree_renderer(children, path=path)
466
+
467
+
468
+ def _default_repr_object(node: PrettyDict) -> Generator[PrettyType, None, None]:
469
+ """Generate the default object representation for PrettyDict.
470
+
471
+ Args:
472
+ node: The PrettyDict node to represent.
473
+
474
+ Yields:
475
+ PrettyType: A type representation with dict-like formatting.
476
+ """
477
+ yield PrettyType('', value_sep=': ', start='{', end='}')
478
+
479
+
480
+ def _default_repr_attr(node: PrettyDict) -> Generator[PrettyAttr, None, None]:
481
+ """Generate the default attribute representations for PrettyDict items.
482
+
483
+ This function converts list and dict values to their pretty equivalents
484
+ and wraps PrettyDict values in NestedStateRepr for compact display.
485
+
486
+ Args:
487
+ node: The PrettyDict node whose attributes to represent.
488
+
489
+ Yields:
490
+ PrettyAttr: Pretty attribute representations for each item.
491
+ """
492
+ for k, v in node.items():
493
+ if isinstance(v, list):
494
+ v = PrettyList(v)
495
+
496
+ if isinstance(v, dict):
497
+ v = PrettyDict(v)
498
+
499
+ if isinstance(v, PrettyDict):
500
+ v = NestedStateRepr(v)
501
+
502
+ yield PrettyAttr(repr(k), v)
503
+
504
+
505
+ class NestedDict(PrettyDict):
506
+ """A pytree-like nested mapping structure for organizing hierarchical data.
507
+
508
+ This class represents a nested mapping from strings or integers to leaves, where
509
+ valid leaf types include :class:`State`, ``jax.Array``, ``numpy.ndarray``, or
510
+ nested :class:`NestedDict` and :class:`FlattedDict` structures.
511
+
512
+ :class:`NestedDict` is a JAX pytree and can be used with JAX transformations.
513
+ It provides methods for flattening to :class:`FlattedDict`, splitting/filtering
514
+ based on predicates, and merging multiple nested structures.
515
+
516
+ Attributes:
517
+ __module__ (str): Module identifier set to 'brainstate.util'.
518
+
519
+ Example:
520
+ >>> from brainstate.util import NestedDict
521
+ >>> state = NestedDict({
522
+ ... 'layer1': {'weight': jnp.ones((3, 3)), 'bias': jnp.zeros(3)},
523
+ ... 'layer2': {'weight': jnp.ones((3, 1))}
524
+ ... })
525
+ >>> flat = state.to_flat()
526
+ >>> print(flat)
527
+ FlattedDict({('layer1', 'weight'): ..., ('layer1', 'bias'): ..., ...})
528
+
529
+ See Also:
530
+ :class:`FlattedDict`: The flattened counterpart with tuple keys.
531
+ :func:`flat_mapping`: Function to flatten a nested mapping.
532
+ :func:`nest_mapping`: Function to unflatten a flat mapping.
533
+ """
534
+ __module__ = 'brainstate.util'
535
+
536
+ def __or__(self, other: 'NestedDict') -> 'NestedDict':
537
+ if not other:
538
+ return self
539
+ assert isinstance(other, NestedDict), f'expected NestedDict; got {type(other).__qualname__}'
540
+ return NestedDict.merge(self, other)
541
+
542
+ def __sub__(self, other: 'NestedDict') -> 'NestedDict':
543
+ if not other:
544
+ return self
545
+
546
+ assert isinstance(other, NestedDict), f'expected NestedDict; got {type(other).__qualname__}'
547
+ self_flat = self.to_flat()
548
+ other_flat = other.to_flat()
549
+ diff = {k: v for k, v in self_flat.items() if k not in other_flat}
550
+ return NestedDict.from_flat(diff)
551
+
552
+ def to_flat(self) -> 'FlattedDict':
553
+ """
554
+ Flatten the nested mapping into a flat mapping.
555
+
556
+ Returns:
557
+ The flattened mapping.
558
+ """
559
+ return flat_mapping(self)
560
+
561
+ @classmethod
562
+ def from_flat(cls, flat_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]]) -> 'NestedDict':
563
+ """
564
+ Create a :class:`NestedDict` from a flat mapping.
565
+
566
+ Args:
567
+ flat_dict: The flat mapping.
568
+
569
+ Returns:
570
+ The :class:`NestedDict`.
571
+ """
572
+ nested_state = nest_mapping(dict(flat_dict))
573
+ return cls(nested_state)
574
+
575
+ def split(self, *filters: Filter) -> Union['NestedDict', Tuple['NestedDict', ...]]:
576
+ """
577
+ Split a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
578
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`),
579
+ and the filters must be exhaustive (i.e. they must cover all
580
+ :class:`State` types in the :class:`NestedDict`).
581
+
582
+ Example usage::
583
+
584
+ >>> import brainstate as brainstate
585
+
586
+ >>> class Model(brainstate.nn.Module):
587
+ ... def __init__(self):
588
+ ... super().__init__()
589
+ ... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
590
+ ... self.linear = brainstate.nn.Linear(2, 3)
591
+ ... def __call__(self, x):
592
+ ... return self.linear(self.batchnorm(x))
593
+
594
+ >>> model = Model()
595
+ >>> state_map = brainstate.graph.treefy_states(model)
596
+ >>> param, others = state_map.treefy_split(brainstate.ParamState, ...)
597
+
598
+ Arguments:
599
+ first: The first filter
600
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
601
+
602
+ Returns:
603
+ One or more ``States`` equal to the number of filters passed.
604
+ """
605
+ *states_, rest = _split_nested_mapping(self, *filters)
606
+ if rest:
607
+ raise ValueError(f'Non-exhaustive filters, got a non-empty remainder: {rest}.\n'
608
+ f'Use `...` to match all remaining elements.')
609
+
610
+ states: NestedDict | Tuple[NestedDict, ...]
611
+ if len(states_) == 1:
612
+ states = states_[0]
613
+ else:
614
+ states = tuple(states_)
615
+ return states # type: ignore[bad-return-type]
616
+
617
+ def filter(self, *filters: Filter) -> Union['NestedDict', Tuple['NestedDict', ...]]:
618
+ """
619
+ Filter a :class:`NestedDict` into one or more :class:`NestedDict`'s. The
620
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`).
621
+ This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
622
+ except the filters can be non-exhaustive.
623
+
624
+ Arguments:
625
+ first: The first filter
626
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
627
+
628
+ Returns:
629
+ One or more ``States`` equal to the number of filters passed.
630
+ """
631
+ states_ = _split_nested_mapping(self, *filters)
632
+ assert len(states_) == len(filters) + 1, f'Expected {len(filters) + 1} states, got {len(states_)}'
633
+ if len(filters) == 1:
634
+ return states_[0]
635
+ else:
636
+ return tuple(states_[:-1])
637
+
638
+ @staticmethod
639
+ def merge(*states) -> 'NestedDict':
640
+ """
641
+ The inverse of :meth:`split()`.
642
+
643
+ ``merge`` takes one or more :class:`PrettyDict`'s and creates a new :class:`PrettyDict`.
644
+
645
+ Args:
646
+ *states: Additional :class:`PrettyDict` objects.
647
+
648
+ Returns:
649
+ The merged :class:`PrettyDict`.
650
+ """
651
+ new_state: FlattedDict = FlattedDict()
652
+ for state in states:
653
+ if isinstance(state, NestedDict):
654
+ new_state.update(state.to_flat()) # type: ignore[attribute-error] # pytype is wrong here
655
+ elif isinstance(state, FlattedDict):
656
+ new_state.update(state)
657
+ else:
658
+ raise TypeError(f'Expected Nested or Flatted Mapping, got {type(state)} instead.')
659
+ return NestedDict.from_flat(new_state)
660
+
661
+ def to_pure_dict(self) -> Dict[str, Any]:
662
+ """Convert to a pure nested dictionary structure.
663
+
664
+ This method creates a standard Python dictionary with the same nested structure
665
+ as this NestedDict, without any special class wrappers.
666
+
667
+ Returns:
668
+ Dict[str, Any]: A pure nested dictionary representation.
669
+
670
+ Example:
671
+ >>> nested = NestedDict({'a': {'b': 1, 'c': 2}})
672
+ >>> pure = nested.to_pure_dict()
673
+ >>> type(pure)
674
+ <class 'dict'>
675
+ """
676
+ flat_values = {k: x for k, x in self.to_flat().items()}
677
+ return nest_mapping(flat_values).to_dict()
678
+
679
+ def replace_by_pure_dict(
680
+ self,
681
+ pure_dict: Dict[str, Any],
682
+ replace_fn: Optional[SetValueFn] = None
683
+ ) -> None:
684
+ """Replace values in this NestedDict using a pure dictionary.
685
+
686
+ This method updates the values in this NestedDict with values from a standard
687
+ Python dictionary. For :class:`State` objects with a ``replace`` method, the
688
+ replace method is called; otherwise, values are directly assigned.
689
+
690
+ Args:
691
+ pure_dict: A pure dictionary with matching structure containing new values.
692
+ replace_fn: Optional custom function to replace values. Takes ``(old_value, new_value)``
693
+ and returns the updated value. Defaults to calling ``replace()`` method if
694
+ available, otherwise direct assignment.
695
+
696
+ Raises:
697
+ ValueError: If a key in ``pure_dict`` is not found in this NestedDict.
698
+
699
+ Example:
700
+ >>> from brainstate._state import State
701
+ >>> nested = NestedDict({'a': State(1), 'b': 2})
702
+ >>> nested.replace_by_pure_dict({'a': 10, 'b': 20})
703
+ >>> nested['a'].value
704
+ 10
705
+ """
706
+ if replace_fn is None:
707
+ replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v
708
+ current_flat = self.to_flat()
709
+ for kp, v in flat_mapping(pure_dict).items():
710
+ if kp not in current_flat:
711
+ raise ValueError(f'key in pure_dict not available in state: {kp}')
712
+ current_flat[kp] = replace_fn(current_flat[kp], v)
713
+ self.update(nest_mapping(current_flat))
714
+
715
+
716
+ class FlattedDict(PrettyDict):
717
+ """
718
+ A pytree-like structure that contains a :class:`Mapping` from strings or integers to leaves.
719
+
720
+ A valid leaf type is either :class:`State`, ``jax.Array``, ``numpy.ndarray`` or Python variables.
721
+
722
+ A :class:`NestedDict` can be generated by either calling :func:`states()` or
723
+ :func:`nodes()` on the :class:`Module`.
724
+
725
+ Example usage::
726
+
727
+ >>> import brainstate as brainstate
728
+ >>> import jax.numpy as jnp
729
+ >>>
730
+ >>> class Model(brainstate.nn.Module):
731
+ ... def __init__(self):
732
+ ... super().__init__()
733
+ ... self.batchnorm = brainstate.nn.BatchNorm1d([10, 3])
734
+ ... self.linear = brainstate.nn.Linear(2, 3)
735
+ ... def __call__(self, x):
736
+ ... return self.linear(self.batchnorm(x))
737
+ >>>
738
+ >>> model = Model()
739
+
740
+ >>> # retrieve the states of the model
741
+ >>> model.states() # with the same to the function of ``brainstate.graph.states()``
742
+ FlattedDict({
743
+ ('batchnorm', 'running_mean'): LongTermState(
744
+ value=Array([[0., 0., 0.]], dtype=float32)
745
+ ),
746
+ ('batchnorm', 'running_var'): LongTermState(
747
+ value=Array([[1., 1., 1.]], dtype=float32)
748
+ ),
749
+ ('batchnorm', 'weight'): ParamState(
750
+ value={'bias': Array([[0., 0., 0.]], dtype=float32), 'scale': Array([[1., 1., 1.]], dtype=float32)}
751
+ ),
752
+ ('linear', 'weight'): ParamState(
753
+ value={'weight': Array([[-0.21467684, 0.7621282 , -0.50756454, -0.49047297],
754
+ [-0.90413696, 0.6711 , -0.1254792 , 0.50412565],
755
+ [ 0.23975602, 0.47905368, 1.4851435 , 0.16745673]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
756
+ )
757
+ })
758
+
759
+ >>> # retrieve the nodes of the model
760
+ >>> model.nodes() # with the same to the function of ``brainstate.graph.nodes()``
761
+ FlattedDict({
762
+ ('batchnorm',): BatchNorm1d(
763
+ in_size=(10, 3),
764
+ out_size=(10, 3),
765
+ affine=True,
766
+ bias_initializer=Constant(value=0.0, dtype=<class 'numpy.float32'>),
767
+ scale_initializer=Constant(value=1.0, dtype=<class 'numpy.float32'>),
768
+ dtype=<class 'numpy.float32'>,
769
+ track_running_stats=True,
770
+ momentum=Array(shape=(), dtype=float32),
771
+ epsilon=Array(shape=(), dtype=float32),
772
+ feature_axis=(1,),
773
+ axis_name=None,
774
+ axis_index_groups=None,
775
+ running_mean=LongTermState(
776
+ value=Array(shape=(1, 3), dtype=float32)
777
+ ),
778
+ running_var=LongTermState(
779
+ value=Array(shape=(1, 3), dtype=float32)
780
+ ),
781
+ weight=ParamState(
782
+ value={'bias': Array(shape=(1, 3), dtype=float32), 'scale': Array(shape=(1, 3), dtype=float32)}
783
+ )
784
+ ),
785
+ ('linear',): Linear(
786
+ in_size=(10, 3),
787
+ out_size=(10, 4),
788
+ w_mask=None,
789
+ weight=ParamState(
790
+ value={'bias': Array(shape=(4,), dtype=float32), 'weight': Array(shape=(3, 4), dtype=float32)}
791
+ )
792
+ ),
793
+ (): Model(
794
+ batchnorm=BatchNorm1d(...),
795
+ linear=Linear(...)
796
+ )
797
+ })
798
+ """
799
+ __module__ = 'brainstate.util'
800
+
801
+ def __or__(self, other: 'FlattedDict') -> 'FlattedDict':
802
+ if not other:
803
+ return self
804
+ assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
805
+ return FlattedDict.merge(self, other)
806
+
807
+ def __sub__(self, other: 'FlattedDict') -> 'FlattedDict':
808
+ if not other:
809
+ return self
810
+ assert isinstance(other, FlattedDict), f'expected NestedDict; got {type(other).__qualname__}'
811
+ diff = {k: v for k, v in self.items() if k not in other}
812
+ return FlattedDict(diff)
813
+
814
+ def to_nest(self) -> NestedDict:
815
+ """
816
+ Unflatten the flat mapping into a nested mapping.
817
+
818
+ Returns:
819
+ The nested mapping.
820
+ """
821
+ return nest_mapping(self)
822
+
823
+ @classmethod
824
+ def from_nest(
825
+ cls, nested_dict: abc.Mapping[PathParts, V] | Iterable[tuple[PathParts, V]],
826
+ ) -> 'FlattedDict':
827
+ """
828
+ Create a :class:`NestedDict` from a flat mapping.
829
+
830
+ Args:
831
+ nested_dict: The flat mapping.
832
+
833
+ Returns:
834
+ The :class:`NestedDict`.
835
+ """
836
+ return flat_mapping(nested_dict)
837
+
838
+ def split( # type: ignore[misc]
839
+ self,
840
+ first: Filter,
841
+ /,
842
+ *filters: Filter
843
+ ) -> Union['FlattedDict', tuple['FlattedDict', ...]]:
844
+ """
845
+ Split a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
846
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`),
847
+ and the filters must be exhaustive (i.e. they must cover all
848
+ :class:`State` types in the :class:`NestedDict`).
849
+
850
+ Arguments:
851
+ first: The first filter
852
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
853
+
854
+ Returns:
855
+ One or more ``States`` equal to the number of filters passed.
856
+ """
857
+ filters = (first, *filters)
858
+ *states_, rest = _split_flatted_mapping(self, *filters)
859
+ if rest:
860
+ raise ValueError(f'Non-exhaustive filters, got a non-empty remainder: {rest}.\n'
861
+ f'Use `...` to match all remaining elements.')
862
+
863
+ states: FlattedDict | Tuple[FlattedDict, ...]
864
+ if len(states_) == 1:
865
+ states = states_[0]
866
+ else:
867
+ states = tuple(states_)
868
+ return states # type: ignore[bad-return-type]
869
+
870
+ def filter(
871
+ self,
872
+ first: Filter,
873
+ /,
874
+ *filters: Filter,
875
+ ) -> Union['FlattedDict', Tuple['FlattedDict', ...]]:
876
+ """
877
+ Filter a :class:`FlattedDict` into one or more :class:`FlattedDict`'s. The
878
+ user must pass at least one `:class:`Filter` (i.e. :class:`State`).
879
+ This method is similar to :meth:`split() <flax.nnx.NestedDict.state.split>`,
880
+ except the filters can be non-exhaustive.
881
+
882
+ Arguments:
883
+ first: The first filter
884
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
885
+
886
+ Returns:
887
+ One or more ``States`` equal to the number of filters passed.
888
+ """
889
+ *states_, _rest = _split_flatted_mapping(self, first, *filters)
890
+ assert len(states_) == len(filters) + 1, f'Expected {len(filters) + 1} states, got {len(states_)}'
891
+ if len(states_) == 1:
892
+ states = states_[0]
893
+ else:
894
+ states = tuple(states_)
895
+ return states # type: ignore[bad-return-type]
896
+
897
+ @staticmethod
898
+ def merge(*states: Union['FlattedDict', 'NestedDict']) -> 'FlattedDict':
899
+ """
900
+ The inverse of :meth:`split()`.
901
+
902
+ ``merge`` takes one or more :class:`FlattedDict`'s and creates a new :class:`FlattedDict`.
903
+
904
+ Args:
905
+ state: A :class:`PrettyDict` object.
906
+ *states: Additional :class:`PrettyDict` objects.
907
+
908
+ Returns:
909
+ The merged :class:`PrettyDict`.
910
+ """
911
+ new_state: FlattedStateMapping[V] = {}
912
+ for state in states:
913
+ if isinstance(state, NestedDict):
914
+ new_state.update(state.to_flat()) # type: ignore[attribute-error] # pytype is wrong here
915
+ elif isinstance(state, FlattedDict):
916
+ new_state.update(state)
917
+ else:
918
+ raise TypeError(f'Expected Nested or Flatted Mapping, got {type(state)} instead.')
919
+ return FlattedDict(new_state)
920
+
921
+ def to_dict_values(self) -> Dict[PathParts, Any]:
922
+ """Convert a FlattedDict containing State objects to a dictionary of raw values.
923
+
924
+ This method extracts the underlying values from any :class:`State` objects in the
925
+ FlattedDict, creating a new dictionary with the same keys but where each State
926
+ object is replaced by its ``value`` attribute. Non-State objects are kept as is.
927
+
928
+ Returns:
929
+ Dict[PathParts, Any]: A dictionary with the same keys as the FlattedDict, but
930
+ where each State object is replaced by its value. Non-State objects remain
931
+ unchanged.
932
+
933
+ Example:
934
+ >>> from brainstate._state import ParamState
935
+ >>> flat_dict = FlattedDict({
936
+ ... ('model', 'layer1', 'weight'): ParamState(value=jnp.ones((10, 5)))
937
+ ... })
938
+ >>> values = flat_dict.to_dict_values()
939
+ >>> values[('model', 'layer1', 'weight')]
940
+ Array([[1., 1., ...]], dtype=float32)
941
+ """
942
+ from brainstate._state import State
943
+ return {
944
+ k: v.value if isinstance(v, State) else v
945
+ for k, v in self.items()
946
+ }
947
+
948
+ def assign_dict_values(self, data: Dict[PathParts, Any]) -> None:
949
+ """Assign values from a dictionary to this FlattedDict.
950
+
951
+ This method updates the values in the FlattedDict with values from the provided
952
+ dictionary. For keys that correspond to :class:`State` objects, the ``value``
953
+ attribute of the State is updated. For other keys, the value in the FlattedDict
954
+ is directly replaced with the new value.
955
+
956
+ Args:
957
+ data: A dictionary containing the values to assign, where keys must match
958
+ those in the FlattedDict.
959
+
960
+ Raises:
961
+ KeyError: If a key in the FlattedDict is not present in the provided dictionary.
962
+
963
+ Example:
964
+ >>> from brainstate._state import ParamState
965
+ >>> flat_dict = FlattedDict({
966
+ ... ('model', 'weight'): ParamState(value=jnp.zeros((5, 5)))
967
+ ... })
968
+ >>> flat_dict.assign_dict_values({('model', 'weight'): jnp.ones((5, 5))})
969
+ # The ParamState's value is now an array of ones
970
+ """
971
+ from brainstate._state import State
972
+ for k in self.keys():
973
+ if k not in data:
974
+ raise KeyError(f'Invalid key: {k!r}')
975
+ val = self[k]
976
+ if isinstance(val, State):
977
+ val.value = data[k]
978
+ else:
979
+ self[k] = data[k]
980
+
981
+
982
+ def _split_nested_mapping(
983
+ mapping: 'NestedDict',
984
+ *filters: Filter,
985
+ ) -> Tuple['NestedDict', ...]:
986
+ """Split a nested mapping into multiple nested mappings based on filters.
987
+
988
+ This internal function partitions a NestedDict into multiple NestedDicts based on
989
+ filter predicates. Items that match each filter are placed in separate mappings,
990
+ with unmatched items going to the final mapping.
991
+
992
+ Args:
993
+ mapping: The NestedDict to split.
994
+ *filters: Filter specifications. The catch-all filters ``...`` or ``True`` can
995
+ only be used as the last filter.
996
+
997
+ Returns:
998
+ Tuple[NestedDict, ...]: A tuple of n+1 NestedDicts, where n is the number
999
+ of filters. The last mapping contains items that didn't match any filter.
1000
+
1001
+ Raises:
1002
+ ValueError: If ``...`` or ``True`` is used before the last filter position.
1003
+ AssertionError: If mapping is not a NestedDict.
1004
+ """
1005
+ # Check if the filters are exhaustive
1006
+ for i, filter_ in enumerate(filters):
1007
+ if filter_ in (..., True) and i != len(filters) - 1:
1008
+ remaining_filters = filters[i + 1:]
1009
+ if not all(f in (..., True) for f in remaining_filters):
1010
+ raise ValueError('`...` or `True` can only be used as the last filters, '
1011
+ f'got {filter_} it at index {i}.')
1012
+
1013
+ # Change the filters to predicates
1014
+ predicates = tuple(map(to_predicate, filters))
1015
+
1016
+ # We have n + 1 state mappings, where n is the number of predicates
1017
+ # The last state mapping is for values that don't match any predicate
1018
+ flat_states: Tuple[FlattedStateMapping[V], ...] = tuple({} for _ in range(len(predicates) + 1))
1019
+
1020
+ assert isinstance(mapping, NestedDict), f'expected NestedDict; got {type(mapping).__qualname__}'
1021
+ flat_state = mapping.to_flat()
1022
+ for path, value in flat_state.items():
1023
+ for i, predicate in enumerate(predicates):
1024
+ if predicate(path, value):
1025
+ flat_states[i][path] = value # type: ignore[index]
1026
+ break
1027
+ else:
1028
+ # If we didn't break, set leaf to last state
1029
+ flat_states[-1][path] = value # type: ignore[index]
1030
+
1031
+ return tuple(NestedDict.from_flat(flat_state) for flat_state in flat_states)
1032
+
1033
+
1034
+ def _split_flatted_mapping(
1035
+ mapping: FlattedDict,
1036
+ *filters: Filter,
1037
+ ) -> Tuple[FlattedDict, ...]:
1038
+ """Split a flattened mapping into multiple flattened mappings based on filters.
1039
+
1040
+ This internal function partitions a FlattedDict into multiple FlattedDicts based on
1041
+ filter predicates. Items that match each filter are placed in separate mappings,
1042
+ with unmatched items going to the final mapping.
1043
+
1044
+ Args:
1045
+ mapping: The FlattedDict to split.
1046
+ *filters: Filter specifications. The catch-all filters ``...`` or ``True`` can
1047
+ only be used as the last filter.
1048
+
1049
+ Returns:
1050
+ Tuple[FlattedDict, ...]: A tuple of n+1 FlattedDicts, where n is the number
1051
+ of filters. The last mapping contains items that didn't match any filter.
1052
+
1053
+ Raises:
1054
+ ValueError: If ``...`` or ``True`` is used before the last filter position.
1055
+ AssertionError: If mapping is not a FlattedDict.
1056
+ """
1057
+ # Check if the filters are exhaustive
1058
+ for i, filter_ in enumerate(filters):
1059
+ if filter_ in (..., True) and i != len(filters) - 1:
1060
+ remaining_filters = filters[i + 1:]
1061
+ if not all(f in (..., True) for f in remaining_filters):
1062
+ raise ValueError('`...` or `True` can only be used as the last filters, '
1063
+ f'got {filter_} it at index {i}.')
1064
+
1065
+ # Change the filters to predicates
1066
+ predicates = tuple(map(to_predicate, filters))
1067
+
1068
+ # We have n + 1 state mappings, where n is the number of predicates
1069
+ # The last state mapping is for values that don't match any predicate
1070
+ flat_states: Tuple[FlattedStateMapping[V], ...] = tuple({} for _ in range(len(predicates) + 1))
1071
+
1072
+ assert isinstance(mapping, FlattedDict), f'expected FlattedDict; got {type(mapping).__qualname__}'
1073
+ for path, value in mapping.items():
1074
+ for i, predicate in enumerate(predicates):
1075
+ if predicate(path, value):
1076
+ flat_states[i][path] = value # type: ignore[index]
1077
+ break
1078
+ else:
1079
+ # If we didn't break, set leaf to last state
1080
+ flat_states[-1][path] = value # type: ignore[index]
1081
+
1082
+ return tuple(FlattedDict(flat_state) for flat_state in flat_states)
1083
+
1084
+
1085
+ # Register :class:`NestedDict` as a pytree
1086
+ def _nest_flatten_with_keys(x: NestedDict) -> Tuple[Tuple[Tuple[jax.tree_util.DictKey, Any], ...], Tuple[K, ...]]:
1087
+ """Flatten a NestedDict for JAX pytree registration with keys.
1088
+
1089
+ Args:
1090
+ x: The NestedDict to flatten.
1091
+
1092
+ Returns:
1093
+ Tuple containing:
1094
+
1095
+ - Tuple of (key, value) pairs where keys are wrapped in DictKey
1096
+ - Tuple of static keys for reconstruction
1097
+ """
1098
+ items = sorted(x.items())
1099
+ children = tuple((jax.tree_util.DictKey(key), value) for key, value in items)
1100
+ return children, tuple(key for key, _ in items)
1101
+
1102
+
1103
+ def _nest_unflatten(
1104
+ static: Tuple[K, ...],
1105
+ leaves: Union[Tuple[V, ...], Tuple[Dict]],
1106
+ ) -> NestedDict:
1107
+ """Unflatten a NestedDict from pytree components.
1108
+
1109
+ Args:
1110
+ static: Tuple of keys for reconstruction.
1111
+ leaves: Tuple of leaf values.
1112
+
1113
+ Returns:
1114
+ NestedDict: Reconstructed NestedDict.
1115
+ """
1116
+ return NestedDict(zip(static, leaves))
1117
+
1118
+
1119
+ jax.tree_util.register_pytree_with_keys(
1120
+ NestedDict,
1121
+ _nest_flatten_with_keys,
1122
+ _nest_unflatten
1123
+ ) # type: ignore[arg-type]
1124
+
1125
+
1126
+ # Register :class:`FlattedDict` as a pytree
1127
+
1128
+ def _flat_unflatten(
1129
+ static: Tuple[K, ...],
1130
+ leaves: Union[Tuple[V, ...], Tuple[Dict]],
1131
+ ) -> FlattedDict:
1132
+ """Unflatten a FlattedDict from pytree components.
1133
+
1134
+ Args:
1135
+ static: Tuple of keys for reconstruction.
1136
+ leaves: Tuple of leaf values.
1137
+
1138
+ Returns:
1139
+ FlattedDict: Reconstructed FlattedDict.
1140
+ """
1141
+ return FlattedDict(zip(static, leaves))
1142
+
1143
+
1144
+ jax.tree_util.register_pytree_with_keys(
1145
+ FlattedDict,
1146
+ _nest_flatten_with_keys,
1147
+ _flat_unflatten
1148
+ ) # type: ignore[arg-type]
1149
+
1150
+
1151
+ @jax.tree_util.register_pytree_node_class
1152
+ class PrettyList(list, PrettyRepr):
1153
+ """A list subclass with pretty representation and JAX pytree compatibility.
1154
+
1155
+ This class extends the built-in list with pretty printing capabilities and
1156
+ registers itself as a JAX pytree for use in JAX transformations.
1157
+
1158
+ Attributes:
1159
+ __module__ (str): Module identifier set to 'brainstate.util'.
1160
+
1161
+ Example:
1162
+ >>> from brainstate.util import PrettyList
1163
+ >>> lst = PrettyList([1, 2, {'a': 3}])
1164
+ >>> print(lst) # Pretty formatted output
1165
+ [1, 2, {'a': 3}]
1166
+ """
1167
+ __module__ = 'brainstate.util'
1168
+
1169
+ def __pretty_repr__(self) -> Generator[Union[PrettyType, PrettyAttr], None, None]:
1170
+ """Generate pretty representation items for this list.
1171
+
1172
+ Yields:
1173
+ Union[PrettyType, PrettyAttr]: Pretty representation items.
1174
+ """
1175
+ yield from yield_unique_pretty_repr_items(self, _list_repr_object, _list_repr_attr)
1176
+
1177
+ def __repr__(self) -> str:
1178
+ """Generate a pretty string representation of the list.
1179
+
1180
+ Returns:
1181
+ str: A formatted string representation using pretty printing.
1182
+ """
1183
+ return pretty_repr_object(self)
1184
+
1185
+ def tree_flatten(self) -> Tuple[list, Tuple]:
1186
+ """Flatten the list for JAX pytree operations.
1187
+
1188
+ Returns:
1189
+ Tuple containing:
1190
+ - The list items as children
1191
+ - Empty tuple as auxiliary data
1192
+ """
1193
+ return list(self), ()
1194
+
1195
+ @classmethod
1196
+ def tree_unflatten(cls, aux_data: Tuple, children: list) -> 'PrettyList':
1197
+ """Reconstruct a PrettyList from pytree components.
1198
+
1199
+ Args:
1200
+ aux_data: Auxiliary data (unused).
1201
+ children: List items to reconstruct from.
1202
+
1203
+ Returns:
1204
+ PrettyList: Reconstructed PrettyList.
1205
+ """
1206
+ return cls(children)
1207
+
1208
+
1209
+ def _list_repr_attr(node: PrettyList) -> Generator[PrettyAttr, None, None]:
1210
+ """Generate attribute representations for PrettyList items.
1211
+
1212
+ This function converts list and dict values to their pretty equivalents
1213
+ and wraps PrettyDict values in NestedStateRepr for compact display.
1214
+
1215
+ Args:
1216
+ node: The PrettyList whose items to represent.
1217
+
1218
+ Yields:
1219
+ PrettyAttr: Pretty attribute representations for each item with empty keys.
1220
+ """
1221
+ for v in node:
1222
+ if isinstance(v, list):
1223
+ v = PrettyList(v)
1224
+ if isinstance(v, dict):
1225
+ v = PrettyDict(v)
1226
+ if isinstance(v, PrettyDict):
1227
+ v = NestedStateRepr(v)
1228
+ yield PrettyAttr('', v)
1229
+
1230
+
1231
+ def _list_repr_object(node: PrettyList) -> Generator[PrettyType, None, None]:
1232
+ """Generate the object representation for PrettyList.
1233
+
1234
+ Args:
1235
+ node: The PrettyList to represent.
1236
+
1237
+ Yields:
1238
+ PrettyType: A type representation with list-like formatting.
1239
+ """
1240
+ yield PrettyType('', value_sep='', start='[', end=']')
1241
+
1242
+
1243
+ def _repr_object_general(node: Any) -> Generator[PrettyType, None, None]:
1244
+ """Generate a general object representation for any PrettyObject.
1245
+
1246
+ This function creates a pretty representation of an object that includes
1247
+ the type of the object with parentheses formatting.
1248
+
1249
+ Args:
1250
+ node: The object to be represented.
1251
+
1252
+ Yields:
1253
+ PrettyType: A PrettyType object representing the type of the node,
1254
+ with specified value separator, start, and end characters.
1255
+ """
1256
+ yield PrettyType(type(node), value_sep='=', start='(', end=')')
1257
+
1258
+
1259
+ def _repr_attribute_general(node: Any) -> Generator[PrettyAttr, None, None]:
1260
+ """Generate a pretty representation of the attributes of a general object.
1261
+
1262
+ This function iterates over the attributes of a given node and generates
1263
+ a pretty representation for each attribute. It handles conversion of lists
1264
+ and dictionaries to their pretty representation counterparts and yields a
1265
+ PrettyAttr object for each visible attribute.
1266
+
1267
+ The function respects the ``__pretty_repr_item__`` method if available on
1268
+ the node, allowing custom filtering or transformation of attribute items.
1269
+
1270
+ Args:
1271
+ node: The object whose attributes are to be represented.
1272
+
1273
+ Yields:
1274
+ PrettyAttr: A PrettyAttr object representing the key and value of
1275
+ each attribute in a pretty format.
1276
+ """
1277
+ for k, v in vars(node).items():
1278
+ try:
1279
+ res = node.__pretty_repr_item__(k, v)
1280
+ if res is None:
1281
+ continue
1282
+ k, v = res
1283
+ except AttributeError:
1284
+ pass
1285
+
1286
+ if k is None:
1287
+ continue
1288
+
1289
+ # Convert list to PrettyList
1290
+ if isinstance(v, list):
1291
+ v = PrettyList(v)
1292
+
1293
+ # Convert dict to PrettyDict
1294
+ if isinstance(v, dict):
1295
+ v = PrettyDict(v)
1296
+
1297
+ # Convert PrettyDict to NestedStateRepr
1298
+ if isinstance(v, PrettyDict):
1299
+ v = NestedStateRepr(v)
1300
+
1301
+ yield PrettyAttr(k, v)