brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,71 +12,87 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- from collections import namedtuple
17
- from typing import Callable, TypeVar, Tuple, Any, Dict
15
+ import warnings
16
+ from collections.abc import Sequence, Mapping
17
+ from typing import Callable, TypeVar, Any
18
18
 
19
19
  import jax
20
20
 
21
21
  from brainstate._state import catch_new_states
22
22
  from brainstate._utils import set_module_as
23
- from brainstate.augment import vmap, vmap_new_states
24
23
  from brainstate.graph import nodes
25
- from brainstate.random import set_key, split_key
24
+ from brainstate.transform import vmap, vmap_new_states
26
25
  from brainstate.typing import Filter
27
26
  from ._module import Module
28
27
 
29
28
  # the maximum order
30
29
  MAX_ORDER = 10
31
30
 
32
- # State Load Results
33
- StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
34
-
35
31
  T = TypeVar('T', bound=Module)
36
32
 
37
33
  __all__ = [
38
- 'MAX_ORDER',
39
34
  'call_order',
40
- 'call_all_functions',
41
- 'vmap_call_all_functions',
35
+ 'call_all_fns',
36
+ 'vmap_call_all_fns',
42
37
  'init_all_states',
43
38
  'vmap_init_all_states',
44
39
  'reset_all_states',
45
- 'load_all_states',
46
- 'save_all_states',
40
+ 'vmap_reset_all_states',
47
41
  'assign_state_values',
48
42
  ]
49
43
 
50
44
 
51
45
  @set_module_as('brainstate.nn')
52
- def call_order(level: int = 0, check_order_boundary: bool = True):
53
- """The decorator for indicating the resetting level.
54
-
55
- The function takes an optional integer argument level with a default value of 0.
56
-
57
- The lower the level, the earlier the function is called.
46
+ def call_order(
47
+ level: int = 0,
48
+ check_order_boundary: bool = True
49
+ ) -> Callable[[Callable], Callable]:
50
+ """
51
+ Decorator for specifying the execution order of functions in collective operations.
58
52
 
59
- >>> import brainstate as brainstate
60
- >>> brainstate.nn.call_order(0)
61
- >>> brainstate.nn.call_order(-1)
62
- >>> brainstate.nn.call_order(-2)
53
+ This decorator attaches a `call_order` attribute to a function, which is used by
54
+ collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
55
+ to determine the execution order. Functions with lower order levels are executed first.
63
56
 
64
57
  Parameters
65
58
  ----------
66
- level: int
67
- The call order level.
68
- check_order_boundary: bool
69
- Whether check the boundary of function call order. If True,
70
- the order that not in [0, 10) will raise a ValueError.
59
+ level : int, optional
60
+ The execution order level. Lower values indicate earlier execution.
61
+ Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
62
+ Default is 0.
63
+ check_order_boundary : bool, optional
64
+ Whether to validate that the order level is within the valid range [0, MAX_ORDER).
65
+ Default is True.
71
66
 
72
67
  Returns
73
68
  -------
74
- The function to warp.
69
+ Callable[[Callable], Callable]
70
+ A decorator function that adds the `call_order` attribute to the decorated function.
71
+
72
+ Raises
73
+ ------
74
+ ValueError
75
+ If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
76
+
77
+ Examples
78
+ --------
79
+ .. code-block:: python
80
+
81
+ >>> import brainstate
82
+ >>>
83
+ >>> class MyModule(brainstate.nn.Module):
84
+ ... @brainstate.nn.call_order(0)
85
+ ... def reset_state(self):
86
+ ... print("Reset first")
87
+ ...
88
+ ... @brainstate.nn.call_order(1)
89
+ ... def another_reset(self):
90
+ ... print("Reset second")
75
91
  """
76
92
  if check_order_boundary and (level < 0 or level >= MAX_ORDER):
77
- raise ValueError(f'"call_order" must be an integer in [0, {MAX_ORDER}). but we got {level}.')
93
+ raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
78
94
 
79
- def wrap(fun: Callable):
95
+ def wrap(fun: Callable) -> Callable:
80
96
  fun.call_order = level
81
97
  return fun
82
98
 
@@ -84,164 +100,196 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
84
100
 
85
101
 
86
102
  @set_module_as('brainstate.nn')
87
- def call_all_functions(
103
+ def call_all_fns(
88
104
  target: T,
89
- fun_name: str,
90
- args: Tuple[Any, ...] | Any = (),
91
- kwargs: Dict[str, Any] | None = None,
105
+ fn_name: str,
106
+ args: Sequence[Any] | Any = (),
107
+ kwargs: Mapping[str, Any] | None = None,
92
108
  node_to_exclude: Filter = None,
93
- fun_if_not_exist: str = 'raise',
109
+ fn_if_not_exist: str = 'raise',
94
110
  ) -> T:
95
111
  """
96
- Call a specified function on all nodes of a target module, respecting call order if defined.
112
+ Call a specified function on all module nodes within a target, respecting call order.
97
113
 
98
- This function iterates through all nodes of the target module, calling a specified function
99
- on each node. It respects the call order of functions if defined, and provides options for
100
- handling cases where the specified function does not exist on a node.
114
+ This function traverses all module nodes in the target and invokes the specified method
115
+ on each node. Functions decorated with `@call_order()` are executed in ascending order
116
+ of their level values, while functions without the decorator are executed first.
101
117
 
102
118
  Parameters
103
- -----------
104
- target : T
119
+ ----------
120
+ target : Module
105
121
  The target module on which to call functions.
106
- fun_name : str
107
- The name of the function to call on each node.
108
- args : Tuple[Any, ...] | Any, optional
109
- Positional arguments to pass to the called function. Default is an empty tuple.
110
- kwargs : Dict[str, Any] | None, optional
111
- Keyword arguments to pass to the called function. Default is None.
122
+ fn_name : str
123
+ The name of the method to call on each module node.
112
124
  node_to_exclude : Filter, optional
113
- A filter function to exclude certain nodes from the function call.
114
- fun_if_not_exist : str, optional
115
- Specifies behavior when the function doesn't exist on a node. Options are:
116
-
117
- - 'raise': Raise an exception (default)
118
- - 'pass' or 'none': Skip the node and continue
119
-
120
- Returns
121
- --------
122
- T
123
- The target module after calling the specified function on all applicable nodes.
125
+ A filter to exclude certain nodes from the function call.
126
+ Can be a type, predicate function, or any filter supported by the graph API.
127
+ fn_if_not_exist : str, optional
128
+ Behavior when the specified method doesn't exist on a node:
129
+
130
+ - 'raise': Raise an AttributeError (default)
131
+ - 'pass' or 'none': Skip the node silently
132
+ - 'warn': Issue a warning and skip the node
133
+ args
134
+ Positional arguments to pass to the called method. A single non-tuple
135
+ argument will be automatically wrapped in a tuple. Default is ().
136
+ kwargs
137
+ Keyword arguments to pass to the called method. Default is None.
124
138
 
125
139
  Raises
126
- -------
127
- AssertionError
128
- If fun_name is not a string or kwargs is not a dictionary.
140
+ ------
141
+ TypeError
142
+ If `fun_name` is not a string or `kwargs` is not a mapping.
129
143
  ValueError
130
- If fun_if_not_exist is not one of the allowed values.
144
+ If `fn_if_not_exist` is not one of the allowed values.
131
145
  AttributeError
132
- If the specified function doesn't exist on a node and fun_if_not_exist is 'raise'.
146
+ If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
147
+
148
+ Examples
149
+ --------
150
+ .. code-block:: python
151
+
152
+ >>> import brainstate
153
+ >>>
154
+ >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
155
+ >>> brainstate.nn.call_all_fns(net, 'init_state')
133
156
  """
134
- assert isinstance(fun_name, str), f'fun_name must be a string, but got {fun_name}.'
157
+ if not isinstance(fn_name, str):
158
+ raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
135
159
 
136
160
  args = (args,) if not isinstance(args, tuple) else args
137
161
  kwargs = kwargs or {}
138
- assert isinstance(kwargs, dict), f'kwargs must be a dict, but got {kwargs}.'
162
+ if not isinstance(kwargs, Mapping):
163
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
139
164
 
140
165
  all_nodes = nodes(target).filter(Module)
141
166
  if node_to_exclude is not None:
142
167
  all_nodes -= all_nodes.filter(node_to_exclude)
143
168
 
169
+ # Separate nodes with and without call_order
144
170
  nodes_with_order = []
145
- for node in all_nodes.values():
171
+ for path, node in all_nodes.items():
146
172
  try:
147
- fun = getattr(node, fun_name)
173
+ fun = getattr(node, fn_name)
148
174
  except AttributeError as e:
149
- if fun_if_not_exist == 'raise':
150
- raise
151
- elif fun_if_not_exist in ('pass', 'none'):
175
+ if fn_if_not_exist == 'raise':
176
+ raise AttributeError(
177
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
178
+ ) from e
179
+ elif fn_if_not_exist in ('pass', 'none'):
180
+ continue
181
+ elif fn_if_not_exist == 'warn':
182
+ warnings.warn(
183
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
184
+ f"Skipping.",
185
+ UserWarning
186
+ )
152
187
  continue
153
188
  else:
154
189
  raise ValueError(
155
- f'fun_if_not_exist must be one of ["raise", "pass", "none"], but got {fun_if_not_exist}.')
190
+ f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
191
+ )
192
+
193
+ if not callable(fun):
194
+ raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
156
195
 
157
- assert callable(fun), f'{fun_name} must be a callable function, but got {fun}.'
158
196
  if hasattr(fun, 'call_order'):
159
197
  nodes_with_order.append(node)
160
198
  else:
161
199
  fun(*args, **kwargs)
162
200
 
163
- for node in sorted(nodes_with_order, key=lambda x: getattr(x, fun_name).call_order):
164
- getattr(node, fun_name)(*args, **kwargs)
165
-
201
+ # Execute nodes with call_order in sorted order
202
+ for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
203
+ getattr(node, fn_name)(*args, **kwargs)
166
204
  return target
167
205
 
168
206
 
169
- def vmap_call_all_functions(
207
+ def vmap_call_all_fns(
170
208
  target: T,
171
- fun_name: str,
172
- args: Tuple[Any, ...] | Any = (),
173
- kwargs: Dict[str, Any] | None = None,
209
+ fn_name: str,
210
+ args: Sequence[Any] | Any = (),
211
+ kwargs: Mapping[str, Any] | None = None,
174
212
  axis_size: int = None,
175
213
  node_to_exclude: Filter = None,
176
- tag: str | None = None,
177
- fun_if_not_exist: str = 'raise',
214
+ state_tag: str | None = None,
215
+ fn_if_not_exist: str = 'raise',
178
216
  ) -> T:
179
217
  """
180
- Apply vectorized mapping (vmap) to call a specified function on all nodes of a target module.
218
+ Apply vectorized mapping to call a function on all module nodes with batched state handling.
181
219
 
182
- This function vectorizes the process of calling a specified function across multiple instances
183
- of the target module, effectively batching the operation.
220
+ This function creates multiple batched instances by applying vmap to the specified method
221
+ call across all module nodes. Each batch element maintains its own random key and state
222
+ values. This is particularly useful for creating ensembles or batched models.
184
223
 
185
224
  Parameters
186
- -----------
187
- target : T
225
+ ----------
226
+ target : Module
188
227
  The target module on which to call functions.
189
- fun_name : str
190
- The name of the function to call on each node.
191
- args : Tuple[Any, ...] | Any, optional
192
- Positional arguments to pass to the called function. Default is an empty tuple.
193
- kwargs : Dict[str, Any] | None, optional
194
- Keyword arguments to pass to the called function. Default is None.
195
- axis_size : int, optional
196
- The size of the batch axis for vmap. Must be a positive integer.
228
+ fn_name : str
229
+ The name of the method to call on each module node.
230
+ args : Sequence[Any] or Any, optional
231
+ Positional arguments to pass to the called method. A single non-tuple
232
+ argument will be automatically wrapped in a tuple. Default is ().
233
+ kwargs : Mapping[str, Any], optional
234
+ Keyword arguments to pass to the called method. Default is None.
235
+ axis_size : int
236
+ The size of the batch dimension for vmap. Must be a positive integer.
197
237
  node_to_exclude : Filter, optional
198
- A filter function to exclude certain nodes from the function call.
199
- tag : str | None, optional
200
- A tag to be used for catching new states.
201
- fun_if_not_exist : str, optional
202
- Specifies behavior when the function doesn't exist on a node. Options are:
238
+ A filter to exclude certain nodes from the function call.
239
+ state_tag : str, optional
240
+ An optional tag to categorize newly created states during the vmap operation.
241
+ fn_if_not_exist : str, optional
242
+ Behavior when the specified method doesn't exist on a node:
203
243
 
204
- - 'raise': Raise an exception (default)
205
- - 'pass' or 'none': Skip the node and continue
244
+ - 'raise': Raise an AttributeError (default)
245
+ - 'pass' or 'none': Skip the node silently
246
+ - 'warn': Issue a warning and skip the node
206
247
 
207
- Returns
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If `axis_size` is None or not a positive integer.
252
+ TypeError
253
+ If `kwargs` is not a mapping.
254
+
255
+ Examples
208
256
  --------
209
- T
210
- The target module after applying the vectorized function call on all applicable nodes.
257
+ .. code-block:: python
211
258
 
212
- Raises
213
- -------
214
- AssertionError
215
- If axis_size is not specified or is not a positive integer.
259
+ >>> import brainstate
260
+ >>>
261
+ >>> net = brainstate.nn.Linear(10, 20)
262
+ >>> # Create 5 batched instances with different initializations
263
+ >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
216
264
  """
217
- assert axis_size is not None and axis_size > 0, f"axis_size must be a positive integer, got {axis_size}"
265
+
266
+ if axis_size is None or axis_size <= 0:
267
+ raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
218
268
 
219
269
  if not isinstance(args, tuple):
220
270
  args = (args,)
221
271
  kwargs = kwargs or {}
222
- assert isinstance(kwargs, dict), f'kwargs must be a dict, but got {kwargs}.'
272
+ if not isinstance(kwargs, Mapping):
273
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
223
274
 
224
- @vmap(out_axes=0, axis_size=axis_size)
225
- def vmapped_fn(key):
226
- set_key(key)
227
- with catch_new_states(tag) as inner_catcher:
228
- call_all_functions(
275
+ @vmap(axis_size=axis_size)
276
+ def vmapped_fn():
277
+ with catch_new_states(state_tag) as inner_catcher:
278
+ call_all_fns(
229
279
  target,
230
- fun_name=fun_name,
280
+ fn_name=fn_name,
231
281
  args=args,
232
282
  kwargs=kwargs,
233
283
  node_to_exclude=node_to_exclude,
234
- fun_if_not_exist=fun_if_not_exist
284
+ fn_if_not_exist=fn_if_not_exist
235
285
  )
236
- values = inner_catcher.get_state_values()
237
- return values
286
+ return inner_catcher.get_state_values()
238
287
 
239
- with catch_new_states(tag) as outer_catcher:
240
- values = vmapped_fn(split_key(axis_size))
288
+ with catch_new_states(state_tag) as outer_catcher:
289
+ values = vmapped_fn()
241
290
  states = outer_catcher.get_states()
242
291
  for state, value in zip(states, values):
243
292
  state.value = value
244
-
245
293
  return target
246
294
 
247
295
 
@@ -253,88 +301,116 @@ def init_all_states(
253
301
  **init_kwargs,
254
302
  ) -> T:
255
303
  """
256
- Initialize all states for the given target module and its submodules.
304
+ Initialize states for all module nodes within the target.
257
305
 
258
- This function initializes the states of the target module and all its submodules,
259
- respecting any call order decorators that may be present on the init_state methods.
306
+ This is a convenience wrapper around `call_all_functions` that specifically calls
307
+ the `init_state` method on all module nodes. The execution order respects any
308
+ `@call_order()` decorators on the `init_state` methods.
260
309
 
261
310
  Parameters
262
311
  ----------
263
- target : T
312
+ target : Module
264
313
  The target module whose states are to be initialized.
265
- init_args : Tuple[Any, ...] | Any, optional
266
- Positional arguments to be passed to each init_state method.
267
- If a single non-tuple argument is provided, it will be wrapped in a tuple.
268
- init_kwargs : Dict[str, Any] | None, optional
269
- Keyword arguments to be passed to each init_state method.
270
- If None, an empty dictionary will be used.
314
+ *init_args
315
+ Variable positional arguments to pass to each `init_state` method.
271
316
  node_to_exclude : Filter, optional
272
- A filter function or predicate to exclude certain nodes from initialization.
273
-
274
- Returns
275
- -------
276
- T
277
- The target module with all states initialized.
317
+ A filter to exclude certain nodes from initialization.
318
+ Can be a type, predicate function, or any filter supported by the graph API.
319
+ **init_kwargs
320
+ Variable keyword arguments to pass to each `init_state` method.
278
321
 
279
- Raises
280
- ------
281
- AssertionError
282
- If init_kwargs is provided but is not a dictionary.
322
+ Examples
323
+ --------
324
+ .. code-block:: python
325
+
326
+ >>> import brainstate
327
+ >>>
328
+ >>> net = brainstate.nn.Sequential(
329
+ ... brainstate.nn.Linear(10, 20),
330
+ ... brainstate.nn.Dropout(0.5)
331
+ ... )
332
+ >>> # Initialize all states
333
+ >>> brainstate.nn.init_all_states(net)
334
+ >>>
335
+ >>> # Initialize with custom arguments
336
+ >>> brainstate.nn.init_all_states(net, batch_size=32)
337
+
338
+ See Also
339
+ --------
340
+ call_all_functions : The underlying function that executes the calls.
341
+ vmap_init_all_states : Vectorized version for batched initialization.
283
342
  """
284
- return call_all_functions(target, 'init_state', init_args, init_kwargs, node_to_exclude)
343
+ call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
344
+ return target
285
345
 
286
346
 
287
347
  @set_module_as('brainstate.nn')
288
348
  def vmap_init_all_states(
289
349
  target: T,
290
- *init_args: Tuple[Any, ...] | Any,
350
+ *init_args,
291
351
  axis_size: int = None,
292
352
  node_to_exclude: Filter = None,
293
353
  state_to_exclude: Filter = None,
294
354
  state_tag: str | None = None,
295
- **init_kwargs: Dict[str, Any] | None
355
+ **init_kwargs
296
356
  ) -> T:
297
357
  """
298
- Initialize all vmap states for the given target module.
358
+ Initialize states with vectorized mapping for creating batched module instances.
299
359
 
300
- This function applies vectorized mapping (vmap) to initialize states across multiple
301
- instances of the target module, effectively batching the initialization process.
360
+ This function applies vmap to the initialization process, creating multiple batched
361
+ instances of module states. Each batch element will have independent state values
362
+ and random keys. This is useful for ensemble models or parameter sweeps.
302
363
 
303
364
  Parameters
304
- -----------
305
- target : T
365
+ ----------
366
+ target : Module
306
367
  The target module whose states are to be initialized.
307
- init_args : Tuple[Any, ...] | Any, optional
308
- Positional arguments to be passed to the init_all_states function. Default is an empty tuple.
309
- init_kwargs : Dict[str, Any] | None, optional
310
- Keyword arguments to be passed to the init_all_states function. Default is None.
311
- axis_size : int, optional
312
- The size of the batch axis for vmap. This must be specified and should be greater than 0.
368
+ *init_args
369
+ Variable positional arguments to pass to each `init_state` method.
370
+ axis_size : int
371
+ The size of the batch dimension. Must be a positive integer.
313
372
  node_to_exclude : Filter, optional
314
373
  A filter to exclude certain nodes from initialization.
315
- state_tag : str | None, optional
316
- A tag to be used for catching new states.
317
-
318
- Returns
319
- --------
320
- T
321
- The target module with initialized states.
374
+ state_to_exclude : Filter, optional
375
+ A filter to exclude certain states from being vmapped.
376
+ Excluded states will remain shared across all batched instances.
377
+ state_tag : str, optional
378
+ An optional tag to categorize newly created states.
379
+ **init_kwargs
380
+ Variable keyword arguments to pass to each `init_state` method.
322
381
 
323
382
  Raises
324
- -------
325
- AssertionError
326
- If axis_size is not specified or is not greater than 0.
327
- If init_kwargs is not a dictionary.
383
+ ------
384
+ ValueError
385
+ If `axis_size` is None or not a positive integer.
386
+
387
+ Examples
388
+ --------
389
+ .. code-block:: python
390
+
391
+ >>> import brainstate
392
+ >>>
393
+ >>> net = brainstate.nn.Linear(10, 20)
394
+ >>> # Create 8 batched instances with different random initializations
395
+ >>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
396
+ >>>
397
+ >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
398
+ >>> print(net.weight.shape)
399
+
400
+ See Also
401
+ --------
402
+ init_all_states : Non-vectorized version.
403
+ vmap_new_states : The underlying vmap transformation for states.
328
404
  """
329
405
 
330
- # return vmap_call_all_functions(
406
+ # vmap_call_all_functions(
331
407
  # target,
332
- # 'init_state',
408
+ # fun_name='init_state',
333
409
  # args=init_args,
334
410
  # kwargs=init_kwargs,
335
411
  # axis_size=axis_size,
336
412
  # node_to_exclude=node_to_exclude,
337
- # tag=tag,
413
+ # state_tag=state_tag,
338
414
  # )
339
415
 
340
416
  def init_fn():
@@ -353,162 +429,205 @@ def vmap_init_all_states(
353
429
  @set_module_as('brainstate.nn')
354
430
  def reset_all_states(
355
431
  target: T,
356
- reset_args: Tuple[Any, ...] | Any = (),
357
- reset_kwargs: Dict[str, Any] | None = None,
432
+ *reset_args,
358
433
  node_to_exclude: Filter = None,
434
+ **reset_kwargs,
359
435
  ) -> T:
360
436
  """
361
- Reset all states for the given target module and its submodules.
437
+ Reset states for all module nodes within the target.
362
438
 
363
- This function resets the states of the target module and all its submodules,
364
- respecting any call order decorators that may be present on the reset_state methods.
439
+ This is a convenience wrapper around `call_all_functions` that specifically calls
440
+ the `reset_state` method on all module nodes. The execution order respects any
441
+ `@call_order()` decorators on the `reset_state` methods. This is typically used
442
+ to reset recurrent neural network states between sequences.
365
443
 
366
444
  Parameters
367
445
  ----------
368
- target : T
446
+ target : Module
369
447
  The target module whose states are to be reset.
370
- reset_args : Tuple[Any, ...] | Any, optional
371
- Positional arguments to be passed to each reset_state method.
372
- If a single non-tuple argument is provided, it will be wrapped in a tuple.
373
- reset_kwargs : Dict[str, Any] | None, optional
374
- Keyword arguments to be passed to each reset_state method.
375
- If None, an empty dictionary will be used.
448
+ reset_args
449
+ Positional arguments to pass to each `reset_state` method.
450
+ A single non-tuple argument will be automatically wrapped in a tuple.
451
+ Default is ().
452
+ reset_kwargs
453
+ Keyword arguments to pass to each `reset_state` method.
454
+ Default is None.
376
455
  node_to_exclude : Filter, optional
377
- A filter function or predicate to exclude certain nodes from reset.
378
-
379
- Returns
380
- -------
381
- T
382
- The target module with all states reset.
456
+ A filter to exclude certain nodes from reset.
457
+ Can be a type, predicate function, or any filter supported by the graph API.
383
458
 
384
- Raises
385
- ------
386
- AssertionError
387
- If init_kwargs is provided but is not a dictionary.
459
+ Examples
460
+ --------
461
+ .. code-block:: python
462
+
463
+ >>> import brainstate
464
+ >>>
465
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
466
+ >>> brainstate.nn.init_all_states(rnn, batch_size=32)
467
+ >>>
468
+ >>> # Process a sequence
469
+ >>> for x in sequence:
470
+ ... output = rnn(x)
471
+ >>>
472
+ >>> # Reset states before processing next sequence
473
+ >>> brainstate.nn.reset_all_states(rnn)
474
+
475
+ See Also
476
+ --------
477
+ call_all_functions : The underlying function that executes the calls.
478
+ vmap_reset_all_states : Vectorized version for batched reset.
388
479
  """
389
- return call_all_functions(
480
+ call_all_fns(
390
481
  target,
391
- fun_name='reset_state',
482
+ fn_name='reset_state',
392
483
  args=reset_args,
393
484
  kwargs=reset_kwargs,
394
485
  node_to_exclude=node_to_exclude
395
486
  )
487
+ return target
396
488
 
397
489
 
398
490
  def vmap_reset_all_states(
399
491
  target: T,
400
- reset_args: Tuple[Any, ...] | Any = (),
401
- reset_kwargs: Dict[str, Any] | None = None,
492
+ *reset_args,
402
493
  axis_size: int = None,
403
494
  node_to_exclude: Filter = None,
404
- tag: str | None = None,
495
+ state_tag: str | None = None,
496
+ **reset_kwargs,
405
497
  ) -> T:
406
498
  """
407
- Reset all vmap states for the given target module.
499
+ Reset states with vectorized mapping across batched module instances.
408
500
 
409
- This function applies vectorized mapping (vmap) to reset states across multiple
410
- instances of the target module, effectively batching the reset process.
501
+ This function applies vmap to the reset process, resetting states across all
502
+ batched instances of the module. Each batch element will have its state reset
503
+ independently with its own random key. This is useful when working with batched
504
+ recurrent models or ensembles.
411
505
 
412
506
  Parameters
413
- -----------
414
- target : T
507
+ ----------
508
+ target : Module
415
509
  The target module whose states are to be reset.
416
- reset_args : Tuple[Any, ...] | Any, optional
417
- Positional arguments to be passed to the reset_all_states function. Default is an empty tuple.
418
- reset_kwargs : Dict[str, Any] | None, optional
419
- Keyword arguments to be passed to the reset_all_states function. Default is None.
420
- axis_size : int, optional
421
- The size of the batch axis for vmap. This must be specified and should be greater than 0.
510
+ reset_args
511
+ Positional arguments to pass to each `reset_state` method.
512
+ A single non-tuple argument will be automatically wrapped in a tuple.
513
+ Default is ().
514
+ reset_kwargs
515
+ Keyword arguments to pass to each `reset_state` method.
516
+ Default is None.
517
+ axis_size : int
518
+ The size of the batch dimension. Must be a positive integer.
422
519
  node_to_exclude : Filter, optional
423
520
  A filter to exclude certain nodes from reset.
424
- tag : str | None, optional
425
- A tag to be used for catching new states.
426
-
427
- Returns
428
- --------
429
- T
430
- The target module with reset states.
521
+ state_tag : str, optional
522
+ An optional tag to categorize newly created states during the reset.
431
523
 
432
524
  Raises
433
- -------
434
- AssertionError
435
- If axis_size is not specified or is not greater than 0.
436
- If reset_kwargs is not a dictionary.
525
+ ------
526
+ ValueError
527
+ If `axis_size` is None or not a positive integer.
528
+ TypeError
529
+ If `reset_kwargs` is not a mapping.
530
+
531
+ Examples
532
+ --------
533
+ .. code-block:: python
534
+
535
+ >>> import brainstate
536
+ >>>
537
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
538
+ >>> # Initialize with 16 batched instances
539
+ >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
540
+ >>>
541
+ >>> # Process sequences...
542
+ >>>
543
+ >>> # Reset all 16 batched instances
544
+ >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
545
+
546
+ See Also
547
+ --------
548
+ reset_all_states : Non-vectorized version.
549
+ vmap_call_all_functions : The underlying vmap function call mechanism.
437
550
  """
438
- return vmap_call_all_functions(
551
+ vmap_call_all_fns(
439
552
  target,
440
- fun_name='reset_state',
553
+ fn_name='reset_state',
441
554
  args=reset_args,
442
555
  kwargs=reset_kwargs,
443
556
  axis_size=axis_size,
444
557
  node_to_exclude=node_to_exclude,
445
- tag=tag,
558
+ state_tag=state_tag,
446
559
  )
560
+ return target
447
561
 
448
562
 
449
563
  @set_module_as('brainstate.nn')
450
- def load_all_states(target: Module, state_dict: Dict, **kwargs):
451
- """
452
- Copy parameters and buffers from :attr:`state_dict` into
453
- this module and its descendants.
454
-
455
- Args:
456
- target: Module. The dynamical system to load its states.
457
- state_dict: dict. A dict containing parameters and persistent buffers.
458
-
459
- Returns
460
- -------
461
- ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
462
-
463
- * **missing_keys** is a list of str containing the missing keys
464
- * **unexpected_keys** is a list of str containing the unexpected keys
564
+ def assign_state_values(
565
+ target: Module,
566
+ *state_by_abs_path: Mapping[str, Any]
567
+ ) -> tuple[list[str], list[str]]:
465
568
  """
466
- missing_keys = []
467
- unexpected_keys = []
468
- for path, node in nodes(target).items():
469
- r = node.load_state(state_dict[path], **kwargs)
470
- if r is not None:
471
- missing, unexpected = r
472
- missing_keys.extend([f'{path}.{key}' for key in missing])
473
- unexpected_keys.extend([f'{path}.{key}' for key in unexpected])
474
- return StateLoadResult(missing_keys, unexpected_keys)
569
+ Assign state values to a module from one or more state dictionaries.
475
570
 
571
+ This function updates the state values of a module based on provided state dictionaries.
572
+ State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
573
+ The function handles missing and unexpected keys, returning them for inspection.
476
574
 
477
- @set_module_as('brainstate.nn')
478
- def save_all_states(target: Module, **kwargs) -> Dict:
479
- """
480
- Save all states in the ``target`` as a dictionary for later disk serialization.
481
-
482
- Args:
483
- target: Module. The node to save its states.
575
+ Parameters
576
+ ----------
577
+ target : Module
578
+ The target module whose states will be updated.
579
+ *state_by_abs_path : Mapping[str, Any]
580
+ One or more state dictionaries with absolute path keys mapping to state values.
581
+ If multiple dictionaries are provided, they will be merged (later dictionaries
582
+ override earlier ones for duplicate keys).
484
583
 
485
584
  Returns
486
- Dict. The state dict for serialization.
487
- """
488
- return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
585
+ -------
586
+ tuple[list[str], list[str]]
587
+ A tuple of (unexpected_keys, missing_keys):
489
588
 
589
+ - unexpected_keys: Keys present in the state dictionaries but not in the module
590
+ - missing_keys: Keys present in the module but not in the state dictionaries
490
591
 
491
- @set_module_as('brainstate.nn')
492
- def assign_state_values(target: Module, *state_by_abs_path: Dict):
592
+ Examples
593
+ --------
594
+ .. code-block:: python
595
+
596
+ >>> import brainstate
597
+ >>>
598
+ >>> net = brainstate.nn.Linear(10, 20)
599
+ >>> brainstate.nn.init_all_states(net)
600
+ >>>
601
+ >>> # Save state values
602
+ >>> state_dict = {path: state.value for path, state in net.states().items()}
603
+ >>>
604
+ >>> # Later, restore state values
605
+ >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
606
+ >>> print(f"Unexpected keys: {unexpected}")
607
+ >>> print(f"Missing keys: {missing}")
608
+
609
+ Notes
610
+ -----
611
+ - All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
612
+ - Only states with matching keys are updated; unexpected and missing keys are
613
+ returned but do not cause errors.
614
+ - If multiple dictionaries contain the same key, the last one takes precedence.
493
615
  """
494
- Assign state values according to the given state dictionary.
616
+ # Merge all state dictionaries
617
+ all_states = {}
618
+ for state_dict in state_by_abs_path:
619
+ all_states.update(state_dict)
495
620
 
496
- Parameters
497
- ----------
498
- target: Module
499
- The target module.
500
- state_by_abs_path: dict
501
- The state dictionary which is accessed by the "absolute" accessing method.
502
-
503
- """
504
- all_states = dict()
505
- for state in state_by_abs_path:
506
- all_states.update(state)
621
+ # Get current module states
507
622
  variables = target.states()
508
623
  keys1 = set(all_states.keys())
509
624
  keys2 = set(variables.keys())
625
+
626
+ # Update matching states
510
627
  for key in keys2.intersection(keys1):
511
628
  variables[key].value = jax.numpy.asarray(all_states[key])
512
- unexpected_keys = list(keys1 - keys2)
513
- missing_keys = list(keys2 - keys1)
629
+
630
+ # Return mismatched keys
631
+ unexpected_keys = sorted(keys1 - keys2)
632
+ missing_keys = sorted(keys2 - keys1)
514
633
  return unexpected_keys, missing_keys