brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_collective_ops.py
CHANGED
@@ -1,633 +1,633 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
import warnings
|
16
|
-
from collections.abc import Sequence, Mapping
|
17
|
-
from typing import Callable, TypeVar, Any
|
18
|
-
|
19
|
-
import jax
|
20
|
-
|
21
|
-
from brainstate._state import catch_new_states
|
22
|
-
from brainstate._utils import set_module_as
|
23
|
-
from brainstate.graph import nodes
|
24
|
-
from brainstate.transform import vmap, vmap_new_states
|
25
|
-
from brainstate.typing import Filter
|
26
|
-
from ._module import Module
|
27
|
-
|
28
|
-
# the maximum order
|
29
|
-
MAX_ORDER = 10
|
30
|
-
|
31
|
-
T = TypeVar('T', bound=Module)
|
32
|
-
|
33
|
-
__all__ = [
|
34
|
-
'call_order',
|
35
|
-
'call_all_fns',
|
36
|
-
'vmap_call_all_fns',
|
37
|
-
'init_all_states',
|
38
|
-
'vmap_init_all_states',
|
39
|
-
'reset_all_states',
|
40
|
-
'vmap_reset_all_states',
|
41
|
-
'assign_state_values',
|
42
|
-
]
|
43
|
-
|
44
|
-
|
45
|
-
@set_module_as('brainstate.nn')
|
46
|
-
def call_order(
|
47
|
-
level: int = 0,
|
48
|
-
check_order_boundary: bool = True
|
49
|
-
) -> Callable[[Callable], Callable]:
|
50
|
-
"""
|
51
|
-
Decorator for specifying the execution order of functions in collective operations.
|
52
|
-
|
53
|
-
This decorator attaches a `call_order` attribute to a function, which is used by
|
54
|
-
collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
|
55
|
-
to determine the execution order. Functions with lower order levels are executed first.
|
56
|
-
|
57
|
-
Parameters
|
58
|
-
----------
|
59
|
-
level : int, optional
|
60
|
-
The execution order level. Lower values indicate earlier execution.
|
61
|
-
Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
|
62
|
-
Default is 0.
|
63
|
-
check_order_boundary : bool, optional
|
64
|
-
Whether to validate that the order level is within the valid range [0, MAX_ORDER).
|
65
|
-
Default is True.
|
66
|
-
|
67
|
-
Returns
|
68
|
-
-------
|
69
|
-
Callable[[Callable], Callable]
|
70
|
-
A decorator function that adds the `call_order` attribute to the decorated function.
|
71
|
-
|
72
|
-
Raises
|
73
|
-
------
|
74
|
-
ValueError
|
75
|
-
If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
|
76
|
-
|
77
|
-
Examples
|
78
|
-
--------
|
79
|
-
.. code-block:: python
|
80
|
-
|
81
|
-
>>> import brainstate
|
82
|
-
>>>
|
83
|
-
>>> class MyModule(brainstate.nn.Module):
|
84
|
-
... @brainstate.nn.call_order(0)
|
85
|
-
... def reset_state(self):
|
86
|
-
... print("Reset first")
|
87
|
-
...
|
88
|
-
... @brainstate.nn.call_order(1)
|
89
|
-
... def another_reset(self):
|
90
|
-
... print("Reset second")
|
91
|
-
"""
|
92
|
-
if check_order_boundary and (level < 0 or level >= MAX_ORDER):
|
93
|
-
raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
|
94
|
-
|
95
|
-
def wrap(fun: Callable) -> Callable:
|
96
|
-
fun.call_order = level
|
97
|
-
return fun
|
98
|
-
|
99
|
-
return wrap
|
100
|
-
|
101
|
-
|
102
|
-
@set_module_as('brainstate.nn')
|
103
|
-
def call_all_fns(
|
104
|
-
target: T,
|
105
|
-
fn_name: str,
|
106
|
-
args: Sequence[Any] | Any = (),
|
107
|
-
kwargs: Mapping[str, Any] | None = None,
|
108
|
-
node_to_exclude: Filter = None,
|
109
|
-
fn_if_not_exist: str = 'raise',
|
110
|
-
) -> T:
|
111
|
-
"""
|
112
|
-
Call a specified function on all module nodes within a target, respecting call order.
|
113
|
-
|
114
|
-
This function traverses all module nodes in the target and invokes the specified method
|
115
|
-
on each node. Functions decorated with `@call_order()` are executed in ascending order
|
116
|
-
of their level values, while functions without the decorator are executed first.
|
117
|
-
|
118
|
-
Parameters
|
119
|
-
----------
|
120
|
-
target : Module
|
121
|
-
The target module on which to call functions.
|
122
|
-
fn_name : str
|
123
|
-
The name of the method to call on each module node.
|
124
|
-
node_to_exclude : Filter, optional
|
125
|
-
A filter to exclude certain nodes from the function call.
|
126
|
-
Can be a type, predicate function, or any filter supported by the graph API.
|
127
|
-
fn_if_not_exist : str, optional
|
128
|
-
Behavior when the specified method doesn't exist on a node:
|
129
|
-
|
130
|
-
- 'raise': Raise an AttributeError (default)
|
131
|
-
- 'pass' or 'none': Skip the node silently
|
132
|
-
- 'warn': Issue a warning and skip the node
|
133
|
-
args
|
134
|
-
Positional arguments to pass to the called method. A single non-tuple
|
135
|
-
argument will be automatically wrapped in a tuple. Default is ().
|
136
|
-
kwargs
|
137
|
-
Keyword arguments to pass to the called method. Default is None.
|
138
|
-
|
139
|
-
Raises
|
140
|
-
------
|
141
|
-
TypeError
|
142
|
-
If `fun_name` is not a string or `kwargs` is not a mapping.
|
143
|
-
ValueError
|
144
|
-
If `fn_if_not_exist` is not one of the allowed values.
|
145
|
-
AttributeError
|
146
|
-
If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
|
147
|
-
|
148
|
-
Examples
|
149
|
-
--------
|
150
|
-
.. code-block:: python
|
151
|
-
|
152
|
-
>>> import brainstate
|
153
|
-
>>>
|
154
|
-
>>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
|
155
|
-
>>> brainstate.nn.call_all_fns(net, 'init_state')
|
156
|
-
"""
|
157
|
-
if not isinstance(fn_name, str):
|
158
|
-
raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
|
159
|
-
|
160
|
-
args = (args,) if not isinstance(args, tuple) else args
|
161
|
-
kwargs = kwargs or {}
|
162
|
-
if not isinstance(kwargs, Mapping):
|
163
|
-
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
|
164
|
-
|
165
|
-
all_nodes = nodes(target).filter(Module)
|
166
|
-
if node_to_exclude is not None:
|
167
|
-
all_nodes -= all_nodes.filter(node_to_exclude)
|
168
|
-
|
169
|
-
# Separate nodes with and without call_order
|
170
|
-
nodes_with_order = []
|
171
|
-
for path, node in all_nodes.items():
|
172
|
-
try:
|
173
|
-
fun = getattr(node, fn_name)
|
174
|
-
except AttributeError as e:
|
175
|
-
if fn_if_not_exist == 'raise':
|
176
|
-
raise AttributeError(
|
177
|
-
f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
|
178
|
-
) from e
|
179
|
-
elif fn_if_not_exist in ('pass', 'none'):
|
180
|
-
continue
|
181
|
-
elif fn_if_not_exist == 'warn':
|
182
|
-
warnings.warn(
|
183
|
-
f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
|
184
|
-
f"Skipping.",
|
185
|
-
UserWarning
|
186
|
-
)
|
187
|
-
continue
|
188
|
-
else:
|
189
|
-
raise ValueError(
|
190
|
-
f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
|
191
|
-
)
|
192
|
-
|
193
|
-
if not callable(fun):
|
194
|
-
raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
|
195
|
-
|
196
|
-
if hasattr(fun, 'call_order'):
|
197
|
-
nodes_with_order.append(node)
|
198
|
-
else:
|
199
|
-
fun(*args, **kwargs)
|
200
|
-
|
201
|
-
# Execute nodes with call_order in sorted order
|
202
|
-
for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
|
203
|
-
getattr(node, fn_name)(*args, **kwargs)
|
204
|
-
return target
|
205
|
-
|
206
|
-
|
207
|
-
def vmap_call_all_fns(
|
208
|
-
target: T,
|
209
|
-
fn_name: str,
|
210
|
-
args: Sequence[Any] | Any = (),
|
211
|
-
kwargs: Mapping[str, Any] | None = None,
|
212
|
-
axis_size: int = None,
|
213
|
-
node_to_exclude: Filter = None,
|
214
|
-
state_tag: str | None = None,
|
215
|
-
fn_if_not_exist: str = 'raise',
|
216
|
-
) -> T:
|
217
|
-
"""
|
218
|
-
Apply vectorized mapping to call a function on all module nodes with batched state handling.
|
219
|
-
|
220
|
-
This function creates multiple batched instances by applying vmap to the specified method
|
221
|
-
call across all module nodes. Each batch element maintains its own random key and state
|
222
|
-
values. This is particularly useful for creating ensembles or batched models.
|
223
|
-
|
224
|
-
Parameters
|
225
|
-
----------
|
226
|
-
target : Module
|
227
|
-
The target module on which to call functions.
|
228
|
-
fn_name : str
|
229
|
-
The name of the method to call on each module node.
|
230
|
-
args : Sequence[Any] or Any, optional
|
231
|
-
Positional arguments to pass to the called method. A single non-tuple
|
232
|
-
argument will be automatically wrapped in a tuple. Default is ().
|
233
|
-
kwargs : Mapping[str, Any], optional
|
234
|
-
Keyword arguments to pass to the called method. Default is None.
|
235
|
-
axis_size : int
|
236
|
-
The size of the batch dimension for vmap. Must be a positive integer.
|
237
|
-
node_to_exclude : Filter, optional
|
238
|
-
A filter to exclude certain nodes from the function call.
|
239
|
-
state_tag : str, optional
|
240
|
-
An optional tag to categorize newly created states during the vmap operation.
|
241
|
-
fn_if_not_exist : str, optional
|
242
|
-
Behavior when the specified method doesn't exist on a node:
|
243
|
-
|
244
|
-
- 'raise': Raise an AttributeError (default)
|
245
|
-
- 'pass' or 'none': Skip the node silently
|
246
|
-
- 'warn': Issue a warning and skip the node
|
247
|
-
|
248
|
-
Raises
|
249
|
-
------
|
250
|
-
ValueError
|
251
|
-
If `axis_size` is None or not a positive integer.
|
252
|
-
TypeError
|
253
|
-
If `kwargs` is not a mapping.
|
254
|
-
|
255
|
-
Examples
|
256
|
-
--------
|
257
|
-
.. code-block:: python
|
258
|
-
|
259
|
-
>>> import brainstate
|
260
|
-
>>>
|
261
|
-
>>> net = brainstate.nn.Linear(10, 20)
|
262
|
-
>>> # Create 5 batched instances with different initializations
|
263
|
-
>>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
|
264
|
-
"""
|
265
|
-
|
266
|
-
if axis_size is None or axis_size <= 0:
|
267
|
-
raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
|
268
|
-
|
269
|
-
if not isinstance(args, tuple):
|
270
|
-
args = (args,)
|
271
|
-
kwargs = kwargs or {}
|
272
|
-
if not isinstance(kwargs, Mapping):
|
273
|
-
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
|
274
|
-
|
275
|
-
@vmap(axis_size=axis_size)
|
276
|
-
def vmapped_fn():
|
277
|
-
with catch_new_states(state_tag) as inner_catcher:
|
278
|
-
call_all_fns(
|
279
|
-
target,
|
280
|
-
fn_name=fn_name,
|
281
|
-
args=args,
|
282
|
-
kwargs=kwargs,
|
283
|
-
node_to_exclude=node_to_exclude,
|
284
|
-
fn_if_not_exist=fn_if_not_exist
|
285
|
-
)
|
286
|
-
return inner_catcher.get_state_values()
|
287
|
-
|
288
|
-
with catch_new_states(state_tag) as outer_catcher:
|
289
|
-
values = vmapped_fn()
|
290
|
-
states = outer_catcher.get_states()
|
291
|
-
for state, value in zip(states, values):
|
292
|
-
state.value = value
|
293
|
-
return target
|
294
|
-
|
295
|
-
|
296
|
-
@set_module_as('brainstate.nn')
|
297
|
-
def init_all_states(
|
298
|
-
target: T,
|
299
|
-
*init_args,
|
300
|
-
node_to_exclude: Filter = None,
|
301
|
-
**init_kwargs,
|
302
|
-
) -> T:
|
303
|
-
"""
|
304
|
-
Initialize states for all module nodes within the target.
|
305
|
-
|
306
|
-
This is a convenience wrapper around `call_all_functions` that specifically calls
|
307
|
-
the `init_state` method on all module nodes. The execution order respects any
|
308
|
-
`@call_order()` decorators on the `init_state` methods.
|
309
|
-
|
310
|
-
Parameters
|
311
|
-
----------
|
312
|
-
target : Module
|
313
|
-
The target module whose states are to be initialized.
|
314
|
-
*init_args
|
315
|
-
Variable positional arguments to pass to each `init_state` method.
|
316
|
-
node_to_exclude : Filter, optional
|
317
|
-
A filter to exclude certain nodes from initialization.
|
318
|
-
Can be a type, predicate function, or any filter supported by the graph API.
|
319
|
-
**init_kwargs
|
320
|
-
Variable keyword arguments to pass to each `init_state` method.
|
321
|
-
|
322
|
-
Examples
|
323
|
-
--------
|
324
|
-
.. code-block:: python
|
325
|
-
|
326
|
-
>>> import brainstate
|
327
|
-
>>>
|
328
|
-
>>> net = brainstate.nn.Sequential(
|
329
|
-
... brainstate.nn.Linear(10, 20),
|
330
|
-
... brainstate.nn.Dropout(0.5)
|
331
|
-
... )
|
332
|
-
>>> # Initialize all states
|
333
|
-
>>> brainstate.nn.init_all_states(net)
|
334
|
-
>>>
|
335
|
-
>>> # Initialize with custom arguments
|
336
|
-
>>> brainstate.nn.init_all_states(net, batch_size=32)
|
337
|
-
|
338
|
-
See Also
|
339
|
-
--------
|
340
|
-
call_all_functions : The underlying function that executes the calls.
|
341
|
-
vmap_init_all_states : Vectorized version for batched initialization.
|
342
|
-
"""
|
343
|
-
call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
|
344
|
-
return target
|
345
|
-
|
346
|
-
|
347
|
-
@set_module_as('brainstate.nn')
|
348
|
-
def vmap_init_all_states(
|
349
|
-
target: T,
|
350
|
-
*init_args,
|
351
|
-
axis_size: int = None,
|
352
|
-
node_to_exclude: Filter = None,
|
353
|
-
state_to_exclude: Filter = None,
|
354
|
-
state_tag: str | None = None,
|
355
|
-
**init_kwargs
|
356
|
-
) -> T:
|
357
|
-
"""
|
358
|
-
Initialize states with vectorized mapping for creating batched module instances.
|
359
|
-
|
360
|
-
This function applies vmap to the initialization process, creating multiple batched
|
361
|
-
instances of module states. Each batch element will have independent state values
|
362
|
-
and random keys. This is useful for ensemble models or parameter sweeps.
|
363
|
-
|
364
|
-
Parameters
|
365
|
-
----------
|
366
|
-
target : Module
|
367
|
-
The target module whose states are to be initialized.
|
368
|
-
*init_args
|
369
|
-
Variable positional arguments to pass to each `init_state` method.
|
370
|
-
axis_size : int
|
371
|
-
The size of the batch dimension. Must be a positive integer.
|
372
|
-
node_to_exclude : Filter, optional
|
373
|
-
A filter to exclude certain nodes from initialization.
|
374
|
-
state_to_exclude : Filter, optional
|
375
|
-
A filter to exclude certain states from being vmapped.
|
376
|
-
Excluded states will remain shared across all batched instances.
|
377
|
-
state_tag : str, optional
|
378
|
-
An optional tag to categorize newly created states.
|
379
|
-
**init_kwargs
|
380
|
-
Variable keyword arguments to pass to each `init_state` method.
|
381
|
-
|
382
|
-
Raises
|
383
|
-
------
|
384
|
-
ValueError
|
385
|
-
If `axis_size` is None or not a positive integer.
|
386
|
-
|
387
|
-
Examples
|
388
|
-
--------
|
389
|
-
.. code-block:: python
|
390
|
-
|
391
|
-
>>> import brainstate
|
392
|
-
>>>
|
393
|
-
>>> net = brainstate.nn.Linear(10, 20)
|
394
|
-
>>> # Create 8 batched instances with different random initializations
|
395
|
-
>>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
|
396
|
-
>>>
|
397
|
-
>>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
|
398
|
-
>>> print(net.weight.shape)
|
399
|
-
|
400
|
-
See Also
|
401
|
-
--------
|
402
|
-
init_all_states : Non-vectorized version.
|
403
|
-
vmap_new_states : The underlying vmap transformation for states.
|
404
|
-
"""
|
405
|
-
|
406
|
-
# vmap_call_all_functions(
|
407
|
-
# target,
|
408
|
-
# fun_name='init_state',
|
409
|
-
# args=init_args,
|
410
|
-
# kwargs=init_kwargs,
|
411
|
-
# axis_size=axis_size,
|
412
|
-
# node_to_exclude=node_to_exclude,
|
413
|
-
# state_tag=state_tag,
|
414
|
-
# )
|
415
|
-
|
416
|
-
def init_fn():
|
417
|
-
init_all_states(
|
418
|
-
target,
|
419
|
-
*init_args,
|
420
|
-
**init_kwargs,
|
421
|
-
node_to_exclude=node_to_exclude,
|
422
|
-
)
|
423
|
-
return
|
424
|
-
|
425
|
-
vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
|
426
|
-
return target
|
427
|
-
|
428
|
-
|
429
|
-
@set_module_as('brainstate.nn')
|
430
|
-
def reset_all_states(
|
431
|
-
target: T,
|
432
|
-
*reset_args,
|
433
|
-
node_to_exclude: Filter = None,
|
434
|
-
**reset_kwargs,
|
435
|
-
) -> T:
|
436
|
-
"""
|
437
|
-
Reset states for all module nodes within the target.
|
438
|
-
|
439
|
-
This is a convenience wrapper around `call_all_functions` that specifically calls
|
440
|
-
the `reset_state` method on all module nodes. The execution order respects any
|
441
|
-
`@call_order()` decorators on the `reset_state` methods. This is typically used
|
442
|
-
to reset recurrent neural network states between sequences.
|
443
|
-
|
444
|
-
Parameters
|
445
|
-
----------
|
446
|
-
target : Module
|
447
|
-
The target module whose states are to be reset.
|
448
|
-
reset_args
|
449
|
-
Positional arguments to pass to each `reset_state` method.
|
450
|
-
A single non-tuple argument will be automatically wrapped in a tuple.
|
451
|
-
Default is ().
|
452
|
-
reset_kwargs
|
453
|
-
Keyword arguments to pass to each `reset_state` method.
|
454
|
-
Default is None.
|
455
|
-
node_to_exclude : Filter, optional
|
456
|
-
A filter to exclude certain nodes from reset.
|
457
|
-
Can be a type, predicate function, or any filter supported by the graph API.
|
458
|
-
|
459
|
-
Examples
|
460
|
-
--------
|
461
|
-
.. code-block:: python
|
462
|
-
|
463
|
-
>>> import brainstate
|
464
|
-
>>>
|
465
|
-
>>> rnn = brainstate.nn.RNNCell(10, 20)
|
466
|
-
>>> brainstate.nn.init_all_states(rnn, batch_size=32)
|
467
|
-
>>>
|
468
|
-
>>> # Process a sequence
|
469
|
-
>>> for x in sequence:
|
470
|
-
... output = rnn(x)
|
471
|
-
>>>
|
472
|
-
>>> # Reset states before processing next sequence
|
473
|
-
>>> brainstate.nn.reset_all_states(rnn)
|
474
|
-
|
475
|
-
See Also
|
476
|
-
--------
|
477
|
-
call_all_functions : The underlying function that executes the calls.
|
478
|
-
vmap_reset_all_states : Vectorized version for batched reset.
|
479
|
-
"""
|
480
|
-
call_all_fns(
|
481
|
-
target,
|
482
|
-
fn_name='reset_state',
|
483
|
-
args=reset_args,
|
484
|
-
kwargs=reset_kwargs,
|
485
|
-
node_to_exclude=node_to_exclude
|
486
|
-
)
|
487
|
-
return target
|
488
|
-
|
489
|
-
|
490
|
-
def vmap_reset_all_states(
|
491
|
-
target: T,
|
492
|
-
*reset_args,
|
493
|
-
axis_size: int = None,
|
494
|
-
node_to_exclude: Filter = None,
|
495
|
-
state_tag: str | None = None,
|
496
|
-
**reset_kwargs,
|
497
|
-
) -> T:
|
498
|
-
"""
|
499
|
-
Reset states with vectorized mapping across batched module instances.
|
500
|
-
|
501
|
-
This function applies vmap to the reset process, resetting states across all
|
502
|
-
batched instances of the module. Each batch element will have its state reset
|
503
|
-
independently with its own random key. This is useful when working with batched
|
504
|
-
recurrent models or ensembles.
|
505
|
-
|
506
|
-
Parameters
|
507
|
-
----------
|
508
|
-
target : Module
|
509
|
-
The target module whose states are to be reset.
|
510
|
-
reset_args
|
511
|
-
Positional arguments to pass to each `reset_state` method.
|
512
|
-
A single non-tuple argument will be automatically wrapped in a tuple.
|
513
|
-
Default is ().
|
514
|
-
reset_kwargs
|
515
|
-
Keyword arguments to pass to each `reset_state` method.
|
516
|
-
Default is None.
|
517
|
-
axis_size : int
|
518
|
-
The size of the batch dimension. Must be a positive integer.
|
519
|
-
node_to_exclude : Filter, optional
|
520
|
-
A filter to exclude certain nodes from reset.
|
521
|
-
state_tag : str, optional
|
522
|
-
An optional tag to categorize newly created states during the reset.
|
523
|
-
|
524
|
-
Raises
|
525
|
-
------
|
526
|
-
ValueError
|
527
|
-
If `axis_size` is None or not a positive integer.
|
528
|
-
TypeError
|
529
|
-
If `reset_kwargs` is not a mapping.
|
530
|
-
|
531
|
-
Examples
|
532
|
-
--------
|
533
|
-
.. code-block:: python
|
534
|
-
|
535
|
-
>>> import brainstate
|
536
|
-
>>>
|
537
|
-
>>> rnn = brainstate.nn.RNNCell(10, 20)
|
538
|
-
>>> # Initialize with 16 batched instances
|
539
|
-
>>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
|
540
|
-
>>>
|
541
|
-
>>> # Process sequences...
|
542
|
-
>>>
|
543
|
-
>>> # Reset all 16 batched instances
|
544
|
-
>>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
|
545
|
-
|
546
|
-
See Also
|
547
|
-
--------
|
548
|
-
reset_all_states : Non-vectorized version.
|
549
|
-
vmap_call_all_functions : The underlying vmap function call mechanism.
|
550
|
-
"""
|
551
|
-
vmap_call_all_fns(
|
552
|
-
target,
|
553
|
-
fn_name='reset_state',
|
554
|
-
args=reset_args,
|
555
|
-
kwargs=reset_kwargs,
|
556
|
-
axis_size=axis_size,
|
557
|
-
node_to_exclude=node_to_exclude,
|
558
|
-
state_tag=state_tag,
|
559
|
-
)
|
560
|
-
return target
|
561
|
-
|
562
|
-
|
563
|
-
@set_module_as('brainstate.nn')
|
564
|
-
def assign_state_values(
|
565
|
-
target: Module,
|
566
|
-
*state_by_abs_path: Mapping[str, Any]
|
567
|
-
) -> tuple[list[str], list[str]]:
|
568
|
-
"""
|
569
|
-
Assign state values to a module from one or more state dictionaries.
|
570
|
-
|
571
|
-
This function updates the state values of a module based on provided state dictionaries.
|
572
|
-
State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
|
573
|
-
The function handles missing and unexpected keys, returning them for inspection.
|
574
|
-
|
575
|
-
Parameters
|
576
|
-
----------
|
577
|
-
target : Module
|
578
|
-
The target module whose states will be updated.
|
579
|
-
*state_by_abs_path : Mapping[str, Any]
|
580
|
-
One or more state dictionaries with absolute path keys mapping to state values.
|
581
|
-
If multiple dictionaries are provided, they will be merged (later dictionaries
|
582
|
-
override earlier ones for duplicate keys).
|
583
|
-
|
584
|
-
Returns
|
585
|
-
-------
|
586
|
-
tuple[list[str], list[str]]
|
587
|
-
A tuple of (unexpected_keys, missing_keys):
|
588
|
-
|
589
|
-
- unexpected_keys: Keys present in the state dictionaries but not in the module
|
590
|
-
- missing_keys: Keys present in the module but not in the state dictionaries
|
591
|
-
|
592
|
-
Examples
|
593
|
-
--------
|
594
|
-
.. code-block:: python
|
595
|
-
|
596
|
-
>>> import brainstate
|
597
|
-
>>>
|
598
|
-
>>> net = brainstate.nn.Linear(10, 20)
|
599
|
-
>>> brainstate.nn.init_all_states(net)
|
600
|
-
>>>
|
601
|
-
>>> # Save state values
|
602
|
-
>>> state_dict = {path: state.value for path, state in net.states().items()}
|
603
|
-
>>>
|
604
|
-
>>> # Later, restore state values
|
605
|
-
>>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
|
606
|
-
>>> print(f"Unexpected keys: {unexpected}")
|
607
|
-
>>> print(f"Missing keys: {missing}")
|
608
|
-
|
609
|
-
Notes
|
610
|
-
-----
|
611
|
-
- All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
|
612
|
-
- Only states with matching keys are updated; unexpected and missing keys are
|
613
|
-
returned but do not cause errors.
|
614
|
-
- If multiple dictionaries contain the same key, the last one takes precedence.
|
615
|
-
"""
|
616
|
-
# Merge all state dictionaries
|
617
|
-
all_states = {}
|
618
|
-
for state_dict in state_by_abs_path:
|
619
|
-
all_states.update(state_dict)
|
620
|
-
|
621
|
-
# Get current module states
|
622
|
-
variables = target.states()
|
623
|
-
keys1 = set(all_states.keys())
|
624
|
-
keys2 = set(variables.keys())
|
625
|
-
|
626
|
-
# Update matching states
|
627
|
-
for key in keys2.intersection(keys1):
|
628
|
-
variables[key].value = jax.numpy.asarray(all_states[key])
|
629
|
-
|
630
|
-
# Return mismatched keys
|
631
|
-
unexpected_keys = sorted(keys1 - keys2)
|
632
|
-
missing_keys = sorted(keys2 - keys1)
|
633
|
-
return unexpected_keys, missing_keys
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import warnings
|
16
|
+
from collections.abc import Sequence, Mapping
|
17
|
+
from typing import Callable, TypeVar, Any
|
18
|
+
|
19
|
+
import jax
|
20
|
+
|
21
|
+
from brainstate._state import catch_new_states
|
22
|
+
from brainstate._utils import set_module_as
|
23
|
+
from brainstate.graph import nodes
|
24
|
+
from brainstate.transform import vmap, vmap_new_states
|
25
|
+
from brainstate.typing import Filter
|
26
|
+
from ._module import Module
|
27
|
+
|
28
|
+
# the maximum order
|
29
|
+
MAX_ORDER = 10
|
30
|
+
|
31
|
+
T = TypeVar('T', bound=Module)
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
'call_order',
|
35
|
+
'call_all_fns',
|
36
|
+
'vmap_call_all_fns',
|
37
|
+
'init_all_states',
|
38
|
+
'vmap_init_all_states',
|
39
|
+
'reset_all_states',
|
40
|
+
'vmap_reset_all_states',
|
41
|
+
'assign_state_values',
|
42
|
+
]
|
43
|
+
|
44
|
+
|
45
|
+
@set_module_as('brainstate.nn')
|
46
|
+
def call_order(
|
47
|
+
level: int = 0,
|
48
|
+
check_order_boundary: bool = True
|
49
|
+
) -> Callable[[Callable], Callable]:
|
50
|
+
"""
|
51
|
+
Decorator for specifying the execution order of functions in collective operations.
|
52
|
+
|
53
|
+
This decorator attaches a `call_order` attribute to a function, which is used by
|
54
|
+
collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
|
55
|
+
to determine the execution order. Functions with lower order levels are executed first.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
level : int, optional
|
60
|
+
The execution order level. Lower values indicate earlier execution.
|
61
|
+
Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
|
62
|
+
Default is 0.
|
63
|
+
check_order_boundary : bool, optional
|
64
|
+
Whether to validate that the order level is within the valid range [0, MAX_ORDER).
|
65
|
+
Default is True.
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
Callable[[Callable], Callable]
|
70
|
+
A decorator function that adds the `call_order` attribute to the decorated function.
|
71
|
+
|
72
|
+
Raises
|
73
|
+
------
|
74
|
+
ValueError
|
75
|
+
If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
.. code-block:: python
|
80
|
+
|
81
|
+
>>> import brainstate
|
82
|
+
>>>
|
83
|
+
>>> class MyModule(brainstate.nn.Module):
|
84
|
+
... @brainstate.nn.call_order(0)
|
85
|
+
... def reset_state(self):
|
86
|
+
... print("Reset first")
|
87
|
+
...
|
88
|
+
... @brainstate.nn.call_order(1)
|
89
|
+
... def another_reset(self):
|
90
|
+
... print("Reset second")
|
91
|
+
"""
|
92
|
+
if check_order_boundary and (level < 0 or level >= MAX_ORDER):
|
93
|
+
raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
|
94
|
+
|
95
|
+
def wrap(fun: Callable) -> Callable:
|
96
|
+
fun.call_order = level
|
97
|
+
return fun
|
98
|
+
|
99
|
+
return wrap
|
100
|
+
|
101
|
+
|
102
|
+
@set_module_as('brainstate.nn')
|
103
|
+
def call_all_fns(
|
104
|
+
target: T,
|
105
|
+
fn_name: str,
|
106
|
+
args: Sequence[Any] | Any = (),
|
107
|
+
kwargs: Mapping[str, Any] | None = None,
|
108
|
+
node_to_exclude: Filter = None,
|
109
|
+
fn_if_not_exist: str = 'raise',
|
110
|
+
) -> T:
|
111
|
+
"""
|
112
|
+
Call a specified function on all module nodes within a target, respecting call order.
|
113
|
+
|
114
|
+
This function traverses all module nodes in the target and invokes the specified method
|
115
|
+
on each node. Functions decorated with `@call_order()` are executed in ascending order
|
116
|
+
of their level values, while functions without the decorator are executed first.
|
117
|
+
|
118
|
+
Parameters
|
119
|
+
----------
|
120
|
+
target : Module
|
121
|
+
The target module on which to call functions.
|
122
|
+
fn_name : str
|
123
|
+
The name of the method to call on each module node.
|
124
|
+
node_to_exclude : Filter, optional
|
125
|
+
A filter to exclude certain nodes from the function call.
|
126
|
+
Can be a type, predicate function, or any filter supported by the graph API.
|
127
|
+
fn_if_not_exist : str, optional
|
128
|
+
Behavior when the specified method doesn't exist on a node:
|
129
|
+
|
130
|
+
- 'raise': Raise an AttributeError (default)
|
131
|
+
- 'pass' or 'none': Skip the node silently
|
132
|
+
- 'warn': Issue a warning and skip the node
|
133
|
+
args
|
134
|
+
Positional arguments to pass to the called method. A single non-tuple
|
135
|
+
argument will be automatically wrapped in a tuple. Default is ().
|
136
|
+
kwargs
|
137
|
+
Keyword arguments to pass to the called method. Default is None.
|
138
|
+
|
139
|
+
Raises
|
140
|
+
------
|
141
|
+
TypeError
|
142
|
+
If `fun_name` is not a string or `kwargs` is not a mapping.
|
143
|
+
ValueError
|
144
|
+
If `fn_if_not_exist` is not one of the allowed values.
|
145
|
+
AttributeError
|
146
|
+
If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
|
147
|
+
|
148
|
+
Examples
|
149
|
+
--------
|
150
|
+
.. code-block:: python
|
151
|
+
|
152
|
+
>>> import brainstate
|
153
|
+
>>>
|
154
|
+
>>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
|
155
|
+
>>> brainstate.nn.call_all_fns(net, 'init_state')
|
156
|
+
"""
|
157
|
+
if not isinstance(fn_name, str):
|
158
|
+
raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
|
159
|
+
|
160
|
+
args = (args,) if not isinstance(args, tuple) else args
|
161
|
+
kwargs = kwargs or {}
|
162
|
+
if not isinstance(kwargs, Mapping):
|
163
|
+
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
|
164
|
+
|
165
|
+
all_nodes = nodes(target).filter(Module)
|
166
|
+
if node_to_exclude is not None:
|
167
|
+
all_nodes -= all_nodes.filter(node_to_exclude)
|
168
|
+
|
169
|
+
# Separate nodes with and without call_order
|
170
|
+
nodes_with_order = []
|
171
|
+
for path, node in all_nodes.items():
|
172
|
+
try:
|
173
|
+
fun = getattr(node, fn_name)
|
174
|
+
except AttributeError as e:
|
175
|
+
if fn_if_not_exist == 'raise':
|
176
|
+
raise AttributeError(
|
177
|
+
f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
|
178
|
+
) from e
|
179
|
+
elif fn_if_not_exist in ('pass', 'none'):
|
180
|
+
continue
|
181
|
+
elif fn_if_not_exist == 'warn':
|
182
|
+
warnings.warn(
|
183
|
+
f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
|
184
|
+
f"Skipping.",
|
185
|
+
UserWarning
|
186
|
+
)
|
187
|
+
continue
|
188
|
+
else:
|
189
|
+
raise ValueError(
|
190
|
+
f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
|
191
|
+
)
|
192
|
+
|
193
|
+
if not callable(fun):
|
194
|
+
raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
|
195
|
+
|
196
|
+
if hasattr(fun, 'call_order'):
|
197
|
+
nodes_with_order.append(node)
|
198
|
+
else:
|
199
|
+
fun(*args, **kwargs)
|
200
|
+
|
201
|
+
# Execute nodes with call_order in sorted order
|
202
|
+
for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
|
203
|
+
getattr(node, fn_name)(*args, **kwargs)
|
204
|
+
return target
|
205
|
+
|
206
|
+
|
207
|
+
def vmap_call_all_fns(
|
208
|
+
target: T,
|
209
|
+
fn_name: str,
|
210
|
+
args: Sequence[Any] | Any = (),
|
211
|
+
kwargs: Mapping[str, Any] | None = None,
|
212
|
+
axis_size: int = None,
|
213
|
+
node_to_exclude: Filter = None,
|
214
|
+
state_tag: str | None = None,
|
215
|
+
fn_if_not_exist: str = 'raise',
|
216
|
+
) -> T:
|
217
|
+
"""
|
218
|
+
Apply vectorized mapping to call a function on all module nodes with batched state handling.
|
219
|
+
|
220
|
+
This function creates multiple batched instances by applying vmap to the specified method
|
221
|
+
call across all module nodes. Each batch element maintains its own random key and state
|
222
|
+
values. This is particularly useful for creating ensembles or batched models.
|
223
|
+
|
224
|
+
Parameters
|
225
|
+
----------
|
226
|
+
target : Module
|
227
|
+
The target module on which to call functions.
|
228
|
+
fn_name : str
|
229
|
+
The name of the method to call on each module node.
|
230
|
+
args : Sequence[Any] or Any, optional
|
231
|
+
Positional arguments to pass to the called method. A single non-tuple
|
232
|
+
argument will be automatically wrapped in a tuple. Default is ().
|
233
|
+
kwargs : Mapping[str, Any], optional
|
234
|
+
Keyword arguments to pass to the called method. Default is None.
|
235
|
+
axis_size : int
|
236
|
+
The size of the batch dimension for vmap. Must be a positive integer.
|
237
|
+
node_to_exclude : Filter, optional
|
238
|
+
A filter to exclude certain nodes from the function call.
|
239
|
+
state_tag : str, optional
|
240
|
+
An optional tag to categorize newly created states during the vmap operation.
|
241
|
+
fn_if_not_exist : str, optional
|
242
|
+
Behavior when the specified method doesn't exist on a node:
|
243
|
+
|
244
|
+
- 'raise': Raise an AttributeError (default)
|
245
|
+
- 'pass' or 'none': Skip the node silently
|
246
|
+
- 'warn': Issue a warning and skip the node
|
247
|
+
|
248
|
+
Raises
|
249
|
+
------
|
250
|
+
ValueError
|
251
|
+
If `axis_size` is None or not a positive integer.
|
252
|
+
TypeError
|
253
|
+
If `kwargs` is not a mapping.
|
254
|
+
|
255
|
+
Examples
|
256
|
+
--------
|
257
|
+
.. code-block:: python
|
258
|
+
|
259
|
+
>>> import brainstate
|
260
|
+
>>>
|
261
|
+
>>> net = brainstate.nn.Linear(10, 20)
|
262
|
+
>>> # Create 5 batched instances with different initializations
|
263
|
+
>>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
|
264
|
+
"""
|
265
|
+
|
266
|
+
if axis_size is None or axis_size <= 0:
|
267
|
+
raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
|
268
|
+
|
269
|
+
if not isinstance(args, tuple):
|
270
|
+
args = (args,)
|
271
|
+
kwargs = kwargs or {}
|
272
|
+
if not isinstance(kwargs, Mapping):
|
273
|
+
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
|
274
|
+
|
275
|
+
@vmap(axis_size=axis_size)
|
276
|
+
def vmapped_fn():
|
277
|
+
with catch_new_states(state_tag) as inner_catcher:
|
278
|
+
call_all_fns(
|
279
|
+
target,
|
280
|
+
fn_name=fn_name,
|
281
|
+
args=args,
|
282
|
+
kwargs=kwargs,
|
283
|
+
node_to_exclude=node_to_exclude,
|
284
|
+
fn_if_not_exist=fn_if_not_exist
|
285
|
+
)
|
286
|
+
return inner_catcher.get_state_values()
|
287
|
+
|
288
|
+
with catch_new_states(state_tag) as outer_catcher:
|
289
|
+
values = vmapped_fn()
|
290
|
+
states = outer_catcher.get_states()
|
291
|
+
for state, value in zip(states, values):
|
292
|
+
state.value = value
|
293
|
+
return target
|
294
|
+
|
295
|
+
|
296
|
+
@set_module_as('brainstate.nn')
|
297
|
+
def init_all_states(
|
298
|
+
target: T,
|
299
|
+
*init_args,
|
300
|
+
node_to_exclude: Filter = None,
|
301
|
+
**init_kwargs,
|
302
|
+
) -> T:
|
303
|
+
"""
|
304
|
+
Initialize states for all module nodes within the target.
|
305
|
+
|
306
|
+
This is a convenience wrapper around `call_all_functions` that specifically calls
|
307
|
+
the `init_state` method on all module nodes. The execution order respects any
|
308
|
+
`@call_order()` decorators on the `init_state` methods.
|
309
|
+
|
310
|
+
Parameters
|
311
|
+
----------
|
312
|
+
target : Module
|
313
|
+
The target module whose states are to be initialized.
|
314
|
+
*init_args
|
315
|
+
Variable positional arguments to pass to each `init_state` method.
|
316
|
+
node_to_exclude : Filter, optional
|
317
|
+
A filter to exclude certain nodes from initialization.
|
318
|
+
Can be a type, predicate function, or any filter supported by the graph API.
|
319
|
+
**init_kwargs
|
320
|
+
Variable keyword arguments to pass to each `init_state` method.
|
321
|
+
|
322
|
+
Examples
|
323
|
+
--------
|
324
|
+
.. code-block:: python
|
325
|
+
|
326
|
+
>>> import brainstate
|
327
|
+
>>>
|
328
|
+
>>> net = brainstate.nn.Sequential(
|
329
|
+
... brainstate.nn.Linear(10, 20),
|
330
|
+
... brainstate.nn.Dropout(0.5)
|
331
|
+
... )
|
332
|
+
>>> # Initialize all states
|
333
|
+
>>> brainstate.nn.init_all_states(net)
|
334
|
+
>>>
|
335
|
+
>>> # Initialize with custom arguments
|
336
|
+
>>> brainstate.nn.init_all_states(net, batch_size=32)
|
337
|
+
|
338
|
+
See Also
|
339
|
+
--------
|
340
|
+
call_all_functions : The underlying function that executes the calls.
|
341
|
+
vmap_init_all_states : Vectorized version for batched initialization.
|
342
|
+
"""
|
343
|
+
call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
|
344
|
+
return target
|
345
|
+
|
346
|
+
|
347
|
+
@set_module_as('brainstate.nn')
|
348
|
+
def vmap_init_all_states(
|
349
|
+
target: T,
|
350
|
+
*init_args,
|
351
|
+
axis_size: int = None,
|
352
|
+
node_to_exclude: Filter = None,
|
353
|
+
state_to_exclude: Filter = None,
|
354
|
+
state_tag: str | None = None,
|
355
|
+
**init_kwargs
|
356
|
+
) -> T:
|
357
|
+
"""
|
358
|
+
Initialize states with vectorized mapping for creating batched module instances.
|
359
|
+
|
360
|
+
This function applies vmap to the initialization process, creating multiple batched
|
361
|
+
instances of module states. Each batch element will have independent state values
|
362
|
+
and random keys. This is useful for ensemble models or parameter sweeps.
|
363
|
+
|
364
|
+
Parameters
|
365
|
+
----------
|
366
|
+
target : Module
|
367
|
+
The target module whose states are to be initialized.
|
368
|
+
*init_args
|
369
|
+
Variable positional arguments to pass to each `init_state` method.
|
370
|
+
axis_size : int
|
371
|
+
The size of the batch dimension. Must be a positive integer.
|
372
|
+
node_to_exclude : Filter, optional
|
373
|
+
A filter to exclude certain nodes from initialization.
|
374
|
+
state_to_exclude : Filter, optional
|
375
|
+
A filter to exclude certain states from being vmapped.
|
376
|
+
Excluded states will remain shared across all batched instances.
|
377
|
+
state_tag : str, optional
|
378
|
+
An optional tag to categorize newly created states.
|
379
|
+
**init_kwargs
|
380
|
+
Variable keyword arguments to pass to each `init_state` method.
|
381
|
+
|
382
|
+
Raises
|
383
|
+
------
|
384
|
+
ValueError
|
385
|
+
If `axis_size` is None or not a positive integer.
|
386
|
+
|
387
|
+
Examples
|
388
|
+
--------
|
389
|
+
.. code-block:: python
|
390
|
+
|
391
|
+
>>> import brainstate
|
392
|
+
>>>
|
393
|
+
>>> net = brainstate.nn.Linear(10, 20)
|
394
|
+
>>> # Create 8 batched instances with different random initializations
|
395
|
+
>>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
|
396
|
+
>>>
|
397
|
+
>>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
|
398
|
+
>>> print(net.weight.shape)
|
399
|
+
|
400
|
+
See Also
|
401
|
+
--------
|
402
|
+
init_all_states : Non-vectorized version.
|
403
|
+
vmap_new_states : The underlying vmap transformation for states.
|
404
|
+
"""
|
405
|
+
|
406
|
+
# vmap_call_all_functions(
|
407
|
+
# target,
|
408
|
+
# fun_name='init_state',
|
409
|
+
# args=init_args,
|
410
|
+
# kwargs=init_kwargs,
|
411
|
+
# axis_size=axis_size,
|
412
|
+
# node_to_exclude=node_to_exclude,
|
413
|
+
# state_tag=state_tag,
|
414
|
+
# )
|
415
|
+
|
416
|
+
def init_fn():
|
417
|
+
init_all_states(
|
418
|
+
target,
|
419
|
+
*init_args,
|
420
|
+
**init_kwargs,
|
421
|
+
node_to_exclude=node_to_exclude,
|
422
|
+
)
|
423
|
+
return
|
424
|
+
|
425
|
+
vmap_new_states(init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude)()
|
426
|
+
return target
|
427
|
+
|
428
|
+
|
429
|
+
@set_module_as('brainstate.nn')
|
430
|
+
def reset_all_states(
|
431
|
+
target: T,
|
432
|
+
*reset_args,
|
433
|
+
node_to_exclude: Filter = None,
|
434
|
+
**reset_kwargs,
|
435
|
+
) -> T:
|
436
|
+
"""
|
437
|
+
Reset states for all module nodes within the target.
|
438
|
+
|
439
|
+
This is a convenience wrapper around `call_all_functions` that specifically calls
|
440
|
+
the `reset_state` method on all module nodes. The execution order respects any
|
441
|
+
`@call_order()` decorators on the `reset_state` methods. This is typically used
|
442
|
+
to reset recurrent neural network states between sequences.
|
443
|
+
|
444
|
+
Parameters
|
445
|
+
----------
|
446
|
+
target : Module
|
447
|
+
The target module whose states are to be reset.
|
448
|
+
reset_args
|
449
|
+
Positional arguments to pass to each `reset_state` method.
|
450
|
+
A single non-tuple argument will be automatically wrapped in a tuple.
|
451
|
+
Default is ().
|
452
|
+
reset_kwargs
|
453
|
+
Keyword arguments to pass to each `reset_state` method.
|
454
|
+
Default is None.
|
455
|
+
node_to_exclude : Filter, optional
|
456
|
+
A filter to exclude certain nodes from reset.
|
457
|
+
Can be a type, predicate function, or any filter supported by the graph API.
|
458
|
+
|
459
|
+
Examples
|
460
|
+
--------
|
461
|
+
.. code-block:: python
|
462
|
+
|
463
|
+
>>> import brainstate
|
464
|
+
>>>
|
465
|
+
>>> rnn = brainstate.nn.RNNCell(10, 20)
|
466
|
+
>>> brainstate.nn.init_all_states(rnn, batch_size=32)
|
467
|
+
>>>
|
468
|
+
>>> # Process a sequence
|
469
|
+
>>> for x in sequence:
|
470
|
+
... output = rnn(x)
|
471
|
+
>>>
|
472
|
+
>>> # Reset states before processing next sequence
|
473
|
+
>>> brainstate.nn.reset_all_states(rnn)
|
474
|
+
|
475
|
+
See Also
|
476
|
+
--------
|
477
|
+
call_all_functions : The underlying function that executes the calls.
|
478
|
+
vmap_reset_all_states : Vectorized version for batched reset.
|
479
|
+
"""
|
480
|
+
call_all_fns(
|
481
|
+
target,
|
482
|
+
fn_name='reset_state',
|
483
|
+
args=reset_args,
|
484
|
+
kwargs=reset_kwargs,
|
485
|
+
node_to_exclude=node_to_exclude
|
486
|
+
)
|
487
|
+
return target
|
488
|
+
|
489
|
+
|
490
|
+
def vmap_reset_all_states(
|
491
|
+
target: T,
|
492
|
+
*reset_args,
|
493
|
+
axis_size: int = None,
|
494
|
+
node_to_exclude: Filter = None,
|
495
|
+
state_tag: str | None = None,
|
496
|
+
**reset_kwargs,
|
497
|
+
) -> T:
|
498
|
+
"""
|
499
|
+
Reset states with vectorized mapping across batched module instances.
|
500
|
+
|
501
|
+
This function applies vmap to the reset process, resetting states across all
|
502
|
+
batched instances of the module. Each batch element will have its state reset
|
503
|
+
independently with its own random key. This is useful when working with batched
|
504
|
+
recurrent models or ensembles.
|
505
|
+
|
506
|
+
Parameters
|
507
|
+
----------
|
508
|
+
target : Module
|
509
|
+
The target module whose states are to be reset.
|
510
|
+
reset_args
|
511
|
+
Positional arguments to pass to each `reset_state` method.
|
512
|
+
A single non-tuple argument will be automatically wrapped in a tuple.
|
513
|
+
Default is ().
|
514
|
+
reset_kwargs
|
515
|
+
Keyword arguments to pass to each `reset_state` method.
|
516
|
+
Default is None.
|
517
|
+
axis_size : int
|
518
|
+
The size of the batch dimension. Must be a positive integer.
|
519
|
+
node_to_exclude : Filter, optional
|
520
|
+
A filter to exclude certain nodes from reset.
|
521
|
+
state_tag : str, optional
|
522
|
+
An optional tag to categorize newly created states during the reset.
|
523
|
+
|
524
|
+
Raises
|
525
|
+
------
|
526
|
+
ValueError
|
527
|
+
If `axis_size` is None or not a positive integer.
|
528
|
+
TypeError
|
529
|
+
If `reset_kwargs` is not a mapping.
|
530
|
+
|
531
|
+
Examples
|
532
|
+
--------
|
533
|
+
.. code-block:: python
|
534
|
+
|
535
|
+
>>> import brainstate
|
536
|
+
>>>
|
537
|
+
>>> rnn = brainstate.nn.RNNCell(10, 20)
|
538
|
+
>>> # Initialize with 16 batched instances
|
539
|
+
>>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
|
540
|
+
>>>
|
541
|
+
>>> # Process sequences...
|
542
|
+
>>>
|
543
|
+
>>> # Reset all 16 batched instances
|
544
|
+
>>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
|
545
|
+
|
546
|
+
See Also
|
547
|
+
--------
|
548
|
+
reset_all_states : Non-vectorized version.
|
549
|
+
vmap_call_all_functions : The underlying vmap function call mechanism.
|
550
|
+
"""
|
551
|
+
vmap_call_all_fns(
|
552
|
+
target,
|
553
|
+
fn_name='reset_state',
|
554
|
+
args=reset_args,
|
555
|
+
kwargs=reset_kwargs,
|
556
|
+
axis_size=axis_size,
|
557
|
+
node_to_exclude=node_to_exclude,
|
558
|
+
state_tag=state_tag,
|
559
|
+
)
|
560
|
+
return target
|
561
|
+
|
562
|
+
|
563
|
+
@set_module_as('brainstate.nn')
|
564
|
+
def assign_state_values(
|
565
|
+
target: Module,
|
566
|
+
*state_by_abs_path: Mapping[str, Any]
|
567
|
+
) -> tuple[list[str], list[str]]:
|
568
|
+
"""
|
569
|
+
Assign state values to a module from one or more state dictionaries.
|
570
|
+
|
571
|
+
This function updates the state values of a module based on provided state dictionaries.
|
572
|
+
State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
|
573
|
+
The function handles missing and unexpected keys, returning them for inspection.
|
574
|
+
|
575
|
+
Parameters
|
576
|
+
----------
|
577
|
+
target : Module
|
578
|
+
The target module whose states will be updated.
|
579
|
+
*state_by_abs_path : Mapping[str, Any]
|
580
|
+
One or more state dictionaries with absolute path keys mapping to state values.
|
581
|
+
If multiple dictionaries are provided, they will be merged (later dictionaries
|
582
|
+
override earlier ones for duplicate keys).
|
583
|
+
|
584
|
+
Returns
|
585
|
+
-------
|
586
|
+
tuple[list[str], list[str]]
|
587
|
+
A tuple of (unexpected_keys, missing_keys):
|
588
|
+
|
589
|
+
- unexpected_keys: Keys present in the state dictionaries but not in the module
|
590
|
+
- missing_keys: Keys present in the module but not in the state dictionaries
|
591
|
+
|
592
|
+
Examples
|
593
|
+
--------
|
594
|
+
.. code-block:: python
|
595
|
+
|
596
|
+
>>> import brainstate
|
597
|
+
>>>
|
598
|
+
>>> net = brainstate.nn.Linear(10, 20)
|
599
|
+
>>> brainstate.nn.init_all_states(net)
|
600
|
+
>>>
|
601
|
+
>>> # Save state values
|
602
|
+
>>> state_dict = {path: state.value for path, state in net.states().items()}
|
603
|
+
>>>
|
604
|
+
>>> # Later, restore state values
|
605
|
+
>>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
|
606
|
+
>>> print(f"Unexpected keys: {unexpected}")
|
607
|
+
>>> print(f"Missing keys: {missing}")
|
608
|
+
|
609
|
+
Notes
|
610
|
+
-----
|
611
|
+
- All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
|
612
|
+
- Only states with matching keys are updated; unexpected and missing keys are
|
613
|
+
returned but do not cause errors.
|
614
|
+
- If multiple dictionaries contain the same key, the last one takes precedence.
|
615
|
+
"""
|
616
|
+
# Merge all state dictionaries
|
617
|
+
all_states = {}
|
618
|
+
for state_dict in state_by_abs_path:
|
619
|
+
all_states.update(state_dict)
|
620
|
+
|
621
|
+
# Get current module states
|
622
|
+
variables = target.states()
|
623
|
+
keys1 = set(all_states.keys())
|
624
|
+
keys2 = set(variables.keys())
|
625
|
+
|
626
|
+
# Update matching states
|
627
|
+
for key in keys2.intersection(keys1):
|
628
|
+
variables[key].value = jax.numpy.asarray(all_states[key])
|
629
|
+
|
630
|
+
# Return mismatched keys
|
631
|
+
unexpected_keys = sorted(keys1 - keys2)
|
632
|
+
missing_keys = sorted(keys2 - keys1)
|
633
|
+
return unexpected_keys, missing_keys
|