brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -146
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -470
  58. brainstate/nn/_delay_test.py +238 -0
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1361
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1120
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -208
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.7.dist-info/RECORD +0 -131
  133. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1738 +1,1738 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- from __future__ import annotations
19
-
20
- import dataclasses
21
- from typing import (
22
- Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
23
- Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload
24
- )
25
-
26
- import jax
27
- import numpy as np
28
- from typing_extensions import TypeGuard, Unpack
29
-
30
- from brainstate._state import State, TreefyState
31
- from brainstate._utils import set_module_as
32
- from brainstate.typing import PathParts, Filter, Predicate, Key
33
- from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
34
- from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
- from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
- from brainstate.util.struct import FrozenDict
37
- from brainstate.util.filter import to_predicate
38
-
39
- _max_int = np.iinfo(np.int32).max
40
-
41
- __all__ = [
42
- # state management in the given graph or node
43
- 'pop_states', 'nodes', 'states', 'treefy_states', 'update_states',
44
-
45
- # graph node operations
46
- 'flatten', 'unflatten', 'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef', 'call',
47
-
48
- # others
49
- 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef'
50
- ]
51
-
52
- A = TypeVar('A')
53
- B = TypeVar('B')
54
- C = TypeVar('C')
55
- F = TypeVar('F', bound=Callable)
56
-
57
- HA = TypeVar('HA', bound=Hashable)
58
- HB = TypeVar('HB', bound=Hashable)
59
-
60
- Index = int
61
- Names = Sequence[int]
62
- Node = TypeVar('Node')
63
- Leaf = TypeVar('Leaf')
64
- AuxData = TypeVar('AuxData')
65
-
66
- StateLeaf = TreefyState[Any]
67
- NodeLeaf = State[Any]
68
- GraphStateMapping = NestedDict[Key, StateLeaf]
69
-
70
-
71
- # --------------------------------------------------------
72
-
73
-
74
- def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
75
- return isinstance(x, TreefyState)
76
-
77
-
78
- def _is_node_leaf(x: Any) -> TypeGuard[NodeLeaf]:
79
- return isinstance(x, State)
80
-
81
-
82
- class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
83
- """
84
- A mapping that uses object id as the hash for the keys.
85
-
86
- This mapping is useful when we want to keep track of objects
87
- that are being referenced by other objects.
88
-
89
- Args:
90
- mapping: A mapping or iterable of key-value pairs.
91
-
92
- """
93
- __module__ = 'brainstate.graph'
94
-
95
- def __init__(self, mapping: Mapping[A, B] | Iterable[Tuple[A, B]] = ()):
96
- self._mapping: Dict[int, Tuple[A, B]] = {}
97
- self.update(mapping)
98
-
99
- def __getitem__(self, key: A) -> B:
100
- return self._mapping[id(key)][1]
101
-
102
- def __contains__(self, key: Any) -> bool:
103
- return id(key) in self._mapping
104
-
105
- def __setitem__(self, key: A, value: B):
106
- self._mapping[id(key)] = (key, value)
107
-
108
- def __delitem__(self, key: A):
109
- del self._mapping[id(key)]
110
-
111
- def __iter__(self) -> Iterator[A]:
112
- return (key for key, _ in self._mapping.values())
113
-
114
- def __len__(self) -> int:
115
- return len(self._mapping)
116
-
117
- def __str__(self) -> str:
118
- return repr(self)
119
-
120
-
121
- @dataclasses.dataclass(frozen=True)
122
- class NodeImplBase(Generic[Node, Leaf, AuxData]):
123
- type: type
124
- flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
125
-
126
- def node_dict(self, node: Node) -> dict[Key, Leaf]:
127
- nodes, _ = self.flatten(node)
128
- return dict(nodes)
129
-
130
-
131
- @dataclasses.dataclass(frozen=True)
132
- class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
133
- set_key: Callable[[Node, Key, Leaf], None]
134
- pop_key: Callable[[Node, Key], Leaf]
135
- create_empty: Callable[[AuxData], Node]
136
- clear: Callable[[Node], None]
137
-
138
- def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]):
139
- for key, value in items:
140
- self.set_key(node, key, value)
141
-
142
-
143
- @dataclasses.dataclass(frozen=True)
144
- class PyTreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
145
- unflatten: Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
146
-
147
-
148
- NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf, AuxData]]
149
-
150
- # --------------------------------------------------------
151
- # Graph Node implementation: start
152
- # --------------------------------------------------------
153
-
154
- _node_impl_for_type: dict[type, NodeImpl[Any, Any, Any]] = {}
155
-
156
-
157
- def register_graph_node_type(
158
- type: type,
159
- flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
160
- set_key: Callable[[Node, Key, Leaf], None],
161
- pop_key: Callable[[Node, Key], Leaf],
162
- create_empty: Callable[[AuxData], Node],
163
- clear: Callable[[Node], None],
164
- ):
165
- """
166
- Register a graph node type.
167
-
168
- Args:
169
- type: The type of the node.
170
- flatten: A function that flattens the node into a sequence of key-value pairs.
171
- set_key: A function that sets a key in the node.
172
- pop_key: A function that pops a key from the node.
173
- create_empty: A function that creates an empty node.
174
- clear: A function that clears the node
175
- """
176
- _node_impl_for_type[type] = GraphNodeImpl(
177
- type=type,
178
- flatten=flatten,
179
- set_key=set_key,
180
- pop_key=pop_key,
181
- create_empty=create_empty,
182
- clear=clear,
183
- )
184
-
185
-
186
- # --------------------------------------------------------
187
- # Graph node implementation: end
188
- # --------------------------------------------------------
189
-
190
-
191
- def _is_node(x: Any) -> bool:
192
- return _is_graph_node(x) or _is_pytree_node(x)
193
-
194
-
195
- def _is_pytree_node(x: Any) -> bool:
196
- return not jax.tree_util.all_leaves((x,))
197
-
198
-
199
- def _is_graph_node(x: Any) -> bool:
200
- return type(x) in _node_impl_for_type
201
-
202
-
203
- def _is_node_type(x: type[Any]) -> bool:
204
- return x in _node_impl_for_type or x is PytreeType
205
-
206
-
207
- def _get_node_impl(x: Node) -> NodeImpl[Node, Any, Any]:
208
- if isinstance(x, State):
209
- raise ValueError(f'State is not a node: {x}')
210
-
211
- node_type = type(x)
212
- if node_type not in _node_impl_for_type:
213
- if _is_pytree_node(x):
214
- return PYTREE_NODE_IMPL
215
- else:
216
- raise ValueError(f'Unknown node type: {x}')
217
-
218
- return _node_impl_for_type[node_type]
219
-
220
-
221
- def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, Any, Any]:
222
- if x is PytreeType:
223
- return PYTREE_NODE_IMPL
224
- return _node_impl_for_type[x]
225
-
226
-
227
- class HashableMapping(Mapping[HA, HB], Hashable):
228
- def __init__(self, mapping: Mapping[HA, HB] | Iterable[tuple[HA, HB]]):
229
- self._mapping = dict(mapping)
230
-
231
- def __contains__(self, key: object) -> bool:
232
- return key in self._mapping
233
-
234
- def __getitem__(self, key: HA) -> HB:
235
- return self._mapping[key]
236
-
237
- def __iter__(self) -> Iterator[HA]:
238
- return iter(self._mapping)
239
-
240
- def __len__(self) -> int:
241
- return len(self._mapping)
242
-
243
- def __hash__(self) -> int:
244
- return hash(tuple(sorted(self._mapping.items())))
245
-
246
- def __eq__(self, other: Any) -> bool:
247
- return isinstance(other, HashableMapping) and self._mapping == other._mapping
248
-
249
- def __repr__(self) -> str:
250
- return repr(self._mapping)
251
-
252
-
253
- class GraphDef(Generic[Node]):
254
- """
255
- A base dataclass that denotes the graph structure of a :class:`Node`.
256
-
257
- It contains two main components:
258
- - type: The type of the node.
259
- - index: The index of the node in the graph.
260
-
261
- It has two concrete subclasses:
262
- - :class:`NodeRef`: A reference to a node in the graph.
263
- - :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
264
-
265
- """
266
- type: type[Node]
267
- index: int
268
-
269
-
270
- @dataclasses.dataclass(frozen=True, repr=False)
271
- class NodeRef(GraphDef[Node], PrettyRepr):
272
- """
273
- A reference to a node in the graph.
274
-
275
- The node can be instances of :class:`Node` or :class:`State`.
276
- """
277
- type: type[Node]
278
- index: int
279
-
280
- def __pretty_repr__(self):
281
- yield PrettyType(type=type(self))
282
- yield PrettyAttr('type', self.type.__name__)
283
- yield PrettyAttr('index', self.index)
284
-
285
- def __treescope_repr__(self, path, subtree_renderer):
286
- """
287
- Treescope repr for the object.
288
- """
289
- import treescope # type: ignore[import-not-found,import-untyped]
290
- return treescope.repr_lib.render_object_constructor(
291
- object_type=type(self),
292
- attributes={'type': self.type, 'index': self.index},
293
- path=path,
294
- subtree_renderer=subtree_renderer,
295
- )
296
-
297
-
298
- jax.tree_util.register_static(NodeRef)
299
-
300
-
301
- @dataclasses.dataclass(frozen=True, repr=False)
302
- class NodeDef(GraphDef[Node], PrettyRepr):
303
- """
304
- A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
305
-
306
- """
307
-
308
- type: Type[Node] # type of the node
309
- index: int # index of the node in the graph
310
- attributes: Tuple[Key, ...] # attributes for the node
311
- subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
312
- static_fields: HashableMapping[Key, Any]
313
- leaves: HashableMapping[Key, NodeRef[Any] | None]
314
- metadata: Hashable
315
- index_mapping: FrozenDict[Index, Index] | None
316
-
317
- @classmethod
318
- def create(
319
- cls,
320
- type: Type[Node],
321
- index: int,
322
- attributes: tuple[Key, ...],
323
- subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
324
- static_fields: Iterable[tuple[Key, Any]],
325
- leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
326
- metadata: Hashable,
327
- index_mapping: Mapping[Index, Index] | None,
328
- ):
329
- return cls(
330
- type=type,
331
- index=index,
332
- attributes=attributes,
333
- subgraphs=HashableMapping(subgraphs),
334
- static_fields=HashableMapping(static_fields),
335
- leaves=HashableMapping(leaves),
336
- metadata=metadata,
337
- index_mapping=FrozenDict(index_mapping) if index_mapping is not None else None,
338
- )
339
-
340
- def __pretty_repr__(self):
341
- yield PrettyType(type=type(self))
342
-
343
- yield PrettyAttr('type', self.type.__name__)
344
- yield PrettyAttr('index', self.index)
345
- yield PrettyAttr('attributes', self.attributes)
346
- yield PrettyAttr('subgraphs', PrettyMapping(self.subgraphs))
347
- yield PrettyAttr('static_fields', PrettyMapping(self.static_fields))
348
- yield PrettyAttr('leaves', PrettyMapping(self.leaves))
349
- yield PrettyAttr('metadata', self.metadata)
350
- yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
351
-
352
- def apply(
353
- self,
354
- state_map: GraphStateMapping,
355
- *state_maps: GraphStateMapping
356
- ) -> ApplyCaller[tuple[GraphDef[Node], GraphStateMapping]]:
357
- accessor = DelayedAccessor()
358
-
359
- def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[
360
- Any, tuple[GraphDef[Node], GraphStateMapping]]:
361
- module = treefy_merge(self, state_map, *state_maps)
362
- fn = accessor(module)
363
- out = fn(*args, **kwargs)
364
- return out, flatten(module)
365
-
366
- return CallableProxy(_apply, accessor) # type: ignore
367
-
368
-
369
- jax.tree_util.register_static(NodeDef)
370
-
371
-
372
- # --------------------------------------------------------
373
- # Graph operations: start
374
- # --------------------------------------------------------
375
-
376
-
377
- def _graph_flatten(
378
- path: PathParts,
379
- ref_index: RefMap[Any, Index],
380
- flatted_state_mapping: Dict[PathParts, StateLeaf],
381
- node: Node,
382
- treefy_state: bool = False,
383
- ):
384
- """
385
- Recursive helper for graph flatten.
386
-
387
- Args:
388
- path: The path to the node.
389
- ref_index: A mapping from nodes to indexes.
390
- flatted_state_mapping: A mapping from paths to state leaves.
391
- node: The node to flatten.
392
-
393
- Returns:
394
- A NodeDef or a NodeRef.
395
- """
396
- if not _is_node(node):
397
- raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
398
-
399
- # If the node is already in the cache, return a reference, otherwise
400
- # add it to the cache and continue with the flattening process.
401
- # This is done to avoid infinite recursion when there is a reference cycle.
402
- if node in ref_index:
403
- return NodeRef(type(node), ref_index[node])
404
-
405
- # Get the node implementation for the node type.
406
- # There are two types of node implementations: GraphNodeImpl and PyTreeNodeImpl.
407
- # - ``GraphNodeImpl`` is used for nodes that have a graph structure.
408
- # - ``PyTreeNodeImpl`` is used for nodes that have a tree structure.
409
- node_impl = _get_node_impl(node)
410
-
411
- # There are two types of nodes: Node and State.
412
- # Here we handle the Node case.
413
- if isinstance(node_impl, GraphNodeImpl):
414
- # add the node to the cache
415
- index = len(ref_index)
416
- ref_index[node] = index
417
- else:
418
- index = -1
419
-
420
- subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
421
- static_fields: list[tuple[Key, Any]] = []
422
- leaves: list[tuple[Key, NodeRef | None]] = []
423
-
424
- # Flatten the node into a sequence of key-value pairs.
425
- values, metadata = node_impl.flatten(node)
426
- for key, value in values:
427
- if _is_node(value):
428
- # Recursively flatten the subgraph.
429
- nodedef = _graph_flatten((*path, key), ref_index, flatted_state_mapping, value, treefy_state)
430
- subgraphs.append((key, nodedef))
431
- elif isinstance(value, State):
432
- # If the variable is in the cache, add a reference to it.
433
- if value in ref_index:
434
- leaves.append((key, NodeRef(type(value), ref_index[value])))
435
- else:
436
- # If the variable is not in the cache, add it to the cache.
437
- # This is done to avoid multiple references to the same variable.
438
- flatted_state_mapping[(*path, key)] = (value.to_state_ref() if treefy_state else value)
439
- variable_index = ref_index[value] = len(ref_index)
440
- leaves.append((key, NodeRef(type(value), variable_index)))
441
- elif _is_state_leaf(value):
442
- # The instance of ``TreefyState`` is a leaf.
443
- flatted_state_mapping[(*path, key)] = value
444
- leaves.append((key, None))
445
- else:
446
- # if isinstance(value, (jax.Array, np.ndarray)):
447
- # path_str = '/'.join(map(str, (*path, key)))
448
- # raise ValueError(f'Arrays leaves are not supported, at {path_str!r}: {value}')
449
-
450
- # The value is a static field.
451
- static_fields.append((key, value))
452
-
453
- nodedef = NodeDef.create(type=node_impl.type,
454
- index=index,
455
- attributes=tuple(key for key, _ in values),
456
- subgraphs=subgraphs,
457
- static_fields=static_fields,
458
- leaves=leaves,
459
- metadata=metadata,
460
- index_mapping=None, )
461
- return nodedef
462
-
463
-
464
- @set_module_as('brainstate.graph')
465
- def flatten(
466
- node: Node,
467
- /,
468
- ref_index: Optional[RefMap[Any, Index]] = None,
469
- treefy_state: bool = True,
470
- ) -> Tuple[GraphDef, NestedDict]:
471
- """
472
- Flattens a graph node into a (graph_def, state_mapping) pair.
473
-
474
- Example::
475
-
476
- >>> import brainstate as brainstate
477
- >>> node = brainstate.graph.Node()
478
- >>> graph_def, state_mapping = flatten(node)
479
- >>> print(graph_def)
480
- >>> print(state_mapping)
481
-
482
- Args:
483
- node: A graph node.
484
- ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
485
- empty dictionary is created. This argument can be used to flatten a sequence of graph
486
- nodes that share references.
487
- treefy_state: If True, the state mapping will be a NestedDict instead of a flat dictionary.
488
- """
489
- ref_index = RefMap() if ref_index is None else ref_index
490
- assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
491
- flatted_state_mapping: dict[PathParts, StateLeaf] = {}
492
- graph_def = _graph_flatten((), ref_index, flatted_state_mapping, node, treefy_state)
493
- return graph_def, NestedDict.from_flat(flatted_state_mapping)
494
-
495
-
496
- def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
497
- children: dict[Key, StateLeaf | Node] = {}
498
-
499
- # NOTE: we could allow adding new StateLeafs here
500
- # All state keys must be present in the graph definition (the object attributes)
501
- if unknown_keys := set(state_mapping) - set(graph_def.attributes):
502
- raise ValueError(f'Unknown keys: {unknown_keys}')
503
-
504
- # for every key in attributes there are 6 possible cases:
505
- # - (2) the key can either be present in the state or not
506
- # - (3) the key can be a subgraph, a leaf, or a static attribute
507
- for key in graph_def.attributes:
508
- if key not in state_mapping: # static field
509
- # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
510
- # if key is not present, create an empty types
511
- if key in graph_def.static_fields:
512
- children[key] = graph_def.static_fields[key]
513
-
514
- elif key in graph_def.subgraphs:
515
- # if the key is a subgraph we create an empty node
516
- subgraphdef = graph_def.subgraphs[key]
517
- if isinstance(subgraphdef, NodeRef):
518
- # subgraph exists, take it from the cache
519
- children[key] = index_ref[subgraphdef.index]
520
-
521
- else:
522
- # create a node from an empty state, reasoning:
523
- # * it is a node with no state
524
- # * it is a node with state but only through references of already
525
- # created nodes
526
- substate = {}
527
- children[key] = _graph_unflatten(subgraphdef, substate, index_ref, index_ref_cache)
528
-
529
- elif key in graph_def.leaves:
530
- noderef = graph_def.leaves[key]
531
- if (noderef is not None) and (noderef.index in index_ref):
532
- # variable exists, take it from the cache
533
- children[key] = index_ref[noderef.index]
534
-
535
- else:
536
- # key for a variable is missing, raise an error
537
- raise ValueError(f'Expected key {key!r} in state while building node of type '
538
- f'{graph_def.type.__name__}.')
539
-
540
- else:
541
- raise RuntimeError(f'Unknown static field: {key!r}')
542
-
543
- else: # state field
544
- value = state_mapping[key]
545
- if isinstance(value, PrettyDict):
546
- value = dict(value)
547
-
548
- if key in graph_def.static_fields:
549
- raise ValueError(f'Got state for static field {key!r}, this is not supported.')
550
-
551
- if key in graph_def.subgraphs:
552
- # if _is_state_leaf(value):
553
- if isinstance(value, (TreefyState, State)):
554
- raise ValueError(f'Expected value of type {graph_def.subgraphs[key]} '
555
- f'for {key!r}, but got {value!r}')
556
- if not isinstance(value, dict):
557
- raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
558
-
559
- subgraphdef = graph_def.subgraphs[key]
560
- if isinstance(subgraphdef, NodeRef):
561
- children[key] = index_ref[subgraphdef.index]
562
- else:
563
- children[key] = _graph_unflatten(subgraphdef, value, index_ref, index_ref_cache)
564
-
565
- elif key in graph_def.leaves:
566
- # if not _is_state_leaf(value):
567
- if not isinstance(value, (TreefyState, State)):
568
- raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
569
-
570
- noderef = graph_def.leaves[key]
571
- if noderef is None:
572
- # if the leaf is None, it means that the value was originally
573
- # a non-TreefyState leaf, however we allow providing a
574
- # TreefyState presumbly created by modifying the NestedDict
575
- if isinstance(value, TreefyState):
576
- value = value.to_state()
577
- # elif isinstance(value, State):
578
- # value = value
579
- children[key] = value
580
-
581
- elif noderef.index in index_ref:
582
- # add an existing variable
583
- children[key] = index_ref[noderef.index]
584
-
585
- else:
586
- # it is an unseen variable, create a new one
587
- if not isinstance(value, (TreefyState, State)):
588
- raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
589
- # when idxmap is present, check if the Varable exists there
590
- # and update existing variables if it does
591
- if index_ref_cache is not None and noderef.index in index_ref_cache:
592
- variable = index_ref_cache[noderef.index]
593
- if not isinstance(variable, State):
594
- raise ValueError(f'Expected a State type for {key!r}, but got {type(variable)}.')
595
- if isinstance(value, TreefyState):
596
- variable.update_from_ref(value)
597
- elif isinstance(value, State):
598
- if value._been_writen:
599
- variable.value = value.value
600
- else:
601
- variable.restore_value(value.value)
602
- else:
603
- raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
604
- else: # if it doesn't, create a new variable
605
- if isinstance(value, TreefyState):
606
- variable = value.to_state()
607
- elif isinstance(value, State):
608
- variable = value
609
- else:
610
- raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
611
- children[key] = variable
612
- index_ref[noderef.index] = variable
613
-
614
- else:
615
- raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
616
-
617
- return children
618
-
619
-
620
- def _graph_unflatten(
621
- graph_def: NodeDef[Node] | NodeRef[Node],
622
- state_mapping: Mapping[Key, StateLeaf | Mapping[Key, Any]],
623
- index_ref: dict[Index, Any],
624
- index_ref_cache: dict[Index, Any] | None,
625
- ) -> Node:
626
- """
627
- Recursive helper for graph unflatten.
628
-
629
- Args:
630
- graph_def: A `GraphDef` instance or an index to a node in the cache.
631
- state_mapping: A state mapping from attribute names to variables or subgraphs.
632
- index_ref: A mapping from indexes to nodes that have been traversed.
633
- If a node is already in the cache, it won't be traversed again.
634
- index_ref_cache: A mapping from indexes to existing nodes that can be reused.
635
- When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
636
- object in an empty state and then filled by the unflatten process, as a result
637
- existing graph nodes are mutated to have the new content/topology
638
- specified by the nodedef.
639
-
640
- Returns:
641
- A node instance.
642
- """
643
-
644
- # if the graph_def is a reference, this means that the node has already been created, so
645
- # we return the node from the cache
646
- if isinstance(graph_def, NodeRef):
647
- return index_ref[graph_def.index]
648
- else:
649
- assert isinstance(graph_def, NodeDef), f"graph_def must be a NodeDef. But we got: {graph_def}"
650
-
651
- # graph_def must be a registered node type
652
- if not _is_node_type(graph_def.type):
653
- raise RuntimeError(f'Unsupported type: {graph_def.type}, this is a bug.')
654
-
655
- # check if the index is already in the cache
656
- if graph_def.index in index_ref:
657
- raise RuntimeError(f'GraphDef index {graph_def.index} already used.')
658
-
659
- # get the node implementation for the node type
660
- node_impl = get_node_impl_for_type(graph_def.type)
661
-
662
- if isinstance(node_impl, GraphNodeImpl):
663
- # we create an empty node first and add it to the index
664
- # this avoids infinite recursion when there is a reference cycle
665
-
666
- if (index_ref_cache is not None) and (graph_def.index in index_ref_cache):
667
- # clear the node to leave it in an empty state
668
- node = index_ref_cache[graph_def.index]
669
- if type(node) != graph_def.type:
670
- raise ValueError(f'Expected a node of type {graph_def.type} for index '
671
- f'{graph_def.index}, but got a node of type {type(node)}.')
672
- node_impl.clear(node)
673
- else:
674
- # create an empty node
675
- node = node_impl.create_empty(graph_def.metadata)
676
-
677
- # add the node to the cache
678
- index_ref[graph_def.index] = node
679
-
680
- # get the children (the attributes) of the node
681
- children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
682
-
683
- # initialize the node with the children
684
- node_impl.init(node, tuple(children.items()))
685
-
686
- else:
687
- # if the node type does not support the creation of an empty object it means
688
- # that it cannot reference itself, so we can create its children first
689
-
690
- # first, we create the children (attributes)
691
- children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
692
- # then, we create the node
693
- node = node_impl.unflatten(tuple(children.items()), graph_def.metadata)
694
-
695
- return node
696
-
697
-
698
- @set_module_as('brainstate.graph')
699
- def unflatten(
700
- graph_def: GraphDef,
701
- state_mapping: NestedDict[Key, StateLeaf],
702
- /,
703
- *,
704
- index_ref: dict[Index, Any] | None = None,
705
- index_ref_cache: dict[Index, Any] | None = None,
706
- ) -> Node:
707
- """
708
- Unflattens a graphdef into a node with the given state tree mapping.
709
-
710
- Example::
711
-
712
- >>> import brainstate as brainstate
713
- >>> class MyNode(brainstate.graph.Node):
714
- ... def __init__(self):
715
- ... self.a = brainstate.nn.Linear(2, 3)
716
- ... self.b = brainstate.nn.Linear(3, 4)
717
- ... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
718
- ... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
719
- ...
720
- >>> graphdef, statetree = brainstate.graph.flatten(MyNode())
721
- >>> statetree
722
- NestedDict({
723
- 'a': {
724
- 'weight': TreefyState(
725
- type=ParamState,
726
- value={'weight': Array([[-0.8466386 , -2.0294454 , -0.6911647 ],
727
- [ 0.60034966, -1.1869028 , 0.84003365]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
728
- )
729
- },
730
- 'b': {
731
- 'weight': TreefyState(
732
- type=ParamState,
733
- value={'weight': Array([[ 0.8565106 , -0.10337489],
734
- [ 1.7449658 , 0.29128835],
735
- [ 0.11441387, 1.0012752 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
736
- )
737
- },
738
- 'c': {
739
- 0: {
740
- 'weight': TreefyState(
741
- type=ParamState,
742
- value={'weight': Array([[ 2.4465137, -0.5711426]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
743
- )
744
- },
745
- 1: {
746
- 'weight': TreefyState(
747
- type=ParamState,
748
- value={'weight': Array([[ 0.14321847, -2.4154725 , -0.6322363 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
749
- )
750
- }
751
- },
752
- 'd': {
753
- 'x': {
754
- 'weight': TreefyState(
755
- type=ParamState,
756
- value={'weight': Array([[ 0.9647322, -0.8958757, 1.585352 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
757
- )
758
- },
759
- 'y': {
760
- 'weight': TreefyState(
761
- type=ParamState,
762
- value={'weight': Array([[-1.2904786 , 0.5695903 , 0.40079263, 0.8769669 ]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
763
- )
764
- }
765
- }
766
- })
767
- >>> node = brainstate.graph.unflatten(graphdef, statetree)
768
- >>> node
769
- MyNode(
770
- a=Linear(
771
- in_size=(2,),
772
- out_size=(3,),
773
- w_mask=None,
774
- weight=ParamState(
775
- value={'weight': Array([[ 0.55600464, -1.6276929 , 0.26805446],
776
- [ 1.175099 , 1.0077754 , 0.37592274]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)},
777
- )
778
- ),
779
- b=Linear(
780
- in_size=(3,),
781
- out_size=(4,),
782
- w_mask=None,
783
- weight=ParamState(
784
- value={'weight': Array([[-0.24753566, 0.18456966, -0.29438975, 0.16891003],
785
- [-0.803741 , -0.46037054, -0.21617596, 0.1260884 ],
786
- [-0.43074366, -0.24757433, 1.2237076 , -0.07842704]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)},
787
- )
788
- ),
789
- c=[Linear(
790
- in_size=(4,),
791
- out_size=(5,),
792
- w_mask=None,
793
- weight=ParamState(
794
- value={'weight': Array([[-0.22384474, 0.79441446, -0.658726 , 0.05991402, 0.3014344 ],
795
- [-1.4755846 , -0.42272082, -0.07692316, 0.03077666, 0.34513143],
796
- [-0.69395834, 0.48617035, 1.1042316 , 0.13105175, -0.25620162],
797
- [ 0.50389856, 0.6998943 , 0.43716812, 1.2168779 , -0.47325954]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)},
798
- )
799
- ), Linear(
800
- in_size=(5,),
801
- out_size=(6,),
802
- w_mask=None,
803
- weight=ParamState(
804
- value={'weight': Array([[ 0.07714394, 0.78213537, 0.6745718 , -0.22881542, 0.5523547 ,
805
- -0.6399196 ],
806
- [-0.22626828, -0.54522336, 0.07448788, -0.00464636, 1.1483842 ,
807
- -0.57049096],
808
- [-0.86659616, 0.5683135 , -0.7449975 , 1.1862832 , 0.15047254,
809
- 0.68890226],
810
- [-1.0325443 , 0.2658072 , -0.10083053, -0.66915905, 0.11258496,
811
- 0.5440655 ],
812
- [ 0.27917263, 0.05717273, -0.5682605 , -0.88345915, 0.01314917,
813
- 0.780759 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)},
814
- )
815
- )],
816
- d={'x': Linear(
817
- in_size=(6,),
818
- out_size=(7,),
819
- w_mask=None,
820
- weight=ParamState(
821
- value={'weight': Array([[-0.24238771, -0.23202638, 0.13663477, -0.48858666, 0.80871904,
822
- 0.00593298, 0.7595096 ],
823
- [ 0.50457454, 0.24180941, 0.25048748, 0.8937061 , 0.25398138,
824
- -1.2400566 , 0.00151599],
825
- [-0.19136038, 0.34470603, -0.11892717, -0.12514868, -0.5871703 ,
826
- 0.13572927, -1.1859009 ],
827
- [-0.01580911, 0.9301295 , -1.1246226 , -0.137708 , -0.4952151 ,
828
- 0.17537868, 0.98440856],
829
- [ 0.6399284 , 0.01739843, 0.61856824, 0.93258303, 0.64012206,
830
- 0.22780116, -0.5763679 ],
831
- [ 0.14077143, -1.0359222 , 0.28072503, 0.2557584 , -0.50622064,
832
- 0.4388198 , -0.26106128]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
833
- )
834
- ), 'y': Linear(
835
- in_size=(7,),
836
- out_size=(8,),
837
- w_mask=None,
838
- weight=ParamState(
839
- value={'weight': Array([[-0.23334591, -0.2893582 , 0.8071877 , -0.49038902, -0.29646504,
840
- 0.13624157, 0.22763114, 0.01906361],
841
- [-0.26742765, 0.20136863, 0.35148615, 0.42135832, 0.06401154,
842
- -0.78036404, 0.6616062 , 0.19437549],
843
- [ 0.9229799 , -0.1205209 , 0.69602865, 0.9685676 , -0.99886954,
844
- -0.12649904, -0.15393028, 0.65067965],
845
- [ 0.7020109 , -0.5452006 , 0.3649151 , -0.42368713, 0.24738027,
846
- 0.29290223, -0.63721114, 0.6007214 ],
847
- [-0.45045808, -0.08538888, -0.01338054, -0.39983988, 0.4028439 ,
848
- 1.0498686 , -0.24730456, 0.37612835],
849
- [ 0.16273966, 0.9001257 , 0.15190877, -1.1129239 , -0.29441378,
850
- 0.5168159 , -0.4205143 , 0.45700482],
851
- [ 0.08611429, -0.9271384 , -0.562362 , -0.586757 , 1.1611121 ,
852
- 0.5137503 , -0.46277294, 0.84642583]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
853
- )
854
- )}
855
- )
856
-
857
- Args:
858
- graph_def: A GraphDef instance.
859
- state_mapping: A NestedDict instance.
860
- index_ref: A mapping from indexes to nodes references found during the graph
861
- traversal, defaults to None. If not provided, a new empty dictionary is
862
- created. This argument can be used to unflatten a sequence of (graphdef, state_mapping)
863
- pairs that share the same index space.
864
- index_ref_cache: A mapping from indexes to existing nodes that can be reused.
865
- When a reference is reused, ``GraphNodeImpl.clear`` is called to leave the
866
- object in an empty state and then filled by the unflatten process, as a result
867
- existing graph nodes are mutated to have the new content/topology
868
- specified by the graphdef.
869
- """
870
- index_ref = {} if index_ref is None else index_ref
871
- assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
872
- node = _graph_unflatten(graph_def, state_mapping.to_dict(), index_ref, index_ref_cache)
873
- return node
874
-
875
-
876
- def _graph_pop(
877
- node: Node,
878
- id_to_index: dict[int, Index],
879
- path_parts: PathParts,
880
- flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
881
- predicates: tuple[Predicate, ...],
882
- ) -> None:
883
- if not _is_node(node):
884
- raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
885
-
886
- if id(node) in id_to_index:
887
- return
888
-
889
- id_to_index[id(node)] = len(id_to_index)
890
- node_impl = _get_node_impl(node)
891
- node_dict = node_impl.node_dict(node)
892
-
893
- for name, value in node_dict.items():
894
- if _is_node(value):
895
- _graph_pop(
896
- node=value,
897
- id_to_index=id_to_index,
898
- path_parts=(*path_parts, name),
899
- flatted_state_dicts=flatted_state_dicts,
900
- predicates=predicates,
901
- )
902
- continue
903
- elif not _is_node_leaf(value):
904
- continue
905
- elif id(value) in id_to_index:
906
- continue
907
-
908
- node_path = (*path_parts, name)
909
- node_impl = _get_node_impl(node)
910
- for state_dicts, predicate in zip(flatted_state_dicts, predicates):
911
- if predicate(node_path, value):
912
- if isinstance(node_impl, PyTreeNodeImpl):
913
- raise ValueError(f'Cannot pop key {name!r} from node of type {type(node).__name__}')
914
- id_to_index[id(value)] = len(id_to_index)
915
- node_impl.pop_key(node, name)
916
- # if isinstance(value, State):
917
- # value = value.to_state_ref()
918
- state_dicts[node_path] = value # type: ignore[index] # mypy is wrong here?
919
- break
920
- else:
921
- # NOTE: should we raise an error here?
922
- pass
923
-
924
-
925
- @overload
926
- def pop_states(node, filter1: Filter, /) -> NestedDict:
927
- ...
928
-
929
-
930
- @overload
931
- def pop_states(node, filter1: Filter, filter2: Filter, /, *filters: Filter) -> tuple[NestedDict, ...]:
932
- ...
933
-
934
-
935
- @set_module_as('brainstate.graph')
936
- def pop_states(
937
- node: Node,
938
- *filters: Any
939
- ) -> Union[NestedDict[Key, State], Tuple[NestedDict[Key, State], ...]]:
940
- """
941
- Pop one or more :class:`State` types from the graph node.
942
-
943
- Example usage::
944
-
945
- >>> import brainstate as brainstate
946
- >>> import jax.numpy as jnp
947
-
948
- >>> class Model(brainstate.nn.Module):
949
- ... def __init__(self):
950
- ... super().__init__()
951
- ... self.a = brainstate.nn.Linear(2, 3)
952
- ... self.b = brainstate.nn.LIF([10, 2])
953
-
954
- >>> model = Model()
955
- >>> with brainstate.catch_new_states('new'):
956
- ... brainstate.nn.init_all_states(model)
957
-
958
- >>> assert len(model.states()) == 2
959
- >>> model_states = brainstate.graph.pop_states(model, 'new')
960
- >>> model_states
961
- NestedDict({
962
- 'b': {
963
- 'V': {
964
- 'st': ShortTermState(
965
- value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
966
- 0., 0., 0.], dtype=float32),
967
- tag='new'
968
- )
969
- }
970
- }
971
- })
972
-
973
- Args:
974
- node: A graph node object.
975
- *filters: One or more :class:`State` objects to filter by.
976
-
977
- Returns:
978
- The popped :class:`NestedDict` containing the :class:`State`
979
- objects that were filtered for.
980
- """
981
- if len(filters) == 0:
982
- raise ValueError('Expected at least one filter')
983
-
984
- id_to_index: dict[int, Index] = {}
985
- path_parts: PathParts = ()
986
- predicates = tuple(to_predicate(filter) for filter in filters)
987
- flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
988
- _graph_pop(node=node,
989
- id_to_index=id_to_index,
990
- path_parts=path_parts,
991
- flatted_state_dicts=flatted_state_dicts,
992
- predicates=predicates, )
993
- states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
994
-
995
- if len(states) == 1:
996
- return states[0]
997
- else:
998
- return states
999
-
1000
-
1001
- def _split_state(
1002
- state: GraphStateMapping,
1003
- filters: tuple[Filter, ...],
1004
- ) -> tuple[GraphStateMapping, Unpack[tuple[GraphStateMapping, ...]]]:
1005
- if not filters:
1006
- return (state,)
1007
- states = state.split(*filters)
1008
- if isinstance(states, NestedDict):
1009
- return (states,)
1010
- assert len(states) > 0
1011
- return states # type: ignore[return-value]
1012
-
1013
-
1014
- @overload
1015
- def treefy_split(node: A, /) -> Tuple[GraphDef, NestedDict]:
1016
- ...
1017
-
1018
-
1019
- @overload
1020
- def treefy_split(node: A, first: Filter, /) -> Tuple[GraphDef, NestedDict]:
1021
- ...
1022
-
1023
-
1024
- @overload
1025
- def treefy_split(node: A, first: Filter, second: Filter, /) -> Tuple[GraphDef, NestedDict, NestedDict]:
1026
- ...
1027
-
1028
-
1029
- @overload
1030
- def treefy_split(
1031
- node: A, first: Filter, second: Filter, /, *filters: Filter,
1032
- ) -> Tuple[GraphDef, NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1033
- ...
1034
-
1035
-
1036
- @set_module_as('brainstate.graph')
1037
- def treefy_split(
1038
- node: A,
1039
- *filters: Filter
1040
- ) -> Tuple[GraphDef[A], NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1041
- """Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s. NestedDict is
1042
- a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
1043
- contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
1044
- to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
1045
- seamlessly between stateful and stateless representations of the graph.
1046
-
1047
- Example usage::
1048
-
1049
- >>> from joblib.testing import param >>> import brainstate as brainstate
1050
- >>> import jax, jax.numpy as jnp
1051
- ...
1052
- >>> class Foo(brainstate.graph.Node):
1053
- ... def __init__(self):
1054
- ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1055
- ... self.b = brainstate.nn.Linear(2, 3)
1056
- ...
1057
- >>> node = Foo()
1058
- >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1059
- ...
1060
- >>> params
1061
- NestedDict({
1062
- 'a': {
1063
- 'weight': TreefyState(
1064
- type=ParamState,
1065
- value={'weight': Array([[-1.0013659, 1.5763807],
1066
- [ 1.7149199, 2.0140953]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
1067
- )
1068
- },
1069
- 'b': {
1070
- 'weight': TreefyState(
1071
- type=ParamState,
1072
- value={'bias': Array([[0., 0.]], dtype=float32), 'scale': Array([[1., 1.]], dtype=float32)}
1073
- )
1074
- }
1075
- })
1076
- >>> jax.tree.map(jnp.shape, others)
1077
- NestedDict({
1078
- 'b': {
1079
- 'running_mean': TreefyState(
1080
- type=LongTermState,
1081
- value=(1, 2)
1082
- ),
1083
- 'running_var': TreefyState(
1084
- type=LongTermState,
1085
- value=(1, 2)
1086
- )
1087
- }
1088
- })
1089
-
1090
- :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1091
- transformations, see
1092
- `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1093
- for more information.
1094
-
1095
- Arguments:
1096
- node: graph node to split.
1097
- *filters: some optional filters to group the state into mutually exclusive substates.
1098
-
1099
- Returns:
1100
- ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
1101
- filters are passed, a single ``NestedDict`` is returned.
1102
- """
1103
- graphdef, state_tree = flatten(node)
1104
- states = tuple(_split_state(state_tree, filters))
1105
- return graphdef, *states
1106
-
1107
-
1108
- @set_module_as('brainstate.graph')
1109
- def treefy_merge(
1110
- graphdef: GraphDef[A],
1111
- state_mapping: GraphStateMapping,
1112
- /,
1113
- *state_mappings: GraphStateMapping,
1114
- ) -> A:
1115
- """The inverse of :func:`split`.
1116
-
1117
- ``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
1118
- a new node with the same structure as the original node.
1119
-
1120
- Example usage::
1121
-
1122
- >>> import brainstate as brainstate
1123
- >>> import jax, jax.numpy as jnp
1124
- ...
1125
- >>> class Foo(brainstate.graph.Node):
1126
- ... def __init__(self):
1127
- ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1128
- ... self.b = brainstate.nn.Linear(2, 3)
1129
- ...
1130
- >>> node = Foo()
1131
- >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1132
- ...
1133
- >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1134
- >>> assert isinstance(new_node, Foo)
1135
- >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1136
- >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1137
-
1138
- :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1139
- transformations, see
1140
- `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1141
- for more information.
1142
-
1143
- Args:
1144
- graphdef: A :class:`GraphDef` object.
1145
- state_mapping: A :class:`NestedDict` object.
1146
- *state_mappings: Additional :class:`NestedDict` objects.
1147
-
1148
- Returns:
1149
- The merged :class:`Module`.
1150
- """
1151
- state_mapping = GraphStateMapping.merge(state_mapping, *state_mappings)
1152
- node = unflatten(graphdef, state_mapping)
1153
- return node
1154
-
1155
-
1156
- def _filters_to_predicates(filters: Tuple[Filter, ...]) -> Tuple[Predicate, ...]:
1157
- for i, filter_ in enumerate(filters):
1158
- if filter_ in (..., True) and i != len(filters) - 1:
1159
- remaining_filters = filters[i + 1:]
1160
- if not all(f in (..., True) for f in remaining_filters):
1161
- raise ValueError('`...` or `True` can only be used as the last filters, '
1162
- f'got {filter_} it at index {i}.')
1163
- return tuple(map(to_predicate, filters))
1164
-
1165
-
1166
- def _split_flatted(
1167
- flatted: Iterable[tuple[PathParts, Any]],
1168
- filters: tuple[Filter, ...],
1169
- ) -> tuple[list[tuple[PathParts, Any]], ...]:
1170
- predicates = _filters_to_predicates(filters)
1171
-
1172
- # we have n + 1 states, where n is the number of predicates
1173
- # the last state is for values that don't match any predicate
1174
- flat_states: tuple[list[tuple[PathParts, Any]], ...] = tuple([] for _ in predicates)
1175
-
1176
- for path, value in flatted:
1177
- for i, predicate in enumerate(predicates):
1178
- if predicate(path, value):
1179
- flat_states[i].append((path, value))
1180
- break
1181
- else:
1182
- raise ValueError('Non-exhaustive filters, got a non-empty remainder: '
1183
- f'{path} -> {value}.'
1184
- '\nUse `...` to match all remaining elements.')
1185
-
1186
- return flat_states
1187
-
1188
-
1189
- @overload
1190
- def nodes(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1191
- ...
1192
-
1193
-
1194
- @overload
1195
- def nodes(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1196
- ...
1197
-
1198
-
1199
- @overload
1200
- def nodes(
1201
- node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1202
- ) -> Tuple[FlattedDict[Key, Node], ...]:
1203
- ...
1204
-
1205
-
1206
- @set_module_as('brainstate.graph')
1207
- def nodes(
1208
- node,
1209
- *filters: Filter,
1210
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1211
- ) -> Union[FlattedDict[Key, Node], Tuple[FlattedDict[Key, Node], ...]]:
1212
- """
1213
- Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1214
- """
1215
- num_filters = len(filters)
1216
- if num_filters == 0:
1217
- filters = (..., ...)
1218
- else:
1219
- filters = (*filters, ...)
1220
-
1221
- nodes_iterable = iter_node(node, allowed_hierarchy=allowed_hierarchy)
1222
- flat_nodes = _split_flatted(nodes_iterable, (*filters, ...))
1223
- node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
1224
- if num_filters < 2:
1225
- return node_maps[0]
1226
- return node_maps[:num_filters]
1227
-
1228
-
1229
- def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
1230
- for path, value in iter_leaf(node, allowed_hierarchy=allowed_hierarchy):
1231
- if isinstance(value, State):
1232
- yield path, value
1233
-
1234
-
1235
- @overload
1236
- def states(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1237
- ...
1238
-
1239
-
1240
- @overload
1241
- def states(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1242
- ...
1243
-
1244
-
1245
- @overload
1246
- def states(
1247
- node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1248
- ) -> tuple[FlattedDict[Key, State], ...]:
1249
- ...
1250
-
1251
-
1252
- @set_module_as('brainstate.graph')
1253
- def states(
1254
- node,
1255
- *filters: Filter,
1256
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1257
- ) -> Union[FlattedDict[Key, State], tuple[FlattedDict[Key, State], ...]]:
1258
- """
1259
- Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1260
- """
1261
- num_filters = len(filters)
1262
- if num_filters == 0:
1263
- filters = (..., ...)
1264
- else:
1265
- filters = (*filters, ...)
1266
-
1267
- states_iterable = _states_generator(node, allowed_hierarchy=allowed_hierarchy)
1268
- flat_states = _split_flatted(states_iterable, (*filters, ...))
1269
- state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
1270
- if num_filters < 2:
1271
- return state_maps[0]
1272
- return state_maps[:num_filters]
1273
-
1274
-
1275
- @overload
1276
- def treefy_states(
1277
- node, /, flatted: bool = False
1278
- ) -> NestedDict[Key, TreefyState]:
1279
- ...
1280
-
1281
-
1282
- @overload
1283
- def treefy_states(
1284
- node, first: Filter, /, flatted: bool = False
1285
- ) -> NestedDict[Key, TreefyState]:
1286
- ...
1287
-
1288
-
1289
- @overload
1290
- def treefy_states(
1291
- node, first: Filter, second: Filter, /, *filters: Filter, flatted: bool = False
1292
- ) -> Tuple[NestedDict[Key, TreefyState], ...]:
1293
- ...
1294
-
1295
-
1296
- @set_module_as('brainstate.graph')
1297
- def treefy_states(
1298
- node, *filters,
1299
- ) -> NestedDict[Key, TreefyState] | tuple[NestedDict[Key, TreefyState], ...]:
1300
- """
1301
- Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1302
-
1303
- Example usage::
1304
-
1305
- >>> import brainstate as brainstate
1306
- >>> class Model(brainstate.nn.Module):
1307
- ... def __init__(self):
1308
- ... super().__init__()
1309
- ... self.l1 = brainstate.nn.Linear(2, 3)
1310
- ... self.l2 = brainstate.nn.Linear(3, 4)
1311
- ... def __call__(self, x):
1312
- ... return self.l2(self.l1(x))
1313
-
1314
- >>> model = Model()
1315
- >>> # get the learnable parameters from the batch norm and linear layer
1316
- >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1317
- >>> # get them separately
1318
- >>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
1319
- >>> # get them together
1320
- >>> states = brainstate.graph.treefy_states(model)
1321
-
1322
- Args:
1323
- node: A graph node object.
1324
- *filters: One or more :class:`State` objects to filter by.
1325
-
1326
- Returns:
1327
- One or more :class:`NestedDict` mappings.
1328
- """
1329
- _, state_mapping = flatten(node)
1330
- state_mappings: GraphStateMapping | tuple[GraphStateMapping, ...]
1331
- if len(filters) == 0:
1332
- state_mappings = state_mapping
1333
- elif len(filters) == 1:
1334
- state_mappings = state_mapping.filter(filters[0])
1335
- else:
1336
- state_mappings = state_mapping.filter(filters[0], filters[1], *filters[2:])
1337
- return state_mappings
1338
-
1339
-
1340
- def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
1341
- if not _is_node(node):
1342
- raise RuntimeError(f'Unsupported type: {type(node)}')
1343
-
1344
- node_impl = _get_node_impl(node)
1345
- node_dict = node_impl.node_dict(node)
1346
- for key, value in state.items():
1347
- # case 1: new state is being added
1348
- if key not in node_dict:
1349
- if isinstance(node_impl, PyTreeNodeImpl):
1350
- raise ValueError(f'Cannot set key {key!r} on immutable node of '
1351
- f'type {type(node).__name__}')
1352
- if isinstance(value, State):
1353
- value = value.copy() # TODO: chenge it to state_ref
1354
- node_impl.set_key(node, key, value)
1355
- continue
1356
-
1357
- # check values are of the same type
1358
- current_value = node_dict[key]
1359
-
1360
- # case 2: subgraph is being updated
1361
- if _is_node(current_value):
1362
- if _is_state_leaf(value):
1363
- raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
1364
- _graph_update_dynamic(current_value, value)
1365
- elif isinstance(value, TreefyState):
1366
- # case 3: state leaf is being updated
1367
- if not isinstance(current_value, State):
1368
- raise ValueError(f'Trying to update a non-State attribute {key!r} with a State: '
1369
- f'{value!r}')
1370
- current_value.update_from_ref(value)
1371
- elif _is_state_leaf(value):
1372
- # case 4: state field is being updated
1373
- if isinstance(node_impl, PyTreeNodeImpl):
1374
- raise ValueError(f'Cannot set key {key!r} on immutable node of '
1375
- f'type {type(node).__name__}')
1376
- node_impl.set_key(node, key, value)
1377
- else:
1378
- raise ValueError(f'Unsupported update type: {type(value)} for key {key!r}')
1379
-
1380
-
1381
- def update_states(
1382
- node: Node,
1383
- state_dict: NestedDict | FlattedDict,
1384
- /,
1385
- *state_dicts: NestedDict | FlattedDict
1386
- ) -> None:
1387
- """
1388
- Update the given graph node with a new :class:`NestedMapping` in-place.
1389
-
1390
- Args:
1391
- node: A graph node to update.
1392
- state_dict: A :class:`NestedMapping` object.
1393
- *state_dicts: Additional :class:`NestedMapping` objects.
1394
- """
1395
- if state_dicts:
1396
- state_dict = NestedDict.merge(state_dict, *state_dicts)
1397
- _graph_update_dynamic(node, state_dict.to_dict())
1398
-
1399
-
1400
- @set_module_as('brainstate.graph')
1401
- def graphdef(node: Any, /) -> GraphDef[Any]:
1402
- """Get the :class:`GraphDef` of the given graph node.
1403
-
1404
- Example usage::
1405
-
1406
- >>> import brainstate as brainstate
1407
-
1408
- >>> model = brainstate.nn.Linear(2, 3)
1409
- >>> graphdef, _ = brainstate.graph.treefy_split(model)
1410
- >>> assert graphdef == brainstate.graph.graphdef(model)
1411
-
1412
- Args:
1413
- node: A graph node object.
1414
-
1415
- Returns:
1416
- The :class:`GraphDef` of the :class:`Module` object.
1417
- """
1418
- graphdef, _ = flatten(node)
1419
- return graphdef
1420
-
1421
-
1422
- @set_module_as('brainstate.graph')
1423
- def clone(node: Node) -> Node:
1424
- """
1425
- Create a deep copy of the given graph node.
1426
-
1427
- Example usage::
1428
-
1429
- >>> import brainstate as brainstate
1430
- >>> model = brainstate.nn.Linear(2, 3)
1431
- >>> cloned_model = clone(model)
1432
- >>> model.weight.value['bias'] += 1
1433
- >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1434
-
1435
- Args:
1436
- node: A graph node object.
1437
-
1438
- Returns:
1439
- A deep copy of the :class:`Module` object.
1440
- """
1441
- graphdef, state = treefy_split(node)
1442
- return treefy_merge(graphdef, state)
1443
-
1444
-
1445
- @set_module_as('brainstate.graph')
1446
- def call(
1447
- graphdef_state: Tuple[GraphDef[A], GraphStateMapping],
1448
- ) -> ApplyCaller[Tuple[GraphDef[A], GraphStateMapping]]:
1449
- """Calls a method underlying graph node defined by a (GraphDef, NestedDict) pair.
1450
-
1451
- ``call`` takes a ``(GraphDef, NestedDict)`` pair and creates a proxy object that can be
1452
- used to call methods on the underlying graph node. When a method is called, the
1453
- output is returned along with a new (GraphDef, NestedDict) pair that represents the
1454
- updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method``
1455
- > :func:`split`` but is more convenient to use in pure JAX functions.
1456
-
1457
- Example::
1458
-
1459
- >>> import brainstate as brainstate
1460
- >>> import jax
1461
- >>> import jax.numpy as jnp
1462
- ...
1463
- >>> class StatefulLinear(brainstate.graph.Node):
1464
- ... def __init__(self, din, dout):
1465
- ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1466
- ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1467
- ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1468
- ...
1469
- ... def increment(self):
1470
- ... self.count.value += 1
1471
- ...
1472
- ... def __call__(self, x):
1473
- ... self.increment()
1474
- ... return x @ self.w.value + self.b.value
1475
- ...
1476
- >>> linear = StatefulLinear(3, 2)
1477
- >>> linear_state = brainstate.graph.treefy_split(linear)
1478
- ...
1479
- >>> @jax.jit
1480
- ... def forward(x, linear_state):
1481
- ... y, linear_state = brainstate.graph.call(linear_state)(x)
1482
- ... return y, linear_state
1483
- ...
1484
- >>> x = jnp.ones((1, 3))
1485
- >>> y, linear_state = forward(x, linear_state)
1486
- >>> y, linear_state = forward(x, linear_state)
1487
- ...
1488
- >>> linear = brainstate.graph.treefy_merge(*linear_state)
1489
- >>> linear.count.value
1490
- Array(2, dtype=uint32)
1491
-
1492
- The proxy object returned by ``call`` supports indexing and attribute access
1493
- to access nested methods. In the example below, the ``increment`` method indexing
1494
- is used to call the ``increment`` method of the ``StatefulLinear`` module
1495
- at the ``b`` key of a ``nodes`` dictionary.
1496
-
1497
- >>> class StatefulLinear(brainstate.graph.Node):
1498
- ... def __init__(self, din, dout):
1499
- ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1500
- ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1501
- ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1502
- ...
1503
- ... def increment(self):
1504
- ... self.count.value += 1
1505
- ...
1506
- ... def __call__(self, x):
1507
- ... self.increment()
1508
- ... return x @ self.w.value + self.b.value
1509
- ...
1510
- >>> nodes = dict(
1511
- ... a=StatefulLinear(3, 2),
1512
- ... b=StatefulLinear(2, 1),
1513
- ... )
1514
- ...
1515
- >>> node_state = treefy_split(nodes)
1516
- >>> # use attribute access
1517
- >>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
1518
- ...
1519
- >>> nodes = treefy_merge(*node_state)
1520
- >>> nodes['a'].count.value
1521
- Array(0, dtype=uint32)
1522
- >>> nodes['b'].count.value
1523
- Array(1, dtype=uint32)
1524
- """
1525
-
1526
- def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
1527
- node = treefy_merge(*graphdef_state)
1528
- method = accessor(node)
1529
- out = method(*args, **kwargs)
1530
- return out, treefy_split(node)
1531
-
1532
- return CallableProxy(pure_caller) # type: ignore
1533
-
1534
-
1535
- @set_module_as('brainstate.graph')
1536
- def iter_leaf(
1537
- node: Any,
1538
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1539
- ) -> Iterator[tuple[PathParts, Any]]:
1540
- """Iterates over all nested leaves in the given graph node, including the current node.
1541
-
1542
- ``iter_graph`` creates a generator that yields path and value pairs, where
1543
- the path is a tuple of strings or integers representing the path to the value from the
1544
- root. Repeated nodes are visited only once. Leaves include static values.
1545
-
1546
- Example::
1547
- >>> import brainstate as brainstate
1548
- >>> import jax.numpy as jnp
1549
- ...
1550
- >>> class Linear(brainstate.nn.Module):
1551
- ... def __init__(self, din, dout):
1552
- ... super().__init__()
1553
- ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1554
- ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1555
- ... self.a = 1
1556
- ...
1557
- >>> module = Linear(3, 4)
1558
- ...
1559
- >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1560
- ... print(path, type(value).__name__)
1561
- ...
1562
- (0, 'a') int
1563
- (0, 'bias') ParamState
1564
- (0, 'weight') ParamState
1565
-
1566
- Parameters
1567
- ----------
1568
- node: Node
1569
- The node to iterate over.
1570
- allowed_hierarchy: tuple of int
1571
- The allowed hierarchy.
1572
-
1573
- """
1574
-
1575
- def _iter_graph_leaf(
1576
- node_: Any,
1577
- visited_: set[int],
1578
- path_parts_: PathParts,
1579
- level_: int,
1580
- ) -> Iterator[tuple[PathParts, Any]]:
1581
- if level_ > allowed_hierarchy[1]:
1582
- return
1583
-
1584
- if _is_node(node_):
1585
- if id(node_) in visited_:
1586
- return
1587
- visited_.add(id(node_))
1588
- node_dict = _get_node_impl(node_).node_dict(node_)
1589
- for key, value in node_dict.items():
1590
- yield from _iter_graph_leaf(
1591
- value,
1592
- visited_,
1593
- (*path_parts_, key),
1594
- level_ + 1 if _is_graph_node(value) else level_
1595
- )
1596
- else:
1597
- if level_ >= allowed_hierarchy[0]:
1598
- yield path_parts_, node_
1599
-
1600
- visited: set[int] = set()
1601
- path_parts: PathParts = ()
1602
- level: int = 0
1603
- yield from _iter_graph_leaf(node, visited, path_parts, level)
1604
-
1605
-
1606
- @set_module_as('brainstate.graph')
1607
- def iter_node(
1608
- node: Any,
1609
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1610
- ) -> Iterator[Tuple[PathParts, Any]]:
1611
- """
1612
- Iterates over all nested nodes of the given graph node, including the current node.
1613
-
1614
- ``iter_graph`` creates a generator that yields path and value pairs, where
1615
- the path is a tuple of strings or integers representing the path to the value from the
1616
- root. Repeated nodes are visited only once. Leaves include static values.
1617
-
1618
- Example::
1619
- >>> import brainstate as brainstate
1620
- >>> import jax.numpy as jnp
1621
- ...
1622
- >>> class Model(brainstate.nn.Module):
1623
- ... def __init__(self):
1624
- ... super().__init__()
1625
- ... self.a = brainstate.nn.Linear(1, 2)
1626
- ... self.b = brainstate.nn.Linear(2, 3)
1627
- ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1628
- ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1629
- ... self.b.a = brainstate.nn.LIF(2)
1630
- ...
1631
- >>> model = Model()
1632
- ...
1633
- >>> for path, node in brainstate.graph.iter_node([model, model]):
1634
- ... print(path, node.__class__.__name__)
1635
- ...
1636
- (0, 'a') Linear
1637
- (0, 'b', 'a') LIF
1638
- (0, 'b') Linear
1639
- (0, 'c', 0) Linear
1640
- (0, 'c', 1) Linear
1641
- (0, 'd', 'x') Linear
1642
- (0, 'd', 'y') Linear
1643
- (0,) Model
1644
-
1645
- Parameters
1646
- ----------
1647
- node: Node
1648
- The node to iterate over.
1649
- allowed_hierarchy: tuple of int
1650
- The allowed hierarchy.
1651
-
1652
- """
1653
-
1654
- def _iter_graph_node(
1655
- node_: Any,
1656
- visited_: set[int],
1657
- path_parts_: PathParts,
1658
- level_: int,
1659
- ) -> Iterator[tuple[PathParts, Any]]:
1660
- if level_ > allowed_hierarchy[1]:
1661
- return
1662
-
1663
- if _is_node(node_):
1664
- if id(node_) in visited_:
1665
- return
1666
-
1667
- visited_.add(id(node_))
1668
- node_dict = _get_node_impl(node_).node_dict(node_)
1669
- for key, value in node_dict.items():
1670
- yield from _iter_graph_node(value, visited_, (*path_parts_, key),
1671
- level_ + 1 if _is_graph_node(value) else level_)
1672
-
1673
- if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
1674
- yield path_parts_, node_
1675
-
1676
- visited: set[int] = set()
1677
- path_parts: PathParts = ()
1678
- level: int = 0
1679
- yield from _iter_graph_node(node, visited, path_parts, level)
1680
-
1681
-
1682
- # --------------------------------------------------------
1683
- # Graph operations: end
1684
- # --------------------------------------------------------
1685
-
1686
-
1687
- @dataclasses.dataclass(frozen=True)
1688
- class Static(Generic[A]):
1689
- """An empty pytree node that treats its inner value as static.
1690
- ``value`` must define ``__eq__`` and ``__hash__``.
1691
- """
1692
-
1693
- value: A
1694
-
1695
-
1696
- jax.tree_util.register_static(Static)
1697
-
1698
-
1699
- # ---------------------------------------------------------
1700
- # Pytree
1701
- # ---------------------------------------------------------
1702
-
1703
- class PytreeType:
1704
- ...
1705
-
1706
-
1707
- def _key_path_to_key(key: Any) -> Key:
1708
- if isinstance(key, jax.tree_util.SequenceKey):
1709
- return key.idx
1710
- elif isinstance(
1711
- key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
1712
- ):
1713
- if not isinstance(key.key, Key):
1714
- raise ValueError(
1715
- f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
1716
- )
1717
- return key.key
1718
- elif isinstance(key, jax.tree_util.GetAttrKey):
1719
- return key.name
1720
- else:
1721
- return str(key)
1722
-
1723
-
1724
- def _flatten_pytree(pytree: Any):
1725
- leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
1726
- nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
1727
- return nodes, treedef
1728
-
1729
-
1730
- def _unflatten_pytree(
1731
- nodes: tuple[tuple[Key, Any], ...],
1732
- treedef: jax.tree_util.PyTreeDef
1733
- ):
1734
- pytree = treedef.unflatten(value for _, value in nodes)
1735
- return pytree
1736
-
1737
-
1738
- PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import dataclasses
21
+ from typing import (
22
+ Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
23
+ Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload
24
+ )
25
+
26
+ import jax
27
+ import numpy as np
28
+ from typing_extensions import TypeGuard, Unpack
29
+
30
+ from brainstate._state import State, TreefyState
31
+ from brainstate._utils import set_module_as
32
+ from brainstate.typing import PathParts, Filter, Predicate, Key
33
+ from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
34
+ from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
+ from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
+ from brainstate.util.struct import FrozenDict
37
+ from brainstate.util.filter import to_predicate
38
+
39
+ _max_int = np.iinfo(np.int32).max
40
+
41
+ __all__ = [
42
+ # state management in the given graph or node
43
+ 'pop_states', 'nodes', 'states', 'treefy_states', 'update_states',
44
+
45
+ # graph node operations
46
+ 'flatten', 'unflatten', 'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef', 'call',
47
+
48
+ # others
49
+ 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef'
50
+ ]
51
+
52
+ A = TypeVar('A')
53
+ B = TypeVar('B')
54
+ C = TypeVar('C')
55
+ F = TypeVar('F', bound=Callable)
56
+
57
+ HA = TypeVar('HA', bound=Hashable)
58
+ HB = TypeVar('HB', bound=Hashable)
59
+
60
+ Index = int
61
+ Names = Sequence[int]
62
+ Node = TypeVar('Node')
63
+ Leaf = TypeVar('Leaf')
64
+ AuxData = TypeVar('AuxData')
65
+
66
+ StateLeaf = TreefyState[Any]
67
+ NodeLeaf = State[Any]
68
+ GraphStateMapping = NestedDict[Key, StateLeaf]
69
+
70
+
71
+ # --------------------------------------------------------
72
+
73
+
74
+ def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
75
+ return isinstance(x, TreefyState)
76
+
77
+
78
+ def _is_node_leaf(x: Any) -> TypeGuard[NodeLeaf]:
79
+ return isinstance(x, State)
80
+
81
+
82
+ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
83
+ """
84
+ A mapping that uses object id as the hash for the keys.
85
+
86
+ This mapping is useful when we want to keep track of objects
87
+ that are being referenced by other objects.
88
+
89
+ Args:
90
+ mapping: A mapping or iterable of key-value pairs.
91
+
92
+ """
93
+ __module__ = 'brainstate.graph'
94
+
95
+ def __init__(self, mapping: Mapping[A, B] | Iterable[Tuple[A, B]] = ()):
96
+ self._mapping: Dict[int, Tuple[A, B]] = {}
97
+ self.update(mapping)
98
+
99
+ def __getitem__(self, key: A) -> B:
100
+ return self._mapping[id(key)][1]
101
+
102
+ def __contains__(self, key: Any) -> bool:
103
+ return id(key) in self._mapping
104
+
105
+ def __setitem__(self, key: A, value: B):
106
+ self._mapping[id(key)] = (key, value)
107
+
108
+ def __delitem__(self, key: A):
109
+ del self._mapping[id(key)]
110
+
111
+ def __iter__(self) -> Iterator[A]:
112
+ return (key for key, _ in self._mapping.values())
113
+
114
+ def __len__(self) -> int:
115
+ return len(self._mapping)
116
+
117
+ def __str__(self) -> str:
118
+ return repr(self)
119
+
120
+
121
+ @dataclasses.dataclass(frozen=True)
122
+ class NodeImplBase(Generic[Node, Leaf, AuxData]):
123
+ type: type
124
+ flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
125
+
126
+ def node_dict(self, node: Node) -> dict[Key, Leaf]:
127
+ nodes, _ = self.flatten(node)
128
+ return dict(nodes)
129
+
130
+
131
+ @dataclasses.dataclass(frozen=True)
132
+ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
133
+ set_key: Callable[[Node, Key, Leaf], None]
134
+ pop_key: Callable[[Node, Key], Leaf]
135
+ create_empty: Callable[[AuxData], Node]
136
+ clear: Callable[[Node], None]
137
+
138
+ def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]):
139
+ for key, value in items:
140
+ self.set_key(node, key, value)
141
+
142
+
143
+ @dataclasses.dataclass(frozen=True)
144
+ class PyTreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
145
+ unflatten: Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
146
+
147
+
148
+ NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf, AuxData]]
149
+
150
+ # --------------------------------------------------------
151
+ # Graph Node implementation: start
152
+ # --------------------------------------------------------
153
+
154
+ _node_impl_for_type: dict[type, NodeImpl[Any, Any, Any]] = {}
155
+
156
+
157
+ def register_graph_node_type(
158
+ type: type,
159
+ flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
160
+ set_key: Callable[[Node, Key, Leaf], None],
161
+ pop_key: Callable[[Node, Key], Leaf],
162
+ create_empty: Callable[[AuxData], Node],
163
+ clear: Callable[[Node], None],
164
+ ):
165
+ """
166
+ Register a graph node type.
167
+
168
+ Args:
169
+ type: The type of the node.
170
+ flatten: A function that flattens the node into a sequence of key-value pairs.
171
+ set_key: A function that sets a key in the node.
172
+ pop_key: A function that pops a key from the node.
173
+ create_empty: A function that creates an empty node.
174
+ clear: A function that clears the node
175
+ """
176
+ _node_impl_for_type[type] = GraphNodeImpl(
177
+ type=type,
178
+ flatten=flatten,
179
+ set_key=set_key,
180
+ pop_key=pop_key,
181
+ create_empty=create_empty,
182
+ clear=clear,
183
+ )
184
+
185
+
186
+ # --------------------------------------------------------
187
+ # Graph node implementation: end
188
+ # --------------------------------------------------------
189
+
190
+
191
+ def _is_node(x: Any) -> bool:
192
+ return _is_graph_node(x) or _is_pytree_node(x)
193
+
194
+
195
+ def _is_pytree_node(x: Any) -> bool:
196
+ return not jax.tree_util.all_leaves((x,))
197
+
198
+
199
+ def _is_graph_node(x: Any) -> bool:
200
+ return type(x) in _node_impl_for_type
201
+
202
+
203
+ def _is_node_type(x: type[Any]) -> bool:
204
+ return x in _node_impl_for_type or x is PytreeType
205
+
206
+
207
+ def _get_node_impl(x: Node) -> NodeImpl[Node, Any, Any]:
208
+ if isinstance(x, State):
209
+ raise ValueError(f'State is not a node: {x}')
210
+
211
+ node_type = type(x)
212
+ if node_type not in _node_impl_for_type:
213
+ if _is_pytree_node(x):
214
+ return PYTREE_NODE_IMPL
215
+ else:
216
+ raise ValueError(f'Unknown node type: {x}')
217
+
218
+ return _node_impl_for_type[node_type]
219
+
220
+
221
+ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, Any, Any]:
222
+ if x is PytreeType:
223
+ return PYTREE_NODE_IMPL
224
+ return _node_impl_for_type[x]
225
+
226
+
227
+ class HashableMapping(Mapping[HA, HB], Hashable):
228
+ def __init__(self, mapping: Mapping[HA, HB] | Iterable[tuple[HA, HB]]):
229
+ self._mapping = dict(mapping)
230
+
231
+ def __contains__(self, key: object) -> bool:
232
+ return key in self._mapping
233
+
234
+ def __getitem__(self, key: HA) -> HB:
235
+ return self._mapping[key]
236
+
237
+ def __iter__(self) -> Iterator[HA]:
238
+ return iter(self._mapping)
239
+
240
+ def __len__(self) -> int:
241
+ return len(self._mapping)
242
+
243
+ def __hash__(self) -> int:
244
+ return hash(tuple(sorted(self._mapping.items())))
245
+
246
+ def __eq__(self, other: Any) -> bool:
247
+ return isinstance(other, HashableMapping) and self._mapping == other._mapping
248
+
249
+ def __repr__(self) -> str:
250
+ return repr(self._mapping)
251
+
252
+
253
+ class GraphDef(Generic[Node]):
254
+ """
255
+ A base dataclass that denotes the graph structure of a :class:`Node`.
256
+
257
+ It contains two main components:
258
+ - type: The type of the node.
259
+ - index: The index of the node in the graph.
260
+
261
+ It has two concrete subclasses:
262
+ - :class:`NodeRef`: A reference to a node in the graph.
263
+ - :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
264
+
265
+ """
266
+ type: type[Node]
267
+ index: int
268
+
269
+
270
+ @dataclasses.dataclass(frozen=True, repr=False)
271
+ class NodeRef(GraphDef[Node], PrettyRepr):
272
+ """
273
+ A reference to a node in the graph.
274
+
275
+ The node can be instances of :class:`Node` or :class:`State`.
276
+ """
277
+ type: type[Node]
278
+ index: int
279
+
280
+ def __pretty_repr__(self):
281
+ yield PrettyType(type=type(self))
282
+ yield PrettyAttr('type', self.type.__name__)
283
+ yield PrettyAttr('index', self.index)
284
+
285
+ def __treescope_repr__(self, path, subtree_renderer):
286
+ """
287
+ Treescope repr for the object.
288
+ """
289
+ import treescope # type: ignore[import-not-found,import-untyped]
290
+ return treescope.repr_lib.render_object_constructor(
291
+ object_type=type(self),
292
+ attributes={'type': self.type, 'index': self.index},
293
+ path=path,
294
+ subtree_renderer=subtree_renderer,
295
+ )
296
+
297
+
298
+ jax.tree_util.register_static(NodeRef)
299
+
300
+
301
+ @dataclasses.dataclass(frozen=True, repr=False)
302
+ class NodeDef(GraphDef[Node], PrettyRepr):
303
+ """
304
+ A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
305
+
306
+ """
307
+
308
+ type: Type[Node] # type of the node
309
+ index: int # index of the node in the graph
310
+ attributes: Tuple[Key, ...] # attributes for the node
311
+ subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
312
+ static_fields: HashableMapping[Key, Any]
313
+ leaves: HashableMapping[Key, NodeRef[Any] | None]
314
+ metadata: Hashable
315
+ index_mapping: FrozenDict[Index, Index] | None
316
+
317
+ @classmethod
318
+ def create(
319
+ cls,
320
+ type: Type[Node],
321
+ index: int,
322
+ attributes: tuple[Key, ...],
323
+ subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
324
+ static_fields: Iterable[tuple[Key, Any]],
325
+ leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
326
+ metadata: Hashable,
327
+ index_mapping: Mapping[Index, Index] | None,
328
+ ):
329
+ return cls(
330
+ type=type,
331
+ index=index,
332
+ attributes=attributes,
333
+ subgraphs=HashableMapping(subgraphs),
334
+ static_fields=HashableMapping(static_fields),
335
+ leaves=HashableMapping(leaves),
336
+ metadata=metadata,
337
+ index_mapping=FrozenDict(index_mapping) if index_mapping is not None else None,
338
+ )
339
+
340
+ def __pretty_repr__(self):
341
+ yield PrettyType(type=type(self))
342
+
343
+ yield PrettyAttr('type', self.type.__name__)
344
+ yield PrettyAttr('index', self.index)
345
+ yield PrettyAttr('attributes', self.attributes)
346
+ yield PrettyAttr('subgraphs', PrettyMapping(self.subgraphs))
347
+ yield PrettyAttr('static_fields', PrettyMapping(self.static_fields))
348
+ yield PrettyAttr('leaves', PrettyMapping(self.leaves))
349
+ yield PrettyAttr('metadata', self.metadata)
350
+ yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
351
+
352
+ def apply(
353
+ self,
354
+ state_map: GraphStateMapping,
355
+ *state_maps: GraphStateMapping
356
+ ) -> ApplyCaller[tuple[GraphDef[Node], GraphStateMapping]]:
357
+ accessor = DelayedAccessor()
358
+
359
+ def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[
360
+ Any, tuple[GraphDef[Node], GraphStateMapping]]:
361
+ module = treefy_merge(self, state_map, *state_maps)
362
+ fn = accessor(module)
363
+ out = fn(*args, **kwargs)
364
+ return out, flatten(module)
365
+
366
+ return CallableProxy(_apply, accessor) # type: ignore
367
+
368
+
369
+ jax.tree_util.register_static(NodeDef)
370
+
371
+
372
+ # --------------------------------------------------------
373
+ # Graph operations: start
374
+ # --------------------------------------------------------
375
+
376
+
377
+ def _graph_flatten(
378
+ path: PathParts,
379
+ ref_index: RefMap[Any, Index],
380
+ flatted_state_mapping: Dict[PathParts, StateLeaf],
381
+ node: Node,
382
+ treefy_state: bool = False,
383
+ ):
384
+ """
385
+ Recursive helper for graph flatten.
386
+
387
+ Args:
388
+ path: The path to the node.
389
+ ref_index: A mapping from nodes to indexes.
390
+ flatted_state_mapping: A mapping from paths to state leaves.
391
+ node: The node to flatten.
392
+
393
+ Returns:
394
+ A NodeDef or a NodeRef.
395
+ """
396
+ if not _is_node(node):
397
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
398
+
399
+ # If the node is already in the cache, return a reference, otherwise
400
+ # add it to the cache and continue with the flattening process.
401
+ # This is done to avoid infinite recursion when there is a reference cycle.
402
+ if node in ref_index:
403
+ return NodeRef(type(node), ref_index[node])
404
+
405
+ # Get the node implementation for the node type.
406
+ # There are two types of node implementations: GraphNodeImpl and PyTreeNodeImpl.
407
+ # - ``GraphNodeImpl`` is used for nodes that have a graph structure.
408
+ # - ``PyTreeNodeImpl`` is used for nodes that have a tree structure.
409
+ node_impl = _get_node_impl(node)
410
+
411
+ # There are two types of nodes: Node and State.
412
+ # Here we handle the Node case.
413
+ if isinstance(node_impl, GraphNodeImpl):
414
+ # add the node to the cache
415
+ index = len(ref_index)
416
+ ref_index[node] = index
417
+ else:
418
+ index = -1
419
+
420
+ subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
421
+ static_fields: list[tuple[Key, Any]] = []
422
+ leaves: list[tuple[Key, NodeRef | None]] = []
423
+
424
+ # Flatten the node into a sequence of key-value pairs.
425
+ values, metadata = node_impl.flatten(node)
426
+ for key, value in values:
427
+ if _is_node(value):
428
+ # Recursively flatten the subgraph.
429
+ nodedef = _graph_flatten((*path, key), ref_index, flatted_state_mapping, value, treefy_state)
430
+ subgraphs.append((key, nodedef))
431
+ elif isinstance(value, State):
432
+ # If the variable is in the cache, add a reference to it.
433
+ if value in ref_index:
434
+ leaves.append((key, NodeRef(type(value), ref_index[value])))
435
+ else:
436
+ # If the variable is not in the cache, add it to the cache.
437
+ # This is done to avoid multiple references to the same variable.
438
+ flatted_state_mapping[(*path, key)] = (value.to_state_ref() if treefy_state else value)
439
+ variable_index = ref_index[value] = len(ref_index)
440
+ leaves.append((key, NodeRef(type(value), variable_index)))
441
+ elif _is_state_leaf(value):
442
+ # The instance of ``TreefyState`` is a leaf.
443
+ flatted_state_mapping[(*path, key)] = value
444
+ leaves.append((key, None))
445
+ else:
446
+ # if isinstance(value, (jax.Array, np.ndarray)):
447
+ # path_str = '/'.join(map(str, (*path, key)))
448
+ # raise ValueError(f'Arrays leaves are not supported, at {path_str!r}: {value}')
449
+
450
+ # The value is a static field.
451
+ static_fields.append((key, value))
452
+
453
+ nodedef = NodeDef.create(type=node_impl.type,
454
+ index=index,
455
+ attributes=tuple(key for key, _ in values),
456
+ subgraphs=subgraphs,
457
+ static_fields=static_fields,
458
+ leaves=leaves,
459
+ metadata=metadata,
460
+ index_mapping=None, )
461
+ return nodedef
462
+
463
+
464
+ @set_module_as('brainstate.graph')
465
+ def flatten(
466
+ node: Node,
467
+ /,
468
+ ref_index: Optional[RefMap[Any, Index]] = None,
469
+ treefy_state: bool = True,
470
+ ) -> Tuple[GraphDef, NestedDict]:
471
+ """
472
+ Flattens a graph node into a (graph_def, state_mapping) pair.
473
+
474
+ Example::
475
+
476
+ >>> import brainstate as brainstate
477
+ >>> node = brainstate.graph.Node()
478
+ >>> graph_def, state_mapping = flatten(node)
479
+ >>> print(graph_def)
480
+ >>> print(state_mapping)
481
+
482
+ Args:
483
+ node: A graph node.
484
+ ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
485
+ empty dictionary is created. This argument can be used to flatten a sequence of graph
486
+ nodes that share references.
487
+ treefy_state: If True, the state mapping will be a NestedDict instead of a flat dictionary.
488
+ """
489
+ ref_index = RefMap() if ref_index is None else ref_index
490
+ assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
491
+ flatted_state_mapping: dict[PathParts, StateLeaf] = {}
492
+ graph_def = _graph_flatten((), ref_index, flatted_state_mapping, node, treefy_state)
493
+ return graph_def, NestedDict.from_flat(flatted_state_mapping)
494
+
495
+
496
+ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
497
+ children: dict[Key, StateLeaf | Node] = {}
498
+
499
+ # NOTE: we could allow adding new StateLeafs here
500
+ # All state keys must be present in the graph definition (the object attributes)
501
+ if unknown_keys := set(state_mapping) - set(graph_def.attributes):
502
+ raise ValueError(f'Unknown keys: {unknown_keys}')
503
+
504
+ # for every key in attributes there are 6 possible cases:
505
+ # - (2) the key can either be present in the state or not
506
+ # - (3) the key can be a subgraph, a leaf, or a static attribute
507
+ for key in graph_def.attributes:
508
+ if key not in state_mapping: # static field
509
+ # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
510
+ # if key is not present, create an empty types
511
+ if key in graph_def.static_fields:
512
+ children[key] = graph_def.static_fields[key]
513
+
514
+ elif key in graph_def.subgraphs:
515
+ # if the key is a subgraph we create an empty node
516
+ subgraphdef = graph_def.subgraphs[key]
517
+ if isinstance(subgraphdef, NodeRef):
518
+ # subgraph exists, take it from the cache
519
+ children[key] = index_ref[subgraphdef.index]
520
+
521
+ else:
522
+ # create a node from an empty state, reasoning:
523
+ # * it is a node with no state
524
+ # * it is a node with state but only through references of already
525
+ # created nodes
526
+ substate = {}
527
+ children[key] = _graph_unflatten(subgraphdef, substate, index_ref, index_ref_cache)
528
+
529
+ elif key in graph_def.leaves:
530
+ noderef = graph_def.leaves[key]
531
+ if (noderef is not None) and (noderef.index in index_ref):
532
+ # variable exists, take it from the cache
533
+ children[key] = index_ref[noderef.index]
534
+
535
+ else:
536
+ # key for a variable is missing, raise an error
537
+ raise ValueError(f'Expected key {key!r} in state while building node of type '
538
+ f'{graph_def.type.__name__}.')
539
+
540
+ else:
541
+ raise RuntimeError(f'Unknown static field: {key!r}')
542
+
543
+ else: # state field
544
+ value = state_mapping[key]
545
+ if isinstance(value, PrettyDict):
546
+ value = dict(value)
547
+
548
+ if key in graph_def.static_fields:
549
+ raise ValueError(f'Got state for static field {key!r}, this is not supported.')
550
+
551
+ if key in graph_def.subgraphs:
552
+ # if _is_state_leaf(value):
553
+ if isinstance(value, (TreefyState, State)):
554
+ raise ValueError(f'Expected value of type {graph_def.subgraphs[key]} '
555
+ f'for {key!r}, but got {value!r}')
556
+ if not isinstance(value, dict):
557
+ raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
558
+
559
+ subgraphdef = graph_def.subgraphs[key]
560
+ if isinstance(subgraphdef, NodeRef):
561
+ children[key] = index_ref[subgraphdef.index]
562
+ else:
563
+ children[key] = _graph_unflatten(subgraphdef, value, index_ref, index_ref_cache)
564
+
565
+ elif key in graph_def.leaves:
566
+ # if not _is_state_leaf(value):
567
+ if not isinstance(value, (TreefyState, State)):
568
+ raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
569
+
570
+ noderef = graph_def.leaves[key]
571
+ if noderef is None:
572
+ # if the leaf is None, it means that the value was originally
573
+ # a non-TreefyState leaf, however we allow providing a
574
+ # TreefyState presumbly created by modifying the NestedDict
575
+ if isinstance(value, TreefyState):
576
+ value = value.to_state()
577
+ # elif isinstance(value, State):
578
+ # value = value
579
+ children[key] = value
580
+
581
+ elif noderef.index in index_ref:
582
+ # add an existing variable
583
+ children[key] = index_ref[noderef.index]
584
+
585
+ else:
586
+ # it is an unseen variable, create a new one
587
+ if not isinstance(value, (TreefyState, State)):
588
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
589
+ # when idxmap is present, check if the Varable exists there
590
+ # and update existing variables if it does
591
+ if index_ref_cache is not None and noderef.index in index_ref_cache:
592
+ variable = index_ref_cache[noderef.index]
593
+ if not isinstance(variable, State):
594
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(variable)}.')
595
+ if isinstance(value, TreefyState):
596
+ variable.update_from_ref(value)
597
+ elif isinstance(value, State):
598
+ if value._been_writen:
599
+ variable.value = value.value
600
+ else:
601
+ variable.restore_value(value.value)
602
+ else:
603
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
604
+ else: # if it doesn't, create a new variable
605
+ if isinstance(value, TreefyState):
606
+ variable = value.to_state()
607
+ elif isinstance(value, State):
608
+ variable = value
609
+ else:
610
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
611
+ children[key] = variable
612
+ index_ref[noderef.index] = variable
613
+
614
+ else:
615
+ raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
616
+
617
+ return children
618
+
619
+
620
+ def _graph_unflatten(
621
+ graph_def: NodeDef[Node] | NodeRef[Node],
622
+ state_mapping: Mapping[Key, StateLeaf | Mapping[Key, Any]],
623
+ index_ref: dict[Index, Any],
624
+ index_ref_cache: dict[Index, Any] | None,
625
+ ) -> Node:
626
+ """
627
+ Recursive helper for graph unflatten.
628
+
629
+ Args:
630
+ graph_def: A `GraphDef` instance or an index to a node in the cache.
631
+ state_mapping: A state mapping from attribute names to variables or subgraphs.
632
+ index_ref: A mapping from indexes to nodes that have been traversed.
633
+ If a node is already in the cache, it won't be traversed again.
634
+ index_ref_cache: A mapping from indexes to existing nodes that can be reused.
635
+ When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
636
+ object in an empty state and then filled by the unflatten process, as a result
637
+ existing graph nodes are mutated to have the new content/topology
638
+ specified by the nodedef.
639
+
640
+ Returns:
641
+ A node instance.
642
+ """
643
+
644
+ # if the graph_def is a reference, this means that the node has already been created, so
645
+ # we return the node from the cache
646
+ if isinstance(graph_def, NodeRef):
647
+ return index_ref[graph_def.index]
648
+ else:
649
+ assert isinstance(graph_def, NodeDef), f"graph_def must be a NodeDef. But we got: {graph_def}"
650
+
651
+ # graph_def must be a registered node type
652
+ if not _is_node_type(graph_def.type):
653
+ raise RuntimeError(f'Unsupported type: {graph_def.type}, this is a bug.')
654
+
655
+ # check if the index is already in the cache
656
+ if graph_def.index in index_ref:
657
+ raise RuntimeError(f'GraphDef index {graph_def.index} already used.')
658
+
659
+ # get the node implementation for the node type
660
+ node_impl = get_node_impl_for_type(graph_def.type)
661
+
662
+ if isinstance(node_impl, GraphNodeImpl):
663
+ # we create an empty node first and add it to the index
664
+ # this avoids infinite recursion when there is a reference cycle
665
+
666
+ if (index_ref_cache is not None) and (graph_def.index in index_ref_cache):
667
+ # clear the node to leave it in an empty state
668
+ node = index_ref_cache[graph_def.index]
669
+ if type(node) != graph_def.type:
670
+ raise ValueError(f'Expected a node of type {graph_def.type} for index '
671
+ f'{graph_def.index}, but got a node of type {type(node)}.')
672
+ node_impl.clear(node)
673
+ else:
674
+ # create an empty node
675
+ node = node_impl.create_empty(graph_def.metadata)
676
+
677
+ # add the node to the cache
678
+ index_ref[graph_def.index] = node
679
+
680
+ # get the children (the attributes) of the node
681
+ children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
682
+
683
+ # initialize the node with the children
684
+ node_impl.init(node, tuple(children.items()))
685
+
686
+ else:
687
+ # if the node type does not support the creation of an empty object it means
688
+ # that it cannot reference itself, so we can create its children first
689
+
690
+ # first, we create the children (attributes)
691
+ children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
692
+ # then, we create the node
693
+ node = node_impl.unflatten(tuple(children.items()), graph_def.metadata)
694
+
695
+ return node
696
+
697
+
698
+ @set_module_as('brainstate.graph')
699
+ def unflatten(
700
+ graph_def: GraphDef,
701
+ state_mapping: NestedDict[Key, StateLeaf],
702
+ /,
703
+ *,
704
+ index_ref: dict[Index, Any] | None = None,
705
+ index_ref_cache: dict[Index, Any] | None = None,
706
+ ) -> Node:
707
+ """
708
+ Unflattens a graphdef into a node with the given state tree mapping.
709
+
710
+ Example::
711
+
712
+ >>> import brainstate as brainstate
713
+ >>> class MyNode(brainstate.graph.Node):
714
+ ... def __init__(self):
715
+ ... self.a = brainstate.nn.Linear(2, 3)
716
+ ... self.b = brainstate.nn.Linear(3, 4)
717
+ ... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
718
+ ... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
719
+ ...
720
+ >>> graphdef, statetree = brainstate.graph.flatten(MyNode())
721
+ >>> statetree
722
+ NestedDict({
723
+ 'a': {
724
+ 'weight': TreefyState(
725
+ type=ParamState,
726
+ value={'weight': Array([[-0.8466386 , -2.0294454 , -0.6911647 ],
727
+ [ 0.60034966, -1.1869028 , 0.84003365]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
728
+ )
729
+ },
730
+ 'b': {
731
+ 'weight': TreefyState(
732
+ type=ParamState,
733
+ value={'weight': Array([[ 0.8565106 , -0.10337489],
734
+ [ 1.7449658 , 0.29128835],
735
+ [ 0.11441387, 1.0012752 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
736
+ )
737
+ },
738
+ 'c': {
739
+ 0: {
740
+ 'weight': TreefyState(
741
+ type=ParamState,
742
+ value={'weight': Array([[ 2.4465137, -0.5711426]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
743
+ )
744
+ },
745
+ 1: {
746
+ 'weight': TreefyState(
747
+ type=ParamState,
748
+ value={'weight': Array([[ 0.14321847, -2.4154725 , -0.6322363 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
749
+ )
750
+ }
751
+ },
752
+ 'd': {
753
+ 'x': {
754
+ 'weight': TreefyState(
755
+ type=ParamState,
756
+ value={'weight': Array([[ 0.9647322, -0.8958757, 1.585352 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
757
+ )
758
+ },
759
+ 'y': {
760
+ 'weight': TreefyState(
761
+ type=ParamState,
762
+ value={'weight': Array([[-1.2904786 , 0.5695903 , 0.40079263, 0.8769669 ]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
763
+ )
764
+ }
765
+ }
766
+ })
767
+ >>> node = brainstate.graph.unflatten(graphdef, statetree)
768
+ >>> node
769
+ MyNode(
770
+ a=Linear(
771
+ in_size=(2,),
772
+ out_size=(3,),
773
+ w_mask=None,
774
+ weight=ParamState(
775
+ value={'weight': Array([[ 0.55600464, -1.6276929 , 0.26805446],
776
+ [ 1.175099 , 1.0077754 , 0.37592274]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)},
777
+ )
778
+ ),
779
+ b=Linear(
780
+ in_size=(3,),
781
+ out_size=(4,),
782
+ w_mask=None,
783
+ weight=ParamState(
784
+ value={'weight': Array([[-0.24753566, 0.18456966, -0.29438975, 0.16891003],
785
+ [-0.803741 , -0.46037054, -0.21617596, 0.1260884 ],
786
+ [-0.43074366, -0.24757433, 1.2237076 , -0.07842704]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)},
787
+ )
788
+ ),
789
+ c=[Linear(
790
+ in_size=(4,),
791
+ out_size=(5,),
792
+ w_mask=None,
793
+ weight=ParamState(
794
+ value={'weight': Array([[-0.22384474, 0.79441446, -0.658726 , 0.05991402, 0.3014344 ],
795
+ [-1.4755846 , -0.42272082, -0.07692316, 0.03077666, 0.34513143],
796
+ [-0.69395834, 0.48617035, 1.1042316 , 0.13105175, -0.25620162],
797
+ [ 0.50389856, 0.6998943 , 0.43716812, 1.2168779 , -0.47325954]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)},
798
+ )
799
+ ), Linear(
800
+ in_size=(5,),
801
+ out_size=(6,),
802
+ w_mask=None,
803
+ weight=ParamState(
804
+ value={'weight': Array([[ 0.07714394, 0.78213537, 0.6745718 , -0.22881542, 0.5523547 ,
805
+ -0.6399196 ],
806
+ [-0.22626828, -0.54522336, 0.07448788, -0.00464636, 1.1483842 ,
807
+ -0.57049096],
808
+ [-0.86659616, 0.5683135 , -0.7449975 , 1.1862832 , 0.15047254,
809
+ 0.68890226],
810
+ [-1.0325443 , 0.2658072 , -0.10083053, -0.66915905, 0.11258496,
811
+ 0.5440655 ],
812
+ [ 0.27917263, 0.05717273, -0.5682605 , -0.88345915, 0.01314917,
813
+ 0.780759 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)},
814
+ )
815
+ )],
816
+ d={'x': Linear(
817
+ in_size=(6,),
818
+ out_size=(7,),
819
+ w_mask=None,
820
+ weight=ParamState(
821
+ value={'weight': Array([[-0.24238771, -0.23202638, 0.13663477, -0.48858666, 0.80871904,
822
+ 0.00593298, 0.7595096 ],
823
+ [ 0.50457454, 0.24180941, 0.25048748, 0.8937061 , 0.25398138,
824
+ -1.2400566 , 0.00151599],
825
+ [-0.19136038, 0.34470603, -0.11892717, -0.12514868, -0.5871703 ,
826
+ 0.13572927, -1.1859009 ],
827
+ [-0.01580911, 0.9301295 , -1.1246226 , -0.137708 , -0.4952151 ,
828
+ 0.17537868, 0.98440856],
829
+ [ 0.6399284 , 0.01739843, 0.61856824, 0.93258303, 0.64012206,
830
+ 0.22780116, -0.5763679 ],
831
+ [ 0.14077143, -1.0359222 , 0.28072503, 0.2557584 , -0.50622064,
832
+ 0.4388198 , -0.26106128]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
833
+ )
834
+ ), 'y': Linear(
835
+ in_size=(7,),
836
+ out_size=(8,),
837
+ w_mask=None,
838
+ weight=ParamState(
839
+ value={'weight': Array([[-0.23334591, -0.2893582 , 0.8071877 , -0.49038902, -0.29646504,
840
+ 0.13624157, 0.22763114, 0.01906361],
841
+ [-0.26742765, 0.20136863, 0.35148615, 0.42135832, 0.06401154,
842
+ -0.78036404, 0.6616062 , 0.19437549],
843
+ [ 0.9229799 , -0.1205209 , 0.69602865, 0.9685676 , -0.99886954,
844
+ -0.12649904, -0.15393028, 0.65067965],
845
+ [ 0.7020109 , -0.5452006 , 0.3649151 , -0.42368713, 0.24738027,
846
+ 0.29290223, -0.63721114, 0.6007214 ],
847
+ [-0.45045808, -0.08538888, -0.01338054, -0.39983988, 0.4028439 ,
848
+ 1.0498686 , -0.24730456, 0.37612835],
849
+ [ 0.16273966, 0.9001257 , 0.15190877, -1.1129239 , -0.29441378,
850
+ 0.5168159 , -0.4205143 , 0.45700482],
851
+ [ 0.08611429, -0.9271384 , -0.562362 , -0.586757 , 1.1611121 ,
852
+ 0.5137503 , -0.46277294, 0.84642583]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
853
+ )
854
+ )}
855
+ )
856
+
857
+ Args:
858
+ graph_def: A GraphDef instance.
859
+ state_mapping: A NestedDict instance.
860
+ index_ref: A mapping from indexes to nodes references found during the graph
861
+ traversal, defaults to None. If not provided, a new empty dictionary is
862
+ created. This argument can be used to unflatten a sequence of (graphdef, state_mapping)
863
+ pairs that share the same index space.
864
+ index_ref_cache: A mapping from indexes to existing nodes that can be reused.
865
+ When a reference is reused, ``GraphNodeImpl.clear`` is called to leave the
866
+ object in an empty state and then filled by the unflatten process, as a result
867
+ existing graph nodes are mutated to have the new content/topology
868
+ specified by the graphdef.
869
+ """
870
+ index_ref = {} if index_ref is None else index_ref
871
+ assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
872
+ node = _graph_unflatten(graph_def, state_mapping.to_dict(), index_ref, index_ref_cache)
873
+ return node
874
+
875
+
876
+ def _graph_pop(
877
+ node: Node,
878
+ id_to_index: dict[int, Index],
879
+ path_parts: PathParts,
880
+ flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
881
+ predicates: tuple[Predicate, ...],
882
+ ) -> None:
883
+ if not _is_node(node):
884
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
885
+
886
+ if id(node) in id_to_index:
887
+ return
888
+
889
+ id_to_index[id(node)] = len(id_to_index)
890
+ node_impl = _get_node_impl(node)
891
+ node_dict = node_impl.node_dict(node)
892
+
893
+ for name, value in node_dict.items():
894
+ if _is_node(value):
895
+ _graph_pop(
896
+ node=value,
897
+ id_to_index=id_to_index,
898
+ path_parts=(*path_parts, name),
899
+ flatted_state_dicts=flatted_state_dicts,
900
+ predicates=predicates,
901
+ )
902
+ continue
903
+ elif not _is_node_leaf(value):
904
+ continue
905
+ elif id(value) in id_to_index:
906
+ continue
907
+
908
+ node_path = (*path_parts, name)
909
+ node_impl = _get_node_impl(node)
910
+ for state_dicts, predicate in zip(flatted_state_dicts, predicates):
911
+ if predicate(node_path, value):
912
+ if isinstance(node_impl, PyTreeNodeImpl):
913
+ raise ValueError(f'Cannot pop key {name!r} from node of type {type(node).__name__}')
914
+ id_to_index[id(value)] = len(id_to_index)
915
+ node_impl.pop_key(node, name)
916
+ # if isinstance(value, State):
917
+ # value = value.to_state_ref()
918
+ state_dicts[node_path] = value # type: ignore[index] # mypy is wrong here?
919
+ break
920
+ else:
921
+ # NOTE: should we raise an error here?
922
+ pass
923
+
924
+
925
+ @overload
926
+ def pop_states(node, filter1: Filter, /) -> NestedDict:
927
+ ...
928
+
929
+
930
+ @overload
931
+ def pop_states(node, filter1: Filter, filter2: Filter, /, *filters: Filter) -> tuple[NestedDict, ...]:
932
+ ...
933
+
934
+
935
+ @set_module_as('brainstate.graph')
936
+ def pop_states(
937
+ node: Node,
938
+ *filters: Any
939
+ ) -> Union[NestedDict[Key, State], Tuple[NestedDict[Key, State], ...]]:
940
+ """
941
+ Pop one or more :class:`State` types from the graph node.
942
+
943
+ Example usage::
944
+
945
+ >>> import brainstate as brainstate
946
+ >>> import jax.numpy as jnp
947
+
948
+ >>> class Model(brainstate.nn.Module):
949
+ ... def __init__(self):
950
+ ... super().__init__()
951
+ ... self.a = brainstate.nn.Linear(2, 3)
952
+ ... self.b = brainstate.nn.LIF([10, 2])
953
+
954
+ >>> model = Model()
955
+ >>> with brainstate.catch_new_states('new'):
956
+ ... brainstate.nn.init_all_states(model)
957
+
958
+ >>> assert len(model.states()) == 2
959
+ >>> model_states = brainstate.graph.pop_states(model, 'new')
960
+ >>> model_states
961
+ NestedDict({
962
+ 'b': {
963
+ 'V': {
964
+ 'st': ShortTermState(
965
+ value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
966
+ 0., 0., 0.], dtype=float32),
967
+ tag='new'
968
+ )
969
+ }
970
+ }
971
+ })
972
+
973
+ Args:
974
+ node: A graph node object.
975
+ *filters: One or more :class:`State` objects to filter by.
976
+
977
+ Returns:
978
+ The popped :class:`NestedDict` containing the :class:`State`
979
+ objects that were filtered for.
980
+ """
981
+ if len(filters) == 0:
982
+ raise ValueError('Expected at least one filter')
983
+
984
+ id_to_index: dict[int, Index] = {}
985
+ path_parts: PathParts = ()
986
+ predicates = tuple(to_predicate(filter) for filter in filters)
987
+ flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
988
+ _graph_pop(node=node,
989
+ id_to_index=id_to_index,
990
+ path_parts=path_parts,
991
+ flatted_state_dicts=flatted_state_dicts,
992
+ predicates=predicates, )
993
+ states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
994
+
995
+ if len(states) == 1:
996
+ return states[0]
997
+ else:
998
+ return states
999
+
1000
+
1001
+ def _split_state(
1002
+ state: GraphStateMapping,
1003
+ filters: tuple[Filter, ...],
1004
+ ) -> tuple[GraphStateMapping, Unpack[tuple[GraphStateMapping, ...]]]:
1005
+ if not filters:
1006
+ return (state,)
1007
+ states = state.split(*filters)
1008
+ if isinstance(states, NestedDict):
1009
+ return (states,)
1010
+ assert len(states) > 0
1011
+ return states # type: ignore[return-value]
1012
+
1013
+
1014
+ @overload
1015
+ def treefy_split(node: A, /) -> Tuple[GraphDef, NestedDict]:
1016
+ ...
1017
+
1018
+
1019
+ @overload
1020
+ def treefy_split(node: A, first: Filter, /) -> Tuple[GraphDef, NestedDict]:
1021
+ ...
1022
+
1023
+
1024
+ @overload
1025
+ def treefy_split(node: A, first: Filter, second: Filter, /) -> Tuple[GraphDef, NestedDict, NestedDict]:
1026
+ ...
1027
+
1028
+
1029
+ @overload
1030
+ def treefy_split(
1031
+ node: A, first: Filter, second: Filter, /, *filters: Filter,
1032
+ ) -> Tuple[GraphDef, NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1033
+ ...
1034
+
1035
+
1036
+ @set_module_as('brainstate.graph')
1037
+ def treefy_split(
1038
+ node: A,
1039
+ *filters: Filter
1040
+ ) -> Tuple[GraphDef[A], NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1041
+ """Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s. NestedDict is
1042
+ a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
1043
+ contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
1044
+ to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
1045
+ seamlessly between stateful and stateless representations of the graph.
1046
+
1047
+ Example usage::
1048
+
1049
+ >>> from joblib.testing import param >>> import brainstate as brainstate
1050
+ >>> import jax, jax.numpy as jnp
1051
+ ...
1052
+ >>> class Foo(brainstate.graph.Node):
1053
+ ... def __init__(self):
1054
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1055
+ ... self.b = brainstate.nn.Linear(2, 3)
1056
+ ...
1057
+ >>> node = Foo()
1058
+ >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1059
+ ...
1060
+ >>> params
1061
+ NestedDict({
1062
+ 'a': {
1063
+ 'weight': TreefyState(
1064
+ type=ParamState,
1065
+ value={'weight': Array([[-1.0013659, 1.5763807],
1066
+ [ 1.7149199, 2.0140953]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
1067
+ )
1068
+ },
1069
+ 'b': {
1070
+ 'weight': TreefyState(
1071
+ type=ParamState,
1072
+ value={'bias': Array([[0., 0.]], dtype=float32), 'scale': Array([[1., 1.]], dtype=float32)}
1073
+ )
1074
+ }
1075
+ })
1076
+ >>> jax.tree.map(jnp.shape, others)
1077
+ NestedDict({
1078
+ 'b': {
1079
+ 'running_mean': TreefyState(
1080
+ type=LongTermState,
1081
+ value=(1, 2)
1082
+ ),
1083
+ 'running_var': TreefyState(
1084
+ type=LongTermState,
1085
+ value=(1, 2)
1086
+ )
1087
+ }
1088
+ })
1089
+
1090
+ :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1091
+ transformations, see
1092
+ `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1093
+ for more information.
1094
+
1095
+ Arguments:
1096
+ node: graph node to split.
1097
+ *filters: some optional filters to group the state into mutually exclusive substates.
1098
+
1099
+ Returns:
1100
+ ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
1101
+ filters are passed, a single ``NestedDict`` is returned.
1102
+ """
1103
+ graphdef, state_tree = flatten(node)
1104
+ states = tuple(_split_state(state_tree, filters))
1105
+ return graphdef, *states
1106
+
1107
+
1108
+ @set_module_as('brainstate.graph')
1109
+ def treefy_merge(
1110
+ graphdef: GraphDef[A],
1111
+ state_mapping: GraphStateMapping,
1112
+ /,
1113
+ *state_mappings: GraphStateMapping,
1114
+ ) -> A:
1115
+ """The inverse of :func:`split`.
1116
+
1117
+ ``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
1118
+ a new node with the same structure as the original node.
1119
+
1120
+ Example usage::
1121
+
1122
+ >>> import brainstate as brainstate
1123
+ >>> import jax, jax.numpy as jnp
1124
+ ...
1125
+ >>> class Foo(brainstate.graph.Node):
1126
+ ... def __init__(self):
1127
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1128
+ ... self.b = brainstate.nn.Linear(2, 3)
1129
+ ...
1130
+ >>> node = Foo()
1131
+ >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1132
+ ...
1133
+ >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1134
+ >>> assert isinstance(new_node, Foo)
1135
+ >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1136
+ >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1137
+
1138
+ :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1139
+ transformations, see
1140
+ `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1141
+ for more information.
1142
+
1143
+ Args:
1144
+ graphdef: A :class:`GraphDef` object.
1145
+ state_mapping: A :class:`NestedDict` object.
1146
+ *state_mappings: Additional :class:`NestedDict` objects.
1147
+
1148
+ Returns:
1149
+ The merged :class:`Module`.
1150
+ """
1151
+ state_mapping = GraphStateMapping.merge(state_mapping, *state_mappings)
1152
+ node = unflatten(graphdef, state_mapping)
1153
+ return node
1154
+
1155
+
1156
+ def _filters_to_predicates(filters: Tuple[Filter, ...]) -> Tuple[Predicate, ...]:
1157
+ for i, filter_ in enumerate(filters):
1158
+ if filter_ in (..., True) and i != len(filters) - 1:
1159
+ remaining_filters = filters[i + 1:]
1160
+ if not all(f in (..., True) for f in remaining_filters):
1161
+ raise ValueError('`...` or `True` can only be used as the last filters, '
1162
+ f'got {filter_} it at index {i}.')
1163
+ return tuple(map(to_predicate, filters))
1164
+
1165
+
1166
+ def _split_flatted(
1167
+ flatted: Iterable[tuple[PathParts, Any]],
1168
+ filters: tuple[Filter, ...],
1169
+ ) -> tuple[list[tuple[PathParts, Any]], ...]:
1170
+ predicates = _filters_to_predicates(filters)
1171
+
1172
+ # we have n + 1 states, where n is the number of predicates
1173
+ # the last state is for values that don't match any predicate
1174
+ flat_states: tuple[list[tuple[PathParts, Any]], ...] = tuple([] for _ in predicates)
1175
+
1176
+ for path, value in flatted:
1177
+ for i, predicate in enumerate(predicates):
1178
+ if predicate(path, value):
1179
+ flat_states[i].append((path, value))
1180
+ break
1181
+ else:
1182
+ raise ValueError('Non-exhaustive filters, got a non-empty remainder: '
1183
+ f'{path} -> {value}.'
1184
+ '\nUse `...` to match all remaining elements.')
1185
+
1186
+ return flat_states
1187
+
1188
+
1189
+ @overload
1190
+ def nodes(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1191
+ ...
1192
+
1193
+
1194
+ @overload
1195
+ def nodes(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1196
+ ...
1197
+
1198
+
1199
+ @overload
1200
+ def nodes(
1201
+ node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1202
+ ) -> Tuple[FlattedDict[Key, Node], ...]:
1203
+ ...
1204
+
1205
+
1206
+ @set_module_as('brainstate.graph')
1207
+ def nodes(
1208
+ node,
1209
+ *filters: Filter,
1210
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1211
+ ) -> Union[FlattedDict[Key, Node], Tuple[FlattedDict[Key, Node], ...]]:
1212
+ """
1213
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1214
+ """
1215
+ num_filters = len(filters)
1216
+ if num_filters == 0:
1217
+ filters = (..., ...)
1218
+ else:
1219
+ filters = (*filters, ...)
1220
+
1221
+ nodes_iterable = iter_node(node, allowed_hierarchy=allowed_hierarchy)
1222
+ flat_nodes = _split_flatted(nodes_iterable, (*filters, ...))
1223
+ node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
1224
+ if num_filters < 2:
1225
+ return node_maps[0]
1226
+ return node_maps[:num_filters]
1227
+
1228
+
1229
+ def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
1230
+ for path, value in iter_leaf(node, allowed_hierarchy=allowed_hierarchy):
1231
+ if isinstance(value, State):
1232
+ yield path, value
1233
+
1234
+
1235
+ @overload
1236
+ def states(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1237
+ ...
1238
+
1239
+
1240
+ @overload
1241
+ def states(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1242
+ ...
1243
+
1244
+
1245
+ @overload
1246
+ def states(
1247
+ node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1248
+ ) -> tuple[FlattedDict[Key, State], ...]:
1249
+ ...
1250
+
1251
+
1252
+ @set_module_as('brainstate.graph')
1253
+ def states(
1254
+ node,
1255
+ *filters: Filter,
1256
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1257
+ ) -> Union[FlattedDict[Key, State], tuple[FlattedDict[Key, State], ...]]:
1258
+ """
1259
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1260
+ """
1261
+ num_filters = len(filters)
1262
+ if num_filters == 0:
1263
+ filters = (..., ...)
1264
+ else:
1265
+ filters = (*filters, ...)
1266
+
1267
+ states_iterable = _states_generator(node, allowed_hierarchy=allowed_hierarchy)
1268
+ flat_states = _split_flatted(states_iterable, (*filters, ...))
1269
+ state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
1270
+ if num_filters < 2:
1271
+ return state_maps[0]
1272
+ return state_maps[:num_filters]
1273
+
1274
+
1275
+ @overload
1276
+ def treefy_states(
1277
+ node, /, flatted: bool = False
1278
+ ) -> NestedDict[Key, TreefyState]:
1279
+ ...
1280
+
1281
+
1282
+ @overload
1283
+ def treefy_states(
1284
+ node, first: Filter, /, flatted: bool = False
1285
+ ) -> NestedDict[Key, TreefyState]:
1286
+ ...
1287
+
1288
+
1289
+ @overload
1290
+ def treefy_states(
1291
+ node, first: Filter, second: Filter, /, *filters: Filter, flatted: bool = False
1292
+ ) -> Tuple[NestedDict[Key, TreefyState], ...]:
1293
+ ...
1294
+
1295
+
1296
+ @set_module_as('brainstate.graph')
1297
+ def treefy_states(
1298
+ node, *filters,
1299
+ ) -> NestedDict[Key, TreefyState] | tuple[NestedDict[Key, TreefyState], ...]:
1300
+ """
1301
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1302
+
1303
+ Example usage::
1304
+
1305
+ >>> import brainstate as brainstate
1306
+ >>> class Model(brainstate.nn.Module):
1307
+ ... def __init__(self):
1308
+ ... super().__init__()
1309
+ ... self.l1 = brainstate.nn.Linear(2, 3)
1310
+ ... self.l2 = brainstate.nn.Linear(3, 4)
1311
+ ... def __call__(self, x):
1312
+ ... return self.l2(self.l1(x))
1313
+
1314
+ >>> model = Model()
1315
+ >>> # get the learnable parameters from the batch norm and linear layer
1316
+ >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1317
+ >>> # get them separately
1318
+ >>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
1319
+ >>> # get them together
1320
+ >>> states = brainstate.graph.treefy_states(model)
1321
+
1322
+ Args:
1323
+ node: A graph node object.
1324
+ *filters: One or more :class:`State` objects to filter by.
1325
+
1326
+ Returns:
1327
+ One or more :class:`NestedDict` mappings.
1328
+ """
1329
+ _, state_mapping = flatten(node)
1330
+ state_mappings: GraphStateMapping | tuple[GraphStateMapping, ...]
1331
+ if len(filters) == 0:
1332
+ state_mappings = state_mapping
1333
+ elif len(filters) == 1:
1334
+ state_mappings = state_mapping.filter(filters[0])
1335
+ else:
1336
+ state_mappings = state_mapping.filter(filters[0], filters[1], *filters[2:])
1337
+ return state_mappings
1338
+
1339
+
1340
+ def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
1341
+ if not _is_node(node):
1342
+ raise RuntimeError(f'Unsupported type: {type(node)}')
1343
+
1344
+ node_impl = _get_node_impl(node)
1345
+ node_dict = node_impl.node_dict(node)
1346
+ for key, value in state.items():
1347
+ # case 1: new state is being added
1348
+ if key not in node_dict:
1349
+ if isinstance(node_impl, PyTreeNodeImpl):
1350
+ raise ValueError(f'Cannot set key {key!r} on immutable node of '
1351
+ f'type {type(node).__name__}')
1352
+ if isinstance(value, State):
1353
+ value = value.copy() # TODO: chenge it to state_ref
1354
+ node_impl.set_key(node, key, value)
1355
+ continue
1356
+
1357
+ # check values are of the same type
1358
+ current_value = node_dict[key]
1359
+
1360
+ # case 2: subgraph is being updated
1361
+ if _is_node(current_value):
1362
+ if _is_state_leaf(value):
1363
+ raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
1364
+ _graph_update_dynamic(current_value, value)
1365
+ elif isinstance(value, TreefyState):
1366
+ # case 3: state leaf is being updated
1367
+ if not isinstance(current_value, State):
1368
+ raise ValueError(f'Trying to update a non-State attribute {key!r} with a State: '
1369
+ f'{value!r}')
1370
+ current_value.update_from_ref(value)
1371
+ elif _is_state_leaf(value):
1372
+ # case 4: state field is being updated
1373
+ if isinstance(node_impl, PyTreeNodeImpl):
1374
+ raise ValueError(f'Cannot set key {key!r} on immutable node of '
1375
+ f'type {type(node).__name__}')
1376
+ node_impl.set_key(node, key, value)
1377
+ else:
1378
+ raise ValueError(f'Unsupported update type: {type(value)} for key {key!r}')
1379
+
1380
+
1381
+ def update_states(
1382
+ node: Node,
1383
+ state_dict: NestedDict | FlattedDict,
1384
+ /,
1385
+ *state_dicts: NestedDict | FlattedDict
1386
+ ) -> None:
1387
+ """
1388
+ Update the given graph node with a new :class:`NestedMapping` in-place.
1389
+
1390
+ Args:
1391
+ node: A graph node to update.
1392
+ state_dict: A :class:`NestedMapping` object.
1393
+ *state_dicts: Additional :class:`NestedMapping` objects.
1394
+ """
1395
+ if state_dicts:
1396
+ state_dict = NestedDict.merge(state_dict, *state_dicts)
1397
+ _graph_update_dynamic(node, state_dict.to_dict())
1398
+
1399
+
1400
+ @set_module_as('brainstate.graph')
1401
+ def graphdef(node: Any, /) -> GraphDef[Any]:
1402
+ """Get the :class:`GraphDef` of the given graph node.
1403
+
1404
+ Example usage::
1405
+
1406
+ >>> import brainstate as brainstate
1407
+
1408
+ >>> model = brainstate.nn.Linear(2, 3)
1409
+ >>> graphdef, _ = brainstate.graph.treefy_split(model)
1410
+ >>> assert graphdef == brainstate.graph.graphdef(model)
1411
+
1412
+ Args:
1413
+ node: A graph node object.
1414
+
1415
+ Returns:
1416
+ The :class:`GraphDef` of the :class:`Module` object.
1417
+ """
1418
+ graphdef, _ = flatten(node)
1419
+ return graphdef
1420
+
1421
+
1422
+ @set_module_as('brainstate.graph')
1423
+ def clone(node: Node) -> Node:
1424
+ """
1425
+ Create a deep copy of the given graph node.
1426
+
1427
+ Example usage::
1428
+
1429
+ >>> import brainstate as brainstate
1430
+ >>> model = brainstate.nn.Linear(2, 3)
1431
+ >>> cloned_model = clone(model)
1432
+ >>> model.weight.value['bias'] += 1
1433
+ >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1434
+
1435
+ Args:
1436
+ node: A graph node object.
1437
+
1438
+ Returns:
1439
+ A deep copy of the :class:`Module` object.
1440
+ """
1441
+ graphdef, state = treefy_split(node)
1442
+ return treefy_merge(graphdef, state)
1443
+
1444
+
1445
+ @set_module_as('brainstate.graph')
1446
+ def call(
1447
+ graphdef_state: Tuple[GraphDef[A], GraphStateMapping],
1448
+ ) -> ApplyCaller[Tuple[GraphDef[A], GraphStateMapping]]:
1449
+ """Calls a method underlying graph node defined by a (GraphDef, NestedDict) pair.
1450
+
1451
+ ``call`` takes a ``(GraphDef, NestedDict)`` pair and creates a proxy object that can be
1452
+ used to call methods on the underlying graph node. When a method is called, the
1453
+ output is returned along with a new (GraphDef, NestedDict) pair that represents the
1454
+ updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method``
1455
+ > :func:`split`` but is more convenient to use in pure JAX functions.
1456
+
1457
+ Example::
1458
+
1459
+ >>> import brainstate as brainstate
1460
+ >>> import jax
1461
+ >>> import jax.numpy as jnp
1462
+ ...
1463
+ >>> class StatefulLinear(brainstate.graph.Node):
1464
+ ... def __init__(self, din, dout):
1465
+ ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1466
+ ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1467
+ ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1468
+ ...
1469
+ ... def increment(self):
1470
+ ... self.count.value += 1
1471
+ ...
1472
+ ... def __call__(self, x):
1473
+ ... self.increment()
1474
+ ... return x @ self.w.value + self.b.value
1475
+ ...
1476
+ >>> linear = StatefulLinear(3, 2)
1477
+ >>> linear_state = brainstate.graph.treefy_split(linear)
1478
+ ...
1479
+ >>> @jax.jit
1480
+ ... def forward(x, linear_state):
1481
+ ... y, linear_state = brainstate.graph.call(linear_state)(x)
1482
+ ... return y, linear_state
1483
+ ...
1484
+ >>> x = jnp.ones((1, 3))
1485
+ >>> y, linear_state = forward(x, linear_state)
1486
+ >>> y, linear_state = forward(x, linear_state)
1487
+ ...
1488
+ >>> linear = brainstate.graph.treefy_merge(*linear_state)
1489
+ >>> linear.count.value
1490
+ Array(2, dtype=uint32)
1491
+
1492
+ The proxy object returned by ``call`` supports indexing and attribute access
1493
+ to access nested methods. In the example below, the ``increment`` method indexing
1494
+ is used to call the ``increment`` method of the ``StatefulLinear`` module
1495
+ at the ``b`` key of a ``nodes`` dictionary.
1496
+
1497
+ >>> class StatefulLinear(brainstate.graph.Node):
1498
+ ... def __init__(self, din, dout):
1499
+ ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1500
+ ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1501
+ ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1502
+ ...
1503
+ ... def increment(self):
1504
+ ... self.count.value += 1
1505
+ ...
1506
+ ... def __call__(self, x):
1507
+ ... self.increment()
1508
+ ... return x @ self.w.value + self.b.value
1509
+ ...
1510
+ >>> nodes = dict(
1511
+ ... a=StatefulLinear(3, 2),
1512
+ ... b=StatefulLinear(2, 1),
1513
+ ... )
1514
+ ...
1515
+ >>> node_state = treefy_split(nodes)
1516
+ >>> # use attribute access
1517
+ >>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
1518
+ ...
1519
+ >>> nodes = treefy_merge(*node_state)
1520
+ >>> nodes['a'].count.value
1521
+ Array(0, dtype=uint32)
1522
+ >>> nodes['b'].count.value
1523
+ Array(1, dtype=uint32)
1524
+ """
1525
+
1526
+ def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
1527
+ node = treefy_merge(*graphdef_state)
1528
+ method = accessor(node)
1529
+ out = method(*args, **kwargs)
1530
+ return out, treefy_split(node)
1531
+
1532
+ return CallableProxy(pure_caller) # type: ignore
1533
+
1534
+
1535
+ @set_module_as('brainstate.graph')
1536
+ def iter_leaf(
1537
+ node: Any,
1538
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1539
+ ) -> Iterator[tuple[PathParts, Any]]:
1540
+ """Iterates over all nested leaves in the given graph node, including the current node.
1541
+
1542
+ ``iter_graph`` creates a generator that yields path and value pairs, where
1543
+ the path is a tuple of strings or integers representing the path to the value from the
1544
+ root. Repeated nodes are visited only once. Leaves include static values.
1545
+
1546
+ Example::
1547
+ >>> import brainstate as brainstate
1548
+ >>> import jax.numpy as jnp
1549
+ ...
1550
+ >>> class Linear(brainstate.nn.Module):
1551
+ ... def __init__(self, din, dout):
1552
+ ... super().__init__()
1553
+ ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1554
+ ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1555
+ ... self.a = 1
1556
+ ...
1557
+ >>> module = Linear(3, 4)
1558
+ ...
1559
+ >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1560
+ ... print(path, type(value).__name__)
1561
+ ...
1562
+ (0, 'a') int
1563
+ (0, 'bias') ParamState
1564
+ (0, 'weight') ParamState
1565
+
1566
+ Parameters
1567
+ ----------
1568
+ node: Node
1569
+ The node to iterate over.
1570
+ allowed_hierarchy: tuple of int
1571
+ The allowed hierarchy.
1572
+
1573
+ """
1574
+
1575
+ def _iter_graph_leaf(
1576
+ node_: Any,
1577
+ visited_: set[int],
1578
+ path_parts_: PathParts,
1579
+ level_: int,
1580
+ ) -> Iterator[tuple[PathParts, Any]]:
1581
+ if level_ > allowed_hierarchy[1]:
1582
+ return
1583
+
1584
+ if _is_node(node_):
1585
+ if id(node_) in visited_:
1586
+ return
1587
+ visited_.add(id(node_))
1588
+ node_dict = _get_node_impl(node_).node_dict(node_)
1589
+ for key, value in node_dict.items():
1590
+ yield from _iter_graph_leaf(
1591
+ value,
1592
+ visited_,
1593
+ (*path_parts_, key),
1594
+ level_ + 1 if _is_graph_node(value) else level_
1595
+ )
1596
+ else:
1597
+ if level_ >= allowed_hierarchy[0]:
1598
+ yield path_parts_, node_
1599
+
1600
+ visited: set[int] = set()
1601
+ path_parts: PathParts = ()
1602
+ level: int = 0
1603
+ yield from _iter_graph_leaf(node, visited, path_parts, level)
1604
+
1605
+
1606
+ @set_module_as('brainstate.graph')
1607
+ def iter_node(
1608
+ node: Any,
1609
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1610
+ ) -> Iterator[Tuple[PathParts, Any]]:
1611
+ """
1612
+ Iterates over all nested nodes of the given graph node, including the current node.
1613
+
1614
+ ``iter_graph`` creates a generator that yields path and value pairs, where
1615
+ the path is a tuple of strings or integers representing the path to the value from the
1616
+ root. Repeated nodes are visited only once. Leaves include static values.
1617
+
1618
+ Example::
1619
+ >>> import brainstate as brainstate
1620
+ >>> import jax.numpy as jnp
1621
+ ...
1622
+ >>> class Model(brainstate.nn.Module):
1623
+ ... def __init__(self):
1624
+ ... super().__init__()
1625
+ ... self.a = brainstate.nn.Linear(1, 2)
1626
+ ... self.b = brainstate.nn.Linear(2, 3)
1627
+ ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1628
+ ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1629
+ ... self.b.a = brainstate.nn.LIF(2)
1630
+ ...
1631
+ >>> model = Model()
1632
+ ...
1633
+ >>> for path, node in brainstate.graph.iter_node([model, model]):
1634
+ ... print(path, node.__class__.__name__)
1635
+ ...
1636
+ (0, 'a') Linear
1637
+ (0, 'b', 'a') LIF
1638
+ (0, 'b') Linear
1639
+ (0, 'c', 0) Linear
1640
+ (0, 'c', 1) Linear
1641
+ (0, 'd', 'x') Linear
1642
+ (0, 'd', 'y') Linear
1643
+ (0,) Model
1644
+
1645
+ Parameters
1646
+ ----------
1647
+ node: Node
1648
+ The node to iterate over.
1649
+ allowed_hierarchy: tuple of int
1650
+ The allowed hierarchy.
1651
+
1652
+ """
1653
+
1654
+ def _iter_graph_node(
1655
+ node_: Any,
1656
+ visited_: set[int],
1657
+ path_parts_: PathParts,
1658
+ level_: int,
1659
+ ) -> Iterator[tuple[PathParts, Any]]:
1660
+ if level_ > allowed_hierarchy[1]:
1661
+ return
1662
+
1663
+ if _is_node(node_):
1664
+ if id(node_) in visited_:
1665
+ return
1666
+
1667
+ visited_.add(id(node_))
1668
+ node_dict = _get_node_impl(node_).node_dict(node_)
1669
+ for key, value in node_dict.items():
1670
+ yield from _iter_graph_node(value, visited_, (*path_parts_, key),
1671
+ level_ + 1 if _is_graph_node(value) else level_)
1672
+
1673
+ if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
1674
+ yield path_parts_, node_
1675
+
1676
+ visited: set[int] = set()
1677
+ path_parts: PathParts = ()
1678
+ level: int = 0
1679
+ yield from _iter_graph_node(node, visited, path_parts, level)
1680
+
1681
+
1682
+ # --------------------------------------------------------
1683
+ # Graph operations: end
1684
+ # --------------------------------------------------------
1685
+
1686
+
1687
+ @dataclasses.dataclass(frozen=True)
1688
+ class Static(Generic[A]):
1689
+ """An empty pytree node that treats its inner value as static.
1690
+ ``value`` must define ``__eq__`` and ``__hash__``.
1691
+ """
1692
+
1693
+ value: A
1694
+
1695
+
1696
+ jax.tree_util.register_static(Static)
1697
+
1698
+
1699
+ # ---------------------------------------------------------
1700
+ # Pytree
1701
+ # ---------------------------------------------------------
1702
+
1703
+ class PytreeType:
1704
+ ...
1705
+
1706
+
1707
+ def _key_path_to_key(key: Any) -> Key:
1708
+ if isinstance(key, jax.tree_util.SequenceKey):
1709
+ return key.idx
1710
+ elif isinstance(
1711
+ key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
1712
+ ):
1713
+ if not isinstance(key.key, Key):
1714
+ raise ValueError(
1715
+ f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
1716
+ )
1717
+ return key.key
1718
+ elif isinstance(key, jax.tree_util.GetAttrKey):
1719
+ return key.name
1720
+ else:
1721
+ return str(key)
1722
+
1723
+
1724
+ def _flatten_pytree(pytree: Any):
1725
+ leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
1726
+ nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
1727
+ return nodes, treedef
1728
+
1729
+
1730
+ def _unflatten_pytree(
1731
+ nodes: tuple[tuple[Key, Any], ...],
1732
+ treedef: jax.tree_util.PyTreeDef
1733
+ ):
1734
+ pytree = treedef.unflatten(value for _, value in nodes)
1735
+ return pytree
1736
+
1737
+
1738
+ PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)