brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,514 +1,633 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from collections import namedtuple
17
- from typing import Callable, TypeVar, Tuple, Any, Dict
18
-
19
- import jax
20
-
21
- from brainstate._state import catch_new_states
22
- from brainstate._utils import set_module_as
23
- from brainstate.augment import vmap, vmap_new_states
24
- from brainstate.graph import nodes
25
- from brainstate.random import set_key, split_key
26
- from brainstate.typing import Filter
27
- from ._module import Module
28
-
29
- # the maximum order
30
- MAX_ORDER = 10
31
-
32
- # State Load Results
33
- StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
34
-
35
- T = TypeVar('T', bound=Module)
36
-
37
- __all__ = [
38
- 'MAX_ORDER',
39
- 'call_order',
40
- 'call_all_functions',
41
- 'vmap_call_all_functions',
42
- 'init_all_states',
43
- 'vmap_init_all_states',
44
- 'reset_all_states',
45
- 'load_all_states',
46
- 'save_all_states',
47
- 'assign_state_values',
48
- ]
49
-
50
-
51
- @set_module_as('brainstate.nn')
52
- def call_order(level: int = 0, check_order_boundary: bool = True):
53
- """The decorator for indicating the resetting level.
54
-
55
- The function takes an optional integer argument level with a default value of 0.
56
-
57
- The lower the level, the earlier the function is called.
58
-
59
- >>> import brainstate as brainstate
60
- >>> brainstate.nn.call_order(0)
61
- >>> brainstate.nn.call_order(-1)
62
- >>> brainstate.nn.call_order(-2)
63
-
64
- Parameters
65
- ----------
66
- level: int
67
- The call order level.
68
- check_order_boundary: bool
69
- Whether check the boundary of function call order. If True,
70
- the order that not in [0, 10) will raise a ValueError.
71
-
72
- Returns
73
- -------
74
- The function to warp.
75
- """
76
- if check_order_boundary and (level < 0 or level >= MAX_ORDER):
77
- raise ValueError(f'"call_order" must be an integer in [0, {MAX_ORDER}). but we got {level}.')
78
-
79
- def wrap(fun: Callable):
80
- fun.call_order = level
81
- return fun
82
-
83
- return wrap
84
-
85
-
86
- @set_module_as('brainstate.nn')
87
- def call_all_functions(
88
- target: T,
89
- fun_name: str,
90
- args: Tuple[Any, ...] | Any = (),
91
- kwargs: Dict[str, Any] | None = None,
92
- node_to_exclude: Filter = None,
93
- fun_if_not_exist: str = 'raise',
94
- ) -> T:
95
- """
96
- Call a specified function on all nodes of a target module, respecting call order if defined.
97
-
98
- This function iterates through all nodes of the target module, calling a specified function
99
- on each node. It respects the call order of functions if defined, and provides options for
100
- handling cases where the specified function does not exist on a node.
101
-
102
- Parameters
103
- -----------
104
- target : T
105
- The target module on which to call functions.
106
- fun_name : str
107
- The name of the function to call on each node.
108
- args : Tuple[Any, ...] | Any, optional
109
- Positional arguments to pass to the called function. Default is an empty tuple.
110
- kwargs : Dict[str, Any] | None, optional
111
- Keyword arguments to pass to the called function. Default is None.
112
- node_to_exclude : Filter, optional
113
- A filter function to exclude certain nodes from the function call.
114
- fun_if_not_exist : str, optional
115
- Specifies behavior when the function doesn't exist on a node. Options are:
116
-
117
- - 'raise': Raise an exception (default)
118
- - 'pass' or 'none': Skip the node and continue
119
-
120
- Returns
121
- --------
122
- T
123
- The target module after calling the specified function on all applicable nodes.
124
-
125
- Raises
126
- -------
127
- AssertionError
128
- If fun_name is not a string or kwargs is not a dictionary.
129
- ValueError
130
- If fun_if_not_exist is not one of the allowed values.
131
- AttributeError
132
- If the specified function doesn't exist on a node and fun_if_not_exist is 'raise'.
133
- """
134
- assert isinstance(fun_name, str), f'fun_name must be a string, but got {fun_name}.'
135
-
136
- args = (args,) if not isinstance(args, tuple) else args
137
- kwargs = kwargs or {}
138
- assert isinstance(kwargs, dict), f'kwargs must be a dict, but got {kwargs}.'
139
-
140
- all_nodes = nodes(target).filter(Module)
141
- if node_to_exclude is not None:
142
- all_nodes -= all_nodes.filter(node_to_exclude)
143
-
144
- nodes_with_order = []
145
- for node in all_nodes.values():
146
- try:
147
- fun = getattr(node, fun_name)
148
- except AttributeError as e:
149
- if fun_if_not_exist == 'raise':
150
- raise
151
- elif fun_if_not_exist in ('pass', 'none'):
152
- continue
153
- else:
154
- raise ValueError(
155
- f'fun_if_not_exist must be one of ["raise", "pass", "none"], but got {fun_if_not_exist}.')
156
-
157
- assert callable(fun), f'{fun_name} must be a callable function, but got {fun}.'
158
- if hasattr(fun, 'call_order'):
159
- nodes_with_order.append(node)
160
- else:
161
- fun(*args, **kwargs)
162
-
163
- for node in sorted(nodes_with_order, key=lambda x: getattr(x, fun_name).call_order):
164
- getattr(node, fun_name)(*args, **kwargs)
165
-
166
- return target
167
-
168
-
169
- def vmap_call_all_functions(
170
- target: T,
171
- fun_name: str,
172
- args: Tuple[Any, ...] | Any = (),
173
- kwargs: Dict[str, Any] | None = None,
174
- axis_size: int = None,
175
- node_to_exclude: Filter = None,
176
- tag: str | None = None,
177
- fun_if_not_exist: str = 'raise',
178
- ) -> T:
179
- """
180
- Apply vectorized mapping (vmap) to call a specified function on all nodes of a target module.
181
-
182
- This function vectorizes the process of calling a specified function across multiple instances
183
- of the target module, effectively batching the operation.
184
-
185
- Parameters
186
- -----------
187
- target : T
188
- The target module on which to call functions.
189
- fun_name : str
190
- The name of the function to call on each node.
191
- args : Tuple[Any, ...] | Any, optional
192
- Positional arguments to pass to the called function. Default is an empty tuple.
193
- kwargs : Dict[str, Any] | None, optional
194
- Keyword arguments to pass to the called function. Default is None.
195
- axis_size : int, optional
196
- The size of the batch axis for vmap. Must be a positive integer.
197
- node_to_exclude : Filter, optional
198
- A filter function to exclude certain nodes from the function call.
199
- tag : str | None, optional
200
- A tag to be used for catching new states.
201
- fun_if_not_exist : str, optional
202
- Specifies behavior when the function doesn't exist on a node. Options are:
203
-
204
- - 'raise': Raise an exception (default)
205
- - 'pass' or 'none': Skip the node and continue
206
-
207
- Returns
208
- --------
209
- T
210
- The target module after applying the vectorized function call on all applicable nodes.
211
-
212
- Raises
213
- -------
214
- AssertionError
215
- If axis_size is not specified or is not a positive integer.
216
- """
217
- assert axis_size is not None and axis_size > 0, f"axis_size must be a positive integer, got {axis_size}"
218
-
219
- if not isinstance(args, tuple):
220
- args = (args,)
221
- kwargs = kwargs or {}
222
- assert isinstance(kwargs, dict), f'kwargs must be a dict, but got {kwargs}.'
223
-
224
- @vmap(out_axes=0, axis_size=axis_size)
225
- def vmapped_fn(key):
226
- set_key(key)
227
- with catch_new_states(tag) as inner_catcher:
228
- call_all_functions(
229
- target,
230
- fun_name=fun_name,
231
- args=args,
232
- kwargs=kwargs,
233
- node_to_exclude=node_to_exclude,
234
- fun_if_not_exist=fun_if_not_exist
235
- )
236
- values = inner_catcher.get_state_values()
237
- return values
238
-
239
- with catch_new_states(tag) as outer_catcher:
240
- values = vmapped_fn(split_key(axis_size))
241
- states = outer_catcher.get_states()
242
- for state, value in zip(states, values):
243
- state.value = value
244
-
245
- return target
246
-
247
-
248
- @set_module_as('brainstate.nn')
249
- def init_all_states(
250
- target: T,
251
- *init_args,
252
- node_to_exclude: Filter = None,
253
- **init_kwargs,
254
- ) -> T:
255
- """
256
- Initialize all states for the given target module and its submodules.
257
-
258
- This function initializes the states of the target module and all its submodules,
259
- respecting any call order decorators that may be present on the init_state methods.
260
-
261
- Parameters
262
- ----------
263
- target : T
264
- The target module whose states are to be initialized.
265
- init_args : Tuple[Any, ...] | Any, optional
266
- Positional arguments to be passed to each init_state method.
267
- If a single non-tuple argument is provided, it will be wrapped in a tuple.
268
- init_kwargs : Dict[str, Any] | None, optional
269
- Keyword arguments to be passed to each init_state method.
270
- If None, an empty dictionary will be used.
271
- node_to_exclude : Filter, optional
272
- A filter function or predicate to exclude certain nodes from initialization.
273
-
274
- Returns
275
- -------
276
- T
277
- The target module with all states initialized.
278
-
279
- Raises
280
- ------
281
- AssertionError
282
- If init_kwargs is provided but is not a dictionary.
283
- """
284
- return call_all_functions(target, 'init_state', init_args, init_kwargs, node_to_exclude)
285
-
286
-
287
- @set_module_as('brainstate.nn')
288
- def vmap_init_all_states(
289
- target: T,
290
- *init_args: Tuple[Any, ...] | Any,
291
- axis_size: int = None,
292
- node_to_exclude: Filter = None,
293
- state_to_exclude: Filter = None,
294
- state_tag: str | None = None,
295
- **init_kwargs: Dict[str, Any] | None
296
- ) -> T:
297
- """
298
- Initialize all vmap states for the given target module.
299
-
300
- This function applies vectorized mapping (vmap) to initialize states across multiple
301
- instances of the target module, effectively batching the initialization process.
302
-
303
- Parameters
304
- -----------
305
- target : T
306
- The target module whose states are to be initialized.
307
- init_args : Tuple[Any, ...] | Any, optional
308
- Positional arguments to be passed to the init_all_states function. Default is an empty tuple.
309
- init_kwargs : Dict[str, Any] | None, optional
310
- Keyword arguments to be passed to the init_all_states function. Default is None.
311
- axis_size : int, optional
312
- The size of the batch axis for vmap. This must be specified and should be greater than 0.
313
- node_to_exclude : Filter, optional
314
- A filter to exclude certain nodes from initialization.
315
- state_tag : str | None, optional
316
- A tag to be used for catching new states.
317
-
318
- Returns
319
- --------
320
- T
321
- The target module with initialized states.
322
-
323
- Raises
324
- -------
325
- AssertionError
326
- If axis_size is not specified or is not greater than 0.
327
- If init_kwargs is not a dictionary.
328
- """
329
-
330
- # return vmap_call_all_functions(
331
- # target,
332
- # 'init_state',
333
- # args=init_args,
334
- # kwargs=init_kwargs,
335
- # axis_size=axis_size,
336
- # node_to_exclude=node_to_exclude,
337
- # tag=tag,
338
- # )
339
-
340
- def init_fn():
341
- init_all_states(
342
- target,
343
- *init_args,
344
- **init_kwargs,
345
- node_to_exclude=node_to_exclude,
346
- )
347
- return
348
-
349
- vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
350
- return target
351
-
352
-
353
- @set_module_as('brainstate.nn')
354
- def reset_all_states(
355
- target: T,
356
- reset_args: Tuple[Any, ...] | Any = (),
357
- reset_kwargs: Dict[str, Any] | None = None,
358
- node_to_exclude: Filter = None,
359
- ) -> T:
360
- """
361
- Reset all states for the given target module and its submodules.
362
-
363
- This function resets the states of the target module and all its submodules,
364
- respecting any call order decorators that may be present on the reset_state methods.
365
-
366
- Parameters
367
- ----------
368
- target : T
369
- The target module whose states are to be reset.
370
- reset_args : Tuple[Any, ...] | Any, optional
371
- Positional arguments to be passed to each reset_state method.
372
- If a single non-tuple argument is provided, it will be wrapped in a tuple.
373
- reset_kwargs : Dict[str, Any] | None, optional
374
- Keyword arguments to be passed to each reset_state method.
375
- If None, an empty dictionary will be used.
376
- node_to_exclude : Filter, optional
377
- A filter function or predicate to exclude certain nodes from reset.
378
-
379
- Returns
380
- -------
381
- T
382
- The target module with all states reset.
383
-
384
- Raises
385
- ------
386
- AssertionError
387
- If init_kwargs is provided but is not a dictionary.
388
- """
389
- return call_all_functions(
390
- target,
391
- fun_name='reset_state',
392
- args=reset_args,
393
- kwargs=reset_kwargs,
394
- node_to_exclude=node_to_exclude
395
- )
396
-
397
-
398
- def vmap_reset_all_states(
399
- target: T,
400
- reset_args: Tuple[Any, ...] | Any = (),
401
- reset_kwargs: Dict[str, Any] | None = None,
402
- axis_size: int = None,
403
- node_to_exclude: Filter = None,
404
- tag: str | None = None,
405
- ) -> T:
406
- """
407
- Reset all vmap states for the given target module.
408
-
409
- This function applies vectorized mapping (vmap) to reset states across multiple
410
- instances of the target module, effectively batching the reset process.
411
-
412
- Parameters
413
- -----------
414
- target : T
415
- The target module whose states are to be reset.
416
- reset_args : Tuple[Any, ...] | Any, optional
417
- Positional arguments to be passed to the reset_all_states function. Default is an empty tuple.
418
- reset_kwargs : Dict[str, Any] | None, optional
419
- Keyword arguments to be passed to the reset_all_states function. Default is None.
420
- axis_size : int, optional
421
- The size of the batch axis for vmap. This must be specified and should be greater than 0.
422
- node_to_exclude : Filter, optional
423
- A filter to exclude certain nodes from reset.
424
- tag : str | None, optional
425
- A tag to be used for catching new states.
426
-
427
- Returns
428
- --------
429
- T
430
- The target module with reset states.
431
-
432
- Raises
433
- -------
434
- AssertionError
435
- If axis_size is not specified or is not greater than 0.
436
- If reset_kwargs is not a dictionary.
437
- """
438
- return vmap_call_all_functions(
439
- target,
440
- fun_name='reset_state',
441
- args=reset_args,
442
- kwargs=reset_kwargs,
443
- axis_size=axis_size,
444
- node_to_exclude=node_to_exclude,
445
- tag=tag,
446
- )
447
-
448
-
449
- @set_module_as('brainstate.nn')
450
- def load_all_states(target: Module, state_dict: Dict, **kwargs):
451
- """
452
- Copy parameters and buffers from :attr:`state_dict` into
453
- this module and its descendants.
454
-
455
- Args:
456
- target: Module. The dynamical system to load its states.
457
- state_dict: dict. A dict containing parameters and persistent buffers.
458
-
459
- Returns
460
- -------
461
- ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
462
-
463
- * **missing_keys** is a list of str containing the missing keys
464
- * **unexpected_keys** is a list of str containing the unexpected keys
465
- """
466
- missing_keys = []
467
- unexpected_keys = []
468
- for path, node in nodes(target).items():
469
- r = node.load_state(state_dict[path], **kwargs)
470
- if r is not None:
471
- missing, unexpected = r
472
- missing_keys.extend([f'{path}.{key}' for key in missing])
473
- unexpected_keys.extend([f'{path}.{key}' for key in unexpected])
474
- return StateLoadResult(missing_keys, unexpected_keys)
475
-
476
-
477
- @set_module_as('brainstate.nn')
478
- def save_all_states(target: Module, **kwargs) -> Dict:
479
- """
480
- Save all states in the ``target`` as a dictionary for later disk serialization.
481
-
482
- Args:
483
- target: Module. The node to save its states.
484
-
485
- Returns
486
- Dict. The state dict for serialization.
487
- """
488
- return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
489
-
490
-
491
- @set_module_as('brainstate.nn')
492
- def assign_state_values(target: Module, *state_by_abs_path: Dict):
493
- """
494
- Assign state values according to the given state dictionary.
495
-
496
- Parameters
497
- ----------
498
- target: Module
499
- The target module.
500
- state_by_abs_path: dict
501
- The state dictionary which is accessed by the "absolute" accessing method.
502
-
503
- """
504
- all_states = dict()
505
- for state in state_by_abs_path:
506
- all_states.update(state)
507
- variables = target.states()
508
- keys1 = set(all_states.keys())
509
- keys2 = set(variables.keys())
510
- for key in keys2.intersection(keys1):
511
- variables[key].value = jax.numpy.asarray(all_states[key])
512
- unexpected_keys = list(keys1 - keys2)
513
- missing_keys = list(keys2 - keys1)
514
- return unexpected_keys, missing_keys
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import warnings
16
+ from collections.abc import Sequence, Mapping
17
+ from typing import Callable, TypeVar, Any
18
+
19
+ import jax
20
+
21
+ from brainstate._state import catch_new_states
22
+ from brainstate._utils import set_module_as
23
+ from brainstate.graph import nodes
24
+ from brainstate.transform import vmap, vmap_new_states
25
+ from brainstate.typing import Filter
26
+ from ._module import Module
27
+
28
+ # the maximum order
29
+ MAX_ORDER = 10
30
+
31
+ T = TypeVar('T', bound=Module)
32
+
33
+ __all__ = [
34
+ 'call_order',
35
+ 'call_all_fns',
36
+ 'vmap_call_all_fns',
37
+ 'init_all_states',
38
+ 'vmap_init_all_states',
39
+ 'reset_all_states',
40
+ 'vmap_reset_all_states',
41
+ 'assign_state_values',
42
+ ]
43
+
44
+
45
+ @set_module_as('brainstate.nn')
46
+ def call_order(
47
+ level: int = 0,
48
+ check_order_boundary: bool = True
49
+ ) -> Callable[[Callable], Callable]:
50
+ """
51
+ Decorator for specifying the execution order of functions in collective operations.
52
+
53
+ This decorator attaches a `call_order` attribute to a function, which is used by
54
+ collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
55
+ to determine the execution order. Functions with lower order levels are executed first.
56
+
57
+ Parameters
58
+ ----------
59
+ level : int, optional
60
+ The execution order level. Lower values indicate earlier execution.
61
+ Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
62
+ Default is 0.
63
+ check_order_boundary : bool, optional
64
+ Whether to validate that the order level is within the valid range [0, MAX_ORDER).
65
+ Default is True.
66
+
67
+ Returns
68
+ -------
69
+ Callable[[Callable], Callable]
70
+ A decorator function that adds the `call_order` attribute to the decorated function.
71
+
72
+ Raises
73
+ ------
74
+ ValueError
75
+ If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
76
+
77
+ Examples
78
+ --------
79
+ .. code-block:: python
80
+
81
+ >>> import brainstate
82
+ >>>
83
+ >>> class MyModule(brainstate.nn.Module):
84
+ ... @brainstate.nn.call_order(0)
85
+ ... def reset_state(self):
86
+ ... print("Reset first")
87
+ ...
88
+ ... @brainstate.nn.call_order(1)
89
+ ... def another_reset(self):
90
+ ... print("Reset second")
91
+ """
92
+ if check_order_boundary and (level < 0 or level >= MAX_ORDER):
93
+ raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
94
+
95
+ def wrap(fun: Callable) -> Callable:
96
+ fun.call_order = level
97
+ return fun
98
+
99
+ return wrap
100
+
101
+
102
+ @set_module_as('brainstate.nn')
103
+ def call_all_fns(
104
+ target: T,
105
+ fn_name: str,
106
+ args: Sequence[Any] | Any = (),
107
+ kwargs: Mapping[str, Any] | None = None,
108
+ node_to_exclude: Filter = None,
109
+ fn_if_not_exist: str = 'raise',
110
+ ) -> T:
111
+ """
112
+ Call a specified function on all module nodes within a target, respecting call order.
113
+
114
+ This function traverses all module nodes in the target and invokes the specified method
115
+ on each node. Functions decorated with `@call_order()` are executed in ascending order
116
+ of their level values, while functions without the decorator are executed first.
117
+
118
+ Parameters
119
+ ----------
120
+ target : Module
121
+ The target module on which to call functions.
122
+ fn_name : str
123
+ The name of the method to call on each module node.
124
+ node_to_exclude : Filter, optional
125
+ A filter to exclude certain nodes from the function call.
126
+ Can be a type, predicate function, or any filter supported by the graph API.
127
+ fn_if_not_exist : str, optional
128
+ Behavior when the specified method doesn't exist on a node:
129
+
130
+ - 'raise': Raise an AttributeError (default)
131
+ - 'pass' or 'none': Skip the node silently
132
+ - 'warn': Issue a warning and skip the node
133
+ args
134
+ Positional arguments to pass to the called method. A single non-tuple
135
+ argument will be automatically wrapped in a tuple. Default is ().
136
+ kwargs
137
+ Keyword arguments to pass to the called method. Default is None.
138
+
139
+ Raises
140
+ ------
141
+ TypeError
142
+ If `fun_name` is not a string or `kwargs` is not a mapping.
143
+ ValueError
144
+ If `fn_if_not_exist` is not one of the allowed values.
145
+ AttributeError
146
+ If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
147
+
148
+ Examples
149
+ --------
150
+ .. code-block:: python
151
+
152
+ >>> import brainstate
153
+ >>>
154
+ >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
155
+ >>> brainstate.nn.call_all_fns(net, 'init_state')
156
+ """
157
+ if not isinstance(fn_name, str):
158
+ raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
159
+
160
+ args = (args,) if not isinstance(args, tuple) else args
161
+ kwargs = kwargs or {}
162
+ if not isinstance(kwargs, Mapping):
163
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
164
+
165
+ all_nodes = nodes(target).filter(Module)
166
+ if node_to_exclude is not None:
167
+ all_nodes -= all_nodes.filter(node_to_exclude)
168
+
169
+ # Separate nodes with and without call_order
170
+ nodes_with_order = []
171
+ for path, node in all_nodes.items():
172
+ try:
173
+ fun = getattr(node, fn_name)
174
+ except AttributeError as e:
175
+ if fn_if_not_exist == 'raise':
176
+ raise AttributeError(
177
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
178
+ ) from e
179
+ elif fn_if_not_exist in ('pass', 'none'):
180
+ continue
181
+ elif fn_if_not_exist == 'warn':
182
+ warnings.warn(
183
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
184
+ f"Skipping.",
185
+ UserWarning
186
+ )
187
+ continue
188
+ else:
189
+ raise ValueError(
190
+ f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
191
+ )
192
+
193
+ if not callable(fun):
194
+ raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
195
+
196
+ if hasattr(fun, 'call_order'):
197
+ nodes_with_order.append(node)
198
+ else:
199
+ fun(*args, **kwargs)
200
+
201
+ # Execute nodes with call_order in sorted order
202
+ for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
203
+ getattr(node, fn_name)(*args, **kwargs)
204
+ return target
205
+
206
+
207
+ def vmap_call_all_fns(
208
+ target: T,
209
+ fn_name: str,
210
+ args: Sequence[Any] | Any = (),
211
+ kwargs: Mapping[str, Any] | None = None,
212
+ axis_size: int = None,
213
+ node_to_exclude: Filter = None,
214
+ state_tag: str | None = None,
215
+ fn_if_not_exist: str = 'raise',
216
+ ) -> T:
217
+ """
218
+ Apply vectorized mapping to call a function on all module nodes with batched state handling.
219
+
220
+ This function creates multiple batched instances by applying vmap to the specified method
221
+ call across all module nodes. Each batch element maintains its own random key and state
222
+ values. This is particularly useful for creating ensembles or batched models.
223
+
224
+ Parameters
225
+ ----------
226
+ target : Module
227
+ The target module on which to call functions.
228
+ fn_name : str
229
+ The name of the method to call on each module node.
230
+ args : Sequence[Any] or Any, optional
231
+ Positional arguments to pass to the called method. A single non-tuple
232
+ argument will be automatically wrapped in a tuple. Default is ().
233
+ kwargs : Mapping[str, Any], optional
234
+ Keyword arguments to pass to the called method. Default is None.
235
+ axis_size : int
236
+ The size of the batch dimension for vmap. Must be a positive integer.
237
+ node_to_exclude : Filter, optional
238
+ A filter to exclude certain nodes from the function call.
239
+ state_tag : str, optional
240
+ An optional tag to categorize newly created states during the vmap operation.
241
+ fn_if_not_exist : str, optional
242
+ Behavior when the specified method doesn't exist on a node:
243
+
244
+ - 'raise': Raise an AttributeError (default)
245
+ - 'pass' or 'none': Skip the node silently
246
+ - 'warn': Issue a warning and skip the node
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If `axis_size` is None or not a positive integer.
252
+ TypeError
253
+ If `kwargs` is not a mapping.
254
+
255
+ Examples
256
+ --------
257
+ .. code-block:: python
258
+
259
+ >>> import brainstate
260
+ >>>
261
+ >>> net = brainstate.nn.Linear(10, 20)
262
+ >>> # Create 5 batched instances with different initializations
263
+ >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
264
+ """
265
+
266
+ if axis_size is None or axis_size <= 0:
267
+ raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
268
+
269
+ if not isinstance(args, tuple):
270
+ args = (args,)
271
+ kwargs = kwargs or {}
272
+ if not isinstance(kwargs, Mapping):
273
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
274
+
275
+ @vmap(axis_size=axis_size)
276
+ def vmapped_fn():
277
+ with catch_new_states(state_tag) as inner_catcher:
278
+ call_all_fns(
279
+ target,
280
+ fn_name=fn_name,
281
+ args=args,
282
+ kwargs=kwargs,
283
+ node_to_exclude=node_to_exclude,
284
+ fn_if_not_exist=fn_if_not_exist
285
+ )
286
+ return inner_catcher.get_state_values()
287
+
288
+ with catch_new_states(state_tag) as outer_catcher:
289
+ values = vmapped_fn()
290
+ states = outer_catcher.get_states()
291
+ for state, value in zip(states, values):
292
+ state.value = value
293
+ return target
294
+
295
+
296
+ @set_module_as('brainstate.nn')
297
+ def init_all_states(
298
+ target: T,
299
+ *init_args,
300
+ node_to_exclude: Filter = None,
301
+ **init_kwargs,
302
+ ) -> T:
303
+ """
304
+ Initialize states for all module nodes within the target.
305
+
306
+ This is a convenience wrapper around `call_all_functions` that specifically calls
307
+ the `init_state` method on all module nodes. The execution order respects any
308
+ `@call_order()` decorators on the `init_state` methods.
309
+
310
+ Parameters
311
+ ----------
312
+ target : Module
313
+ The target module whose states are to be initialized.
314
+ *init_args
315
+ Variable positional arguments to pass to each `init_state` method.
316
+ node_to_exclude : Filter, optional
317
+ A filter to exclude certain nodes from initialization.
318
+ Can be a type, predicate function, or any filter supported by the graph API.
319
+ **init_kwargs
320
+ Variable keyword arguments to pass to each `init_state` method.
321
+
322
+ Examples
323
+ --------
324
+ .. code-block:: python
325
+
326
+ >>> import brainstate
327
+ >>>
328
+ >>> net = brainstate.nn.Sequential(
329
+ ... brainstate.nn.Linear(10, 20),
330
+ ... brainstate.nn.Dropout(0.5)
331
+ ... )
332
+ >>> # Initialize all states
333
+ >>> brainstate.nn.init_all_states(net)
334
+ >>>
335
+ >>> # Initialize with custom arguments
336
+ >>> brainstate.nn.init_all_states(net, batch_size=32)
337
+
338
+ See Also
339
+ --------
340
+ call_all_functions : The underlying function that executes the calls.
341
+ vmap_init_all_states : Vectorized version for batched initialization.
342
+ """
343
+ call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
344
+ return target
345
+
346
+
347
+ @set_module_as('brainstate.nn')
348
+ def vmap_init_all_states(
349
+ target: T,
350
+ *init_args,
351
+ axis_size: int = None,
352
+ node_to_exclude: Filter = None,
353
+ state_to_exclude: Filter = None,
354
+ state_tag: str | None = None,
355
+ **init_kwargs
356
+ ) -> T:
357
+ """
358
+ Initialize states with vectorized mapping for creating batched module instances.
359
+
360
+ This function applies vmap to the initialization process, creating multiple batched
361
+ instances of module states. Each batch element will have independent state values
362
+ and random keys. This is useful for ensemble models or parameter sweeps.
363
+
364
+ Parameters
365
+ ----------
366
+ target : Module
367
+ The target module whose states are to be initialized.
368
+ *init_args
369
+ Variable positional arguments to pass to each `init_state` method.
370
+ axis_size : int
371
+ The size of the batch dimension. Must be a positive integer.
372
+ node_to_exclude : Filter, optional
373
+ A filter to exclude certain nodes from initialization.
374
+ state_to_exclude : Filter, optional
375
+ A filter to exclude certain states from being vmapped.
376
+ Excluded states will remain shared across all batched instances.
377
+ state_tag : str, optional
378
+ An optional tag to categorize newly created states.
379
+ **init_kwargs
380
+ Variable keyword arguments to pass to each `init_state` method.
381
+
382
+ Raises
383
+ ------
384
+ ValueError
385
+ If `axis_size` is None or not a positive integer.
386
+
387
+ Examples
388
+ --------
389
+ .. code-block:: python
390
+
391
+ >>> import brainstate
392
+ >>>
393
+ >>> net = brainstate.nn.Linear(10, 20)
394
+ >>> # Create 8 batched instances with different random initializations
395
+ >>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
396
+ >>>
397
+ >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
398
+ >>> print(net.weight.shape)
399
+
400
+ See Also
401
+ --------
402
+ init_all_states : Non-vectorized version.
403
+ vmap_new_states : The underlying vmap transformation for states.
404
+ """
405
+
406
+ # vmap_call_all_functions(
407
+ # target,
408
+ # fun_name='init_state',
409
+ # args=init_args,
410
+ # kwargs=init_kwargs,
411
+ # axis_size=axis_size,
412
+ # node_to_exclude=node_to_exclude,
413
+ # state_tag=state_tag,
414
+ # )
415
+
416
+ def init_fn():
417
+ init_all_states(
418
+ target,
419
+ *init_args,
420
+ **init_kwargs,
421
+ node_to_exclude=node_to_exclude,
422
+ )
423
+ return
424
+
425
+ vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
426
+ return target
427
+
428
+
429
+ @set_module_as('brainstate.nn')
430
+ def reset_all_states(
431
+ target: T,
432
+ *reset_args,
433
+ node_to_exclude: Filter = None,
434
+ **reset_kwargs,
435
+ ) -> T:
436
+ """
437
+ Reset states for all module nodes within the target.
438
+
439
+ This is a convenience wrapper around `call_all_functions` that specifically calls
440
+ the `reset_state` method on all module nodes. The execution order respects any
441
+ `@call_order()` decorators on the `reset_state` methods. This is typically used
442
+ to reset recurrent neural network states between sequences.
443
+
444
+ Parameters
445
+ ----------
446
+ target : Module
447
+ The target module whose states are to be reset.
448
+ reset_args
449
+ Positional arguments to pass to each `reset_state` method.
450
+ A single non-tuple argument will be automatically wrapped in a tuple.
451
+ Default is ().
452
+ reset_kwargs
453
+ Keyword arguments to pass to each `reset_state` method.
454
+ Default is None.
455
+ node_to_exclude : Filter, optional
456
+ A filter to exclude certain nodes from reset.
457
+ Can be a type, predicate function, or any filter supported by the graph API.
458
+
459
+ Examples
460
+ --------
461
+ .. code-block:: python
462
+
463
+ >>> import brainstate
464
+ >>>
465
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
466
+ >>> brainstate.nn.init_all_states(rnn, batch_size=32)
467
+ >>>
468
+ >>> # Process a sequence
469
+ >>> for x in sequence:
470
+ ... output = rnn(x)
471
+ >>>
472
+ >>> # Reset states before processing next sequence
473
+ >>> brainstate.nn.reset_all_states(rnn)
474
+
475
+ See Also
476
+ --------
477
+ call_all_functions : The underlying function that executes the calls.
478
+ vmap_reset_all_states : Vectorized version for batched reset.
479
+ """
480
+ call_all_fns(
481
+ target,
482
+ fn_name='reset_state',
483
+ args=reset_args,
484
+ kwargs=reset_kwargs,
485
+ node_to_exclude=node_to_exclude
486
+ )
487
+ return target
488
+
489
+
490
+ def vmap_reset_all_states(
491
+ target: T,
492
+ *reset_args,
493
+ axis_size: int = None,
494
+ node_to_exclude: Filter = None,
495
+ state_tag: str | None = None,
496
+ **reset_kwargs,
497
+ ) -> T:
498
+ """
499
+ Reset states with vectorized mapping across batched module instances.
500
+
501
+ This function applies vmap to the reset process, resetting states across all
502
+ batched instances of the module. Each batch element will have its state reset
503
+ independently with its own random key. This is useful when working with batched
504
+ recurrent models or ensembles.
505
+
506
+ Parameters
507
+ ----------
508
+ target : Module
509
+ The target module whose states are to be reset.
510
+ reset_args
511
+ Positional arguments to pass to each `reset_state` method.
512
+ A single non-tuple argument will be automatically wrapped in a tuple.
513
+ Default is ().
514
+ reset_kwargs
515
+ Keyword arguments to pass to each `reset_state` method.
516
+ Default is None.
517
+ axis_size : int
518
+ The size of the batch dimension. Must be a positive integer.
519
+ node_to_exclude : Filter, optional
520
+ A filter to exclude certain nodes from reset.
521
+ state_tag : str, optional
522
+ An optional tag to categorize newly created states during the reset.
523
+
524
+ Raises
525
+ ------
526
+ ValueError
527
+ If `axis_size` is None or not a positive integer.
528
+ TypeError
529
+ If `reset_kwargs` is not a mapping.
530
+
531
+ Examples
532
+ --------
533
+ .. code-block:: python
534
+
535
+ >>> import brainstate
536
+ >>>
537
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
538
+ >>> # Initialize with 16 batched instances
539
+ >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
540
+ >>>
541
+ >>> # Process sequences...
542
+ >>>
543
+ >>> # Reset all 16 batched instances
544
+ >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
545
+
546
+ See Also
547
+ --------
548
+ reset_all_states : Non-vectorized version.
549
+ vmap_call_all_functions : The underlying vmap function call mechanism.
550
+ """
551
+ vmap_call_all_fns(
552
+ target,
553
+ fn_name='reset_state',
554
+ args=reset_args,
555
+ kwargs=reset_kwargs,
556
+ axis_size=axis_size,
557
+ node_to_exclude=node_to_exclude,
558
+ state_tag=state_tag,
559
+ )
560
+ return target
561
+
562
+
563
+ @set_module_as('brainstate.nn')
564
+ def assign_state_values(
565
+ target: Module,
566
+ *state_by_abs_path: Mapping[str, Any]
567
+ ) -> tuple[list[str], list[str]]:
568
+ """
569
+ Assign state values to a module from one or more state dictionaries.
570
+
571
+ This function updates the state values of a module based on provided state dictionaries.
572
+ State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
573
+ The function handles missing and unexpected keys, returning them for inspection.
574
+
575
+ Parameters
576
+ ----------
577
+ target : Module
578
+ The target module whose states will be updated.
579
+ *state_by_abs_path : Mapping[str, Any]
580
+ One or more state dictionaries with absolute path keys mapping to state values.
581
+ If multiple dictionaries are provided, they will be merged (later dictionaries
582
+ override earlier ones for duplicate keys).
583
+
584
+ Returns
585
+ -------
586
+ tuple[list[str], list[str]]
587
+ A tuple of (unexpected_keys, missing_keys):
588
+
589
+ - unexpected_keys: Keys present in the state dictionaries but not in the module
590
+ - missing_keys: Keys present in the module but not in the state dictionaries
591
+
592
+ Examples
593
+ --------
594
+ .. code-block:: python
595
+
596
+ >>> import brainstate
597
+ >>>
598
+ >>> net = brainstate.nn.Linear(10, 20)
599
+ >>> brainstate.nn.init_all_states(net)
600
+ >>>
601
+ >>> # Save state values
602
+ >>> state_dict = {path: state.value for path, state in net.states().items()}
603
+ >>>
604
+ >>> # Later, restore state values
605
+ >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
606
+ >>> print(f"Unexpected keys: {unexpected}")
607
+ >>> print(f"Missing keys: {missing}")
608
+
609
+ Notes
610
+ -----
611
+ - All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
612
+ - Only states with matching keys are updated; unexpected and missing keys are
613
+ returned but do not cause errors.
614
+ - If multiple dictionaries contain the same key, the last one takes precedence.
615
+ """
616
+ # Merge all state dictionaries
617
+ all_states = {}
618
+ for state_dict in state_by_abs_path:
619
+ all_states.update(state_dict)
620
+
621
+ # Get current module states
622
+ variables = target.states()
623
+ keys1 = set(all_states.keys())
624
+ keys2 = set(variables.keys())
625
+
626
+ # Update matching states
627
+ for key in keys2.intersection(keys1):
628
+ variables[key].value = jax.numpy.asarray(all_states[key])
629
+
630
+ # Return mismatched keys
631
+ unexpected_keys = sorted(keys1 - keys2)
632
+ missing_keys = sorted(keys2 - keys1)
633
+ return unexpected_keys, missing_keys