brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1624 +1,1624 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- from __future__ import annotations
19
-
20
- import dataclasses
21
- from typing import (
22
- Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
23
- Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
24
- )
25
-
26
- import jax
27
- import numpy as np
28
- from typing_extensions import TypeGuard, Unpack
29
-
30
- from brainstate._state import State, TreefyState
31
- from brainstate._utils import set_module_as
32
- from brainstate.typing import PathParts, Filter, Predicate, Key
33
- from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
34
- from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
35
- from brainstate.util.filter import to_predicate
36
- from brainstate.util.struct import FrozenDict
37
-
38
- __all__ = [
39
- 'register_graph_node_type',
40
-
41
- # state management in the given graph or node
42
- 'pop_states',
43
- 'nodes',
44
- 'states',
45
- 'treefy_states',
46
- 'update_states',
47
-
48
- # graph node operations
49
- 'flatten',
50
- 'unflatten',
51
- 'treefy_split',
52
- 'treefy_merge',
53
- 'iter_leaf',
54
- 'iter_node',
55
- 'clone',
56
- 'graphdef',
57
-
58
- # others
59
- 'RefMap',
60
- 'GraphDef',
61
- 'NodeDef',
62
- 'NodeRef',
63
- ]
64
-
65
- MAX_INT = np.iinfo(np.int32).max
66
-
67
- A = TypeVar('A')
68
- B = TypeVar('B')
69
- C = TypeVar('C')
70
- F = TypeVar('F', bound=Callable)
71
-
72
- HA = TypeVar('HA', bound=Hashable)
73
- HB = TypeVar('HB', bound=Hashable)
74
-
75
- Index = int
76
- Names = Sequence[int]
77
- Node = TypeVar('Node')
78
- Leaf = TypeVar('Leaf')
79
- AuxData = TypeVar('AuxData')
80
-
81
- StateLeaf = TreefyState[Any]
82
- NodeLeaf = State[Any]
83
- GraphStateMapping = NestedDict
84
-
85
-
86
- # --------------------------------------------------------
87
-
88
- def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
89
- return isinstance(x, TreefyState)
90
-
91
-
92
- def _is_node_leaf(x: Any) -> TypeGuard[NodeLeaf]:
93
- return isinstance(x, State)
94
-
95
-
96
- class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
97
- """
98
- A mapping that uses object id as the hash for the keys.
99
-
100
- This mapping is useful when we want to keep track of objects
101
- that are being referenced by other objects.
102
-
103
- Parameters
104
- ----------
105
- mapping : Mapping[A, B] or Iterable[Tuple[A, B]], optional
106
- A mapping or iterable of key-value pairs.
107
-
108
- Examples
109
- --------
110
- .. code-block:: python
111
-
112
- >>> import brainstate
113
- >>> obj1 = object()
114
- >>> obj2 = object()
115
- >>> ref_map = brainstate.graph.RefMap()
116
- >>> ref_map[obj1] = 'value1'
117
- >>> ref_map[obj2] = 'value2'
118
- >>> print(obj1 in ref_map)
119
- True
120
- >>> print(ref_map[obj1])
121
- value1
122
-
123
- """
124
- __module__ = 'brainstate.graph'
125
-
126
- def __init__(self, mapping: Union[Mapping[A, B], Iterable[Tuple[A, B]]] = ()) -> None:
127
- self._mapping: Dict[int, Tuple[A, B]] = {}
128
- self.update(mapping)
129
-
130
- def __getitem__(self, key: A) -> B:
131
- return self._mapping[id(key)][1]
132
-
133
- def __contains__(self, key: Any) -> bool:
134
- return id(key) in self._mapping
135
-
136
- def __setitem__(self, key: A, value: B) -> None:
137
- self._mapping[id(key)] = (key, value)
138
-
139
- def __delitem__(self, key: A) -> None:
140
- del self._mapping[id(key)]
141
-
142
- def __iter__(self) -> Iterator[A]:
143
- return (key for key, _ in self._mapping.values())
144
-
145
- def __len__(self) -> int:
146
- return len(self._mapping)
147
-
148
- def __str__(self) -> str:
149
- return repr(self)
150
-
151
-
152
- @dataclasses.dataclass(frozen=True)
153
- class NodeImplBase(Generic[Node, Leaf, AuxData]):
154
- type: type
155
- flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
156
-
157
- def node_dict(self, node: Node) -> dict[Key, Leaf]:
158
- nodes, _ = self.flatten(node)
159
- return dict(nodes)
160
-
161
-
162
- @dataclasses.dataclass(frozen=True)
163
- class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
164
- set_key: Callable[[Node, Key, Leaf], None]
165
- pop_key: Callable[[Node, Key], Leaf]
166
- create_empty: Callable[[AuxData], Node]
167
- clear: Callable[[Node], None]
168
-
169
- def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]) -> None:
170
- for key, value in items:
171
- self.set_key(node, key, value)
172
-
173
-
174
- @dataclasses.dataclass(frozen=True)
175
- class PyTreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
176
- unflatten: Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
177
-
178
-
179
- NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf, AuxData]]
180
-
181
- # --------------------------------------------------------
182
- # Graph Node implementation: start
183
- # --------------------------------------------------------
184
-
185
- _node_impl_for_type: dict[type, NodeImpl] = {}
186
-
187
-
188
- def register_graph_node_type(
189
- type: type,
190
- flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
191
- set_key: Callable[[Node, Key, Leaf], None],
192
- pop_key: Callable[[Node, Key], Leaf],
193
- create_empty: Callable[[AuxData], Node],
194
- clear: Callable[[Node], None],
195
- ):
196
- """
197
- Register a graph node type.
198
-
199
- Parameters
200
- ----------
201
- type : type
202
- The type of the node.
203
- flatten : Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
204
- A function that flattens the node into a sequence of key-value pairs.
205
- set_key : Callable[[Node, Key, Leaf], None]
206
- A function that sets a key in the node.
207
- pop_key : Callable[[Node, Key], Leaf]
208
- A function that pops a key from the node.
209
- create_empty : Callable[[AuxData], Node]
210
- A function that creates an empty node.
211
- clear : Callable[[Node], None]
212
- A function that clears the node.
213
-
214
- Examples
215
- --------
216
- .. code-block:: python
217
-
218
- >>> import brainstate
219
- >>> # Custom node type implementation
220
- >>> class CustomNode:
221
- ... def __init__(self):
222
- ... self.data = {}
223
- ...
224
- >>> def flatten_custom(node):
225
- ... return list(node.data.items()), None
226
- ...
227
- >>> def set_key_custom(node, key, value):
228
- ... node.data[key] = value
229
- ...
230
- >>> def pop_key_custom(node, key):
231
- ... return node.data.pop(key)
232
- ...
233
- >>> def create_empty_custom(metadata):
234
- ... return CustomNode()
235
- ...
236
- >>> def clear_custom(node):
237
- ... node.data.clear()
238
- ...
239
- >>> # Register the custom node type
240
- >>> brainstate.graph.register_graph_node_type(
241
- ... CustomNode,
242
- ... flatten_custom,
243
- ... set_key_custom,
244
- ... pop_key_custom,
245
- ... create_empty_custom,
246
- ... clear_custom
247
- ... )
248
-
249
- """
250
- _node_impl_for_type[type] = GraphNodeImpl(
251
- type=type,
252
- flatten=flatten,
253
- set_key=set_key,
254
- pop_key=pop_key,
255
- create_empty=create_empty,
256
- clear=clear,
257
- )
258
-
259
-
260
- # --------------------------------------------------------
261
- # Graph node implementation: end
262
- # --------------------------------------------------------
263
-
264
-
265
- def _is_node(x: Any) -> bool:
266
- return _is_graph_node(x) or _is_pytree_node(x)
267
-
268
-
269
- def _is_pytree_node(x: Any) -> bool:
270
- return not jax.tree_util.all_leaves((x,))
271
-
272
-
273
- def _is_graph_node(x: Any) -> bool:
274
- return type(x) in _node_impl_for_type
275
-
276
-
277
- def _is_node_type(x: Type[Any]) -> bool:
278
- return x in _node_impl_for_type or x is PytreeType
279
-
280
-
281
- def _get_node_impl(x: Any) -> NodeImpl:
282
- if isinstance(x, State):
283
- raise ValueError(f'State is not a node: {x}')
284
-
285
- node_type = type(x)
286
- if node_type not in _node_impl_for_type:
287
- if _is_pytree_node(x):
288
- return PYTREE_NODE_IMPL
289
- else:
290
- raise ValueError(f'Unknown node type: {x}')
291
-
292
- return _node_impl_for_type[node_type]
293
-
294
-
295
- def get_node_impl_for_type(x: Type[Any]) -> NodeImpl:
296
- if x is PytreeType:
297
- return PYTREE_NODE_IMPL
298
- return _node_impl_for_type[x]
299
-
300
-
301
- class HashableMapping(Mapping[HA, HB], Hashable):
302
- def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
303
- self._mapping = dict(mapping)
304
-
305
- def __contains__(self, key: object) -> bool:
306
- return key in self._mapping
307
-
308
- def __getitem__(self, key: HA) -> HB:
309
- return self._mapping[key]
310
-
311
- def __iter__(self) -> Iterator[HA]:
312
- return iter(self._mapping)
313
-
314
- def __len__(self) -> int:
315
- return len(self._mapping)
316
-
317
- def __hash__(self) -> int:
318
- return hash(tuple(sorted(self._mapping.items())))
319
-
320
- def __eq__(self, other: Any) -> bool:
321
- return isinstance(other, HashableMapping) and self._mapping == other._mapping
322
-
323
- def __repr__(self) -> str:
324
- return repr(self._mapping)
325
-
326
-
327
- class GraphDef(Generic[Node]):
328
- """
329
- A base dataclass that denotes the graph structure of a :class:`Node`.
330
-
331
- It contains two main components:
332
- - type: The type of the node.
333
- - index: The index of the node in the graph.
334
-
335
- It has two concrete subclasses:
336
-
337
- - :class:`NodeRef`: A reference to a node in the graph.
338
- - :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
339
-
340
- Attributes
341
- ----------
342
- type : Type[Node]
343
- The type of the node.
344
- index : int
345
- The index of the node in the graph.
346
-
347
- """
348
- type: Type[Node]
349
- index: int
350
-
351
-
352
- @dataclasses.dataclass(frozen=True, repr=False)
353
- class NodeDef(GraphDef[Node], PrettyRepr):
354
- """
355
- A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
356
-
357
- Attributes
358
- ----------
359
- type : Type[Node]
360
- Type of the node.
361
- index : int
362
- Index of the node in the graph.
363
- attributes : Tuple[Key, ...]
364
- Attributes for the node.
365
- subgraphs : HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
366
- Mapping of subgraph definitions.
367
- static_fields : HashableMapping
368
- Mapping of static fields.
369
- leaves : HashableMapping[Key, NodeRef[Any] | None]
370
- Mapping of leaf nodes.
371
- metadata : Hashable
372
- Metadata associated with the node.
373
- index_mapping : FrozenDict[Index, Index] | None
374
- Index mapping for node references.
375
-
376
- """
377
-
378
- type: Type[Node] # type of the node
379
- index: int # index of the node in the graph
380
- attributes: Tuple[Key, ...] # attributes for the node
381
- subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
382
- static_fields: HashableMapping
383
- leaves: HashableMapping[Key, NodeRef[Any] | None]
384
- metadata: Hashable
385
- index_mapping: FrozenDict[Index, Index] | None
386
-
387
- @classmethod
388
- def create(
389
- cls,
390
- type: Type[Node],
391
- index: int,
392
- attributes: tuple[Key, ...],
393
- subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
394
- static_fields: Iterable[tuple],
395
- leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
396
- metadata: Hashable,
397
- index_mapping: Mapping[Index, Index] | None,
398
- ):
399
- return cls(
400
- type=type,
401
- index=index,
402
- attributes=attributes,
403
- subgraphs=HashableMapping(subgraphs),
404
- static_fields=HashableMapping(static_fields),
405
- leaves=HashableMapping(leaves),
406
- metadata=metadata,
407
- index_mapping=FrozenDict(index_mapping) if index_mapping is not None else None,
408
- )
409
-
410
- def __pretty_repr__(self):
411
- yield PrettyType(type=type(self))
412
-
413
- yield PrettyAttr('type', self.type.__name__)
414
- yield PrettyAttr('index', self.index)
415
- yield PrettyAttr('attributes', self.attributes)
416
- yield PrettyAttr('subgraphs', PrettyMapping(self.subgraphs))
417
- yield PrettyAttr('static_fields', PrettyMapping(self.static_fields))
418
- yield PrettyAttr('leaves', PrettyMapping(self.leaves))
419
- yield PrettyAttr('metadata', self.metadata)
420
- yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
421
-
422
-
423
- jax.tree_util.register_static(NodeDef)
424
-
425
-
426
- @dataclasses.dataclass(frozen=True, repr=False)
427
- class NodeRef(GraphDef[Node], PrettyRepr):
428
- """
429
- A reference to a node in the graph.
430
-
431
- The node can be instances of :class:`Node` or :class:`State`.
432
-
433
- Attributes
434
- ----------
435
- type : Type[Node]
436
- The type of the node being referenced.
437
- index : int
438
- The index of the node in the graph.
439
-
440
- """
441
- type: Type[Node]
442
- index: int
443
-
444
- def __pretty_repr__(self):
445
- yield PrettyType(type=type(self))
446
- yield PrettyAttr('type', self.type.__name__)
447
- yield PrettyAttr('index', self.index)
448
-
449
-
450
- jax.tree_util.register_static(NodeRef)
451
-
452
-
453
- # --------------------------------------------------------
454
- # Graph operations: start
455
- # --------------------------------------------------------
456
-
457
-
458
- def _graph_flatten(
459
- path: PathParts,
460
- ref_index: RefMap[Any, Index],
461
- flatted_state_mapping: Dict[PathParts, StateLeaf],
462
- node: Any,
463
- treefy_state: bool = False,
464
- ) -> Union[NodeDef[Any], NodeRef[Any]]:
465
- """
466
- Recursive helper for graph flatten.
467
-
468
- Parameters
469
- ----------
470
- path : PathParts
471
- The path to the node.
472
- ref_index : RefMap[Any, Index]
473
- A mapping from nodes to indexes.
474
- flatted_state_mapping : Dict[PathParts, StateLeaf]
475
- A mapping from paths to state leaves.
476
- node : Node
477
- The node to flatten.
478
- treefy_state : bool, optional
479
- Whether to convert states to TreefyState, by default False.
480
-
481
- Returns
482
- -------
483
- NodeDef or NodeRef
484
- A NodeDef or a NodeRef.
485
-
486
- """
487
- if not _is_node(node):
488
- raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
489
-
490
- # If the node is already in the cache, return a reference, otherwise
491
- # add it to the cache and continue with the flattening process.
492
- # This is done to avoid infinite recursion when there is a reference cycle.
493
- if node in ref_index:
494
- return NodeRef(type(node), ref_index[node])
495
-
496
- # Get the node implementation for the node type.
497
- # There are two types of node implementations: GraphNodeImpl and PyTreeNodeImpl.
498
- # - ``GraphNodeImpl`` is used for nodes that have a graph structure.
499
- # - ``PyTreeNodeImpl`` is used for nodes that have a tree structure.
500
- node_impl = _get_node_impl(node)
501
-
502
- # There are two types of nodes: Node and State.
503
- # Here we handle the Node case.
504
- if isinstance(node_impl, GraphNodeImpl):
505
- # add the node to the cache
506
- index = len(ref_index)
507
- ref_index[node] = index
508
- else:
509
- index = -1
510
-
511
- subgraphs: list[tuple[Key, Union[NodeDef[Any], NodeRef[Any]]]] = []
512
- static_fields: list[tuple] = []
513
- leaves: list[tuple[Key, Union[NodeRef[Any], None]]] = []
514
-
515
- # Flatten the node into a sequence of key-value pairs.
516
- values, metadata = node_impl.flatten(node)
517
- for key, value in values:
518
- if _is_node(value):
519
- # Recursively flatten the subgraph.
520
- nodedef = _graph_flatten((*path, key), ref_index, flatted_state_mapping, value, treefy_state)
521
- subgraphs.append((key, nodedef))
522
- elif isinstance(value, State):
523
- # If the variable is in the cache, add a reference to it.
524
- if value in ref_index:
525
- leaves.append((key, NodeRef(type(value), ref_index[value])))
526
- else:
527
- # If the variable is not in the cache, add it to the cache.
528
- # This is done to avoid multiple references to the same variable.
529
- flatted_state_mapping[(*path, key)] = (value.to_state_ref() if treefy_state else value)
530
- variable_index = ref_index[value] = len(ref_index)
531
- leaves.append((key, NodeRef(type(value), variable_index)))
532
- elif _is_state_leaf(value):
533
- # The instance of ``TreefyState`` is a leaf.
534
- flatted_state_mapping[(*path, key)] = value
535
- leaves.append((key, None))
536
- else:
537
- # if isinstance(value, (jax.Array, np.ndarray)):
538
- # path_str = '/'.join(map(str, (*path, key)))
539
- # raise ValueError(f'Arrays leaves are not supported, at {path_str!r}: {value}')
540
-
541
- # The value is a static field.
542
- static_fields.append((key, value))
543
-
544
- nodedef = NodeDef.create(
545
- type=node_impl.type,
546
- index=index,
547
- attributes=tuple(key for key, _ in values),
548
- subgraphs=subgraphs,
549
- static_fields=static_fields,
550
- leaves=leaves,
551
- metadata=metadata,
552
- index_mapping=None,
553
- )
554
- return nodedef
555
-
556
-
557
- @set_module_as('brainstate.graph')
558
- def flatten(
559
- node: Any,
560
- /,
561
- ref_index: Optional[RefMap[Any, Index]] = None,
562
- treefy_state: bool = True,
563
- ) -> Tuple[GraphDef[Any], NestedDict]:
564
- """
565
- Flattens a graph node into a (graph_def, state_mapping) pair.
566
-
567
- Parameters
568
- ----------
569
- node : Node
570
- A graph node.
571
- ref_index : RefMap[Any, Index], optional
572
- A mapping from nodes to indexes, defaults to None. If not provided, a new
573
- empty dictionary is created. This argument can be used to flatten a sequence of graph
574
- nodes that share references.
575
- treefy_state : bool, optional
576
- If True, the state mapping will be a NestedDict instead of a flat dictionary.
577
- Default is True.
578
-
579
- Returns
580
- -------
581
- tuple[GraphDef, NestedDict]
582
- A tuple containing the graph definition and state mapping.
583
-
584
- Examples
585
- --------
586
- .. code-block:: python
587
-
588
- >>> import brainstate
589
- >>> node = brainstate.graph.Node()
590
- >>> graph_def, state_mapping = brainstate.graph.flatten(node)
591
- >>> print(graph_def)
592
- >>> print(state_mapping)
593
-
594
- """
595
- ref_index = RefMap() if ref_index is None else ref_index
596
- assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
597
- flatted_state_mapping: dict[PathParts, StateLeaf] = {}
598
- graph_def = _graph_flatten((), ref_index, flatted_state_mapping, node, treefy_state)
599
- return graph_def, NestedDict.from_flat(flatted_state_mapping)
600
-
601
-
602
- def _get_children(
603
- graph_def: NodeDef[Any],
604
- state_mapping: Mapping,
605
- index_ref: dict[Index, Any],
606
- index_ref_cache: Optional[dict[Index, Any]],
607
- ) -> dict[Key, Union[StateLeaf, Any]]:
608
- children: dict[Key, Union[StateLeaf, Any]] = {}
609
-
610
- # NOTE: we could allow adding new StateLeafs here
611
- # All state keys must be present in the graph definition (the object attributes)
612
- if unknown_keys := set(state_mapping) - set(graph_def.attributes):
613
- raise ValueError(f'Unknown keys: {unknown_keys}')
614
-
615
- # for every key in attributes there are 6 possible cases:
616
- # - (2) the key can either be present in the state or not
617
- # - (3) the key can be a subgraph, a leaf, or a static attribute
618
- for key in graph_def.attributes:
619
- if key not in state_mapping: # static field
620
- # Support unflattening with missing keys for static fields and subgraphs
621
- # This allows partial state restoration and flexible graph reconstruction
622
- if key in graph_def.static_fields:
623
- children[key] = graph_def.static_fields[key]
624
-
625
- elif key in graph_def.subgraphs:
626
- # if the key is a subgraph we create an empty node
627
- subgraphdef = graph_def.subgraphs[key]
628
- if isinstance(subgraphdef, NodeRef):
629
- # subgraph exists, take it from the cache
630
- children[key] = index_ref[subgraphdef.index]
631
-
632
- else:
633
- # create a node from an empty state, reasoning:
634
- # * it is a node with no state
635
- # * it is a node with state but only through references of already
636
- # created nodes
637
- substate = {}
638
- children[key] = _graph_unflatten(subgraphdef, substate, index_ref, index_ref_cache)
639
-
640
- elif key in graph_def.leaves:
641
- noderef = graph_def.leaves[key]
642
- if (noderef is not None) and (noderef.index in index_ref):
643
- # variable exists, take it from the cache
644
- children[key] = index_ref[noderef.index]
645
-
646
- else:
647
- # key for a variable is missing, raise an error
648
- raise ValueError(
649
- f'Expected key {key!r} in state while building node of type '
650
- f'{graph_def.type.__name__}.'
651
- )
652
-
653
- else:
654
- raise RuntimeError(f'Unknown static field: {key!r}')
655
-
656
- else: # state field
657
- value = state_mapping[key]
658
- if isinstance(value, PrettyDict):
659
- value = dict(value)
660
-
661
- if key in graph_def.static_fields:
662
- raise ValueError(f'Got state for static field {key!r}, this is not supported.')
663
-
664
- if key in graph_def.subgraphs:
665
- # if _is_state_leaf(value):
666
- if isinstance(value, (TreefyState, State)):
667
- raise ValueError(
668
- f'Expected value of type {graph_def.subgraphs[key]} '
669
- f'for {key!r}, but got {value!r}'
670
- )
671
-
672
- if not isinstance(value, dict):
673
- raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
674
-
675
- subgraphdef = graph_def.subgraphs[key]
676
- if isinstance(subgraphdef, NodeRef):
677
- children[key] = index_ref[subgraphdef.index]
678
- else:
679
- children[key] = _graph_unflatten(subgraphdef, value, index_ref, index_ref_cache)
680
-
681
- elif key in graph_def.leaves:
682
- # if not _is_state_leaf(value):
683
- if not isinstance(value, (TreefyState, State)):
684
- raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
685
-
686
- noderef = graph_def.leaves[key]
687
- if noderef is None:
688
- # if the leaf is None, it means that the value was originally
689
- # a non-TreefyState leaf, however we allow providing a
690
- # TreefyState presumbly created by modifying the NestedDict
691
- if isinstance(value, TreefyState):
692
- value = value.to_state()
693
- elif isinstance(value, State):
694
- value = value
695
- children[key] = value
696
-
697
- elif noderef.index in index_ref:
698
- # add an existing variable
699
- children[key] = index_ref[noderef.index]
700
-
701
- else:
702
- # it is an unseen variable, create a new one
703
- if not isinstance(value, (TreefyState, State)):
704
- raise ValueError(
705
- f'Expected a State type for {key!r}, but got {type(value)}.'
706
- )
707
-
708
- # when idxmap is present, check if the Varable exists there
709
- # and update existing variables if it does
710
- if index_ref_cache is not None and noderef.index in index_ref_cache:
711
- variable = index_ref_cache[noderef.index]
712
- if not isinstance(variable, State):
713
- raise ValueError(f'Expected a State type for {key!r}, but got {type(variable)}.')
714
- if isinstance(value, TreefyState):
715
- variable.update_from_ref(value)
716
- elif isinstance(value, State):
717
- if value._been_writen:
718
- variable.value = value.value
719
- else:
720
- variable.restore_value(value.value)
721
- else:
722
- raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
723
- else: # if it doesn't, create a new variable
724
- if isinstance(value, TreefyState):
725
- variable = value.to_state()
726
- elif isinstance(value, State):
727
- variable = value
728
- else:
729
- raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
730
- children[key] = variable
731
- index_ref[noderef.index] = variable
732
-
733
- else:
734
- raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
735
-
736
- return children
737
-
738
-
739
- def _graph_unflatten(
740
- graph_def: Union[NodeDef[Any], NodeRef[Any]],
741
- state_mapping: Mapping[Key, Union[StateLeaf, Mapping]],
742
- index_ref: dict[Index, Any],
743
- index_ref_cache: Optional[dict[Index, Any]],
744
- ) -> Any:
745
- """
746
- Recursive helper for graph unflatten.
747
-
748
- Args:
749
- graph_def: A `GraphDef` instance or an index to a node in the cache.
750
- state_mapping: A state mapping from attribute names to variables or subgraphs.
751
- index_ref: A mapping from indexes to nodes that have been traversed.
752
- If a node is already in the cache, it won't be traversed again.
753
- index_ref_cache: A mapping from indexes to existing nodes that can be reused.
754
- When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
755
- object in an empty state and then filled by the unflatten process, as a result
756
- existing graph nodes are mutated to have the new content/topology
757
- specified by the nodedef.
758
-
759
- Returns:
760
- A node instance.
761
- """
762
-
763
- # if the graph_def is a reference, this means that the node has already been created, so
764
- # we return the node from the cache
765
- if isinstance(graph_def, NodeRef):
766
- return index_ref[graph_def.index]
767
- else:
768
- assert isinstance(graph_def, NodeDef), f"graph_def must be a NodeDef. But we got: {graph_def}"
769
-
770
- # graph_def must be a registered node type
771
- if not _is_node_type(graph_def.type):
772
- raise RuntimeError(f'Unsupported type: {graph_def.type}, this is a bug.')
773
-
774
- # check if the index is already in the cache
775
- if graph_def.index in index_ref:
776
- raise RuntimeError(f'GraphDef index {graph_def.index} already used.')
777
-
778
- # get the node implementation for the node type
779
- node_impl = get_node_impl_for_type(graph_def.type)
780
-
781
- if isinstance(node_impl, GraphNodeImpl):
782
- # we create an empty node first and add it to the index
783
- # this avoids infinite recursion when there is a reference cycle
784
-
785
- if (index_ref_cache is not None) and (graph_def.index in index_ref_cache):
786
- # clear the node to leave it in an empty state
787
- node = index_ref_cache[graph_def.index]
788
- if type(node) != graph_def.type:
789
- raise ValueError(f'Expected a node of type {graph_def.type} for index '
790
- f'{graph_def.index}, but got a node of type {type(node)}.')
791
- node_impl.clear(node)
792
- else:
793
- # create an empty node
794
- node = node_impl.create_empty(graph_def.metadata)
795
-
796
- # add the node to the cache
797
- index_ref[graph_def.index] = node
798
-
799
- # get the children (the attributes) of the node
800
- children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
801
-
802
- # initialize the node with the children
803
- node_impl.init(node, tuple(children.items()))
804
-
805
- else:
806
- # if the node type does not support the creation of an empty object it means
807
- # that it cannot reference itself, so we can create its children first
808
-
809
- # first, we create the children (attributes)
810
- children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
811
- # then, we create the node
812
- node = node_impl.unflatten(tuple(children.items()), graph_def.metadata)
813
-
814
- return node
815
-
816
-
817
- @set_module_as('brainstate.graph')
818
- def unflatten(
819
- graph_def: GraphDef[Any],
820
- state_mapping: NestedDict,
821
- /,
822
- *,
823
- index_ref: Optional[dict[Index, Any]] = None,
824
- index_ref_cache: Optional[dict[Index, Any]] = None,
825
- ) -> Any:
826
- """
827
- Unflattens a graphdef into a node with the given state tree mapping.
828
-
829
- Parameters
830
- ----------
831
- graph_def : GraphDef
832
- A GraphDef instance.
833
- state_mapping : NestedDict
834
- A NestedDict instance containing the state mapping.
835
- index_ref : dict[Index, Any], optional
836
- A mapping from indexes to nodes references found during the graph
837
- traversal. If not provided, a new empty dictionary is created. This argument
838
- can be used to unflatten a sequence of (graphdef, state_mapping) pairs that
839
- share the same index space.
840
- index_ref_cache : dict[Index, Any], optional
841
- A mapping from indexes to existing nodes that can be reused. When a reference
842
- is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty
843
- state and then filled by the unflatten process. As a result, existing graph
844
- nodes are mutated to have the new content/topology specified by the graphdef.
845
-
846
- Returns
847
- -------
848
- Node
849
- The reconstructed node.
850
-
851
- Examples
852
- --------
853
- .. code-block:: python
854
-
855
- >>> import brainstate
856
- >>> class MyNode(brainstate.graph.Node):
857
- ... def __init__(self):
858
- ... self.a = brainstate.nn.Linear(2, 3)
859
- ... self.b = brainstate.nn.Linear(3, 4)
860
- ...
861
- >>> # Flatten a node
862
- >>> node = MyNode()
863
- >>> graphdef, statetree = brainstate.graph.flatten(node)
864
- >>>
865
- >>> # Unflatten back to node
866
- >>> reconstructed_node = brainstate.graph.unflatten(graphdef, statetree)
867
- >>> assert isinstance(reconstructed_node, MyNode)
868
- >>> assert isinstance(reconstructed_node.a, brainstate.nn.Linear)
869
- >>> assert isinstance(reconstructed_node.b, brainstate.nn.Linear)
870
- """
871
- index_ref = {} if index_ref is None else index_ref
872
- assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
873
- node = _graph_unflatten(graph_def, state_mapping.to_dict(), index_ref, index_ref_cache)
874
- return node
875
-
876
-
877
- def _graph_pop(
878
- node: Any,
879
- id_to_index: dict[int, Index],
880
- path_parts: PathParts,
881
- flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
882
- predicates: tuple[Predicate, ...],
883
- ) -> None:
884
- if not _is_node(node):
885
- raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
886
-
887
- if id(node) in id_to_index:
888
- return
889
-
890
- id_to_index[id(node)] = len(id_to_index)
891
- node_impl = _get_node_impl(node)
892
- node_dict = node_impl.node_dict(node)
893
-
894
- for name, value in node_dict.items():
895
- if _is_node(value):
896
- _graph_pop(
897
- node=value,
898
- id_to_index=id_to_index,
899
- path_parts=(*path_parts, name),
900
- flatted_state_dicts=flatted_state_dicts,
901
- predicates=predicates,
902
- )
903
- continue
904
- elif not _is_node_leaf(value):
905
- continue
906
- elif id(value) in id_to_index:
907
- continue
908
-
909
- node_path = (*path_parts, name)
910
- node_impl = _get_node_impl(node)
911
- for state_dicts, predicate in zip(flatted_state_dicts, predicates):
912
- if predicate(node_path, value):
913
- if isinstance(node_impl, PyTreeNodeImpl):
914
- raise ValueError(f'Cannot pop key {name!r} from node of type {type(node).__name__}')
915
- id_to_index[id(value)] = len(id_to_index)
916
- node_impl.pop_key(node, name)
917
- # if isinstance(value, State):
918
- # value = value.to_state_ref()
919
- state_dicts[node_path] = value # type: ignore[index] # mypy is wrong here?
920
- break
921
- else:
922
- # NOTE: should we raise an error here?
923
- pass
924
-
925
-
926
- @set_module_as('brainstate.graph')
927
- def pop_states(
928
- node: Any, *filters: Any
929
- ) -> Union[NestedDict, Tuple[NestedDict, ...]]:
930
- """
931
- Pop one or more :class:`State` types from the graph node.
932
-
933
- Parameters
934
- ----------
935
- node : Node
936
- A graph node object.
937
- *filters
938
- One or more :class:`State` objects to filter by.
939
-
940
- Returns
941
- -------
942
- NestedDict or tuple[NestedDict, ...]
943
- The popped :class:`NestedDict` containing the :class:`State`
944
- objects that were filtered for.
945
-
946
- Examples
947
- --------
948
- .. code-block:: python
949
-
950
- >>> import brainstate
951
- >>> import jax.numpy as jnp
952
-
953
- >>> class Model(brainstate.nn.Module):
954
- ... def __init__(self):
955
- ... super().__init__()
956
- ... self.a = brainstate.nn.Linear(2, 3)
957
- ... self.b = brainstate.nn.LIF([10, 2])
958
-
959
- >>> model = Model()
960
- >>> with brainstate.catch_new_states('new'):
961
- ... brainstate.nn.init_all_states(model)
962
-
963
- >>> assert len(model.states()) == 2
964
- >>> model_states = brainstate.graph.pop_states(model, 'new')
965
- >>> model_states # doctest: +SKIP
966
- NestedDict({
967
- 'b': {
968
- 'V': {
969
- 'st': ShortTermState(
970
- value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
971
- 0., 0., 0.], dtype=float32),
972
- tag='new'
973
- )
974
- }
975
- }
976
- })
977
- """
978
- if len(filters) == 0:
979
- raise ValueError('Expected at least one filter')
980
-
981
- id_to_index: dict[int, Index] = {}
982
- path_parts: PathParts = ()
983
- predicates = tuple(to_predicate(filter) for filter in filters)
984
- flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
985
- _graph_pop(
986
- node=node,
987
- id_to_index=id_to_index,
988
- path_parts=path_parts,
989
- flatted_state_dicts=flatted_state_dicts,
990
- predicates=predicates,
991
- )
992
- states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
993
-
994
- if len(states) == 1:
995
- return states[0]
996
- else:
997
- return states
998
-
999
-
1000
- def _split_state(
1001
- state: GraphStateMapping,
1002
- filters: tuple[Filter, ...],
1003
- ) -> tuple[GraphStateMapping, Unpack[tuple[GraphStateMapping, ...]]]:
1004
- if not filters:
1005
- return (state,)
1006
- states = state.split(*filters)
1007
- if isinstance(states, NestedDict):
1008
- return (states,)
1009
- assert len(states) > 0
1010
- return states # type: ignore[return-value]
1011
-
1012
-
1013
- @set_module_as('brainstate.graph')
1014
- def treefy_split(
1015
- node: A, *filters: Filter
1016
- ):
1017
- """
1018
- Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s.
1019
-
1020
- NestedDict is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States.
1021
- GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is
1022
- analogous to JAX's ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to
1023
- switch seamlessly between stateful and stateless representations of the graph.
1024
-
1025
- Parameters
1026
- ----------
1027
- node : A
1028
- Graph node to split.
1029
- *filters
1030
- Optional filters to group the state into mutually exclusive substates.
1031
-
1032
- Returns
1033
- -------
1034
- tuple
1035
- ``GraphDef`` and one or more ``States`` equal to the number of filters passed.
1036
- If no filters are passed, a single ``NestedDict`` is returned.
1037
-
1038
- Examples
1039
- --------
1040
- .. code-block:: python
1041
-
1042
- >>> import brainstate
1043
- >>> import jax, jax.numpy as jnp
1044
-
1045
- >>> class Foo(brainstate.graph.Node):
1046
- ... def __init__(self):
1047
- ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1048
- ... self.b = brainstate.nn.Linear(2, 3)
1049
- ...
1050
- >>> node = Foo()
1051
- >>> graphdef, params, others = brainstate.graph.treefy_split(
1052
- ... node, brainstate.ParamState, ...
1053
- ... )
1054
- >>> # params contains ParamState variables
1055
- >>> # others contains all other state variables
1056
- """
1057
- graphdef, state_tree = flatten(node)
1058
- states = tuple(_split_state(state_tree, filters))
1059
- return graphdef, *states
1060
-
1061
-
1062
- @set_module_as('brainstate.graph')
1063
- def treefy_merge(graphdef: GraphDef[A], *state_mappings) -> A:
1064
- """
1065
- The inverse of :func:`split`.
1066
-
1067
- ``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
1068
- a new node with the same structure as the original node.
1069
-
1070
- Parameters
1071
- ----------
1072
- graphdef : GraphDef[A]
1073
- A :class:`GraphDef` object.
1074
- *state_mappings
1075
- Additional :class:`NestedDict` objects.
1076
-
1077
- Returns
1078
- -------
1079
- A
1080
- The merged :class:`Module`.
1081
-
1082
- Examples
1083
- --------
1084
- .. code-block:: python
1085
-
1086
- >>> import brainstate
1087
- >>> import jax, jax.numpy as jnp
1088
-
1089
- >>> class Foo(brainstate.graph.Node):
1090
- ... def __init__(self):
1091
- ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1092
- ... self.b = brainstate.nn.Linear(2, 3)
1093
- ...
1094
- >>> node = Foo()
1095
- >>> graphdef, params, others = brainstate.graph.treefy_split(
1096
- ... node, brainstate.ParamState, ...
1097
- ... )
1098
- >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1099
- >>> assert isinstance(new_node, Foo)
1100
- >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1101
- >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1102
- """
1103
- state_mapping = GraphStateMapping.merge(*state_mappings)
1104
- node = unflatten(graphdef, state_mapping)
1105
- return node
1106
-
1107
-
1108
- def _filters_to_predicates(filters: Tuple[Filter, ...]) -> Tuple[Predicate, ...]:
1109
- for i, filter_ in enumerate(filters):
1110
- if filter_ in (..., True) and i != len(filters) - 1:
1111
- remaining_filters = filters[i + 1:]
1112
- if not all(f in (..., True) for f in remaining_filters):
1113
- raise ValueError('`...` or `True` can only be used as the last filters, '
1114
- f'got {filter_} it at index {i}.')
1115
- return tuple(map(to_predicate, filters))
1116
-
1117
-
1118
- def _split_flatted(
1119
- flatted: Iterable[tuple[PathParts, Any]],
1120
- filters: tuple[Filter, ...],
1121
- ) -> tuple[list[tuple[PathParts, Any]], ...]:
1122
- predicates = _filters_to_predicates(filters)
1123
-
1124
- # we have n + 1 states, where n is the number of predicates
1125
- # the last state is for values that don't match any predicate
1126
- flat_states: tuple[list[tuple[PathParts, Any]], ...] = tuple([] for _ in predicates)
1127
-
1128
- for path, value in flatted:
1129
- for i, predicate in enumerate(predicates):
1130
- if predicate(path, value):
1131
- flat_states[i].append((path, value))
1132
- break
1133
- else:
1134
- raise ValueError('Non-exhaustive filters, got a non-empty remainder: '
1135
- f'{path} -> {value}.'
1136
- '\nUse `...` to match all remaining elements.')
1137
-
1138
- return flat_states
1139
-
1140
-
1141
- @set_module_as('brainstate.graph')
1142
- def nodes(
1143
- node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1144
- ):
1145
- """
1146
- Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1147
-
1148
- Parameters
1149
- ----------
1150
- node : Node
1151
- The node to get nodes from.
1152
- *filters
1153
- Filters to apply to the nodes.
1154
- allowed_hierarchy : tuple[int, int], optional
1155
- The allowed hierarchy levels, by default (0, MAX_INT).
1156
-
1157
- Returns
1158
- -------
1159
- FlattedDict or tuple[FlattedDict, ...]
1160
- The filtered nodes.
1161
-
1162
- """
1163
- num_filters = len(filters)
1164
- if num_filters == 0:
1165
- filters = (..., ...)
1166
- else:
1167
- filters = (*filters, ...)
1168
-
1169
- nodes_iterable = iter_node(node, allowed_hierarchy=allowed_hierarchy)
1170
- flat_nodes = _split_flatted(nodes_iterable, (*filters, ...))
1171
- node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
1172
- if num_filters < 2:
1173
- return node_maps[0]
1174
- return node_maps[:num_filters]
1175
-
1176
-
1177
- def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
1178
- for path, value in iter_leaf(node, allowed_hierarchy=allowed_hierarchy):
1179
- if isinstance(value, State):
1180
- yield path, value
1181
-
1182
-
1183
- @set_module_as('brainstate.graph')
1184
- def states(
1185
- node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1186
- ) -> Union[FlattedDict, tuple[FlattedDict, ...]]:
1187
- """
1188
- Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1189
-
1190
- Parameters
1191
- ----------
1192
- node : Node
1193
- The node to get states from.
1194
- *filters
1195
- Filters to apply to the states.
1196
- allowed_hierarchy : tuple[int, int], optional
1197
- The allowed hierarchy levels, by default (0, MAX_INT).
1198
-
1199
- Returns
1200
- -------
1201
- FlattedDict or tuple[FlattedDict, ...]
1202
- The filtered states.
1203
-
1204
- """
1205
- num_filters = len(filters)
1206
- if num_filters == 0:
1207
- filters = (..., ...)
1208
- else:
1209
- filters = (*filters, ...)
1210
-
1211
- states_iterable = _states_generator(node, allowed_hierarchy=allowed_hierarchy)
1212
- flat_states = _split_flatted(states_iterable, (*filters, ...))
1213
- state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
1214
- if num_filters < 2:
1215
- return state_maps[0]
1216
- return state_maps[:num_filters]
1217
-
1218
-
1219
- @set_module_as('brainstate.graph')
1220
- def treefy_states(
1221
- node, *filters,
1222
- ):
1223
- """
1224
- Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1225
-
1226
- Parameters
1227
- ----------
1228
- node : Node
1229
- A graph node object.
1230
- *filters
1231
- One or more :class:`State` objects to filter by.
1232
-
1233
- Returns
1234
- -------
1235
- NestedDict or tuple of NestedDict
1236
- One or more :class:`NestedDict` mappings.
1237
-
1238
- Examples
1239
- --------
1240
- .. code-block:: python
1241
-
1242
- >>> import brainstate
1243
- >>> class Model(brainstate.nn.Module):
1244
- ... def __init__(self):
1245
- ... super().__init__()
1246
- ... self.l1 = brainstate.nn.Linear(2, 3)
1247
- ... self.l2 = brainstate.nn.Linear(3, 4)
1248
- ... def __call__(self, x):
1249
- ... return self.l2(self.l1(x))
1250
-
1251
- >>> model = Model()
1252
- >>> # Get the learnable parameters
1253
- >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1254
- >>> # Get them separately
1255
- >>> params, others = brainstate.graph.treefy_states(
1256
- ... model, brainstate.ParamState, brainstate.ShortTermState
1257
- ... )
1258
- >>> # Get all states together
1259
- >>> states = brainstate.graph.treefy_states(model)
1260
- """
1261
- _, state_mapping = flatten(node)
1262
- if len(filters) == 0:
1263
- return state_mapping
1264
- else:
1265
- state_mappings = state_mapping.filter(*filters)
1266
- if len(filters) == 1:
1267
- return state_mappings[0]
1268
- else:
1269
- return state_mappings
1270
-
1271
-
1272
- def _graph_update_dynamic(node: Any, state: Mapping) -> None:
1273
- if not _is_node(node):
1274
- raise RuntimeError(f'Unsupported type: {type(node)}')
1275
-
1276
- node_impl = _get_node_impl(node)
1277
- node_dict = node_impl.node_dict(node)
1278
- for key, value in state.items():
1279
- # case 1: new state is being added
1280
- if key not in node_dict:
1281
- if isinstance(node_impl, PyTreeNodeImpl):
1282
- raise ValueError(f'Cannot set key {key!r} on immutable node of '
1283
- f'type {type(node).__name__}')
1284
- if isinstance(value, State):
1285
- # TODO: here maybe error? we should check if the state already belongs to another node?
1286
- value = value.to_state_ref() # Convert to state reference for proper state management
1287
- node_impl.set_key(node, key, value)
1288
- continue
1289
-
1290
- # check values are of the same type
1291
- current_value = node_dict[key]
1292
-
1293
- # case 2: subgraph is being updated
1294
- if _is_node(current_value):
1295
- if _is_state_leaf(value):
1296
- raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
1297
- _graph_update_dynamic(current_value, value)
1298
- elif isinstance(value, TreefyState):
1299
- # case 3: state leaf is being updated
1300
- if not isinstance(current_value, State):
1301
- raise ValueError(f'Trying to update a non-State attribute {key!r} with a State: '
1302
- f'{value!r}')
1303
- current_value.update_from_ref(value)
1304
- elif _is_state_leaf(value):
1305
- # case 4: state field is being updated
1306
- if isinstance(node_impl, PyTreeNodeImpl):
1307
- raise ValueError(f'Cannot set key {key!r} on immutable node of '
1308
- f'type {type(node).__name__}')
1309
- node_impl.set_key(node, key, value)
1310
- else:
1311
- raise ValueError(f'Unsupported update type: {type(value)} for key {key!r}')
1312
-
1313
-
1314
- def update_states(
1315
- node: Any,
1316
- state_dict: Union[NestedDict, FlattedDict],
1317
- /,
1318
- *state_dicts: Union[NestedDict, FlattedDict]
1319
- ) -> None:
1320
- """
1321
- Update the given graph node with a new :class:`NestedMapping` in-place.
1322
-
1323
- Parameters
1324
- ----------
1325
- node : Node
1326
- A graph node to update.
1327
- state_dict : NestedDict | FlattedDict
1328
- A :class:`NestedMapping` object.
1329
- *state_dicts : NestedDict | FlattedDict
1330
- Additional :class:`NestedMapping` objects.
1331
-
1332
- """
1333
- if state_dicts:
1334
- state_dict = NestedDict.merge(state_dict, *state_dicts)
1335
- _graph_update_dynamic(node, state_dict.to_dict())
1336
-
1337
-
1338
- @set_module_as('brainstate.graph')
1339
- def graphdef(node: Any) -> GraphDef[Any]:
1340
- """
1341
- Get the :class:`GraphDef` of the given graph node.
1342
-
1343
- Parameters
1344
- ----------
1345
- node : Any
1346
- A graph node object.
1347
-
1348
- Returns
1349
- -------
1350
- GraphDef[Any]
1351
- The :class:`GraphDef` of the :class:`Module` object.
1352
-
1353
- Examples
1354
- --------
1355
- .. code-block:: python
1356
-
1357
- >>> import brainstate
1358
-
1359
- >>> model = brainstate.nn.Linear(2, 3)
1360
- >>> graphdef, _ = brainstate.graph.treefy_split(model)
1361
- >>> assert graphdef == brainstate.graph.graphdef(model)
1362
-
1363
- """
1364
- graphdef, _ = flatten(node)
1365
- return graphdef
1366
-
1367
-
1368
- @set_module_as('brainstate.graph')
1369
- def clone(node: A) -> A:
1370
- """
1371
- Create a deep copy of the given graph node.
1372
-
1373
- Parameters
1374
- ----------
1375
- node : Node
1376
- A graph node object.
1377
-
1378
- Returns
1379
- -------
1380
- Node
1381
- A deep copy of the :class:`Module` object.
1382
-
1383
- Examples
1384
- --------
1385
- .. code-block:: python
1386
-
1387
- >>> import brainstate
1388
- >>> model = brainstate.nn.Linear(2, 3)
1389
- >>> cloned_model = brainstate.graph.clone(model)
1390
- >>> model.weight.value['bias'] += 1
1391
- >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1392
-
1393
- """
1394
- graphdef, state = treefy_split(node)
1395
- return treefy_merge(graphdef, state)
1396
-
1397
-
1398
- @set_module_as('brainstate.graph')
1399
- def iter_leaf(
1400
- node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1401
- ) -> Iterator[tuple[PathParts, Any]]:
1402
- """
1403
- Iterates over all nested leaves in the given graph node, including the current node.
1404
-
1405
- ``iter_graph`` creates a generator that yields path and value pairs, where
1406
- the path is a tuple of strings or integers representing the path to the value from the
1407
- root. Repeated nodes are visited only once. Leaves include static values.
1408
-
1409
- Parameters
1410
- ----------
1411
- node : Any
1412
- The node to iterate over.
1413
- allowed_hierarchy : tuple[int, int], optional
1414
- The allowed hierarchy levels, by default (0, MAX_INT).
1415
-
1416
- Yields
1417
- ------
1418
- Iterator[tuple[PathParts, Any]]
1419
- Path and value pairs.
1420
-
1421
- Examples
1422
- --------
1423
- .. code-block:: python
1424
-
1425
- >>> import brainstate
1426
- >>> import jax.numpy as jnp
1427
-
1428
- >>> class Linear(brainstate.nn.Module):
1429
- ... def __init__(self, din, dout):
1430
- ... super().__init__()
1431
- ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1432
- ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1433
- ... self.a = 1
1434
- ...
1435
- >>> module = Linear(3, 4)
1436
- ...
1437
- >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1438
- ... print(path, type(value).__name__)
1439
- ...
1440
- (0, 'a') int
1441
- (0, 'bias') ParamState
1442
- (0, 'weight') ParamState
1443
-
1444
- """
1445
-
1446
- def _iter_graph_leaf(
1447
- node_: Any,
1448
- visited_: set[int],
1449
- path_parts_: PathParts,
1450
- level_: int,
1451
- ) -> Iterator[tuple[PathParts, Any]]:
1452
- if level_ > allowed_hierarchy[1]:
1453
- return
1454
-
1455
- if _is_node(node_):
1456
- if id(node_) in visited_:
1457
- return
1458
- visited_.add(id(node_))
1459
- node_dict = _get_node_impl(node_).node_dict(node_)
1460
- for key, value in node_dict.items():
1461
- yield from _iter_graph_leaf(
1462
- value,
1463
- visited_,
1464
- (*path_parts_, key),
1465
- level_ + 1 if _is_graph_node(value) else level_
1466
- )
1467
- else:
1468
- if level_ >= allowed_hierarchy[0]:
1469
- yield path_parts_, node_
1470
-
1471
- visited: set[int] = set()
1472
- path_parts: PathParts = ()
1473
- level: int = 0
1474
- yield from _iter_graph_leaf(node, visited, path_parts, level)
1475
-
1476
-
1477
- @set_module_as('brainstate.graph')
1478
- def iter_node(
1479
- node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1480
- ) -> Iterator[Tuple[PathParts, Any]]:
1481
- """
1482
- Iterates over all nested nodes of the given graph node, including the current node.
1483
-
1484
- ``iter_graph`` creates a generator that yields path and value pairs, where
1485
- the path is a tuple of strings or integers representing the path to the value from the
1486
- root. Repeated nodes are visited only once. Leaves include static values.
1487
-
1488
- Parameters
1489
- ----------
1490
- node : Any
1491
- The node to iterate over.
1492
- allowed_hierarchy : tuple[int, int], optional
1493
- The allowed hierarchy levels, by default (0, MAX_INT).
1494
-
1495
- Yields
1496
- ------
1497
- Iterator[tuple[PathParts, Any]]
1498
- Path and node pairs.
1499
-
1500
- Examples
1501
- --------
1502
- .. code-block:: python
1503
-
1504
- >>> import brainstate
1505
- >>> import jax.numpy as jnp
1506
-
1507
- >>> class Model(brainstate.nn.Module):
1508
- ... def __init__(self):
1509
- ... super().__init__()
1510
- ... self.a = brainstate.nn.Linear(1, 2)
1511
- ... self.b = brainstate.nn.Linear(2, 3)
1512
- ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1513
- ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1514
- ... self.b.a = brainstate.nn.LIF(2)
1515
- ...
1516
- >>> model = Model()
1517
- ...
1518
- >>> for path, node in brainstate.graph.iter_node([model, model]):
1519
- ... print(path, node.__class__.__name__)
1520
- ...
1521
- (0, 'a') Linear
1522
- (0, 'b', 'a') LIF
1523
- (0, 'b') Linear
1524
- (0, 'c', 0) Linear
1525
- (0, 'c', 1) Linear
1526
- (0, 'd', 'x') Linear
1527
- (0, 'd', 'y') Linear
1528
- (0,) Model
1529
-
1530
- """
1531
-
1532
- def _iter_graph_node(
1533
- node_: Any,
1534
- visited_: set[int],
1535
- path_parts_: PathParts,
1536
- level_: int,
1537
- ) -> Iterator[tuple[PathParts, Any]]:
1538
- if level_ > allowed_hierarchy[1]:
1539
- return
1540
-
1541
- if _is_node(node_):
1542
- if id(node_) in visited_:
1543
- return
1544
-
1545
- visited_.add(id(node_))
1546
- node_dict = _get_node_impl(node_).node_dict(node_)
1547
- for key, value in node_dict.items():
1548
- yield from _iter_graph_node(value, visited_, (*path_parts_, key),
1549
- level_ + 1 if _is_graph_node(value) else level_)
1550
-
1551
- if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
1552
- yield path_parts_, node_
1553
-
1554
- visited: set[int] = set()
1555
- path_parts: PathParts = ()
1556
- level: int = 0
1557
- yield from _iter_graph_node(node, visited, path_parts, level)
1558
-
1559
-
1560
- # --------------------------------------------------------
1561
- # Graph operations: end
1562
- # --------------------------------------------------------
1563
-
1564
-
1565
- @dataclasses.dataclass(frozen=True)
1566
- class Static(Generic[A]):
1567
- """
1568
- An empty pytree node that treats its inner value as static.
1569
-
1570
- ``value`` must define ``__eq__`` and ``__hash__``.
1571
-
1572
- Attributes
1573
- ----------
1574
- value : A
1575
- The static value to wrap.
1576
-
1577
- """
1578
-
1579
- value: A
1580
-
1581
-
1582
- jax.tree_util.register_static(Static)
1583
-
1584
-
1585
- # ---------------------------------------------------------
1586
- # Pytree
1587
- # ---------------------------------------------------------
1588
-
1589
- class PytreeType:
1590
- ...
1591
-
1592
-
1593
- def _key_path_to_key(key: Any) -> Key:
1594
- if isinstance(key, jax.tree_util.SequenceKey):
1595
- return key.idx
1596
- elif isinstance(
1597
- key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
1598
- ):
1599
- if not isinstance(key.key, Key):
1600
- raise ValueError(
1601
- f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
1602
- )
1603
- return key.key
1604
- elif isinstance(key, jax.tree_util.GetAttrKey):
1605
- return key.name
1606
- else:
1607
- return str(key)
1608
-
1609
-
1610
- def _flatten_pytree(pytree: Any) -> Tuple[Tuple[Tuple, ...], jax.tree_util.PyTreeDef]:
1611
- leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
1612
- nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
1613
- return nodes, treedef
1614
-
1615
-
1616
- def _unflatten_pytree(
1617
- nodes: tuple[tuple, ...],
1618
- treedef: jax.tree_util.PyTreeDef
1619
- ) -> Any:
1620
- pytree = treedef.unflatten(value for _, value in nodes)
1621
- return pytree
1622
-
1623
-
1624
- PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import dataclasses
21
+ from typing import (
22
+ Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
23
+ Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
24
+ )
25
+
26
+ import jax
27
+ import numpy as np
28
+ from typing_extensions import TypeGuard, Unpack
29
+
30
+ from brainstate._state import State, TreefyState
31
+ from brainstate._utils import set_module_as
32
+ from brainstate.typing import PathParts, Filter, Predicate, Key
33
+ from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
34
+ from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
35
+ from brainstate.util.filter import to_predicate
36
+ from brainstate.util.struct import FrozenDict
37
+
38
+ __all__ = [
39
+ 'register_graph_node_type',
40
+
41
+ # state management in the given graph or node
42
+ 'pop_states',
43
+ 'nodes',
44
+ 'states',
45
+ 'treefy_states',
46
+ 'update_states',
47
+
48
+ # graph node operations
49
+ 'flatten',
50
+ 'unflatten',
51
+ 'treefy_split',
52
+ 'treefy_merge',
53
+ 'iter_leaf',
54
+ 'iter_node',
55
+ 'clone',
56
+ 'graphdef',
57
+
58
+ # others
59
+ 'RefMap',
60
+ 'GraphDef',
61
+ 'NodeDef',
62
+ 'NodeRef',
63
+ ]
64
+
65
+ MAX_INT = np.iinfo(np.int32).max
66
+
67
+ A = TypeVar('A')
68
+ B = TypeVar('B')
69
+ C = TypeVar('C')
70
+ F = TypeVar('F', bound=Callable)
71
+
72
+ HA = TypeVar('HA', bound=Hashable)
73
+ HB = TypeVar('HB', bound=Hashable)
74
+
75
+ Index = int
76
+ Names = Sequence[int]
77
+ Node = TypeVar('Node')
78
+ Leaf = TypeVar('Leaf')
79
+ AuxData = TypeVar('AuxData')
80
+
81
+ StateLeaf = TreefyState[Any]
82
+ NodeLeaf = State[Any]
83
+ GraphStateMapping = NestedDict
84
+
85
+
86
+ # --------------------------------------------------------
87
+
88
+ def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
89
+ return isinstance(x, TreefyState)
90
+
91
+
92
+ def _is_node_leaf(x: Any) -> TypeGuard[NodeLeaf]:
93
+ return isinstance(x, State)
94
+
95
+
96
+ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
97
+ """
98
+ A mapping that uses object id as the hash for the keys.
99
+
100
+ This mapping is useful when we want to keep track of objects
101
+ that are being referenced by other objects.
102
+
103
+ Parameters
104
+ ----------
105
+ mapping : Mapping[A, B] or Iterable[Tuple[A, B]], optional
106
+ A mapping or iterable of key-value pairs.
107
+
108
+ Examples
109
+ --------
110
+ .. code-block:: python
111
+
112
+ >>> import brainstate
113
+ >>> obj1 = object()
114
+ >>> obj2 = object()
115
+ >>> ref_map = brainstate.graph.RefMap()
116
+ >>> ref_map[obj1] = 'value1'
117
+ >>> ref_map[obj2] = 'value2'
118
+ >>> print(obj1 in ref_map)
119
+ True
120
+ >>> print(ref_map[obj1])
121
+ value1
122
+
123
+ """
124
+ __module__ = 'brainstate.graph'
125
+
126
+ def __init__(self, mapping: Union[Mapping[A, B], Iterable[Tuple[A, B]]] = ()) -> None:
127
+ self._mapping: Dict[int, Tuple[A, B]] = {}
128
+ self.update(mapping)
129
+
130
+ def __getitem__(self, key: A) -> B:
131
+ return self._mapping[id(key)][1]
132
+
133
+ def __contains__(self, key: Any) -> bool:
134
+ return id(key) in self._mapping
135
+
136
+ def __setitem__(self, key: A, value: B) -> None:
137
+ self._mapping[id(key)] = (key, value)
138
+
139
+ def __delitem__(self, key: A) -> None:
140
+ del self._mapping[id(key)]
141
+
142
+ def __iter__(self) -> Iterator[A]:
143
+ return (key for key, _ in self._mapping.values())
144
+
145
+ def __len__(self) -> int:
146
+ return len(self._mapping)
147
+
148
+ def __str__(self) -> str:
149
+ return repr(self)
150
+
151
+
152
+ @dataclasses.dataclass(frozen=True)
153
+ class NodeImplBase(Generic[Node, Leaf, AuxData]):
154
+ type: type
155
+ flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
156
+
157
+ def node_dict(self, node: Node) -> dict[Key, Leaf]:
158
+ nodes, _ = self.flatten(node)
159
+ return dict(nodes)
160
+
161
+
162
+ @dataclasses.dataclass(frozen=True)
163
+ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
164
+ set_key: Callable[[Node, Key, Leaf], None]
165
+ pop_key: Callable[[Node, Key], Leaf]
166
+ create_empty: Callable[[AuxData], Node]
167
+ clear: Callable[[Node], None]
168
+
169
+ def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]) -> None:
170
+ for key, value in items:
171
+ self.set_key(node, key, value)
172
+
173
+
174
+ @dataclasses.dataclass(frozen=True)
175
+ class PyTreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
176
+ unflatten: Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
177
+
178
+
179
+ NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf, AuxData]]
180
+
181
+ # --------------------------------------------------------
182
+ # Graph Node implementation: start
183
+ # --------------------------------------------------------
184
+
185
+ _node_impl_for_type: dict[type, NodeImpl] = {}
186
+
187
+
188
+ def register_graph_node_type(
189
+ type: type,
190
+ flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
191
+ set_key: Callable[[Node, Key, Leaf], None],
192
+ pop_key: Callable[[Node, Key], Leaf],
193
+ create_empty: Callable[[AuxData], Node],
194
+ clear: Callable[[Node], None],
195
+ ):
196
+ """
197
+ Register a graph node type.
198
+
199
+ Parameters
200
+ ----------
201
+ type : type
202
+ The type of the node.
203
+ flatten : Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
204
+ A function that flattens the node into a sequence of key-value pairs.
205
+ set_key : Callable[[Node, Key, Leaf], None]
206
+ A function that sets a key in the node.
207
+ pop_key : Callable[[Node, Key], Leaf]
208
+ A function that pops a key from the node.
209
+ create_empty : Callable[[AuxData], Node]
210
+ A function that creates an empty node.
211
+ clear : Callable[[Node], None]
212
+ A function that clears the node.
213
+
214
+ Examples
215
+ --------
216
+ .. code-block:: python
217
+
218
+ >>> import brainstate
219
+ >>> # Custom node type implementation
220
+ >>> class CustomNode:
221
+ ... def __init__(self):
222
+ ... self.data = {}
223
+ ...
224
+ >>> def flatten_custom(node):
225
+ ... return list(node.data.items()), None
226
+ ...
227
+ >>> def set_key_custom(node, key, value):
228
+ ... node.data[key] = value
229
+ ...
230
+ >>> def pop_key_custom(node, key):
231
+ ... return node.data.pop(key)
232
+ ...
233
+ >>> def create_empty_custom(metadata):
234
+ ... return CustomNode()
235
+ ...
236
+ >>> def clear_custom(node):
237
+ ... node.data.clear()
238
+ ...
239
+ >>> # Register the custom node type
240
+ >>> brainstate.graph.register_graph_node_type(
241
+ ... CustomNode,
242
+ ... flatten_custom,
243
+ ... set_key_custom,
244
+ ... pop_key_custom,
245
+ ... create_empty_custom,
246
+ ... clear_custom
247
+ ... )
248
+
249
+ """
250
+ _node_impl_for_type[type] = GraphNodeImpl(
251
+ type=type,
252
+ flatten=flatten,
253
+ set_key=set_key,
254
+ pop_key=pop_key,
255
+ create_empty=create_empty,
256
+ clear=clear,
257
+ )
258
+
259
+
260
+ # --------------------------------------------------------
261
+ # Graph node implementation: end
262
+ # --------------------------------------------------------
263
+
264
+
265
+ def _is_node(x: Any) -> bool:
266
+ return _is_graph_node(x) or _is_pytree_node(x)
267
+
268
+
269
+ def _is_pytree_node(x: Any) -> bool:
270
+ return not jax.tree_util.all_leaves((x,))
271
+
272
+
273
+ def _is_graph_node(x: Any) -> bool:
274
+ return type(x) in _node_impl_for_type
275
+
276
+
277
+ def _is_node_type(x: Type[Any]) -> bool:
278
+ return x in _node_impl_for_type or x is PytreeType
279
+
280
+
281
+ def _get_node_impl(x: Any) -> NodeImpl:
282
+ if isinstance(x, State):
283
+ raise ValueError(f'State is not a node: {x}')
284
+
285
+ node_type = type(x)
286
+ if node_type not in _node_impl_for_type:
287
+ if _is_pytree_node(x):
288
+ return PYTREE_NODE_IMPL
289
+ else:
290
+ raise ValueError(f'Unknown node type: {x}')
291
+
292
+ return _node_impl_for_type[node_type]
293
+
294
+
295
+ def get_node_impl_for_type(x: Type[Any]) -> NodeImpl:
296
+ if x is PytreeType:
297
+ return PYTREE_NODE_IMPL
298
+ return _node_impl_for_type[x]
299
+
300
+
301
+ class HashableMapping(Mapping[HA, HB], Hashable):
302
+ def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
303
+ self._mapping = dict(mapping)
304
+
305
+ def __contains__(self, key: object) -> bool:
306
+ return key in self._mapping
307
+
308
+ def __getitem__(self, key: HA) -> HB:
309
+ return self._mapping[key]
310
+
311
+ def __iter__(self) -> Iterator[HA]:
312
+ return iter(self._mapping)
313
+
314
+ def __len__(self) -> int:
315
+ return len(self._mapping)
316
+
317
+ def __hash__(self) -> int:
318
+ return hash(tuple(sorted(self._mapping.items())))
319
+
320
+ def __eq__(self, other: Any) -> bool:
321
+ return isinstance(other, HashableMapping) and self._mapping == other._mapping
322
+
323
+ def __repr__(self) -> str:
324
+ return repr(self._mapping)
325
+
326
+
327
+ class GraphDef(Generic[Node]):
328
+ """
329
+ A base dataclass that denotes the graph structure of a :class:`Node`.
330
+
331
+ It contains two main components:
332
+ - type: The type of the node.
333
+ - index: The index of the node in the graph.
334
+
335
+ It has two concrete subclasses:
336
+
337
+ - :class:`NodeRef`: A reference to a node in the graph.
338
+ - :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
339
+
340
+ Attributes
341
+ ----------
342
+ type : Type[Node]
343
+ The type of the node.
344
+ index : int
345
+ The index of the node in the graph.
346
+
347
+ """
348
+ type: Type[Node]
349
+ index: int
350
+
351
+
352
+ @dataclasses.dataclass(frozen=True, repr=False)
353
+ class NodeDef(GraphDef[Node], PrettyRepr):
354
+ """
355
+ A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
356
+
357
+ Attributes
358
+ ----------
359
+ type : Type[Node]
360
+ Type of the node.
361
+ index : int
362
+ Index of the node in the graph.
363
+ attributes : Tuple[Key, ...]
364
+ Attributes for the node.
365
+ subgraphs : HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
366
+ Mapping of subgraph definitions.
367
+ static_fields : HashableMapping
368
+ Mapping of static fields.
369
+ leaves : HashableMapping[Key, NodeRef[Any] | None]
370
+ Mapping of leaf nodes.
371
+ metadata : Hashable
372
+ Metadata associated with the node.
373
+ index_mapping : FrozenDict[Index, Index] | None
374
+ Index mapping for node references.
375
+
376
+ """
377
+
378
+ type: Type[Node] # type of the node
379
+ index: int # index of the node in the graph
380
+ attributes: Tuple[Key, ...] # attributes for the node
381
+ subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
382
+ static_fields: HashableMapping
383
+ leaves: HashableMapping[Key, NodeRef[Any] | None]
384
+ metadata: Hashable
385
+ index_mapping: FrozenDict[Index, Index] | None
386
+
387
+ @classmethod
388
+ def create(
389
+ cls,
390
+ type: Type[Node],
391
+ index: int,
392
+ attributes: tuple[Key, ...],
393
+ subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
394
+ static_fields: Iterable[tuple],
395
+ leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
396
+ metadata: Hashable,
397
+ index_mapping: Mapping[Index, Index] | None,
398
+ ):
399
+ return cls(
400
+ type=type,
401
+ index=index,
402
+ attributes=attributes,
403
+ subgraphs=HashableMapping(subgraphs),
404
+ static_fields=HashableMapping(static_fields),
405
+ leaves=HashableMapping(leaves),
406
+ metadata=metadata,
407
+ index_mapping=FrozenDict(index_mapping) if index_mapping is not None else None,
408
+ )
409
+
410
+ def __pretty_repr__(self):
411
+ yield PrettyType(type=type(self))
412
+
413
+ yield PrettyAttr('type', self.type.__name__)
414
+ yield PrettyAttr('index', self.index)
415
+ yield PrettyAttr('attributes', self.attributes)
416
+ yield PrettyAttr('subgraphs', PrettyMapping(self.subgraphs))
417
+ yield PrettyAttr('static_fields', PrettyMapping(self.static_fields))
418
+ yield PrettyAttr('leaves', PrettyMapping(self.leaves))
419
+ yield PrettyAttr('metadata', self.metadata)
420
+ yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
421
+
422
+
423
+ jax.tree_util.register_static(NodeDef)
424
+
425
+
426
+ @dataclasses.dataclass(frozen=True, repr=False)
427
+ class NodeRef(GraphDef[Node], PrettyRepr):
428
+ """
429
+ A reference to a node in the graph.
430
+
431
+ The node can be instances of :class:`Node` or :class:`State`.
432
+
433
+ Attributes
434
+ ----------
435
+ type : Type[Node]
436
+ The type of the node being referenced.
437
+ index : int
438
+ The index of the node in the graph.
439
+
440
+ """
441
+ type: Type[Node]
442
+ index: int
443
+
444
+ def __pretty_repr__(self):
445
+ yield PrettyType(type=type(self))
446
+ yield PrettyAttr('type', self.type.__name__)
447
+ yield PrettyAttr('index', self.index)
448
+
449
+
450
+ jax.tree_util.register_static(NodeRef)
451
+
452
+
453
+ # --------------------------------------------------------
454
+ # Graph operations: start
455
+ # --------------------------------------------------------
456
+
457
+
458
+ def _graph_flatten(
459
+ path: PathParts,
460
+ ref_index: RefMap[Any, Index],
461
+ flatted_state_mapping: Dict[PathParts, StateLeaf],
462
+ node: Any,
463
+ treefy_state: bool = False,
464
+ ) -> Union[NodeDef[Any], NodeRef[Any]]:
465
+ """
466
+ Recursive helper for graph flatten.
467
+
468
+ Parameters
469
+ ----------
470
+ path : PathParts
471
+ The path to the node.
472
+ ref_index : RefMap[Any, Index]
473
+ A mapping from nodes to indexes.
474
+ flatted_state_mapping : Dict[PathParts, StateLeaf]
475
+ A mapping from paths to state leaves.
476
+ node : Node
477
+ The node to flatten.
478
+ treefy_state : bool, optional
479
+ Whether to convert states to TreefyState, by default False.
480
+
481
+ Returns
482
+ -------
483
+ NodeDef or NodeRef
484
+ A NodeDef or a NodeRef.
485
+
486
+ """
487
+ if not _is_node(node):
488
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
489
+
490
+ # If the node is already in the cache, return a reference, otherwise
491
+ # add it to the cache and continue with the flattening process.
492
+ # This is done to avoid infinite recursion when there is a reference cycle.
493
+ if node in ref_index:
494
+ return NodeRef(type(node), ref_index[node])
495
+
496
+ # Get the node implementation for the node type.
497
+ # There are two types of node implementations: GraphNodeImpl and PyTreeNodeImpl.
498
+ # - ``GraphNodeImpl`` is used for nodes that have a graph structure.
499
+ # - ``PyTreeNodeImpl`` is used for nodes that have a tree structure.
500
+ node_impl = _get_node_impl(node)
501
+
502
+ # There are two types of nodes: Node and State.
503
+ # Here we handle the Node case.
504
+ if isinstance(node_impl, GraphNodeImpl):
505
+ # add the node to the cache
506
+ index = len(ref_index)
507
+ ref_index[node] = index
508
+ else:
509
+ index = -1
510
+
511
+ subgraphs: list[tuple[Key, Union[NodeDef[Any], NodeRef[Any]]]] = []
512
+ static_fields: list[tuple] = []
513
+ leaves: list[tuple[Key, Union[NodeRef[Any], None]]] = []
514
+
515
+ # Flatten the node into a sequence of key-value pairs.
516
+ values, metadata = node_impl.flatten(node)
517
+ for key, value in values:
518
+ if _is_node(value):
519
+ # Recursively flatten the subgraph.
520
+ nodedef = _graph_flatten((*path, key), ref_index, flatted_state_mapping, value, treefy_state)
521
+ subgraphs.append((key, nodedef))
522
+ elif isinstance(value, State):
523
+ # If the variable is in the cache, add a reference to it.
524
+ if value in ref_index:
525
+ leaves.append((key, NodeRef(type(value), ref_index[value])))
526
+ else:
527
+ # If the variable is not in the cache, add it to the cache.
528
+ # This is done to avoid multiple references to the same variable.
529
+ flatted_state_mapping[(*path, key)] = (value.to_state_ref() if treefy_state else value)
530
+ variable_index = ref_index[value] = len(ref_index)
531
+ leaves.append((key, NodeRef(type(value), variable_index)))
532
+ elif _is_state_leaf(value):
533
+ # The instance of ``TreefyState`` is a leaf.
534
+ flatted_state_mapping[(*path, key)] = value
535
+ leaves.append((key, None))
536
+ else:
537
+ # if isinstance(value, (jax.Array, np.ndarray)):
538
+ # path_str = '/'.join(map(str, (*path, key)))
539
+ # raise ValueError(f'Arrays leaves are not supported, at {path_str!r}: {value}')
540
+
541
+ # The value is a static field.
542
+ static_fields.append((key, value))
543
+
544
+ nodedef = NodeDef.create(
545
+ type=node_impl.type,
546
+ index=index,
547
+ attributes=tuple(key for key, _ in values),
548
+ subgraphs=subgraphs,
549
+ static_fields=static_fields,
550
+ leaves=leaves,
551
+ metadata=metadata,
552
+ index_mapping=None,
553
+ )
554
+ return nodedef
555
+
556
+
557
+ @set_module_as('brainstate.graph')
558
+ def flatten(
559
+ node: Any,
560
+ /,
561
+ ref_index: Optional[RefMap[Any, Index]] = None,
562
+ treefy_state: bool = True,
563
+ ) -> Tuple[GraphDef[Any], NestedDict]:
564
+ """
565
+ Flattens a graph node into a (graph_def, state_mapping) pair.
566
+
567
+ Parameters
568
+ ----------
569
+ node : Node
570
+ A graph node.
571
+ ref_index : RefMap[Any, Index], optional
572
+ A mapping from nodes to indexes, defaults to None. If not provided, a new
573
+ empty dictionary is created. This argument can be used to flatten a sequence of graph
574
+ nodes that share references.
575
+ treefy_state : bool, optional
576
+ If True, the state mapping will be a NestedDict instead of a flat dictionary.
577
+ Default is True.
578
+
579
+ Returns
580
+ -------
581
+ tuple[GraphDef, NestedDict]
582
+ A tuple containing the graph definition and state mapping.
583
+
584
+ Examples
585
+ --------
586
+ .. code-block:: python
587
+
588
+ >>> import brainstate
589
+ >>> node = brainstate.graph.Node()
590
+ >>> graph_def, state_mapping = brainstate.graph.flatten(node)
591
+ >>> print(graph_def)
592
+ >>> print(state_mapping)
593
+
594
+ """
595
+ ref_index = RefMap() if ref_index is None else ref_index
596
+ assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
597
+ flatted_state_mapping: dict[PathParts, StateLeaf] = {}
598
+ graph_def = _graph_flatten((), ref_index, flatted_state_mapping, node, treefy_state)
599
+ return graph_def, NestedDict.from_flat(flatted_state_mapping)
600
+
601
+
602
+ def _get_children(
603
+ graph_def: NodeDef[Any],
604
+ state_mapping: Mapping,
605
+ index_ref: dict[Index, Any],
606
+ index_ref_cache: Optional[dict[Index, Any]],
607
+ ) -> dict[Key, Union[StateLeaf, Any]]:
608
+ children: dict[Key, Union[StateLeaf, Any]] = {}
609
+
610
+ # NOTE: we could allow adding new StateLeafs here
611
+ # All state keys must be present in the graph definition (the object attributes)
612
+ if unknown_keys := set(state_mapping) - set(graph_def.attributes):
613
+ raise ValueError(f'Unknown keys: {unknown_keys}')
614
+
615
+ # for every key in attributes there are 6 possible cases:
616
+ # - (2) the key can either be present in the state or not
617
+ # - (3) the key can be a subgraph, a leaf, or a static attribute
618
+ for key in graph_def.attributes:
619
+ if key not in state_mapping: # static field
620
+ # Support unflattening with missing keys for static fields and subgraphs
621
+ # This allows partial state restoration and flexible graph reconstruction
622
+ if key in graph_def.static_fields:
623
+ children[key] = graph_def.static_fields[key]
624
+
625
+ elif key in graph_def.subgraphs:
626
+ # if the key is a subgraph we create an empty node
627
+ subgraphdef = graph_def.subgraphs[key]
628
+ if isinstance(subgraphdef, NodeRef):
629
+ # subgraph exists, take it from the cache
630
+ children[key] = index_ref[subgraphdef.index]
631
+
632
+ else:
633
+ # create a node from an empty state, reasoning:
634
+ # * it is a node with no state
635
+ # * it is a node with state but only through references of already
636
+ # created nodes
637
+ substate = {}
638
+ children[key] = _graph_unflatten(subgraphdef, substate, index_ref, index_ref_cache)
639
+
640
+ elif key in graph_def.leaves:
641
+ noderef = graph_def.leaves[key]
642
+ if (noderef is not None) and (noderef.index in index_ref):
643
+ # variable exists, take it from the cache
644
+ children[key] = index_ref[noderef.index]
645
+
646
+ else:
647
+ # key for a variable is missing, raise an error
648
+ raise ValueError(
649
+ f'Expected key {key!r} in state while building node of type '
650
+ f'{graph_def.type.__name__}.'
651
+ )
652
+
653
+ else:
654
+ raise RuntimeError(f'Unknown static field: {key!r}')
655
+
656
+ else: # state field
657
+ value = state_mapping[key]
658
+ if isinstance(value, PrettyDict):
659
+ value = dict(value)
660
+
661
+ if key in graph_def.static_fields:
662
+ raise ValueError(f'Got state for static field {key!r}, this is not supported.')
663
+
664
+ if key in graph_def.subgraphs:
665
+ # if _is_state_leaf(value):
666
+ if isinstance(value, (TreefyState, State)):
667
+ raise ValueError(
668
+ f'Expected value of type {graph_def.subgraphs[key]} '
669
+ f'for {key!r}, but got {value!r}'
670
+ )
671
+
672
+ if not isinstance(value, dict):
673
+ raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
674
+
675
+ subgraphdef = graph_def.subgraphs[key]
676
+ if isinstance(subgraphdef, NodeRef):
677
+ children[key] = index_ref[subgraphdef.index]
678
+ else:
679
+ children[key] = _graph_unflatten(subgraphdef, value, index_ref, index_ref_cache)
680
+
681
+ elif key in graph_def.leaves:
682
+ # if not _is_state_leaf(value):
683
+ if not isinstance(value, (TreefyState, State)):
684
+ raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
685
+
686
+ noderef = graph_def.leaves[key]
687
+ if noderef is None:
688
+ # if the leaf is None, it means that the value was originally
689
+ # a non-TreefyState leaf, however we allow providing a
690
+ # TreefyState presumbly created by modifying the NestedDict
691
+ if isinstance(value, TreefyState):
692
+ value = value.to_state()
693
+ elif isinstance(value, State):
694
+ value = value
695
+ children[key] = value
696
+
697
+ elif noderef.index in index_ref:
698
+ # add an existing variable
699
+ children[key] = index_ref[noderef.index]
700
+
701
+ else:
702
+ # it is an unseen variable, create a new one
703
+ if not isinstance(value, (TreefyState, State)):
704
+ raise ValueError(
705
+ f'Expected a State type for {key!r}, but got {type(value)}.'
706
+ )
707
+
708
+ # when idxmap is present, check if the Varable exists there
709
+ # and update existing variables if it does
710
+ if index_ref_cache is not None and noderef.index in index_ref_cache:
711
+ variable = index_ref_cache[noderef.index]
712
+ if not isinstance(variable, State):
713
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(variable)}.')
714
+ if isinstance(value, TreefyState):
715
+ variable.update_from_ref(value)
716
+ elif isinstance(value, State):
717
+ if value._been_writen:
718
+ variable.value = value.value
719
+ else:
720
+ variable.restore_value(value.value)
721
+ else:
722
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
723
+ else: # if it doesn't, create a new variable
724
+ if isinstance(value, TreefyState):
725
+ variable = value.to_state()
726
+ elif isinstance(value, State):
727
+ variable = value
728
+ else:
729
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
730
+ children[key] = variable
731
+ index_ref[noderef.index] = variable
732
+
733
+ else:
734
+ raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
735
+
736
+ return children
737
+
738
+
739
+ def _graph_unflatten(
740
+ graph_def: Union[NodeDef[Any], NodeRef[Any]],
741
+ state_mapping: Mapping[Key, Union[StateLeaf, Mapping]],
742
+ index_ref: dict[Index, Any],
743
+ index_ref_cache: Optional[dict[Index, Any]],
744
+ ) -> Any:
745
+ """
746
+ Recursive helper for graph unflatten.
747
+
748
+ Args:
749
+ graph_def: A `GraphDef` instance or an index to a node in the cache.
750
+ state_mapping: A state mapping from attribute names to variables or subgraphs.
751
+ index_ref: A mapping from indexes to nodes that have been traversed.
752
+ If a node is already in the cache, it won't be traversed again.
753
+ index_ref_cache: A mapping from indexes to existing nodes that can be reused.
754
+ When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
755
+ object in an empty state and then filled by the unflatten process, as a result
756
+ existing graph nodes are mutated to have the new content/topology
757
+ specified by the nodedef.
758
+
759
+ Returns:
760
+ A node instance.
761
+ """
762
+
763
+ # if the graph_def is a reference, this means that the node has already been created, so
764
+ # we return the node from the cache
765
+ if isinstance(graph_def, NodeRef):
766
+ return index_ref[graph_def.index]
767
+ else:
768
+ assert isinstance(graph_def, NodeDef), f"graph_def must be a NodeDef. But we got: {graph_def}"
769
+
770
+ # graph_def must be a registered node type
771
+ if not _is_node_type(graph_def.type):
772
+ raise RuntimeError(f'Unsupported type: {graph_def.type}, this is a bug.')
773
+
774
+ # check if the index is already in the cache
775
+ if graph_def.index in index_ref:
776
+ raise RuntimeError(f'GraphDef index {graph_def.index} already used.')
777
+
778
+ # get the node implementation for the node type
779
+ node_impl = get_node_impl_for_type(graph_def.type)
780
+
781
+ if isinstance(node_impl, GraphNodeImpl):
782
+ # we create an empty node first and add it to the index
783
+ # this avoids infinite recursion when there is a reference cycle
784
+
785
+ if (index_ref_cache is not None) and (graph_def.index in index_ref_cache):
786
+ # clear the node to leave it in an empty state
787
+ node = index_ref_cache[graph_def.index]
788
+ if type(node) != graph_def.type:
789
+ raise ValueError(f'Expected a node of type {graph_def.type} for index '
790
+ f'{graph_def.index}, but got a node of type {type(node)}.')
791
+ node_impl.clear(node)
792
+ else:
793
+ # create an empty node
794
+ node = node_impl.create_empty(graph_def.metadata)
795
+
796
+ # add the node to the cache
797
+ index_ref[graph_def.index] = node
798
+
799
+ # get the children (the attributes) of the node
800
+ children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
801
+
802
+ # initialize the node with the children
803
+ node_impl.init(node, tuple(children.items()))
804
+
805
+ else:
806
+ # if the node type does not support the creation of an empty object it means
807
+ # that it cannot reference itself, so we can create its children first
808
+
809
+ # first, we create the children (attributes)
810
+ children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
811
+ # then, we create the node
812
+ node = node_impl.unflatten(tuple(children.items()), graph_def.metadata)
813
+
814
+ return node
815
+
816
+
817
+ @set_module_as('brainstate.graph')
818
+ def unflatten(
819
+ graph_def: GraphDef[Any],
820
+ state_mapping: NestedDict,
821
+ /,
822
+ *,
823
+ index_ref: Optional[dict[Index, Any]] = None,
824
+ index_ref_cache: Optional[dict[Index, Any]] = None,
825
+ ) -> Any:
826
+ """
827
+ Unflattens a graphdef into a node with the given state tree mapping.
828
+
829
+ Parameters
830
+ ----------
831
+ graph_def : GraphDef
832
+ A GraphDef instance.
833
+ state_mapping : NestedDict
834
+ A NestedDict instance containing the state mapping.
835
+ index_ref : dict[Index, Any], optional
836
+ A mapping from indexes to nodes references found during the graph
837
+ traversal. If not provided, a new empty dictionary is created. This argument
838
+ can be used to unflatten a sequence of (graphdef, state_mapping) pairs that
839
+ share the same index space.
840
+ index_ref_cache : dict[Index, Any], optional
841
+ A mapping from indexes to existing nodes that can be reused. When a reference
842
+ is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty
843
+ state and then filled by the unflatten process. As a result, existing graph
844
+ nodes are mutated to have the new content/topology specified by the graphdef.
845
+
846
+ Returns
847
+ -------
848
+ Node
849
+ The reconstructed node.
850
+
851
+ Examples
852
+ --------
853
+ .. code-block:: python
854
+
855
+ >>> import brainstate
856
+ >>> class MyNode(brainstate.graph.Node):
857
+ ... def __init__(self):
858
+ ... self.a = brainstate.nn.Linear(2, 3)
859
+ ... self.b = brainstate.nn.Linear(3, 4)
860
+ ...
861
+ >>> # Flatten a node
862
+ >>> node = MyNode()
863
+ >>> graphdef, statetree = brainstate.graph.flatten(node)
864
+ >>>
865
+ >>> # Unflatten back to node
866
+ >>> reconstructed_node = brainstate.graph.unflatten(graphdef, statetree)
867
+ >>> assert isinstance(reconstructed_node, MyNode)
868
+ >>> assert isinstance(reconstructed_node.a, brainstate.nn.Linear)
869
+ >>> assert isinstance(reconstructed_node.b, brainstate.nn.Linear)
870
+ """
871
+ index_ref = {} if index_ref is None else index_ref
872
+ assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
873
+ node = _graph_unflatten(graph_def, state_mapping.to_dict(), index_ref, index_ref_cache)
874
+ return node
875
+
876
+
877
+ def _graph_pop(
878
+ node: Any,
879
+ id_to_index: dict[int, Index],
880
+ path_parts: PathParts,
881
+ flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
882
+ predicates: tuple[Predicate, ...],
883
+ ) -> None:
884
+ if not _is_node(node):
885
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
886
+
887
+ if id(node) in id_to_index:
888
+ return
889
+
890
+ id_to_index[id(node)] = len(id_to_index)
891
+ node_impl = _get_node_impl(node)
892
+ node_dict = node_impl.node_dict(node)
893
+
894
+ for name, value in node_dict.items():
895
+ if _is_node(value):
896
+ _graph_pop(
897
+ node=value,
898
+ id_to_index=id_to_index,
899
+ path_parts=(*path_parts, name),
900
+ flatted_state_dicts=flatted_state_dicts,
901
+ predicates=predicates,
902
+ )
903
+ continue
904
+ elif not _is_node_leaf(value):
905
+ continue
906
+ elif id(value) in id_to_index:
907
+ continue
908
+
909
+ node_path = (*path_parts, name)
910
+ node_impl = _get_node_impl(node)
911
+ for state_dicts, predicate in zip(flatted_state_dicts, predicates):
912
+ if predicate(node_path, value):
913
+ if isinstance(node_impl, PyTreeNodeImpl):
914
+ raise ValueError(f'Cannot pop key {name!r} from node of type {type(node).__name__}')
915
+ id_to_index[id(value)] = len(id_to_index)
916
+ node_impl.pop_key(node, name)
917
+ # if isinstance(value, State):
918
+ # value = value.to_state_ref()
919
+ state_dicts[node_path] = value # type: ignore[index] # mypy is wrong here?
920
+ break
921
+ else:
922
+ # NOTE: should we raise an error here?
923
+ pass
924
+
925
+
926
+ @set_module_as('brainstate.graph')
927
+ def pop_states(
928
+ node: Any, *filters: Any
929
+ ) -> Union[NestedDict, Tuple[NestedDict, ...]]:
930
+ """
931
+ Pop one or more :class:`State` types from the graph node.
932
+
933
+ Parameters
934
+ ----------
935
+ node : Node
936
+ A graph node object.
937
+ *filters
938
+ One or more :class:`State` objects to filter by.
939
+
940
+ Returns
941
+ -------
942
+ NestedDict or tuple[NestedDict, ...]
943
+ The popped :class:`NestedDict` containing the :class:`State`
944
+ objects that were filtered for.
945
+
946
+ Examples
947
+ --------
948
+ .. code-block:: python
949
+
950
+ >>> import brainstate
951
+ >>> import jax.numpy as jnp
952
+
953
+ >>> class Model(brainstate.nn.Module):
954
+ ... def __init__(self):
955
+ ... super().__init__()
956
+ ... self.a = brainstate.nn.Linear(2, 3)
957
+ ... self.b = brainstate.nn.LIF([10, 2])
958
+
959
+ >>> model = Model()
960
+ >>> with brainstate.catch_new_states('new'):
961
+ ... brainstate.nn.init_all_states(model)
962
+
963
+ >>> assert len(model.states()) == 2
964
+ >>> model_states = brainstate.graph.pop_states(model, 'new')
965
+ >>> model_states # doctest: +SKIP
966
+ NestedDict({
967
+ 'b': {
968
+ 'V': {
969
+ 'st': ShortTermState(
970
+ value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
971
+ 0., 0., 0.], dtype=float32),
972
+ tag='new'
973
+ )
974
+ }
975
+ }
976
+ })
977
+ """
978
+ if len(filters) == 0:
979
+ raise ValueError('Expected at least one filter')
980
+
981
+ id_to_index: dict[int, Index] = {}
982
+ path_parts: PathParts = ()
983
+ predicates = tuple(to_predicate(filter) for filter in filters)
984
+ flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
985
+ _graph_pop(
986
+ node=node,
987
+ id_to_index=id_to_index,
988
+ path_parts=path_parts,
989
+ flatted_state_dicts=flatted_state_dicts,
990
+ predicates=predicates,
991
+ )
992
+ states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
993
+
994
+ if len(states) == 1:
995
+ return states[0]
996
+ else:
997
+ return states
998
+
999
+
1000
+ def _split_state(
1001
+ state: GraphStateMapping,
1002
+ filters: tuple[Filter, ...],
1003
+ ) -> tuple[GraphStateMapping, Unpack[tuple[GraphStateMapping, ...]]]:
1004
+ if not filters:
1005
+ return (state,)
1006
+ states = state.split(*filters)
1007
+ if isinstance(states, NestedDict):
1008
+ return (states,)
1009
+ assert len(states) > 0
1010
+ return states # type: ignore[return-value]
1011
+
1012
+
1013
+ @set_module_as('brainstate.graph')
1014
+ def treefy_split(
1015
+ node: A, *filters: Filter
1016
+ ):
1017
+ """
1018
+ Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s.
1019
+
1020
+ NestedDict is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States.
1021
+ GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is
1022
+ analogous to JAX's ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to
1023
+ switch seamlessly between stateful and stateless representations of the graph.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ node : A
1028
+ Graph node to split.
1029
+ *filters
1030
+ Optional filters to group the state into mutually exclusive substates.
1031
+
1032
+ Returns
1033
+ -------
1034
+ tuple
1035
+ ``GraphDef`` and one or more ``States`` equal to the number of filters passed.
1036
+ If no filters are passed, a single ``NestedDict`` is returned.
1037
+
1038
+ Examples
1039
+ --------
1040
+ .. code-block:: python
1041
+
1042
+ >>> import brainstate
1043
+ >>> import jax, jax.numpy as jnp
1044
+
1045
+ >>> class Foo(brainstate.graph.Node):
1046
+ ... def __init__(self):
1047
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1048
+ ... self.b = brainstate.nn.Linear(2, 3)
1049
+ ...
1050
+ >>> node = Foo()
1051
+ >>> graphdef, params, others = brainstate.graph.treefy_split(
1052
+ ... node, brainstate.ParamState, ...
1053
+ ... )
1054
+ >>> # params contains ParamState variables
1055
+ >>> # others contains all other state variables
1056
+ """
1057
+ graphdef, state_tree = flatten(node)
1058
+ states = tuple(_split_state(state_tree, filters))
1059
+ return graphdef, *states
1060
+
1061
+
1062
+ @set_module_as('brainstate.graph')
1063
+ def treefy_merge(graphdef: GraphDef[A], *state_mappings) -> A:
1064
+ """
1065
+ The inverse of :func:`split`.
1066
+
1067
+ ``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
1068
+ a new node with the same structure as the original node.
1069
+
1070
+ Parameters
1071
+ ----------
1072
+ graphdef : GraphDef[A]
1073
+ A :class:`GraphDef` object.
1074
+ *state_mappings
1075
+ Additional :class:`NestedDict` objects.
1076
+
1077
+ Returns
1078
+ -------
1079
+ A
1080
+ The merged :class:`Module`.
1081
+
1082
+ Examples
1083
+ --------
1084
+ .. code-block:: python
1085
+
1086
+ >>> import brainstate
1087
+ >>> import jax, jax.numpy as jnp
1088
+
1089
+ >>> class Foo(brainstate.graph.Node):
1090
+ ... def __init__(self):
1091
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1092
+ ... self.b = brainstate.nn.Linear(2, 3)
1093
+ ...
1094
+ >>> node = Foo()
1095
+ >>> graphdef, params, others = brainstate.graph.treefy_split(
1096
+ ... node, brainstate.ParamState, ...
1097
+ ... )
1098
+ >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1099
+ >>> assert isinstance(new_node, Foo)
1100
+ >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1101
+ >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1102
+ """
1103
+ state_mapping = GraphStateMapping.merge(*state_mappings)
1104
+ node = unflatten(graphdef, state_mapping)
1105
+ return node
1106
+
1107
+
1108
+ def _filters_to_predicates(filters: Tuple[Filter, ...]) -> Tuple[Predicate, ...]:
1109
+ for i, filter_ in enumerate(filters):
1110
+ if filter_ in (..., True) and i != len(filters) - 1:
1111
+ remaining_filters = filters[i + 1:]
1112
+ if not all(f in (..., True) for f in remaining_filters):
1113
+ raise ValueError('`...` or `True` can only be used as the last filters, '
1114
+ f'got {filter_} it at index {i}.')
1115
+ return tuple(map(to_predicate, filters))
1116
+
1117
+
1118
+ def _split_flatted(
1119
+ flatted: Iterable[tuple[PathParts, Any]],
1120
+ filters: tuple[Filter, ...],
1121
+ ) -> tuple[list[tuple[PathParts, Any]], ...]:
1122
+ predicates = _filters_to_predicates(filters)
1123
+
1124
+ # we have n + 1 states, where n is the number of predicates
1125
+ # the last state is for values that don't match any predicate
1126
+ flat_states: tuple[list[tuple[PathParts, Any]], ...] = tuple([] for _ in predicates)
1127
+
1128
+ for path, value in flatted:
1129
+ for i, predicate in enumerate(predicates):
1130
+ if predicate(path, value):
1131
+ flat_states[i].append((path, value))
1132
+ break
1133
+ else:
1134
+ raise ValueError('Non-exhaustive filters, got a non-empty remainder: '
1135
+ f'{path} -> {value}.'
1136
+ '\nUse `...` to match all remaining elements.')
1137
+
1138
+ return flat_states
1139
+
1140
+
1141
+ @set_module_as('brainstate.graph')
1142
+ def nodes(
1143
+ node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1144
+ ):
1145
+ """
1146
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1147
+
1148
+ Parameters
1149
+ ----------
1150
+ node : Node
1151
+ The node to get nodes from.
1152
+ *filters
1153
+ Filters to apply to the nodes.
1154
+ allowed_hierarchy : tuple[int, int], optional
1155
+ The allowed hierarchy levels, by default (0, MAX_INT).
1156
+
1157
+ Returns
1158
+ -------
1159
+ FlattedDict or tuple[FlattedDict, ...]
1160
+ The filtered nodes.
1161
+
1162
+ """
1163
+ num_filters = len(filters)
1164
+ if num_filters == 0:
1165
+ filters = (..., ...)
1166
+ else:
1167
+ filters = (*filters, ...)
1168
+
1169
+ nodes_iterable = iter_node(node, allowed_hierarchy=allowed_hierarchy)
1170
+ flat_nodes = _split_flatted(nodes_iterable, (*filters, ...))
1171
+ node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
1172
+ if num_filters < 2:
1173
+ return node_maps[0]
1174
+ return node_maps[:num_filters]
1175
+
1176
+
1177
+ def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
1178
+ for path, value in iter_leaf(node, allowed_hierarchy=allowed_hierarchy):
1179
+ if isinstance(value, State):
1180
+ yield path, value
1181
+
1182
+
1183
+ @set_module_as('brainstate.graph')
1184
+ def states(
1185
+ node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1186
+ ) -> Union[FlattedDict, tuple[FlattedDict, ...]]:
1187
+ """
1188
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1189
+
1190
+ Parameters
1191
+ ----------
1192
+ node : Node
1193
+ The node to get states from.
1194
+ *filters
1195
+ Filters to apply to the states.
1196
+ allowed_hierarchy : tuple[int, int], optional
1197
+ The allowed hierarchy levels, by default (0, MAX_INT).
1198
+
1199
+ Returns
1200
+ -------
1201
+ FlattedDict or tuple[FlattedDict, ...]
1202
+ The filtered states.
1203
+
1204
+ """
1205
+ num_filters = len(filters)
1206
+ if num_filters == 0:
1207
+ filters = (..., ...)
1208
+ else:
1209
+ filters = (*filters, ...)
1210
+
1211
+ states_iterable = _states_generator(node, allowed_hierarchy=allowed_hierarchy)
1212
+ flat_states = _split_flatted(states_iterable, (*filters, ...))
1213
+ state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
1214
+ if num_filters < 2:
1215
+ return state_maps[0]
1216
+ return state_maps[:num_filters]
1217
+
1218
+
1219
+ @set_module_as('brainstate.graph')
1220
+ def treefy_states(
1221
+ node, *filters,
1222
+ ):
1223
+ """
1224
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1225
+
1226
+ Parameters
1227
+ ----------
1228
+ node : Node
1229
+ A graph node object.
1230
+ *filters
1231
+ One or more :class:`State` objects to filter by.
1232
+
1233
+ Returns
1234
+ -------
1235
+ NestedDict or tuple of NestedDict
1236
+ One or more :class:`NestedDict` mappings.
1237
+
1238
+ Examples
1239
+ --------
1240
+ .. code-block:: python
1241
+
1242
+ >>> import brainstate
1243
+ >>> class Model(brainstate.nn.Module):
1244
+ ... def __init__(self):
1245
+ ... super().__init__()
1246
+ ... self.l1 = brainstate.nn.Linear(2, 3)
1247
+ ... self.l2 = brainstate.nn.Linear(3, 4)
1248
+ ... def __call__(self, x):
1249
+ ... return self.l2(self.l1(x))
1250
+
1251
+ >>> model = Model()
1252
+ >>> # Get the learnable parameters
1253
+ >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1254
+ >>> # Get them separately
1255
+ >>> params, others = brainstate.graph.treefy_states(
1256
+ ... model, brainstate.ParamState, brainstate.ShortTermState
1257
+ ... )
1258
+ >>> # Get all states together
1259
+ >>> states = brainstate.graph.treefy_states(model)
1260
+ """
1261
+ _, state_mapping = flatten(node)
1262
+ if len(filters) == 0:
1263
+ return state_mapping
1264
+ else:
1265
+ state_mappings = state_mapping.filter(*filters)
1266
+ if len(filters) == 1:
1267
+ return state_mappings[0]
1268
+ else:
1269
+ return state_mappings
1270
+
1271
+
1272
+ def _graph_update_dynamic(node: Any, state: Mapping) -> None:
1273
+ if not _is_node(node):
1274
+ raise RuntimeError(f'Unsupported type: {type(node)}')
1275
+
1276
+ node_impl = _get_node_impl(node)
1277
+ node_dict = node_impl.node_dict(node)
1278
+ for key, value in state.items():
1279
+ # case 1: new state is being added
1280
+ if key not in node_dict:
1281
+ if isinstance(node_impl, PyTreeNodeImpl):
1282
+ raise ValueError(f'Cannot set key {key!r} on immutable node of '
1283
+ f'type {type(node).__name__}')
1284
+ if isinstance(value, State):
1285
+ # TODO: here maybe error? we should check if the state already belongs to another node?
1286
+ value = value.to_state_ref() # Convert to state reference for proper state management
1287
+ node_impl.set_key(node, key, value)
1288
+ continue
1289
+
1290
+ # check values are of the same type
1291
+ current_value = node_dict[key]
1292
+
1293
+ # case 2: subgraph is being updated
1294
+ if _is_node(current_value):
1295
+ if _is_state_leaf(value):
1296
+ raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
1297
+ _graph_update_dynamic(current_value, value)
1298
+ elif isinstance(value, TreefyState):
1299
+ # case 3: state leaf is being updated
1300
+ if not isinstance(current_value, State):
1301
+ raise ValueError(f'Trying to update a non-State attribute {key!r} with a State: '
1302
+ f'{value!r}')
1303
+ current_value.update_from_ref(value)
1304
+ elif _is_state_leaf(value):
1305
+ # case 4: state field is being updated
1306
+ if isinstance(node_impl, PyTreeNodeImpl):
1307
+ raise ValueError(f'Cannot set key {key!r} on immutable node of '
1308
+ f'type {type(node).__name__}')
1309
+ node_impl.set_key(node, key, value)
1310
+ else:
1311
+ raise ValueError(f'Unsupported update type: {type(value)} for key {key!r}')
1312
+
1313
+
1314
+ def update_states(
1315
+ node: Any,
1316
+ state_dict: Union[NestedDict, FlattedDict],
1317
+ /,
1318
+ *state_dicts: Union[NestedDict, FlattedDict]
1319
+ ) -> None:
1320
+ """
1321
+ Update the given graph node with a new :class:`NestedMapping` in-place.
1322
+
1323
+ Parameters
1324
+ ----------
1325
+ node : Node
1326
+ A graph node to update.
1327
+ state_dict : NestedDict | FlattedDict
1328
+ A :class:`NestedMapping` object.
1329
+ *state_dicts : NestedDict | FlattedDict
1330
+ Additional :class:`NestedMapping` objects.
1331
+
1332
+ """
1333
+ if state_dicts:
1334
+ state_dict = NestedDict.merge(state_dict, *state_dicts)
1335
+ _graph_update_dynamic(node, state_dict.to_dict())
1336
+
1337
+
1338
+ @set_module_as('brainstate.graph')
1339
+ def graphdef(node: Any) -> GraphDef[Any]:
1340
+ """
1341
+ Get the :class:`GraphDef` of the given graph node.
1342
+
1343
+ Parameters
1344
+ ----------
1345
+ node : Any
1346
+ A graph node object.
1347
+
1348
+ Returns
1349
+ -------
1350
+ GraphDef[Any]
1351
+ The :class:`GraphDef` of the :class:`Module` object.
1352
+
1353
+ Examples
1354
+ --------
1355
+ .. code-block:: python
1356
+
1357
+ >>> import brainstate
1358
+
1359
+ >>> model = brainstate.nn.Linear(2, 3)
1360
+ >>> graphdef, _ = brainstate.graph.treefy_split(model)
1361
+ >>> assert graphdef == brainstate.graph.graphdef(model)
1362
+
1363
+ """
1364
+ graphdef, _ = flatten(node)
1365
+ return graphdef
1366
+
1367
+
1368
+ @set_module_as('brainstate.graph')
1369
+ def clone(node: A) -> A:
1370
+ """
1371
+ Create a deep copy of the given graph node.
1372
+
1373
+ Parameters
1374
+ ----------
1375
+ node : Node
1376
+ A graph node object.
1377
+
1378
+ Returns
1379
+ -------
1380
+ Node
1381
+ A deep copy of the :class:`Module` object.
1382
+
1383
+ Examples
1384
+ --------
1385
+ .. code-block:: python
1386
+
1387
+ >>> import brainstate
1388
+ >>> model = brainstate.nn.Linear(2, 3)
1389
+ >>> cloned_model = brainstate.graph.clone(model)
1390
+ >>> model.weight.value['bias'] += 1
1391
+ >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1392
+
1393
+ """
1394
+ graphdef, state = treefy_split(node)
1395
+ return treefy_merge(graphdef, state)
1396
+
1397
+
1398
+ @set_module_as('brainstate.graph')
1399
+ def iter_leaf(
1400
+ node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1401
+ ) -> Iterator[tuple[PathParts, Any]]:
1402
+ """
1403
+ Iterates over all nested leaves in the given graph node, including the current node.
1404
+
1405
+ ``iter_graph`` creates a generator that yields path and value pairs, where
1406
+ the path is a tuple of strings or integers representing the path to the value from the
1407
+ root. Repeated nodes are visited only once. Leaves include static values.
1408
+
1409
+ Parameters
1410
+ ----------
1411
+ node : Any
1412
+ The node to iterate over.
1413
+ allowed_hierarchy : tuple[int, int], optional
1414
+ The allowed hierarchy levels, by default (0, MAX_INT).
1415
+
1416
+ Yields
1417
+ ------
1418
+ Iterator[tuple[PathParts, Any]]
1419
+ Path and value pairs.
1420
+
1421
+ Examples
1422
+ --------
1423
+ .. code-block:: python
1424
+
1425
+ >>> import brainstate
1426
+ >>> import jax.numpy as jnp
1427
+
1428
+ >>> class Linear(brainstate.nn.Module):
1429
+ ... def __init__(self, din, dout):
1430
+ ... super().__init__()
1431
+ ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1432
+ ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1433
+ ... self.a = 1
1434
+ ...
1435
+ >>> module = Linear(3, 4)
1436
+ ...
1437
+ >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1438
+ ... print(path, type(value).__name__)
1439
+ ...
1440
+ (0, 'a') int
1441
+ (0, 'bias') ParamState
1442
+ (0, 'weight') ParamState
1443
+
1444
+ """
1445
+
1446
+ def _iter_graph_leaf(
1447
+ node_: Any,
1448
+ visited_: set[int],
1449
+ path_parts_: PathParts,
1450
+ level_: int,
1451
+ ) -> Iterator[tuple[PathParts, Any]]:
1452
+ if level_ > allowed_hierarchy[1]:
1453
+ return
1454
+
1455
+ if _is_node(node_):
1456
+ if id(node_) in visited_:
1457
+ return
1458
+ visited_.add(id(node_))
1459
+ node_dict = _get_node_impl(node_).node_dict(node_)
1460
+ for key, value in node_dict.items():
1461
+ yield from _iter_graph_leaf(
1462
+ value,
1463
+ visited_,
1464
+ (*path_parts_, key),
1465
+ level_ + 1 if _is_graph_node(value) else level_
1466
+ )
1467
+ else:
1468
+ if level_ >= allowed_hierarchy[0]:
1469
+ yield path_parts_, node_
1470
+
1471
+ visited: set[int] = set()
1472
+ path_parts: PathParts = ()
1473
+ level: int = 0
1474
+ yield from _iter_graph_leaf(node, visited, path_parts, level)
1475
+
1476
+
1477
+ @set_module_as('brainstate.graph')
1478
+ def iter_node(
1479
+ node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1480
+ ) -> Iterator[Tuple[PathParts, Any]]:
1481
+ """
1482
+ Iterates over all nested nodes of the given graph node, including the current node.
1483
+
1484
+ ``iter_graph`` creates a generator that yields path and value pairs, where
1485
+ the path is a tuple of strings or integers representing the path to the value from the
1486
+ root. Repeated nodes are visited only once. Leaves include static values.
1487
+
1488
+ Parameters
1489
+ ----------
1490
+ node : Any
1491
+ The node to iterate over.
1492
+ allowed_hierarchy : tuple[int, int], optional
1493
+ The allowed hierarchy levels, by default (0, MAX_INT).
1494
+
1495
+ Yields
1496
+ ------
1497
+ Iterator[tuple[PathParts, Any]]
1498
+ Path and node pairs.
1499
+
1500
+ Examples
1501
+ --------
1502
+ .. code-block:: python
1503
+
1504
+ >>> import brainstate
1505
+ >>> import jax.numpy as jnp
1506
+
1507
+ >>> class Model(brainstate.nn.Module):
1508
+ ... def __init__(self):
1509
+ ... super().__init__()
1510
+ ... self.a = brainstate.nn.Linear(1, 2)
1511
+ ... self.b = brainstate.nn.Linear(2, 3)
1512
+ ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1513
+ ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1514
+ ... self.b.a = brainstate.nn.LIF(2)
1515
+ ...
1516
+ >>> model = Model()
1517
+ ...
1518
+ >>> for path, node in brainstate.graph.iter_node([model, model]):
1519
+ ... print(path, node.__class__.__name__)
1520
+ ...
1521
+ (0, 'a') Linear
1522
+ (0, 'b', 'a') LIF
1523
+ (0, 'b') Linear
1524
+ (0, 'c', 0) Linear
1525
+ (0, 'c', 1) Linear
1526
+ (0, 'd', 'x') Linear
1527
+ (0, 'd', 'y') Linear
1528
+ (0,) Model
1529
+
1530
+ """
1531
+
1532
+ def _iter_graph_node(
1533
+ node_: Any,
1534
+ visited_: set[int],
1535
+ path_parts_: PathParts,
1536
+ level_: int,
1537
+ ) -> Iterator[tuple[PathParts, Any]]:
1538
+ if level_ > allowed_hierarchy[1]:
1539
+ return
1540
+
1541
+ if _is_node(node_):
1542
+ if id(node_) in visited_:
1543
+ return
1544
+
1545
+ visited_.add(id(node_))
1546
+ node_dict = _get_node_impl(node_).node_dict(node_)
1547
+ for key, value in node_dict.items():
1548
+ yield from _iter_graph_node(value, visited_, (*path_parts_, key),
1549
+ level_ + 1 if _is_graph_node(value) else level_)
1550
+
1551
+ if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
1552
+ yield path_parts_, node_
1553
+
1554
+ visited: set[int] = set()
1555
+ path_parts: PathParts = ()
1556
+ level: int = 0
1557
+ yield from _iter_graph_node(node, visited, path_parts, level)
1558
+
1559
+
1560
+ # --------------------------------------------------------
1561
+ # Graph operations: end
1562
+ # --------------------------------------------------------
1563
+
1564
+
1565
+ @dataclasses.dataclass(frozen=True)
1566
+ class Static(Generic[A]):
1567
+ """
1568
+ An empty pytree node that treats its inner value as static.
1569
+
1570
+ ``value`` must define ``__eq__`` and ``__hash__``.
1571
+
1572
+ Attributes
1573
+ ----------
1574
+ value : A
1575
+ The static value to wrap.
1576
+
1577
+ """
1578
+
1579
+ value: A
1580
+
1581
+
1582
+ jax.tree_util.register_static(Static)
1583
+
1584
+
1585
+ # ---------------------------------------------------------
1586
+ # Pytree
1587
+ # ---------------------------------------------------------
1588
+
1589
+ class PytreeType:
1590
+ ...
1591
+
1592
+
1593
+ def _key_path_to_key(key: Any) -> Key:
1594
+ if isinstance(key, jax.tree_util.SequenceKey):
1595
+ return key.idx
1596
+ elif isinstance(
1597
+ key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
1598
+ ):
1599
+ if not isinstance(key.key, Key):
1600
+ raise ValueError(
1601
+ f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
1602
+ )
1603
+ return key.key
1604
+ elif isinstance(key, jax.tree_util.GetAttrKey):
1605
+ return key.name
1606
+ else:
1607
+ return str(key)
1608
+
1609
+
1610
+ def _flatten_pytree(pytree: Any) -> Tuple[Tuple[Tuple, ...], jax.tree_util.PyTreeDef]:
1611
+ leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
1612
+ nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
1613
+ return nodes, treedef
1614
+
1615
+
1616
+ def _unflatten_pytree(
1617
+ nodes: tuple[tuple, ...],
1618
+ treedef: jax.tree_util.PyTreeDef
1619
+ ) -> Any:
1620
+ pytree = treedef.unflatten(value for _, value in nodes)
1621
+ return pytree
1622
+
1623
+
1624
+ PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)