brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,633 +1,633 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- import warnings
16
- from collections.abc import Sequence, Mapping
17
- from typing import Callable, TypeVar, Any
18
-
19
- import jax
20
-
21
- from brainstate._state import catch_new_states
22
- from brainstate._utils import set_module_as
23
- from brainstate.graph import nodes
24
- from brainstate.transform import vmap, vmap_new_states
25
- from brainstate.typing import Filter
26
- from ._module import Module
27
-
28
- # the maximum order
29
- MAX_ORDER = 10
30
-
31
- T = TypeVar('T', bound=Module)
32
-
33
- __all__ = [
34
- 'call_order',
35
- 'call_all_fns',
36
- 'vmap_call_all_fns',
37
- 'init_all_states',
38
- 'vmap_init_all_states',
39
- 'reset_all_states',
40
- 'vmap_reset_all_states',
41
- 'assign_state_values',
42
- ]
43
-
44
-
45
- @set_module_as('brainstate.nn')
46
- def call_order(
47
- level: int = 0,
48
- check_order_boundary: bool = True
49
- ) -> Callable[[Callable], Callable]:
50
- """
51
- Decorator for specifying the execution order of functions in collective operations.
52
-
53
- This decorator attaches a `call_order` attribute to a function, which is used by
54
- collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
55
- to determine the execution order. Functions with lower order levels are executed first.
56
-
57
- Parameters
58
- ----------
59
- level : int, optional
60
- The execution order level. Lower values indicate earlier execution.
61
- Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
62
- Default is 0.
63
- check_order_boundary : bool, optional
64
- Whether to validate that the order level is within the valid range [0, MAX_ORDER).
65
- Default is True.
66
-
67
- Returns
68
- -------
69
- Callable[[Callable], Callable]
70
- A decorator function that adds the `call_order` attribute to the decorated function.
71
-
72
- Raises
73
- ------
74
- ValueError
75
- If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
76
-
77
- Examples
78
- --------
79
- .. code-block:: python
80
-
81
- >>> import brainstate
82
- >>>
83
- >>> class MyModule(brainstate.nn.Module):
84
- ... @brainstate.nn.call_order(0)
85
- ... def reset_state(self):
86
- ... print("Reset first")
87
- ...
88
- ... @brainstate.nn.call_order(1)
89
- ... def another_reset(self):
90
- ... print("Reset second")
91
- """
92
- if check_order_boundary and (level < 0 or level >= MAX_ORDER):
93
- raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
94
-
95
- def wrap(fun: Callable) -> Callable:
96
- fun.call_order = level
97
- return fun
98
-
99
- return wrap
100
-
101
-
102
- @set_module_as('brainstate.nn')
103
- def call_all_fns(
104
- target: T,
105
- fn_name: str,
106
- args: Sequence[Any] | Any = (),
107
- kwargs: Mapping[str, Any] | None = None,
108
- node_to_exclude: Filter = None,
109
- fn_if_not_exist: str = 'raise',
110
- ) -> T:
111
- """
112
- Call a specified function on all module nodes within a target, respecting call order.
113
-
114
- This function traverses all module nodes in the target and invokes the specified method
115
- on each node. Functions decorated with `@call_order()` are executed in ascending order
116
- of their level values, while functions without the decorator are executed first.
117
-
118
- Parameters
119
- ----------
120
- target : Module
121
- The target module on which to call functions.
122
- fn_name : str
123
- The name of the method to call on each module node.
124
- node_to_exclude : Filter, optional
125
- A filter to exclude certain nodes from the function call.
126
- Can be a type, predicate function, or any filter supported by the graph API.
127
- fn_if_not_exist : str, optional
128
- Behavior when the specified method doesn't exist on a node:
129
-
130
- - 'raise': Raise an AttributeError (default)
131
- - 'pass' or 'none': Skip the node silently
132
- - 'warn': Issue a warning and skip the node
133
- args
134
- Positional arguments to pass to the called method. A single non-tuple
135
- argument will be automatically wrapped in a tuple. Default is ().
136
- kwargs
137
- Keyword arguments to pass to the called method. Default is None.
138
-
139
- Raises
140
- ------
141
- TypeError
142
- If `fun_name` is not a string or `kwargs` is not a mapping.
143
- ValueError
144
- If `fn_if_not_exist` is not one of the allowed values.
145
- AttributeError
146
- If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
147
-
148
- Examples
149
- --------
150
- .. code-block:: python
151
-
152
- >>> import brainstate
153
- >>>
154
- >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
155
- >>> brainstate.nn.call_all_fns(net, 'init_state')
156
- """
157
- if not isinstance(fn_name, str):
158
- raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
159
-
160
- args = (args,) if not isinstance(args, tuple) else args
161
- kwargs = kwargs or {}
162
- if not isinstance(kwargs, Mapping):
163
- raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
164
-
165
- all_nodes = nodes(target).filter(Module)
166
- if node_to_exclude is not None:
167
- all_nodes -= all_nodes.filter(node_to_exclude)
168
-
169
- # Separate nodes with and without call_order
170
- nodes_with_order = []
171
- for path, node in all_nodes.items():
172
- try:
173
- fun = getattr(node, fn_name)
174
- except AttributeError as e:
175
- if fn_if_not_exist == 'raise':
176
- raise AttributeError(
177
- f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
178
- ) from e
179
- elif fn_if_not_exist in ('pass', 'none'):
180
- continue
181
- elif fn_if_not_exist == 'warn':
182
- warnings.warn(
183
- f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
184
- f"Skipping.",
185
- UserWarning
186
- )
187
- continue
188
- else:
189
- raise ValueError(
190
- f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
191
- )
192
-
193
- if not callable(fun):
194
- raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
195
-
196
- if hasattr(fun, 'call_order'):
197
- nodes_with_order.append(node)
198
- else:
199
- fun(*args, **kwargs)
200
-
201
- # Execute nodes with call_order in sorted order
202
- for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
203
- getattr(node, fn_name)(*args, **kwargs)
204
- return target
205
-
206
-
207
- def vmap_call_all_fns(
208
- target: T,
209
- fn_name: str,
210
- args: Sequence[Any] | Any = (),
211
- kwargs: Mapping[str, Any] | None = None,
212
- axis_size: int = None,
213
- node_to_exclude: Filter = None,
214
- state_tag: str | None = None,
215
- fn_if_not_exist: str = 'raise',
216
- ) -> T:
217
- """
218
- Apply vectorized mapping to call a function on all module nodes with batched state handling.
219
-
220
- This function creates multiple batched instances by applying vmap to the specified method
221
- call across all module nodes. Each batch element maintains its own random key and state
222
- values. This is particularly useful for creating ensembles or batched models.
223
-
224
- Parameters
225
- ----------
226
- target : Module
227
- The target module on which to call functions.
228
- fn_name : str
229
- The name of the method to call on each module node.
230
- args : Sequence[Any] or Any, optional
231
- Positional arguments to pass to the called method. A single non-tuple
232
- argument will be automatically wrapped in a tuple. Default is ().
233
- kwargs : Mapping[str, Any], optional
234
- Keyword arguments to pass to the called method. Default is None.
235
- axis_size : int
236
- The size of the batch dimension for vmap. Must be a positive integer.
237
- node_to_exclude : Filter, optional
238
- A filter to exclude certain nodes from the function call.
239
- state_tag : str, optional
240
- An optional tag to categorize newly created states during the vmap operation.
241
- fn_if_not_exist : str, optional
242
- Behavior when the specified method doesn't exist on a node:
243
-
244
- - 'raise': Raise an AttributeError (default)
245
- - 'pass' or 'none': Skip the node silently
246
- - 'warn': Issue a warning and skip the node
247
-
248
- Raises
249
- ------
250
- ValueError
251
- If `axis_size` is None or not a positive integer.
252
- TypeError
253
- If `kwargs` is not a mapping.
254
-
255
- Examples
256
- --------
257
- .. code-block:: python
258
-
259
- >>> import brainstate
260
- >>>
261
- >>> net = brainstate.nn.Linear(10, 20)
262
- >>> # Create 5 batched instances with different initializations
263
- >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
264
- """
265
-
266
- if axis_size is None or axis_size <= 0:
267
- raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
268
-
269
- if not isinstance(args, tuple):
270
- args = (args,)
271
- kwargs = kwargs or {}
272
- if not isinstance(kwargs, Mapping):
273
- raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
274
-
275
- @vmap(axis_size=axis_size)
276
- def vmapped_fn():
277
- with catch_new_states(state_tag) as inner_catcher:
278
- call_all_fns(
279
- target,
280
- fn_name=fn_name,
281
- args=args,
282
- kwargs=kwargs,
283
- node_to_exclude=node_to_exclude,
284
- fn_if_not_exist=fn_if_not_exist
285
- )
286
- return inner_catcher.get_state_values()
287
-
288
- with catch_new_states(state_tag) as outer_catcher:
289
- values = vmapped_fn()
290
- states = outer_catcher.get_states()
291
- for state, value in zip(states, values):
292
- state.value = value
293
- return target
294
-
295
-
296
- @set_module_as('brainstate.nn')
297
- def init_all_states(
298
- target: T,
299
- *init_args,
300
- node_to_exclude: Filter = None,
301
- **init_kwargs,
302
- ) -> T:
303
- """
304
- Initialize states for all module nodes within the target.
305
-
306
- This is a convenience wrapper around `call_all_functions` that specifically calls
307
- the `init_state` method on all module nodes. The execution order respects any
308
- `@call_order()` decorators on the `init_state` methods.
309
-
310
- Parameters
311
- ----------
312
- target : Module
313
- The target module whose states are to be initialized.
314
- *init_args
315
- Variable positional arguments to pass to each `init_state` method.
316
- node_to_exclude : Filter, optional
317
- A filter to exclude certain nodes from initialization.
318
- Can be a type, predicate function, or any filter supported by the graph API.
319
- **init_kwargs
320
- Variable keyword arguments to pass to each `init_state` method.
321
-
322
- Examples
323
- --------
324
- .. code-block:: python
325
-
326
- >>> import brainstate
327
- >>>
328
- >>> net = brainstate.nn.Sequential(
329
- ... brainstate.nn.Linear(10, 20),
330
- ... brainstate.nn.Dropout(0.5)
331
- ... )
332
- >>> # Initialize all states
333
- >>> brainstate.nn.init_all_states(net)
334
- >>>
335
- >>> # Initialize with custom arguments
336
- >>> brainstate.nn.init_all_states(net, batch_size=32)
337
-
338
- See Also
339
- --------
340
- call_all_functions : The underlying function that executes the calls.
341
- vmap_init_all_states : Vectorized version for batched initialization.
342
- """
343
- call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
344
- return target
345
-
346
-
347
- @set_module_as('brainstate.nn')
348
- def vmap_init_all_states(
349
- target: T,
350
- *init_args,
351
- axis_size: int = None,
352
- node_to_exclude: Filter = None,
353
- state_to_exclude: Filter = None,
354
- state_tag: str | None = None,
355
- **init_kwargs
356
- ) -> T:
357
- """
358
- Initialize states with vectorized mapping for creating batched module instances.
359
-
360
- This function applies vmap to the initialization process, creating multiple batched
361
- instances of module states. Each batch element will have independent state values
362
- and random keys. This is useful for ensemble models or parameter sweeps.
363
-
364
- Parameters
365
- ----------
366
- target : Module
367
- The target module whose states are to be initialized.
368
- *init_args
369
- Variable positional arguments to pass to each `init_state` method.
370
- axis_size : int
371
- The size of the batch dimension. Must be a positive integer.
372
- node_to_exclude : Filter, optional
373
- A filter to exclude certain nodes from initialization.
374
- state_to_exclude : Filter, optional
375
- A filter to exclude certain states from being vmapped.
376
- Excluded states will remain shared across all batched instances.
377
- state_tag : str, optional
378
- An optional tag to categorize newly created states.
379
- **init_kwargs
380
- Variable keyword arguments to pass to each `init_state` method.
381
-
382
- Raises
383
- ------
384
- ValueError
385
- If `axis_size` is None or not a positive integer.
386
-
387
- Examples
388
- --------
389
- .. code-block:: python
390
-
391
- >>> import brainstate
392
- >>>
393
- >>> net = brainstate.nn.Linear(10, 20)
394
- >>> # Create 8 batched instances with different random initializations
395
- >>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
396
- >>>
397
- >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
398
- >>> print(net.weight.shape)
399
-
400
- See Also
401
- --------
402
- init_all_states : Non-vectorized version.
403
- vmap_new_states : The underlying vmap transformation for states.
404
- """
405
-
406
- # vmap_call_all_functions(
407
- # target,
408
- # fun_name='init_state',
409
- # args=init_args,
410
- # kwargs=init_kwargs,
411
- # axis_size=axis_size,
412
- # node_to_exclude=node_to_exclude,
413
- # state_tag=state_tag,
414
- # )
415
-
416
- def init_fn():
417
- init_all_states(
418
- target,
419
- *init_args,
420
- **init_kwargs,
421
- node_to_exclude=node_to_exclude,
422
- )
423
- return
424
-
425
- vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
426
- return target
427
-
428
-
429
- @set_module_as('brainstate.nn')
430
- def reset_all_states(
431
- target: T,
432
- *reset_args,
433
- node_to_exclude: Filter = None,
434
- **reset_kwargs,
435
- ) -> T:
436
- """
437
- Reset states for all module nodes within the target.
438
-
439
- This is a convenience wrapper around `call_all_functions` that specifically calls
440
- the `reset_state` method on all module nodes. The execution order respects any
441
- `@call_order()` decorators on the `reset_state` methods. This is typically used
442
- to reset recurrent neural network states between sequences.
443
-
444
- Parameters
445
- ----------
446
- target : Module
447
- The target module whose states are to be reset.
448
- reset_args
449
- Positional arguments to pass to each `reset_state` method.
450
- A single non-tuple argument will be automatically wrapped in a tuple.
451
- Default is ().
452
- reset_kwargs
453
- Keyword arguments to pass to each `reset_state` method.
454
- Default is None.
455
- node_to_exclude : Filter, optional
456
- A filter to exclude certain nodes from reset.
457
- Can be a type, predicate function, or any filter supported by the graph API.
458
-
459
- Examples
460
- --------
461
- .. code-block:: python
462
-
463
- >>> import brainstate
464
- >>>
465
- >>> rnn = brainstate.nn.RNNCell(10, 20)
466
- >>> brainstate.nn.init_all_states(rnn, batch_size=32)
467
- >>>
468
- >>> # Process a sequence
469
- >>> for x in sequence:
470
- ... output = rnn(x)
471
- >>>
472
- >>> # Reset states before processing next sequence
473
- >>> brainstate.nn.reset_all_states(rnn)
474
-
475
- See Also
476
- --------
477
- call_all_functions : The underlying function that executes the calls.
478
- vmap_reset_all_states : Vectorized version for batched reset.
479
- """
480
- call_all_fns(
481
- target,
482
- fn_name='reset_state',
483
- args=reset_args,
484
- kwargs=reset_kwargs,
485
- node_to_exclude=node_to_exclude
486
- )
487
- return target
488
-
489
-
490
- def vmap_reset_all_states(
491
- target: T,
492
- *reset_args,
493
- axis_size: int = None,
494
- node_to_exclude: Filter = None,
495
- state_tag: str | None = None,
496
- **reset_kwargs,
497
- ) -> T:
498
- """
499
- Reset states with vectorized mapping across batched module instances.
500
-
501
- This function applies vmap to the reset process, resetting states across all
502
- batched instances of the module. Each batch element will have its state reset
503
- independently with its own random key. This is useful when working with batched
504
- recurrent models or ensembles.
505
-
506
- Parameters
507
- ----------
508
- target : Module
509
- The target module whose states are to be reset.
510
- reset_args
511
- Positional arguments to pass to each `reset_state` method.
512
- A single non-tuple argument will be automatically wrapped in a tuple.
513
- Default is ().
514
- reset_kwargs
515
- Keyword arguments to pass to each `reset_state` method.
516
- Default is None.
517
- axis_size : int
518
- The size of the batch dimension. Must be a positive integer.
519
- node_to_exclude : Filter, optional
520
- A filter to exclude certain nodes from reset.
521
- state_tag : str, optional
522
- An optional tag to categorize newly created states during the reset.
523
-
524
- Raises
525
- ------
526
- ValueError
527
- If `axis_size` is None or not a positive integer.
528
- TypeError
529
- If `reset_kwargs` is not a mapping.
530
-
531
- Examples
532
- --------
533
- .. code-block:: python
534
-
535
- >>> import brainstate
536
- >>>
537
- >>> rnn = brainstate.nn.RNNCell(10, 20)
538
- >>> # Initialize with 16 batched instances
539
- >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
540
- >>>
541
- >>> # Process sequences...
542
- >>>
543
- >>> # Reset all 16 batched instances
544
- >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
545
-
546
- See Also
547
- --------
548
- reset_all_states : Non-vectorized version.
549
- vmap_call_all_functions : The underlying vmap function call mechanism.
550
- """
551
- vmap_call_all_fns(
552
- target,
553
- fn_name='reset_state',
554
- args=reset_args,
555
- kwargs=reset_kwargs,
556
- axis_size=axis_size,
557
- node_to_exclude=node_to_exclude,
558
- state_tag=state_tag,
559
- )
560
- return target
561
-
562
-
563
- @set_module_as('brainstate.nn')
564
- def assign_state_values(
565
- target: Module,
566
- *state_by_abs_path: Mapping[str, Any]
567
- ) -> tuple[list[str], list[str]]:
568
- """
569
- Assign state values to a module from one or more state dictionaries.
570
-
571
- This function updates the state values of a module based on provided state dictionaries.
572
- State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
573
- The function handles missing and unexpected keys, returning them for inspection.
574
-
575
- Parameters
576
- ----------
577
- target : Module
578
- The target module whose states will be updated.
579
- *state_by_abs_path : Mapping[str, Any]
580
- One or more state dictionaries with absolute path keys mapping to state values.
581
- If multiple dictionaries are provided, they will be merged (later dictionaries
582
- override earlier ones for duplicate keys).
583
-
584
- Returns
585
- -------
586
- tuple[list[str], list[str]]
587
- A tuple of (unexpected_keys, missing_keys):
588
-
589
- - unexpected_keys: Keys present in the state dictionaries but not in the module
590
- - missing_keys: Keys present in the module but not in the state dictionaries
591
-
592
- Examples
593
- --------
594
- .. code-block:: python
595
-
596
- >>> import brainstate
597
- >>>
598
- >>> net = brainstate.nn.Linear(10, 20)
599
- >>> brainstate.nn.init_all_states(net)
600
- >>>
601
- >>> # Save state values
602
- >>> state_dict = {path: state.value for path, state in net.states().items()}
603
- >>>
604
- >>> # Later, restore state values
605
- >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
606
- >>> print(f"Unexpected keys: {unexpected}")
607
- >>> print(f"Missing keys: {missing}")
608
-
609
- Notes
610
- -----
611
- - All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
612
- - Only states with matching keys are updated; unexpected and missing keys are
613
- returned but do not cause errors.
614
- - If multiple dictionaries contain the same key, the last one takes precedence.
615
- """
616
- # Merge all state dictionaries
617
- all_states = {}
618
- for state_dict in state_by_abs_path:
619
- all_states.update(state_dict)
620
-
621
- # Get current module states
622
- variables = target.states()
623
- keys1 = set(all_states.keys())
624
- keys2 = set(variables.keys())
625
-
626
- # Update matching states
627
- for key in keys2.intersection(keys1):
628
- variables[key].value = jax.numpy.asarray(all_states[key])
629
-
630
- # Return mismatched keys
631
- unexpected_keys = sorted(keys1 - keys2)
632
- missing_keys = sorted(keys2 - keys1)
633
- return unexpected_keys, missing_keys
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import warnings
16
+ from collections.abc import Sequence, Mapping
17
+ from typing import Callable, TypeVar, Any
18
+
19
+ import jax
20
+
21
+ from brainstate._state import catch_new_states
22
+ from brainstate._utils import set_module_as
23
+ from brainstate.graph import nodes
24
+ from brainstate.transform import vmap, vmap_new_states
25
+ from brainstate.typing import Filter
26
+ from ._module import Module
27
+
28
+ # the maximum order
29
+ MAX_ORDER = 10
30
+
31
+ T = TypeVar('T', bound=Module)
32
+
33
+ __all__ = [
34
+ 'call_order',
35
+ 'call_all_fns',
36
+ 'vmap_call_all_fns',
37
+ 'init_all_states',
38
+ 'vmap_init_all_states',
39
+ 'reset_all_states',
40
+ 'vmap_reset_all_states',
41
+ 'assign_state_values',
42
+ ]
43
+
44
+
45
+ @set_module_as('brainstate.nn')
46
+ def call_order(
47
+ level: int = 0,
48
+ check_order_boundary: bool = True
49
+ ) -> Callable[[Callable], Callable]:
50
+ """
51
+ Decorator for specifying the execution order of functions in collective operations.
52
+
53
+ This decorator attaches a `call_order` attribute to a function, which is used by
54
+ collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
55
+ to determine the execution order. Functions with lower order levels are executed first.
56
+
57
+ Parameters
58
+ ----------
59
+ level : int, optional
60
+ The execution order level. Lower values indicate earlier execution.
61
+ Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
62
+ Default is 0.
63
+ check_order_boundary : bool, optional
64
+ Whether to validate that the order level is within the valid range [0, MAX_ORDER).
65
+ Default is True.
66
+
67
+ Returns
68
+ -------
69
+ Callable[[Callable], Callable]
70
+ A decorator function that adds the `call_order` attribute to the decorated function.
71
+
72
+ Raises
73
+ ------
74
+ ValueError
75
+ If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
76
+
77
+ Examples
78
+ --------
79
+ .. code-block:: python
80
+
81
+ >>> import brainstate
82
+ >>>
83
+ >>> class MyModule(brainstate.nn.Module):
84
+ ... @brainstate.nn.call_order(0)
85
+ ... def reset_state(self):
86
+ ... print("Reset first")
87
+ ...
88
+ ... @brainstate.nn.call_order(1)
89
+ ... def another_reset(self):
90
+ ... print("Reset second")
91
+ """
92
+ if check_order_boundary and (level < 0 or level >= MAX_ORDER):
93
+ raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
94
+
95
+ def wrap(fun: Callable) -> Callable:
96
+ fun.call_order = level
97
+ return fun
98
+
99
+ return wrap
100
+
101
+
102
+ @set_module_as('brainstate.nn')
103
+ def call_all_fns(
104
+ target: T,
105
+ fn_name: str,
106
+ args: Sequence[Any] | Any = (),
107
+ kwargs: Mapping[str, Any] | None = None,
108
+ node_to_exclude: Filter = None,
109
+ fn_if_not_exist: str = 'raise',
110
+ ) -> T:
111
+ """
112
+ Call a specified function on all module nodes within a target, respecting call order.
113
+
114
+ This function traverses all module nodes in the target and invokes the specified method
115
+ on each node. Functions decorated with `@call_order()` are executed in ascending order
116
+ of their level values, while functions without the decorator are executed first.
117
+
118
+ Parameters
119
+ ----------
120
+ target : Module
121
+ The target module on which to call functions.
122
+ fn_name : str
123
+ The name of the method to call on each module node.
124
+ node_to_exclude : Filter, optional
125
+ A filter to exclude certain nodes from the function call.
126
+ Can be a type, predicate function, or any filter supported by the graph API.
127
+ fn_if_not_exist : str, optional
128
+ Behavior when the specified method doesn't exist on a node:
129
+
130
+ - 'raise': Raise an AttributeError (default)
131
+ - 'pass' or 'none': Skip the node silently
132
+ - 'warn': Issue a warning and skip the node
133
+ args
134
+ Positional arguments to pass to the called method. A single non-tuple
135
+ argument will be automatically wrapped in a tuple. Default is ().
136
+ kwargs
137
+ Keyword arguments to pass to the called method. Default is None.
138
+
139
+ Raises
140
+ ------
141
+ TypeError
142
+ If `fun_name` is not a string or `kwargs` is not a mapping.
143
+ ValueError
144
+ If `fn_if_not_exist` is not one of the allowed values.
145
+ AttributeError
146
+ If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
147
+
148
+ Examples
149
+ --------
150
+ .. code-block:: python
151
+
152
+ >>> import brainstate
153
+ >>>
154
+ >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
155
+ >>> brainstate.nn.call_all_fns(net, 'init_state')
156
+ """
157
+ if not isinstance(fn_name, str):
158
+ raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
159
+
160
+ args = (args,) if not isinstance(args, tuple) else args
161
+ kwargs = kwargs or {}
162
+ if not isinstance(kwargs, Mapping):
163
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
164
+
165
+ all_nodes = nodes(target).filter(Module)
166
+ if node_to_exclude is not None:
167
+ all_nodes -= all_nodes.filter(node_to_exclude)
168
+
169
+ # Separate nodes with and without call_order
170
+ nodes_with_order = []
171
+ for path, node in all_nodes.items():
172
+ try:
173
+ fun = getattr(node, fn_name)
174
+ except AttributeError as e:
175
+ if fn_if_not_exist == 'raise':
176
+ raise AttributeError(
177
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
178
+ ) from e
179
+ elif fn_if_not_exist in ('pass', 'none'):
180
+ continue
181
+ elif fn_if_not_exist == 'warn':
182
+ warnings.warn(
183
+ f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
184
+ f"Skipping.",
185
+ UserWarning
186
+ )
187
+ continue
188
+ else:
189
+ raise ValueError(
190
+ f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
191
+ )
192
+
193
+ if not callable(fun):
194
+ raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
195
+
196
+ if hasattr(fun, 'call_order'):
197
+ nodes_with_order.append(node)
198
+ else:
199
+ fun(*args, **kwargs)
200
+
201
+ # Execute nodes with call_order in sorted order
202
+ for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
203
+ getattr(node, fn_name)(*args, **kwargs)
204
+ return target
205
+
206
+
207
+ def vmap_call_all_fns(
208
+ target: T,
209
+ fn_name: str,
210
+ args: Sequence[Any] | Any = (),
211
+ kwargs: Mapping[str, Any] | None = None,
212
+ axis_size: int = None,
213
+ node_to_exclude: Filter = None,
214
+ state_tag: str | None = None,
215
+ fn_if_not_exist: str = 'raise',
216
+ ) -> T:
217
+ """
218
+ Apply vectorized mapping to call a function on all module nodes with batched state handling.
219
+
220
+ This function creates multiple batched instances by applying vmap to the specified method
221
+ call across all module nodes. Each batch element maintains its own random key and state
222
+ values. This is particularly useful for creating ensembles or batched models.
223
+
224
+ Parameters
225
+ ----------
226
+ target : Module
227
+ The target module on which to call functions.
228
+ fn_name : str
229
+ The name of the method to call on each module node.
230
+ args : Sequence[Any] or Any, optional
231
+ Positional arguments to pass to the called method. A single non-tuple
232
+ argument will be automatically wrapped in a tuple. Default is ().
233
+ kwargs : Mapping[str, Any], optional
234
+ Keyword arguments to pass to the called method. Default is None.
235
+ axis_size : int
236
+ The size of the batch dimension for vmap. Must be a positive integer.
237
+ node_to_exclude : Filter, optional
238
+ A filter to exclude certain nodes from the function call.
239
+ state_tag : str, optional
240
+ An optional tag to categorize newly created states during the vmap operation.
241
+ fn_if_not_exist : str, optional
242
+ Behavior when the specified method doesn't exist on a node:
243
+
244
+ - 'raise': Raise an AttributeError (default)
245
+ - 'pass' or 'none': Skip the node silently
246
+ - 'warn': Issue a warning and skip the node
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If `axis_size` is None or not a positive integer.
252
+ TypeError
253
+ If `kwargs` is not a mapping.
254
+
255
+ Examples
256
+ --------
257
+ .. code-block:: python
258
+
259
+ >>> import brainstate
260
+ >>>
261
+ >>> net = brainstate.nn.Linear(10, 20)
262
+ >>> # Create 5 batched instances with different initializations
263
+ >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
264
+ """
265
+
266
+ if axis_size is None or axis_size <= 0:
267
+ raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
268
+
269
+ if not isinstance(args, tuple):
270
+ args = (args,)
271
+ kwargs = kwargs or {}
272
+ if not isinstance(kwargs, Mapping):
273
+ raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
274
+
275
+ @vmap(axis_size=axis_size)
276
+ def vmapped_fn():
277
+ with catch_new_states(state_tag) as inner_catcher:
278
+ call_all_fns(
279
+ target,
280
+ fn_name=fn_name,
281
+ args=args,
282
+ kwargs=kwargs,
283
+ node_to_exclude=node_to_exclude,
284
+ fn_if_not_exist=fn_if_not_exist
285
+ )
286
+ return inner_catcher.get_state_values()
287
+
288
+ with catch_new_states(state_tag) as outer_catcher:
289
+ values = vmapped_fn()
290
+ states = outer_catcher.get_states()
291
+ for state, value in zip(states, values):
292
+ state.value = value
293
+ return target
294
+
295
+
296
+ @set_module_as('brainstate.nn')
297
+ def init_all_states(
298
+ target: T,
299
+ *init_args,
300
+ node_to_exclude: Filter = None,
301
+ **init_kwargs,
302
+ ) -> T:
303
+ """
304
+ Initialize states for all module nodes within the target.
305
+
306
+ This is a convenience wrapper around `call_all_functions` that specifically calls
307
+ the `init_state` method on all module nodes. The execution order respects any
308
+ `@call_order()` decorators on the `init_state` methods.
309
+
310
+ Parameters
311
+ ----------
312
+ target : Module
313
+ The target module whose states are to be initialized.
314
+ *init_args
315
+ Variable positional arguments to pass to each `init_state` method.
316
+ node_to_exclude : Filter, optional
317
+ A filter to exclude certain nodes from initialization.
318
+ Can be a type, predicate function, or any filter supported by the graph API.
319
+ **init_kwargs
320
+ Variable keyword arguments to pass to each `init_state` method.
321
+
322
+ Examples
323
+ --------
324
+ .. code-block:: python
325
+
326
+ >>> import brainstate
327
+ >>>
328
+ >>> net = brainstate.nn.Sequential(
329
+ ... brainstate.nn.Linear(10, 20),
330
+ ... brainstate.nn.Dropout(0.5)
331
+ ... )
332
+ >>> # Initialize all states
333
+ >>> brainstate.nn.init_all_states(net)
334
+ >>>
335
+ >>> # Initialize with custom arguments
336
+ >>> brainstate.nn.init_all_states(net, batch_size=32)
337
+
338
+ See Also
339
+ --------
340
+ call_all_functions : The underlying function that executes the calls.
341
+ vmap_init_all_states : Vectorized version for batched initialization.
342
+ """
343
+ call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
344
+ return target
345
+
346
+
347
+ @set_module_as('brainstate.nn')
348
+ def vmap_init_all_states(
349
+ target: T,
350
+ *init_args,
351
+ axis_size: int = None,
352
+ node_to_exclude: Filter = None,
353
+ state_to_exclude: Filter = None,
354
+ state_tag: str | None = None,
355
+ **init_kwargs
356
+ ) -> T:
357
+ """
358
+ Initialize states with vectorized mapping for creating batched module instances.
359
+
360
+ This function applies vmap to the initialization process, creating multiple batched
361
+ instances of module states. Each batch element will have independent state values
362
+ and random keys. This is useful for ensemble models or parameter sweeps.
363
+
364
+ Parameters
365
+ ----------
366
+ target : Module
367
+ The target module whose states are to be initialized.
368
+ *init_args
369
+ Variable positional arguments to pass to each `init_state` method.
370
+ axis_size : int
371
+ The size of the batch dimension. Must be a positive integer.
372
+ node_to_exclude : Filter, optional
373
+ A filter to exclude certain nodes from initialization.
374
+ state_to_exclude : Filter, optional
375
+ A filter to exclude certain states from being vmapped.
376
+ Excluded states will remain shared across all batched instances.
377
+ state_tag : str, optional
378
+ An optional tag to categorize newly created states.
379
+ **init_kwargs
380
+ Variable keyword arguments to pass to each `init_state` method.
381
+
382
+ Raises
383
+ ------
384
+ ValueError
385
+ If `axis_size` is None or not a positive integer.
386
+
387
+ Examples
388
+ --------
389
+ .. code-block:: python
390
+
391
+ >>> import brainstate
392
+ >>>
393
+ >>> net = brainstate.nn.Linear(10, 20)
394
+ >>> # Create 8 batched instances with different random initializations
395
+ >>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
396
+ >>>
397
+ >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
398
+ >>> print(net.weight.shape)
399
+
400
+ See Also
401
+ --------
402
+ init_all_states : Non-vectorized version.
403
+ vmap_new_states : The underlying vmap transformation for states.
404
+ """
405
+
406
+ # vmap_call_all_functions(
407
+ # target,
408
+ # fun_name='init_state',
409
+ # args=init_args,
410
+ # kwargs=init_kwargs,
411
+ # axis_size=axis_size,
412
+ # node_to_exclude=node_to_exclude,
413
+ # state_tag=state_tag,
414
+ # )
415
+
416
+ def init_fn():
417
+ init_all_states(
418
+ target,
419
+ *init_args,
420
+ **init_kwargs,
421
+ node_to_exclude=node_to_exclude,
422
+ )
423
+ return
424
+
425
+ vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
426
+ return target
427
+
428
+
429
+ @set_module_as('brainstate.nn')
430
+ def reset_all_states(
431
+ target: T,
432
+ *reset_args,
433
+ node_to_exclude: Filter = None,
434
+ **reset_kwargs,
435
+ ) -> T:
436
+ """
437
+ Reset states for all module nodes within the target.
438
+
439
+ This is a convenience wrapper around `call_all_functions` that specifically calls
440
+ the `reset_state` method on all module nodes. The execution order respects any
441
+ `@call_order()` decorators on the `reset_state` methods. This is typically used
442
+ to reset recurrent neural network states between sequences.
443
+
444
+ Parameters
445
+ ----------
446
+ target : Module
447
+ The target module whose states are to be reset.
448
+ reset_args
449
+ Positional arguments to pass to each `reset_state` method.
450
+ A single non-tuple argument will be automatically wrapped in a tuple.
451
+ Default is ().
452
+ reset_kwargs
453
+ Keyword arguments to pass to each `reset_state` method.
454
+ Default is None.
455
+ node_to_exclude : Filter, optional
456
+ A filter to exclude certain nodes from reset.
457
+ Can be a type, predicate function, or any filter supported by the graph API.
458
+
459
+ Examples
460
+ --------
461
+ .. code-block:: python
462
+
463
+ >>> import brainstate
464
+ >>>
465
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
466
+ >>> brainstate.nn.init_all_states(rnn, batch_size=32)
467
+ >>>
468
+ >>> # Process a sequence
469
+ >>> for x in sequence:
470
+ ... output = rnn(x)
471
+ >>>
472
+ >>> # Reset states before processing next sequence
473
+ >>> brainstate.nn.reset_all_states(rnn)
474
+
475
+ See Also
476
+ --------
477
+ call_all_functions : The underlying function that executes the calls.
478
+ vmap_reset_all_states : Vectorized version for batched reset.
479
+ """
480
+ call_all_fns(
481
+ target,
482
+ fn_name='reset_state',
483
+ args=reset_args,
484
+ kwargs=reset_kwargs,
485
+ node_to_exclude=node_to_exclude
486
+ )
487
+ return target
488
+
489
+
490
+ def vmap_reset_all_states(
491
+ target: T,
492
+ *reset_args,
493
+ axis_size: int = None,
494
+ node_to_exclude: Filter = None,
495
+ state_tag: str | None = None,
496
+ **reset_kwargs,
497
+ ) -> T:
498
+ """
499
+ Reset states with vectorized mapping across batched module instances.
500
+
501
+ This function applies vmap to the reset process, resetting states across all
502
+ batched instances of the module. Each batch element will have its state reset
503
+ independently with its own random key. This is useful when working with batched
504
+ recurrent models or ensembles.
505
+
506
+ Parameters
507
+ ----------
508
+ target : Module
509
+ The target module whose states are to be reset.
510
+ reset_args
511
+ Positional arguments to pass to each `reset_state` method.
512
+ A single non-tuple argument will be automatically wrapped in a tuple.
513
+ Default is ().
514
+ reset_kwargs
515
+ Keyword arguments to pass to each `reset_state` method.
516
+ Default is None.
517
+ axis_size : int
518
+ The size of the batch dimension. Must be a positive integer.
519
+ node_to_exclude : Filter, optional
520
+ A filter to exclude certain nodes from reset.
521
+ state_tag : str, optional
522
+ An optional tag to categorize newly created states during the reset.
523
+
524
+ Raises
525
+ ------
526
+ ValueError
527
+ If `axis_size` is None or not a positive integer.
528
+ TypeError
529
+ If `reset_kwargs` is not a mapping.
530
+
531
+ Examples
532
+ --------
533
+ .. code-block:: python
534
+
535
+ >>> import brainstate
536
+ >>>
537
+ >>> rnn = brainstate.nn.RNNCell(10, 20)
538
+ >>> # Initialize with 16 batched instances
539
+ >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
540
+ >>>
541
+ >>> # Process sequences...
542
+ >>>
543
+ >>> # Reset all 16 batched instances
544
+ >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
545
+
546
+ See Also
547
+ --------
548
+ reset_all_states : Non-vectorized version.
549
+ vmap_call_all_functions : The underlying vmap function call mechanism.
550
+ """
551
+ vmap_call_all_fns(
552
+ target,
553
+ fn_name='reset_state',
554
+ args=reset_args,
555
+ kwargs=reset_kwargs,
556
+ axis_size=axis_size,
557
+ node_to_exclude=node_to_exclude,
558
+ state_tag=state_tag,
559
+ )
560
+ return target
561
+
562
+
563
+ @set_module_as('brainstate.nn')
564
+ def assign_state_values(
565
+ target: Module,
566
+ *state_by_abs_path: Mapping[str, Any]
567
+ ) -> tuple[list[str], list[str]]:
568
+ """
569
+ Assign state values to a module from one or more state dictionaries.
570
+
571
+ This function updates the state values of a module based on provided state dictionaries.
572
+ State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
573
+ The function handles missing and unexpected keys, returning them for inspection.
574
+
575
+ Parameters
576
+ ----------
577
+ target : Module
578
+ The target module whose states will be updated.
579
+ *state_by_abs_path : Mapping[str, Any]
580
+ One or more state dictionaries with absolute path keys mapping to state values.
581
+ If multiple dictionaries are provided, they will be merged (later dictionaries
582
+ override earlier ones for duplicate keys).
583
+
584
+ Returns
585
+ -------
586
+ tuple[list[str], list[str]]
587
+ A tuple of (unexpected_keys, missing_keys):
588
+
589
+ - unexpected_keys: Keys present in the state dictionaries but not in the module
590
+ - missing_keys: Keys present in the module but not in the state dictionaries
591
+
592
+ Examples
593
+ --------
594
+ .. code-block:: python
595
+
596
+ >>> import brainstate
597
+ >>>
598
+ >>> net = brainstate.nn.Linear(10, 20)
599
+ >>> brainstate.nn.init_all_states(net)
600
+ >>>
601
+ >>> # Save state values
602
+ >>> state_dict = {path: state.value for path, state in net.states().items()}
603
+ >>>
604
+ >>> # Later, restore state values
605
+ >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
606
+ >>> print(f"Unexpected keys: {unexpected}")
607
+ >>> print(f"Missing keys: {missing}")
608
+
609
+ Notes
610
+ -----
611
+ - All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
612
+ - Only states with matching keys are updated; unexpected and missing keys are
613
+ returned but do not cause errors.
614
+ - If multiple dictionaries contain the same key, the last one takes precedence.
615
+ """
616
+ # Merge all state dictionaries
617
+ all_states = {}
618
+ for state_dict in state_by_abs_path:
619
+ all_states.update(state_dict)
620
+
621
+ # Get current module states
622
+ variables = target.states()
623
+ keys1 = set(all_states.keys())
624
+ keys2 = set(variables.keys())
625
+
626
+ # Update matching states
627
+ for key in keys2.intersection(keys1):
628
+ variables[key].value = jax.numpy.asarray(all_states[key])
629
+
630
+ # Return mismatched keys
631
+ unexpected_keys = sorted(keys1 - keys2)
632
+ missing_keys = sorted(keys2 - keys1)
633
+ return unexpected_keys, missing_keys