brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
brainstate/augment/_mapping.py
CHANGED
@@ -15,152 +15,152 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import dataclasses
|
19
18
|
import functools
|
20
|
-
from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable,
|
19
|
+
from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Tuple, Union, Optional, Dict, List
|
21
20
|
|
22
21
|
import jax
|
22
|
+
from jax.interpreters.batching import BatchTracer
|
23
23
|
|
24
|
-
from brainstate.
|
25
|
-
from brainstate.
|
24
|
+
from brainstate._state import State, StateTraceStack
|
25
|
+
from brainstate.compile._loop_collect_return import scan
|
26
26
|
from brainstate.random import DEFAULT, RandomState
|
27
|
-
from brainstate.typing import Missing
|
28
|
-
from brainstate.util import NestedDict
|
27
|
+
from brainstate.typing import Missing
|
28
|
+
from brainstate.util import NestedDict, BrainStateError
|
29
29
|
from ._random import restore_rngs
|
30
30
|
|
31
31
|
__all__ = [
|
32
|
-
'StateAxes',
|
33
32
|
'vmap',
|
34
33
|
'pmap',
|
34
|
+
'map',
|
35
35
|
]
|
36
36
|
|
37
37
|
AxisName = Hashable
|
38
38
|
F = TypeVar("F", bound=Callable)
|
39
|
-
|
40
|
-
|
39
|
+
AxisToState = Dict[int, List[State]]
|
40
|
+
StateToAxis = Dict[State, int]
|
41
41
|
|
42
42
|
|
43
|
-
class
|
44
|
-
|
45
|
-
A class to represent the axes of a state.
|
46
|
-
|
47
|
-
This class is used to control how graph nodes like Modules are vectorized or
|
48
|
-
parallelized by specifying the axes to be applied to substates of the graph
|
49
|
-
node given a Filter.
|
50
|
-
|
51
|
-
Args:
|
52
|
-
filter_axes: A mapping from filters to axes. The axes can be an index, a carry or None.
|
53
|
-
|
54
|
-
"""
|
55
|
-
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
filter_axes: Union[Mapping[Filter, Index | Carry | None], Iterable[Tuple[Filter, Index | Carry | None]]],
|
59
|
-
):
|
60
|
-
iterable = filter_axes.items() if isinstance(filter_axes, Mapping) else filter_axes
|
61
|
-
self._filters = tuple(filter_ for filter_, _ in iterable)
|
62
|
-
self._axes = tuple(axis for _, axis in iterable)
|
63
|
-
|
64
|
-
@property
|
65
|
-
def filters(self) -> Tuple[Filter, ...]:
|
66
|
-
return self._filters
|
67
|
-
|
68
|
-
@property
|
69
|
-
def axes(self) -> Tuple[Index | Carry | None, ...]:
|
70
|
-
return self._axes
|
71
|
-
|
72
|
-
def __repr__(self):
|
73
|
-
return f'StateAxes({dict(self.items())})'
|
74
|
-
|
75
|
-
def items(self):
|
76
|
-
return zip(self.filters, self.axes)
|
77
|
-
|
78
|
-
def __eq__(self, other):
|
79
|
-
return isinstance(other, StateAxes) and self.filters == other.filters and self.axes == other.axes
|
80
|
-
|
81
|
-
def __hash__(self):
|
82
|
-
return hash((self.filters, self.axes))
|
43
|
+
class BatchAxisError(BrainStateError):
|
44
|
+
pass
|
83
45
|
|
84
46
|
|
85
|
-
def
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
ctxtag: str
|
97
|
-
|
98
|
-
def __post_init__(self):
|
99
|
-
functools.update_wrapper(self, self.f)
|
100
|
-
|
101
|
-
def __call__(self, *pure_args: Tuple[Any, ...]):
|
102
|
-
# pytree to graph
|
103
|
-
args = tree_to_graph(pure_args, ctxtag=self.ctxtag)
|
104
|
-
|
105
|
-
# call the function
|
106
|
-
out = self.f(*args)
|
107
|
-
|
108
|
-
# graph to pytree
|
109
|
-
args_out = clear_non_graph_nodes(args)
|
110
|
-
pure_args_out, pure_out = graph_to_tree(
|
111
|
-
(args_out, out),
|
112
|
-
prefix=(self.in_axes, self.out_axes),
|
113
|
-
split_fn=_map_split_fn,
|
114
|
-
ctxtag=self.ctxtag,
|
47
|
+
def _flatten_in_out_states(
|
48
|
+
in_states: Dict[int, Dict] | Any = None,
|
49
|
+
) -> Tuple[AxisToState, StateToAxis]:
|
50
|
+
if in_states is None:
|
51
|
+
return dict(), dict()
|
52
|
+
if isinstance(in_states, dict):
|
53
|
+
keys = tuple(in_states.keys())
|
54
|
+
values = tuple(in_states.values())
|
55
|
+
is_axis_in_states = (
|
56
|
+
all([isinstance(key, int) for key in keys]) and
|
57
|
+
all([isinstance(value, dict) for value in values])
|
115
58
|
)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
59
|
+
else:
|
60
|
+
is_axis_in_states = False
|
61
|
+
if is_axis_in_states:
|
62
|
+
axis_to_states = {key: list(value.values()) for key, value in in_states.items()}
|
63
|
+
state_to_axis = {}
|
64
|
+
for key, value in in_states.items():
|
65
|
+
for state in value.values():
|
66
|
+
state_to_axis[state] = key
|
67
|
+
return axis_to_states, state_to_axis
|
68
|
+
else:
|
69
|
+
in_states = jax.tree.leaves(in_states)
|
70
|
+
axis_to_states = {0: list(in_states)}
|
71
|
+
state_to_axis = {state: 0 for state in in_states}
|
72
|
+
return axis_to_states, state_to_axis
|
73
|
+
|
74
|
+
|
75
|
+
def _vmap_transform(
|
122
76
|
f: F,
|
123
77
|
*,
|
124
|
-
in_axes:
|
78
|
+
in_axes: int | None | Sequence[Any] = 0,
|
125
79
|
out_axes: Any = 0,
|
80
|
+
in_states: Dict[int, Dict] | Any | None = None,
|
81
|
+
out_states: Dict[int, Dict] | Any | None = None,
|
126
82
|
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
127
83
|
**transform_kwargs,
|
128
84
|
):
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
85
|
+
axis_to_in_states, in_state_to_axis = _flatten_in_out_states(in_states)
|
86
|
+
axis_to_out_states, out_state_to_axis = _flatten_in_out_states(out_states)
|
87
|
+
for _in_state, _axis in in_state_to_axis.items():
|
88
|
+
if _in_state in out_state_to_axis:
|
89
|
+
_out_axis = out_state_to_axis[_in_state]
|
90
|
+
if _out_axis != _axis:
|
91
|
+
_in_state.raise_error_with_source_info(
|
92
|
+
BatchAxisError(
|
93
|
+
f"State {_in_state} has been mapped to axis {_axis} in in_states, "
|
94
|
+
f"However, it is mapped to axis {_out_axis} in out_states."
|
95
|
+
)
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
out_state_to_axis[_in_state] = _axis
|
99
|
+
if _axis not in axis_to_out_states:
|
100
|
+
axis_to_out_states[_axis] = []
|
101
|
+
axis_to_out_states[_axis].append(_in_state)
|
102
|
+
if isinstance(rngs, RandomState):
|
103
|
+
rngs = (rngs,)
|
104
|
+
rng_ids = set([id(rng) for rng in rngs])
|
148
105
|
|
149
106
|
@functools.wraps(f)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
# vmap with pytree
|
157
|
-
pure_args_out, pure_out = mapped_fn(*pure_args)
|
158
|
-
|
159
|
-
# pytree to graph
|
160
|
-
_args_out, out = tree_to_graph((pure_args_out, pure_out), ctxtag=ctxtag)
|
161
|
-
return out
|
107
|
+
def new_fn(in_states_, args):
|
108
|
+
# restore state values
|
109
|
+
for i, states in enumerate(axis_to_in_states.values()):
|
110
|
+
for state, state_val in zip(states, in_states_[i]):
|
111
|
+
state.restore_value(state_val)
|
162
112
|
|
163
|
-
|
113
|
+
# call the function
|
114
|
+
with StateTraceStack() as stack:
|
115
|
+
outs = f(*args)
|
116
|
+
|
117
|
+
# analyze
|
118
|
+
for state in stack.get_write_states():
|
119
|
+
leaves = jax.tree.leaves(state.value)
|
120
|
+
if isinstance(leaves[0], BatchTracer) and state not in out_state_to_axis:
|
121
|
+
if isinstance(state, RandomState) and id(state) in rng_ids:
|
122
|
+
continue
|
123
|
+
state.raise_error_with_source_info(
|
124
|
+
BatchAxisError(
|
125
|
+
f"The value of State {state} is batched, but it is not in the out_states."
|
126
|
+
)
|
127
|
+
)
|
128
|
+
|
129
|
+
out_states_ = [
|
130
|
+
[state.value for state in states]
|
131
|
+
for axis, states in axis_to_out_states.items()
|
132
|
+
]
|
133
|
+
return out_states_, outs
|
134
|
+
|
135
|
+
def vmapped_fn(*args):
|
136
|
+
# vmapping
|
137
|
+
in_state_vals = [
|
138
|
+
[st.value for st in states]
|
139
|
+
for axis, states in axis_to_in_states.items()
|
140
|
+
]
|
141
|
+
in_axes_st = list(axis_to_in_states.keys())
|
142
|
+
out_axes_st = list(axis_to_out_states.keys())
|
143
|
+
if len(in_axes_st) == 0:
|
144
|
+
in_axes_st = 0
|
145
|
+
if len(out_axes_st) == 0:
|
146
|
+
out_axes_st = 0
|
147
|
+
out_state_vals, outs = restore_rngs(
|
148
|
+
jax.vmap(
|
149
|
+
new_fn,
|
150
|
+
in_axes=(in_axes_st, in_axes),
|
151
|
+
out_axes=(out_axes_st, out_axes),
|
152
|
+
**transform_kwargs
|
153
|
+
),
|
154
|
+
rngs=rngs
|
155
|
+
)(in_state_vals, args)
|
156
|
+
|
157
|
+
# restore mapped state values
|
158
|
+
for i, states in enumerate(axis_to_out_states.values()):
|
159
|
+
for state, st_val in zip(states, out_state_vals[i]):
|
160
|
+
state.restore_value(st_val)
|
161
|
+
return outs
|
162
|
+
|
163
|
+
return vmapped_fn
|
164
164
|
|
165
165
|
|
166
166
|
def vmap(
|
@@ -172,6 +172,8 @@ def vmap(
|
|
172
172
|
axis_size: int | None = None,
|
173
173
|
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
174
174
|
# brainstate specific arguments
|
175
|
+
in_states: Dict[int, Dict] | Any | None = None,
|
176
|
+
out_states: Dict[int, Dict] | Any | None = None,
|
175
177
|
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
176
178
|
) -> F | Callable[[F], F]:
|
177
179
|
"""
|
@@ -183,103 +185,38 @@ def vmap(
|
|
183
185
|
|
184
186
|
More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
|
185
187
|
|
186
|
-
|
187
188
|
These are several example usage::
|
188
189
|
|
189
190
|
>>> import brainstate as bst
|
190
191
|
>>> import jax.numpy as jnp
|
191
192
|
|
192
|
-
>>>
|
193
|
-
>>>
|
194
|
-
|
195
|
-
>>>
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
>>> y = forward(model, x)
|
200
|
-
>>> print(y.shape)
|
201
|
-
(5, 3)
|
202
|
-
|
203
|
-
Another example with a more complex model::
|
204
|
-
|
205
|
-
>>> class LinearEnsemble(bst.nn.Module):
|
206
|
-
... def __init__(self, n: int):
|
207
|
-
... super().__init__()
|
208
|
-
... self.n = n
|
209
|
-
... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
|
210
|
-
|
211
|
-
>>> model = LinearEnsemble(5)
|
212
|
-
>>> x = jnp.ones((2,))
|
213
|
-
|
214
|
-
>>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
|
215
|
-
... def forward(model, x):
|
216
|
-
... return jnp.dot(x, model.w.value)
|
217
|
-
|
218
|
-
>>> y = forward(model, x)
|
219
|
-
>>> print(y.shape)
|
220
|
-
(5, 3)
|
193
|
+
>>> class Model(bst.nn.Module):
|
194
|
+
>>> def __init__(self):
|
195
|
+
>>> super().__init__()
|
196
|
+
>>>
|
197
|
+
>>> self.a = bst.ShortTermState(bst.random.randn(5))
|
198
|
+
>>> self.b = bst.ShortTermState(bst.random.randn(5))
|
199
|
+
>>> self.c = bst.State(bst.random.randn(1))
|
221
200
|
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
share the parameters between the ensemble members which keeping different
|
226
|
-
batch statistics and dropout random state::
|
201
|
+
>>> def __call__(self, *args, **kwargs):
|
202
|
+
>>> self.c.value = self.a.value * self.b.value
|
203
|
+
>>> return self.c.value + 1.
|
227
204
|
|
228
|
-
>>>
|
229
|
-
... def __init__(self):
|
230
|
-
... super().__init__()
|
231
|
-
... self.a = bst.ParamState(jnp.arange(4))
|
232
|
-
... self.b = bst.ShortTermState(jnp.arange(4))
|
205
|
+
>>> model = Model()
|
233
206
|
|
234
|
-
>>>
|
235
|
-
>>>
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
>>> model = Foo()
|
240
|
-
>>> y = mul(model)
|
241
|
-
>>> print(y.shape)
|
242
|
-
(4, 4)
|
207
|
+
>>> r = bst.augment.vmap(
|
208
|
+
>>> model,
|
209
|
+
>>> in_states=model.states(bst.ShortTermState),
|
210
|
+
>>> out_states=model.c
|
211
|
+
>>> )()
|
243
212
|
|
244
213
|
Args:
|
245
214
|
fn: Function to be mapped over additional axes.
|
246
215
|
in_axes: An integer, None, or sequence of values specifying which input
|
247
216
|
array axes to map over.
|
248
|
-
|
249
|
-
If each positional argument to ``fun`` is an array, then ``in_axes`` can
|
250
|
-
be an integer, a None, or a tuple of integers and Nones with length equal
|
251
|
-
to the number of positional arguments to ``fun``. An integer or ``None``
|
252
|
-
indicates which array axis to map over for all arguments (with ``None``
|
253
|
-
indicating not to map any axis), and a tuple indicates which axis to map
|
254
|
-
for each corresponding positional argument. Axis integers must be in the
|
255
|
-
range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
|
256
|
-
dimensions (axes) of the corresponding input array.
|
257
|
-
|
258
|
-
If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
|
259
|
-
must be a sequence with length equal to the number of positional arguments to
|
260
|
-
``fun``, and for each argument the corresponding element of ``in_axes`` can
|
261
|
-
be a container with a matching pytree structure specifying the mapping of its
|
262
|
-
container elements. In other words, ``in_axes`` must be a container tree prefix
|
263
|
-
of the positional argument tuple passed to ``fun``. See this link for more detail:
|
264
|
-
https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
|
265
|
-
|
266
|
-
Either ``axis_size`` must be provided explicitly, or at least one
|
267
|
-
positional argument must have ``in_axes`` not None. The sizes of the
|
268
|
-
mapped input axes for all mapped positional arguments must all be equal.
|
269
|
-
|
270
|
-
Arguments passed as keywords are always mapped over their leading axis
|
271
|
-
(i.e. axis index 0).
|
272
|
-
|
273
|
-
See below for examples.
|
274
|
-
|
275
217
|
out_axes: An integer, None, or (nested) standard Python container
|
276
218
|
(tuple/list/dict) thereof indicating where the mapped axis should appear
|
277
|
-
in the output.
|
278
|
-
``out_axes`` specification. Axis integers must be in the range ``[-ndim,
|
279
|
-
ndim)`` for each output array, where ``ndim`` is the number of dimensions
|
280
|
-
(axes) of the array returned by the :func:`vmap`-ed function, which is one
|
281
|
-
more than the number of dimensions (axes) of the corresponding array
|
282
|
-
returned by ``fun``.
|
219
|
+
in the output.
|
283
220
|
axis_name: Optional, a hashable Python object used to identify the mapped
|
284
221
|
axis so that parallel collectives can be applied.
|
285
222
|
axis_size: Optional, an integer indicating the size of the axis to be
|
@@ -296,6 +233,8 @@ def vmap(
|
|
296
233
|
generators to be used in the mapped function. These random number
|
297
234
|
generators are restored their random key after the mapped function is
|
298
235
|
executed.
|
236
|
+
in_states: Optional, the :class:`State` objects to be mapped over in the inputs.
|
237
|
+
out_states: Optional, the :class:`State` objects to be mapped over in the outputs.
|
299
238
|
|
300
239
|
Returns:
|
301
240
|
Batched/vectorized version of ``fun`` with arguments that correspond to
|
@@ -304,23 +243,26 @@ def vmap(
|
|
304
243
|
with extra array axes at positions indicated by ``out_axes``.
|
305
244
|
|
306
245
|
"""
|
246
|
+
|
307
247
|
if isinstance(fn, Missing):
|
308
248
|
return functools.partial(
|
309
|
-
|
249
|
+
_vmap_transform,
|
310
250
|
in_axes=in_axes,
|
311
251
|
out_axes=out_axes,
|
252
|
+
in_states=in_states,
|
253
|
+
out_states=out_states,
|
312
254
|
axis_name=axis_name,
|
313
255
|
axis_size=axis_size,
|
314
256
|
spmd_axis_name=spmd_axis_name,
|
315
257
|
rngs=rngs,
|
316
258
|
) # type: ignore[return-value]
|
317
259
|
|
318
|
-
return
|
319
|
-
'vmap', # ctxtag
|
320
|
-
jax.vmap,
|
260
|
+
return _vmap_transform(
|
321
261
|
fn,
|
322
262
|
in_axes=in_axes,
|
323
263
|
out_axes=out_axes,
|
264
|
+
in_states=in_states,
|
265
|
+
out_states=out_states,
|
324
266
|
axis_name=axis_name,
|
325
267
|
axis_size=axis_size,
|
326
268
|
spmd_axis_name=spmd_axis_name,
|
@@ -363,65 +305,6 @@ def pmap(
|
|
363
305
|
|
364
306
|
More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
|
365
307
|
|
366
|
-
If there are 4 XLA devices available, the following example will execute
|
367
|
-
the function in parallel on each device::
|
368
|
-
|
369
|
-
|
370
|
-
>>> import brainstate as bst
|
371
|
-
>>> import jax.numpy as jnp
|
372
|
-
|
373
|
-
>>> model = bst.nn.Linear(2, 3)
|
374
|
-
>>> x = jnp.ones((4, 2))
|
375
|
-
|
376
|
-
>>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
|
377
|
-
... def forward(model, x):
|
378
|
-
... return model(x)
|
379
|
-
|
380
|
-
>>> y = forward(model, x)
|
381
|
-
>>> print(y.shape)
|
382
|
-
(4, 3)
|
383
|
-
|
384
|
-
Another example with a more complex model::
|
385
|
-
|
386
|
-
>>> class LinearEnsemble(bst.nn.Module):
|
387
|
-
... def __init__(self, n: int):
|
388
|
-
... super().__init__()
|
389
|
-
... self.n = n
|
390
|
-
... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
|
391
|
-
|
392
|
-
>>> model = LinearEnsemble(4)
|
393
|
-
>>> x = jnp.ones((2,))
|
394
|
-
|
395
|
-
>>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
|
396
|
-
... def forward(model, x):
|
397
|
-
... return jnp.dot(x, model.w.value)
|
398
|
-
|
399
|
-
>>> y = forward(model, x)
|
400
|
-
>>> print(y.shape)
|
401
|
-
(4, 3)
|
402
|
-
|
403
|
-
To control how different types of states are vectorized, ``StateAxes``
|
404
|
-
can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
|
405
|
-
applied to each substate given a filter. The following example shows how to
|
406
|
-
share the parameters between the ensemble members which keeping different
|
407
|
-
batch statistics and dropout random state::
|
408
|
-
|
409
|
-
>>> class Foo(bst.nn.Module):
|
410
|
-
... def __init__(self):
|
411
|
-
... super().__init__()
|
412
|
-
... self.a = bst.ParamState(jnp.arange(4))
|
413
|
-
... self.b = bst.ShortTermState(jnp.arange(4))
|
414
|
-
|
415
|
-
>>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
|
416
|
-
>>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
|
417
|
-
... def mul(foo):
|
418
|
-
... return foo.a.value * foo.b.value
|
419
|
-
|
420
|
-
>>> model = Foo()
|
421
|
-
>>> y = mul(model)
|
422
|
-
>>> print(y.shape)
|
423
|
-
(4, 4)
|
424
|
-
|
425
308
|
|
426
309
|
Args:
|
427
310
|
fn: Function to be mapped over argument axes. Its arguments and return
|
@@ -508,18 +391,123 @@ def pmap(
|
|
508
391
|
rngs=rngs,
|
509
392
|
) # type: ignore[return-value]
|
510
393
|
|
511
|
-
return
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
rngs=rngs
|
394
|
+
return restore_rngs(
|
395
|
+
jax.pmap(
|
396
|
+
fn,
|
397
|
+
in_axes=in_axes,
|
398
|
+
out_axes=out_axes,
|
399
|
+
axis_name=axis_name,
|
400
|
+
static_broadcasted_argnums=static_broadcasted_argnums,
|
401
|
+
devices=devices,
|
402
|
+
backend=backend,
|
403
|
+
axis_size=axis_size,
|
404
|
+
donate_argnums=donate_argnums,
|
405
|
+
global_arg_shapes=global_arg_shapes,
|
406
|
+
),
|
407
|
+
rngs=rngs
|
525
408
|
)
|
409
|
+
|
410
|
+
|
411
|
+
def _batch_and_remainder(x, batch_size: int):
|
412
|
+
leaves, tree_def = jax.tree.flatten(x)
|
413
|
+
|
414
|
+
scan_leaves = []
|
415
|
+
remainder_leaves = []
|
416
|
+
|
417
|
+
length = None
|
418
|
+
for leaf in leaves:
|
419
|
+
if length is None:
|
420
|
+
length = leaf.shape[0]
|
421
|
+
if length != leaf.shape[0]:
|
422
|
+
raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
|
423
|
+
|
424
|
+
num_batches, num_remainder = divmod(length, batch_size)
|
425
|
+
for leaf in leaves:
|
426
|
+
total_batch_elems = num_batches * batch_size
|
427
|
+
scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
|
428
|
+
if num_remainder:
|
429
|
+
remainder_leaves.append(leaf[total_batch_elems:])
|
430
|
+
|
431
|
+
scan_tree = tree_def.unflatten(scan_leaves)
|
432
|
+
if num_remainder:
|
433
|
+
remainder_tree = tree_def.unflatten(remainder_leaves)
|
434
|
+
return scan_tree, remainder_tree
|
435
|
+
else:
|
436
|
+
return scan_tree, None
|
437
|
+
|
438
|
+
|
439
|
+
def map(
|
440
|
+
f,
|
441
|
+
*xs,
|
442
|
+
batch_size: int | None = None,
|
443
|
+
):
|
444
|
+
"""
|
445
|
+
Map a function over leading array axes.
|
446
|
+
|
447
|
+
Like Python's builtin map, except inputs and outputs are in the form of
|
448
|
+
stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
|
449
|
+
need to apply a function element by element for reduced memory usage or
|
450
|
+
heterogeneous computation with other control flow primitives.
|
451
|
+
|
452
|
+
When ``xs`` is an array type, the semantics of :func:`~map` are given by this
|
453
|
+
Python implementation::
|
454
|
+
|
455
|
+
def map(f, *xs):
|
456
|
+
return np.stack([f(*x) for x in xs])
|
457
|
+
|
458
|
+
Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
|
459
|
+
many of the same advantages over a Python loop apply: ``xs`` may be an
|
460
|
+
arbitrary nested pytree type, and the mapped computation is compiled only
|
461
|
+
once.
|
462
|
+
|
463
|
+
If ``batch_size`` is provided, the computation is executed in batches of that size
|
464
|
+
and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
|
465
|
+
version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
|
466
|
+
divisible by the batch size, the remainder is processed in a separate ``vmap`` and
|
467
|
+
concatenated to the result.
|
468
|
+
|
469
|
+
>>> import jax.numpy as jnp
|
470
|
+
>>> x = jnp.ones((10, 3, 4))
|
471
|
+
>>> def f(x):
|
472
|
+
... print('inner shape:', x.shape)
|
473
|
+
... return x + 1
|
474
|
+
>>> y = map(f, x, batch_size=3)
|
475
|
+
inner shape: (3, 4)
|
476
|
+
inner shape: (3, 4)
|
477
|
+
>>> y.shape
|
478
|
+
(10, 3, 4)
|
479
|
+
|
480
|
+
In the example above, "inner shape" is printed twice, once while tracing the batched
|
481
|
+
computation and once while tracing the remainder computation.
|
482
|
+
|
483
|
+
Args:
|
484
|
+
f: a Python function to apply element-wise over the first axis or axes of
|
485
|
+
``xs``.
|
486
|
+
xs: values over which to map along the leading axis.
|
487
|
+
batch_size: (optional) integer specifying the size of the batch for each step to execute
|
488
|
+
in parallel.
|
489
|
+
|
490
|
+
Returns:
|
491
|
+
Mapped values.
|
492
|
+
"""
|
493
|
+
if batch_size is not None:
|
494
|
+
scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
|
495
|
+
g = lambda _, x: ((), vmap(f)(*x))
|
496
|
+
_, scan_ys = scan(g, (), scan_xs)
|
497
|
+
if remainder_xs is None:
|
498
|
+
ys = jax.tree.map(lambda x: flatten_(x), scan_ys)
|
499
|
+
else:
|
500
|
+
remainder_ys = vmap(f)(*remainder_xs)
|
501
|
+
ys = jax.tree.map(
|
502
|
+
lambda x, y: jax.lax.concatenate([flatten_(x), y], dimension=0),
|
503
|
+
scan_ys,
|
504
|
+
remainder_ys,
|
505
|
+
)
|
506
|
+
else:
|
507
|
+
g = lambda _, x: ((), f(*x))
|
508
|
+
_, ys = scan(g, (), xs)
|
509
|
+
return ys
|
510
|
+
|
511
|
+
|
512
|
+
def flatten_(x):
|
513
|
+
return x.reshape(-1, *x.shape[2:])
|