brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_collective_ops.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,71 +12,87 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
from collections import
|
17
|
-
from typing import Callable, TypeVar,
|
15
|
+
import warnings
|
16
|
+
from collections.abc import Sequence, Mapping
|
17
|
+
from typing import Callable, TypeVar, Any
|
18
18
|
|
19
19
|
import jax
|
20
20
|
|
21
21
|
from brainstate._state import catch_new_states
|
22
22
|
from brainstate._utils import set_module_as
|
23
|
-
from brainstate.augment import vmap, vmap_new_states
|
24
23
|
from brainstate.graph import nodes
|
25
|
-
from brainstate.
|
24
|
+
from brainstate.transform import vmap, vmap_new_states
|
26
25
|
from brainstate.typing import Filter
|
27
26
|
from ._module import Module
|
28
27
|
|
29
28
|
# the maximum order
|
30
29
|
MAX_ORDER = 10
|
31
30
|
|
32
|
-
# State Load Results
|
33
|
-
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
|
34
|
-
|
35
31
|
T = TypeVar('T', bound=Module)
|
36
32
|
|
37
33
|
__all__ = [
|
38
|
-
'MAX_ORDER',
|
39
34
|
'call_order',
|
40
|
-
'
|
41
|
-
'
|
35
|
+
'call_all_fns',
|
36
|
+
'vmap_call_all_fns',
|
42
37
|
'init_all_states',
|
43
38
|
'vmap_init_all_states',
|
44
39
|
'reset_all_states',
|
45
|
-
'
|
46
|
-
'save_all_states',
|
40
|
+
'vmap_reset_all_states',
|
47
41
|
'assign_state_values',
|
48
42
|
]
|
49
43
|
|
50
44
|
|
51
45
|
@set_module_as('brainstate.nn')
|
52
|
-
def call_order(
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
46
|
+
def call_order(
|
47
|
+
level: int = 0,
|
48
|
+
check_order_boundary: bool = True
|
49
|
+
) -> Callable[[Callable], Callable]:
|
50
|
+
"""
|
51
|
+
Decorator for specifying the execution order of functions in collective operations.
|
58
52
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
>>> brainstate.nn.call_order(-2)
|
53
|
+
This decorator attaches a `call_order` attribute to a function, which is used by
|
54
|
+
collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
|
55
|
+
to determine the execution order. Functions with lower order levels are executed first.
|
63
56
|
|
64
57
|
Parameters
|
65
58
|
----------
|
66
|
-
level: int
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
59
|
+
level : int, optional
|
60
|
+
The execution order level. Lower values indicate earlier execution.
|
61
|
+
Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
|
62
|
+
Default is 0.
|
63
|
+
check_order_boundary : bool, optional
|
64
|
+
Whether to validate that the order level is within the valid range [0, MAX_ORDER).
|
65
|
+
Default is True.
|
71
66
|
|
72
67
|
Returns
|
73
68
|
-------
|
74
|
-
|
69
|
+
Callable[[Callable], Callable]
|
70
|
+
A decorator function that adds the `call_order` attribute to the decorated function.
|
71
|
+
|
72
|
+
Raises
|
73
|
+
------
|
74
|
+
ValueError
|
75
|
+
If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
.. code-block:: python
|
80
|
+
|
81
|
+
>>> import brainstate
|
82
|
+
>>>
|
83
|
+
>>> class MyModule(brainstate.nn.Module):
|
84
|
+
... @brainstate.nn.call_order(0)
|
85
|
+
... def reset_state(self):
|
86
|
+
... print("Reset first")
|
87
|
+
...
|
88
|
+
... @brainstate.nn.call_order(1)
|
89
|
+
... def another_reset(self):
|
90
|
+
... print("Reset second")
|
75
91
|
"""
|
76
92
|
if check_order_boundary and (level < 0 or level >= MAX_ORDER):
|
77
|
-
raise ValueError(f'"
|
93
|
+
raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')
|
78
94
|
|
79
|
-
def wrap(fun: Callable):
|
95
|
+
def wrap(fun: Callable) -> Callable:
|
80
96
|
fun.call_order = level
|
81
97
|
return fun
|
82
98
|
|
@@ -84,164 +100,196 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
|
|
84
100
|
|
85
101
|
|
86
102
|
@set_module_as('brainstate.nn')
|
87
|
-
def
|
103
|
+
def call_all_fns(
|
88
104
|
target: T,
|
89
|
-
|
90
|
-
args:
|
91
|
-
kwargs:
|
105
|
+
fn_name: str,
|
106
|
+
args: Sequence[Any] | Any = (),
|
107
|
+
kwargs: Mapping[str, Any] | None = None,
|
92
108
|
node_to_exclude: Filter = None,
|
93
|
-
|
109
|
+
fn_if_not_exist: str = 'raise',
|
94
110
|
) -> T:
|
95
111
|
"""
|
96
|
-
Call a specified function on all nodes
|
112
|
+
Call a specified function on all module nodes within a target, respecting call order.
|
97
113
|
|
98
|
-
This function
|
99
|
-
on each node.
|
100
|
-
|
114
|
+
This function traverses all module nodes in the target and invokes the specified method
|
115
|
+
on each node. Functions decorated with `@call_order()` are executed in ascending order
|
116
|
+
of their level values, while functions without the decorator are executed first.
|
101
117
|
|
102
118
|
Parameters
|
103
|
-
|
104
|
-
target :
|
119
|
+
----------
|
120
|
+
target : Module
|
105
121
|
The target module on which to call functions.
|
106
|
-
|
107
|
-
The name of the
|
108
|
-
args : Tuple[Any, ...] | Any, optional
|
109
|
-
Positional arguments to pass to the called function. Default is an empty tuple.
|
110
|
-
kwargs : Dict[str, Any] | None, optional
|
111
|
-
Keyword arguments to pass to the called function. Default is None.
|
122
|
+
fn_name : str
|
123
|
+
The name of the method to call on each module node.
|
112
124
|
node_to_exclude : Filter, optional
|
113
|
-
A filter
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
- '
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
125
|
+
A filter to exclude certain nodes from the function call.
|
126
|
+
Can be a type, predicate function, or any filter supported by the graph API.
|
127
|
+
fn_if_not_exist : str, optional
|
128
|
+
Behavior when the specified method doesn't exist on a node:
|
129
|
+
|
130
|
+
- 'raise': Raise an AttributeError (default)
|
131
|
+
- 'pass' or 'none': Skip the node silently
|
132
|
+
- 'warn': Issue a warning and skip the node
|
133
|
+
args
|
134
|
+
Positional arguments to pass to the called method. A single non-tuple
|
135
|
+
argument will be automatically wrapped in a tuple. Default is ().
|
136
|
+
kwargs
|
137
|
+
Keyword arguments to pass to the called method. Default is None.
|
124
138
|
|
125
139
|
Raises
|
126
|
-
|
127
|
-
|
128
|
-
If fun_name is not a string or kwargs is not a
|
140
|
+
------
|
141
|
+
TypeError
|
142
|
+
If `fun_name` is not a string or `kwargs` is not a mapping.
|
129
143
|
ValueError
|
130
|
-
If
|
144
|
+
If `fn_if_not_exist` is not one of the allowed values.
|
131
145
|
AttributeError
|
132
|
-
If the specified
|
146
|
+
If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.
|
147
|
+
|
148
|
+
Examples
|
149
|
+
--------
|
150
|
+
.. code-block:: python
|
151
|
+
|
152
|
+
>>> import brainstate
|
153
|
+
>>>
|
154
|
+
>>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
|
155
|
+
>>> brainstate.nn.call_all_fns(net, 'init_state')
|
133
156
|
"""
|
134
|
-
|
157
|
+
if not isinstance(fn_name, str):
|
158
|
+
raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')
|
135
159
|
|
136
160
|
args = (args,) if not isinstance(args, tuple) else args
|
137
161
|
kwargs = kwargs or {}
|
138
|
-
|
162
|
+
if not isinstance(kwargs, Mapping):
|
163
|
+
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
|
139
164
|
|
140
165
|
all_nodes = nodes(target).filter(Module)
|
141
166
|
if node_to_exclude is not None:
|
142
167
|
all_nodes -= all_nodes.filter(node_to_exclude)
|
143
168
|
|
169
|
+
# Separate nodes with and without call_order
|
144
170
|
nodes_with_order = []
|
145
|
-
for node in all_nodes.
|
171
|
+
for path, node in all_nodes.items():
|
146
172
|
try:
|
147
|
-
fun = getattr(node,
|
173
|
+
fun = getattr(node, fn_name)
|
148
174
|
except AttributeError as e:
|
149
|
-
if
|
150
|
-
raise
|
151
|
-
|
175
|
+
if fn_if_not_exist == 'raise':
|
176
|
+
raise AttributeError(
|
177
|
+
f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
|
178
|
+
) from e
|
179
|
+
elif fn_if_not_exist in ('pass', 'none'):
|
180
|
+
continue
|
181
|
+
elif fn_if_not_exist == 'warn':
|
182
|
+
warnings.warn(
|
183
|
+
f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
|
184
|
+
f"Skipping.",
|
185
|
+
UserWarning
|
186
|
+
)
|
152
187
|
continue
|
153
188
|
else:
|
154
189
|
raise ValueError(
|
155
|
-
f
|
190
|
+
f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
|
191
|
+
)
|
192
|
+
|
193
|
+
if not callable(fun):
|
194
|
+
raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")
|
156
195
|
|
157
|
-
assert callable(fun), f'{fun_name} must be a callable function, but got {fun}.'
|
158
196
|
if hasattr(fun, 'call_order'):
|
159
197
|
nodes_with_order.append(node)
|
160
198
|
else:
|
161
199
|
fun(*args, **kwargs)
|
162
200
|
|
163
|
-
|
164
|
-
|
165
|
-
|
201
|
+
# Execute nodes with call_order in sorted order
|
202
|
+
for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
|
203
|
+
getattr(node, fn_name)(*args, **kwargs)
|
166
204
|
return target
|
167
205
|
|
168
206
|
|
169
|
-
def
|
207
|
+
def vmap_call_all_fns(
|
170
208
|
target: T,
|
171
|
-
|
172
|
-
args:
|
173
|
-
kwargs:
|
209
|
+
fn_name: str,
|
210
|
+
args: Sequence[Any] | Any = (),
|
211
|
+
kwargs: Mapping[str, Any] | None = None,
|
174
212
|
axis_size: int = None,
|
175
213
|
node_to_exclude: Filter = None,
|
176
|
-
|
177
|
-
|
214
|
+
state_tag: str | None = None,
|
215
|
+
fn_if_not_exist: str = 'raise',
|
178
216
|
) -> T:
|
179
217
|
"""
|
180
|
-
Apply vectorized mapping
|
218
|
+
Apply vectorized mapping to call a function on all module nodes with batched state handling.
|
181
219
|
|
182
|
-
This function
|
183
|
-
|
220
|
+
This function creates multiple batched instances by applying vmap to the specified method
|
221
|
+
call across all module nodes. Each batch element maintains its own random key and state
|
222
|
+
values. This is particularly useful for creating ensembles or batched models.
|
184
223
|
|
185
224
|
Parameters
|
186
|
-
|
187
|
-
target :
|
225
|
+
----------
|
226
|
+
target : Module
|
188
227
|
The target module on which to call functions.
|
189
|
-
|
190
|
-
The name of the
|
191
|
-
args :
|
192
|
-
Positional arguments to pass to the called
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
228
|
+
fn_name : str
|
229
|
+
The name of the method to call on each module node.
|
230
|
+
args : Sequence[Any] or Any, optional
|
231
|
+
Positional arguments to pass to the called method. A single non-tuple
|
232
|
+
argument will be automatically wrapped in a tuple. Default is ().
|
233
|
+
kwargs : Mapping[str, Any], optional
|
234
|
+
Keyword arguments to pass to the called method. Default is None.
|
235
|
+
axis_size : int
|
236
|
+
The size of the batch dimension for vmap. Must be a positive integer.
|
197
237
|
node_to_exclude : Filter, optional
|
198
|
-
A filter
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
238
|
+
A filter to exclude certain nodes from the function call.
|
239
|
+
state_tag : str, optional
|
240
|
+
An optional tag to categorize newly created states during the vmap operation.
|
241
|
+
fn_if_not_exist : str, optional
|
242
|
+
Behavior when the specified method doesn't exist on a node:
|
203
243
|
|
204
|
-
- 'raise': Raise an
|
205
|
-
- 'pass' or 'none': Skip the node
|
244
|
+
- 'raise': Raise an AttributeError (default)
|
245
|
+
- 'pass' or 'none': Skip the node silently
|
246
|
+
- 'warn': Issue a warning and skip the node
|
206
247
|
|
207
|
-
|
248
|
+
Raises
|
249
|
+
------
|
250
|
+
ValueError
|
251
|
+
If `axis_size` is None or not a positive integer.
|
252
|
+
TypeError
|
253
|
+
If `kwargs` is not a mapping.
|
254
|
+
|
255
|
+
Examples
|
208
256
|
--------
|
209
|
-
|
210
|
-
The target module after applying the vectorized function call on all applicable nodes.
|
257
|
+
.. code-block:: python
|
211
258
|
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
259
|
+
>>> import brainstate
|
260
|
+
>>>
|
261
|
+
>>> net = brainstate.nn.Linear(10, 20)
|
262
|
+
>>> # Create 5 batched instances with different initializations
|
263
|
+
>>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)
|
216
264
|
"""
|
217
|
-
|
265
|
+
|
266
|
+
if axis_size is None or axis_size <= 0:
|
267
|
+
raise ValueError(f"axis_size must be a positive integer, got {axis_size}")
|
218
268
|
|
219
269
|
if not isinstance(args, tuple):
|
220
270
|
args = (args,)
|
221
271
|
kwargs = kwargs or {}
|
222
|
-
|
272
|
+
if not isinstance(kwargs, Mapping):
|
273
|
+
raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')
|
223
274
|
|
224
|
-
@vmap(
|
225
|
-
def vmapped_fn(
|
226
|
-
|
227
|
-
|
228
|
-
call_all_functions(
|
275
|
+
@vmap(axis_size=axis_size)
|
276
|
+
def vmapped_fn():
|
277
|
+
with catch_new_states(state_tag) as inner_catcher:
|
278
|
+
call_all_fns(
|
229
279
|
target,
|
230
|
-
|
280
|
+
fn_name=fn_name,
|
231
281
|
args=args,
|
232
282
|
kwargs=kwargs,
|
233
283
|
node_to_exclude=node_to_exclude,
|
234
|
-
|
284
|
+
fn_if_not_exist=fn_if_not_exist
|
235
285
|
)
|
236
|
-
|
237
|
-
return values
|
286
|
+
return inner_catcher.get_state_values()
|
238
287
|
|
239
|
-
with catch_new_states(
|
240
|
-
values = vmapped_fn(
|
288
|
+
with catch_new_states(state_tag) as outer_catcher:
|
289
|
+
values = vmapped_fn()
|
241
290
|
states = outer_catcher.get_states()
|
242
291
|
for state, value in zip(states, values):
|
243
292
|
state.value = value
|
244
|
-
|
245
293
|
return target
|
246
294
|
|
247
295
|
|
@@ -253,88 +301,116 @@ def init_all_states(
|
|
253
301
|
**init_kwargs,
|
254
302
|
) -> T:
|
255
303
|
"""
|
256
|
-
Initialize
|
304
|
+
Initialize states for all module nodes within the target.
|
257
305
|
|
258
|
-
This
|
259
|
-
|
306
|
+
This is a convenience wrapper around `call_all_functions` that specifically calls
|
307
|
+
the `init_state` method on all module nodes. The execution order respects any
|
308
|
+
`@call_order()` decorators on the `init_state` methods.
|
260
309
|
|
261
310
|
Parameters
|
262
311
|
----------
|
263
|
-
target :
|
312
|
+
target : Module
|
264
313
|
The target module whose states are to be initialized.
|
265
|
-
init_args
|
266
|
-
|
267
|
-
If a single non-tuple argument is provided, it will be wrapped in a tuple.
|
268
|
-
init_kwargs : Dict[str, Any] | None, optional
|
269
|
-
Keyword arguments to be passed to each init_state method.
|
270
|
-
If None, an empty dictionary will be used.
|
314
|
+
*init_args
|
315
|
+
Variable positional arguments to pass to each `init_state` method.
|
271
316
|
node_to_exclude : Filter, optional
|
272
|
-
A filter
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
T
|
277
|
-
The target module with all states initialized.
|
317
|
+
A filter to exclude certain nodes from initialization.
|
318
|
+
Can be a type, predicate function, or any filter supported by the graph API.
|
319
|
+
**init_kwargs
|
320
|
+
Variable keyword arguments to pass to each `init_state` method.
|
278
321
|
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
322
|
+
Examples
|
323
|
+
--------
|
324
|
+
.. code-block:: python
|
325
|
+
|
326
|
+
>>> import brainstate
|
327
|
+
>>>
|
328
|
+
>>> net = brainstate.nn.Sequential(
|
329
|
+
... brainstate.nn.Linear(10, 20),
|
330
|
+
... brainstate.nn.Dropout(0.5)
|
331
|
+
... )
|
332
|
+
>>> # Initialize all states
|
333
|
+
>>> brainstate.nn.init_all_states(net)
|
334
|
+
>>>
|
335
|
+
>>> # Initialize with custom arguments
|
336
|
+
>>> brainstate.nn.init_all_states(net, batch_size=32)
|
337
|
+
|
338
|
+
See Also
|
339
|
+
--------
|
340
|
+
call_all_functions : The underlying function that executes the calls.
|
341
|
+
vmap_init_all_states : Vectorized version for batched initialization.
|
283
342
|
"""
|
284
|
-
|
343
|
+
call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude)
|
344
|
+
return target
|
285
345
|
|
286
346
|
|
287
347
|
@set_module_as('brainstate.nn')
|
288
348
|
def vmap_init_all_states(
|
289
349
|
target: T,
|
290
|
-
*init_args
|
350
|
+
*init_args,
|
291
351
|
axis_size: int = None,
|
292
352
|
node_to_exclude: Filter = None,
|
293
353
|
state_to_exclude: Filter = None,
|
294
354
|
state_tag: str | None = None,
|
295
|
-
**init_kwargs
|
355
|
+
**init_kwargs
|
296
356
|
) -> T:
|
297
357
|
"""
|
298
|
-
Initialize
|
358
|
+
Initialize states with vectorized mapping for creating batched module instances.
|
299
359
|
|
300
|
-
This function applies
|
301
|
-
instances of
|
360
|
+
This function applies vmap to the initialization process, creating multiple batched
|
361
|
+
instances of module states. Each batch element will have independent state values
|
362
|
+
and random keys. This is useful for ensemble models or parameter sweeps.
|
302
363
|
|
303
364
|
Parameters
|
304
|
-
|
305
|
-
target :
|
365
|
+
----------
|
366
|
+
target : Module
|
306
367
|
The target module whose states are to be initialized.
|
307
|
-
init_args
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
axis_size : int, optional
|
312
|
-
The size of the batch axis for vmap. This must be specified and should be greater than 0.
|
368
|
+
*init_args
|
369
|
+
Variable positional arguments to pass to each `init_state` method.
|
370
|
+
axis_size : int
|
371
|
+
The size of the batch dimension. Must be a positive integer.
|
313
372
|
node_to_exclude : Filter, optional
|
314
373
|
A filter to exclude certain nodes from initialization.
|
315
|
-
|
316
|
-
A
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
374
|
+
state_to_exclude : Filter, optional
|
375
|
+
A filter to exclude certain states from being vmapped.
|
376
|
+
Excluded states will remain shared across all batched instances.
|
377
|
+
state_tag : str, optional
|
378
|
+
An optional tag to categorize newly created states.
|
379
|
+
**init_kwargs
|
380
|
+
Variable keyword arguments to pass to each `init_state` method.
|
322
381
|
|
323
382
|
Raises
|
324
|
-
|
325
|
-
|
326
|
-
If axis_size is
|
327
|
-
|
383
|
+
------
|
384
|
+
ValueError
|
385
|
+
If `axis_size` is None or not a positive integer.
|
386
|
+
|
387
|
+
Examples
|
388
|
+
--------
|
389
|
+
.. code-block:: python
|
390
|
+
|
391
|
+
>>> import brainstate
|
392
|
+
>>>
|
393
|
+
>>> net = brainstate.nn.Linear(10, 20)
|
394
|
+
>>> # Create 8 batched instances with different random initializations
|
395
|
+
>>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
|
396
|
+
>>>
|
397
|
+
>>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
|
398
|
+
>>> print(net.weight.shape)
|
399
|
+
|
400
|
+
See Also
|
401
|
+
--------
|
402
|
+
init_all_states : Non-vectorized version.
|
403
|
+
vmap_new_states : The underlying vmap transformation for states.
|
328
404
|
"""
|
329
405
|
|
330
|
-
#
|
406
|
+
# vmap_call_all_functions(
|
331
407
|
# target,
|
332
|
-
# 'init_state',
|
408
|
+
# fun_name='init_state',
|
333
409
|
# args=init_args,
|
334
410
|
# kwargs=init_kwargs,
|
335
411
|
# axis_size=axis_size,
|
336
412
|
# node_to_exclude=node_to_exclude,
|
337
|
-
#
|
413
|
+
# state_tag=state_tag,
|
338
414
|
# )
|
339
415
|
|
340
416
|
def init_fn():
|
@@ -353,162 +429,205 @@ def vmap_init_all_states(
|
|
353
429
|
@set_module_as('brainstate.nn')
|
354
430
|
def reset_all_states(
|
355
431
|
target: T,
|
356
|
-
reset_args
|
357
|
-
reset_kwargs: Dict[str, Any] | None = None,
|
432
|
+
*reset_args,
|
358
433
|
node_to_exclude: Filter = None,
|
434
|
+
**reset_kwargs,
|
359
435
|
) -> T:
|
360
436
|
"""
|
361
|
-
Reset
|
437
|
+
Reset states for all module nodes within the target.
|
362
438
|
|
363
|
-
This
|
364
|
-
|
439
|
+
This is a convenience wrapper around `call_all_functions` that specifically calls
|
440
|
+
the `reset_state` method on all module nodes. The execution order respects any
|
441
|
+
`@call_order()` decorators on the `reset_state` methods. This is typically used
|
442
|
+
to reset recurrent neural network states between sequences.
|
365
443
|
|
366
444
|
Parameters
|
367
445
|
----------
|
368
|
-
target :
|
446
|
+
target : Module
|
369
447
|
The target module whose states are to be reset.
|
370
|
-
reset_args
|
371
|
-
Positional arguments to
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
448
|
+
reset_args
|
449
|
+
Positional arguments to pass to each `reset_state` method.
|
450
|
+
A single non-tuple argument will be automatically wrapped in a tuple.
|
451
|
+
Default is ().
|
452
|
+
reset_kwargs
|
453
|
+
Keyword arguments to pass to each `reset_state` method.
|
454
|
+
Default is None.
|
376
455
|
node_to_exclude : Filter, optional
|
377
|
-
A filter
|
378
|
-
|
379
|
-
Returns
|
380
|
-
-------
|
381
|
-
T
|
382
|
-
The target module with all states reset.
|
456
|
+
A filter to exclude certain nodes from reset.
|
457
|
+
Can be a type, predicate function, or any filter supported by the graph API.
|
383
458
|
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
459
|
+
Examples
|
460
|
+
--------
|
461
|
+
.. code-block:: python
|
462
|
+
|
463
|
+
>>> import brainstate
|
464
|
+
>>>
|
465
|
+
>>> rnn = brainstate.nn.RNNCell(10, 20)
|
466
|
+
>>> brainstate.nn.init_all_states(rnn, batch_size=32)
|
467
|
+
>>>
|
468
|
+
>>> # Process a sequence
|
469
|
+
>>> for x in sequence:
|
470
|
+
... output = rnn(x)
|
471
|
+
>>>
|
472
|
+
>>> # Reset states before processing next sequence
|
473
|
+
>>> brainstate.nn.reset_all_states(rnn)
|
474
|
+
|
475
|
+
See Also
|
476
|
+
--------
|
477
|
+
call_all_functions : The underlying function that executes the calls.
|
478
|
+
vmap_reset_all_states : Vectorized version for batched reset.
|
388
479
|
"""
|
389
|
-
|
480
|
+
call_all_fns(
|
390
481
|
target,
|
391
|
-
|
482
|
+
fn_name='reset_state',
|
392
483
|
args=reset_args,
|
393
484
|
kwargs=reset_kwargs,
|
394
485
|
node_to_exclude=node_to_exclude
|
395
486
|
)
|
487
|
+
return target
|
396
488
|
|
397
489
|
|
398
490
|
def vmap_reset_all_states(
|
399
491
|
target: T,
|
400
|
-
reset_args
|
401
|
-
reset_kwargs: Dict[str, Any] | None = None,
|
492
|
+
*reset_args,
|
402
493
|
axis_size: int = None,
|
403
494
|
node_to_exclude: Filter = None,
|
404
|
-
|
495
|
+
state_tag: str | None = None,
|
496
|
+
**reset_kwargs,
|
405
497
|
) -> T:
|
406
498
|
"""
|
407
|
-
Reset
|
499
|
+
Reset states with vectorized mapping across batched module instances.
|
408
500
|
|
409
|
-
This function applies
|
410
|
-
instances of the
|
501
|
+
This function applies vmap to the reset process, resetting states across all
|
502
|
+
batched instances of the module. Each batch element will have its state reset
|
503
|
+
independently with its own random key. This is useful when working with batched
|
504
|
+
recurrent models or ensembles.
|
411
505
|
|
412
506
|
Parameters
|
413
|
-
|
414
|
-
target :
|
507
|
+
----------
|
508
|
+
target : Module
|
415
509
|
The target module whose states are to be reset.
|
416
|
-
reset_args
|
417
|
-
Positional arguments to
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
510
|
+
reset_args
|
511
|
+
Positional arguments to pass to each `reset_state` method.
|
512
|
+
A single non-tuple argument will be automatically wrapped in a tuple.
|
513
|
+
Default is ().
|
514
|
+
reset_kwargs
|
515
|
+
Keyword arguments to pass to each `reset_state` method.
|
516
|
+
Default is None.
|
517
|
+
axis_size : int
|
518
|
+
The size of the batch dimension. Must be a positive integer.
|
422
519
|
node_to_exclude : Filter, optional
|
423
520
|
A filter to exclude certain nodes from reset.
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
Returns
|
428
|
-
--------
|
429
|
-
T
|
430
|
-
The target module with reset states.
|
521
|
+
state_tag : str, optional
|
522
|
+
An optional tag to categorize newly created states during the reset.
|
431
523
|
|
432
524
|
Raises
|
433
|
-
|
434
|
-
|
435
|
-
If axis_size is
|
436
|
-
|
525
|
+
------
|
526
|
+
ValueError
|
527
|
+
If `axis_size` is None or not a positive integer.
|
528
|
+
TypeError
|
529
|
+
If `reset_kwargs` is not a mapping.
|
530
|
+
|
531
|
+
Examples
|
532
|
+
--------
|
533
|
+
.. code-block:: python
|
534
|
+
|
535
|
+
>>> import brainstate
|
536
|
+
>>>
|
537
|
+
>>> rnn = brainstate.nn.RNNCell(10, 20)
|
538
|
+
>>> # Initialize with 16 batched instances
|
539
|
+
>>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
|
540
|
+
>>>
|
541
|
+
>>> # Process sequences...
|
542
|
+
>>>
|
543
|
+
>>> # Reset all 16 batched instances
|
544
|
+
>>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
|
545
|
+
|
546
|
+
See Also
|
547
|
+
--------
|
548
|
+
reset_all_states : Non-vectorized version.
|
549
|
+
vmap_call_all_functions : The underlying vmap function call mechanism.
|
437
550
|
"""
|
438
|
-
|
551
|
+
vmap_call_all_fns(
|
439
552
|
target,
|
440
|
-
|
553
|
+
fn_name='reset_state',
|
441
554
|
args=reset_args,
|
442
555
|
kwargs=reset_kwargs,
|
443
556
|
axis_size=axis_size,
|
444
557
|
node_to_exclude=node_to_exclude,
|
445
|
-
|
558
|
+
state_tag=state_tag,
|
446
559
|
)
|
560
|
+
return target
|
447
561
|
|
448
562
|
|
449
563
|
@set_module_as('brainstate.nn')
|
450
|
-
def
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
Args:
|
456
|
-
target: Module. The dynamical system to load its states.
|
457
|
-
state_dict: dict. A dict containing parameters and persistent buffers.
|
458
|
-
|
459
|
-
Returns
|
460
|
-
-------
|
461
|
-
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
462
|
-
|
463
|
-
* **missing_keys** is a list of str containing the missing keys
|
464
|
-
* **unexpected_keys** is a list of str containing the unexpected keys
|
564
|
+
def assign_state_values(
|
565
|
+
target: Module,
|
566
|
+
*state_by_abs_path: Mapping[str, Any]
|
567
|
+
) -> tuple[list[str], list[str]]:
|
465
568
|
"""
|
466
|
-
|
467
|
-
unexpected_keys = []
|
468
|
-
for path, node in nodes(target).items():
|
469
|
-
r = node.load_state(state_dict[path], **kwargs)
|
470
|
-
if r is not None:
|
471
|
-
missing, unexpected = r
|
472
|
-
missing_keys.extend([f'{path}.{key}' for key in missing])
|
473
|
-
unexpected_keys.extend([f'{path}.{key}' for key in unexpected])
|
474
|
-
return StateLoadResult(missing_keys, unexpected_keys)
|
569
|
+
Assign state values to a module from one or more state dictionaries.
|
475
570
|
|
571
|
+
This function updates the state values of a module based on provided state dictionaries.
|
572
|
+
State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias').
|
573
|
+
The function handles missing and unexpected keys, returning them for inspection.
|
476
574
|
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
575
|
+
Parameters
|
576
|
+
----------
|
577
|
+
target : Module
|
578
|
+
The target module whose states will be updated.
|
579
|
+
*state_by_abs_path : Mapping[str, Any]
|
580
|
+
One or more state dictionaries with absolute path keys mapping to state values.
|
581
|
+
If multiple dictionaries are provided, they will be merged (later dictionaries
|
582
|
+
override earlier ones for duplicate keys).
|
484
583
|
|
485
584
|
Returns
|
486
|
-
|
487
|
-
|
488
|
-
|
585
|
+
-------
|
586
|
+
tuple[list[str], list[str]]
|
587
|
+
A tuple of (unexpected_keys, missing_keys):
|
489
588
|
|
589
|
+
- unexpected_keys: Keys present in the state dictionaries but not in the module
|
590
|
+
- missing_keys: Keys present in the module but not in the state dictionaries
|
490
591
|
|
491
|
-
|
492
|
-
|
592
|
+
Examples
|
593
|
+
--------
|
594
|
+
.. code-block:: python
|
595
|
+
|
596
|
+
>>> import brainstate
|
597
|
+
>>>
|
598
|
+
>>> net = brainstate.nn.Linear(10, 20)
|
599
|
+
>>> brainstate.nn.init_all_states(net)
|
600
|
+
>>>
|
601
|
+
>>> # Save state values
|
602
|
+
>>> state_dict = {path: state.value for path, state in net.states().items()}
|
603
|
+
>>>
|
604
|
+
>>> # Later, restore state values
|
605
|
+
>>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict)
|
606
|
+
>>> print(f"Unexpected keys: {unexpected}")
|
607
|
+
>>> print(f"Missing keys: {missing}")
|
608
|
+
|
609
|
+
Notes
|
610
|
+
-----
|
611
|
+
- All values are automatically converted to JAX arrays using `jax.numpy.asarray`.
|
612
|
+
- Only states with matching keys are updated; unexpected and missing keys are
|
613
|
+
returned but do not cause errors.
|
614
|
+
- If multiple dictionaries contain the same key, the last one takes precedence.
|
493
615
|
"""
|
494
|
-
|
616
|
+
# Merge all state dictionaries
|
617
|
+
all_states = {}
|
618
|
+
for state_dict in state_by_abs_path:
|
619
|
+
all_states.update(state_dict)
|
495
620
|
|
496
|
-
|
497
|
-
----------
|
498
|
-
target: Module
|
499
|
-
The target module.
|
500
|
-
state_by_abs_path: dict
|
501
|
-
The state dictionary which is accessed by the "absolute" accessing method.
|
502
|
-
|
503
|
-
"""
|
504
|
-
all_states = dict()
|
505
|
-
for state in state_by_abs_path:
|
506
|
-
all_states.update(state)
|
621
|
+
# Get current module states
|
507
622
|
variables = target.states()
|
508
623
|
keys1 = set(all_states.keys())
|
509
624
|
keys2 = set(variables.keys())
|
625
|
+
|
626
|
+
# Update matching states
|
510
627
|
for key in keys2.intersection(keys1):
|
511
628
|
variables[key].value = jax.numpy.asarray(all_states[key])
|
512
|
-
|
513
|
-
|
629
|
+
|
630
|
+
# Return mismatched keys
|
631
|
+
unexpected_keys = sorted(keys1 - keys2)
|
632
|
+
missing_keys = sorted(keys2 - keys1)
|
514
633
|
return unexpected_keys, missing_keys
|