brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/transform/_mapping.py
CHANGED
@@ -61,80 +61,93 @@ def vmap(
|
|
61
61
|
# --- brainstate specific arguments --- #
|
62
62
|
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
63
63
|
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
64
|
+
unexpected_out_state_mapping: str = 'raise',
|
64
65
|
) -> StatefulMapping | Callable[[F], StatefulMapping]:
|
65
66
|
"""
|
66
|
-
|
67
|
+
Vectorize a callable while preserving BrainState state semantics.
|
67
68
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
These are several example usage::
|
75
|
-
|
76
|
-
>>> import brainstate as brainstate
|
77
|
-
>>> import jax.numpy as jnp
|
78
|
-
|
79
|
-
>>> class Model(brainstate.nn.Module):
|
80
|
-
>>> def __init__(self):
|
81
|
-
>>> super().__init__()
|
82
|
-
>>>
|
83
|
-
>>> self.a = brainstate.ShortTermState(brainstate.random.randn(5))
|
84
|
-
>>> self.b = brainstate.ShortTermState(brainstate.random.randn(5))
|
85
|
-
>>> self.c = brainstate.State(brainstate.random.randn(1))
|
86
|
-
|
87
|
-
>>> def __call__(self, *args, **kwargs):
|
88
|
-
>>> self.c.value = self.a.value * self.b.value
|
89
|
-
>>> return self.c.value + 1.
|
90
|
-
|
91
|
-
>>> model = Model()
|
92
|
-
|
93
|
-
>>> r = brainstate.transform.vmap(
|
94
|
-
>>> model,
|
95
|
-
>>> in_states=model.states(brainstate.ShortTermState),
|
96
|
-
>>> out_states=model.c
|
97
|
-
>>> )()
|
69
|
+
This helper mirrors :func:`jax.vmap` but routes execution through
|
70
|
+
:class:`~brainstate.transform.StatefulMapping` so that reads and writes to
|
71
|
+
:class:`~brainstate.State` instances (including newly created random states)
|
72
|
+
are tracked correctly across the mapped axis. The returned object can be used
|
73
|
+
directly or as a decorator when ``fn`` is omitted.
|
98
74
|
|
99
75
|
Parameters
|
100
76
|
----------
|
101
77
|
fn : callable, optional
|
102
|
-
Function to be
|
103
|
-
in_axes : int
|
104
|
-
|
105
|
-
|
106
|
-
out_axes :
|
107
|
-
|
108
|
-
|
109
|
-
in the output.
|
78
|
+
Function to be vectorised. If omitted, the function acts as a decorator.
|
79
|
+
in_axes : int | None | sequence, default 0
|
80
|
+
Mapping specification for positional arguments, following the semantics
|
81
|
+
of :func:`jax.vmap`.
|
82
|
+
out_axes : any, default 0
|
83
|
+
Placement of the mapped axis in the result. Must broadcast with the
|
84
|
+
structure of the outputs.
|
110
85
|
axis_name : hashable, optional
|
111
|
-
|
112
|
-
|
86
|
+
Name for the mapped axis so that collective primitives (e.g. ``lax.psum``)
|
87
|
+
can target it.
|
113
88
|
axis_size : int, optional
|
114
|
-
|
115
|
-
|
116
|
-
spmd_axis_name : hashable or tuple
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
89
|
+
Explicit size of the mapped axis. If omitted, the size is inferred from
|
90
|
+
the arguments.
|
91
|
+
spmd_axis_name : hashable or tuple[hashable], optional
|
92
|
+
Axis labels used when the transformed function is itself executed inside
|
93
|
+
another SPMD transform (e.g. nested :func:`vmap` or :func:`pmap`).
|
94
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
95
|
+
Filters identifying which :class:`State` objects should be batched on
|
96
|
+
input. Passing a single filter is shorthand for ``{0: filter}``. Filters
|
97
|
+
are converted with :func:`brainstate.util.filter.to_predicate`.
|
98
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
99
|
+
Filters describing how written states are scattered back across the
|
100
|
+
mapped axis. Semantics mirror ``state_in_axes``.
|
101
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
102
|
+
Policy when a state is written during the mapped call but not matched by
|
103
|
+
``state_out_axes``. ``'raise'`` propagates a :class:`BatchAxisError`,
|
104
|
+
``'warn'`` emits a warning, and ``'ignore'`` silently accepts the state.
|
129
105
|
|
130
106
|
Returns
|
131
107
|
-------
|
132
|
-
callable
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
108
|
+
StatefulMapping or callable
|
109
|
+
If ``fn`` is supplied, returns a :class:`StatefulMapping` instance that
|
110
|
+
behaves like ``fn`` but with batch semantics. Otherwise a decorator is
|
111
|
+
returned.
|
112
|
+
|
113
|
+
Raises
|
114
|
+
------
|
115
|
+
ValueError
|
116
|
+
If axis sizes are inconsistent or cannot be inferred.
|
117
|
+
BatchAxisError
|
118
|
+
If a state write violates ``state_out_axes`` and the policy is ``'raise'``.
|
119
|
+
|
120
|
+
Examples
|
121
|
+
--------
|
122
|
+
.. code-block:: python
|
123
|
+
|
124
|
+
>>> import brainstate as bst
|
125
|
+
>>> import jax.numpy as jnp
|
126
|
+
>>> from brainstate.util.filter import OfType
|
127
|
+
>>>
|
128
|
+
>>> counter = bst.ShortTermState(jnp.array(0.0))
|
129
|
+
>>>
|
130
|
+
>>> @bst.transform.vmap(
|
131
|
+
... in_axes=0,
|
132
|
+
... out_axes=0,
|
133
|
+
... state_in_axes={0: OfType(bst.ShortTermState)},
|
134
|
+
... state_out_axes={0: OfType(bst.ShortTermState)},
|
135
|
+
... )
|
136
|
+
... def accumulate(x):
|
137
|
+
... counter.value = counter.value + x
|
138
|
+
... return counter.value
|
139
|
+
>>>
|
140
|
+
>>> xs = jnp.arange(3.0)
|
141
|
+
>>> accumulate(xs)
|
142
|
+
Array([0., 1., 3.], dtype=float32)
|
143
|
+
>>> counter.value
|
144
|
+
Array(3., dtype=float32)
|
145
|
+
|
146
|
+
See Also
|
147
|
+
--------
|
148
|
+
brainstate.transform.StatefulMapping : Underlying state-aware mapping helper.
|
149
|
+
pmap : Parallel mapping variant for multiple devices.
|
150
|
+
vmap_new_states : Vectorize newly created states within ``fn``.
|
138
151
|
"""
|
139
152
|
|
140
153
|
if isinstance(fn, Missing):
|
@@ -147,6 +160,7 @@ def vmap(
|
|
147
160
|
axis_name=axis_name,
|
148
161
|
axis_size=axis_size,
|
149
162
|
spmd_axis_name=spmd_axis_name,
|
163
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
150
164
|
) # type: ignore[return-value]
|
151
165
|
|
152
166
|
return StatefulMapping(
|
@@ -157,7 +171,9 @@ def vmap(
|
|
157
171
|
state_out_axes=state_out_axes,
|
158
172
|
axis_name=axis_name,
|
159
173
|
axis_size=axis_size,
|
160
|
-
|
174
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
175
|
+
mapping_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name),
|
176
|
+
name='vmap'
|
161
177
|
)
|
162
178
|
|
163
179
|
|
@@ -177,96 +193,97 @@ def pmap(
|
|
177
193
|
# --- brainstate specific arguments --- #
|
178
194
|
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
179
195
|
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
196
|
+
unexpected_out_state_mapping: str = 'raise',
|
180
197
|
) -> Callable[[F], F] | F:
|
181
198
|
"""
|
182
|
-
Parallel
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
|
190
|
-
mapped axis down into primitive operations, :py:func:`pmap` instead replicates
|
191
|
-
the function and executes each replica on its own XLA device in parallel.
|
192
|
-
|
193
|
-
The mapped axis size must be less than or equal to the number of local XLA
|
194
|
-
devices available, as returned by :py:func:`jax.local_device_count()` (unless
|
195
|
-
``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
|
196
|
-
product of the mapped axis sizes must be less than or equal to the number of
|
197
|
-
XLA devices.
|
198
|
-
|
199
|
-
More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
|
200
|
-
|
201
|
-
|
202
|
-
Args:
|
203
|
-
fn: Function to be mapped over argument axes. Its arguments and return
|
204
|
-
value should be arrays, scalars, or (nested) standard Python containers
|
205
|
-
(tuple/list/dict) thereof. Positional arguments indicated by
|
206
|
-
``static_broadcasted_argnums`` can be anything at all, provided they are
|
207
|
-
hashable and have an equality operation defined.
|
208
|
-
axis_name: Optional, a hashable Python object used to identify the mapped
|
209
|
-
axis so that parallel collectives can be applied.
|
210
|
-
in_axes: A non-negative integer, None, or nested Python container thereof
|
211
|
-
that specifies which axes of positional arguments to map over. Arguments
|
212
|
-
passed as keywords are always mapped over their leading axis (i.e. axis
|
213
|
-
index 0). See :py:func:`vmap` for details.
|
214
|
-
out_axes: A non-negative integer, None, or nested Python container thereof
|
215
|
-
indicating where the mapped axis should appear in the output. All outputs
|
216
|
-
with a mapped axis must have a non-None ``out_axes`` specification
|
217
|
-
(see :py:func:`vmap`).
|
218
|
-
static_broadcasted_argnums: An int or collection of ints specifying which
|
219
|
-
positional arguments to treat as static (compile-time constant).
|
220
|
-
Operations that only depend on static arguments will be constant-folded.
|
221
|
-
Calling the pmapped function with different values for these constants
|
222
|
-
will trigger recompilation. If the pmapped function is called with fewer
|
223
|
-
positional arguments than indicated by ``static_broadcasted_argnums`` then
|
224
|
-
an error is raised. Each of the static arguments will be broadcasted to
|
225
|
-
all devices. Arguments that are not arrays or containers thereof must be
|
226
|
-
marked as static. Defaults to ().
|
227
|
-
|
228
|
-
Static arguments must be hashable, meaning both ``__hash__`` and
|
229
|
-
``__eq__`` are implemented, and should be immutable.
|
230
|
-
|
231
|
-
devices: This is an experimental feature and the API is likely to change.
|
232
|
-
Optional, a sequence of Devices to map over. (Available devices can be
|
233
|
-
retrieved via jax.devices()). Must be given identically for each process
|
234
|
-
in multi-process settings (and will therefore include devices across
|
235
|
-
processes). If specified, the size of the mapped axis must be equal to
|
236
|
-
the number of devices in the sequence local to the given process. Nested
|
237
|
-
:py:func:`pmap` s with ``devices`` specified in either the inner or outer
|
238
|
-
:py:func:`pmap` are not yet supported.
|
239
|
-
backend: This is an experimental feature and the API is likely to change.
|
240
|
-
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
|
241
|
-
axis_size: Optional; the size of the mapped axis.
|
242
|
-
donate_argnums: Specify which positional argument buffers are "donated" to
|
243
|
-
the computation. It is safe to donate argument buffers if you no longer need
|
244
|
-
them once the computation has finished. In some cases XLA can make use of
|
245
|
-
donated buffers to reduce the amount of memory needed to perform a
|
246
|
-
computation, for example recycling one of your input buffers to store a
|
247
|
-
result. You should not reuse buffers that you donate to a computation, JAX
|
248
|
-
will raise an error if you try to.
|
249
|
-
Note that donate_argnums only work for positional arguments, and keyword
|
250
|
-
arguments will not be donated.
|
251
|
-
|
252
|
-
For more details on buffer donation see the
|
253
|
-
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
254
|
-
global_arg_shapes: Optional; a tuple of tuples of integers representing the
|
255
|
-
shapes of the global arguments. These are arguments that are not replicated
|
256
|
-
across devices, but are broadcasted to all devices. The tuple should have
|
257
|
-
the same length as the number of global arguments, and each inner tuple
|
258
|
-
should have the same length as the corresponding argument. The shapes of
|
259
|
-
the global arguments must be the same on all devices.
|
260
|
-
rngs: Optional, a random number generator or sequence of random number
|
261
|
-
generators to be used in the mapped function. These random number
|
262
|
-
generators are restored their random key after the mapped function is
|
263
|
-
executed.
|
264
|
-
|
265
|
-
Returns:
|
266
|
-
A parallelized version of ``fun`` with arguments that correspond to those of
|
267
|
-
``fun`` but with extra array axes at positions indicated by ``in_axes`` and
|
268
|
-
with output that has an additional leading array axis (with the same size).
|
199
|
+
Parallel mapping with state-aware semantics across devices.
|
200
|
+
|
201
|
+
This function mirrors :func:`jax.pmap` but integrates with
|
202
|
+
:class:`~brainstate.transform.StatefulMapping` so that
|
203
|
+
:class:`~brainstate.State` objects (including random states) are replicated
|
204
|
+
and restored correctly on every device. When ``fn`` is omitted the function
|
205
|
+
can be used as a decorator.
|
269
206
|
|
207
|
+
Parameters
|
208
|
+
----------
|
209
|
+
fn : callable, optional
|
210
|
+
Function to execute in SPMD style. If omitted, a decorator is returned.
|
211
|
+
axis_name : hashable, optional
|
212
|
+
Name for the mapped axis used by collective primitives.
|
213
|
+
in_axes : any, default 0
|
214
|
+
Axis mapping for positional arguments, identical to :func:`jax.pmap`.
|
215
|
+
out_axes : any, default 0
|
216
|
+
Placement of the mapped axis in the outputs.
|
217
|
+
static_broadcasted_argnums : int or iterable[int], default ()
|
218
|
+
Indices of positional arguments to treat as compile-time constants.
|
219
|
+
devices : sequence[Device], optional
|
220
|
+
Explicit device list to map over. Must be identical on every host in
|
221
|
+
multi-host setups.
|
222
|
+
backend : str, optional
|
223
|
+
Backend identifier (``'cpu'``, ``'gpu'``, or ``'tpu'``).
|
224
|
+
axis_size : int, optional
|
225
|
+
Size of the mapped axis. Defaults to ``len(devices)`` or the local device
|
226
|
+
count when ``devices`` is ``None``.
|
227
|
+
donate_argnums : int or iterable[int], default ()
|
228
|
+
Positional arguments whose buffers may be donated to the computation.
|
229
|
+
global_arg_shapes : tuple[tuple[int, ...], ...], optional
|
230
|
+
Shapes for globally distributed arguments (i.e. arguments not replicated
|
231
|
+
across devices).
|
232
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
233
|
+
Filters indicating which states should be treated as device-mapped inputs.
|
234
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
235
|
+
Filters describing how state writes are scattered back to devices.
|
236
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
237
|
+
Policy applied when a state write is not covered by ``state_out_axes``.
|
238
|
+
rngs : Any, optional
|
239
|
+
Optional RNG seeds passed through to ``fn``. They are restored to their
|
240
|
+
original values after execution.
|
241
|
+
|
242
|
+
Returns
|
243
|
+
-------
|
244
|
+
StatefulMapping or callable
|
245
|
+
If ``fn`` is provided, returns a :class:`StatefulMapping` executing ``fn``
|
246
|
+
over devices. Otherwise returns a decorator that produces such an object.
|
247
|
+
|
248
|
+
Raises
|
249
|
+
------
|
250
|
+
ValueError
|
251
|
+
If ``axis_size`` or argument shapes are inconsistent.
|
252
|
+
BatchAxisError
|
253
|
+
If an unexpected state write occurs and the policy is ``'raise'``.
|
254
|
+
|
255
|
+
Examples
|
256
|
+
--------
|
257
|
+
.. code-block:: python
|
258
|
+
|
259
|
+
>>> import brainstate as bst
|
260
|
+
>>> import jax.numpy as jnp
|
261
|
+
>>> from brainstate.util.filter import OfType
|
262
|
+
>>>
|
263
|
+
>>> weights = bst.ParamState(jnp.ones((4,)))
|
264
|
+
>>>
|
265
|
+
>>> @bst.transform.pmap(
|
266
|
+
... axis_name='devices',
|
267
|
+
... in_axes=0,
|
268
|
+
... out_axes=0,
|
269
|
+
... state_in_axes={0: OfType(bst.ParamState)},
|
270
|
+
... state_out_axes={0: OfType(bst.ParamState)},
|
271
|
+
... )
|
272
|
+
... def update(delta):
|
273
|
+
... weights.value = weights.value + delta
|
274
|
+
... return weights.value
|
275
|
+
>>>
|
276
|
+
>>> deltas = jnp.arange(jax.local_device_count() * 4.).reshape(
|
277
|
+
... jax.local_device_count(), 4
|
278
|
+
... )
|
279
|
+
>>> updated = update(deltas)
|
280
|
+
>>> updated.shape
|
281
|
+
(jax.local_device_count(), 4)
|
282
|
+
|
283
|
+
See Also
|
284
|
+
--------
|
285
|
+
jax.pmap : Underlying JAX primitive.
|
286
|
+
vmap : Single-host vectorisation with the same state semantics.
|
270
287
|
"""
|
271
288
|
|
272
289
|
if isinstance(fn, Missing):
|
@@ -281,6 +298,7 @@ def pmap(
|
|
281
298
|
axis_size=axis_size,
|
282
299
|
donate_argnums=donate_argnums,
|
283
300
|
global_arg_shapes=global_arg_shapes,
|
301
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
284
302
|
) # type: ignore[return-value]
|
285
303
|
|
286
304
|
return StatefulMapping(
|
@@ -299,6 +317,8 @@ def pmap(
|
|
299
317
|
donate_argnums=donate_argnums,
|
300
318
|
global_arg_shapes=global_arg_shapes,
|
301
319
|
),
|
320
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
321
|
+
name='pmap'
|
302
322
|
)
|
303
323
|
|
304
324
|
|
@@ -337,53 +357,56 @@ def map(
|
|
337
357
|
batch_size: int | None = None,
|
338
358
|
):
|
339
359
|
"""
|
340
|
-
|
360
|
+
Apply a Python function over the leading axis of one or more pytrees.
|
341
361
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
362
|
+
Compared with :func:`jax.vmap`, this helper executes sequentially by default
|
363
|
+
(via :func:`jax.lax.scan`), making it useful when auto-vectorisation is
|
364
|
+
impractical or when memory usage must be reduced. Providing ``batch_size``
|
365
|
+
enables chunked evaluation that internally leverages :func:`vmap` to improve
|
366
|
+
throughput while keeping peak memory bounded.
|
346
367
|
|
347
|
-
|
348
|
-
|
368
|
+
Parameters
|
369
|
+
----------
|
370
|
+
f : callable
|
371
|
+
Function applied element-wise across the leading dimension. Its return
|
372
|
+
value must be a pytree whose leaves can be stacked along axis ``0``.
|
373
|
+
*xs : Any
|
374
|
+
Positional pytrees sharing the same length along their leading axis.
|
375
|
+
batch_size : int, optional
|
376
|
+
Size of vectorised blocks. When given, ``map`` first processes full
|
377
|
+
batches using :func:`vmap` then handles any remainder sequentially.
|
349
378
|
|
350
|
-
|
351
|
-
|
379
|
+
Returns
|
380
|
+
-------
|
381
|
+
Any
|
382
|
+
PyTree matching the structure of ``f``'s outputs with results stacked
|
383
|
+
along the leading dimension.
|
352
384
|
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
385
|
+
Raises
|
386
|
+
------
|
387
|
+
ValueError
|
388
|
+
If the inputs do not share the same leading length.
|
357
389
|
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
divisible by the batch size, the remainder is processed in a separate ``vmap`` and
|
362
|
-
concatenated to the result.
|
390
|
+
Examples
|
391
|
+
--------
|
392
|
+
.. code-block:: python
|
363
393
|
|
364
394
|
>>> import jax.numpy as jnp
|
365
|
-
>>>
|
366
|
-
>>>
|
367
|
-
|
368
|
-
|
369
|
-
>>>
|
370
|
-
|
371
|
-
|
372
|
-
>>>
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
``xs``.
|
381
|
-
xs: values over which to map along the leading axis.
|
382
|
-
batch_size: (optional) integer specifying the size of the batch for each step to execute
|
383
|
-
in parallel.
|
384
|
-
|
385
|
-
Returns:
|
386
|
-
Mapped values.
|
395
|
+
>>> from brainstate.transform import map
|
396
|
+
>>>
|
397
|
+
>>> xs = jnp.arange(6).reshape(6, 1)
|
398
|
+
>>>
|
399
|
+
>>> def normalize(row):
|
400
|
+
... return row / (1.0 + jnp.linalg.norm(row))
|
401
|
+
>>>
|
402
|
+
>>> stacked = map(normalize, xs, batch_size=2)
|
403
|
+
>>> stacked.shape
|
404
|
+
(6, 1)
|
405
|
+
|
406
|
+
See Also
|
407
|
+
--------
|
408
|
+
vmap : Vectorised mapping with automatic batching.
|
409
|
+
jax.lax.scan : Primitive used for the sequential fallback.
|
387
410
|
"""
|
388
411
|
if batch_size is not None:
|
389
412
|
scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
|
@@ -422,6 +445,7 @@ def _vmap_new_states_transform(
|
|
422
445
|
state_to_exclude: Filter | None = None,
|
423
446
|
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
424
447
|
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
448
|
+
unexpected_out_state_mapping: str = 'raise',
|
425
449
|
):
|
426
450
|
# TODO: How about nested call ``vmap_new_states``?
|
427
451
|
if isinstance(axis_size, int) and axis_size <= 0:
|
@@ -435,6 +459,7 @@ def _vmap_new_states_transform(
|
|
435
459
|
spmd_axis_name=spmd_axis_name,
|
436
460
|
state_in_axes=state_in_axes,
|
437
461
|
state_out_axes=state_out_axes,
|
462
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
438
463
|
)
|
439
464
|
def new_fun(args):
|
440
465
|
# call the function
|
@@ -480,26 +505,78 @@ def vmap_new_states(
|
|
480
505
|
state_to_exclude: Filter = None,
|
481
506
|
state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
482
507
|
state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
|
508
|
+
unexpected_out_state_mapping: str = 'raise',
|
483
509
|
):
|
484
510
|
"""
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
511
|
+
Vectorise a function that creates new BrainState states on the fly.
|
512
|
+
|
513
|
+
The helper wraps :func:`vmap` but also captures states instantiated inside
|
514
|
+
``fun`` via :func:`brainstate._state.catch_new_states`. Newly created states
|
515
|
+
are materialised for each batch element and restored after execution so that
|
516
|
+
their side effects persist exactly once. When ``fun`` is omitted the helper
|
517
|
+
can be used as a decorator.
|
518
|
+
|
519
|
+
Parameters
|
520
|
+
----------
|
521
|
+
fun : callable, optional
|
522
|
+
Function to transform. If omitted, :func:`vmap_new_states` returns a
|
523
|
+
decorator expecting ``fun``.
|
524
|
+
in_axes : int | None | sequence, default 0
|
525
|
+
Mapping specification for positional arguments, following
|
526
|
+
:func:`jax.vmap` semantics.
|
527
|
+
out_axes : any, default 0
|
528
|
+
Placement of the mapped axis in the outputs.
|
529
|
+
axis_name : hashable, optional
|
530
|
+
Name of the mapped axis for collective primitives.
|
531
|
+
axis_size : int, optional
|
532
|
+
Explicit size of the mapped axis. Must be positive when provided.
|
533
|
+
spmd_axis_name : hashable or tuple[hashable], optional
|
534
|
+
Axis labels used when nesting inside other SPMD transforms.
|
535
|
+
state_tag : str, optional
|
536
|
+
Tag used to limit which newly created states are tracked.
|
537
|
+
state_to_exclude : Filter, optional
|
538
|
+
Filter describing states that should *not* participate in the mapping.
|
539
|
+
state_in_axes : dict[AxisName, Filter] or Filter, optional
|
540
|
+
Filters indicating which existing states are batched on input.
|
541
|
+
state_out_axes : dict[AxisName, Filter] or Filter, optional
|
542
|
+
Filters describing how written states are scattered over the mapped axis.
|
543
|
+
unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
|
544
|
+
Behaviour when a state write is not covered by ``state_out_axes``.
|
545
|
+
|
546
|
+
Returns
|
547
|
+
-------
|
548
|
+
callable
|
549
|
+
A function with vectorised semantics that also mirrors new state
|
550
|
+
creation across the mapped axis.
|
551
|
+
|
552
|
+
Raises
|
553
|
+
------
|
554
|
+
ValueError
|
555
|
+
If ``axis_size`` is provided and is not strictly positive.
|
556
|
+
BatchAxisError
|
557
|
+
If unexpected state writes occur and the policy is ``'raise'``.
|
558
|
+
|
559
|
+
Examples
|
560
|
+
--------
|
561
|
+
.. code-block:: python
|
562
|
+
|
563
|
+
>>> import brainstate as bst
|
564
|
+
>>> import jax.numpy as jnp
|
565
|
+
>>> from brainstate.transform import vmap_new_states
|
566
|
+
>>>
|
567
|
+
>>> @vmap_new_states(in_axes=0, out_axes=0)
|
568
|
+
... def forward(x):
|
569
|
+
... scratch = bst.ShortTermState(jnp.array(0.0), tag='scratch')
|
570
|
+
... scratch.value = scratch.value + x
|
571
|
+
... return scratch.value
|
572
|
+
>>>
|
573
|
+
>>> forward(jnp.arange(3.0))
|
574
|
+
Array([0., 1., 2.], dtype=float32)
|
575
|
+
|
576
|
+
See Also
|
577
|
+
--------
|
578
|
+
vmap : State-aware vectorisation for existing states.
|
579
|
+
catch_new_states : Context manager used internally to intercept state creation.
|
503
580
|
"""
|
504
581
|
if isinstance(fun, Missing):
|
505
582
|
return functools.partial(
|
@@ -513,6 +590,7 @@ def vmap_new_states(
|
|
513
590
|
state_to_exclude=state_to_exclude,
|
514
591
|
state_in_axes=state_in_axes,
|
515
592
|
state_out_axes=state_out_axes,
|
593
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
516
594
|
)
|
517
595
|
else:
|
518
596
|
return _vmap_new_states_transform(
|
@@ -525,5 +603,5 @@ def vmap_new_states(
|
|
525
603
|
state_tag=state_tag,
|
526
604
|
state_to_exclude=state_to_exclude,
|
527
605
|
state_in_axes=state_in_axes,
|
528
|
-
|
606
|
+
unexpected_out_state_mapping=unexpected_out_state_mapping,
|
529
607
|
)
|