brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from __future__ import annotations
20
20
  import dataclasses
21
21
  from typing import (
22
22
  Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
23
- Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload
23
+ Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional
24
24
  )
25
25
 
26
26
  import jax
@@ -30,25 +30,40 @@ from typing_extensions import TypeGuard, Unpack
30
30
  from brainstate._state import State, TreefyState
31
31
  from brainstate._utils import set_module_as
32
32
  from brainstate.typing import PathParts, Filter, Predicate, Key
33
- from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
34
- from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
- from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
- from brainstate.util.struct import FrozenDict
33
+ from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
34
+ from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
37
35
  from brainstate.util.filter import to_predicate
38
-
39
- _max_int = np.iinfo(np.int32).max
36
+ from brainstate.util.struct import FrozenDict
40
37
 
41
38
  __all__ = [
39
+ 'register_graph_node_type',
40
+
42
41
  # state management in the given graph or node
43
- 'pop_states', 'nodes', 'states', 'treefy_states', 'update_states',
42
+ 'pop_states',
43
+ 'nodes',
44
+ 'states',
45
+ 'treefy_states',
46
+ 'update_states',
44
47
 
45
48
  # graph node operations
46
- 'flatten', 'unflatten', 'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef', 'call',
49
+ 'flatten',
50
+ 'unflatten',
51
+ 'treefy_split',
52
+ 'treefy_merge',
53
+ 'iter_leaf',
54
+ 'iter_node',
55
+ 'clone',
56
+ 'graphdef',
47
57
 
48
58
  # others
49
- 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef'
59
+ 'RefMap',
60
+ 'GraphDef',
61
+ 'NodeDef',
62
+ 'NodeRef',
50
63
  ]
51
64
 
65
+ MAX_INT = np.iinfo(np.int32).max
66
+
52
67
  A = TypeVar('A')
53
68
  B = TypeVar('B')
54
69
  C = TypeVar('C')
@@ -65,12 +80,11 @@ AuxData = TypeVar('AuxData')
65
80
 
66
81
  StateLeaf = TreefyState[Any]
67
82
  NodeLeaf = State[Any]
68
- GraphStateMapping = NestedDict[Key, StateLeaf]
83
+ GraphStateMapping = NestedDict
69
84
 
70
85
 
71
86
  # --------------------------------------------------------
72
87
 
73
-
74
88
  def _is_state_leaf(x: Any) -> TypeGuard[StateLeaf]:
75
89
  return isinstance(x, TreefyState)
76
90
 
@@ -86,13 +100,30 @@ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
86
100
  This mapping is useful when we want to keep track of objects
87
101
  that are being referenced by other objects.
88
102
 
89
- Args:
90
- mapping: A mapping or iterable of key-value pairs.
103
+ Parameters
104
+ ----------
105
+ mapping : Mapping[A, B] or Iterable[Tuple[A, B]], optional
106
+ A mapping or iterable of key-value pairs.
107
+
108
+ Examples
109
+ --------
110
+ .. code-block:: python
111
+
112
+ >>> import brainstate
113
+ >>> obj1 = object()
114
+ >>> obj2 = object()
115
+ >>> ref_map = brainstate.graph.RefMap()
116
+ >>> ref_map[obj1] = 'value1'
117
+ >>> ref_map[obj2] = 'value2'
118
+ >>> print(obj1 in ref_map)
119
+ True
120
+ >>> print(ref_map[obj1])
121
+ value1
91
122
 
92
123
  """
93
124
  __module__ = 'brainstate.graph'
94
125
 
95
- def __init__(self, mapping: Mapping[A, B] | Iterable[Tuple[A, B]] = ()):
126
+ def __init__(self, mapping: Union[Mapping[A, B], Iterable[Tuple[A, B]]] = ()) -> None:
96
127
  self._mapping: Dict[int, Tuple[A, B]] = {}
97
128
  self.update(mapping)
98
129
 
@@ -102,10 +133,10 @@ class RefMap(MutableMapping[A, B], MappingReprMixin[A, B]):
102
133
  def __contains__(self, key: Any) -> bool:
103
134
  return id(key) in self._mapping
104
135
 
105
- def __setitem__(self, key: A, value: B):
136
+ def __setitem__(self, key: A, value: B) -> None:
106
137
  self._mapping[id(key)] = (key, value)
107
138
 
108
- def __delitem__(self, key: A):
139
+ def __delitem__(self, key: A) -> None:
109
140
  del self._mapping[id(key)]
110
141
 
111
142
  def __iter__(self) -> Iterator[A]:
@@ -135,7 +166,7 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
135
166
  create_empty: Callable[[AuxData], Node]
136
167
  clear: Callable[[Node], None]
137
168
 
138
- def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]):
169
+ def init(self, node: Node, items: Tuple[Tuple[Key, Leaf], ...]) -> None:
139
170
  for key, value in items:
140
171
  self.set_key(node, key, value)
141
172
 
@@ -151,7 +182,7 @@ NodeImpl = Union[GraphNodeImpl[Node, Leaf, AuxData], PyTreeNodeImpl[Node, Leaf,
151
182
  # Graph Node implementation: start
152
183
  # --------------------------------------------------------
153
184
 
154
- _node_impl_for_type: dict[type, NodeImpl[Any, Any, Any]] = {}
185
+ _node_impl_for_type: dict[type, NodeImpl] = {}
155
186
 
156
187
 
157
188
  def register_graph_node_type(
@@ -165,13 +196,56 @@ def register_graph_node_type(
165
196
  """
166
197
  Register a graph node type.
167
198
 
168
- Args:
169
- type: The type of the node.
170
- flatten: A function that flattens the node into a sequence of key-value pairs.
171
- set_key: A function that sets a key in the node.
172
- pop_key: A function that pops a key from the node.
173
- create_empty: A function that creates an empty node.
174
- clear: A function that clears the node
199
+ Parameters
200
+ ----------
201
+ type : type
202
+ The type of the node.
203
+ flatten : Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]]
204
+ A function that flattens the node into a sequence of key-value pairs.
205
+ set_key : Callable[[Node, Key, Leaf], None]
206
+ A function that sets a key in the node.
207
+ pop_key : Callable[[Node, Key], Leaf]
208
+ A function that pops a key from the node.
209
+ create_empty : Callable[[AuxData], Node]
210
+ A function that creates an empty node.
211
+ clear : Callable[[Node], None]
212
+ A function that clears the node.
213
+
214
+ Examples
215
+ --------
216
+ .. code-block:: python
217
+
218
+ >>> import brainstate
219
+ >>> # Custom node type implementation
220
+ >>> class CustomNode:
221
+ ... def __init__(self):
222
+ ... self.data = {}
223
+ ...
224
+ >>> def flatten_custom(node):
225
+ ... return list(node.data.items()), None
226
+ ...
227
+ >>> def set_key_custom(node, key, value):
228
+ ... node.data[key] = value
229
+ ...
230
+ >>> def pop_key_custom(node, key):
231
+ ... return node.data.pop(key)
232
+ ...
233
+ >>> def create_empty_custom(metadata):
234
+ ... return CustomNode()
235
+ ...
236
+ >>> def clear_custom(node):
237
+ ... node.data.clear()
238
+ ...
239
+ >>> # Register the custom node type
240
+ >>> brainstate.graph.register_graph_node_type(
241
+ ... CustomNode,
242
+ ... flatten_custom,
243
+ ... set_key_custom,
244
+ ... pop_key_custom,
245
+ ... create_empty_custom,
246
+ ... clear_custom
247
+ ... )
248
+
175
249
  """
176
250
  _node_impl_for_type[type] = GraphNodeImpl(
177
251
  type=type,
@@ -200,11 +274,11 @@ def _is_graph_node(x: Any) -> bool:
200
274
  return type(x) in _node_impl_for_type
201
275
 
202
276
 
203
- def _is_node_type(x: type[Any]) -> bool:
277
+ def _is_node_type(x: Type[Any]) -> bool:
204
278
  return x in _node_impl_for_type or x is PytreeType
205
279
 
206
280
 
207
- def _get_node_impl(x: Node) -> NodeImpl[Node, Any, Any]:
281
+ def _get_node_impl(x: Any) -> NodeImpl:
208
282
  if isinstance(x, State):
209
283
  raise ValueError(f'State is not a node: {x}')
210
284
 
@@ -218,14 +292,14 @@ def _get_node_impl(x: Node) -> NodeImpl[Node, Any, Any]:
218
292
  return _node_impl_for_type[node_type]
219
293
 
220
294
 
221
- def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, Any, Any]:
295
+ def get_node_impl_for_type(x: Type[Any]) -> NodeImpl:
222
296
  if x is PytreeType:
223
297
  return PYTREE_NODE_IMPL
224
298
  return _node_impl_for_type[x]
225
299
 
226
300
 
227
301
  class HashableMapping(Mapping[HA, HB], Hashable):
228
- def __init__(self, mapping: Mapping[HA, HB] | Iterable[tuple[HA, HB]]):
302
+ def __init__(self, mapping: Union[Mapping[HA, HB], Iterable[tuple[HA, HB]]]) -> None:
229
303
  self._mapping = dict(mapping)
230
304
 
231
305
  def __contains__(self, key: object) -> bool:
@@ -259,57 +333,53 @@ class GraphDef(Generic[Node]):
259
333
  - index: The index of the node in the graph.
260
334
 
261
335
  It has two concrete subclasses:
336
+
262
337
  - :class:`NodeRef`: A reference to a node in the graph.
263
338
  - :class:`NodeDef`: A dataclass that denotes the graph structure of a :class:`Node` or a :class:`State`.
264
339
 
265
- """
266
- type: type[Node]
267
- index: int
268
-
269
-
270
- @dataclasses.dataclass(frozen=True, repr=False)
271
- class NodeRef(GraphDef[Node], PrettyRepr):
272
- """
273
- A reference to a node in the graph.
340
+ Attributes
341
+ ----------
342
+ type : Type[Node]
343
+ The type of the node.
344
+ index : int
345
+ The index of the node in the graph.
274
346
 
275
- The node can be instances of :class:`Node` or :class:`State`.
276
347
  """
277
- type: type[Node]
348
+ type: Type[Node]
278
349
  index: int
279
350
 
280
- def __pretty_repr__(self):
281
- yield PrettyType(type=type(self))
282
- yield PrettyAttr('type', self.type.__name__)
283
- yield PrettyAttr('index', self.index)
284
-
285
- def __treescope_repr__(self, path, subtree_renderer):
286
- """
287
- Treescope repr for the object.
288
- """
289
- import treescope # type: ignore[import-not-found,import-untyped]
290
- return treescope.repr_lib.render_object_constructor(
291
- object_type=type(self),
292
- attributes={'type': self.type, 'index': self.index},
293
- path=path,
294
- subtree_renderer=subtree_renderer,
295
- )
296
-
297
-
298
- jax.tree_util.register_static(NodeRef)
299
-
300
351
 
301
352
  @dataclasses.dataclass(frozen=True, repr=False)
302
353
  class NodeDef(GraphDef[Node], PrettyRepr):
303
354
  """
304
355
  A dataclass that denotes the tree structure of a node, either :class:`Node` or :class:`State`.
305
356
 
357
+ Attributes
358
+ ----------
359
+ type : Type[Node]
360
+ Type of the node.
361
+ index : int
362
+ Index of the node in the graph.
363
+ attributes : Tuple[Key, ...]
364
+ Attributes for the node.
365
+ subgraphs : HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
366
+ Mapping of subgraph definitions.
367
+ static_fields : HashableMapping
368
+ Mapping of static fields.
369
+ leaves : HashableMapping[Key, NodeRef[Any] | None]
370
+ Mapping of leaf nodes.
371
+ metadata : Hashable
372
+ Metadata associated with the node.
373
+ index_mapping : FrozenDict[Index, Index] | None
374
+ Index mapping for node references.
375
+
306
376
  """
307
377
 
308
378
  type: Type[Node] # type of the node
309
379
  index: int # index of the node in the graph
310
380
  attributes: Tuple[Key, ...] # attributes for the node
311
381
  subgraphs: HashableMapping[Key, NodeDef[Any] | NodeRef[Any]]
312
- static_fields: HashableMapping[Key, Any]
382
+ static_fields: HashableMapping
313
383
  leaves: HashableMapping[Key, NodeRef[Any] | None]
314
384
  metadata: Hashable
315
385
  index_mapping: FrozenDict[Index, Index] | None
@@ -321,7 +391,7 @@ class NodeDef(GraphDef[Node], PrettyRepr):
321
391
  index: int,
322
392
  attributes: tuple[Key, ...],
323
393
  subgraphs: Iterable[tuple[Key, NodeDef[Any] | NodeRef[Any]]],
324
- static_fields: Iterable[tuple[Key, Any]],
394
+ static_fields: Iterable[tuple],
325
395
  leaves: Iterable[tuple[Key, NodeRef[Any] | None]],
326
396
  metadata: Hashable,
327
397
  index_mapping: Mapping[Index, Index] | None,
@@ -349,24 +419,35 @@ class NodeDef(GraphDef[Node], PrettyRepr):
349
419
  yield PrettyAttr('metadata', self.metadata)
350
420
  yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
351
421
 
352
- def apply(
353
- self,
354
- state_map: GraphStateMapping,
355
- *state_maps: GraphStateMapping
356
- ) -> ApplyCaller[tuple[GraphDef[Node], GraphStateMapping]]:
357
- accessor = DelayedAccessor()
358
422
 
359
- def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[
360
- Any, tuple[GraphDef[Node], GraphStateMapping]]:
361
- module = treefy_merge(self, state_map, *state_maps)
362
- fn = accessor(module)
363
- out = fn(*args, **kwargs)
364
- return out, flatten(module)
423
+ jax.tree_util.register_static(NodeDef)
424
+
425
+
426
+ @dataclasses.dataclass(frozen=True, repr=False)
427
+ class NodeRef(GraphDef[Node], PrettyRepr):
428
+ """
429
+ A reference to a node in the graph.
430
+
431
+ The node can be instances of :class:`Node` or :class:`State`.
365
432
 
366
- return CallableProxy(_apply, accessor) # type: ignore
433
+ Attributes
434
+ ----------
435
+ type : Type[Node]
436
+ The type of the node being referenced.
437
+ index : int
438
+ The index of the node in the graph.
367
439
 
440
+ """
441
+ type: Type[Node]
442
+ index: int
368
443
 
369
- jax.tree_util.register_static(NodeDef)
444
+ def __pretty_repr__(self):
445
+ yield PrettyType(type=type(self))
446
+ yield PrettyAttr('type', self.type.__name__)
447
+ yield PrettyAttr('index', self.index)
448
+
449
+
450
+ jax.tree_util.register_static(NodeRef)
370
451
 
371
452
 
372
453
  # --------------------------------------------------------
@@ -378,20 +459,30 @@ def _graph_flatten(
378
459
  path: PathParts,
379
460
  ref_index: RefMap[Any, Index],
380
461
  flatted_state_mapping: Dict[PathParts, StateLeaf],
381
- node: Node,
462
+ node: Any,
382
463
  treefy_state: bool = False,
383
- ):
464
+ ) -> Union[NodeDef[Any], NodeRef[Any]]:
384
465
  """
385
466
  Recursive helper for graph flatten.
386
467
 
387
- Args:
388
- path: The path to the node.
389
- ref_index: A mapping from nodes to indexes.
390
- flatted_state_mapping: A mapping from paths to state leaves.
391
- node: The node to flatten.
468
+ Parameters
469
+ ----------
470
+ path : PathParts
471
+ The path to the node.
472
+ ref_index : RefMap[Any, Index]
473
+ A mapping from nodes to indexes.
474
+ flatted_state_mapping : Dict[PathParts, StateLeaf]
475
+ A mapping from paths to state leaves.
476
+ node : Node
477
+ The node to flatten.
478
+ treefy_state : bool, optional
479
+ Whether to convert states to TreefyState, by default False.
480
+
481
+ Returns
482
+ -------
483
+ NodeDef or NodeRef
484
+ A NodeDef or a NodeRef.
392
485
 
393
- Returns:
394
- A NodeDef or a NodeRef.
395
486
  """
396
487
  if not _is_node(node):
397
488
  raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
@@ -417,9 +508,9 @@ def _graph_flatten(
417
508
  else:
418
509
  index = -1
419
510
 
420
- subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
421
- static_fields: list[tuple[Key, Any]] = []
422
- leaves: list[tuple[Key, NodeRef | None]] = []
511
+ subgraphs: list[tuple[Key, Union[NodeDef[Any], NodeRef[Any]]]] = []
512
+ static_fields: list[tuple] = []
513
+ leaves: list[tuple[Key, Union[NodeRef[Any], None]]] = []
423
514
 
424
515
  # Flatten the node into a sequence of key-value pairs.
425
516
  values, metadata = node_impl.flatten(node)
@@ -450,41 +541,56 @@ def _graph_flatten(
450
541
  # The value is a static field.
451
542
  static_fields.append((key, value))
452
543
 
453
- nodedef = NodeDef.create(type=node_impl.type,
454
- index=index,
455
- attributes=tuple(key for key, _ in values),
456
- subgraphs=subgraphs,
457
- static_fields=static_fields,
458
- leaves=leaves,
459
- metadata=metadata,
460
- index_mapping=None, )
544
+ nodedef = NodeDef.create(
545
+ type=node_impl.type,
546
+ index=index,
547
+ attributes=tuple(key for key, _ in values),
548
+ subgraphs=subgraphs,
549
+ static_fields=static_fields,
550
+ leaves=leaves,
551
+ metadata=metadata,
552
+ index_mapping=None,
553
+ )
461
554
  return nodedef
462
555
 
463
556
 
464
557
  @set_module_as('brainstate.graph')
465
558
  def flatten(
466
- node: Node,
559
+ node: Any,
467
560
  /,
468
561
  ref_index: Optional[RefMap[Any, Index]] = None,
469
562
  treefy_state: bool = True,
470
- ) -> Tuple[GraphDef, NestedDict]:
563
+ ) -> Tuple[GraphDef[Any], NestedDict]:
471
564
  """
472
565
  Flattens a graph node into a (graph_def, state_mapping) pair.
473
566
 
474
- Example::
475
-
476
- >>> import brainstate as brainstate
567
+ Parameters
568
+ ----------
569
+ node : Node
570
+ A graph node.
571
+ ref_index : RefMap[Any, Index], optional
572
+ A mapping from nodes to indexes, defaults to None. If not provided, a new
573
+ empty dictionary is created. This argument can be used to flatten a sequence of graph
574
+ nodes that share references.
575
+ treefy_state : bool, optional
576
+ If True, the state mapping will be a NestedDict instead of a flat dictionary.
577
+ Default is True.
578
+
579
+ Returns
580
+ -------
581
+ tuple[GraphDef, NestedDict]
582
+ A tuple containing the graph definition and state mapping.
583
+
584
+ Examples
585
+ --------
586
+ .. code-block:: python
587
+
588
+ >>> import brainstate
477
589
  >>> node = brainstate.graph.Node()
478
- >>> graph_def, state_mapping = flatten(node)
590
+ >>> graph_def, state_mapping = brainstate.graph.flatten(node)
479
591
  >>> print(graph_def)
480
592
  >>> print(state_mapping)
481
593
 
482
- Args:
483
- node: A graph node.
484
- ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
485
- empty dictionary is created. This argument can be used to flatten a sequence of graph
486
- nodes that share references.
487
- treefy_state: If True, the state mapping will be a NestedDict instead of a flat dictionary.
488
594
  """
489
595
  ref_index = RefMap() if ref_index is None else ref_index
490
596
  assert isinstance(ref_index, RefMap), f"ref_index must be a RefMap. But we got: {ref_index}"
@@ -493,8 +599,13 @@ def flatten(
493
599
  return graph_def, NestedDict.from_flat(flatted_state_mapping)
494
600
 
495
601
 
496
- def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
497
- children: dict[Key, StateLeaf | Node] = {}
602
+ def _get_children(
603
+ graph_def: NodeDef[Any],
604
+ state_mapping: Mapping,
605
+ index_ref: dict[Index, Any],
606
+ index_ref_cache: Optional[dict[Index, Any]],
607
+ ) -> dict[Key, Union[StateLeaf, Any]]:
608
+ children: dict[Key, Union[StateLeaf, Any]] = {}
498
609
 
499
610
  # NOTE: we could allow adding new StateLeafs here
500
611
  # All state keys must be present in the graph definition (the object attributes)
@@ -506,8 +617,8 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
506
617
  # - (3) the key can be a subgraph, a leaf, or a static attribute
507
618
  for key in graph_def.attributes:
508
619
  if key not in state_mapping: # static field
509
- # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
510
- # if key is not present, create an empty types
620
+ # Support unflattening with missing keys for static fields and subgraphs
621
+ # This allows partial state restoration and flexible graph reconstruction
511
622
  if key in graph_def.static_fields:
512
623
  children[key] = graph_def.static_fields[key]
513
624
 
@@ -534,8 +645,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
534
645
 
535
646
  else:
536
647
  # key for a variable is missing, raise an error
537
- raise ValueError(f'Expected key {key!r} in state while building node of type '
538
- f'{graph_def.type.__name__}.')
648
+ raise ValueError(
649
+ f'Expected key {key!r} in state while building node of type '
650
+ f'{graph_def.type.__name__}.'
651
+ )
539
652
 
540
653
  else:
541
654
  raise RuntimeError(f'Unknown static field: {key!r}')
@@ -551,8 +664,11 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
551
664
  if key in graph_def.subgraphs:
552
665
  # if _is_state_leaf(value):
553
666
  if isinstance(value, (TreefyState, State)):
554
- raise ValueError(f'Expected value of type {graph_def.subgraphs[key]} '
555
- f'for {key!r}, but got {value!r}')
667
+ raise ValueError(
668
+ f'Expected value of type {graph_def.subgraphs[key]} '
669
+ f'for {key!r}, but got {value!r}'
670
+ )
671
+
556
672
  if not isinstance(value, dict):
557
673
  raise TypeError(f'Expected a dict for {key!r}, but got {type(value)}.')
558
674
 
@@ -574,8 +690,8 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
574
690
  # TreefyState presumbly created by modifying the NestedDict
575
691
  if isinstance(value, TreefyState):
576
692
  value = value.to_state()
577
- # elif isinstance(value, State):
578
- # value = value
693
+ elif isinstance(value, State):
694
+ value = value
579
695
  children[key] = value
580
696
 
581
697
  elif noderef.index in index_ref:
@@ -585,7 +701,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
585
701
  else:
586
702
  # it is an unseen variable, create a new one
587
703
  if not isinstance(value, (TreefyState, State)):
588
- raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
704
+ raise ValueError(
705
+ f'Expected a State type for {key!r}, but got {type(value)}.'
706
+ )
707
+
589
708
  # when idxmap is present, check if the Varable exists there
590
709
  # and update existing variables if it does
591
710
  if index_ref_cache is not None and noderef.index in index_ref_cache:
@@ -618,11 +737,11 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
618
737
 
619
738
 
620
739
  def _graph_unflatten(
621
- graph_def: NodeDef[Node] | NodeRef[Node],
622
- state_mapping: Mapping[Key, StateLeaf | Mapping[Key, Any]],
740
+ graph_def: Union[NodeDef[Any], NodeRef[Any]],
741
+ state_mapping: Mapping[Key, Union[StateLeaf, Mapping]],
623
742
  index_ref: dict[Index, Any],
624
- index_ref_cache: dict[Index, Any] | None,
625
- ) -> Node:
743
+ index_ref_cache: Optional[dict[Index, Any]],
744
+ ) -> Any:
626
745
  """
627
746
  Recursive helper for graph unflatten.
628
747
 
@@ -697,175 +816,57 @@ def _graph_unflatten(
697
816
 
698
817
  @set_module_as('brainstate.graph')
699
818
  def unflatten(
700
- graph_def: GraphDef,
701
- state_mapping: NestedDict[Key, StateLeaf],
819
+ graph_def: GraphDef[Any],
820
+ state_mapping: NestedDict,
702
821
  /,
703
822
  *,
704
- index_ref: dict[Index, Any] | None = None,
705
- index_ref_cache: dict[Index, Any] | None = None,
706
- ) -> Node:
823
+ index_ref: Optional[dict[Index, Any]] = None,
824
+ index_ref_cache: Optional[dict[Index, Any]] = None,
825
+ ) -> Any:
707
826
  """
708
827
  Unflattens a graphdef into a node with the given state tree mapping.
709
828
 
710
- Example::
711
-
712
- >>> import brainstate as brainstate
713
- >>> class MyNode(brainstate.graph.Node):
714
- ... def __init__(self):
715
- ... self.a = brainstate.nn.Linear(2, 3)
716
- ... self.b = brainstate.nn.Linear(3, 4)
717
- ... self.c = [brainstate.nn.Linear(4, 5), brainstate.nn.Linear(5, 6)]
718
- ... self.d = {'x': brainstate.nn.Linear(6, 7), 'y': brainstate.nn.Linear(7, 8)}
719
- ...
720
- >>> graphdef, statetree = brainstate.graph.flatten(MyNode())
721
- >>> statetree
722
- NestedDict({
723
- 'a': {
724
- 'weight': TreefyState(
725
- type=ParamState,
726
- value={'weight': Array([[-0.8466386 , -2.0294454 , -0.6911647 ],
727
- [ 0.60034966, -1.1869028 , 0.84003365]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
728
- )
729
- },
730
- 'b': {
731
- 'weight': TreefyState(
732
- type=ParamState,
733
- value={'weight': Array([[ 0.8565106 , -0.10337489],
734
- [ 1.7449658 , 0.29128835],
735
- [ 0.11441387, 1.0012752 ]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
736
- )
737
- },
738
- 'c': {
739
- 0: {
740
- 'weight': TreefyState(
741
- type=ParamState,
742
- value={'weight': Array([[ 2.4465137, -0.5711426]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
743
- )
744
- },
745
- 1: {
746
- 'weight': TreefyState(
747
- type=ParamState,
748
- value={'weight': Array([[ 0.14321847, -2.4154725 , -0.6322363 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
749
- )
750
- }
751
- },
752
- 'd': {
753
- 'x': {
754
- 'weight': TreefyState(
755
- type=ParamState,
756
- value={'weight': Array([[ 0.9647322, -0.8958757, 1.585352 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
757
- )
758
- },
759
- 'y': {
760
- 'weight': TreefyState(
761
- type=ParamState,
762
- value={'weight': Array([[-1.2904786 , 0.5695903 , 0.40079263, 0.8769669 ]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)}
763
- )
764
- }
765
- }
766
- })
767
- >>> node = brainstate.graph.unflatten(graphdef, statetree)
768
- >>> node
769
- MyNode(
770
- a=Linear(
771
- in_size=(2,),
772
- out_size=(3,),
773
- w_mask=None,
774
- weight=ParamState(
775
- value={'weight': Array([[ 0.55600464, -1.6276929 , 0.26805446],
776
- [ 1.175099 , 1.0077754 , 0.37592274]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)},
777
- )
778
- ),
779
- b=Linear(
780
- in_size=(3,),
781
- out_size=(4,),
782
- w_mask=None,
783
- weight=ParamState(
784
- value={'weight': Array([[-0.24753566, 0.18456966, -0.29438975, 0.16891003],
785
- [-0.803741 , -0.46037054, -0.21617596, 0.1260884 ],
786
- [-0.43074366, -0.24757433, 1.2237076 , -0.07842704]], dtype=float32), 'bias': Array([0., 0., 0., 0.], dtype=float32)},
787
- )
788
- ),
789
- c=[Linear(
790
- in_size=(4,),
791
- out_size=(5,),
792
- w_mask=None,
793
- weight=ParamState(
794
- value={'weight': Array([[-0.22384474, 0.79441446, -0.658726 , 0.05991402, 0.3014344 ],
795
- [-1.4755846 , -0.42272082, -0.07692316, 0.03077666, 0.34513143],
796
- [-0.69395834, 0.48617035, 1.1042316 , 0.13105175, -0.25620162],
797
- [ 0.50389856, 0.6998943 , 0.43716812, 1.2168779 , -0.47325954]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)},
798
- )
799
- ), Linear(
800
- in_size=(5,),
801
- out_size=(6,),
802
- w_mask=None,
803
- weight=ParamState(
804
- value={'weight': Array([[ 0.07714394, 0.78213537, 0.6745718 , -0.22881542, 0.5523547 ,
805
- -0.6399196 ],
806
- [-0.22626828, -0.54522336, 0.07448788, -0.00464636, 1.1483842 ,
807
- -0.57049096],
808
- [-0.86659616, 0.5683135 , -0.7449975 , 1.1862832 , 0.15047254,
809
- 0.68890226],
810
- [-1.0325443 , 0.2658072 , -0.10083053, -0.66915905, 0.11258496,
811
- 0.5440655 ],
812
- [ 0.27917263, 0.05717273, -0.5682605 , -0.88345915, 0.01314917,
813
- 0.780759 ]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)},
814
- )
815
- )],
816
- d={'x': Linear(
817
- in_size=(6,),
818
- out_size=(7,),
819
- w_mask=None,
820
- weight=ParamState(
821
- value={'weight': Array([[-0.24238771, -0.23202638, 0.13663477, -0.48858666, 0.80871904,
822
- 0.00593298, 0.7595096 ],
823
- [ 0.50457454, 0.24180941, 0.25048748, 0.8937061 , 0.25398138,
824
- -1.2400566 , 0.00151599],
825
- [-0.19136038, 0.34470603, -0.11892717, -0.12514868, -0.5871703 ,
826
- 0.13572927, -1.1859009 ],
827
- [-0.01580911, 0.9301295 , -1.1246226 , -0.137708 , -0.4952151 ,
828
- 0.17537868, 0.98440856],
829
- [ 0.6399284 , 0.01739843, 0.61856824, 0.93258303, 0.64012206,
830
- 0.22780116, -0.5763679 ],
831
- [ 0.14077143, -1.0359222 , 0.28072503, 0.2557584 , -0.50622064,
832
- 0.4388198 , -0.26106128]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
833
- )
834
- ), 'y': Linear(
835
- in_size=(7,),
836
- out_size=(8,),
837
- w_mask=None,
838
- weight=ParamState(
839
- value={'weight': Array([[-0.23334591, -0.2893582 , 0.8071877 , -0.49038902, -0.29646504,
840
- 0.13624157, 0.22763114, 0.01906361],
841
- [-0.26742765, 0.20136863, 0.35148615, 0.42135832, 0.06401154,
842
- -0.78036404, 0.6616062 , 0.19437549],
843
- [ 0.9229799 , -0.1205209 , 0.69602865, 0.9685676 , -0.99886954,
844
- -0.12649904, -0.15393028, 0.65067965],
845
- [ 0.7020109 , -0.5452006 , 0.3649151 , -0.42368713, 0.24738027,
846
- 0.29290223, -0.63721114, 0.6007214 ],
847
- [-0.45045808, -0.08538888, -0.01338054, -0.39983988, 0.4028439 ,
848
- 1.0498686 , -0.24730456, 0.37612835],
849
- [ 0.16273966, 0.9001257 , 0.15190877, -1.1129239 , -0.29441378,
850
- 0.5168159 , -0.4205143 , 0.45700482],
851
- [ 0.08611429, -0.9271384 , -0.562362 , -0.586757 , 1.1611121 ,
852
- 0.5137503 , -0.46277294, 0.84642583]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
853
- )
854
- )}
855
- )
856
-
857
- Args:
858
- graph_def: A GraphDef instance.
859
- state_mapping: A NestedDict instance.
860
- index_ref: A mapping from indexes to nodes references found during the graph
861
- traversal, defaults to None. If not provided, a new empty dictionary is
862
- created. This argument can be used to unflatten a sequence of (graphdef, state_mapping)
863
- pairs that share the same index space.
864
- index_ref_cache: A mapping from indexes to existing nodes that can be reused.
865
- When a reference is reused, ``GraphNodeImpl.clear`` is called to leave the
866
- object in an empty state and then filled by the unflatten process, as a result
867
- existing graph nodes are mutated to have the new content/topology
868
- specified by the graphdef.
829
+ Parameters
830
+ ----------
831
+ graph_def : GraphDef
832
+ A GraphDef instance.
833
+ state_mapping : NestedDict
834
+ A NestedDict instance containing the state mapping.
835
+ index_ref : dict[Index, Any], optional
836
+ A mapping from indexes to nodes references found during the graph
837
+ traversal. If not provided, a new empty dictionary is created. This argument
838
+ can be used to unflatten a sequence of (graphdef, state_mapping) pairs that
839
+ share the same index space.
840
+ index_ref_cache : dict[Index, Any], optional
841
+ A mapping from indexes to existing nodes that can be reused. When a reference
842
+ is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty
843
+ state and then filled by the unflatten process. As a result, existing graph
844
+ nodes are mutated to have the new content/topology specified by the graphdef.
845
+
846
+ Returns
847
+ -------
848
+ Node
849
+ The reconstructed node.
850
+
851
+ Examples
852
+ --------
853
+ .. code-block:: python
854
+
855
+ >>> import brainstate
856
+ >>> class MyNode(brainstate.graph.Node):
857
+ ... def __init__(self):
858
+ ... self.a = brainstate.nn.Linear(2, 3)
859
+ ... self.b = brainstate.nn.Linear(3, 4)
860
+ ...
861
+ >>> # Flatten a node
862
+ >>> node = MyNode()
863
+ >>> graphdef, statetree = brainstate.graph.flatten(node)
864
+ >>>
865
+ >>> # Unflatten back to node
866
+ >>> reconstructed_node = brainstate.graph.unflatten(graphdef, statetree)
867
+ >>> assert isinstance(reconstructed_node, MyNode)
868
+ >>> assert isinstance(reconstructed_node.a, brainstate.nn.Linear)
869
+ >>> assert isinstance(reconstructed_node.b, brainstate.nn.Linear)
869
870
  """
870
871
  index_ref = {} if index_ref is None else index_ref
871
872
  assert isinstance(graph_def, (NodeDef, NodeRef)), f"graph_def must be a NodeDef or NodeRef. But we got: {graph_def}"
@@ -874,7 +875,7 @@ def unflatten(
874
875
 
875
876
 
876
877
  def _graph_pop(
877
- node: Node,
878
+ node: Any,
878
879
  id_to_index: dict[int, Index],
879
880
  path_parts: PathParts,
880
881
  flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...],
@@ -922,61 +923,57 @@ def _graph_pop(
922
923
  pass
923
924
 
924
925
 
925
- @overload
926
- def pop_states(node, filter1: Filter, /) -> NestedDict:
927
- ...
928
-
929
-
930
- @overload
931
- def pop_states(node, filter1: Filter, filter2: Filter, /, *filters: Filter) -> tuple[NestedDict, ...]:
932
- ...
933
-
934
-
935
926
  @set_module_as('brainstate.graph')
936
927
  def pop_states(
937
- node: Node,
938
- *filters: Any
939
- ) -> Union[NestedDict[Key, State], Tuple[NestedDict[Key, State], ...]]:
928
+ node: Any, *filters: Any
929
+ ) -> Union[NestedDict, Tuple[NestedDict, ...]]:
940
930
  """
941
931
  Pop one or more :class:`State` types from the graph node.
942
932
 
943
- Example usage::
944
-
945
- >>> import brainstate as brainstate
946
- >>> import jax.numpy as jnp
947
-
948
- >>> class Model(brainstate.nn.Module):
949
- ... def __init__(self):
950
- ... super().__init__()
951
- ... self.a = brainstate.nn.Linear(2, 3)
952
- ... self.b = brainstate.nn.LIF([10, 2])
953
-
954
- >>> model = Model()
955
- >>> with brainstate.catch_new_states('new'):
956
- ... brainstate.nn.init_all_states(model)
957
-
958
- >>> assert len(model.states()) == 2
959
- >>> model_states = brainstate.graph.pop_states(model, 'new')
960
- >>> model_states
961
- NestedDict({
962
- 'b': {
963
- 'V': {
964
- 'st': ShortTermState(
965
- value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
966
- 0., 0., 0.], dtype=float32),
967
- tag='new'
968
- )
933
+ Parameters
934
+ ----------
935
+ node : Node
936
+ A graph node object.
937
+ *filters
938
+ One or more :class:`State` objects to filter by.
939
+
940
+ Returns
941
+ -------
942
+ NestedDict or tuple[NestedDict, ...]
943
+ The popped :class:`NestedDict` containing the :class:`State`
944
+ objects that were filtered for.
945
+
946
+ Examples
947
+ --------
948
+ .. code-block:: python
949
+
950
+ >>> import brainstate
951
+ >>> import jax.numpy as jnp
952
+
953
+ >>> class Model(brainstate.nn.Module):
954
+ ... def __init__(self):
955
+ ... super().__init__()
956
+ ... self.a = brainstate.nn.Linear(2, 3)
957
+ ... self.b = brainstate.nn.LIF([10, 2])
958
+
959
+ >>> model = Model()
960
+ >>> with brainstate.catch_new_states('new'):
961
+ ... brainstate.nn.init_all_states(model)
962
+
963
+ >>> assert len(model.states()) == 2
964
+ >>> model_states = brainstate.graph.pop_states(model, 'new')
965
+ >>> model_states # doctest: +SKIP
966
+ NestedDict({
967
+ 'b': {
968
+ 'V': {
969
+ 'st': ShortTermState(
970
+ value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
971
+ 0., 0., 0.], dtype=float32),
972
+ tag='new'
973
+ )
974
+ }
969
975
  }
970
- }
971
- })
972
-
973
- Args:
974
- node: A graph node object.
975
- *filters: One or more :class:`State` objects to filter by.
976
-
977
- Returns:
978
- The popped :class:`NestedDict` containing the :class:`State`
979
- objects that were filtered for.
976
+ })
980
977
  """
981
978
  if len(filters) == 0:
982
979
  raise ValueError('Expected at least one filter')
@@ -985,11 +982,13 @@ def pop_states(
985
982
  path_parts: PathParts = ()
986
983
  predicates = tuple(to_predicate(filter) for filter in filters)
987
984
  flatted_state_dicts: tuple[FlattedDict[PathParts, StateLeaf], ...] = tuple({} for _ in predicates)
988
- _graph_pop(node=node,
989
- id_to_index=id_to_index,
990
- path_parts=path_parts,
991
- flatted_state_dicts=flatted_state_dicts,
992
- predicates=predicates, )
985
+ _graph_pop(
986
+ node=node,
987
+ id_to_index=id_to_index,
988
+ path_parts=path_parts,
989
+ flatted_state_dicts=flatted_state_dicts,
990
+ predicates=predicates,
991
+ )
993
992
  states = tuple(NestedDict.from_flat(flat_state) for flat_state in flatted_state_dicts)
994
993
 
995
994
  if len(states) == 1:
@@ -1011,94 +1010,49 @@ def _split_state(
1011
1010
  return states # type: ignore[return-value]
1012
1011
 
1013
1012
 
1014
- @overload
1015
- def treefy_split(node: A, /) -> Tuple[GraphDef, NestedDict]:
1016
- ...
1017
-
1018
-
1019
- @overload
1020
- def treefy_split(node: A, first: Filter, /) -> Tuple[GraphDef, NestedDict]:
1021
- ...
1022
-
1023
-
1024
- @overload
1025
- def treefy_split(node: A, first: Filter, second: Filter, /) -> Tuple[GraphDef, NestedDict, NestedDict]:
1026
- ...
1027
-
1028
-
1029
- @overload
1030
- def treefy_split(
1031
- node: A, first: Filter, second: Filter, /, *filters: Filter,
1032
- ) -> Tuple[GraphDef, NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1033
- ...
1034
-
1035
-
1036
1013
  @set_module_as('brainstate.graph')
1037
1014
  def treefy_split(
1038
- node: A,
1039
- *filters: Filter
1040
- ) -> Tuple[GraphDef[A], NestedDict, Unpack[Tuple[NestedDict, ...]]]:
1041
- """Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s. NestedDict is
1042
- a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
1043
- contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
1044
- to JAX’s ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch
1045
- seamlessly between stateful and stateless representations of the graph.
1046
-
1047
- Example usage::
1048
-
1049
- >>> from joblib.testing import param >>> import brainstate as brainstate
1050
- >>> import jax, jax.numpy as jnp
1051
- ...
1052
- >>> class Foo(brainstate.graph.Node):
1053
- ... def __init__(self):
1054
- ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1055
- ... self.b = brainstate.nn.Linear(2, 3)
1056
- ...
1057
- >>> node = Foo()
1058
- >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1059
- ...
1060
- >>> params
1061
- NestedDict({
1062
- 'a': {
1063
- 'weight': TreefyState(
1064
- type=ParamState,
1065
- value={'weight': Array([[-1.0013659, 1.5763807],
1066
- [ 1.7149199, 2.0140953]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}
1067
- )
1068
- },
1069
- 'b': {
1070
- 'weight': TreefyState(
1071
- type=ParamState,
1072
- value={'bias': Array([[0., 0.]], dtype=float32), 'scale': Array([[1., 1.]], dtype=float32)}
1073
- )
1074
- }
1075
- })
1076
- >>> jax.tree.map(jnp.shape, others)
1077
- NestedDict({
1078
- 'b': {
1079
- 'running_mean': TreefyState(
1080
- type=LongTermState,
1081
- value=(1, 2)
1082
- ),
1083
- 'running_var': TreefyState(
1084
- type=LongTermState,
1085
- value=(1, 2)
1086
- )
1087
- }
1088
- })
1089
-
1090
- :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1091
- transformations, see
1092
- `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1093
- for more information.
1094
-
1095
- Arguments:
1096
- node: graph node to split.
1097
- *filters: some optional filters to group the state into mutually exclusive substates.
1015
+ node: A, *filters: Filter
1016
+ ):
1017
+ """
1018
+ Split a graph node into a :class:`GraphDef` and one or more :class:`NestedDict`s.
1098
1019
 
1099
- Returns:
1100
- ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
1101
- filters are passed, a single ``NestedDict`` is returned.
1020
+ NestedDict is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States.
1021
+ GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is
1022
+ analogous to JAX's ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to
1023
+ switch seamlessly between stateful and stateless representations of the graph.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ node : A
1028
+ Graph node to split.
1029
+ *filters
1030
+ Optional filters to group the state into mutually exclusive substates.
1031
+
1032
+ Returns
1033
+ -------
1034
+ tuple
1035
+ ``GraphDef`` and one or more ``States`` equal to the number of filters passed.
1036
+ If no filters are passed, a single ``NestedDict`` is returned.
1037
+
1038
+ Examples
1039
+ --------
1040
+ .. code-block:: python
1041
+
1042
+ >>> import brainstate
1043
+ >>> import jax, jax.numpy as jnp
1044
+
1045
+ >>> class Foo(brainstate.graph.Node):
1046
+ ... def __init__(self):
1047
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1048
+ ... self.b = brainstate.nn.Linear(2, 3)
1049
+ ...
1050
+ >>> node = Foo()
1051
+ >>> graphdef, params, others = brainstate.graph.treefy_split(
1052
+ ... node, brainstate.ParamState, ...
1053
+ ... )
1054
+ >>> # params contains ParamState variables
1055
+ >>> # others contains all other state variables
1102
1056
  """
1103
1057
  graphdef, state_tree = flatten(node)
1104
1058
  states = tuple(_split_state(state_tree, filters))
@@ -1106,49 +1060,47 @@ def treefy_split(
1106
1060
 
1107
1061
 
1108
1062
  @set_module_as('brainstate.graph')
1109
- def treefy_merge(
1110
- graphdef: GraphDef[A],
1111
- state_mapping: GraphStateMapping,
1112
- /,
1113
- *state_mappings: GraphStateMapping,
1114
- ) -> A:
1115
- """The inverse of :func:`split`.
1063
+ def treefy_merge(graphdef: GraphDef[A], *state_mappings) -> A:
1064
+ """
1065
+ The inverse of :func:`split`.
1116
1066
 
1117
1067
  ``merge`` takes a :class:`GraphDef` and one or more :class:`NestedDict`'s and creates
1118
1068
  a new node with the same structure as the original node.
1119
1069
 
1120
- Example usage::
1121
-
1122
- >>> import brainstate as brainstate
1123
- >>> import jax, jax.numpy as jnp
1124
- ...
1125
- >>> class Foo(brainstate.graph.Node):
1126
- ... def __init__(self):
1127
- ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1128
- ... self.b = brainstate.nn.Linear(2, 3)
1129
- ...
1130
- >>> node = Foo()
1131
- >>> graphdef, params, others = brainstate.graph.treefy_split(node, brainstate.ParamState, ...)
1132
- ...
1133
- >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1134
- >>> assert isinstance(new_node, Foo)
1135
- >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1136
- >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1137
-
1138
- :func:`split` and :func:`merge` are primarily used to interact directly with JAX
1139
- transformations, see
1140
- `Functional API <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#the-functional-api>`__
1141
- for more information.
1142
-
1143
- Args:
1144
- graphdef: A :class:`GraphDef` object.
1145
- state_mapping: A :class:`NestedDict` object.
1146
- *state_mappings: Additional :class:`NestedDict` objects.
1147
-
1148
- Returns:
1149
- The merged :class:`Module`.
1070
+ Parameters
1071
+ ----------
1072
+ graphdef : GraphDef[A]
1073
+ A :class:`GraphDef` object.
1074
+ *state_mappings
1075
+ Additional :class:`NestedDict` objects.
1076
+
1077
+ Returns
1078
+ -------
1079
+ A
1080
+ The merged :class:`Module`.
1081
+
1082
+ Examples
1083
+ --------
1084
+ .. code-block:: python
1085
+
1086
+ >>> import brainstate
1087
+ >>> import jax, jax.numpy as jnp
1088
+
1089
+ >>> class Foo(brainstate.graph.Node):
1090
+ ... def __init__(self):
1091
+ ... self.a = brainstate.nn.BatchNorm1d([10, 2])
1092
+ ... self.b = brainstate.nn.Linear(2, 3)
1093
+ ...
1094
+ >>> node = Foo()
1095
+ >>> graphdef, params, others = brainstate.graph.treefy_split(
1096
+ ... node, brainstate.ParamState, ...
1097
+ ... )
1098
+ >>> new_node = brainstate.graph.treefy_merge(graphdef, params, others)
1099
+ >>> assert isinstance(new_node, Foo)
1100
+ >>> assert isinstance(new_node.b, brainstate.nn.BatchNorm1d)
1101
+ >>> assert isinstance(new_node.a, brainstate.nn.Linear)
1150
1102
  """
1151
- state_mapping = GraphStateMapping.merge(state_mapping, *state_mappings)
1103
+ state_mapping = GraphStateMapping.merge(*state_mappings)
1152
1104
  node = unflatten(graphdef, state_mapping)
1153
1105
  return node
1154
1106
 
@@ -1186,31 +1138,27 @@ def _split_flatted(
1186
1138
  return flat_states
1187
1139
 
1188
1140
 
1189
- @overload
1190
- def nodes(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1191
- ...
1192
-
1193
-
1194
- @overload
1195
- def nodes(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, Node]:
1196
- ...
1197
-
1198
-
1199
- @overload
1200
- def nodes(
1201
- node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1202
- ) -> Tuple[FlattedDict[Key, Node], ...]:
1203
- ...
1204
-
1205
-
1206
1141
  @set_module_as('brainstate.graph')
1207
1142
  def nodes(
1208
- node,
1209
- *filters: Filter,
1210
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1211
- ) -> Union[FlattedDict[Key, Node], Tuple[FlattedDict[Key, Node], ...]]:
1143
+ node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1144
+ ):
1212
1145
  """
1213
1146
  Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1147
+
1148
+ Parameters
1149
+ ----------
1150
+ node : Node
1151
+ The node to get nodes from.
1152
+ *filters
1153
+ Filters to apply to the nodes.
1154
+ allowed_hierarchy : tuple[int, int], optional
1155
+ The allowed hierarchy levels, by default (0, MAX_INT).
1156
+
1157
+ Returns
1158
+ -------
1159
+ FlattedDict or tuple[FlattedDict, ...]
1160
+ The filtered nodes.
1161
+
1214
1162
  """
1215
1163
  num_filters = len(filters)
1216
1164
  if num_filters == 0:
@@ -1232,31 +1180,27 @@ def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, Stat
1232
1180
  yield path, value
1233
1181
 
1234
1182
 
1235
- @overload
1236
- def states(node, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1237
- ...
1238
-
1239
-
1240
- @overload
1241
- def states(node, first: Filter, /, allowed_hierarchy=(0, _max_int)) -> FlattedDict[Key, State]:
1242
- ...
1243
-
1244
-
1245
- @overload
1246
- def states(
1247
- node, first: Filter, second: Filter, /, *filters: Filter, allowed_hierarchy=(0, _max_int)
1248
- ) -> tuple[FlattedDict[Key, State], ...]:
1249
- ...
1250
-
1251
-
1252
1183
  @set_module_as('brainstate.graph')
1253
1184
  def states(
1254
- node,
1255
- *filters: Filter,
1256
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1257
- ) -> Union[FlattedDict[Key, State], tuple[FlattedDict[Key, State], ...]]:
1185
+ node, *filters: Filter, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1186
+ ) -> Union[FlattedDict, tuple[FlattedDict, ...]]:
1258
1187
  """
1259
1188
  Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1189
+
1190
+ Parameters
1191
+ ----------
1192
+ node : Node
1193
+ The node to get states from.
1194
+ *filters
1195
+ Filters to apply to the states.
1196
+ allowed_hierarchy : tuple[int, int], optional
1197
+ The allowed hierarchy levels, by default (0, MAX_INT).
1198
+
1199
+ Returns
1200
+ -------
1201
+ FlattedDict or tuple[FlattedDict, ...]
1202
+ The filtered states.
1203
+
1260
1204
  """
1261
1205
  num_filters = len(filters)
1262
1206
  if num_filters == 0:
@@ -1272,72 +1216,60 @@ def states(
1272
1216
  return state_maps[:num_filters]
1273
1217
 
1274
1218
 
1275
- @overload
1276
- def treefy_states(
1277
- node, /, flatted: bool = False
1278
- ) -> NestedDict[Key, TreefyState]:
1279
- ...
1280
-
1281
-
1282
- @overload
1283
- def treefy_states(
1284
- node, first: Filter, /, flatted: bool = False
1285
- ) -> NestedDict[Key, TreefyState]:
1286
- ...
1287
-
1288
-
1289
- @overload
1290
- def treefy_states(
1291
- node, first: Filter, second: Filter, /, *filters: Filter, flatted: bool = False
1292
- ) -> Tuple[NestedDict[Key, TreefyState], ...]:
1293
- ...
1294
-
1295
-
1296
1219
  @set_module_as('brainstate.graph')
1297
1220
  def treefy_states(
1298
1221
  node, *filters,
1299
- ) -> NestedDict[Key, TreefyState] | tuple[NestedDict[Key, TreefyState], ...]:
1222
+ ):
1300
1223
  """
1301
1224
  Similar to :func:`split` but only returns the :class:`NestedDict`'s indicated by the filters.
1302
1225
 
1303
- Example usage::
1304
-
1305
- >>> import brainstate as brainstate
1306
- >>> class Model(brainstate.nn.Module):
1307
- ... def __init__(self):
1308
- ... super().__init__()
1309
- ... self.l1 = brainstate.nn.Linear(2, 3)
1310
- ... self.l2 = brainstate.nn.Linear(3, 4)
1311
- ... def __call__(self, x):
1312
- ... return self.l2(self.l1(x))
1313
-
1314
- >>> model = Model()
1315
- >>> # get the learnable parameters from the batch norm and linear layer
1316
- >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1317
- >>> # get them separately
1318
- >>> params, others = brainstate.graph.treefy_states(model, brainstate.ParamState, brainstate.ShortTermState)
1319
- >>> # get them together
1320
- >>> states = brainstate.graph.treefy_states(model)
1321
-
1322
- Args:
1323
- node: A graph node object.
1324
- *filters: One or more :class:`State` objects to filter by.
1325
-
1326
- Returns:
1327
- One or more :class:`NestedDict` mappings.
1226
+ Parameters
1227
+ ----------
1228
+ node : Node
1229
+ A graph node object.
1230
+ *filters
1231
+ One or more :class:`State` objects to filter by.
1232
+
1233
+ Returns
1234
+ -------
1235
+ NestedDict or tuple of NestedDict
1236
+ One or more :class:`NestedDict` mappings.
1237
+
1238
+ Examples
1239
+ --------
1240
+ .. code-block:: python
1241
+
1242
+ >>> import brainstate
1243
+ >>> class Model(brainstate.nn.Module):
1244
+ ... def __init__(self):
1245
+ ... super().__init__()
1246
+ ... self.l1 = brainstate.nn.Linear(2, 3)
1247
+ ... self.l2 = brainstate.nn.Linear(3, 4)
1248
+ ... def __call__(self, x):
1249
+ ... return self.l2(self.l1(x))
1250
+
1251
+ >>> model = Model()
1252
+ >>> # Get the learnable parameters
1253
+ >>> params = brainstate.graph.treefy_states(model, brainstate.ParamState)
1254
+ >>> # Get them separately
1255
+ >>> params, others = brainstate.graph.treefy_states(
1256
+ ... model, brainstate.ParamState, brainstate.ShortTermState
1257
+ ... )
1258
+ >>> # Get all states together
1259
+ >>> states = brainstate.graph.treefy_states(model)
1328
1260
  """
1329
1261
  _, state_mapping = flatten(node)
1330
- state_mappings: GraphStateMapping | tuple[GraphStateMapping, ...]
1331
1262
  if len(filters) == 0:
1332
- state_mappings = state_mapping
1333
- elif len(filters) == 1:
1334
- state_mappings = state_mapping.filter(filters[0])
1263
+ return state_mapping
1335
1264
  else:
1336
- state_mappings = state_mapping.filter(filters[0], filters[1], *filters[2:])
1337
- return state_mappings
1265
+ state_mappings = state_mapping.filter(*filters)
1266
+ if len(filters) == 1:
1267
+ return state_mappings[0]
1268
+ else:
1269
+ return state_mappings
1338
1270
 
1339
1271
 
1340
- def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
1272
+ def _graph_update_dynamic(node: Any, state: Mapping) -> None:
1341
1273
  if not _is_node(node):
1342
1274
  raise RuntimeError(f'Unsupported type: {type(node)}')
1343
1275
 
@@ -1350,7 +1282,8 @@ def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
1350
1282
  raise ValueError(f'Cannot set key {key!r} on immutable node of '
1351
1283
  f'type {type(node).__name__}')
1352
1284
  if isinstance(value, State):
1353
- value = value.copy() # TODO: chenge it to state_ref
1285
+ # TODO: here maybe error? we should check if the state already belongs to another node?
1286
+ value = value.to_state_ref() # Convert to state reference for proper state management
1354
1287
  node_impl.set_key(node, key, value)
1355
1288
  continue
1356
1289
 
@@ -1379,18 +1312,23 @@ def _graph_update_dynamic(node: Any, state: Mapping[Key, Any]):
1379
1312
 
1380
1313
 
1381
1314
  def update_states(
1382
- node: Node,
1383
- state_dict: NestedDict | FlattedDict,
1315
+ node: Any,
1316
+ state_dict: Union[NestedDict, FlattedDict],
1384
1317
  /,
1385
- *state_dicts: NestedDict | FlattedDict
1318
+ *state_dicts: Union[NestedDict, FlattedDict]
1386
1319
  ) -> None:
1387
1320
  """
1388
1321
  Update the given graph node with a new :class:`NestedMapping` in-place.
1389
1322
 
1390
- Args:
1391
- node: A graph node to update.
1392
- state_dict: A :class:`NestedMapping` object.
1393
- *state_dicts: Additional :class:`NestedMapping` objects.
1323
+ Parameters
1324
+ ----------
1325
+ node : Node
1326
+ A graph node to update.
1327
+ state_dict : NestedDict | FlattedDict
1328
+ A :class:`NestedMapping` object.
1329
+ *state_dicts : NestedDict | FlattedDict
1330
+ Additional :class:`NestedMapping` objects.
1331
+
1394
1332
  """
1395
1333
  if state_dicts:
1396
1334
  state_dict = NestedDict.merge(state_dict, *state_dicts)
@@ -1398,177 +1336,110 @@ def update_states(
1398
1336
 
1399
1337
 
1400
1338
  @set_module_as('brainstate.graph')
1401
- def graphdef(node: Any, /) -> GraphDef[Any]:
1402
- """Get the :class:`GraphDef` of the given graph node.
1339
+ def graphdef(node: Any) -> GraphDef[Any]:
1340
+ """
1341
+ Get the :class:`GraphDef` of the given graph node.
1342
+
1343
+ Parameters
1344
+ ----------
1345
+ node : Any
1346
+ A graph node object.
1403
1347
 
1404
- Example usage::
1348
+ Returns
1349
+ -------
1350
+ GraphDef[Any]
1351
+ The :class:`GraphDef` of the :class:`Module` object.
1405
1352
 
1406
- >>> import brainstate as brainstate
1353
+ Examples
1354
+ --------
1355
+ .. code-block:: python
1407
1356
 
1408
- >>> model = brainstate.nn.Linear(2, 3)
1409
- >>> graphdef, _ = brainstate.graph.treefy_split(model)
1410
- >>> assert graphdef == brainstate.graph.graphdef(model)
1357
+ >>> import brainstate
1411
1358
 
1412
- Args:
1413
- node: A graph node object.
1359
+ >>> model = brainstate.nn.Linear(2, 3)
1360
+ >>> graphdef, _ = brainstate.graph.treefy_split(model)
1361
+ >>> assert graphdef == brainstate.graph.graphdef(model)
1414
1362
 
1415
- Returns:
1416
- The :class:`GraphDef` of the :class:`Module` object.
1417
1363
  """
1418
1364
  graphdef, _ = flatten(node)
1419
1365
  return graphdef
1420
1366
 
1421
1367
 
1422
1368
  @set_module_as('brainstate.graph')
1423
- def clone(node: Node) -> Node:
1369
+ def clone(node: A) -> A:
1424
1370
  """
1425
1371
  Create a deep copy of the given graph node.
1426
1372
 
1427
- Example usage::
1373
+ Parameters
1374
+ ----------
1375
+ node : Node
1376
+ A graph node object.
1428
1377
 
1429
- >>> import brainstate as brainstate
1430
- >>> model = brainstate.nn.Linear(2, 3)
1431
- >>> cloned_model = clone(model)
1432
- >>> model.weight.value['bias'] += 1
1433
- >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1378
+ Returns
1379
+ -------
1380
+ Node
1381
+ A deep copy of the :class:`Module` object.
1434
1382
 
1435
- Args:
1436
- node: A graph node object.
1383
+ Examples
1384
+ --------
1385
+ .. code-block:: python
1386
+
1387
+ >>> import brainstate
1388
+ >>> model = brainstate.nn.Linear(2, 3)
1389
+ >>> cloned_model = brainstate.graph.clone(model)
1390
+ >>> model.weight.value['bias'] += 1
1391
+ >>> assert (model.weight.value['bias'] != cloned_model.weight.value['bias']).all()
1437
1392
 
1438
- Returns:
1439
- A deep copy of the :class:`Module` object.
1440
1393
  """
1441
1394
  graphdef, state = treefy_split(node)
1442
1395
  return treefy_merge(graphdef, state)
1443
1396
 
1444
1397
 
1445
- @set_module_as('brainstate.graph')
1446
- def call(
1447
- graphdef_state: Tuple[GraphDef[A], GraphStateMapping],
1448
- ) -> ApplyCaller[Tuple[GraphDef[A], GraphStateMapping]]:
1449
- """Calls a method underlying graph node defined by a (GraphDef, NestedDict) pair.
1450
-
1451
- ``call`` takes a ``(GraphDef, NestedDict)`` pair and creates a proxy object that can be
1452
- used to call methods on the underlying graph node. When a method is called, the
1453
- output is returned along with a new (GraphDef, NestedDict) pair that represents the
1454
- updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method``
1455
- > :func:`split`` but is more convenient to use in pure JAX functions.
1456
-
1457
- Example::
1458
-
1459
- >>> import brainstate as brainstate
1460
- >>> import jax
1461
- >>> import jax.numpy as jnp
1462
- ...
1463
- >>> class StatefulLinear(brainstate.graph.Node):
1464
- ... def __init__(self, din, dout):
1465
- ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1466
- ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1467
- ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1468
- ...
1469
- ... def increment(self):
1470
- ... self.count.value += 1
1471
- ...
1472
- ... def __call__(self, x):
1473
- ... self.increment()
1474
- ... return x @ self.w.value + self.b.value
1475
- ...
1476
- >>> linear = StatefulLinear(3, 2)
1477
- >>> linear_state = brainstate.graph.treefy_split(linear)
1478
- ...
1479
- >>> @jax.jit
1480
- ... def forward(x, linear_state):
1481
- ... y, linear_state = brainstate.graph.call(linear_state)(x)
1482
- ... return y, linear_state
1483
- ...
1484
- >>> x = jnp.ones((1, 3))
1485
- >>> y, linear_state = forward(x, linear_state)
1486
- >>> y, linear_state = forward(x, linear_state)
1487
- ...
1488
- >>> linear = brainstate.graph.treefy_merge(*linear_state)
1489
- >>> linear.count.value
1490
- Array(2, dtype=uint32)
1491
-
1492
- The proxy object returned by ``call`` supports indexing and attribute access
1493
- to access nested methods. In the example below, the ``increment`` method indexing
1494
- is used to call the ``increment`` method of the ``StatefulLinear`` module
1495
- at the ``b`` key of a ``nodes`` dictionary.
1496
-
1497
- >>> class StatefulLinear(brainstate.graph.Node):
1498
- ... def __init__(self, din, dout):
1499
- ... self.w = brainstate.ParamState(brainstate.random.rand(din, dout))
1500
- ... self.b = brainstate.ParamState(jnp.zeros((dout,)))
1501
- ... self.count = brainstate.State(jnp.array(0, dtype=jnp.uint32))
1502
- ...
1503
- ... def increment(self):
1504
- ... self.count.value += 1
1505
- ...
1506
- ... def __call__(self, x):
1507
- ... self.increment()
1508
- ... return x @ self.w.value + self.b.value
1509
- ...
1510
- >>> nodes = dict(
1511
- ... a=StatefulLinear(3, 2),
1512
- ... b=StatefulLinear(2, 1),
1513
- ... )
1514
- ...
1515
- >>> node_state = treefy_split(nodes)
1516
- >>> # use attribute access
1517
- >>> _, node_state = brainstate.graph.call(node_state)['b'].increment()
1518
- ...
1519
- >>> nodes = treefy_merge(*node_state)
1520
- >>> nodes['a'].count.value
1521
- Array(0, dtype=uint32)
1522
- >>> nodes['b'].count.value
1523
- Array(1, dtype=uint32)
1524
- """
1525
-
1526
- def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
1527
- node = treefy_merge(*graphdef_state)
1528
- method = accessor(node)
1529
- out = method(*args, **kwargs)
1530
- return out, treefy_split(node)
1531
-
1532
- return CallableProxy(pure_caller) # type: ignore
1533
-
1534
-
1535
1398
  @set_module_as('brainstate.graph')
1536
1399
  def iter_leaf(
1537
- node: Any,
1538
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1400
+ node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1539
1401
  ) -> Iterator[tuple[PathParts, Any]]:
1540
- """Iterates over all nested leaves in the given graph node, including the current node.
1402
+ """
1403
+ Iterates over all nested leaves in the given graph node, including the current node.
1541
1404
 
1542
1405
  ``iter_graph`` creates a generator that yields path and value pairs, where
1543
1406
  the path is a tuple of strings or integers representing the path to the value from the
1544
1407
  root. Repeated nodes are visited only once. Leaves include static values.
1545
1408
 
1546
- Example::
1547
- >>> import brainstate as brainstate
1548
- >>> import jax.numpy as jnp
1549
- ...
1550
- >>> class Linear(brainstate.nn.Module):
1551
- ... def __init__(self, din, dout):
1552
- ... super().__init__()
1553
- ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1554
- ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1555
- ... self.a = 1
1556
- ...
1557
- >>> module = Linear(3, 4)
1558
- ...
1559
- >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1560
- ... print(path, type(value).__name__)
1561
- ...
1562
- (0, 'a') int
1563
- (0, 'bias') ParamState
1564
- (0, 'weight') ParamState
1565
-
1566
1409
  Parameters
1567
1410
  ----------
1568
- node: Node
1569
- The node to iterate over.
1570
- allowed_hierarchy: tuple of int
1571
- The allowed hierarchy.
1411
+ node : Any
1412
+ The node to iterate over.
1413
+ allowed_hierarchy : tuple[int, int], optional
1414
+ The allowed hierarchy levels, by default (0, MAX_INT).
1415
+
1416
+ Yields
1417
+ ------
1418
+ Iterator[tuple[PathParts, Any]]
1419
+ Path and value pairs.
1420
+
1421
+ Examples
1422
+ --------
1423
+ .. code-block:: python
1424
+
1425
+ >>> import brainstate
1426
+ >>> import jax.numpy as jnp
1427
+
1428
+ >>> class Linear(brainstate.nn.Module):
1429
+ ... def __init__(self, din, dout):
1430
+ ... super().__init__()
1431
+ ... self.weight = brainstate.ParamState(brainstate.random.randn(din, dout))
1432
+ ... self.bias = brainstate.ParamState(brainstate.random.randn(dout))
1433
+ ... self.a = 1
1434
+ ...
1435
+ >>> module = Linear(3, 4)
1436
+ ...
1437
+ >>> for path, value in brainstate.graph.iter_leaf([module, module]):
1438
+ ... print(path, type(value).__name__)
1439
+ ...
1440
+ (0, 'a') int
1441
+ (0, 'bias') ParamState
1442
+ (0, 'weight') ParamState
1572
1443
 
1573
1444
  """
1574
1445
 
@@ -1605,8 +1476,7 @@ def iter_leaf(
1605
1476
 
1606
1477
  @set_module_as('brainstate.graph')
1607
1478
  def iter_node(
1608
- node: Any,
1609
- allowed_hierarchy: Tuple[int, int] = (0, _max_int)
1479
+ node: Any, allowed_hierarchy: Tuple[int, int] = (0, MAX_INT)
1610
1480
  ) -> Iterator[Tuple[PathParts, Any]]:
1611
1481
  """
1612
1482
  Iterates over all nested nodes of the given graph node, including the current node.
@@ -1615,39 +1485,47 @@ def iter_node(
1615
1485
  the path is a tuple of strings or integers representing the path to the value from the
1616
1486
  root. Repeated nodes are visited only once. Leaves include static values.
1617
1487
 
1618
- Example::
1619
- >>> import brainstate as brainstate
1620
- >>> import jax.numpy as jnp
1621
- ...
1622
- >>> class Model(brainstate.nn.Module):
1623
- ... def __init__(self):
1624
- ... super().__init__()
1625
- ... self.a = brainstate.nn.Linear(1, 2)
1626
- ... self.b = brainstate.nn.Linear(2, 3)
1627
- ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1628
- ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1629
- ... self.b.a = brainstate.nn.LIF(2)
1630
- ...
1631
- >>> model = Model()
1632
- ...
1633
- >>> for path, node in brainstate.graph.iter_node([model, model]):
1634
- ... print(path, node.__class__.__name__)
1635
- ...
1636
- (0, 'a') Linear
1637
- (0, 'b', 'a') LIF
1638
- (0, 'b') Linear
1639
- (0, 'c', 0) Linear
1640
- (0, 'c', 1) Linear
1641
- (0, 'd', 'x') Linear
1642
- (0, 'd', 'y') Linear
1643
- (0,) Model
1644
-
1645
1488
  Parameters
1646
1489
  ----------
1647
- node: Node
1648
- The node to iterate over.
1649
- allowed_hierarchy: tuple of int
1650
- The allowed hierarchy.
1490
+ node : Any
1491
+ The node to iterate over.
1492
+ allowed_hierarchy : tuple[int, int], optional
1493
+ The allowed hierarchy levels, by default (0, MAX_INT).
1494
+
1495
+ Yields
1496
+ ------
1497
+ Iterator[tuple[PathParts, Any]]
1498
+ Path and node pairs.
1499
+
1500
+ Examples
1501
+ --------
1502
+ .. code-block:: python
1503
+
1504
+ >>> import brainstate
1505
+ >>> import jax.numpy as jnp
1506
+
1507
+ >>> class Model(brainstate.nn.Module):
1508
+ ... def __init__(self):
1509
+ ... super().__init__()
1510
+ ... self.a = brainstate.nn.Linear(1, 2)
1511
+ ... self.b = brainstate.nn.Linear(2, 3)
1512
+ ... self.c = [brainstate.nn.Linear(3, 4), brainstate.nn.Linear(4, 5)]
1513
+ ... self.d = {'x': brainstate.nn.Linear(5, 6), 'y': brainstate.nn.Linear(6, 7)}
1514
+ ... self.b.a = brainstate.nn.LIF(2)
1515
+ ...
1516
+ >>> model = Model()
1517
+ ...
1518
+ >>> for path, node in brainstate.graph.iter_node([model, model]):
1519
+ ... print(path, node.__class__.__name__)
1520
+ ...
1521
+ (0, 'a') Linear
1522
+ (0, 'b', 'a') LIF
1523
+ (0, 'b') Linear
1524
+ (0, 'c', 0) Linear
1525
+ (0, 'c', 1) Linear
1526
+ (0, 'd', 'x') Linear
1527
+ (0, 'd', 'y') Linear
1528
+ (0,) Model
1651
1529
 
1652
1530
  """
1653
1531
 
@@ -1686,8 +1564,16 @@ def iter_node(
1686
1564
 
1687
1565
  @dataclasses.dataclass(frozen=True)
1688
1566
  class Static(Generic[A]):
1689
- """An empty pytree node that treats its inner value as static.
1567
+ """
1568
+ An empty pytree node that treats its inner value as static.
1569
+
1690
1570
  ``value`` must define ``__eq__`` and ``__hash__``.
1571
+
1572
+ Attributes
1573
+ ----------
1574
+ value : A
1575
+ The static value to wrap.
1576
+
1691
1577
  """
1692
1578
 
1693
1579
  value: A
@@ -1721,16 +1607,16 @@ def _key_path_to_key(key: Any) -> Key:
1721
1607
  return str(key)
1722
1608
 
1723
1609
 
1724
- def _flatten_pytree(pytree: Any):
1610
+ def _flatten_pytree(pytree: Any) -> Tuple[Tuple[Tuple, ...], jax.tree_util.PyTreeDef]:
1725
1611
  leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree, is_leaf=lambda x: x is not pytree)
1726
1612
  nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
1727
1613
  return nodes, treedef
1728
1614
 
1729
1615
 
1730
1616
  def _unflatten_pytree(
1731
- nodes: tuple[tuple[Key, Any], ...],
1617
+ nodes: tuple[tuple, ...],
1732
1618
  treedef: jax.tree_util.PyTreeDef
1733
- ):
1619
+ ) -> Any:
1734
1620
  pytree = treedef.unflatten(value for _, value in nodes)
1735
1621
  return pytree
1736
1622