brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,633 +1,635 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import warnings
16
- from collections.abc import Sequence, Mapping
17
- from typing import Callable, TypeVar, Any
18
-
19
- import jax
20
-
21
- from brainstate._state import catch_new_states
22
- from brainstate._utils import set_module_as
23
- from brainstate.graph import nodes
24
- from brainstate.transform import vmap, vmap_new_states
25
- from brainstate.typing import Filter
26
- from ._module import Module
27
-
28
- # the maximum order
29
- MAX_ORDER = 10
30
-
31
- T = TypeVar('T', bound=Module)
32
-
33
- __all__ = [
34
- 'call_order',
35
- 'call_all_fns',
36
- 'vmap_call_all_fns',
37
- 'init_all_states',
38
- 'vmap_init_all_states',
39
- 'reset_all_states',
40
- 'vmap_reset_all_states',
41
- 'assign_state_values',
42
- ]
43
-
44
-
45
- @set_module_as('brainstate.nn')
46
- def call_order(
47
- level: int = 0,
48
- check_order_boundary: bool = True
49
- ) -> Callable[[Callable], Callable]:
50
- """
51
- Decorator for specifying the execution order of functions in collective operations.
52
-
53
- This decorator attaches a `call_order` attribute to a function, which is used by
54
- collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
55
- to determine the execution order. Functions with lower order levels are executed first.
56
-
57
- Parameters
58
- ----------
59
- level : int, optional
60
- The execution order level. Lower values indicate earlier execution.
61
- Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
62
- Default is 0.
63
- check_order_boundary : bool, optional
64
- Whether to validate that the order level is within the valid range [0, MAX_ORDER).
65
- Default is True.
66
-
67
- Returns
68
- -------
69
- Callable[[Callable], Callable]
70
- A decorator function that adds the `call_order` attribute to the decorated function.
71
-
72
- Raises
73
- ------
74
- ValueError
75
- If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
76
-
77
- Examples
78
- --------
79
- .. code-block:: python
80
-
81
- >>> import brainstate
82
- >>>
83
- >>> class MyModule(brainstate.nn.Module):
84
- ... @brainstate.nn.call_order(0)
85
- ... def reset_state(self):
86
- ... print("Reset first")
87
- ...
88
- ... @brainstate.nn.call_order(1)
89
- ... def another_reset(self):
90
- ... print("Reset second")
91
- """
92
- if check_order_boundary and (level < 0 or level >= MAX_ORDER):
93
- raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
94
-
95
- def wrap(fun: Callable) -> Callable:
96
- fun.call_order = level
97
- return fun
98
-
99
- return wrap
100
-
101
-
102
- @set_module_as('brainstate.nn')
103
- def call_all_fns(
104
- target: T,
105
- fn_name: str,
106
- args: Sequence[Any] | Any = (),
107
- kwargs: Mapping[str, Any] | None = None,
108
- node_to_exclude: Filter = None,
109
- fn_if_not_exist: str = 'raise',
110
- ) -> T:
111
- """
112
- Call a specified function on all module nodes within a target, respecting call order.
113
-
114
- This function traverses all module nodes in the target and invokes the specified method
115
- on each node. Functions decorated with `@call_order()` are executed in ascending order
116
- of their level values, while functions without the decorator are executed first.
117
-
118
- Parameters
119
- ----------
120
- target : Module
121
- The target module on which to call functions.
122
- fn_name : str
123
- The name of the method to call on each module node.
124
- node_to_exclude : Filter, optional
125
- A filter to exclude certain nodes from the function call.
126
- Can be a type, predicate function, or any filter supported by the graph API.
127
- fn_if_not_exist : str, optional
128
- Behavior when the specified method doesn't exist on a node:
129
-
130
- - 'raise': Raise an AttributeError (default)
131
- - 'pass' or 'none': Skip the node silently
132
- - 'warn': Issue a warning and skip the node
133
- args
134
- Positional arguments to pass to the called method. A single non-tuple
135
- argument will be automatically wrapped in a tuple. Default is ().
136
- kwargs
137
- Keyword arguments to pass to the called method. Default is None.
138
-
139
- Raises
140
- ------
141
- TypeError
142
- If `fun_name` is not a string or `kwargs` is not a mapping.
143
- ValueError
144
- If `fn_if_not_exist` is not one of the allowed values.
145
- AttributeError
146
- If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
147
-
148
- Examples
149
- --------
150
- .. code-block:: python
151
-
152
- >>> import brainstate
153
- >>>
154
- >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
155
- >>> brainstate.nn.call_all_fns(net, 'init_state')
156
- """
157
- if not isinstance(fn_name, str):
158
- raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
159
-
160
- args = (args,) if not isinstance(args, tuple) else args
161
- kwargs = kwargs or {}
162
- if not isinstance(kwargs, Mapping):
163
- raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
164
-
165
- all_nodes = nodes(target).filter(Module)
166
- if node_to_exclude is not None:
167
- all_nodes -= all_nodes.filter(node_to_exclude)
168
-
169
- # Separate nodes with and without call_order
170
- nodes_with_order = []
171
- for path, node in all_nodes.items():
172
- try:
173
- fun = getattr(node, fn_name)
174
- except AttributeError as e:
175
- if fn_if_not_exist == 'raise':
176
- raise AttributeError(
177
- f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
178
- ) from e
179
- elif fn_if_not_exist in ('pass', 'none'):
180
- continue
181
- elif fn_if_not_exist == 'warn':
182
- warnings.warn(
183
- f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
184
- f"Skipping.",
185
- UserWarning
186
- )
187
- continue
188
- else:
189
- raise ValueError(
190
- f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
191
- )
192
-
193
- if not callable(fun):
194
- raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
195
-
196
- if hasattr(fun, 'call_order'):
197
- nodes_with_order.append(node)
198
- else:
199
- fun(*args, **kwargs)
200
-
201
- # Execute nodes with call_order in sorted order
202
- for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
203
- getattr(node, fn_name)(*args, **kwargs)
204
- return target
205
-
206
-
207
- def vmap_call_all_fns(
208
- target: T,
209
- fn_name: str,
210
- args: Sequence[Any] | Any = (),
211
- kwargs: Mapping[str, Any] | None = None,
212
- axis_size: int = None,
213
- node_to_exclude: Filter = None,
214
- state_tag: str | None = None,
215
- fn_if_not_exist: str = 'raise',
216
- ) -> T:
217
- """
218
- Apply vectorized mapping to call a function on all module nodes with batched state handling.
219
-
220
- This function creates multiple batched instances by applying vmap to the specified method
221
- call across all module nodes. Each batch element maintains its own random key and state
222
- values. This is particularly useful for creating ensembles or batched models.
223
-
224
- Parameters
225
- ----------
226
- target : Module
227
- The target module on which to call functions.
228
- fn_name : str
229
- The name of the method to call on each module node.
230
- args : Sequence[Any] or Any, optional
231
- Positional arguments to pass to the called method. A single non-tuple
232
- argument will be automatically wrapped in a tuple. Default is ().
233
- kwargs : Mapping[str, Any], optional
234
- Keyword arguments to pass to the called method. Default is None.
235
- axis_size : int
236
- The size of the batch dimension for vmap. Must be a positive integer.
237
- node_to_exclude : Filter, optional
238
- A filter to exclude certain nodes from the function call.
239
- state_tag : str, optional
240
- An optional tag to categorize newly created states during the vmap operation.
241
- fn_if_not_exist : str, optional
242
- Behavior when the specified method doesn't exist on a node:
243
-
244
- - 'raise': Raise an AttributeError (default)
245
- - 'pass' or 'none': Skip the node silently
246
- - 'warn': Issue a warning and skip the node
247
-
248
- Raises
249
- ------
250
- ValueError
251
- If `axis_size` is None or not a positive integer.
252
- TypeError
253
- If `kwargs` is not a mapping.
254
-
255
- Examples
256
- --------
257
- .. code-block:: python
258
-
259
- >>> import brainstate
260
- >>>
261
- >>> net = brainstate.nn.Linear(10, 20)
262
- >>> # Create 5 batched instances with different initializations
263
- >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
264
- """
265
-
266
- if axis_size is None or axis_size <= 0:
267
- raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
268
-
269
- if not isinstance(args, tuple):
270
- args = (args,)
271
- kwargs = kwargs or {}
272
- if not isinstance(kwargs, Mapping):
273
- raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
274
-
275
- @vmap(axis_size=axis_size)
276
- def vmapped_fn():
277
- with catch_new_states(state_tag) as inner_catcher:
278
- call_all_fns(
279
- target,
280
- fn_name=fn_name,
281
- args=args,
282
- kwargs=kwargs,
283
- node_to_exclude=node_to_exclude,
284
- fn_if_not_exist=fn_if_not_exist
285
- )
286
- return inner_catcher.get_state_values()
287
-
288
- with catch_new_states(state_tag) as outer_catcher:
289
- values = vmapped_fn()
290
- states = outer_catcher.get_states()
291
- for state, value in zip(states, values):
292
- state.value = value
293
- return target
294
-
295
-
296
- @set_module_as('brainstate.nn')
297
- def init_all_states(
298
- target: T,
299
- *init_args,
300
- node_to_exclude: Filter = None,
301
- **init_kwargs,
302
- ) -> T:
303
- """
304
- Initialize states for all module nodes within the target.
305
-
306
- This is a convenience wrapper around `call_all_functions` that specifically calls
307
- the `init_state` method on all module nodes. The execution order respects any
308
- `@call_order()` decorators on the `init_state` methods.
309
-
310
- Parameters
311
- ----------
312
- target : Module
313
- The target module whose states are to be initialized.
314
- *init_args
315
- Variable positional arguments to pass to each `init_state` method.
316
- node_to_exclude : Filter, optional
317
- A filter to exclude certain nodes from initialization.
318
- Can be a type, predicate function, or any filter supported by the graph API.
319
- **init_kwargs
320
- Variable keyword arguments to pass to each `init_state` method.
321
-
322
- Examples
323
- --------
324
- .. code-block:: python
325
-
326
- >>> import brainstate
327
- >>>
328
- >>> net = brainstate.nn.Sequential(
329
- ... brainstate.nn.Linear(10, 20),
330
- ... brainstate.nn.Dropout(0.5)
331
- ... )
332
- >>> # Initialize all states
333
- >>> brainstate.nn.init_all_states(net)
334
- >>>
335
- >>> # Initialize with custom arguments
336
- >>> brainstate.nn.init_all_states(net, batch_size=32)
337
-
338
- See Also
339
- --------
340
- call_all_functions : The underlying function that executes the calls.
341
- vmap_init_all_states : Vectorized version for batched initialization.
342
- """
343
- call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
344
- return target
345
-
346
-
347
- @set_module_as('brainstate.nn')
348
- def vmap_init_all_states(
349
- target: T,
350
- *init_args,
351
- axis_size: int = None,
352
- node_to_exclude: Filter = None,
353
- state_to_exclude: Filter = None,
354
- state_tag: str | None = None,
355
- **init_kwargs
356
- ) -> T:
357
- """
358
- Initialize states with vectorized mapping for creating batched module instances.
359
-
360
- This function applies vmap to the initialization process, creating multiple batched
361
- instances of module states. Each batch element will have independent state values
362
- and random keys. This is useful for ensemble models or parameter sweeps.
363
-
364
- Parameters
365
- ----------
366
- target : Module
367
- The target module whose states are to be initialized.
368
- *init_args
369
- Variable positional arguments to pass to each `init_state` method.
370
- axis_size : int
371
- The size of the batch dimension. Must be a positive integer.
372
- node_to_exclude : Filter, optional
373
- A filter to exclude certain nodes from initialization.
374
- state_to_exclude : Filter, optional
375
- A filter to exclude certain states from being vmapped.
376
- Excluded states will remain shared across all batched instances.
377
- state_tag : str, optional
378
- An optional tag to categorize newly created states.
379
- **init_kwargs
380
- Variable keyword arguments to pass to each `init_state` method.
381
-
382
- Raises
383
- ------
384
- ValueError
385
- If `axis_size` is None or not a positive integer.
386
-
387
- Examples
388
- --------
389
- .. code-block:: python
390
-
391
- >>> import brainstate
392
- >>>
393
- >>> net = brainstate.nn.Linear(10, 20)
394
- >>> # Create 8 batched instances with different random initializations
395
- >>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
396
- >>>
397
- >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
398
- >>> print(net.weight.shape)
399
-
400
- See Also
401
- --------
402
- init_all_states : Non-vectorized version.
403
- vmap_new_states : The underlying vmap transformation for states.
404
- """
405
-
406
- # vmap_call_all_functions(
407
- # target,
408
- # fun_name='init_state',
409
- # args=init_args,
410
- # kwargs=init_kwargs,
411
- # axis_size=axis_size,
412
- # node_to_exclude=node_to_exclude,
413
- # state_tag=state_tag,
414
- # )
415
-
416
- def init_fn():
417
- init_all_states(
418
- target,
419
- *init_args,
420
- **init_kwargs,
421
- node_to_exclude=node_to_exclude,
422
- )
423
- return
424
-
425
- vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
426
- return target
427
-
428
-
429
- @set_module_as('brainstate.nn')
430
- def reset_all_states(
431
- target: T,
432
- *reset_args,
433
- node_to_exclude: Filter = None,
434
- **reset_kwargs,
435
- ) -> T:
436
- """
437
- Reset states for all module nodes within the target.
438
-
439
- This is a convenience wrapper around `call_all_functions` that specifically calls
440
- the `reset_state` method on all module nodes. The execution order respects any
441
- `@call_order()` decorators on the `reset_state` methods. This is typically used
442
- to reset recurrent neural network states between sequences.
443
-
444
- Parameters
445
- ----------
446
- target : Module
447
- The target module whose states are to be reset.
448
- reset_args
449
- Positional arguments to pass to each `reset_state` method.
450
- A single non-tuple argument will be automatically wrapped in a tuple.
451
- Default is ().
452
- reset_kwargs
453
- Keyword arguments to pass to each `reset_state` method.
454
- Default is None.
455
- node_to_exclude : Filter, optional
456
- A filter to exclude certain nodes from reset.
457
- Can be a type, predicate function, or any filter supported by the graph API.
458
-
459
- Examples
460
- --------
461
- .. code-block:: python
462
-
463
- >>> import brainstate
464
- >>>
465
- >>> rnn = brainstate.nn.RNNCell(10, 20)
466
- >>> brainstate.nn.init_all_states(rnn, batch_size=32)
467
- >>>
468
- >>> # Process a sequence
469
- >>> for x in sequence:
470
- ... output = rnn(x)
471
- >>>
472
- >>> # Reset states before processing next sequence
473
- >>> brainstate.nn.reset_all_states(rnn)
474
-
475
- See Also
476
- --------
477
- call_all_functions : The underlying function that executes the calls.
478
- vmap_reset_all_states : Vectorized version for batched reset.
479
- """
480
- call_all_fns(
481
- target,
482
- fn_name='reset_state',
483
- args=reset_args,
484
- kwargs=reset_kwargs,
485
- node_to_exclude=node_to_exclude
486
- )
487
- return target
488
-
489
-
490
- def vmap_reset_all_states(
491
- target: T,
492
- *reset_args,
493
- axis_size: int = None,
494
- node_to_exclude: Filter = None,
495
- state_tag: str | None = None,
496
- **reset_kwargs,
497
- ) -> T:
498
- """
499
- Reset states with vectorized mapping across batched module instances.
500
-
501
- This function applies vmap to the reset process, resetting states across all
502
- batched instances of the module. Each batch element will have its state reset
503
- independently with its own random key. This is useful when working with batched
504
- recurrent models or ensembles.
505
-
506
- Parameters
507
- ----------
508
- target : Module
509
- The target module whose states are to be reset.
510
- reset_args
511
- Positional arguments to pass to each `reset_state` method.
512
- A single non-tuple argument will be automatically wrapped in a tuple.
513
- Default is ().
514
- reset_kwargs
515
- Keyword arguments to pass to each `reset_state` method.
516
- Default is None.
517
- axis_size : int
518
- The size of the batch dimension. Must be a positive integer.
519
- node_to_exclude : Filter, optional
520
- A filter to exclude certain nodes from reset.
521
- state_tag : str, optional
522
- An optional tag to categorize newly created states during the reset.
523
-
524
- Raises
525
- ------
526
- ValueError
527
- If `axis_size` is None or not a positive integer.
528
- TypeError
529
- If `reset_kwargs` is not a mapping.
530
-
531
- Examples
532
- --------
533
- .. code-block:: python
534
-
535
- >>> import brainstate
536
- >>>
537
- >>> rnn = brainstate.nn.RNNCell(10, 20)
538
- >>> # Initialize with 16 batched instances
539
- >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
540
- >>>
541
- >>> # Process sequences...
542
- >>>
543
- >>> # Reset all 16 batched instances
544
- >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
545
-
546
- See Also
547
- --------
548
- reset_all_states : Non-vectorized version.
549
- vmap_call_all_functions : The underlying vmap function call mechanism.
550
- """
551
- vmap_call_all_fns(
552
- target,
553
- fn_name='reset_state',
554
- args=reset_args,
555
- kwargs=reset_kwargs,
556
- axis_size=axis_size,
557
- node_to_exclude=node_to_exclude,
558
- state_tag=state_tag,
559
- )
560
- return target
561
-
562
-
563
- @set_module_as('brainstate.nn')
564
- def assign_state_values(
565
- target: Module,
566
- *state_by_abs_path: Mapping[str, Any]
567
- ) -> tuple[list[str], list[str]]:
568
- """
569
- Assign state values to a module from one or more state dictionaries.
570
-
571
- This function updates the state values of a module based on provided state dictionaries.
572
- State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
573
- The function handles missing and unexpected keys, returning them for inspection.
574
-
575
- Parameters
576
- ----------
577
- target : Module
578
- The target module whose states will be updated.
579
- *state_by_abs_path : Mapping[str, Any]
580
- One or more state dictionaries with absolute path keys mapping to state values.
581
- If multiple dictionaries are provided, they will be merged (later dictionaries
582
- override earlier ones for duplicate keys).
583
-
584
- Returns
585
- -------
586
- tuple[list[str], list[str]]
587
- A tuple of (unexpected_keys, missing_keys):
588
-
589
- - unexpected_keys: Keys present in the state dictionaries but not in the module
590
- - missing_keys: Keys present in the module but not in the state dictionaries
591
-
592
- Examples
593
- --------
594
- .. code-block:: python
595
-
596
- >>> import brainstate
597
- >>>
598
- >>> net = brainstate.nn.Linear(10, 20)
599
- >>> brainstate.nn.init_all_states(net)
600
- >>>
601
- >>> # Save state values
602
- >>> state_dict = {path: state.value for path, state in net.states().items()}
603
- >>>
604
- >>> # Later, restore state values
605
- >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
606
- >>> print(f"Unexpected keys: {unexpected}")
607
- >>> print(f"Missing keys: {missing}")
608
-
609
- Notes
610
- -----
611
- - All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
612
- - Only states with matching keys are updated; unexpected and missing keys are
613
- returned but do not cause errors.
614
- - If multiple dictionaries contain the same key, the last one takes precedence.
615
- """
616
- # Merge all state dictionaries
617
- all_states = {}
618
- for state_dict in state_by_abs_path:
619
- all_states.update(state_dict)
620
-
621
- # Get current module states
622
- variables = target.states()
623
- keys1 = set(all_states.keys())
624
- keys2 = set(variables.keys())
625
-
626
- # Update matching states
627
- for key in keys2.intersection(keys1):
628
- variables[key].value = jax.numpy.asarray(all_states[key])
629
-
630
- # Return mismatched keys
631
- unexpected_keys = sorted(keys1 - keys2)
632
- missing_keys = sorted(keys2 - keys1)
633
- return unexpected_keys, missing_keys
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import warnings
18
+ from collections.abc import Sequence, Mapping
19
+ from typing import Callable, TypeVar, Any
20
+
21
+ import jax
22
+
23
+ from brainstate._state import catch_new_states
24
+ from brainstate._utils import set_module_as
25
+ from brainstate.graph import nodes
26
+ from brainstate.transform import vmap, vmap_new_states
27
+ from brainstate.typing import Filter
28
+ from ._module import Module
29
+
30
+ # the maximum order
31
+ MAX_ORDER = 10
32
+
33
+ T = TypeVar('T', bound=Module)
34
+
35
+ __all__ = [
36
+ 'call_order',
37
+ 'call_all_fns',
38
+ 'vmap_call_all_fns',
39
+ 'init_all_states',
40
+ 'vmap_init_all_states',
41
+ 'reset_all_states',
42
+ 'vmap_reset_all_states',
43
+ 'assign_state_values',
44
+ ]
45
+
46
+
47
+ @set_module_as('brainstate.nn')
48
+ def call_order(
49
+ level: int = 0,
50
+ check_order_boundary: bool = True
51
+ ) -> Callable[[Callable], Callable]:
52
+ """
53
+ Decorator for specifying the execution order of functions in collective operations.
54
+
55
+ This decorator attaches a `call_order` attribute to a function, which is used by
56
+ collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
57
+ to determine the execution order. Functions with lower order levels are executed first.
58
+
59
+ Parameters
60
+ ----------
61
+ level : int, optional
62
+ The execution order level. Lower values indicate earlier execution.
63
+ Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
64
+ Default is 0.
65
+ check_order_boundary : bool, optional
66
+ Whether to validate that the order level is within the valid range [0, MAX_ORDER).
67
+ Default is True.
68
+
69
+ Returns
70
+ -------
71
+ Callable[[Callable], Callable]
72
+ A decorator function that adds the `call_order` attribute to the decorated function.
73
+
74
+ Raises
75
+ ------
76
+ ValueError
77
+ If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
78
+
79
+ Examples
80
+ --------
81
+ .. code-block:: python
82
+
83
+ >>> import brainstate
84
+ >>>
85
+ >>> class MyModule(brainstate.nn.Module):
86
+ ... @brainstate.nn.call_order(0)
87
+ ... def reset_state(self):
88
+ ... print("Reset first")
89
+ ...
90
+ ... @brainstate.nn.call_order(1)
91
+ ... def another_reset(self):
92
+ ... print("Reset second")
93
+ """
94
+ if check_order_boundary and (level < 0 or level >= MAX_ORDER):
95
+ raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
96
+
97
+ def wrap(fun: Callable) -> Callable:
98
+ fun.call_order = level
99
+ return fun
100
+
101
+ return wrap
102
+
103
+
104
+ @set_module_as('brainstate.nn')
105
+ def call_all_fns(
106
+ target: T,
107
+ fn_name: str,
108
+ args: Sequence[Any] | Any = (),
109
+ kwargs: Mapping[str, Any] | None = None,
110
+ node_to_exclude: Filter = None,
111
+ fn_if_not_exist: str = 'raise',
112
+ ) -> T:
113
+ """
114
+ Call a specified function on all module nodes within a target, respecting call order.
115
+
116
+ This function traverses all module nodes in the target and invokes the specified method
117
+ on each node. Functions decorated with `@call_order()` are executed in ascending order
118
+ of their level values, while functions without the decorator are executed first.
119
+
120
+ Parameters
121
+ ----------
122
+ target : Module
123
+ The target module on which to call functions.
124
+ fn_name : str
125
+ The name of the method to call on each module node.
126
+ node_to_exclude : Filter, optional
127
+ A filter to exclude certain nodes from the function call.
128
+ Can be a type, predicate function, or any filter supported by the graph API.
129
+ fn_if_not_exist : str, optional
130
+ Behavior when the specified method doesn't exist on a node:
131
+
132
+ - 'raise': Raise an AttributeError (default)
133
+ - 'pass' or 'none': Skip the node silently
134
+ - 'warn': Issue a warning and skip the node
135
+ args
136
+ Positional arguments to pass to the called method. A single non-tuple
137
+ argument will be automatically wrapped in a tuple. Default is ().
138
+ kwargs
139
+ Keyword arguments to pass to the called method. Default is None.
140
+
141
+ Raises
142
+ ------
143
+ TypeError
144
+ If `fun_name` is not a string or `kwargs` is not a mapping.
145
+ ValueError
146
+ If `fn_if_not_exist` is not one of the allowed values.
147
+ AttributeError
148
+ If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
149
+
150
+ Examples
151
+ --------
152
+ .. code-block:: python
153
+
154
+ >>> import brainstate
155
+ >>>
156
+ >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
157
+ >>> brainstate.nn.call_all_fns(net, 'init_state')
158
+ """
159
+ if not isinstance(fn_name, str):
160
+ raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
161
+
162
+ args = (args,) if not isinstance(args, tuple) else args
163
+ kwargs = kwargs or {}
164
+ if not isinstance(kwargs, Mapping):
165
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
166
+
167
+ all_nodes = nodes(target).filter(Module)
168
+ if node_to_exclude is not None:
169
+ all_nodes -= all_nodes.filter(node_to_exclude)
170
+
171
+ # Separate nodes with and without call_order
172
+ nodes_with_order = []
173
+ for path, node in all_nodes.items():
174
+ try:
175
+ fun = getattr(node, fn_name)
176
+ except AttributeError as e:
177
+ if fn_if_not_exist == 'raise':
178
+ raise AttributeError(
179
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
180
+ ) from e
181
+ elif fn_if_not_exist in ('pass', 'none'):
182
+ continue
183
+ elif fn_if_not_exist == 'warn':
184
+ warnings.warn(
185
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
186
+ f"Skipping.",
187
+ UserWarning
188
+ )
189
+ continue
190
+ else:
191
+ raise ValueError(
192
+ f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
193
+ )
194
+
195
+ if not callable(fun):
196
+ raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
197
+
198
+ if hasattr(fun, 'call_order'):
199
+ nodes_with_order.append(node)
200
+ else:
201
+ fun(*args, **kwargs)
202
+
203
+ # Execute nodes with call_order in sorted order
204
+ for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
205
+ getattr(node, fn_name)(*args, **kwargs)
206
+ return target
207
+
208
+
209
+ def vmap_call_all_fns(
210
+ target: T,
211
+ fn_name: str,
212
+ args: Sequence[Any] | Any = (),
213
+ kwargs: Mapping[str, Any] | None = None,
214
+ axis_size: int = None,
215
+ node_to_exclude: Filter = None,
216
+ state_tag: str | None = None,
217
+ fn_if_not_exist: str = 'raise',
218
+ ) -> T:
219
+ """
220
+ Apply vectorized mapping to call a function on all module nodes with batched state handling.
221
+
222
+ This function creates multiple batched instances by applying vmap to the specified method
223
+ call across all module nodes. Each batch element maintains its own random key and state
224
+ values. This is particularly useful for creating ensembles or batched models.
225
+
226
+ Parameters
227
+ ----------
228
+ target : Module
229
+ The target module on which to call functions.
230
+ fn_name : str
231
+ The name of the method to call on each module node.
232
+ args : Sequence[Any] or Any, optional
233
+ Positional arguments to pass to the called method. A single non-tuple
234
+ argument will be automatically wrapped in a tuple. Default is ().
235
+ kwargs : Mapping[str, Any], optional
236
+ Keyword arguments to pass to the called method. Default is None.
237
+ axis_size : int
238
+ The size of the batch dimension for vmap. Must be a positive integer.
239
+ node_to_exclude : Filter, optional
240
+ A filter to exclude certain nodes from the function call.
241
+ state_tag : str, optional
242
+ An optional tag to categorize newly created states during the vmap operation.
243
+ fn_if_not_exist : str, optional
244
+ Behavior when the specified method doesn't exist on a node:
245
+
246
+ - 'raise': Raise an AttributeError (default)
247
+ - 'pass' or 'none': Skip the node silently
248
+ - 'warn': Issue a warning and skip the node
249
+
250
+ Raises
251
+ ------
252
+ ValueError
253
+ If `axis_size` is None or not a positive integer.
254
+ TypeError
255
+ If `kwargs` is not a mapping.
256
+
257
+ Examples
258
+ --------
259
+ .. code-block:: python
260
+
261
+ >>> import brainstate
262
+ >>>
263
+ >>> net = brainstate.nn.Linear(10, 20)
264
+ >>> # Create 5 batched instances with different initializations
265
+ >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
266
+ """
267
+
268
+ if axis_size is None or axis_size <= 0:
269
+ raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
270
+
271
+ if not isinstance(args, tuple):
272
+ args = (args,)
273
+ kwargs = kwargs or {}
274
+ if not isinstance(kwargs, Mapping):
275
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
276
+
277
+ @vmap(axis_size=axis_size)
278
+ def vmapped_fn():
279
+ with catch_new_states(state_tag) as inner_catcher:
280
+ call_all_fns(
281
+ target,
282
+ fn_name=fn_name,
283
+ args=args,
284
+ kwargs=kwargs,
285
+ node_to_exclude=node_to_exclude,
286
+ fn_if_not_exist=fn_if_not_exist
287
+ )
288
+ return inner_catcher.get_state_values()
289
+
290
+ with catch_new_states(state_tag) as outer_catcher:
291
+ values = vmapped_fn()
292
+ states = outer_catcher.get_states()
293
+ for state, value in zip(states, values):
294
+ state.value = value
295
+ return target
296
+
297
+
298
+ @set_module_as('brainstate.nn')
299
+ def init_all_states(
300
+ target: T,
301
+ *init_args,
302
+ node_to_exclude: Filter = None,
303
+ **init_kwargs,
304
+ ) -> T:
305
+ """
306
+ Initialize states for all module nodes within the target.
307
+
308
+ This is a convenience wrapper around `call_all_functions` that specifically calls
309
+ the `init_state` method on all module nodes. The execution order respects any
310
+ `@call_order()` decorators on the `init_state` methods.
311
+
312
+ Parameters
313
+ ----------
314
+ target : Module
315
+ The target module whose states are to be initialized.
316
+ *init_args
317
+ Variable positional arguments to pass to each `init_state` method.
318
+ node_to_exclude : Filter, optional
319
+ A filter to exclude certain nodes from initialization.
320
+ Can be a type, predicate function, or any filter supported by the graph API.
321
+ **init_kwargs
322
+ Variable keyword arguments to pass to each `init_state` method.
323
+
324
+ Examples
325
+ --------
326
+ .. code-block:: python
327
+
328
+ >>> import brainstate
329
+ >>>
330
+ >>> net = brainstate.nn.Sequential(
331
+ ... brainstate.nn.Linear(10, 20),
332
+ ... brainstate.nn.Dropout(0.5)
333
+ ... )
334
+ >>> # Initialize all states
335
+ >>> brainstate.nn.init_all_states(net)
336
+ >>>
337
+ >>> # Initialize with custom arguments
338
+ >>> brainstate.nn.init_all_states(net, batch_size=32)
339
+
340
+ See Also
341
+ --------
342
+ call_all_functions : The underlying function that executes the calls.
343
+ vmap_init_all_states : Vectorized version for batched initialization.
344
+ """
345
+ call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
346
+ return target
347
+
348
+
349
+ @set_module_as('brainstate.nn')
350
+ def vmap_init_all_states(
351
+ target: T,
352
+ *init_args,
353
+ axis_size: int = None,
354
+ node_to_exclude: Filter = None,
355
+ state_to_exclude: Filter = None,
356
+ state_tag: str | None = None,
357
+ **init_kwargs
358
+ ) -> T:
359
+ """
360
+ Initialize states with vectorized mapping for creating batched module instances.
361
+
362
+ This function applies vmap to the initialization process, creating multiple batched
363
+ instances of module states. Each batch element will have independent state values
364
+ and random keys. This is useful for ensemble models or parameter sweeps.
365
+
366
+ Parameters
367
+ ----------
368
+ target : Module
369
+ The target module whose states are to be initialized.
370
+ *init_args
371
+ Variable positional arguments to pass to each `init_state` method.
372
+ axis_size : int
373
+ The size of the batch dimension. Must be a positive integer.
374
+ node_to_exclude : Filter, optional
375
+ A filter to exclude certain nodes from initialization.
376
+ state_to_exclude : Filter, optional
377
+ A filter to exclude certain states from being vmapped.
378
+ Excluded states will remain shared across all batched instances.
379
+ state_tag : str, optional
380
+ An optional tag to categorize newly created states.
381
+ **init_kwargs
382
+ Variable keyword arguments to pass to each `init_state` method.
383
+
384
+ Raises
385
+ ------
386
+ ValueError
387
+ If `axis_size` is None or not a positive integer.
388
+
389
+ Examples
390
+ --------
391
+ .. code-block:: python
392
+
393
+ >>> import brainstate
394
+ >>>
395
+ >>> net = brainstate.nn.Linear(10, 20)
396
+ >>> # Create 8 batched instances with different random initializations
397
+ >>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
398
+ >>>
399
+ >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
400
+ >>> print(net.weight.shape)
401
+
402
+ See Also
403
+ --------
404
+ init_all_states : Non-vectorized version.
405
+ vmap_new_states : The underlying vmap transformation for states.
406
+ """
407
+
408
+ # vmap_call_all_functions(
409
+ # target,
410
+ # fun_name='init_state',
411
+ # args=init_args,
412
+ # kwargs=init_kwargs,
413
+ # axis_size=axis_size,
414
+ # node_to_exclude=node_to_exclude,
415
+ # state_tag=state_tag,
416
+ # )
417
+
418
+ def init_fn():
419
+ init_all_states(
420
+ target,
421
+ *init_args,
422
+ **init_kwargs,
423
+ node_to_exclude=node_to_exclude,
424
+ )
425
+ return
426
+
427
+ vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
428
+ return target
429
+
430
+
431
+ @set_module_as('brainstate.nn')
432
+ def reset_all_states(
433
+ target: T,
434
+ *reset_args,
435
+ node_to_exclude: Filter = None,
436
+ **reset_kwargs,
437
+ ) -> T:
438
+ """
439
+ Reset states for all module nodes within the target.
440
+
441
+ This is a convenience wrapper around `call_all_functions` that specifically calls
442
+ the `reset_state` method on all module nodes. The execution order respects any
443
+ `@call_order()` decorators on the `reset_state` methods. This is typically used
444
+ to reset recurrent neural network states between sequences.
445
+
446
+ Parameters
447
+ ----------
448
+ target : Module
449
+ The target module whose states are to be reset.
450
+ reset_args
451
+ Positional arguments to pass to each `reset_state` method.
452
+ A single non-tuple argument will be automatically wrapped in a tuple.
453
+ Default is ().
454
+ reset_kwargs
455
+ Keyword arguments to pass to each `reset_state` method.
456
+ Default is None.
457
+ node_to_exclude : Filter, optional
458
+ A filter to exclude certain nodes from reset.
459
+ Can be a type, predicate function, or any filter supported by the graph API.
460
+
461
+ Examples
462
+ --------
463
+ .. code-block:: python
464
+
465
+ >>> import brainstate
466
+ >>>
467
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
468
+ >>> brainstate.nn.init_all_states(rnn, batch_size=32)
469
+ >>>
470
+ >>> # Process a sequence
471
+ >>> for x in sequence:
472
+ ... output = rnn(x)
473
+ >>>
474
+ >>> # Reset states before processing next sequence
475
+ >>> brainstate.nn.reset_all_states(rnn)
476
+
477
+ See Also
478
+ --------
479
+ call_all_functions : The underlying function that executes the calls.
480
+ vmap_reset_all_states : Vectorized version for batched reset.
481
+ """
482
+ call_all_fns(
483
+ target,
484
+ fn_name='reset_state',
485
+ args=reset_args,
486
+ kwargs=reset_kwargs,
487
+ node_to_exclude=node_to_exclude
488
+ )
489
+ return target
490
+
491
+
492
+ def vmap_reset_all_states(
493
+ target: T,
494
+ *reset_args,
495
+ axis_size: int = None,
496
+ node_to_exclude: Filter = None,
497
+ state_tag: str | None = None,
498
+ **reset_kwargs,
499
+ ) -> T:
500
+ """
501
+ Reset states with vectorized mapping across batched module instances.
502
+
503
+ This function applies vmap to the reset process, resetting states across all
504
+ batched instances of the module. Each batch element will have its state reset
505
+ independently with its own random key. This is useful when working with batched
506
+ recurrent models or ensembles.
507
+
508
+ Parameters
509
+ ----------
510
+ target : Module
511
+ The target module whose states are to be reset.
512
+ reset_args
513
+ Positional arguments to pass to each `reset_state` method.
514
+ A single non-tuple argument will be automatically wrapped in a tuple.
515
+ Default is ().
516
+ reset_kwargs
517
+ Keyword arguments to pass to each `reset_state` method.
518
+ Default is None.
519
+ axis_size : int
520
+ The size of the batch dimension. Must be a positive integer.
521
+ node_to_exclude : Filter, optional
522
+ A filter to exclude certain nodes from reset.
523
+ state_tag : str, optional
524
+ An optional tag to categorize newly created states during the reset.
525
+
526
+ Raises
527
+ ------
528
+ ValueError
529
+ If `axis_size` is None or not a positive integer.
530
+ TypeError
531
+ If `reset_kwargs` is not a mapping.
532
+
533
+ Examples
534
+ --------
535
+ .. code-block:: python
536
+
537
+ >>> import brainstate
538
+ >>>
539
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
540
+ >>> # Initialize with 16 batched instances
541
+ >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
542
+ >>>
543
+ >>> # Process sequences...
544
+ >>>
545
+ >>> # Reset all 16 batched instances
546
+ >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
547
+
548
+ See Also
549
+ --------
550
+ reset_all_states : Non-vectorized version.
551
+ vmap_call_all_functions : The underlying vmap function call mechanism.
552
+ """
553
+ vmap_call_all_fns(
554
+ target,
555
+ fn_name='reset_state',
556
+ args=reset_args,
557
+ kwargs=reset_kwargs,
558
+ axis_size=axis_size,
559
+ node_to_exclude=node_to_exclude,
560
+ state_tag=state_tag,
561
+ )
562
+ return target
563
+
564
+
565
+ @set_module_as('brainstate.nn')
566
+ def assign_state_values(
567
+ target: Module,
568
+ *state_by_abs_path: Mapping[str, Any]
569
+ ) -> tuple[list[str], list[str]]:
570
+ """
571
+ Assign state values to a module from one or more state dictionaries.
572
+
573
+ This function updates the state values of a module based on provided state dictionaries.
574
+ State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
575
+ The function handles missing and unexpected keys, returning them for inspection.
576
+
577
+ Parameters
578
+ ----------
579
+ target : Module
580
+ The target module whose states will be updated.
581
+ *state_by_abs_path : Mapping[str, Any]
582
+ One or more state dictionaries with absolute path keys mapping to state values.
583
+ If multiple dictionaries are provided, they will be merged (later dictionaries
584
+ override earlier ones for duplicate keys).
585
+
586
+ Returns
587
+ -------
588
+ tuple[list[str], list[str]]
589
+ A tuple of (unexpected_keys, missing_keys):
590
+
591
+ - unexpected_keys: Keys present in the state dictionaries but not in the module
592
+ - missing_keys: Keys present in the module but not in the state dictionaries
593
+
594
+ Examples
595
+ --------
596
+ .. code-block:: python
597
+
598
+ >>> import brainstate
599
+ >>>
600
+ >>> net = brainstate.nn.Linear(10, 20)
601
+ >>> brainstate.nn.init_all_states(net)
602
+ >>>
603
+ >>> # Save state values
604
+ >>> state_dict = {path: state.value for path, state in net.states().items()}
605
+ >>>
606
+ >>> # Later, restore state values
607
+ >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
608
+ >>> print(f"Unexpected keys: {unexpected}")
609
+ >>> print(f"Missing keys: {missing}")
610
+
611
+ Notes
612
+ -----
613
+ - All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
614
+ - Only states with matching keys are updated; unexpected and missing keys are
615
+ returned but do not cause errors.
616
+ - If multiple dictionaries contain the same key, the last one takes precedence.
617
+ """
618
+ # Merge all state dictionaries
619
+ all_states = {}
620
+ for state_dict in state_by_abs_path:
621
+ all_states.update(state_dict)
622
+
623
+ # Get current module states
624
+ variables = target.states()
625
+ keys1 = set(all_states.keys())
626
+ keys2 = set(variables.keys())
627
+
628
+ # Update matching states
629
+ for key in keys2.intersection(keys1):
630
+ variables[key].value = jax.numpy.asarray(all_states[key])
631
+
632
+ # Return mismatched keys
633
+ unexpected_keys = sorted(keys1 - keys2)
634
+ missing_keys = sorted(keys2 - keys1)
635
+ return unexpected_keys, missing_keys