brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +875 -93
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +183 -35
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +128 -10
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/augment/_mapping.py
CHANGED
@@ -16,15 +16,26 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Tuple, Union, Optional, Dict, List
|
20
|
-
|
21
19
|
import jax
|
22
20
|
from jax.interpreters.batching import BatchTracer
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
21
|
+
from typing import (
|
22
|
+
Any,
|
23
|
+
TypeVar,
|
24
|
+
Callable,
|
25
|
+
Hashable,
|
26
|
+
Sequence,
|
27
|
+
Iterable,
|
28
|
+
Tuple,
|
29
|
+
Union,
|
30
|
+
Optional,
|
31
|
+
Dict,
|
32
|
+
List
|
33
|
+
)
|
34
|
+
|
35
|
+
from brainstate._state import State, catch_new_states
|
36
|
+
from brainstate.compile import scan, StatefulFunction
|
37
|
+
from brainstate.random import RandomState, DEFAULT
|
38
|
+
from brainstate.typing import Missing, Filter
|
28
39
|
from brainstate.util import NestedDict, BrainStateError
|
29
40
|
from ._random import restore_rngs
|
30
41
|
|
@@ -32,21 +43,56 @@ __all__ = [
|
|
32
43
|
'vmap',
|
33
44
|
'pmap',
|
34
45
|
'map',
|
46
|
+
'vmap_new_states',
|
35
47
|
]
|
36
48
|
|
37
|
-
AxisName = Hashable
|
38
49
|
F = TypeVar("F", bound=Callable)
|
50
|
+
AxisName = Hashable
|
39
51
|
AxisToState = Dict[int, List[State]]
|
40
52
|
StateToAxis = Dict[State, int]
|
41
53
|
|
42
54
|
|
43
55
|
class BatchAxisError(BrainStateError):
|
56
|
+
"""
|
57
|
+
Exception raised for errors related to batch axis operations.
|
58
|
+
|
59
|
+
This custom exception is used to indicate errors that occur during
|
60
|
+
batch processing or vectorization operations, particularly in the
|
61
|
+
context of state management in the BrainState framework.
|
62
|
+
|
63
|
+
Inherits from:
|
64
|
+
BrainStateError: The base error class for BrainState-related exceptions.
|
65
|
+
"""
|
44
66
|
pass
|
45
67
|
|
46
68
|
|
47
69
|
def _flatten_in_out_states(
|
48
70
|
in_states: Dict[int, Dict] | Any = None,
|
49
71
|
) -> Tuple[AxisToState, StateToAxis]:
|
72
|
+
"""
|
73
|
+
Flattens and organizes input or output states into axis-based mappings.
|
74
|
+
|
75
|
+
This function processes the input or output states, converting them into two
|
76
|
+
dictionary representations: one mapping axes to states, and another mapping
|
77
|
+
states to axes. It handles both structured (Dict[int, Dict]) and unstructured
|
78
|
+
input formats.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
in_states (Dict[int, Dict] | Any, optional): The input or output states to be
|
82
|
+
flattened. Can be a nested dictionary structure where the outer keys are
|
83
|
+
axes and inner dictionaries contain states, or any other structure
|
84
|
+
containing states. Defaults to None.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
Tuple[AxisToState, StateToAxis]: A tuple containing two dictionaries:
|
88
|
+
- AxisToState: Maps axes (int) to lists of states.
|
89
|
+
- StateToAxis: Maps individual states to their corresponding axes (int).
|
90
|
+
|
91
|
+
Note:
|
92
|
+
If in_states is None, empty dictionaries are returned for both mappings.
|
93
|
+
If in_states is not in the expected Dict[int, Dict] format, all states are
|
94
|
+
assigned to axis 0.
|
95
|
+
"""
|
50
96
|
if in_states is None:
|
51
97
|
return dict(), dict()
|
52
98
|
if isinstance(in_states, dict):
|
@@ -72,16 +118,165 @@ def _flatten_in_out_states(
|
|
72
118
|
return axis_to_states, state_to_axis
|
73
119
|
|
74
120
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
121
|
+
def _remove_axis(x, axis: int):
|
122
|
+
"""
|
123
|
+
Remove a specified axis from an array or nested structure.
|
124
|
+
|
125
|
+
This function removes a specified axis from an array or nested structure,
|
126
|
+
adjusting the shape and structure of the output accordingly.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
x (Any): The input array or nested structure to remove the axis from.
|
130
|
+
axis (int): The axis to remove from the input.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Any: The output array or nested structure with the specified axis removed.
|
134
|
+
"""
|
135
|
+
assert isinstance(axis, int), f"Expected axis to be an integer, but got {type(axis)}"
|
136
|
+
if axis < 0:
|
137
|
+
axis += x.ndim
|
138
|
+
if axis < 0 or axis >= x.ndim:
|
139
|
+
raise IndexError(f"Axis {axis} is out of bounds for array of shape {x.shape}")
|
140
|
+
return x[tuple(slice(None, None, None) if i != axis else 0 for i in range(x.ndim))]
|
141
|
+
|
142
|
+
|
143
|
+
def _compile_stateful_function(
|
144
|
+
stateful_fn: StatefulFunction,
|
145
|
+
in_axes: int | Tuple[int, ...],
|
146
|
+
args: Tuple
|
147
|
+
):
|
148
|
+
"""
|
149
|
+
Compile a stateful function with specified input axes and arguments.
|
150
|
+
|
151
|
+
This function prepares and compiles a stateful function for vectorized mapping (vmap)
|
152
|
+
by adjusting the input arguments based on the specified axes and then generating
|
153
|
+
the function's JAX program representation (jaxpr).
|
154
|
+
|
155
|
+
Args:
|
156
|
+
stateful_fn (StatefulFunction): The stateful function to be compiled.
|
157
|
+
in_axes (int | Tuple[int, ...]): Specifies which axes of the input arguments
|
158
|
+
to map over. Can be a single integer (same for all args) or a tuple of integers.
|
159
|
+
args (Tuple): The input arguments to the function.
|
160
|
+
|
161
|
+
Raises:
|
162
|
+
ValueError: If the length of in_axes tuple doesn't match the number of arguments.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
None. The function modifies the stateful_fn in-place by calling make_jaxpr.
|
166
|
+
"""
|
167
|
+
in_axes_st, in_axes = in_axes
|
168
|
+
state_vals, args = args
|
169
|
+
|
170
|
+
# check in_axes
|
171
|
+
if isinstance(in_axes, tuple) and len(in_axes) != len(args):
|
172
|
+
raise ValueError(
|
173
|
+
"vmap in_axes must be an int, None, or a tuple of entries corresponding "
|
174
|
+
"to the positional arguments passed to the function, "
|
175
|
+
f"but got {len(in_axes)=}, {len(args)=}"
|
176
|
+
)
|
177
|
+
|
178
|
+
# check state_vals
|
179
|
+
if len(state_vals) > 0:
|
180
|
+
state_vals = [jax.tree.map(lambda x: _remove_axis(x, axis), vals)
|
181
|
+
for vals, axis in zip(state_vals, in_axes_st)]
|
182
|
+
else:
|
183
|
+
state_vals = []
|
184
|
+
|
185
|
+
if isinstance(in_axes, int):
|
186
|
+
args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
|
187
|
+
elif isinstance(in_axes, tuple):
|
188
|
+
args = tuple(
|
189
|
+
[arg if in_axis is None else _remove_axis(arg, in_axis)
|
190
|
+
for arg, in_axis in zip(args, in_axes)]
|
191
|
+
)
|
192
|
+
stateful_fn.make_jaxpr(state_vals, args)
|
193
|
+
return stateful_fn.get_arg_cache_key(state_vals, args)
|
194
|
+
|
195
|
+
|
196
|
+
def _get_batch_size(
|
197
|
+
args: Tuple,
|
198
|
+
in_axes: int | Tuple[int, ...],
|
199
|
+
in_states: AxisToState,
|
200
|
+
axis_size: Optional[int] = None,
|
201
|
+
) -> int:
|
202
|
+
"""
|
203
|
+
Determine the batch size from input arguments, axes, and states.
|
204
|
+
|
205
|
+
This function calculates the batch size by examining the shapes of input arguments
|
206
|
+
and states along specified axes. It ensures consistency across all inputs.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
args (Tuple): The input arguments to the function being vectorized.
|
210
|
+
in_axes (int | Tuple[int, ...]): The axes along which to vectorize for each argument.
|
211
|
+
Can be a single integer (same for all args) or a tuple of integers.
|
212
|
+
in_states (AxisToState): A dictionary mapping axes to lists of states.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
int: The determined batch size.
|
216
|
+
|
217
|
+
Raises:
|
218
|
+
ValueError: If unable to determine batch size or if inconsistent batch sizes are found.
|
219
|
+
"""
|
220
|
+
batch_sizes = []
|
221
|
+
|
222
|
+
# Check batch size from args and in_axes
|
223
|
+
if isinstance(in_axes, int):
|
224
|
+
in_axes = (in_axes,) * len(args)
|
225
|
+
for arg, in_axis in zip(args, in_axes):
|
226
|
+
if in_axis is not None:
|
227
|
+
arg_leaves = jax.tree.leaves(arg)
|
228
|
+
if arg_leaves:
|
229
|
+
batch_sizes.append(arg_leaves[0].shape[in_axis])
|
230
|
+
|
231
|
+
# Check batch size from in_states
|
232
|
+
if in_states is not None:
|
233
|
+
for axis, states in in_states.items():
|
234
|
+
for state in states:
|
235
|
+
state_leaves = jax.tree.leaves(state.value)
|
236
|
+
if len(state_leaves):
|
237
|
+
batch_sizes.append(state_leaves[0].shape[axis])
|
238
|
+
|
239
|
+
if len(batch_sizes) == 0:
|
240
|
+
assert axis_size is not None, (
|
241
|
+
"Unable to determine batch size. Please provide the 'axis_size' argument."
|
242
|
+
)
|
243
|
+
return axis_size
|
244
|
+
else:
|
245
|
+
# Ensure all batch sizes are consistent
|
246
|
+
if len(set(batch_sizes)) > 1:
|
247
|
+
raise ValueError(f"Inconsistent batch sizes found: {set(batch_sizes)}")
|
248
|
+
|
249
|
+
return batch_sizes[0]
|
250
|
+
|
251
|
+
|
252
|
+
def _format_state_axes(
|
253
|
+
in_states,
|
254
|
+
out_states,
|
84
255
|
):
|
256
|
+
"""
|
257
|
+
Format and validate the axes of input and output states.
|
258
|
+
|
259
|
+
This function processes the input and output states, ensuring consistency
|
260
|
+
between their axis mappings. It also handles cases where a state appears
|
261
|
+
in the input but not in the output.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
in_states: The input states to be formatted. Can be a dictionary mapping
|
265
|
+
axes to states, or any other structure containing states.
|
266
|
+
out_states: The output states to be formatted. Can be a dictionary mapping
|
267
|
+
axes to states, or any other structure containing states.
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
A tuple containing four elements:
|
271
|
+
- axis_to_in_states (dict): Mapping of axes to input states.
|
272
|
+
- in_state_to_axis (dict): Mapping of input states to their axes.
|
273
|
+
- axis_to_out_states (dict): Mapping of axes to output states.
|
274
|
+
- out_state_to_axis (dict): Mapping of output states to their axes.
|
275
|
+
|
276
|
+
Raises:
|
277
|
+
BatchAxisError: If there's an inconsistency between the axis mappings
|
278
|
+
of input and output states.
|
279
|
+
"""
|
85
280
|
axis_to_in_states, in_state_to_axis = _flatten_in_out_states(in_states)
|
86
281
|
axis_to_out_states, out_state_to_axis = _flatten_in_out_states(out_states)
|
87
282
|
for _in_state, _axis in in_state_to_axis.items():
|
@@ -90,8 +285,8 @@ def _vmap_transform(
|
|
90
285
|
if _out_axis != _axis:
|
91
286
|
_in_state.raise_error_with_source_info(
|
92
287
|
BatchAxisError(
|
93
|
-
f"State {_in_state} has been mapped to axis {_axis} in in_states, "
|
94
|
-
f"However, it is mapped to axis {_out_axis} in out_states."
|
288
|
+
f"State {_in_state} has been mapped to axis {_axis} in 'in_states', "
|
289
|
+
f"However, it is mapped to axis {_out_axis} in 'out_states'."
|
95
290
|
)
|
96
291
|
)
|
97
292
|
else:
|
@@ -99,65 +294,286 @@ def _vmap_transform(
|
|
99
294
|
if _axis not in axis_to_out_states:
|
100
295
|
axis_to_out_states[_axis] = []
|
101
296
|
axis_to_out_states[_axis].append(_in_state)
|
102
|
-
if isinstance(rngs, RandomState):
|
103
|
-
rngs = (rngs,)
|
104
|
-
rng_ids = set([id(rng) for rng in rngs])
|
105
297
|
|
106
|
-
|
107
|
-
|
298
|
+
return axis_to_in_states, in_state_to_axis, axis_to_out_states, out_state_to_axis
|
299
|
+
|
300
|
+
|
301
|
+
def _vmap_transform(
|
302
|
+
f: F,
|
303
|
+
*,
|
304
|
+
in_axes: int | None | Sequence[Any] = 0,
|
305
|
+
out_axes: Any = 0,
|
306
|
+
in_states: Dict[int, Dict] | Any | None = None,
|
307
|
+
out_states: Dict[int, Dict] | Any | None = None,
|
308
|
+
axis_size: Optional[int] = None,
|
309
|
+
axis_name: AxisName | None = None,
|
310
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
311
|
+
):
|
312
|
+
"""
|
313
|
+
Transforms a function for vectorized mapping (vmap) with state management.
|
314
|
+
|
315
|
+
This internal function applies vectorized mapping to the input function while
|
316
|
+
handling state management for input and output states. It supports custom
|
317
|
+
axis specifications for both inputs and outputs.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
f (F): The function to be transformed for vectorized mapping.
|
321
|
+
in_axes (int | None | Sequence[Any]): Specifies which axes of the input
|
322
|
+
arguments to map over. Default is 0.
|
323
|
+
out_axes (Any): Specifies where the mapped axis should appear in the output.
|
324
|
+
Default is 0.
|
325
|
+
in_states (Dict[int, Dict] | Any | None): Specifies the input states and
|
326
|
+
their corresponding axes for mapping. Default is None.
|
327
|
+
out_states (Dict[int, Dict] | Any | None): Specifies the output states and
|
328
|
+
their corresponding axes for mapping. Default is None.
|
329
|
+
**transform_kwargs: Additional keyword arguments for the transformation.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
Callable: A new function that applies vectorized mapping to the input
|
333
|
+
function while managing states.
|
334
|
+
"""
|
335
|
+
|
336
|
+
# TODO: support jax.disable_jit()
|
337
|
+
|
338
|
+
# format state axes
|
339
|
+
(
|
340
|
+
axis_to_in_states,
|
341
|
+
in_state_to_axis,
|
342
|
+
axis_to_out_states,
|
343
|
+
out_state_to_axis
|
344
|
+
) = _format_state_axes(in_states, out_states)
|
345
|
+
|
346
|
+
# check in_axes
|
347
|
+
if isinstance(in_axes, list):
|
348
|
+
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
349
|
+
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
350
|
+
# in cases like these users expect tuples and lists to be treated
|
351
|
+
# essentially interchangeably, so we canonicalize lists to tuples here
|
352
|
+
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
|
353
|
+
in_axes = tuple(in_axes)
|
354
|
+
|
355
|
+
def _vmap_fn_for_compilation(in_vmap_state_vals, args):
|
356
|
+
"""
|
357
|
+
Compile a function for vectorized mapping (vmap) with state restoration.
|
358
|
+
|
359
|
+
This internal function is used to prepare a function for vectorized mapping
|
360
|
+
by restoring state values before calling the original function.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
in_vmap_state_vals (List[List]): A nested list containing the state values
|
364
|
+
to be restored. The outer list corresponds to different axes, while
|
365
|
+
the inner lists contain the state values for each axis.
|
366
|
+
args (Tuple): The arguments to be passed to the original function after
|
367
|
+
state restoration.
|
368
|
+
|
369
|
+
Returns:
|
370
|
+
Any: The result of calling the original function 'f' with the restored
|
371
|
+
state and provided arguments.
|
372
|
+
"""
|
108
373
|
# restore state values
|
109
374
|
for i, states in enumerate(axis_to_in_states.values()):
|
110
|
-
for state, state_val in zip(states,
|
375
|
+
for state, state_val in zip(states, in_vmap_state_vals[i]):
|
111
376
|
state.restore_value(state_val)
|
112
377
|
|
113
378
|
# call the function
|
114
|
-
|
115
|
-
|
379
|
+
return f(*args)
|
380
|
+
|
381
|
+
# stateful function
|
382
|
+
stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
|
116
383
|
|
117
|
-
|
118
|
-
|
384
|
+
@functools.wraps(f)
|
385
|
+
def new_fn_for_vmap(
|
386
|
+
rng_keys,
|
387
|
+
in_state_vmap_vals,
|
388
|
+
in_state_oth_vals,
|
389
|
+
args,
|
390
|
+
):
|
391
|
+
"""
|
392
|
+
Wrapper function for vectorized mapping (vmap) that handles state restoration and function execution.
|
393
|
+
|
394
|
+
This function restores state values, random number generators (RNGs), and other state values
|
395
|
+
before calling the original function. It then processes the outputs and prepares them for
|
396
|
+
vectorized mapping.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
rng_keys (Sequence): Random number generator keys for each mapped instance.
|
400
|
+
in_state_vmap_vals (Sequence[Sequence]): Input state values for vectorized mapping,
|
401
|
+
organized by axis.
|
402
|
+
in_state_oth_vals (Sequence): Other input state values not involved in vectorized mapping.
|
403
|
+
args (Tuple): Arguments to be passed to the original function.
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
Tuple: A tuple containing four elements:
|
407
|
+
- out_rng_keys (List): Updated RNG keys after function execution.
|
408
|
+
- out_state_vmap_vals (List[List]): Output state values for vectorized mapping,
|
409
|
+
organized by axis.
|
410
|
+
- out_state_oth_vals (List): Other output state values not involved in vectorized mapping.
|
411
|
+
- outs: The output of the original function call.
|
412
|
+
|
413
|
+
Raises:
|
414
|
+
AssertionError: If there's a mismatch in the number of states, state values, or RNG keys.
|
415
|
+
BatchAxisError: If a state value is batched but not included in out_states.
|
416
|
+
"""
|
417
|
+
# restore vmapping state values
|
418
|
+
for i, states in enumerate(axis_to_in_states.values()):
|
419
|
+
assert len(states) == len(in_state_vmap_vals[i]), (
|
420
|
+
f"The number of states in axis {i} should be equal to the number "
|
421
|
+
f"of state values, but got {len(states)} and {len(in_state_vmap_vals[i])}."
|
422
|
+
)
|
423
|
+
for state, state_val in zip(states, in_state_vmap_vals[i]):
|
424
|
+
state.restore_value(state_val)
|
425
|
+
|
426
|
+
# restore rngs
|
427
|
+
cache_key = stateful_fn.get_arg_cache_key(in_state_vmap_vals, args)
|
428
|
+
state_trace = stateful_fn.get_state_trace(cache_key)
|
429
|
+
rngs = state_trace.state_subset(RandomState)
|
430
|
+
rng_sets = set(rngs)
|
431
|
+
assert len(rngs) == len(rng_keys), (
|
432
|
+
f"The number of random states in the function should be equal to the number "
|
433
|
+
f"of random keys, but got {len(rngs)} and {len(rng_keys)}."
|
434
|
+
)
|
435
|
+
for rng, key in zip(rngs, rng_keys):
|
436
|
+
rng.restore_value(key)
|
437
|
+
|
438
|
+
# restore other state values
|
439
|
+
oth_in_state = [
|
440
|
+
st for st in state_trace.states
|
441
|
+
if st not in in_state_to_axis and st not in rng_sets
|
442
|
+
]
|
443
|
+
assert len(oth_in_state) == len(in_state_oth_vals), (
|
444
|
+
f"The number of states in 'in_states' should be equal to the number "
|
445
|
+
f"of state values, but got {len(oth_in_state)} and {len(in_state_oth_vals)}."
|
446
|
+
)
|
447
|
+
for state, state_val in zip(oth_in_state, in_state_oth_vals):
|
448
|
+
state.restore_value(state_val)
|
449
|
+
|
450
|
+
# call the function
|
451
|
+
outs = stateful_fn.jaxpr_call_auto(in_state_vmap_vals, args)
|
452
|
+
|
453
|
+
# analyze vmapping axis error
|
454
|
+
for state in state_trace.get_write_states():
|
119
455
|
leaves = jax.tree.leaves(state.value)
|
120
|
-
if isinstance(
|
121
|
-
if isinstance(state, RandomState) and
|
456
|
+
if any([isinstance(leaf, BatchTracer) for leaf in leaves]) and state not in out_state_to_axis:
|
457
|
+
if isinstance(state, RandomState) and state in rng_sets:
|
122
458
|
continue
|
123
459
|
state.raise_error_with_source_info(
|
124
|
-
BatchAxisError(
|
125
|
-
|
126
|
-
)
|
460
|
+
BatchAxisError(f"The value of State {state} is batched, "
|
461
|
+
f"but it is not in the out_states.")
|
127
462
|
)
|
128
463
|
|
129
|
-
|
464
|
+
# out state values for vmapping
|
465
|
+
out_state_vmap_vals = [
|
130
466
|
[state.value for state in states]
|
131
467
|
for axis, states in axis_to_out_states.items()
|
132
468
|
]
|
133
|
-
|
469
|
+
out_state_oth_vals = [
|
470
|
+
st.value for st in state_trace.states
|
471
|
+
if st not in out_state_to_axis and st not in rng_sets
|
472
|
+
]
|
473
|
+
out_rng_keys = [rng.value for rng in rngs]
|
474
|
+
return out_rng_keys, out_state_vmap_vals, out_state_oth_vals, outs
|
134
475
|
|
476
|
+
@functools.wraps(f)
|
135
477
|
def vmapped_fn(*args):
|
136
|
-
|
137
|
-
|
478
|
+
"""
|
479
|
+
Applies vectorized mapping (vmap) to the input function while managing state.
|
480
|
+
|
481
|
+
This function handles the vectorization process, including state management,
|
482
|
+
random number generation, and function compilation. It prepares the input
|
483
|
+
states, compiles the stateful function, manages random number generators,
|
484
|
+
applies the vmap transformation, and restores the output states.
|
485
|
+
|
486
|
+
Args:
|
487
|
+
*args: Variable length argument list containing the input arguments
|
488
|
+
to be passed to the vectorized function.
|
489
|
+
|
490
|
+
Returns:
|
491
|
+
Any: The output of the vectorized function after applying vmap and
|
492
|
+
managing states.
|
493
|
+
|
494
|
+
Note:
|
495
|
+
This function assumes the existence of several helper functions and
|
496
|
+
data structures (e.g., axis_to_in_states, in_state_to_axis) which
|
497
|
+
should be defined in the broader context.
|
498
|
+
"""
|
499
|
+
# in states values
|
500
|
+
in_state_map_vals = [
|
138
501
|
[st.value for st in states]
|
139
502
|
for axis, states in axis_to_in_states.items()
|
140
503
|
]
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
504
|
+
st_in_axes = list(axis_to_in_states.keys())
|
505
|
+
if len(st_in_axes) == 0:
|
506
|
+
st_in_axes = 0
|
507
|
+
|
508
|
+
# compile stateful function
|
509
|
+
cache_key = _compile_stateful_function(
|
510
|
+
stateful_fn,
|
511
|
+
(st_in_axes, in_axes),
|
512
|
+
(in_state_map_vals, args)
|
513
|
+
)
|
514
|
+
|
515
|
+
# random keys
|
516
|
+
state_trace = stateful_fn.get_state_trace(cache_key)
|
517
|
+
rngs = state_trace.state_subset(RandomState)
|
518
|
+
rng_sets = set(rngs)
|
519
|
+
if len(rngs):
|
520
|
+
# batch size
|
521
|
+
batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
|
522
|
+
rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
|
523
|
+
rng_backup = tuple(rng.split_key() for rng in rngs)
|
524
|
+
else:
|
525
|
+
rng_keys = tuple()
|
526
|
+
rng_backup = tuple()
|
527
|
+
|
528
|
+
# in states other values
|
529
|
+
in_state_oth_vals = [
|
530
|
+
st.value
|
531
|
+
for st in state_trace.states
|
532
|
+
if st not in in_state_to_axis and st not in rng_sets
|
533
|
+
]
|
534
|
+
|
535
|
+
# out state axis
|
536
|
+
st_out_axes = list(axis_to_out_states.keys())
|
537
|
+
if len(st_out_axes) == 0:
|
538
|
+
st_out_axes = 0
|
539
|
+
|
540
|
+
# --- vmapping --- #
|
541
|
+
fn = jax.vmap(
|
542
|
+
new_fn_for_vmap,
|
543
|
+
in_axes=(0, st_in_axes, None, in_axes),
|
544
|
+
out_axes=(0, st_out_axes, None, out_axes),
|
545
|
+
axis_size=axis_size,
|
546
|
+
axis_name=axis_name,
|
547
|
+
spmd_axis_name=spmd_axis_name,
|
548
|
+
)
|
549
|
+
_, out_state_map_vals, out_state_oth_vals, outs = fn(
|
550
|
+
rng_keys, in_state_map_vals, in_state_oth_vals, args
|
551
|
+
)
|
156
552
|
|
157
553
|
# restore mapped state values
|
158
554
|
for i, states in enumerate(axis_to_out_states.values()):
|
159
|
-
|
555
|
+
assert len(states) == len(out_state_map_vals[i]), (
|
556
|
+
f"The number of states in axis {i} should be equal to the number "
|
557
|
+
f"of state values, but got {len(states)} and {len(out_state_map_vals[i])}."
|
558
|
+
)
|
559
|
+
for state, st_val in zip(states, out_state_map_vals[i]):
|
160
560
|
state.restore_value(st_val)
|
561
|
+
|
562
|
+
# restore other state values
|
563
|
+
out_oth_states = [
|
564
|
+
st for st in state_trace.states
|
565
|
+
if st not in out_state_to_axis and st not in rng_sets
|
566
|
+
]
|
567
|
+
assert len(out_oth_states) == len(out_state_oth_vals), (
|
568
|
+
f"The number of states in 'out_states' should be equal to the number "
|
569
|
+
f"of state values, but got {len(out_oth_states)} and {len(out_state_oth_vals)}."
|
570
|
+
)
|
571
|
+
for state, st_val in zip(out_oth_states, out_state_oth_vals):
|
572
|
+
state.restore_value(st_val)
|
573
|
+
|
574
|
+
# restore random keys
|
575
|
+
for rng, key in zip(rngs, rng_backup):
|
576
|
+
rng.restore_value(key)
|
161
577
|
return outs
|
162
578
|
|
163
579
|
return vmapped_fn
|
@@ -166,15 +582,15 @@ def _vmap_transform(
|
|
166
582
|
def vmap(
|
167
583
|
fn: F | Missing = Missing(),
|
168
584
|
*,
|
585
|
+
# --- normal jax.vmap arguments --- #
|
169
586
|
in_axes: int | None | Sequence[Any] = 0,
|
170
587
|
out_axes: Any = 0,
|
171
588
|
axis_name: AxisName | None = None,
|
172
589
|
axis_size: int | None = None,
|
173
590
|
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
174
|
-
# brainstate specific arguments
|
591
|
+
# --- brainstate specific arguments --- #
|
175
592
|
in_states: Dict[int, Dict] | Any | None = None,
|
176
593
|
out_states: Dict[int, Dict] | Any | None = None,
|
177
|
-
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
178
594
|
) -> F | Callable[[F], F]:
|
179
595
|
"""
|
180
596
|
Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
@@ -229,10 +645,6 @@ def vmap(
|
|
229
645
|
corresponds to the outermost :func:`vmap` call, the second element to
|
230
646
|
the next outermost, and so on. If the tuple is not provided, the
|
231
647
|
``axis_name`` is used for all nested :func:`vmap` calls.
|
232
|
-
rngs: Optional, a random number generator or sequence of random number
|
233
|
-
generators to be used in the mapped function. These random number
|
234
|
-
generators are restored their random key after the mapped function is
|
235
|
-
executed.
|
236
648
|
in_states: Optional, the :class:`State` objects to be mapped over in the inputs.
|
237
649
|
out_states: Optional, the :class:`State` objects to be mapped over in the outputs.
|
238
650
|
|
@@ -254,7 +666,6 @@ def vmap(
|
|
254
666
|
axis_name=axis_name,
|
255
667
|
axis_size=axis_size,
|
256
668
|
spmd_axis_name=spmd_axis_name,
|
257
|
-
rngs=rngs,
|
258
669
|
) # type: ignore[return-value]
|
259
670
|
|
260
671
|
return _vmap_transform(
|
@@ -266,7 +677,6 @@ def vmap(
|
|
266
677
|
axis_name=axis_name,
|
267
678
|
axis_size=axis_size,
|
268
679
|
spmd_axis_name=spmd_axis_name,
|
269
|
-
rngs=rngs
|
270
680
|
)
|
271
681
|
|
272
682
|
|
@@ -511,3 +921,113 @@ def map(
|
|
511
921
|
|
512
922
|
def flatten_(x):
|
513
923
|
return x.reshape(-1, *x.shape[2:])
|
924
|
+
|
925
|
+
|
926
|
+
def _vmap_new_states_transform(
|
927
|
+
fun: Callable[..., Any],
|
928
|
+
*,
|
929
|
+
# -- normal jax.vmap arguments -- #
|
930
|
+
in_axes: int | None | Sequence[Any] = 0,
|
931
|
+
out_axes: Any = 0,
|
932
|
+
axis_name: AxisName | None = None,
|
933
|
+
axis_size: int | None = None,
|
934
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
935
|
+
# -- brainstate specific arguments -- #
|
936
|
+
state_tag: str | None = None,
|
937
|
+
state_to_exclude: Filter | None = None,
|
938
|
+
):
|
939
|
+
|
940
|
+
# TODO: How about nested call ``vmap_new_states``?
|
941
|
+
|
942
|
+
|
943
|
+
@vmap(
|
944
|
+
in_axes=in_axes,
|
945
|
+
out_axes=out_axes,
|
946
|
+
axis_name=axis_name,
|
947
|
+
axis_size=axis_size,
|
948
|
+
spmd_axis_name=spmd_axis_name,
|
949
|
+
)
|
950
|
+
def new_fun(args):
|
951
|
+
# call the function
|
952
|
+
with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
|
953
|
+
out = fun(*args)
|
954
|
+
|
955
|
+
# get vmap state values
|
956
|
+
vmap_state_vals = catcher.get_state_values()
|
957
|
+
|
958
|
+
return out, vmap_state_vals
|
959
|
+
|
960
|
+
@functools.wraps(fun)
|
961
|
+
def vmapped_fn(*args):
|
962
|
+
# vmapping
|
963
|
+
with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
|
964
|
+
outs, vmap_state_vals = new_fun(args)
|
965
|
+
vmap_states = catcher.get_states()
|
966
|
+
|
967
|
+
# restore vmapped state values
|
968
|
+
for st_val, st in zip(vmap_state_vals, vmap_states):
|
969
|
+
st.restore_value(st_val)
|
970
|
+
# ------------------------------------------------
|
971
|
+
# --- this is CRUCIAL to avoid jax tracing leakage
|
972
|
+
# ------------------------------------------------
|
973
|
+
st.decrease_stack_level()
|
974
|
+
return outs
|
975
|
+
|
976
|
+
return vmapped_fn
|
977
|
+
|
978
|
+
|
979
|
+
def vmap_new_states(
|
980
|
+
fun: Callable = Missing(),
|
981
|
+
*,
|
982
|
+
# -- normal jax.vmap arguments -- #
|
983
|
+
in_axes: int | None | Sequence[Any] = 0,
|
984
|
+
out_axes: Any = 0,
|
985
|
+
axis_name: AxisName | None = None,
|
986
|
+
axis_size: int | None = None,
|
987
|
+
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
988
|
+
# -- brainstate specific arguments -- #
|
989
|
+
state_tag: str | None = None,
|
990
|
+
state_to_exclude: Filter = None,
|
991
|
+
):
|
992
|
+
"""
|
993
|
+
Vectorize a function over new states created within it.
|
994
|
+
|
995
|
+
This function applies JAX's vmap transformation to newly created states
|
996
|
+
during the function's execution. It allows for more
|
997
|
+
flexible vectorization in the context of stateful computations.
|
998
|
+
|
999
|
+
Args:
|
1000
|
+
fun (Callable, optional): The function to be vectorized. Defaults to Missing().
|
1001
|
+
in_axes (int | None | Sequence[Any], optional): Specification of input axes for vectorization. Defaults to 0.
|
1002
|
+
out_axes (Any, optional): Specification of output axes after vectorization. Defaults to 0.
|
1003
|
+
axis_name (AxisName, optional): Name of the axis being vectorized over. Defaults to None.
|
1004
|
+
axis_size (int, optional): Size of the axis being vectorized over. Defaults to None.
|
1005
|
+
spmd_axis_name (AxisName | tuple[AxisName, ...], optional): Name(s) of SPMD axis/axes. Defaults to None.
|
1006
|
+
state_tag (str, optional): A tag to identify specific states. Defaults to None.
|
1007
|
+
state_to_exclude (Sequence[int], optional): Indices of states to exclude from vectorization. Defaults to ().
|
1008
|
+
|
1009
|
+
Returns:
|
1010
|
+
Callable: A vectorized version of the input function that handles new state creation.
|
1011
|
+
"""
|
1012
|
+
if isinstance(fun, Missing):
|
1013
|
+
return functools.partial(
|
1014
|
+
_vmap_new_states_transform,
|
1015
|
+
in_axes=in_axes,
|
1016
|
+
out_axes=out_axes,
|
1017
|
+
axis_name=axis_name,
|
1018
|
+
axis_size=axis_size,
|
1019
|
+
spmd_axis_name=spmd_axis_name,
|
1020
|
+
state_tag=state_tag,
|
1021
|
+
state_to_exclude=state_to_exclude,
|
1022
|
+
)
|
1023
|
+
else:
|
1024
|
+
return _vmap_new_states_transform(
|
1025
|
+
fun,
|
1026
|
+
in_axes=in_axes,
|
1027
|
+
out_axes=out_axes,
|
1028
|
+
axis_name=axis_name,
|
1029
|
+
axis_size=axis_size,
|
1030
|
+
spmd_axis_name=spmd_axis_name,
|
1031
|
+
state_tag=state_tag,
|
1032
|
+
state_to_exclude=state_to_exclude,
|
1033
|
+
)
|