brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1746 @@
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import dataclasses
21
+ from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
22
+ Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
23
+
24
+ import jax
25
+ import numpy as np
26
+ from typing_extensions import TypeGuard, Unpack
27
+
28
+ from brainstate._state import State, TreefyState
29
+ from brainstate._utils import set_module_as
30
+ from brainstate.typing import PathParts, Filter, Predicate, Key
31
+ from brainstate.util._caller import ApplyCaller, CallableProxy, DelayedAccessor
32
+ from brainstate.util._dict import NestedDict, FlattedDict, PrettyDict
33
+ from brainstate.util._filter import to_predicate
34
+ from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
35
+ from brainstate.util._struct import FrozenDict
36
+
37
+ _max_int = np.iinfo(np.int32).max
38
+
39
+ __all__ = [
40
+ # state management in the given graph or node
41
+ 'pop_states', 'nodes', 'states', 'treefy_states', 'update_states',
42
+
43
+ # graph node operations
44
+ 'flatten', 'unflatten', 'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef', 'call',
45
+
46
+ # others
47
+ 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef'
48
+ ]
49
+
50
+ A = TypeVar('A')
51
+ B = TypeVar('B')
52
+ C = TypeVar('C')
53
+ F = TypeVar('F', bound=Callable)
54
+
55
+ HA = TypeVar('HA', bound=Hashable)
56
+ HB = TypeVar('HB', bound=Hashable)
57
+
58
+ Index = int
59
+ Names = Sequence[int]
60
+ Node = TypeVar('Node')
61
+ Leaf = TypeVar('Leaf')
62
+ AuxData = TypeVar('AuxData')
63
+
64
+ StateLeaf = TreefyState[Any]
65
+ NodeLeaf = State[Any]
66
+ GraphStateMapping = NestedDict[Key, StateLeaf]
67
+
68
+
69
+ # --------------------------------------------------------
70
+
71
+
72
+ def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
73
+ return isinstance(x, TreefyState)
74
+
75
+
76
+ def _is_node_leaf(x: Any) -> TypeGuard[NodeLeaf]:
77
+ return isinstance(x, State)
78
+
79
+
80
+ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
81
+ """
82
+ A mapping that uses object id as the hash for the keys.
83
+
84
+ This mapping is useful when we want to keep track of objects
85
+ that are being referenced by other objects.
86
+
87
+ Args:
88
+ mapping: A mapping or iterable of key-value pairs.
89
+
90
+ """
91
+ __module__ = 'brainstate.graph'
92
+
93
+ def __init__(self, mapping: Mapping[A, B] | Iterable[Tuple[A, B]] = ()):
94
+ self._mapping: Dict[int, Tuple[A, B]] = {}
95
+ self.update(mapping)
96
+
97
+ def __getitem__(self, key: A) -> B:
98
+ return self._mapping[id(key)][1]
99
+
100
+ def __contains__(self, key: Any) -> bool:
101
+ return id(key) in self._mapping
102
+
103
+ def __setitem__(self, key: A, value: B):
104
+ self._mapping[id(key)] = (key, value)
105
+
106
+ def __delitem__(self, key: A):
107
+ del self._mapping[id(key)]
108
+
109
+ def __iter__(self) -> Iterator[A]:
110
+ return (key for key, _ in self._mapping.values())
111
+
112
+ def __len__(self) -> int:
113
+ return len(self._mapping)
114
+
115
+ def __str__(self) -> str:
116
+ return repr(self)
117
+
118
+
119
+ @dataclasses.dataclass(frozen=True)
120
+ class NodeImplBase(Generic[Node, Leaf, AuxData]):
121
+ type: type
122
+ flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
123
+
124
+ def node_dict(self, node: Node) -> dict[Key, Leaf]:
125
+ nodes, _ = self.flatten(node)
126
+ return dict(nodes)
127
+
128
+
129
+ @dataclasses.dataclass(frozen=True)
130
+ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
131
+ set_key: Callable[[Node, Key, Leaf], None]
132
+ pop_key: Callable[[Node, Key], Leaf]
133
+ create_empty: Callable[[AuxData], Node]
134
+ clear: Callable[[Node], None]
135
+
136
+ def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]):
137
+ for key, value in items:
138
+ self.set_key(node, key, value)
139
+
140
+
141
+ @dataclasses.dataclass(frozen=True)
142
+ class PyTreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
143
+ unflatten: Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
144
+
145
+
146
+ NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf, AuxData]]
147
+
148
+ # --------------------------------------------------------
149
+ # Graph Node implementation: start
150
+ # --------------------------------------------------------
151
+
152
+ _node_impl_for_type: dict[type, NodeImpl[Any, Any, Any]] = {}
153
+
154
+
155
+ def register_graph_node_type(
156
+ type: type,
157
+ flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
158
+ set_key: Callable[[Node, Key, Leaf], None],
159
+ pop_key: Callable[[Node, Key], Leaf],
160
+ create_empty: Callable[[AuxData], Node],
161
+ clear: Callable[[Node], None],
162
+ ):
163
+ """
164
+ Register a graph node type.
165
+
166
+ Args:
167
+ type: The type of the node.
168
+ flatten: A function that flattens the node into a sequence of key-value pairs.
169
+ set_key: A function that sets a key in the node.
170
+ pop_key: A function that pops a key from the node.
171
+ create_empty: A function that creates an empty node.
172
+ clear: A function that clears the node
173
+ """
174
+ _node_impl_for_type[type] = GraphNodeImpl(
175
+ type=type,
176
+ flatten=flatten,
177
+ set_key=set_key,
178
+ pop_key=pop_key,
179
+ create_empty=create_empty,
180
+ clear=clear,
181
+ )
182
+
183
+
184
+ # --------------------------------------------------------
185
+ # Graph node implementation: end
186
+ # --------------------------------------------------------
187
+
188
+
189
+ def _is_node(x: Any) -> bool:
190
+ return _is_graph_node(x) or _is_pytree_node(x)
191
+
192
+
193
+ def _is_pytree_node(x: Any) -> bool:
194
+ return not jax.tree_util.all_leaves((x,))
195
+
196
+
197
+ def _is_graph_node(x: Any) -> bool:
198
+ return type(x) in _node_impl_for_type
199
+
200
+
201
+ def _is_node_type(x: type[Any]) -> bool:
202
+ return x in _node_impl_for_type or x is PytreeType
203
+
204
+
205
+ def _get_node_impl(x: Node) -> NodeImpl[Node, Any, Any]:
206
+ if isinstance(x, State):
207
+ raise ValueError(f'State is not a node: {x}')
208
+
209
+ node_type = type(x)
210
+ if node_type not in _node_impl_for_type:
211
+ if _is_pytree_node(x):
212
+ return PYTREE_NODE_IMPL
213
+ else:
214
+ raise ValueError(f'Unknown node type: {x}')
215
+
216
+ return _node_impl_for_type[node_type]
217
+
218
+
219
+ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, Any, Any]:
220
+ if x is PytreeType:
221
+ return PYTREE_NODE_IMPL
222
+ return _node_impl_for_type[x]
223
+
224
+
225
+ class HashableMapping(Mapping[HA, HB], Hashable):
226
+ def __init__(self, mapping: Mapping[HA, HB] | Iterable[tuple[HA, HB]]):
227
+ self._mapping = dict(mapping)
228
+
229
+ def __contains__(self, key: object) -> bool:
230
+ return key in self._mapping
231
+
232
+ def __getitem__(self, key: HA) -> HB:
233
+ return self._mapping[key]
234
+
235
+ def __iter__(self) -> Iterator[HA]:
236
+ return iter(self._mapping)
237
+
238
+ def __len__(self) -> int:
239
+ return len(self._mapping)
240
+
241
+ def __hash__(self) -> int:
242
+ return hash(tuple(sorted(self._mapping.items())))
243
+
244
+ def __eq__(self, other: Any) -> bool:
245
+ return isinstance(other, HashableMapping) and self._mapping == other._mapping
246
+
247
+ def __repr__(self) -> str:
248
+ return repr(self._mapping)
249
+
250
+
251
+ class GraphDef(Generic[Node]):
252
+ """
253
+ A base dataclass that denotes the graph structure of a :class:`Node`.
254
+
255
+ It contains two main components:
256
+ - type: The type of the node.
257
+ - index: The index of the node in the graph.
258
+
259
+ It has two concrete subclasses:
260
+ - :class:`NodeRef`: A reference to a node in the graph.
261
+ - :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
262
+
263
+ """
264
+ type: type[Node]
265
+ index: int
266
+
267
+
268
+ @dataclasses.dataclass(frozen=True, repr=False)
269
+ class NodeRef(GraphDef[Node], PrettyRepr):
270
+ """
271
+ A reference to a node in the graph.
272
+
273
+ The node can be instances of :class:`Node` or :class:`State`.
274
+ """
275
+ type: type[Node]
276
+ index: int
277
+
278
+ def __pretty_repr__(self):
279
+ yield PrettyType(type=type(self))
280
+ yield PrettyAttr('type', self.type.__name__)
281
+ yield PrettyAttr('index', self.index)
282
+
283
+ def __treescope_repr__(self, path, subtree_renderer):
284
+ """
285
+ Treescope repr for the object.
286
+ """
287
+ import treescope # type: ignore[import-not-found,import-untyped]
288
+ return treescope.repr_lib.render_object_constructor(
289
+ object_type=type(self),
290
+ attributes={'type': self.type, 'index': self.index},
291
+ path=path,
292
+ subtree_renderer=subtree_renderer,
293
+ )
294
+
295
+
296
+ jax.tree_util.register_static(NodeRef)
297
+
298
+
299
+ @dataclasses.dataclass(frozen=True, repr=False)
300
+ class NodeDef(GraphDef[Node], PrettyRepr):
301
+ """
302
+ A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
303
+
304
+ """
305
+
306
+ type: Type[Node] # type of the node
307
+ index: int # index of the node in the graph
308
+ attributes: Tuple[Key, ...] # attributes for the node
309
+ subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
310
+ static_fields: HashableMapping[Key, Any]
311
+ leaves: HashableMapping[Key, NodeRef[Any] | None]
312
+ metadata: Hashable
313
+ index_mapping: FrozenDict[Index, Index] | None
314
+
315
+ @classmethod
316
+ def create(
317
+ cls,
318
+ type: Type[Node],
319
+ index: int,
320
+ attributes: tuple[Key, ...],
321
+ subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
322
+ static_fields: Iterable[tuple[Key, Any]],
323
+ leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
324
+ metadata: Hashable,
325
+ index_mapping: Mapping[Index, Index] | None,
326
+ ):
327
+ return cls(
328
+ type=type,
329
+ index=index,
330
+ attributes=attributes,
331
+ subgraphs=HashableMapping(subgraphs),
332
+ static_fields=HashableMapping(static_fields),
333
+ leaves=HashableMapping(leaves),
334
+ metadata=metadata,
335
+ index_mapping=FrozenDict(index_mapping) if index_mapping is not None else None,
336
+ )
337
+
338
+ def __pretty_repr__(self):
339
+ yield PrettyType(type=type(self))
340
+
341
+ yield PrettyAttr('type', self.type.__name__)
342
+ yield PrettyAttr('index', self.index)
343
+ yield PrettyAttr('attributes', self.attributes)
344
+ yield PrettyAttr('subgraphs', PrettyMapping(self.subgraphs))
345
+ yield PrettyAttr('static_fields', PrettyMapping(self.static_fields))
346
+ yield PrettyAttr('leaves', PrettyMapping(self.leaves))
347
+ yield PrettyAttr('metadata', self.metadata)
348
+ yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
349
+
350
+ def __treescope_repr__(self, path, subtree_renderer):
351
+ import treescope # type: ignore[import-not-found,import-untyped]
352
+ return treescope.repr_lib.render_object_constructor(
353
+ object_type=type(self),
354
+ attributes={'type': self.type,
355
+ 'index': self.index,
356
+ 'attributes': self.attributes,
357
+ 'subgraphs': dict(self.subgraphs),
358
+ 'static_fields': dict(self.static_fields),
359
+ 'leaves': dict(self.leaves),
360
+ 'metadata': self.metadata, },
361
+ path=path,
362
+ subtree_renderer=subtree_renderer,
363
+ )
364
+
365
+ def apply(
366
+ self,
367
+ state_map: GraphStateMapping,
368
+ *state_maps: GraphStateMapping
369
+ ) -> ApplyCaller[tuple[GraphDef[Node], GraphStateMapping]]:
370
+ accessor = DelayedAccessor()
371
+
372
+ def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[
373
+ Any, tuple[GraphDef[Node], GraphStateMapping]]:
374
+ module = treefy_merge(self, state_map, *state_maps)
375
+ fn = accessor(module)
376
+ out = fn(*args, **kwargs)
377
+ return out, flatten(module)
378
+
379
+ return CallableProxy(_apply, accessor) # type: ignore
380
+
381
+
382
+ jax.tree_util.register_static(NodeDef)
383
+
384
+
385
+ # --------------------------------------------------------
386
+ # Graph operations: start
387
+ # --------------------------------------------------------
388
+
389
+
390
+ def _graph_flatten(
391
+ path: PathParts,
392
+ ref_index: RefMap[Any, Index],
393
+ flatted_state_mapping: Dict[PathParts, StateLeaf],
394
+ node: Node,
395
+ treefy_state: bool = False,
396
+ ):
397
+ """
398
+ Recursive helper for graph flatten.
399
+
400
+ Args:
401
+ path: The path to the node.
402
+ ref_index: A mapping from nodes to indexes.
403
+ flatted_state_mapping: A mapping from paths to state leaves.
404
+ node: The node to flatten.
405
+
406
+ Returns:
407
+ A NodeDef or a NodeRef.
408
+ """
409
+ if not _is_node(node):
410
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
411
+
412
+ # If the node is already in the cache, return a reference, otherwise
413
+ # add it to the cache and continue with the flattening process.
414
+ # This is done to avoid infinite recursion when there is a reference cycle.
415
+ if node in ref_index:
416
+ return NodeRef(type(node), ref_index[node])
417
+
418
+ # Get the node implementation for the node type.
419
+ # There are two types of node implementations: GraphNodeImpl and PyTreeNodeImpl.
420
+ # - ``GraphNodeImpl`` is used for nodes that have a graph structure.
421
+ # - ``PyTreeNodeImpl`` is used for nodes that have a tree structure.
422
+ node_impl = _get_node_impl(node)
423
+
424
+ # There are two types of nodes: Node and State.
425
+ # Here we handle the Node case.
426
+ if isinstance(node_impl, GraphNodeImpl):
427
+ # add the node to the cache
428
+ index = len(ref_index)
429
+ ref_index[node] = index
430
+ else:
431
+ index = -1
432
+
433
+ subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
434
+ static_fields: list[tuple[Key, Any]] = []
435
+ leaves: list[tuple[Key, NodeRef | None]] = []
436
+
437
+ # Flatten the node into a sequence of key-value pairs.
438
+ values, metadata = node_impl.flatten(node)
439
+ for key, value in values:
440
+ if _is_node(value):
441
+ # Recursively flatten the subgraph.
442
+ nodedef = _graph_flatten((*path, key), ref_index, flatted_state_mapping, value, treefy_state)
443
+ subgraphs.append((key, nodedef))
444
+ elif isinstance(value, State):
445
+ # If the variable is in the cache, add a reference to it.
446
+ if value in ref_index:
447
+ leaves.append((key, NodeRef(type(value), ref_index[value])))
448
+ else:
449
+ # If the variable is not in the cache, add it to the cache.
450
+ # This is done to avoid multiple references to the same variable.
451
+ flatted_state_mapping[(*path, key)] = (value.to_state_ref() if treefy_state else value)
452
+ variable_index = ref_index[value] = len(ref_index)
453
+ leaves.append((key, NodeRef(type(value), variable_index)))
454
+ elif _is_state_leaf(value):
455
+ # The instance of ``TreefyState`` is a leaf.
456
+ flatted_state_mapping[(*path, key)] = value
457
+ leaves.append((key, None))
458
+ else:
459
+ # if isinstance(value, (jax.Array, np.ndarray)):
460
+ # path_str = '/'.join(map(str, (*path, key)))
461
+ # raise ValueError(f'Arrays leaves are not supported, at {path_str!r}: {value}')
462
+
463
+ # The value is a static field.
464
+ static_fields.append((key, value))
465
+
466
+ nodedef = NodeDef.create(type=node_impl.type,
467
+ index=index,
468
+ attributes=tuple(key for key, _ in values),
469
+ subgraphs=subgraphs,
470
+ static_fields=static_fields,
471
+ leaves=leaves,
472
+ metadata=metadata,
473
+ index_mapping=None, )
474
+ return nodedef
475
+
476
+
477
+ @set_module_as('brainstate.graph')
478
+ def flatten(
479
+ node: Node,
480
+ /,
481
+ ref_index: Optional[RefMap[Any, Index]] = None,
482
+ treefy_state: bool = True,
483
+ ) -> Tuple[GraphDef, NestedDict]:
484
+ """
485
+ Flattens a graph node into a (graph_def, state_mapping) pair.
486
+
487
+ Example::
488
+
489
+ >>> import brainstate as bst
490
+ >>> node = bst.graph.Node()
491
+ >>> graph_def, state_mapping = flatten(node)
492
+ >>> print(graph_def)
493
+ >>> print(state_mapping)
494
+
495
+ Args:
496
+ node: A graph node.
497
+ ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
498
+ empty dictionary is created. This argument can be used to flatten a sequence of graph
499
+ nodes that share references.
500
+ treefy_state: If True, the state mapping will be a NestedDict instead of a flat dictionary.
501
+ """
502
+ ref_index = RefMap() if ref_index is None else ref_index
503
+ assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
504
+ flatted_state_mapping: dict[PathParts, StateLeaf] = {}
505
+ graph_def = _graph_flatten((), ref_index, flatted_state_mapping, node, treefy_state)
506
+ return graph_def, NestedDict.from_flat(flatted_state_mapping)
507
+
508
+
509
+ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
510
+ children: dict[Key, StateLeaf | Node] = {}
511
+
512
+ # NOTE: we could allow adding new StateLeafs here
513
+ # All state keys must be present in the graph definition (the object attributes)
514
+ if unknown_keys := set(state_mapping) - set(graph_def.attributes):
515
+ raise ValueError(f'Unknown keys: {unknown_keys}')
516
+
517
+ # for every key in attributes there are 6 possible cases:
518
+ # - (2) the key can either be present in the state or not
519
+ # - (3) the key can be a subgraph, a leaf, or a static attribute
520
+ for key in graph_def.attributes:
521
+ if key not in state_mapping: # static field
522
+ # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
523
+ # if key is not present, create an empty types
524
+ if key in graph_def.static_fields:
525
+ children[key] = graph_def.static_fields[key]
526
+
527
+ elif key in graph_def.subgraphs:
528
+ # if the key is a subgraph we create an empty node
529
+ subgraphdef = graph_def.subgraphs[key]
530
+ if isinstance(subgraphdef, NodeRef):
531
+ # subgraph exists, take it from the cache
532
+ children[key] = index_ref[subgraphdef.index]
533
+
534
+ else:
535
+ # create a node from an empty state, reasoning:
536
+ # * it is a node with no state
537
+ # * it is a node with state but only through references of already
538
+ # created nodes
539
+ substate = {}
540
+ children[key] = _graph_unflatten(subgraphdef, substate, index_ref, index_ref_cache)
541
+
542
+ elif key in graph_def.leaves:
543
+ noderef = graph_def.leaves[key]
544
+ if (noderef is not None) and (noderef.index in index_ref):
545
+ # variable exists, take it from the cache
546
+ children[key] = index_ref[noderef.index]
547
+
548
+ else:
549
+ # key for a variable is missing, raise an error
550
+ raise ValueError(f'Expected key {key!r} in state while building node of type '
551
+ f'{graph_def.type.__name__}.')
552
+
553
+ else:
554
+ raise RuntimeError(f'Unknown static field: {key!r}')
555
+
556
+ else: # state field
557
+ value = state_mapping[key]
558
+ if isinstance(value, PrettyDict):
559
+ value = dict(value)
560
+
561
+ if key in graph_def.static_fields:
562
+ raise ValueError(f'Got state for static field {key!r}, this is not supported.')
563
+
564
+ if key in graph_def.subgraphs:
565
+ # if _is_state_leaf(value):
566
+ if isinstance(value, (TreefyState, State)):
567
+ raise ValueError(f'Expected value of type {graph_def.subgraphs[key]} '
568
+ f'for {key!r}, but got {value!r}')
569
+ if not isinstance(value, dict):
570
+ raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
571
+
572
+ subgraphdef = graph_def.subgraphs[key]
573
+ if isinstance(subgraphdef, NodeRef):
574
+ children[key] = index_ref[subgraphdef.index]
575
+ else:
576
+ children[key] = _graph_unflatten(subgraphdef, value, index_ref, index_ref_cache)
577
+
578
+ elif key in graph_def.leaves:
579
+ # if not _is_state_leaf(value):
580
+ if not isinstance(value, (TreefyState, State)):
581
+ raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
582
+
583
+ noderef = graph_def.leaves[key]
584
+ if noderef is None:
585
+ # if the leaf is None, it means that the value was originally
586
+ # a non-TreefyState leaf, however we allow providing a
587
+ # TreefyState presumbly created by modifying the NestedDict
588
+ if isinstance(value, TreefyState):
589
+ value = value.to_state()
590
+ # elif isinstance(value, State):
591
+ # value = value
592
+ children[key] = value
593
+
594
+ elif noderef.index in index_ref:
595
+ # add an existing variable
596
+ children[key] = index_ref[noderef.index]
597
+
598
+ else:
599
+ # it is an unseen variable, create a new one
600
+ if not isinstance(value, (TreefyState, State)):
601
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
602
+ # when idxmap is present, check if the Varable exists there
603
+ # and update existing variables if it does
604
+ if index_ref_cache is not None and noderef.index in index_ref_cache:
605
+ variable = index_ref_cache[noderef.index]
606
+ if not isinstance(variable, State):
607
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(variable)}.')
608
+ if isinstance(value, TreefyState):
609
+ variable.update_from_ref(value)
610
+ elif isinstance(value, State):
611
+ variable.restore_value(value.value)
612
+ else:
613
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
614
+ else: # if it doesn't, create a new variable
615
+ if isinstance(value, TreefyState):
616
+ variable = value.to_state()
617
+ elif isinstance(value, State):
618
+ variable = value
619
+ else:
620
+ raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
621
+ children[key] = variable
622
+ index_ref[noderef.index] = variable
623
+
624
+ else:
625
+ raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
626
+
627
+ return children
628
+
629
+
630
+ def _graph_unflatten(
631
+ graph_def: NodeDef[Node] | NodeRef[Node],
632
+ state_mapping: Mapping[Key, StateLeaf | Mapping[Key, Any]],
633
+ index_ref: dict[Index, Any],
634
+ index_ref_cache: dict[Index, Any] | None,
635
+ ) -> Node:
636
+ """
637
+ Recursive helper for graph unflatten.
638
+
639
+ Args:
640
+ graph_def: A `GraphDef` instance or an index to a node in the cache.
641
+ state_mapping: A state mapping from attribute names to variables or subgraphs.
642
+ index_ref: A mapping from indexes to nodes that have been traversed.
643
+ If a node is already in the cache, it won't be traversed again.
644
+ index_ref_cache: A mapping from indexes to existing nodes that can be reused.
645
+ When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
646
+ object in an empty state and then filled by the unflatten process, as a result
647
+ existing graph nodes are mutated to have the new content/topology
648
+ specified by the nodedef.
649
+
650
+ Returns:
651
+ A node instance.
652
+ """
653
+
654
+ # if the graph_def is a reference, this means that the node has already been created, so
655
+ # we return the node from the cache
656
+ if isinstance(graph_def, NodeRef):
657
+ return index_ref[graph_def.index]
658
+ else:
659
+ assert isinstance(graph_def, NodeDef), f"graph_def must be a NodeDef. But we got: {graph_def}"
660
+
661
+ # graph_def must be a registered node type
662
+ if not _is_node_type(graph_def.type):
663
+ raise RuntimeError(f'Unsupported type: {graph_def.type}, this is a bug.')
664
+
665
+ # check if the index is already in the cache
666
+ if graph_def.index in index_ref:
667
+ raise RuntimeError(f'GraphDef index {graph_def.index} already used.')
668
+
669
+ # get the node implementation for the node type
670
+ node_impl = get_node_impl_for_type(graph_def.type)
671
+
672
+ if isinstance(node_impl, GraphNodeImpl):
673
+ # we create an empty node first and add it to the index
674
+ # this avoids infinite recursion when there is a reference cycle
675
+
676
+ if (index_ref_cache is not None) and (graph_def.index in index_ref_cache):
677
+ # clear the node to leave it in an empty state
678
+ node = index_ref_cache[graph_def.index]
679
+ if type(node) != graph_def.type:
680
+ raise ValueError(f'Expected a node of type {graph_def.type} for index '
681
+ f'{graph_def.index}, but got a node of type {type(node)}.')
682
+ node_impl.clear(node)
683
+ else:
684
+ # create an empty node
685
+ node = node_impl.create_empty(graph_def.metadata)
686
+
687
+ # add the node to the cache
688
+ index_ref[graph_def.index] = node
689
+
690
+ # get the children (the attributes) of the node
691
+ children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
692
+
693
+ # initialize the node with the children
694
+ node_impl.init(node, tuple(children.items()))
695
+
696
+ else:
697
+ # if the node type does not support the creation of an empty object it means
698
+ # that it cannot reference itself, so we can create its children first
699
+
700
+ # first, we create the children (attributes)
701
+ children = _get_children(graph_def, state_mapping, index_ref, index_ref_cache)
702
+ # then, we create the node
703
+ node = node_impl.unflatten(tuple(children.items()), graph_def.metadata)
704
+
705
+ return node
706
+
707
+
708
+ @set_module_as('brainstate.graph')
709
+ def unflatten(
710
+ graph_def: GraphDef,
711
+ state_mapping: NestedDict[Key, StateLeaf],
712
+ /,
713
+ *,
714
+ index_ref: dict[Index, Any] | None = None,
715
+ index_ref_cache: dict[Index, Any] | None = None,
716
+ ) -> Node:
717
+ """
718
+ Unflattens a graphdef into a node with the given state tree mapping.
719
+
720
+ Example::
721
+
722
+ >>> import brainstate as bst
723
+ >>> class MyNode(bst.graph.Node):
724
+ ... def __init__(self):
725
+ ... self.a = bst.nn.Linear(2, 3)
726
+ ... self.b = bst.nn.Linear(3, 4)
727
+ ... self.c = [bst.nn.Linear(4, 5), bst.nn.Linear(5, 6)]
728
+ ... self.d = {'x': bst.nn.Linear(6, 7), 'y': bst.nn.Linear(7, 8)}
729
+ ...
730
+ >>> graphdef, statetree = bst.graph.flatten(MyNode())
731
+ >>> statetree
732
+ NestedDict({
733
+ 'a': {
734
+ 'weight': TreefyState(
735
+ type=ParamState,
736
+ value={'weight': Array([[-0.8466386 , -2.0294454 , -0.6911647 ],
737
+ [ 0.60034966, -1.1869028 , 0.84003365]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
738
+ )
739
+ },
740
+ 'b': {
741
+ 'weight': TreefyState(
742
+ type=ParamState,
743
+ value={'weight': Array([[ 0.8565106 , -0.10337489],
744
+ [ 1.7449658 , 0.29128835],
745
+ [ 0.11441387, 1.0012752 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
746
+ )
747
+ },
748
+ 'c': {
749
+ 0: {
750
+ 'weight': TreefyState(
751
+ type=ParamState,
752
+ value={'weight': Array([[ 2.4465137, -0.5711426]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
753
+ )
754
+ },
755
+ 1: {
756
+ 'weight': TreefyState(
757
+ type=ParamState,
758
+ value={'weight': Array([[ 0.14321847, -2.4154725 , -0.6322363 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
759
+ )
760
+ }
761
+ },
762
+ 'd': {
763
+ 'x': {
764
+ 'weight': TreefyState(
765
+ type=ParamState,
766
+ value={'weight': Array([[ 0.9647322, -0.8958757, 1.585352 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
767
+ )
768
+ },
769
+ 'y': {
770
+ 'weight': TreefyState(
771
+ type=ParamState,
772
+ value={'weight': Array([[-1.2904786 , 0.5695903 , 0.40079263, 0.8769669 ]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
773
+ )
774
+ }
775
+ }
776
+ })
777
+ >>> node = bst.graph.unflatten(graphdef, statetree)
778
+ >>> node
779
+ MyNode(
780
+ a=Linear(
781
+ in_size=(2,),
782
+ out_size=(3,),
783
+ w_mask=None,
784
+ weight=ParamState(
785
+ value={'weight': Array([[ 0.55600464, -1.6276929 , 0.26805446],
786
+ [ 1.175099 , 1.0077754 , 0.37592274]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)},
787
+ )
788
+ ),
789
+ b=Linear(
790
+ in_size=(3,),
791
+ out_size=(4,),
792
+ w_mask=None,
793
+ weight=ParamState(
794
+ value={'weight': Array([[-0.24753566, 0.18456966, -0.29438975, 0.16891003],
795
+ [-0.803741 , -0.46037054, -0.21617596, 0.1260884 ],
796
+ [-0.43074366, -0.24757433, 1.2237076 , -0.07842704]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)},
797
+ )
798
+ ),
799
+ c=[Linear(
800
+ in_size=(4,),
801
+ out_size=(5,),
802
+ w_mask=None,
803
+ weight=ParamState(
804
+ value={'weight': Array([[-0.22384474, 0.79441446, -0.658726 , 0.05991402, 0.3014344 ],
805
+ [-1.4755846 , -0.42272082, -0.07692316, 0.03077666, 0.34513143],
806
+ [-0.69395834, 0.48617035, 1.1042316 , 0.13105175, -0.25620162],
807
+ [ 0.50389856, 0.6998943 , 0.43716812, 1.2168779 , -0.47325954]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)},
808
+ )
809
+ ), Linear(
810
+ in_size=(5,),
811
+ out_size=(6,),
812
+ w_mask=None,
813
+ weight=ParamState(
814
+ value={'weight': Array([[ 0.07714394, 0.78213537, 0.6745718 , -0.22881542, 0.5523547 ,
815
+ -0.6399196 ],
816
+ [-0.22626828, -0.54522336, 0.07448788, -0.00464636, 1.1483842 ,
817
+ -0.57049096],
818
+ [-0.86659616, 0.5683135 , -0.7449975 , 1.1862832 , 0.15047254,
819
+ 0.68890226],
820
+ [-1.0325443 , 0.2658072 , -0.10083053, -0.66915905, 0.11258496,
821
+ 0.5440655 ],
822
+ [ 0.27917263, 0.05717273, -0.5682605 , -0.88345915, 0.01314917,
823
+ 0.780759 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)},
824
+ )
825
+ )],
826
+ d={'x': Linear(
827
+ in_size=(6,),
828
+ out_size=(7,),
829
+ w_mask=None,
830
+ weight=ParamState(
831
+ value={'weight': Array([[-0.24238771, -0.23202638, 0.13663477, -0.48858666, 0.80871904,
832
+ 0.00593298, 0.7595096 ],
833
+ [ 0.50457454, 0.24180941, 0.25048748, 0.8937061 , 0.25398138,
834
+ -1.2400566 , 0.00151599],
835
+ [-0.19136038, 0.34470603, -0.11892717, -0.12514868, -0.5871703 ,
836
+ 0.13572927, -1.1859009 ],
837
+ [-0.01580911, 0.9301295 , -1.1246226 , -0.137708 , -0.4952151 ,
838
+ 0.17537868, 0.98440856],
839
+ [ 0.6399284 , 0.01739843, 0.61856824, 0.93258303, 0.64012206,
840
+ 0.22780116, -0.5763679 ],
841
+ [ 0.14077143, -1.0359222 , 0.28072503, 0.2557584 , -0.50622064,
842
+ 0.4388198 , -0.26106128]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
843
+ )
844
+ ), 'y': Linear(
845
+ in_size=(7,),
846
+ out_size=(8,),
847
+ w_mask=None,
848
+ weight=ParamState(
849
+ value={'weight': Array([[-0.23334591, -0.2893582 , 0.8071877 , -0.49038902, -0.29646504,
850
+ 0.13624157, 0.22763114, 0.01906361],
851
+ [-0.26742765, 0.20136863, 0.35148615, 0.42135832, 0.06401154,
852
+ -0.78036404, 0.6616062 , 0.19437549],
853
+ [ 0.9229799 , -0.1205209 , 0.69602865, 0.9685676 , -0.99886954,
854
+ -0.12649904, -0.15393028, 0.65067965],
855
+ [ 0.7020109 , -0.5452006 , 0.3649151 , -0.42368713, 0.24738027,
856
+ 0.29290223, -0.63721114, 0.6007214 ],
857
+ [-0.45045808, -0.08538888, -0.01338054, -0.39983988, 0.4028439 ,
858
+ 1.0498686 , -0.24730456, 0.37612835],
859
+ [ 0.16273966, 0.9001257 , 0.15190877, -1.1129239 , -0.29441378,
860
+ 0.5168159 , -0.4205143 , 0.45700482],
861
+ [ 0.08611429, -0.9271384 , -0.562362 , -0.586757 , 1.1611121 ,
862
+ 0.5137503 , -0.46277294, 0.84642583]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
863
+ )
864
+ )}
865
+ )
866
+
867
+ Args:
868
+ graph_def: A GraphDef instance.
869
+ state_mapping: A NestedDict instance.
870
+ index_ref: A mapping from indexes to nodes references found during the graph
871
+ traversal, defaults to None. If not provided, a new empty dictionary is
872
+ created. This argument can be used to unflatten a sequence of (graphdef, state_mapping)
873
+ pairs that share the same index space.
874
+ index_ref_cache: A mapping from indexes to existing nodes that can be reused.
875
+ When a reference is reused, ``GraphNodeImpl.clear`` is called to leave the
876
+ object in an empty state and then filled by the unflatten process, as a result
877
+ existing graph nodes are mutated to have the new content/topology
878
+ specified by the graphdef.
879
+ """
880
+ index_ref = {} if index_ref is None else index_ref
881
+ assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
882
+ node = _graph_unflatten(graph_def, state_mapping.to_dict(), index_ref, index_ref_cache)
883
+ return node
884
+
885
+
886
+ def _graph_pop(
887
+ node: Node,
888
+ id_to_index: dict[int, Index],
889
+ path_parts: PathParts,
890
+ flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
891
+ predicates: tuple[Predicate, ...],
892
+ ) -> None:
893
+ if not _is_node(node):
894
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
895
+
896
+ if id(node) in id_to_index:
897
+ return
898
+
899
+ id_to_index[id(node)] = len(id_to_index)
900
+ node_impl = _get_node_impl(node)
901
+ node_dict = node_impl.node_dict(node)
902
+
903
+ for name, value in node_dict.items():
904
+ if _is_node(value):
905
+ _graph_pop(
906
+ node=value,
907
+ id_to_index=id_to_index,
908
+ path_parts=(*path_parts, name),
909
+ flatted_state_dicts=flatted_state_dicts,
910
+ predicates=predicates,
911
+ )
912
+ continue
913
+ elif not _is_node_leaf(value):
914
+ continue
915
+ elif id(value) in id_to_index:
916
+ continue
917
+
918
+ node_path = (*path_parts, name)
919
+ node_impl = _get_node_impl(node)
920
+ for state_dicts, predicate in zip(flatted_state_dicts, predicates):
921
+ if predicate(node_path, value):
922
+ if isinstance(node_impl, PyTreeNodeImpl):
923
+ raise ValueError(f'Cannot pop key {name!r} from node of type {type(node).__name__}')
924
+ id_to_index[id(value)] = len(id_to_index)
925
+ node_impl.pop_key(node, name)
926
+ # if isinstance(value, State):
927
+ # value = value.to_state_ref()
928
+ state_dicts[node_path] = value # type: ignore[index] # mypy is wrong here?
929
+ break
930
+ else:
931
+ # NOTE: should we raise an error here?
932
+ pass
933
+
934
+
935
+ @overload
936
+ def pop_states(node, filter1: Filter, /) -> NestedDict:
937
+ ...
938
+
939
+
940
+ @overload
941
+ def pop_states(node, filter1: Filter, filter2: Filter, /, *filters: Filter) -> tuple[NestedDict, ...]:
942
+ ...
943
+
944
+
945
+ @set_module_as('brainstate.graph')
946
+ def pop_states(
947
+ node: Node,
948
+ *filters: Any
949
+ ) -> Union[NestedDict[Key, State], Tuple[NestedDict[Key, State], ...]]:
950
+ """
951
+ Pop one or more :class:`State` types from the graph node.
952
+
953
+ Example usage::
954
+
955
+ >>> import brainstate as bst
956
+ >>> import jax.numpy as jnp
957
+
958
+ >>> class Model(bst.nn.Module):
959
+ ... def __init__(self):
960
+ ... super().__init__()
961
+ ... self.a = bst.nn.Linear(2, 3)
962
+ ... self.b = bst.nn.LIF([10, 2])
963
+
964
+ >>> model = Model()
965
+ >>> with bst.catch_new_states('new'):
966
+ ... bst.nn.init_all_states(model)
967
+
968
+ >>> assert len(model.states()) == 2
969
+ >>> model_states = bst.graph.pop_states(model, 'new')
970
+ >>> model_states
971
+ NestedDict({
972
+ 'b': {
973
+ 'V': {
974
+ 'st': ShortTermState(
975
+ value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
976
+ 0., 0., 0.], dtype=float32),
977
+ tag='new'
978
+ )
979
+ }
980
+ }
981
+ })
982
+
983
+ Args:
984
+ node: A graph node object.
985
+ *filters: One or more :class:`State` objects to filter by.
986
+
987
+ Returns:
988
+ The popped :class:`NestedDict` containing the :class:`State`
989
+ objects that were filtered for.
990
+ """
991
+ if len(filters) == 0:
992
+ raise ValueError('Expected at least one filter')
993
+
994
+ id_to_index: dict[int, Index] = {}
995
+ path_parts: PathParts = ()
996
+ predicates = tuple(to_predicate(filter) for filter in filters)
997
+ flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
998
+ _graph_pop(node=node,
999
+ id_to_index=id_to_index,
1000
+ path_parts=path_parts,
1001
+ flatted_state_dicts=flatted_state_dicts,
1002
+ predicates=predicates, )
1003
+ states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
1004
+
1005
+ if len(states) == 1:
1006
+ return states[0]
1007
+ else:
1008
+ return states
1009
+
1010
+
1011
+ def _split_state(
1012
+ state: GraphStateMapping,
1013
+ filters: tuple[Filter, ...],
1014
+ ) -> tuple[GraphStateMapping, Unpack[tuple[GraphStateMapping, ...]]]:
1015
+ if not filters:
1016
+ return (state,)
1017
+ states = state.split(*filters)
1018
+ if isinstance(states, NestedDict):
1019
+ return (states,)
1020
+ assert len(states) > 0
1021
+ return states # type: ignore[return-value]
1022
+
1023
+
1024
+ @overload
1025
+ def treefy_split(node: A, /) -> Tuple[GraphDef, NestedDict]:
1026
+ ...
1027
+
1028
+
1029
+ @overload
1030
+ def treefy_split(node: A, first: Filter, /) -> Tuple[GraphDef, NestedDict]:
1031
+ ...
1032
+
1033
+
1034
+ @overload
1035
+ def treefy_split(node: A, first: Filter, second: Filter, /) -> Tuple[GraphDef, NestedDict, NestedDict]:
1036
+ ...
1037
+
1038
+
1039
+ @overload
1040
+ def treefy_split(
1041
+ node: A, first: Filter, second: Filter, /, *filters: Filter,
1042
+ ) -> Tuple[GraphDef, NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1043
+ ...
1044
+
1045
+
1046
+ @set_module_as('brainstate.graph')
1047
+ def treefy_split(
1048
+ node: A,
1049
+ *filters: Filter
1050
+ ) -> Tuple[GraphDef[A], NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1051
+ """Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s. NestedDict is
1052
+ a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
1053
+ contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
1054
+ to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
1055
+ seamlessly between stateful and stateless representations of the graph.
1056
+
1057
+ Example usage::
1058
+
1059
+ >>> from joblib.testing import param >>> import brainstate as bst
1060
+ >>> import jax, jax.numpy as jnp
1061
+ ...
1062
+ >>> class Foo(bst.graph.Node):
1063
+ ... def __init__(self):
1064
+ ... self.a = bst.nn.BatchNorm1d([10, 2])
1065
+ ... self.b = bst.nn.Linear(2, 3)
1066
+ ...
1067
+ >>> node = Foo()
1068
+ >>> graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
1069
+ ...
1070
+ >>> params
1071
+ NestedDict({
1072
+ 'a': {
1073
+ 'weight': TreefyState(
1074
+ type=ParamState,
1075
+ value={'weight': Array([[-1.0013659, 1.5763807],
1076
+ [ 1.7149199, 2.0140953]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
1077
+ )
1078
+ },
1079
+ 'b': {
1080
+ 'weight': TreefyState(
1081
+ type=ParamState,
1082
+ value={'bias': Array([[0., 0.]], dtype=float32), 'scale': Array([[1., 1.]], dtype=float32)}
1083
+ )
1084
+ }
1085
+ })
1086
+ >>> jax.tree.map(jnp.shape, others)
1087
+ NestedDict({
1088
+ 'b': {
1089
+ 'running_mean': TreefyState(
1090
+ type=LongTermState,
1091
+ value=(1, 2)
1092
+ ),
1093
+ 'running_var': TreefyState(
1094
+ type=LongTermState,
1095
+ value=(1, 2)
1096
+ )
1097
+ }
1098
+ })
1099
+
1100
+ :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1101
+ transformations, see
1102
+ `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1103
+ for more information.
1104
+
1105
+ Arguments:
1106
+ node: graph node to split.
1107
+ *filters: some optional filters to group the state into mutually exclusive substates.
1108
+
1109
+ Returns:
1110
+ ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
1111
+ filters are passed, a single ``NestedDict`` is returned.
1112
+ """
1113
+ graphdef, state_tree = flatten(node)
1114
+ states = tuple(_split_state(state_tree, filters))
1115
+ return graphdef, *states
1116
+
1117
+
1118
+ @set_module_as('brainstate.graph')
1119
+ def treefy_merge(
1120
+ graphdef: GraphDef[A],
1121
+ state_mapping: GraphStateMapping,
1122
+ /,
1123
+ *state_mappings: GraphStateMapping,
1124
+ ) -> A:
1125
+ """The inverse of :func:`split`.
1126
+
1127
+ ``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
1128
+ a new node with the same structure as the original node.
1129
+
1130
+ Example usage::
1131
+
1132
+ >>> import brainstate as bst
1133
+ >>> import jax, jax.numpy as jnp
1134
+ ...
1135
+ >>> class Foo(bst.graph.Node):
1136
+ ... def __init__(self):
1137
+ ... self.a = bst.nn.BatchNorm1d([10, 2])
1138
+ ... self.b = bst.nn.Linear(2, 3)
1139
+ ...
1140
+ >>> node = Foo()
1141
+ >>> graphdef, params, others = bst.graph.treefy_split(node, bst.ParamState, ...)
1142
+ ...
1143
+ >>> new_node = bst.graph.treefy_merge(graphdef, params, others)
1144
+ >>> assert isinstance(new_node, Foo)
1145
+ >>> assert isinstance(new_node.b, bst.nn.BatchNorm1d)
1146
+ >>> assert isinstance(new_node.a, bst.nn.Linear)
1147
+
1148
+ :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1149
+ transformations, see
1150
+ `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1151
+ for more information.
1152
+
1153
+ Args:
1154
+ graphdef: A :class:`GraphDef` object.
1155
+ state_mapping: A :class:`NestedDict` object.
1156
+ *state_mappings: Additional :class:`NestedDict` objects.
1157
+
1158
+ Returns:
1159
+ The merged :class:`Module`.
1160
+ """
1161
+ state_mapping = GraphStateMapping.merge(state_mapping, *state_mappings)
1162
+ node = unflatten(graphdef, state_mapping)
1163
+ return node
1164
+
1165
+
1166
+ def _filters_to_predicates(filters: Tuple[Filter, ...]) -> Tuple[Predicate, ...]:
1167
+ for i, filter_ in enumerate(filters):
1168
+ if filter_ in (..., True) and i != len(filters) - 1:
1169
+ remaining_filters = filters[i + 1:]
1170
+ if not all(f in (..., True) for f in remaining_filters):
1171
+ raise ValueError('`...` or `True` can only be used as the last filters, '
1172
+ f'got {filter_} it at index {i}.')
1173
+ return tuple(map(to_predicate, filters))
1174
+
1175
+
1176
+ def _split_flatted(
1177
+ flatted: Iterable[tuple[PathParts, Any]],
1178
+ filters: tuple[Filter, ...],
1179
+ ) -> tuple[list[tuple[PathParts, Any]], ...]:
1180
+ predicates = _filters_to_predicates(filters)
1181
+
1182
+ # we have n + 1 states, where n is the number of predicates
1183
+ # the last state is for values that don't match any predicate
1184
+ flat_states: tuple[list[tuple[PathParts, Any]], ...] = tuple([] for _ in predicates)
1185
+
1186
+ for path, value in flatted:
1187
+ for i, predicate in enumerate(predicates):
1188
+ if predicate(path, value):
1189
+ flat_states[i].append((path, value))
1190
+ break
1191
+ else:
1192
+ raise ValueError('Non-exhaustive filters, got a non-empty remainder: '
1193
+ f'{path} -> {value}.'
1194
+ '\nUse `...` to match all remaining elements.')
1195
+
1196
+ return flat_states
1197
+
1198
+
1199
+ @overload
1200
+ def nodes(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1201
+ ...
1202
+
1203
+
1204
+ @overload
1205
+ def nodes(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1206
+ ...
1207
+
1208
+
1209
+ @overload
1210
+ def nodes(
1211
+ node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1212
+ ) -> Tuple[FlattedDict[Key, Node], ...]:
1213
+ ...
1214
+
1215
+
1216
+ @set_module_as('brainstate.graph')
1217
+ def nodes(
1218
+ node,
1219
+ *filters: Filter,
1220
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1221
+ ) -> Union[FlattedDict[Key, Node], Tuple[FlattedDict[Key, Node], ...]]:
1222
+ """
1223
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1224
+ """
1225
+ num_filters = len(filters)
1226
+ if num_filters == 0:
1227
+ filters = (..., ...)
1228
+ else:
1229
+ filters = (*filters, ...)
1230
+
1231
+ nodes_iterable = iter_node(node, allowed_hierarchy=allowed_hierarchy)
1232
+ flat_nodes = _split_flatted(nodes_iterable, (*filters, ...))
1233
+ node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
1234
+ if num_filters < 2:
1235
+ return node_maps[0]
1236
+ return node_maps[:num_filters]
1237
+
1238
+
1239
+ def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
1240
+ for path, value in iter_leaf(node, allowed_hierarchy=allowed_hierarchy):
1241
+ if isinstance(value, State):
1242
+ yield path, value
1243
+
1244
+
1245
+ @overload
1246
+ def states(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1247
+ ...
1248
+
1249
+
1250
+ @overload
1251
+ def states(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1252
+ ...
1253
+
1254
+
1255
+ @overload
1256
+ def states(
1257
+ node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1258
+ ) -> tuple[FlattedDict[Key, State], ...]:
1259
+ ...
1260
+
1261
+
1262
+ @set_module_as('brainstate.graph')
1263
+ def states(
1264
+ node,
1265
+ *filters: Filter,
1266
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1267
+ ) -> Union[FlattedDict[Key, State], tuple[FlattedDict[Key, State], ...]]:
1268
+ """
1269
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1270
+ """
1271
+ num_filters = len(filters)
1272
+ if num_filters == 0:
1273
+ filters = (..., ...)
1274
+ else:
1275
+ filters = (*filters, ...)
1276
+
1277
+ states_iterable = _states_generator(node, allowed_hierarchy=allowed_hierarchy)
1278
+ flat_states = _split_flatted(states_iterable, (*filters, ...))
1279
+ state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
1280
+ if num_filters < 2:
1281
+ return state_maps[0]
1282
+ return state_maps[:num_filters]
1283
+
1284
+
1285
+ @overload
1286
+ def treefy_states(
1287
+ node, /, flatted: bool = False
1288
+ ) -> NestedDict[Key, TreefyState]:
1289
+ ...
1290
+
1291
+
1292
+ @overload
1293
+ def treefy_states(
1294
+ node, first: Filter, /, flatted: bool = False
1295
+ ) -> NestedDict[Key, TreefyState]:
1296
+ ...
1297
+
1298
+
1299
+ @overload
1300
+ def treefy_states(
1301
+ node, first: Filter, second: Filter, /, *filters: Filter, flatted: bool = False
1302
+ ) -> Tuple[NestedDict[Key, TreefyState], ...]:
1303
+ ...
1304
+
1305
+
1306
+ @set_module_as('brainstate.graph')
1307
+ def treefy_states(
1308
+ node, *filters,
1309
+ ) -> NestedDict[Key, TreefyState] | tuple[NestedDict[Key, TreefyState], ...]:
1310
+ """
1311
+ Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1312
+
1313
+ Example usage::
1314
+
1315
+ >>> import brainstate as bst
1316
+ >>> class Model(bst.nn.Module):
1317
+ ... def __init__(self):
1318
+ ... super().__init__()
1319
+ ... self.l1 = bst.nn.Linear(2, 3)
1320
+ ... self.l2 = bst.nn.Linear(3, 4)
1321
+ ... def __call__(self, x):
1322
+ ... return self.l2(self.l1(x))
1323
+
1324
+ >>> model = Model()
1325
+ >>> # get the learnable parameters from the batch norm and linear layer
1326
+ >>> params = bst.graph.treefy_states(model, bst.ParamState)
1327
+ >>> # get them separately
1328
+ >>> params, others = bst.graph.treefy_states(model, bst.ParamState, bst.ShortTermState)
1329
+ >>> # get them together
1330
+ >>> states = bst.graph.treefy_states(model)
1331
+
1332
+ Args:
1333
+ node: A graph node object.
1334
+ *filters: One or more :class:`State` objects to filter by.
1335
+
1336
+ Returns:
1337
+ One or more :class:`NestedDict` mappings.
1338
+ """
1339
+ _, state_mapping = flatten(node)
1340
+ state_mappings: GraphStateMapping | tuple[GraphStateMapping, ...]
1341
+ if len(filters) == 0:
1342
+ state_mappings = state_mapping
1343
+ elif len(filters) == 1:
1344
+ state_mappings = state_mapping.filter(filters[0])
1345
+ else:
1346
+ state_mappings = state_mapping.filter(filters[0], filters[1], *filters[2:])
1347
+ return state_mappings
1348
+
1349
+
1350
+ def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
1351
+ if not _is_node(node):
1352
+ raise RuntimeError(f'Unsupported type: {type(node)}')
1353
+
1354
+ node_impl = _get_node_impl(node)
1355
+ node_dict = node_impl.node_dict(node)
1356
+ for key, value in state.items():
1357
+ # case 1: new state is being added
1358
+ if key not in node_dict:
1359
+ if isinstance(node_impl, PyTreeNodeImpl):
1360
+ raise ValueError(f'Cannot set key {key!r} on immutable node of '
1361
+ f'type {type(node).__name__}')
1362
+ if isinstance(value, State):
1363
+ value = value.copy() # TODO: chenge it to state_ref
1364
+ node_impl.set_key(node, key, value)
1365
+ continue
1366
+
1367
+ # check values are of the same type
1368
+ current_value = node_dict[key]
1369
+
1370
+ # case 2: subgraph is being updated
1371
+ if _is_node(current_value):
1372
+ if _is_state_leaf(value):
1373
+ raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
1374
+ _graph_update_dynamic(current_value, value)
1375
+ elif isinstance(value, TreefyState):
1376
+ # case 3: state leaf is being updated
1377
+ if not isinstance(current_value, State):
1378
+ raise ValueError(f'Trying to update a non-State attribute {key!r} with a State: '
1379
+ f'{value!r}')
1380
+ current_value.update_from_ref(value)
1381
+ elif _is_state_leaf(value):
1382
+ # case 4: state field is being updated
1383
+ if isinstance(node_impl, PyTreeNodeImpl):
1384
+ raise ValueError(f'Cannot set key {key!r} on immutable node of '
1385
+ f'type {type(node).__name__}')
1386
+ node_impl.set_key(node, key, value)
1387
+ else:
1388
+ raise ValueError(f'Unsupported update type: {type(value)} for key {key!r}')
1389
+
1390
+
1391
+ def update_states(
1392
+ node: Node,
1393
+ state_dict: NestedDict | FlattedDict,
1394
+ /,
1395
+ *state_dicts: NestedDict | FlattedDict
1396
+ ) -> None:
1397
+ """
1398
+ Update the given graph node with a new :class:`NestedMapping` in-place.
1399
+
1400
+ Args:
1401
+ node: A graph node to update.
1402
+ state_dict: A :class:`NestedMapping` object.
1403
+ *state_dicts: Additional :class:`NestedMapping` objects.
1404
+ """
1405
+ if state_dicts:
1406
+ state_dict = NestedDict.merge(state_dict, *state_dicts)
1407
+ _graph_update_dynamic(node, state_dict.to_dict())
1408
+
1409
+
1410
+ @set_module_as('brainstate.graph')
1411
+ def graphdef(node: Any, /) -> GraphDef[Any]:
1412
+ """Get the :class:`GraphDef` of the given graph node.
1413
+
1414
+ Example usage::
1415
+
1416
+ >>> import brainstate as bst
1417
+
1418
+ >>> model = bst.nn.Linear(2, 3)
1419
+ >>> graphdef, _ = bst.graph.treefy_split(model)
1420
+ >>> assert graphdef == bst.graph.graphdef(model)
1421
+
1422
+ Args:
1423
+ node: A graph node object.
1424
+
1425
+ Returns:
1426
+ The :class:`GraphDef` of the :class:`Module` object.
1427
+ """
1428
+ graphdef, _ = flatten(node)
1429
+ return graphdef
1430
+
1431
+
1432
+ @set_module_as('brainstate.graph')
1433
+ def clone(node: Node) -> Node:
1434
+ """
1435
+ Create a deep copy of the given graph node.
1436
+
1437
+ Example usage::
1438
+
1439
+ >>> import brainstate as bst
1440
+ >>> model = bst.nn.Linear(2, 3)
1441
+ >>> cloned_model = clone(model)
1442
+ >>> model.weight.value['bias'] += 1
1443
+ >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1444
+
1445
+ Args:
1446
+ node: A graph node object.
1447
+
1448
+ Returns:
1449
+ A deep copy of the :class:`Module` object.
1450
+ """
1451
+ graphdef, state = treefy_split(node)
1452
+ return treefy_merge(graphdef, state)
1453
+
1454
+
1455
+ @set_module_as('brainstate.graph')
1456
+ def call(
1457
+ graphdef_state: Tuple[GraphDef[A], GraphStateMapping],
1458
+ ) -> ApplyCaller[Tuple[GraphDef[A], GraphStateMapping]]:
1459
+ """Calls a method underlying graph node defined by a (GraphDef, NestedDict) pair.
1460
+
1461
+ ``call`` takes a ``(GraphDef, NestedDict)`` pair and creates a proxy object that can be
1462
+ used to call methods on the underlying graph node. When a method is called, the
1463
+ output is returned along with a new (GraphDef, NestedDict) pair that represents the
1464
+ updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method``
1465
+ > :func:`split`` but is more convenient to use in pure JAX functions.
1466
+
1467
+ Example::
1468
+
1469
+ >>> import brainstate as bst
1470
+ >>> import jax
1471
+ >>> import jax.numpy as jnp
1472
+ ...
1473
+ >>> class StatefulLinear(bst.graph.Node):
1474
+ ... def __init__(self, din, dout):
1475
+ ... self.w = bst.ParamState(bst.random.rand(din, dout))
1476
+ ... self.b = bst.ParamState(jnp.zeros((dout,)))
1477
+ ... self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
1478
+ ...
1479
+ ... def increment(self):
1480
+ ... self.count.value += 1
1481
+ ...
1482
+ ... def __call__(self, x):
1483
+ ... self.increment()
1484
+ ... return x @ self.w.value + self.b.value
1485
+ ...
1486
+ >>> linear = StatefulLinear(3, 2)
1487
+ >>> linear_state = bst.graph.treefy_split(linear)
1488
+ ...
1489
+ >>> @jax.jit
1490
+ ... def forward(x, linear_state):
1491
+ ... y, linear_state = bst.graph.call(linear_state)(x)
1492
+ ... return y, linear_state
1493
+ ...
1494
+ >>> x = jnp.ones((1, 3))
1495
+ >>> y, linear_state = forward(x, linear_state)
1496
+ >>> y, linear_state = forward(x, linear_state)
1497
+ ...
1498
+ >>> linear = bst.graph.treefy_merge(*linear_state)
1499
+ >>> linear.count.value
1500
+ Array(2, dtype=uint32)
1501
+
1502
+ The proxy object returned by ``call`` supports indexing and attribute access
1503
+ to access nested methods. In the example below, the ``increment`` method indexing
1504
+ is used to call the ``increment`` method of the ``StatefulLinear`` module
1505
+ at the ``b`` key of a ``nodes`` dictionary.
1506
+
1507
+ >>> class StatefulLinear(bst.graph.Node):
1508
+ ... def __init__(self, din, dout):
1509
+ ... self.w = bst.ParamState(bst.random.rand(din, dout))
1510
+ ... self.b = bst.ParamState(jnp.zeros((dout,)))
1511
+ ... self.count = bst.State(jnp.array(0, dtype=jnp.uint32))
1512
+ ...
1513
+ ... def increment(self):
1514
+ ... self.count.value += 1
1515
+ ...
1516
+ ... def __call__(self, x):
1517
+ ... self.increment()
1518
+ ... return x @ self.w.value + self.b.value
1519
+ ...
1520
+ >>> nodes = dict(
1521
+ ... a=StatefulLinear(3, 2),
1522
+ ... b=StatefulLinear(2, 1),
1523
+ ... )
1524
+ ...
1525
+ >>> node_state = treefy_split(nodes)
1526
+ >>> # use attribute access
1527
+ >>> _, node_state = bst.graph.call(node_state)['b'].increment()
1528
+ ...
1529
+ >>> nodes = treefy_merge(*node_state)
1530
+ >>> nodes['a'].count.value
1531
+ Array(0, dtype=uint32)
1532
+ >>> nodes['b'].count.value
1533
+ Array(1, dtype=uint32)
1534
+ """
1535
+
1536
+ def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
1537
+ node = treefy_merge(*graphdef_state)
1538
+ method = accessor(node)
1539
+ out = method(*args, **kwargs)
1540
+ return out, treefy_split(node)
1541
+
1542
+ return CallableProxy(pure_caller) # type: ignore
1543
+
1544
+
1545
+ @set_module_as('brainstate.graph')
1546
+ def iter_leaf(
1547
+ node: Any,
1548
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1549
+ ) -> Iterator[tuple[PathParts, Any]]:
1550
+ """Iterates over all nested leaves in the given graph node, including the current node.
1551
+
1552
+ ``iter_graph`` creates a generator that yields path and value pairs, where
1553
+ the path is a tuple of strings or integers representing the path to the value from the
1554
+ root. Repeated nodes are visited only once. Leaves include static values.
1555
+
1556
+ Example::
1557
+ >>> import brainstate as bst
1558
+ >>> import jax.numpy as jnp
1559
+ ...
1560
+ >>> class Linear(bst.nn.Module):
1561
+ ... def __init__(self, din, dout):
1562
+ ... super().__init__()
1563
+ ... self.weight = bst.ParamState(bst.random.randn(din, dout))
1564
+ ... self.bias = bst.ParamState(bst.random.randn(dout))
1565
+ ... self.a = 1
1566
+ ...
1567
+ >>> module = Linear(3, 4)
1568
+ ...
1569
+ >>> for path, value in bst.graph.iter_leaf([module, module]):
1570
+ ... print(path, type(value).__name__)
1571
+ ...
1572
+ (0, 'a') int
1573
+ (0, 'bias') ParamState
1574
+ (0, 'weight') ParamState
1575
+
1576
+ Parameters
1577
+ ----------
1578
+ node: Node
1579
+ The node to iterate over.
1580
+ allowed_hierarchy: tuple of int
1581
+ The allowed hierarchy.
1582
+
1583
+ """
1584
+
1585
+ def _iter_graph_leaf(
1586
+ node_: Any,
1587
+ visited_: set[int],
1588
+ path_parts_: PathParts,
1589
+ level_: int,
1590
+ ) -> Iterator[tuple[PathParts, Any]]:
1591
+ if level_ > allowed_hierarchy[1]:
1592
+ return
1593
+
1594
+ if _is_node(node_):
1595
+ if id(node_) in visited_:
1596
+ return
1597
+ visited_.add(id(node_))
1598
+ node_dict = _get_node_impl(node_).node_dict(node_)
1599
+ for key, value in node_dict.items():
1600
+ yield from _iter_graph_leaf(value,
1601
+ visited_,
1602
+ (*path_parts_, key),
1603
+ level_ + 1 if _is_graph_node(value) else level_)
1604
+ else:
1605
+ if level_ >= allowed_hierarchy[0]:
1606
+ yield path_parts_, node_
1607
+
1608
+ visited: set[int] = set()
1609
+ path_parts: PathParts = ()
1610
+ level: int = 0
1611
+ yield from _iter_graph_leaf(node, visited, path_parts, level)
1612
+
1613
+
1614
+ @set_module_as('brainstate.graph')
1615
+ def iter_node(
1616
+ node: Any,
1617
+ allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1618
+ ) -> Iterator[Tuple[PathParts, Any]]:
1619
+ """
1620
+ Iterates over all nested nodes of the given graph node, including the current node.
1621
+
1622
+ ``iter_graph`` creates a generator that yields path and value pairs, where
1623
+ the path is a tuple of strings or integers representing the path to the value from the
1624
+ root. Repeated nodes are visited only once. Leaves include static values.
1625
+
1626
+ Example::
1627
+ >>> import brainstate as bst
1628
+ >>> import jax.numpy as jnp
1629
+ ...
1630
+ >>> class Model(bst.nn.Module):
1631
+ ... def __init__(self):
1632
+ ... super().__init__()
1633
+ ... self.a = bst.nn.Linear(1, 2)
1634
+ ... self.b = bst.nn.Linear(2, 3)
1635
+ ... self.c = [bst.nn.Linear(3, 4), bst.nn.Linear(4, 5)]
1636
+ ... self.d = {'x': bst.nn.Linear(5, 6), 'y': bst.nn.Linear(6, 7)}
1637
+ ... self.b.a = bst.nn.LIF(2)
1638
+ ...
1639
+ >>> model = Model()
1640
+ ...
1641
+ >>> for path, node in bst.graph.iter_node([model, model]):
1642
+ ... print(path, node.__class__.__name__)
1643
+ ...
1644
+ (0, 'a') Linear
1645
+ (0, 'b', 'a') LIF
1646
+ (0, 'b') Linear
1647
+ (0, 'c', 0) Linear
1648
+ (0, 'c', 1) Linear
1649
+ (0, 'd', 'x') Linear
1650
+ (0, 'd', 'y') Linear
1651
+ (0,) Model
1652
+
1653
+ Parameters
1654
+ ----------
1655
+ node: Node
1656
+ The node to iterate over.
1657
+ allowed_hierarchy: tuple of int
1658
+ The allowed hierarchy.
1659
+
1660
+ """
1661
+
1662
+ def _iter_graph_node(
1663
+ node_: Any,
1664
+ visited_: set[int],
1665
+ path_parts_: PathParts,
1666
+ level_: int,
1667
+ ) -> Iterator[tuple[PathParts, Any]]:
1668
+ if level_ > allowed_hierarchy[1]:
1669
+ return
1670
+
1671
+ if _is_node(node_):
1672
+ if id(node_) in visited_:
1673
+ return
1674
+
1675
+ visited_.add(id(node_))
1676
+ node_dict = _get_node_impl(node_).node_dict(node_)
1677
+ for key, value in node_dict.items():
1678
+ yield from _iter_graph_node(value, visited_, (*path_parts_, key),
1679
+ level_ + 1 if _is_graph_node(value) else level_)
1680
+
1681
+ if _is_graph_node(node_) and level_ >= allowed_hierarchy[0]:
1682
+ yield path_parts_, node_
1683
+
1684
+ visited: set[int] = set()
1685
+ path_parts: PathParts = ()
1686
+ level: int = 0
1687
+ yield from _iter_graph_node(node, visited, path_parts, level)
1688
+
1689
+
1690
+ # --------------------------------------------------------
1691
+ # Graph operations: end
1692
+ # --------------------------------------------------------
1693
+
1694
+
1695
+ @dataclasses.dataclass(frozen=True)
1696
+ class Static(Generic[A]):
1697
+ """An empty pytree node that treats its inner value as static.
1698
+ ``value`` must define ``__eq__`` and ``__hash__``.
1699
+ """
1700
+
1701
+ value: A
1702
+
1703
+
1704
+ jax.tree_util.register_static(Static)
1705
+
1706
+
1707
+ # ---------------------------------------------------------
1708
+ # Pytree
1709
+ # ---------------------------------------------------------
1710
+
1711
+ class PytreeType:
1712
+ ...
1713
+
1714
+
1715
+ def _key_path_to_key(key: Any) -> Key:
1716
+ if isinstance(key, jax.tree_util.SequenceKey):
1717
+ return key.idx
1718
+ elif isinstance(
1719
+ key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
1720
+ ):
1721
+ if not isinstance(key.key, Key):
1722
+ raise ValueError(
1723
+ f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
1724
+ )
1725
+ return key.key
1726
+ elif isinstance(key, jax.tree_util.GetAttrKey):
1727
+ return key.name
1728
+ else:
1729
+ return str(key)
1730
+
1731
+
1732
+ def _flatten_pytree(pytree: Any):
1733
+ leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
1734
+ nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
1735
+ return nodes, treedef
1736
+
1737
+
1738
+ def _unflatten_pytree(
1739
+ nodes: tuple[tuple[Key, Any], ...],
1740
+ treedef: jax.tree_util.PyTreeDef
1741
+ ):
1742
+ pytree = treedef.unflatten(value for _, value in nodes)
1743
+ return pytree
1744
+
1745
+
1746
+ PYTREE_NODE_IMPL = PyTreeNodeImpl(type=PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree)