brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from __future__ import annotations
|
|
20
20
|
import dataclasses
|
21
21
|
from typing import (
|
22
22
|
Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
23
|
-
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
|
23
|
+
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
|
24
24
|
)
|
25
25
|
|
26
26
|
import jax
|
@@ -30,25 +30,40 @@ from typing_extensions import TypeGuard, Unpack
|
|
30
30
|
from brainstate._state import State, TreefyState
|
31
31
|
from brainstate._utils import set_module_as
|
32
32
|
from brainstate.typing import PathParts, Filter, Predicate, Key
|
33
|
-
from brainstate.util.
|
34
|
-
from brainstate.util.
|
35
|
-
from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
|
36
|
-
from brainstate.util.struct import FrozenDict
|
33
|
+
from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
|
34
|
+
from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
|
37
35
|
from brainstate.util.filter import to_predicate
|
38
|
-
|
39
|
-
_max_int = np.iinfo(np.int32).max
|
36
|
+
from brainstate.util.struct import FrozenDict
|
40
37
|
|
41
38
|
__all__ = [
|
39
|
+
'register_graph_node_type',
|
40
|
+
|
42
41
|
# state management in the given graph or node
|
43
|
-
'pop_states',
|
42
|
+
'pop_states',
|
43
|
+
'nodes',
|
44
|
+
'states',
|
45
|
+
'treefy_states',
|
46
|
+
'update_states',
|
44
47
|
|
45
48
|
# graph node operations
|
46
|
-
'flatten',
|
49
|
+
'flatten',
|
50
|
+
'unflatten',
|
51
|
+
'treefy_split',
|
52
|
+
'treefy_merge',
|
53
|
+
'iter_leaf',
|
54
|
+
'iter_node',
|
55
|
+
'clone',
|
56
|
+
'graphdef',
|
47
57
|
|
48
58
|
# others
|
49
|
-
'RefMap',
|
59
|
+
'RefMap',
|
60
|
+
'GraphDef',
|
61
|
+
'NodeDef',
|
62
|
+
'NodeRef',
|
50
63
|
]
|
51
64
|
|
65
|
+
MAX_INT = np.iinfo(np.int32).max
|
66
|
+
|
52
67
|
A = TypeVar('A')
|
53
68
|
B = TypeVar('B')
|
54
69
|
C = TypeVar('C')
|
@@ -65,12 +80,11 @@ AuxData = TypeVar('AuxData')
|
|
65
80
|
|
66
81
|
StateLeaf = TreefyState[Any]
|
67
82
|
NodeLeaf = State[Any]
|
68
|
-
GraphStateMapping = NestedDict
|
83
|
+
GraphStateMapping = NestedDict
|
69
84
|
|
70
85
|
|
71
86
|
# --------------------------------------------------------
|
72
87
|
|
73
|
-
|
74
88
|
def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
|
75
89
|
return isinstance(x, TreefyState)
|
76
90
|
|
@@ -86,13 +100,30 @@ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
|
|
86
100
|
This mapping is useful when we want to keep track of objects
|
87
101
|
that are being referenced by other objects.
|
88
102
|
|
89
|
-
|
90
|
-
|
103
|
+
Parameters
|
104
|
+
----------
|
105
|
+
mapping : Mapping[A, B] or Iterable[Tuple[A, B]], optional
|
106
|
+
A mapping or iterable of key-value pairs.
|
107
|
+
|
108
|
+
Examples
|
109
|
+
--------
|
110
|
+
.. code-block:: python
|
111
|
+
|
112
|
+
>>> import brainstate
|
113
|
+
>>> obj1 = object()
|
114
|
+
>>> obj2 = object()
|
115
|
+
>>> ref_map = brainstate.graph.RefMap()
|
116
|
+
>>> ref_map[obj1] = 'value1'
|
117
|
+
>>> ref_map[obj2] = 'value2'
|
118
|
+
>>> print(obj1 in ref_map)
|
119
|
+
True
|
120
|
+
>>> print(ref_map[obj1])
|
121
|
+
value1
|
91
122
|
|
92
123
|
"""
|
93
124
|
__module__ = 'brainstate.graph'
|
94
125
|
|
95
|
-
def __init__(self, mapping: Mapping[A, B]
|
126
|
+
def __init__(self, mapping: Union[Mapping[A, B], Iterable[Tuple[A, B]]] = ()) -> None:
|
96
127
|
self._mapping: Dict[int, Tuple[A, B]] = {}
|
97
128
|
self.update(mapping)
|
98
129
|
|
@@ -102,10 +133,10 @@ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
|
|
102
133
|
def __contains__(self, key: Any) -> bool:
|
103
134
|
return id(key) in self._mapping
|
104
135
|
|
105
|
-
def __setitem__(self, key: A, value: B):
|
136
|
+
def __setitem__(self, key: A, value: B) -> None:
|
106
137
|
self._mapping[id(key)] = (key, value)
|
107
138
|
|
108
|
-
def __delitem__(self, key: A):
|
139
|
+
def __delitem__(self, key: A) -> None:
|
109
140
|
del self._mapping[id(key)]
|
110
141
|
|
111
142
|
def __iter__(self) -> Iterator[A]:
|
@@ -135,7 +166,7 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
|
|
135
166
|
create_empty: Callable[[AuxData], Node]
|
136
167
|
clear: Callable[[Node], None]
|
137
168
|
|
138
|
-
def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]):
|
169
|
+
def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]) -> None:
|
139
170
|
for key, value in items:
|
140
171
|
self.set_key(node, key, value)
|
141
172
|
|
@@ -151,7 +182,7 @@ NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf,
|
|
151
182
|
# Graph Node implementation: start
|
152
183
|
# --------------------------------------------------------
|
153
184
|
|
154
|
-
_node_impl_for_type: dict[type, NodeImpl
|
185
|
+
_node_impl_for_type: dict[type, NodeImpl] = {}
|
155
186
|
|
156
187
|
|
157
188
|
def register_graph_node_type(
|
@@ -165,13 +196,56 @@ def register_graph_node_type(
|
|
165
196
|
"""
|
166
197
|
Register a graph node type.
|
167
198
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
199
|
+
Parameters
|
200
|
+
----------
|
201
|
+
type : type
|
202
|
+
The type of the node.
|
203
|
+
flatten : Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
|
204
|
+
A function that flattens the node into a sequence of key-value pairs.
|
205
|
+
set_key : Callable[[Node, Key, Leaf], None]
|
206
|
+
A function that sets a key in the node.
|
207
|
+
pop_key : Callable[[Node, Key], Leaf]
|
208
|
+
A function that pops a key from the node.
|
209
|
+
create_empty : Callable[[AuxData], Node]
|
210
|
+
A function that creates an empty node.
|
211
|
+
clear : Callable[[Node], None]
|
212
|
+
A function that clears the node.
|
213
|
+
|
214
|
+
Examples
|
215
|
+
--------
|
216
|
+
.. code-block:: python
|
217
|
+
|
218
|
+
>>> import brainstate
|
219
|
+
>>> # Custom node type implementation
|
220
|
+
>>> class CustomNode:
|
221
|
+
... def __init__(self):
|
222
|
+
... self.data = {}
|
223
|
+
...
|
224
|
+
>>> def flatten_custom(node):
|
225
|
+
... return list(node.data.items()), None
|
226
|
+
...
|
227
|
+
>>> def set_key_custom(node, key, value):
|
228
|
+
... node.data[key] = value
|
229
|
+
...
|
230
|
+
>>> def pop_key_custom(node, key):
|
231
|
+
... return node.data.pop(key)
|
232
|
+
...
|
233
|
+
>>> def create_empty_custom(metadata):
|
234
|
+
... return CustomNode()
|
235
|
+
...
|
236
|
+
>>> def clear_custom(node):
|
237
|
+
... node.data.clear()
|
238
|
+
...
|
239
|
+
>>> # Register the custom node type
|
240
|
+
>>> brainstate.graph.register_graph_node_type(
|
241
|
+
... CustomNode,
|
242
|
+
... flatten_custom,
|
243
|
+
... set_key_custom,
|
244
|
+
... pop_key_custom,
|
245
|
+
... create_empty_custom,
|
246
|
+
... clear_custom
|
247
|
+
... )
|
248
|
+
|
175
249
|
"""
|
176
250
|
_node_impl_for_type[type] = GraphNodeImpl(
|
177
251
|
type=type,
|
@@ -200,11 +274,11 @@ def _is_graph_node(x: Any) -> bool:
|
|
200
274
|
return type(x) in _node_impl_for_type
|
201
275
|
|
202
276
|
|
203
|
-
def _is_node_type(x:
|
277
|
+
def _is_node_type(x: Type[Any]) -> bool:
|
204
278
|
return x in _node_impl_for_type or x is PytreeType
|
205
279
|
|
206
280
|
|
207
|
-
def _get_node_impl(x:
|
281
|
+
def _get_node_impl(x: Any) -> NodeImpl:
|
208
282
|
if isinstance(x, State):
|
209
283
|
raise ValueError(f'State is not a node: {x}')
|
210
284
|
|
@@ -218,14 +292,14 @@ def _get_node_impl(x: Node) -> NodeImpl[Node, Any, Any]:
|
|
218
292
|
return _node_impl_for_type[node_type]
|
219
293
|
|
220
294
|
|
221
|
-
def get_node_impl_for_type(x:
|
295
|
+
def get_node_impl_for_type(x: Type[Any]) -> NodeImpl:
|
222
296
|
if x is PytreeType:
|
223
297
|
return PYTREE_NODE_IMPL
|
224
298
|
return _node_impl_for_type[x]
|
225
299
|
|
226
300
|
|
227
301
|
class HashableMapping(Mapping[HA, HB], Hashable):
|
228
|
-
def __init__(self, mapping: Mapping[HA, HB]
|
302
|
+
def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
|
229
303
|
self._mapping = dict(mapping)
|
230
304
|
|
231
305
|
def __contains__(self, key: object) -> bool:
|
@@ -259,57 +333,53 @@ class GraphDef(Generic[Node]):
|
|
259
333
|
- index: The index of the node in the graph.
|
260
334
|
|
261
335
|
It has two concrete subclasses:
|
336
|
+
|
262
337
|
- :class:`NodeRef`: A reference to a node in the graph.
|
263
338
|
- :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
|
264
339
|
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
class NodeRef(GraphDef[Node], PrettyRepr):
|
272
|
-
"""
|
273
|
-
A reference to a node in the graph.
|
340
|
+
Attributes
|
341
|
+
----------
|
342
|
+
type : Type[Node]
|
343
|
+
The type of the node.
|
344
|
+
index : int
|
345
|
+
The index of the node in the graph.
|
274
346
|
|
275
|
-
The node can be instances of :class:`Node` or :class:`State`.
|
276
347
|
"""
|
277
|
-
type:
|
348
|
+
type: Type[Node]
|
278
349
|
index: int
|
279
350
|
|
280
|
-
def __pretty_repr__(self):
|
281
|
-
yield PrettyType(type=type(self))
|
282
|
-
yield PrettyAttr('type', self.type.__name__)
|
283
|
-
yield PrettyAttr('index', self.index)
|
284
|
-
|
285
|
-
def __treescope_repr__(self, path, subtree_renderer):
|
286
|
-
"""
|
287
|
-
Treescope repr for the object.
|
288
|
-
"""
|
289
|
-
import treescope # type: ignore[import-not-found,import-untyped]
|
290
|
-
return treescope.repr_lib.render_object_constructor(
|
291
|
-
object_type=type(self),
|
292
|
-
attributes={'type': self.type, 'index': self.index},
|
293
|
-
path=path,
|
294
|
-
subtree_renderer=subtree_renderer,
|
295
|
-
)
|
296
|
-
|
297
|
-
|
298
|
-
jax.tree_util.register_static(NodeRef)
|
299
|
-
|
300
351
|
|
301
352
|
@dataclasses.dataclass(frozen=True, repr=False)
|
302
353
|
class NodeDef(GraphDef[Node], PrettyRepr):
|
303
354
|
"""
|
304
355
|
A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
|
305
356
|
|
357
|
+
Attributes
|
358
|
+
----------
|
359
|
+
type : Type[Node]
|
360
|
+
Type of the node.
|
361
|
+
index : int
|
362
|
+
Index of the node in the graph.
|
363
|
+
attributes : Tuple[Key, ...]
|
364
|
+
Attributes for the node.
|
365
|
+
subgraphs : HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
|
366
|
+
Mapping of subgraph definitions.
|
367
|
+
static_fields : HashableMapping
|
368
|
+
Mapping of static fields.
|
369
|
+
leaves : HashableMapping[Key, NodeRef[Any] | None]
|
370
|
+
Mapping of leaf nodes.
|
371
|
+
metadata : Hashable
|
372
|
+
Metadata associated with the node.
|
373
|
+
index_mapping : FrozenDict[Index, Index] | None
|
374
|
+
Index mapping for node references.
|
375
|
+
|
306
376
|
"""
|
307
377
|
|
308
378
|
type: Type[Node] # type of the node
|
309
379
|
index: int # index of the node in the graph
|
310
380
|
attributes: Tuple[Key, ...] # attributes for the node
|
311
381
|
subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
|
312
|
-
static_fields: HashableMapping
|
382
|
+
static_fields: HashableMapping
|
313
383
|
leaves: HashableMapping[Key, NodeRef[Any] | None]
|
314
384
|
metadata: Hashable
|
315
385
|
index_mapping: FrozenDict[Index, Index] | None
|
@@ -321,7 +391,7 @@ class NodeDef(GraphDef[Node], PrettyRepr):
|
|
321
391
|
index: int,
|
322
392
|
attributes: tuple[Key, ...],
|
323
393
|
subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
|
324
|
-
static_fields: Iterable[tuple
|
394
|
+
static_fields: Iterable[tuple],
|
325
395
|
leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
|
326
396
|
metadata: Hashable,
|
327
397
|
index_mapping: Mapping[Index, Index] | None,
|
@@ -349,24 +419,35 @@ class NodeDef(GraphDef[Node], PrettyRepr):
|
|
349
419
|
yield PrettyAttr('metadata', self.metadata)
|
350
420
|
yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
|
351
421
|
|
352
|
-
def apply(
|
353
|
-
self,
|
354
|
-
state_map: GraphStateMapping,
|
355
|
-
*state_maps: GraphStateMapping
|
356
|
-
) -> ApplyCaller[tuple[GraphDef[Node], GraphStateMapping]]:
|
357
|
-
accessor = DelayedAccessor()
|
358
422
|
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
423
|
+
jax.tree_util.register_static(NodeDef)
|
424
|
+
|
425
|
+
|
426
|
+
@dataclasses.dataclass(frozen=True, repr=False)
|
427
|
+
class NodeRef(GraphDef[Node], PrettyRepr):
|
428
|
+
"""
|
429
|
+
A reference to a node in the graph.
|
430
|
+
|
431
|
+
The node can be instances of :class:`Node` or :class:`State`.
|
365
432
|
|
366
|
-
|
433
|
+
Attributes
|
434
|
+
----------
|
435
|
+
type : Type[Node]
|
436
|
+
The type of the node being referenced.
|
437
|
+
index : int
|
438
|
+
The index of the node in the graph.
|
367
439
|
|
440
|
+
"""
|
441
|
+
type: Type[Node]
|
442
|
+
index: int
|
368
443
|
|
369
|
-
|
444
|
+
def __pretty_repr__(self):
|
445
|
+
yield PrettyType(type=type(self))
|
446
|
+
yield PrettyAttr('type', self.type.__name__)
|
447
|
+
yield PrettyAttr('index', self.index)
|
448
|
+
|
449
|
+
|
450
|
+
jax.tree_util.register_static(NodeRef)
|
370
451
|
|
371
452
|
|
372
453
|
# --------------------------------------------------------
|
@@ -378,20 +459,30 @@ def _graph_flatten(
|
|
378
459
|
path: PathParts,
|
379
460
|
ref_index: RefMap[Any, Index],
|
380
461
|
flatted_state_mapping: Dict[PathParts, StateLeaf],
|
381
|
-
node:
|
462
|
+
node: Any,
|
382
463
|
treefy_state: bool = False,
|
383
|
-
):
|
464
|
+
) -> Union[NodeDef[Any], NodeRef[Any]]:
|
384
465
|
"""
|
385
466
|
Recursive helper for graph flatten.
|
386
467
|
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
468
|
+
Parameters
|
469
|
+
----------
|
470
|
+
path : PathParts
|
471
|
+
The path to the node.
|
472
|
+
ref_index : RefMap[Any, Index]
|
473
|
+
A mapping from nodes to indexes.
|
474
|
+
flatted_state_mapping : Dict[PathParts, StateLeaf]
|
475
|
+
A mapping from paths to state leaves.
|
476
|
+
node : Node
|
477
|
+
The node to flatten.
|
478
|
+
treefy_state : bool, optional
|
479
|
+
Whether to convert states to TreefyState, by default False.
|
480
|
+
|
481
|
+
Returns
|
482
|
+
-------
|
483
|
+
NodeDef or NodeRef
|
484
|
+
A NodeDef or a NodeRef.
|
392
485
|
|
393
|
-
Returns:
|
394
|
-
A NodeDef or a NodeRef.
|
395
486
|
"""
|
396
487
|
if not _is_node(node):
|
397
488
|
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
|
@@ -417,9 +508,9 @@ def _graph_flatten(
|
|
417
508
|
else:
|
418
509
|
index = -1
|
419
510
|
|
420
|
-
subgraphs: list[tuple[Key, NodeDef[
|
421
|
-
static_fields: list[tuple
|
422
|
-
leaves: list[tuple[Key, NodeRef
|
511
|
+
subgraphs: list[tuple[Key, Union[NodeDef[Any], NodeRef[Any]]]] = []
|
512
|
+
static_fields: list[tuple] = []
|
513
|
+
leaves: list[tuple[Key, Union[NodeRef[Any], None]]] = []
|
423
514
|
|
424
515
|
# Flatten the node into a sequence of key-value pairs.
|
425
516
|
values, metadata = node_impl.flatten(node)
|
@@ -450,41 +541,56 @@ def _graph_flatten(
|
|
450
541
|
# The value is a static field.
|
451
542
|
static_fields.append((key, value))
|
452
543
|
|
453
|
-
nodedef = NodeDef.create(
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
544
|
+
nodedef = NodeDef.create(
|
545
|
+
type=node_impl.type,
|
546
|
+
index=index,
|
547
|
+
attributes=tuple(key for key, _ in values),
|
548
|
+
subgraphs=subgraphs,
|
549
|
+
static_fields=static_fields,
|
550
|
+
leaves=leaves,
|
551
|
+
metadata=metadata,
|
552
|
+
index_mapping=None,
|
553
|
+
)
|
461
554
|
return nodedef
|
462
555
|
|
463
556
|
|
464
557
|
@set_module_as('brainstate.graph')
|
465
558
|
def flatten(
|
466
|
-
node:
|
559
|
+
node: Any,
|
467
560
|
/,
|
468
561
|
ref_index: Optional[RefMap[Any, Index]] = None,
|
469
562
|
treefy_state: bool = True,
|
470
|
-
) -> Tuple[GraphDef, NestedDict]:
|
563
|
+
) -> Tuple[GraphDef[Any], NestedDict]:
|
471
564
|
"""
|
472
565
|
Flattens a graph node into a (graph_def, state_mapping) pair.
|
473
566
|
|
474
|
-
|
475
|
-
|
476
|
-
|
567
|
+
Parameters
|
568
|
+
----------
|
569
|
+
node : Node
|
570
|
+
A graph node.
|
571
|
+
ref_index : RefMap[Any, Index], optional
|
572
|
+
A mapping from nodes to indexes, defaults to None. If not provided, a new
|
573
|
+
empty dictionary is created. This argument can be used to flatten a sequence of graph
|
574
|
+
nodes that share references.
|
575
|
+
treefy_state : bool, optional
|
576
|
+
If True, the state mapping will be a NestedDict instead of a flat dictionary.
|
577
|
+
Default is True.
|
578
|
+
|
579
|
+
Returns
|
580
|
+
-------
|
581
|
+
tuple[GraphDef, NestedDict]
|
582
|
+
A tuple containing the graph definition and state mapping.
|
583
|
+
|
584
|
+
Examples
|
585
|
+
--------
|
586
|
+
.. code-block:: python
|
587
|
+
|
588
|
+
>>> import brainstate
|
477
589
|
>>> node = brainstate.graph.Node()
|
478
|
-
>>> graph_def, state_mapping = flatten(node)
|
590
|
+
>>> graph_def, state_mapping = brainstate.graph.flatten(node)
|
479
591
|
>>> print(graph_def)
|
480
592
|
>>> print(state_mapping)
|
481
593
|
|
482
|
-
Args:
|
483
|
-
node: A graph node.
|
484
|
-
ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
|
485
|
-
empty dictionary is created. This argument can be used to flatten a sequence of graph
|
486
|
-
nodes that share references.
|
487
|
-
treefy_state: If True, the state mapping will be a NestedDict instead of a flat dictionary.
|
488
594
|
"""
|
489
595
|
ref_index = RefMap() if ref_index is None else ref_index
|
490
596
|
assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
|
@@ -493,8 +599,13 @@ def flatten(
|
|
493
599
|
return graph_def, NestedDict.from_flat(flatted_state_mapping)
|
494
600
|
|
495
601
|
|
496
|
-
def _get_children(
|
497
|
-
|
602
|
+
def _get_children(
|
603
|
+
graph_def: NodeDef[Any],
|
604
|
+
state_mapping: Mapping,
|
605
|
+
index_ref: dict[Index, Any],
|
606
|
+
index_ref_cache: Optional[dict[Index, Any]],
|
607
|
+
) -> dict[Key, Union[StateLeaf, Any]]:
|
608
|
+
children: dict[Key, Union[StateLeaf, Any]] = {}
|
498
609
|
|
499
610
|
# NOTE: we could allow adding new StateLeafs here
|
500
611
|
# All state keys must be present in the graph definition (the object attributes)
|
@@ -506,8 +617,8 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
506
617
|
# - (3) the key can be a subgraph, a leaf, or a static attribute
|
507
618
|
for key in graph_def.attributes:
|
508
619
|
if key not in state_mapping: # static field
|
509
|
-
#
|
510
|
-
#
|
620
|
+
# Support unflattening with missing keys for static fields and subgraphs
|
621
|
+
# This allows partial state restoration and flexible graph reconstruction
|
511
622
|
if key in graph_def.static_fields:
|
512
623
|
children[key] = graph_def.static_fields[key]
|
513
624
|
|
@@ -534,8 +645,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
534
645
|
|
535
646
|
else:
|
536
647
|
# key for a variable is missing, raise an error
|
537
|
-
raise ValueError(
|
538
|
-
|
648
|
+
raise ValueError(
|
649
|
+
f'Expected key {key!r} in state while building node of type '
|
650
|
+
f'{graph_def.type.__name__}.'
|
651
|
+
)
|
539
652
|
|
540
653
|
else:
|
541
654
|
raise RuntimeError(f'Unknown static field: {key!r}')
|
@@ -551,8 +664,11 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
551
664
|
if key in graph_def.subgraphs:
|
552
665
|
# if _is_state_leaf(value):
|
553
666
|
if isinstance(value, (TreefyState, State)):
|
554
|
-
raise ValueError(
|
555
|
-
|
667
|
+
raise ValueError(
|
668
|
+
f'Expected value of type {graph_def.subgraphs[key]} '
|
669
|
+
f'for {key!r}, but got {value!r}'
|
670
|
+
)
|
671
|
+
|
556
672
|
if not isinstance(value, dict):
|
557
673
|
raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
|
558
674
|
|
@@ -574,8 +690,8 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
574
690
|
# TreefyState presumbly created by modifying the NestedDict
|
575
691
|
if isinstance(value, TreefyState):
|
576
692
|
value = value.to_state()
|
577
|
-
|
578
|
-
|
693
|
+
elif isinstance(value, State):
|
694
|
+
value = value
|
579
695
|
children[key] = value
|
580
696
|
|
581
697
|
elif noderef.index in index_ref:
|
@@ -585,7 +701,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
585
701
|
else:
|
586
702
|
# it is an unseen variable, create a new one
|
587
703
|
if not isinstance(value, (TreefyState, State)):
|
588
|
-
raise ValueError(
|
704
|
+
raise ValueError(
|
705
|
+
f'Expected a State type for {key!r}, but got {type(value)}.'
|
706
|
+
)
|
707
|
+
|
589
708
|
# when idxmap is present, check if the Varable exists there
|
590
709
|
# and update existing variables if it does
|
591
710
|
if index_ref_cache is not None and noderef.index in index_ref_cache:
|
@@ -618,11 +737,11 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
618
737
|
|
619
738
|
|
620
739
|
def _graph_unflatten(
|
621
|
-
graph_def: NodeDef[
|
622
|
-
state_mapping: Mapping[Key, StateLeaf
|
740
|
+
graph_def: Union[NodeDef[Any], NodeRef[Any]],
|
741
|
+
state_mapping: Mapping[Key, Union[StateLeaf, Mapping]],
|
623
742
|
index_ref: dict[Index, Any],
|
624
|
-
index_ref_cache: dict[Index, Any]
|
625
|
-
) ->
|
743
|
+
index_ref_cache: Optional[dict[Index, Any]],
|
744
|
+
) -> Any:
|
626
745
|
"""
|
627
746
|
Recursive helper for graph unflatten.
|
628
747
|
|
@@ -697,175 +816,57 @@ def _graph_unflatten(
|
|
697
816
|
|
698
817
|
@set_module_as('brainstate.graph')
|
699
818
|
def unflatten(
|
700
|
-
graph_def: GraphDef,
|
701
|
-
state_mapping: NestedDict
|
819
|
+
graph_def: GraphDef[Any],
|
820
|
+
state_mapping: NestedDict,
|
702
821
|
/,
|
703
822
|
*,
|
704
|
-
index_ref: dict[Index, Any]
|
705
|
-
index_ref_cache: dict[Index, Any]
|
706
|
-
) ->
|
823
|
+
index_ref: Optional[dict[Index, Any]] = None,
|
824
|
+
index_ref_cache: Optional[dict[Index, Any]] = None,
|
825
|
+
) -> Any:
|
707
826
|
"""
|
708
827
|
Unflattens a graphdef into a node with the given state tree mapping.
|
709
828
|
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
},
|
752
|
-
'd': {
|
753
|
-
'x': {
|
754
|
-
'weight': TreefyState(
|
755
|
-
type=ParamState,
|
756
|
-
value={'weight': Array([[ 0.9647322, -0.8958757, 1.585352 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
|
757
|
-
)
|
758
|
-
},
|
759
|
-
'y': {
|
760
|
-
'weight': TreefyState(
|
761
|
-
type=ParamState,
|
762
|
-
value={'weight': Array([[-1.2904786 , 0.5695903 , 0.40079263, 0.8769669 ]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
|
763
|
-
)
|
764
|
-
}
|
765
|
-
}
|
766
|
-
})
|
767
|
-
>>> node = brainstate.graph.unflatten(graphdef, statetree)
|
768
|
-
>>> node
|
769
|
-
MyNode(
|
770
|
-
a=Linear(
|
771
|
-
in_size=(2,),
|
772
|
-
out_size=(3,),
|
773
|
-
w_mask=None,
|
774
|
-
weight=ParamState(
|
775
|
-
value={'weight': Array([[ 0.55600464, -1.6276929 , 0.26805446],
|
776
|
-
[ 1.175099 , 1.0077754 , 0.37592274]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)},
|
777
|
-
)
|
778
|
-
),
|
779
|
-
b=Linear(
|
780
|
-
in_size=(3,),
|
781
|
-
out_size=(4,),
|
782
|
-
w_mask=None,
|
783
|
-
weight=ParamState(
|
784
|
-
value={'weight': Array([[-0.24753566, 0.18456966, -0.29438975, 0.16891003],
|
785
|
-
[-0.803741 , -0.46037054, -0.21617596, 0.1260884 ],
|
786
|
-
[-0.43074366, -0.24757433, 1.2237076 , -0.07842704]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)},
|
787
|
-
)
|
788
|
-
),
|
789
|
-
c=[Linear(
|
790
|
-
in_size=(4,),
|
791
|
-
out_size=(5,),
|
792
|
-
w_mask=None,
|
793
|
-
weight=ParamState(
|
794
|
-
value={'weight': Array([[-0.22384474, 0.79441446, -0.658726 , 0.05991402, 0.3014344 ],
|
795
|
-
[-1.4755846 , -0.42272082, -0.07692316, 0.03077666, 0.34513143],
|
796
|
-
[-0.69395834, 0.48617035, 1.1042316 , 0.13105175, -0.25620162],
|
797
|
-
[ 0.50389856, 0.6998943 , 0.43716812, 1.2168779 , -0.47325954]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)},
|
798
|
-
)
|
799
|
-
), Linear(
|
800
|
-
in_size=(5,),
|
801
|
-
out_size=(6,),
|
802
|
-
w_mask=None,
|
803
|
-
weight=ParamState(
|
804
|
-
value={'weight': Array([[ 0.07714394, 0.78213537, 0.6745718 , -0.22881542, 0.5523547 ,
|
805
|
-
-0.6399196 ],
|
806
|
-
[-0.22626828, -0.54522336, 0.07448788, -0.00464636, 1.1483842 ,
|
807
|
-
-0.57049096],
|
808
|
-
[-0.86659616, 0.5683135 , -0.7449975 , 1.1862832 , 0.15047254,
|
809
|
-
0.68890226],
|
810
|
-
[-1.0325443 , 0.2658072 , -0.10083053, -0.66915905, 0.11258496,
|
811
|
-
0.5440655 ],
|
812
|
-
[ 0.27917263, 0.05717273, -0.5682605 , -0.88345915, 0.01314917,
|
813
|
-
0.780759 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)},
|
814
|
-
)
|
815
|
-
)],
|
816
|
-
d={'x': Linear(
|
817
|
-
in_size=(6,),
|
818
|
-
out_size=(7,),
|
819
|
-
w_mask=None,
|
820
|
-
weight=ParamState(
|
821
|
-
value={'weight': Array([[-0.24238771, -0.23202638, 0.13663477, -0.48858666, 0.80871904,
|
822
|
-
0.00593298, 0.7595096 ],
|
823
|
-
[ 0.50457454, 0.24180941, 0.25048748, 0.8937061 , 0.25398138,
|
824
|
-
-1.2400566 , 0.00151599],
|
825
|
-
[-0.19136038, 0.34470603, -0.11892717, -0.12514868, -0.5871703 ,
|
826
|
-
0.13572927, -1.1859009 ],
|
827
|
-
[-0.01580911, 0.9301295 , -1.1246226 , -0.137708 , -0.4952151 ,
|
828
|
-
0.17537868, 0.98440856],
|
829
|
-
[ 0.6399284 , 0.01739843, 0.61856824, 0.93258303, 0.64012206,
|
830
|
-
0.22780116, -0.5763679 ],
|
831
|
-
[ 0.14077143, -1.0359222 , 0.28072503, 0.2557584 , -0.50622064,
|
832
|
-
0.4388198 , -0.26106128]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
|
833
|
-
)
|
834
|
-
), 'y': Linear(
|
835
|
-
in_size=(7,),
|
836
|
-
out_size=(8,),
|
837
|
-
w_mask=None,
|
838
|
-
weight=ParamState(
|
839
|
-
value={'weight': Array([[-0.23334591, -0.2893582 , 0.8071877 , -0.49038902, -0.29646504,
|
840
|
-
0.13624157, 0.22763114, 0.01906361],
|
841
|
-
[-0.26742765, 0.20136863, 0.35148615, 0.42135832, 0.06401154,
|
842
|
-
-0.78036404, 0.6616062 , 0.19437549],
|
843
|
-
[ 0.9229799 , -0.1205209 , 0.69602865, 0.9685676 , -0.99886954,
|
844
|
-
-0.12649904, -0.15393028, 0.65067965],
|
845
|
-
[ 0.7020109 , -0.5452006 , 0.3649151 , -0.42368713, 0.24738027,
|
846
|
-
0.29290223, -0.63721114, 0.6007214 ],
|
847
|
-
[-0.45045808, -0.08538888, -0.01338054, -0.39983988, 0.4028439 ,
|
848
|
-
1.0498686 , -0.24730456, 0.37612835],
|
849
|
-
[ 0.16273966, 0.9001257 , 0.15190877, -1.1129239 , -0.29441378,
|
850
|
-
0.5168159 , -0.4205143 , 0.45700482],
|
851
|
-
[ 0.08611429, -0.9271384 , -0.562362 , -0.586757 , 1.1611121 ,
|
852
|
-
0.5137503 , -0.46277294, 0.84642583]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
|
853
|
-
)
|
854
|
-
)}
|
855
|
-
)
|
856
|
-
|
857
|
-
Args:
|
858
|
-
graph_def: A GraphDef instance.
|
859
|
-
state_mapping: A NestedDict instance.
|
860
|
-
index_ref: A mapping from indexes to nodes references found during the graph
|
861
|
-
traversal, defaults to None. If not provided, a new empty dictionary is
|
862
|
-
created. This argument can be used to unflatten a sequence of (graphdef, state_mapping)
|
863
|
-
pairs that share the same index space.
|
864
|
-
index_ref_cache: A mapping from indexes to existing nodes that can be reused.
|
865
|
-
When a reference is reused, ``GraphNodeImpl.clear`` is called to leave the
|
866
|
-
object in an empty state and then filled by the unflatten process, as a result
|
867
|
-
existing graph nodes are mutated to have the new content/topology
|
868
|
-
specified by the graphdef.
|
829
|
+
Parameters
|
830
|
+
----------
|
831
|
+
graph_def : GraphDef
|
832
|
+
A GraphDef instance.
|
833
|
+
state_mapping : NestedDict
|
834
|
+
A NestedDict instance containing the state mapping.
|
835
|
+
index_ref : dict[Index, Any], optional
|
836
|
+
A mapping from indexes to nodes references found during the graph
|
837
|
+
traversal. If not provided, a new empty dictionary is created. This argument
|
838
|
+
can be used to unflatten a sequence of (graphdef, state_mapping) pairs that
|
839
|
+
share the same index space.
|
840
|
+
index_ref_cache : dict[Index, Any], optional
|
841
|
+
A mapping from indexes to existing nodes that can be reused. When a reference
|
842
|
+
is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty
|
843
|
+
state and then filled by the unflatten process. As a result, existing graph
|
844
|
+
nodes are mutated to have the new content/topology specified by the graphdef.
|
845
|
+
|
846
|
+
Returns
|
847
|
+
-------
|
848
|
+
Node
|
849
|
+
The reconstructed node.
|
850
|
+
|
851
|
+
Examples
|
852
|
+
--------
|
853
|
+
.. code-block:: python
|
854
|
+
|
855
|
+
>>> import brainstate
|
856
|
+
>>> class MyNode(brainstate.graph.Node):
|
857
|
+
... def __init__(self):
|
858
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
859
|
+
... self.b = brainstate.nn.Linear(3, 4)
|
860
|
+
...
|
861
|
+
>>> # Flatten a node
|
862
|
+
>>> node = MyNode()
|
863
|
+
>>> graphdef, statetree = brainstate.graph.flatten(node)
|
864
|
+
>>>
|
865
|
+
>>> # Unflatten back to node
|
866
|
+
>>> reconstructed_node = brainstate.graph.unflatten(graphdef, statetree)
|
867
|
+
>>> assert isinstance(reconstructed_node, MyNode)
|
868
|
+
>>> assert isinstance(reconstructed_node.a, brainstate.nn.Linear)
|
869
|
+
>>> assert isinstance(reconstructed_node.b, brainstate.nn.Linear)
|
869
870
|
"""
|
870
871
|
index_ref = {} if index_ref is None else index_ref
|
871
872
|
assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
|
@@ -874,7 +875,7 @@ def unflatten(
|
|
874
875
|
|
875
876
|
|
876
877
|
def _graph_pop(
|
877
|
-
node:
|
878
|
+
node: Any,
|
878
879
|
id_to_index: dict[int, Index],
|
879
880
|
path_parts: PathParts,
|
880
881
|
flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
|
@@ -922,61 +923,57 @@ def _graph_pop(
|
|
922
923
|
pass
|
923
924
|
|
924
925
|
|
925
|
-
@overload
|
926
|
-
def pop_states(node, filter1: Filter, /) -> NestedDict:
|
927
|
-
...
|
928
|
-
|
929
|
-
|
930
|
-
@overload
|
931
|
-
def pop_states(node, filter1: Filter, filter2: Filter, /, *filters: Filter) -> tuple[NestedDict, ...]:
|
932
|
-
...
|
933
|
-
|
934
|
-
|
935
926
|
@set_module_as('brainstate.graph')
|
936
927
|
def pop_states(
|
937
|
-
node:
|
938
|
-
|
939
|
-
) -> Union[NestedDict[Key, State], Tuple[NestedDict[Key, State], ...]]:
|
928
|
+
node: Any, *filters: Any
|
929
|
+
) -> Union[NestedDict, Tuple[NestedDict, ...]]:
|
940
930
|
"""
|
941
931
|
Pop one or more :class:`State` types from the graph node.
|
942
932
|
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
933
|
+
Parameters
|
934
|
+
----------
|
935
|
+
node : Node
|
936
|
+
A graph node object.
|
937
|
+
*filters
|
938
|
+
One or more :class:`State` objects to filter by.
|
939
|
+
|
940
|
+
Returns
|
941
|
+
-------
|
942
|
+
NestedDict or tuple[NestedDict, ...]
|
943
|
+
The popped :class:`NestedDict` containing the :class:`State`
|
944
|
+
objects that were filtered for.
|
945
|
+
|
946
|
+
Examples
|
947
|
+
--------
|
948
|
+
.. code-block:: python
|
949
|
+
|
950
|
+
>>> import brainstate
|
951
|
+
>>> import jax.numpy as jnp
|
952
|
+
|
953
|
+
>>> class Model(brainstate.nn.Module):
|
954
|
+
... def __init__(self):
|
955
|
+
... super().__init__()
|
956
|
+
... self.a = brainstate.nn.Linear(2, 3)
|
957
|
+
... self.b = brainstate.nn.LIF([10, 2])
|
958
|
+
|
959
|
+
>>> model = Model()
|
960
|
+
>>> with brainstate.catch_new_states('new'):
|
961
|
+
... brainstate.nn.init_all_states(model)
|
962
|
+
|
963
|
+
>>> assert len(model.states()) == 2
|
964
|
+
>>> model_states = brainstate.graph.pop_states(model, 'new')
|
965
|
+
>>> model_states # doctest: +SKIP
|
966
|
+
NestedDict({
|
967
|
+
'b': {
|
968
|
+
'V': {
|
969
|
+
'st': ShortTermState(
|
970
|
+
value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
971
|
+
0., 0., 0.], dtype=float32),
|
972
|
+
tag='new'
|
973
|
+
)
|
974
|
+
}
|
969
975
|
}
|
970
|
-
}
|
971
|
-
})
|
972
|
-
|
973
|
-
Args:
|
974
|
-
node: A graph node object.
|
975
|
-
*filters: One or more :class:`State` objects to filter by.
|
976
|
-
|
977
|
-
Returns:
|
978
|
-
The popped :class:`NestedDict` containing the :class:`State`
|
979
|
-
objects that were filtered for.
|
976
|
+
})
|
980
977
|
"""
|
981
978
|
if len(filters) == 0:
|
982
979
|
raise ValueError('Expected at least one filter')
|
@@ -985,11 +982,13 @@ def pop_states(
|
|
985
982
|
path_parts: PathParts = ()
|
986
983
|
predicates = tuple(to_predicate(filter) for filter in filters)
|
987
984
|
flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
|
988
|
-
_graph_pop(
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
985
|
+
_graph_pop(
|
986
|
+
node=node,
|
987
|
+
id_to_index=id_to_index,
|
988
|
+
path_parts=path_parts,
|
989
|
+
flatted_state_dicts=flatted_state_dicts,
|
990
|
+
predicates=predicates,
|
991
|
+
)
|
993
992
|
states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
|
994
993
|
|
995
994
|
if len(states) == 1:
|
@@ -1011,94 +1010,49 @@ def _split_state(
|
|
1011
1010
|
return states # type: ignore[return-value]
|
1012
1011
|
|
1013
1012
|
|
1014
|
-
@overload
|
1015
|
-
def treefy_split(node: A, /) -> Tuple[GraphDef, NestedDict]:
|
1016
|
-
...
|
1017
|
-
|
1018
|
-
|
1019
|
-
@overload
|
1020
|
-
def treefy_split(node: A, first: Filter, /) -> Tuple[GraphDef, NestedDict]:
|
1021
|
-
...
|
1022
|
-
|
1023
|
-
|
1024
|
-
@overload
|
1025
|
-
def treefy_split(node: A, first: Filter, second: Filter, /) -> Tuple[GraphDef, NestedDict, NestedDict]:
|
1026
|
-
...
|
1027
|
-
|
1028
|
-
|
1029
|
-
@overload
|
1030
|
-
def treefy_split(
|
1031
|
-
node: A, first: Filter, second: Filter, /, *filters: Filter,
|
1032
|
-
) -> Tuple[GraphDef, NestedDict, Unpack[Tuple[NestedDict, ...]]]:
|
1033
|
-
...
|
1034
|
-
|
1035
|
-
|
1036
1013
|
@set_module_as('brainstate.graph')
|
1037
1014
|
def treefy_split(
|
1038
|
-
node: A,
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
|
1043
|
-
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
|
1044
|
-
to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
|
1045
|
-
seamlessly between stateful and stateless representations of the graph.
|
1046
|
-
|
1047
|
-
Example usage::
|
1048
|
-
|
1049
|
-
>>> from joblib.testing import param >>> import brainstate as brainstate
|
1050
|
-
>>> import jax, jax.numpy as jnp
|
1051
|
-
...
|
1052
|
-
>>> class Foo(brainstate.graph.Node):
|
1053
|
-
... def __init__(self):
|
1054
|
-
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1055
|
-
... self.b = brainstate.nn.Linear(2, 3)
|
1056
|
-
...
|
1057
|
-
>>> node = Foo()
|
1058
|
-
>>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
|
1059
|
-
...
|
1060
|
-
>>> params
|
1061
|
-
NestedDict({
|
1062
|
-
'a': {
|
1063
|
-
'weight': TreefyState(
|
1064
|
-
type=ParamState,
|
1065
|
-
value={'weight': Array([[-1.0013659, 1.5763807],
|
1066
|
-
[ 1.7149199, 2.0140953]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
|
1067
|
-
)
|
1068
|
-
},
|
1069
|
-
'b': {
|
1070
|
-
'weight': TreefyState(
|
1071
|
-
type=ParamState,
|
1072
|
-
value={'bias': Array([[0., 0.]], dtype=float32), 'scale': Array([[1., 1.]], dtype=float32)}
|
1073
|
-
)
|
1074
|
-
}
|
1075
|
-
})
|
1076
|
-
>>> jax.tree.map(jnp.shape, others)
|
1077
|
-
NestedDict({
|
1078
|
-
'b': {
|
1079
|
-
'running_mean': TreefyState(
|
1080
|
-
type=LongTermState,
|
1081
|
-
value=(1, 2)
|
1082
|
-
),
|
1083
|
-
'running_var': TreefyState(
|
1084
|
-
type=LongTermState,
|
1085
|
-
value=(1, 2)
|
1086
|
-
)
|
1087
|
-
}
|
1088
|
-
})
|
1089
|
-
|
1090
|
-
:func:`split` and :func:`merge` are primarily used to interact directly with JAX
|
1091
|
-
transformations, see
|
1092
|
-
`Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
|
1093
|
-
for more information.
|
1094
|
-
|
1095
|
-
Arguments:
|
1096
|
-
node: graph node to split.
|
1097
|
-
*filters: some optional filters to group the state into mutually exclusive substates.
|
1015
|
+
node: A, *filters: Filter
|
1016
|
+
):
|
1017
|
+
"""
|
1018
|
+
Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s.
|
1098
1019
|
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1020
|
+
NestedDict is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States.
|
1021
|
+
GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is
|
1022
|
+
analogous to JAX's ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to
|
1023
|
+
switch seamlessly between stateful and stateless representations of the graph.
|
1024
|
+
|
1025
|
+
Parameters
|
1026
|
+
----------
|
1027
|
+
node : A
|
1028
|
+
Graph node to split.
|
1029
|
+
*filters
|
1030
|
+
Optional filters to group the state into mutually exclusive substates.
|
1031
|
+
|
1032
|
+
Returns
|
1033
|
+
-------
|
1034
|
+
tuple
|
1035
|
+
``GraphDef`` and one or more ``States`` equal to the number of filters passed.
|
1036
|
+
If no filters are passed, a single ``NestedDict`` is returned.
|
1037
|
+
|
1038
|
+
Examples
|
1039
|
+
--------
|
1040
|
+
.. code-block:: python
|
1041
|
+
|
1042
|
+
>>> import brainstate
|
1043
|
+
>>> import jax, jax.numpy as jnp
|
1044
|
+
|
1045
|
+
>>> class Foo(brainstate.graph.Node):
|
1046
|
+
... def __init__(self):
|
1047
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1048
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1049
|
+
...
|
1050
|
+
>>> node = Foo()
|
1051
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(
|
1052
|
+
... node, brainstate.ParamState, ...
|
1053
|
+
... )
|
1054
|
+
>>> # params contains ParamState variables
|
1055
|
+
>>> # others contains all other state variables
|
1102
1056
|
"""
|
1103
1057
|
graphdef, state_tree = flatten(node)
|
1104
1058
|
states = tuple(_split_state(state_tree, filters))
|
@@ -1106,49 +1060,47 @@ def treefy_split(
|
|
1106
1060
|
|
1107
1061
|
|
1108
1062
|
@set_module_as('brainstate.graph')
|
1109
|
-
def treefy_merge(
|
1110
|
-
|
1111
|
-
|
1112
|
-
/,
|
1113
|
-
*state_mappings: GraphStateMapping,
|
1114
|
-
) -> A:
|
1115
|
-
"""The inverse of :func:`split`.
|
1063
|
+
def treefy_merge(graphdef: GraphDef[A], *state_mappings) -> A:
|
1064
|
+
"""
|
1065
|
+
The inverse of :func:`split`.
|
1116
1066
|
|
1117
1067
|
``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
|
1118
1068
|
a new node with the same structure as the original node.
|
1119
1069
|
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1070
|
+
Parameters
|
1071
|
+
----------
|
1072
|
+
graphdef : GraphDef[A]
|
1073
|
+
A :class:`GraphDef` object.
|
1074
|
+
*state_mappings
|
1075
|
+
Additional :class:`NestedDict` objects.
|
1076
|
+
|
1077
|
+
Returns
|
1078
|
+
-------
|
1079
|
+
A
|
1080
|
+
The merged :class:`Module`.
|
1081
|
+
|
1082
|
+
Examples
|
1083
|
+
--------
|
1084
|
+
.. code-block:: python
|
1085
|
+
|
1086
|
+
>>> import brainstate
|
1087
|
+
>>> import jax, jax.numpy as jnp
|
1088
|
+
|
1089
|
+
>>> class Foo(brainstate.graph.Node):
|
1090
|
+
... def __init__(self):
|
1091
|
+
... self.a = brainstate.nn.BatchNorm1d([10, 2])
|
1092
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1093
|
+
...
|
1094
|
+
>>> node = Foo()
|
1095
|
+
>>> graphdef, params, others = brainstate.graph.treefy_split(
|
1096
|
+
... node, brainstate.ParamState, ...
|
1097
|
+
... )
|
1098
|
+
>>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
|
1099
|
+
>>> assert isinstance(new_node, Foo)
|
1100
|
+
>>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
|
1101
|
+
>>> assert isinstance(new_node.a, brainstate.nn.Linear)
|
1150
1102
|
"""
|
1151
|
-
state_mapping = GraphStateMapping.merge(
|
1103
|
+
state_mapping = GraphStateMapping.merge(*state_mappings)
|
1152
1104
|
node = unflatten(graphdef, state_mapping)
|
1153
1105
|
return node
|
1154
1106
|
|
@@ -1186,31 +1138,27 @@ def _split_flatted(
|
|
1186
1138
|
return flat_states
|
1187
1139
|
|
1188
1140
|
|
1189
|
-
@overload
|
1190
|
-
def nodes(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
|
1191
|
-
...
|
1192
|
-
|
1193
|
-
|
1194
|
-
@overload
|
1195
|
-
def nodes(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
|
1196
|
-
...
|
1197
|
-
|
1198
|
-
|
1199
|
-
@overload
|
1200
|
-
def nodes(
|
1201
|
-
node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
|
1202
|
-
) -> Tuple[FlattedDict[Key, Node], ...]:
|
1203
|
-
...
|
1204
|
-
|
1205
|
-
|
1206
1141
|
@set_module_as('brainstate.graph')
|
1207
1142
|
def nodes(
|
1208
|
-
node,
|
1209
|
-
|
1210
|
-
allowed_hierarchy: Tuple[int, int] = (0, _max_int)
|
1211
|
-
) -> Union[FlattedDict[Key, Node], Tuple[FlattedDict[Key, Node], ...]]:
|
1143
|
+
node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1144
|
+
):
|
1212
1145
|
"""
|
1213
1146
|
Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
|
1147
|
+
|
1148
|
+
Parameters
|
1149
|
+
----------
|
1150
|
+
node : Node
|
1151
|
+
The node to get nodes from.
|
1152
|
+
*filters
|
1153
|
+
Filters to apply to the nodes.
|
1154
|
+
allowed_hierarchy : tuple[int, int], optional
|
1155
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1156
|
+
|
1157
|
+
Returns
|
1158
|
+
-------
|
1159
|
+
FlattedDict or tuple[FlattedDict, ...]
|
1160
|
+
The filtered nodes.
|
1161
|
+
|
1214
1162
|
"""
|
1215
1163
|
num_filters = len(filters)
|
1216
1164
|
if num_filters == 0:
|
@@ -1232,31 +1180,27 @@ def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, Stat
|
|
1232
1180
|
yield path, value
|
1233
1181
|
|
1234
1182
|
|
1235
|
-
@overload
|
1236
|
-
def states(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
|
1237
|
-
...
|
1238
|
-
|
1239
|
-
|
1240
|
-
@overload
|
1241
|
-
def states(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
|
1242
|
-
...
|
1243
|
-
|
1244
|
-
|
1245
|
-
@overload
|
1246
|
-
def states(
|
1247
|
-
node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
|
1248
|
-
) -> tuple[FlattedDict[Key, State], ...]:
|
1249
|
-
...
|
1250
|
-
|
1251
|
-
|
1252
1183
|
@set_module_as('brainstate.graph')
|
1253
1184
|
def states(
|
1254
|
-
node,
|
1255
|
-
|
1256
|
-
allowed_hierarchy: Tuple[int, int] = (0, _max_int)
|
1257
|
-
) -> Union[FlattedDict[Key, State], tuple[FlattedDict[Key, State], ...]]:
|
1185
|
+
node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1186
|
+
) -> Union[FlattedDict, tuple[FlattedDict, ...]]:
|
1258
1187
|
"""
|
1259
1188
|
Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
|
1189
|
+
|
1190
|
+
Parameters
|
1191
|
+
----------
|
1192
|
+
node : Node
|
1193
|
+
The node to get states from.
|
1194
|
+
*filters
|
1195
|
+
Filters to apply to the states.
|
1196
|
+
allowed_hierarchy : tuple[int, int], optional
|
1197
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1198
|
+
|
1199
|
+
Returns
|
1200
|
+
-------
|
1201
|
+
FlattedDict or tuple[FlattedDict, ...]
|
1202
|
+
The filtered states.
|
1203
|
+
|
1260
1204
|
"""
|
1261
1205
|
num_filters = len(filters)
|
1262
1206
|
if num_filters == 0:
|
@@ -1272,72 +1216,60 @@ def states(
|
|
1272
1216
|
return state_maps[:num_filters]
|
1273
1217
|
|
1274
1218
|
|
1275
|
-
@overload
|
1276
|
-
def treefy_states(
|
1277
|
-
node, /, flatted: bool = False
|
1278
|
-
) -> NestedDict[Key, TreefyState]:
|
1279
|
-
...
|
1280
|
-
|
1281
|
-
|
1282
|
-
@overload
|
1283
|
-
def treefy_states(
|
1284
|
-
node, first: Filter, /, flatted: bool = False
|
1285
|
-
) -> NestedDict[Key, TreefyState]:
|
1286
|
-
...
|
1287
|
-
|
1288
|
-
|
1289
|
-
@overload
|
1290
|
-
def treefy_states(
|
1291
|
-
node, first: Filter, second: Filter, /, *filters: Filter, flatted: bool = False
|
1292
|
-
) -> Tuple[NestedDict[Key, TreefyState], ...]:
|
1293
|
-
...
|
1294
|
-
|
1295
|
-
|
1296
1219
|
@set_module_as('brainstate.graph')
|
1297
1220
|
def treefy_states(
|
1298
1221
|
node, *filters,
|
1299
|
-
)
|
1222
|
+
):
|
1300
1223
|
"""
|
1301
1224
|
Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
|
1302
1225
|
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1226
|
+
Parameters
|
1227
|
+
----------
|
1228
|
+
node : Node
|
1229
|
+
A graph node object.
|
1230
|
+
*filters
|
1231
|
+
One or more :class:`State` objects to filter by.
|
1232
|
+
|
1233
|
+
Returns
|
1234
|
+
-------
|
1235
|
+
NestedDict or tuple of NestedDict
|
1236
|
+
One or more :class:`NestedDict` mappings.
|
1237
|
+
|
1238
|
+
Examples
|
1239
|
+
--------
|
1240
|
+
.. code-block:: python
|
1241
|
+
|
1242
|
+
>>> import brainstate
|
1243
|
+
>>> class Model(brainstate.nn.Module):
|
1244
|
+
... def __init__(self):
|
1245
|
+
... super().__init__()
|
1246
|
+
... self.l1 = brainstate.nn.Linear(2, 3)
|
1247
|
+
... self.l2 = brainstate.nn.Linear(3, 4)
|
1248
|
+
... def __call__(self, x):
|
1249
|
+
... return self.l2(self.l1(x))
|
1250
|
+
|
1251
|
+
>>> model = Model()
|
1252
|
+
>>> # Get the learnable parameters
|
1253
|
+
>>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
|
1254
|
+
>>> # Get them separately
|
1255
|
+
>>> params, others = brainstate.graph.treefy_states(
|
1256
|
+
... model, brainstate.ParamState, brainstate.ShortTermState
|
1257
|
+
... )
|
1258
|
+
>>> # Get all states together
|
1259
|
+
>>> states = brainstate.graph.treefy_states(model)
|
1328
1260
|
"""
|
1329
1261
|
_, state_mapping = flatten(node)
|
1330
|
-
state_mappings: GraphStateMapping | tuple[GraphStateMapping, ...]
|
1331
1262
|
if len(filters) == 0:
|
1332
|
-
|
1333
|
-
elif len(filters) == 1:
|
1334
|
-
state_mappings = state_mapping.filter(filters[0])
|
1263
|
+
return state_mapping
|
1335
1264
|
else:
|
1336
|
-
state_mappings = state_mapping.filter(
|
1337
|
-
|
1265
|
+
state_mappings = state_mapping.filter(*filters)
|
1266
|
+
if len(filters) == 1:
|
1267
|
+
return state_mappings[0]
|
1268
|
+
else:
|
1269
|
+
return state_mappings
|
1338
1270
|
|
1339
1271
|
|
1340
|
-
def _graph_update_dynamic(node: Any, state: Mapping
|
1272
|
+
def _graph_update_dynamic(node: Any, state: Mapping) -> None:
|
1341
1273
|
if not _is_node(node):
|
1342
1274
|
raise RuntimeError(f'Unsupported type: {type(node)}')
|
1343
1275
|
|
@@ -1350,7 +1282,8 @@ def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
|
|
1350
1282
|
raise ValueError(f'Cannot set key {key!r} on immutable node of '
|
1351
1283
|
f'type {type(node).__name__}')
|
1352
1284
|
if isinstance(value, State):
|
1353
|
-
|
1285
|
+
# TODO: here maybe error? we should check if the state already belongs to another node?
|
1286
|
+
value = value.to_state_ref() # Convert to state reference for proper state management
|
1354
1287
|
node_impl.set_key(node, key, value)
|
1355
1288
|
continue
|
1356
1289
|
|
@@ -1379,18 +1312,23 @@ def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
|
|
1379
1312
|
|
1380
1313
|
|
1381
1314
|
def update_states(
|
1382
|
-
node:
|
1383
|
-
state_dict: NestedDict
|
1315
|
+
node: Any,
|
1316
|
+
state_dict: Union[NestedDict, FlattedDict],
|
1384
1317
|
/,
|
1385
|
-
*state_dicts: NestedDict
|
1318
|
+
*state_dicts: Union[NestedDict, FlattedDict]
|
1386
1319
|
) -> None:
|
1387
1320
|
"""
|
1388
1321
|
Update the given graph node with a new :class:`NestedMapping` in-place.
|
1389
1322
|
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1323
|
+
Parameters
|
1324
|
+
----------
|
1325
|
+
node : Node
|
1326
|
+
A graph node to update.
|
1327
|
+
state_dict : NestedDict | FlattedDict
|
1328
|
+
A :class:`NestedMapping` object.
|
1329
|
+
*state_dicts : NestedDict | FlattedDict
|
1330
|
+
Additional :class:`NestedMapping` objects.
|
1331
|
+
|
1394
1332
|
"""
|
1395
1333
|
if state_dicts:
|
1396
1334
|
state_dict = NestedDict.merge(state_dict, *state_dicts)
|
@@ -1398,177 +1336,110 @@ def update_states(
|
|
1398
1336
|
|
1399
1337
|
|
1400
1338
|
@set_module_as('brainstate.graph')
|
1401
|
-
def graphdef(node: Any
|
1402
|
-
"""
|
1339
|
+
def graphdef(node: Any) -> GraphDef[Any]:
|
1340
|
+
"""
|
1341
|
+
Get the :class:`GraphDef` of the given graph node.
|
1342
|
+
|
1343
|
+
Parameters
|
1344
|
+
----------
|
1345
|
+
node : Any
|
1346
|
+
A graph node object.
|
1403
1347
|
|
1404
|
-
|
1348
|
+
Returns
|
1349
|
+
-------
|
1350
|
+
GraphDef[Any]
|
1351
|
+
The :class:`GraphDef` of the :class:`Module` object.
|
1405
1352
|
|
1406
|
-
|
1353
|
+
Examples
|
1354
|
+
--------
|
1355
|
+
.. code-block:: python
|
1407
1356
|
|
1408
|
-
|
1409
|
-
>>> graphdef, _ = brainstate.graph.treefy_split(model)
|
1410
|
-
>>> assert graphdef == brainstate.graph.graphdef(model)
|
1357
|
+
>>> import brainstate
|
1411
1358
|
|
1412
|
-
|
1413
|
-
|
1359
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1360
|
+
>>> graphdef, _ = brainstate.graph.treefy_split(model)
|
1361
|
+
>>> assert graphdef == brainstate.graph.graphdef(model)
|
1414
1362
|
|
1415
|
-
Returns:
|
1416
|
-
The :class:`GraphDef` of the :class:`Module` object.
|
1417
1363
|
"""
|
1418
1364
|
graphdef, _ = flatten(node)
|
1419
1365
|
return graphdef
|
1420
1366
|
|
1421
1367
|
|
1422
1368
|
@set_module_as('brainstate.graph')
|
1423
|
-
def clone(node:
|
1369
|
+
def clone(node: A) -> A:
|
1424
1370
|
"""
|
1425
1371
|
Create a deep copy of the given graph node.
|
1426
1372
|
|
1427
|
-
|
1373
|
+
Parameters
|
1374
|
+
----------
|
1375
|
+
node : Node
|
1376
|
+
A graph node object.
|
1428
1377
|
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
>>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
|
1378
|
+
Returns
|
1379
|
+
-------
|
1380
|
+
Node
|
1381
|
+
A deep copy of the :class:`Module` object.
|
1434
1382
|
|
1435
|
-
|
1436
|
-
|
1383
|
+
Examples
|
1384
|
+
--------
|
1385
|
+
.. code-block:: python
|
1386
|
+
|
1387
|
+
>>> import brainstate
|
1388
|
+
>>> model = brainstate.nn.Linear(2, 3)
|
1389
|
+
>>> cloned_model = brainstate.graph.clone(model)
|
1390
|
+
>>> model.weight.value['bias'] += 1
|
1391
|
+
>>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
|
1437
1392
|
|
1438
|
-
Returns:
|
1439
|
-
A deep copy of the :class:`Module` object.
|
1440
1393
|
"""
|
1441
1394
|
graphdef, state = treefy_split(node)
|
1442
1395
|
return treefy_merge(graphdef, state)
|
1443
1396
|
|
1444
1397
|
|
1445
|
-
@set_module_as('brainstate.graph')
|
1446
|
-
def call(
|
1447
|
-
graphdef_state: Tuple[GraphDef[A], GraphStateMapping],
|
1448
|
-
) -> ApplyCaller[Tuple[GraphDef[A], GraphStateMapping]]:
|
1449
|
-
"""Calls a method underlying graph node defined by a (GraphDef, NestedDict) pair.
|
1450
|
-
|
1451
|
-
``call`` takes a ``(GraphDef, NestedDict)`` pair and creates a proxy object that can be
|
1452
|
-
used to call methods on the underlying graph node. When a method is called, the
|
1453
|
-
output is returned along with a new (GraphDef, NestedDict) pair that represents the
|
1454
|
-
updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method``
|
1455
|
-
> :func:`split`` but is more convenient to use in pure JAX functions.
|
1456
|
-
|
1457
|
-
Example::
|
1458
|
-
|
1459
|
-
>>> import brainstate as brainstate
|
1460
|
-
>>> import jax
|
1461
|
-
>>> import jax.numpy as jnp
|
1462
|
-
...
|
1463
|
-
>>> class StatefulLinear(brainstate.graph.Node):
|
1464
|
-
... def __init__(self, din, dout):
|
1465
|
-
... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
1466
|
-
... self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
1467
|
-
... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
1468
|
-
...
|
1469
|
-
... def increment(self):
|
1470
|
-
... self.count.value += 1
|
1471
|
-
...
|
1472
|
-
... def __call__(self, x):
|
1473
|
-
... self.increment()
|
1474
|
-
... return x @ self.w.value + self.b.value
|
1475
|
-
...
|
1476
|
-
>>> linear = StatefulLinear(3, 2)
|
1477
|
-
>>> linear_state = brainstate.graph.treefy_split(linear)
|
1478
|
-
...
|
1479
|
-
>>> @jax.jit
|
1480
|
-
... def forward(x, linear_state):
|
1481
|
-
... y, linear_state = brainstate.graph.call(linear_state)(x)
|
1482
|
-
... return y, linear_state
|
1483
|
-
...
|
1484
|
-
>>> x = jnp.ones((1, 3))
|
1485
|
-
>>> y, linear_state = forward(x, linear_state)
|
1486
|
-
>>> y, linear_state = forward(x, linear_state)
|
1487
|
-
...
|
1488
|
-
>>> linear = brainstate.graph.treefy_merge(*linear_state)
|
1489
|
-
>>> linear.count.value
|
1490
|
-
Array(2, dtype=uint32)
|
1491
|
-
|
1492
|
-
The proxy object returned by ``call`` supports indexing and attribute access
|
1493
|
-
to access nested methods. In the example below, the ``increment`` method indexing
|
1494
|
-
is used to call the ``increment`` method of the ``StatefulLinear`` module
|
1495
|
-
at the ``b`` key of a ``nodes`` dictionary.
|
1496
|
-
|
1497
|
-
>>> class StatefulLinear(brainstate.graph.Node):
|
1498
|
-
... def __init__(self, din, dout):
|
1499
|
-
... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
|
1500
|
-
... self.b = brainstate.ParamState(jnp.zeros((dout,)))
|
1501
|
-
... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
|
1502
|
-
...
|
1503
|
-
... def increment(self):
|
1504
|
-
... self.count.value += 1
|
1505
|
-
...
|
1506
|
-
... def __call__(self, x):
|
1507
|
-
... self.increment()
|
1508
|
-
... return x @ self.w.value + self.b.value
|
1509
|
-
...
|
1510
|
-
>>> nodes = dict(
|
1511
|
-
... a=StatefulLinear(3, 2),
|
1512
|
-
... b=StatefulLinear(2, 1),
|
1513
|
-
... )
|
1514
|
-
...
|
1515
|
-
>>> node_state = treefy_split(nodes)
|
1516
|
-
>>> # use attribute access
|
1517
|
-
>>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
|
1518
|
-
...
|
1519
|
-
>>> nodes = treefy_merge(*node_state)
|
1520
|
-
>>> nodes['a'].count.value
|
1521
|
-
Array(0, dtype=uint32)
|
1522
|
-
>>> nodes['b'].count.value
|
1523
|
-
Array(1, dtype=uint32)
|
1524
|
-
"""
|
1525
|
-
|
1526
|
-
def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
|
1527
|
-
node = treefy_merge(*graphdef_state)
|
1528
|
-
method = accessor(node)
|
1529
|
-
out = method(*args, **kwargs)
|
1530
|
-
return out, treefy_split(node)
|
1531
|
-
|
1532
|
-
return CallableProxy(pure_caller) # type: ignore
|
1533
|
-
|
1534
|
-
|
1535
1398
|
@set_module_as('brainstate.graph')
|
1536
1399
|
def iter_leaf(
|
1537
|
-
node: Any,
|
1538
|
-
allowed_hierarchy: Tuple[int, int] = (0, _max_int)
|
1400
|
+
node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1539
1401
|
) -> Iterator[tuple[PathParts, Any]]:
|
1540
|
-
"""
|
1402
|
+
"""
|
1403
|
+
Iterates over all nested leaves in the given graph node, including the current node.
|
1541
1404
|
|
1542
1405
|
``iter_graph`` creates a generator that yields path and value pairs, where
|
1543
1406
|
the path is a tuple of strings or integers representing the path to the value from the
|
1544
1407
|
root. Repeated nodes are visited only once. Leaves include static values.
|
1545
1408
|
|
1546
|
-
Example::
|
1547
|
-
>>> import brainstate as brainstate
|
1548
|
-
>>> import jax.numpy as jnp
|
1549
|
-
...
|
1550
|
-
>>> class Linear(brainstate.nn.Module):
|
1551
|
-
... def __init__(self, din, dout):
|
1552
|
-
... super().__init__()
|
1553
|
-
... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
1554
|
-
... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
1555
|
-
... self.a = 1
|
1556
|
-
...
|
1557
|
-
>>> module = Linear(3, 4)
|
1558
|
-
...
|
1559
|
-
>>> for path, value in brainstate.graph.iter_leaf([module, module]):
|
1560
|
-
... print(path, type(value).__name__)
|
1561
|
-
...
|
1562
|
-
(0, 'a') int
|
1563
|
-
(0, 'bias') ParamState
|
1564
|
-
(0, 'weight') ParamState
|
1565
|
-
|
1566
1409
|
Parameters
|
1567
1410
|
----------
|
1568
|
-
node:
|
1569
|
-
|
1570
|
-
allowed_hierarchy: tuple
|
1571
|
-
|
1411
|
+
node : Any
|
1412
|
+
The node to iterate over.
|
1413
|
+
allowed_hierarchy : tuple[int, int], optional
|
1414
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1415
|
+
|
1416
|
+
Yields
|
1417
|
+
------
|
1418
|
+
Iterator[tuple[PathParts, Any]]
|
1419
|
+
Path and value pairs.
|
1420
|
+
|
1421
|
+
Examples
|
1422
|
+
--------
|
1423
|
+
.. code-block:: python
|
1424
|
+
|
1425
|
+
>>> import brainstate
|
1426
|
+
>>> import jax.numpy as jnp
|
1427
|
+
|
1428
|
+
>>> class Linear(brainstate.nn.Module):
|
1429
|
+
... def __init__(self, din, dout):
|
1430
|
+
... super().__init__()
|
1431
|
+
... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
|
1432
|
+
... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
|
1433
|
+
... self.a = 1
|
1434
|
+
...
|
1435
|
+
>>> module = Linear(3, 4)
|
1436
|
+
...
|
1437
|
+
>>> for path, value in brainstate.graph.iter_leaf([module, module]):
|
1438
|
+
... print(path, type(value).__name__)
|
1439
|
+
...
|
1440
|
+
(0, 'a') int
|
1441
|
+
(0, 'bias') ParamState
|
1442
|
+
(0, 'weight') ParamState
|
1572
1443
|
|
1573
1444
|
"""
|
1574
1445
|
|
@@ -1605,8 +1476,7 @@ def iter_leaf(
|
|
1605
1476
|
|
1606
1477
|
@set_module_as('brainstate.graph')
|
1607
1478
|
def iter_node(
|
1608
|
-
node: Any,
|
1609
|
-
allowed_hierarchy: Tuple[int, int] = (0, _max_int)
|
1479
|
+
node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
|
1610
1480
|
) -> Iterator[Tuple[PathParts, Any]]:
|
1611
1481
|
"""
|
1612
1482
|
Iterates over all nested nodes of the given graph node, including the current node.
|
@@ -1615,39 +1485,47 @@ def iter_node(
|
|
1615
1485
|
the path is a tuple of strings or integers representing the path to the value from the
|
1616
1486
|
root. Repeated nodes are visited only once. Leaves include static values.
|
1617
1487
|
|
1618
|
-
Example::
|
1619
|
-
>>> import brainstate as brainstate
|
1620
|
-
>>> import jax.numpy as jnp
|
1621
|
-
...
|
1622
|
-
>>> class Model(brainstate.nn.Module):
|
1623
|
-
... def __init__(self):
|
1624
|
-
... super().__init__()
|
1625
|
-
... self.a = brainstate.nn.Linear(1, 2)
|
1626
|
-
... self.b = brainstate.nn.Linear(2, 3)
|
1627
|
-
... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
1628
|
-
... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
1629
|
-
... self.b.a = brainstate.nn.LIF(2)
|
1630
|
-
...
|
1631
|
-
>>> model = Model()
|
1632
|
-
...
|
1633
|
-
>>> for path, node in brainstate.graph.iter_node([model, model]):
|
1634
|
-
... print(path, node.__class__.__name__)
|
1635
|
-
...
|
1636
|
-
(0, 'a') Linear
|
1637
|
-
(0, 'b', 'a') LIF
|
1638
|
-
(0, 'b') Linear
|
1639
|
-
(0, 'c', 0) Linear
|
1640
|
-
(0, 'c', 1) Linear
|
1641
|
-
(0, 'd', 'x') Linear
|
1642
|
-
(0, 'd', 'y') Linear
|
1643
|
-
(0,) Model
|
1644
|
-
|
1645
1488
|
Parameters
|
1646
1489
|
----------
|
1647
|
-
node:
|
1648
|
-
|
1649
|
-
allowed_hierarchy: tuple
|
1650
|
-
|
1490
|
+
node : Any
|
1491
|
+
The node to iterate over.
|
1492
|
+
allowed_hierarchy : tuple[int, int], optional
|
1493
|
+
The allowed hierarchy levels, by default (0, MAX_INT).
|
1494
|
+
|
1495
|
+
Yields
|
1496
|
+
------
|
1497
|
+
Iterator[tuple[PathParts, Any]]
|
1498
|
+
Path and node pairs.
|
1499
|
+
|
1500
|
+
Examples
|
1501
|
+
--------
|
1502
|
+
.. code-block:: python
|
1503
|
+
|
1504
|
+
>>> import brainstate
|
1505
|
+
>>> import jax.numpy as jnp
|
1506
|
+
|
1507
|
+
>>> class Model(brainstate.nn.Module):
|
1508
|
+
... def __init__(self):
|
1509
|
+
... super().__init__()
|
1510
|
+
... self.a = brainstate.nn.Linear(1, 2)
|
1511
|
+
... self.b = brainstate.nn.Linear(2, 3)
|
1512
|
+
... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
|
1513
|
+
... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
|
1514
|
+
... self.b.a = brainstate.nn.LIF(2)
|
1515
|
+
...
|
1516
|
+
>>> model = Model()
|
1517
|
+
...
|
1518
|
+
>>> for path, node in brainstate.graph.iter_node([model, model]):
|
1519
|
+
... print(path, node.__class__.__name__)
|
1520
|
+
...
|
1521
|
+
(0, 'a') Linear
|
1522
|
+
(0, 'b', 'a') LIF
|
1523
|
+
(0, 'b') Linear
|
1524
|
+
(0, 'c', 0) Linear
|
1525
|
+
(0, 'c', 1) Linear
|
1526
|
+
(0, 'd', 'x') Linear
|
1527
|
+
(0, 'd', 'y') Linear
|
1528
|
+
(0,) Model
|
1651
1529
|
|
1652
1530
|
"""
|
1653
1531
|
|
@@ -1686,8 +1564,16 @@ def iter_node(
|
|
1686
1564
|
|
1687
1565
|
@dataclasses.dataclass(frozen=True)
|
1688
1566
|
class Static(Generic[A]):
|
1689
|
-
"""
|
1567
|
+
"""
|
1568
|
+
An empty pytree node that treats its inner value as static.
|
1569
|
+
|
1690
1570
|
``value`` must define ``__eq__`` and ``__hash__``.
|
1571
|
+
|
1572
|
+
Attributes
|
1573
|
+
----------
|
1574
|
+
value : A
|
1575
|
+
The static value to wrap.
|
1576
|
+
|
1691
1577
|
"""
|
1692
1578
|
|
1693
1579
|
value: A
|
@@ -1721,16 +1607,16 @@ def _key_path_to_key(key: Any) -> Key:
|
|
1721
1607
|
return str(key)
|
1722
1608
|
|
1723
1609
|
|
1724
|
-
def _flatten_pytree(pytree: Any):
|
1610
|
+
def _flatten_pytree(pytree: Any) -> Tuple[Tuple[Tuple, ...], jax.tree_util.PyTreeDef]:
|
1725
1611
|
leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
|
1726
1612
|
nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
|
1727
1613
|
return nodes, treedef
|
1728
1614
|
|
1729
1615
|
|
1730
1616
|
def _unflatten_pytree(
|
1731
|
-
nodes: tuple[tuple
|
1617
|
+
nodes: tuple[tuple, ...],
|
1732
1618
|
treedef: jax.tree_util.PyTreeDef
|
1733
|
-
):
|
1619
|
+
) -> Any:
|
1734
1620
|
pytree = treedef.unflatten(value for _, value in nodes)
|
1735
1621
|
return pytree
|
1736
1622
|
|