brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,529 +1,607 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from typing import (
18
- Any,
19
- TypeVar,
20
- Callable,
21
- Hashable,
22
- Sequence,
23
- Iterable,
24
- Tuple,
25
- Union,
26
- Optional,
27
- Dict
28
- )
29
-
30
- import jax
31
-
32
- from brainstate._compatible_import import Device
33
- from brainstate._state import catch_new_states
34
- from brainstate._utils import set_module_as
35
- from brainstate.typing import Missing, Filter
36
- from brainstate.util import NestedDict
37
- from ._loop_collect_return import scan
38
- from ._make_jaxpr import StatefulMapping
39
-
40
- __all__ = [
41
- 'vmap',
42
- 'pmap',
43
- 'map',
44
- 'vmap_new_states',
45
- ]
46
-
47
- F = TypeVar("F", bound=Callable)
48
- AxisName = Hashable
49
-
50
-
51
- @set_module_as('brainstate.transform')
52
- def vmap(
53
- fn: F | Missing = Missing(),
54
- *,
55
- # --- normal jax.vmap arguments --- #
56
- in_axes: int | None | Sequence[Any] = 0,
57
- out_axes: Any = 0,
58
- axis_name: AxisName | None = None,
59
- axis_size: int | None = None,
60
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
61
- # --- brainstate specific arguments --- #
62
- state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
63
- state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
64
- ) -> StatefulMapping | Callable[[F], StatefulMapping]:
65
- """
66
- Vectorizing map. Creates a function which maps ``fun`` over argument axes.
67
-
68
- The transformation :func:`vmap` is designed to work with ``pygraph`` structure
69
- defined in the ``brainstate`` library. It is used to vectorize functions by
70
- pushing the mapped axis down into primitive operations.
71
-
72
- More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
73
-
74
- These are several example usage::
75
-
76
- >>> import brainstate as brainstate
77
- >>> import jax.numpy as jnp
78
-
79
- >>> class Model(brainstate.nn.Module):
80
- >>> def __init__(self):
81
- >>> super().__init__()
82
- >>>
83
- >>> self.a = brainstate.ShortTermState(brainstate.random.randn(5))
84
- >>> self.b = brainstate.ShortTermState(brainstate.random.randn(5))
85
- >>> self.c = brainstate.State(brainstate.random.randn(1))
86
-
87
- >>> def __call__(self, *args, **kwargs):
88
- >>> self.c.value = self.a.value * self.b.value
89
- >>> return self.c.value + 1.
90
-
91
- >>> model = Model()
92
-
93
- >>> r = brainstate.transform.vmap(
94
- >>> model,
95
- >>> in_states=model.states(brainstate.ShortTermState),
96
- >>> out_states=model.c
97
- >>> )()
98
-
99
- Parameters
100
- ----------
101
- fn : callable, optional
102
- Function to be mapped over additional axes.
103
- in_axes : int, None, or sequence, default 0
104
- An integer, None, or sequence of values specifying which input
105
- array axes to map over.
106
- out_axes : int, None, or sequence, default 0
107
- An integer, None, or (nested) standard Python container
108
- (tuple/list/dict) thereof indicating where the mapped axis should appear
109
- in the output.
110
- axis_name : hashable, optional
111
- A hashable Python object used to identify the mapped
112
- axis so that parallel collectives can be applied.
113
- axis_size : int, optional
114
- An integer indicating the size of the axis to be
115
- mapped. If not provided, the mapped axis size is inferred from arguments.
116
- spmd_axis_name : hashable or tuple of hashable, optional
117
- A hashable Python object or tuple of hashable
118
- Python objects used to identify the mapped axis so that parallel collectives
119
- can be applied. This is used to specify multiple axes to be mapped over
120
- in a nested :func:`vmap` call. The length of the tuple must match the
121
- number of nested :func:`vmap` calls. The first element of the tuple
122
- corresponds to the outermost :func:`vmap` call, the second element to
123
- the next outermost, and so on. If the tuple is not provided, the
124
- ``axis_name`` is used for all nested :func:`vmap` calls.
125
- in_states : dict or State objects, optional
126
- The :class:`State` objects to be mapped over in the inputs.
127
- out_states : dict or State objects, optional
128
- The :class:`State` objects to be mapped over in the outputs.
129
-
130
- Returns
131
- -------
132
- callable
133
- Batched/vectorized version of ``fun`` with arguments that correspond to
134
- those of ``fun``, but with extra array axes at positions indicated by
135
- ``in_axes``, and a return value that corresponds to that of ``fun``, but
136
- with extra array axes at positions indicated by ``out_axes``.
137
-
138
- """
139
-
140
- if isinstance(fn, Missing):
141
- return functools.partial(
142
- vmap,
143
- in_axes=in_axes,
144
- out_axes=out_axes,
145
- state_in_axes=state_in_axes,
146
- state_out_axes=state_out_axes,
147
- axis_name=axis_name,
148
- axis_size=axis_size,
149
- spmd_axis_name=spmd_axis_name,
150
- ) # type: ignore[return-value]
151
-
152
- return StatefulMapping(
153
- fn,
154
- in_axes=in_axes,
155
- out_axes=out_axes,
156
- state_in_axes=state_in_axes,
157
- state_out_axes=state_out_axes,
158
- axis_name=axis_name,
159
- axis_size=axis_size,
160
- mapping_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name)
161
- )
162
-
163
-
164
- @set_module_as('brainstate.transform')
165
- def pmap(
166
- fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
167
- axis_name: Optional[AxisName] = None,
168
- *,
169
- in_axes: Any = 0,
170
- out_axes: Any = 0,
171
- static_broadcasted_argnums: int | Iterable[int] = (),
172
- devices: Optional[Sequence[Device]] = None, # noqa: F811
173
- backend: Optional[str] = None,
174
- axis_size: Optional[int] = None,
175
- donate_argnums: int | Iterable[int] = (),
176
- global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
177
- # --- brainstate specific arguments --- #
178
- state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
179
- state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
180
- ) -> Callable[[F], F] | F:
181
- """
182
- Parallel map with support for collective operations.
183
-
184
- The purpose of :py:func:`pmap` is to express single-program multiple-data
185
- (SPMD) programs. Applying :py:func:`pmap` to a function will compile the
186
- function with XLA (similarly to :py:func:`jit`), then execute it in parallel
187
- on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it
188
- is comparable to :py:func:`vmap` because both transformations map a function
189
- over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
190
- mapped axis down into primitive operations, :py:func:`pmap` instead replicates
191
- the function and executes each replica on its own XLA device in parallel.
192
-
193
- The mapped axis size must be less than or equal to the number of local XLA
194
- devices available, as returned by :py:func:`jax.local_device_count()` (unless
195
- ``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
196
- product of the mapped axis sizes must be less than or equal to the number of
197
- XLA devices.
198
-
199
- More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
200
-
201
-
202
- Args:
203
- fn: Function to be mapped over argument axes. Its arguments and return
204
- value should be arrays, scalars, or (nested) standard Python containers
205
- (tuple/list/dict) thereof. Positional arguments indicated by
206
- ``static_broadcasted_argnums`` can be anything at all, provided they are
207
- hashable and have an equality operation defined.
208
- axis_name: Optional, a hashable Python object used to identify the mapped
209
- axis so that parallel collectives can be applied.
210
- in_axes: A non-negative integer, None, or nested Python container thereof
211
- that specifies which axes of positional arguments to map over. Arguments
212
- passed as keywords are always mapped over their leading axis (i.e. axis
213
- index 0). See :py:func:`vmap` for details.
214
- out_axes: A non-negative integer, None, or nested Python container thereof
215
- indicating where the mapped axis should appear in the output. All outputs
216
- with a mapped axis must have a non-None ``out_axes`` specification
217
- (see :py:func:`vmap`).
218
- static_broadcasted_argnums: An int or collection of ints specifying which
219
- positional arguments to treat as static (compile-time constant).
220
- Operations that only depend on static arguments will be constant-folded.
221
- Calling the pmapped function with different values for these constants
222
- will trigger recompilation. If the pmapped function is called with fewer
223
- positional arguments than indicated by ``static_broadcasted_argnums`` then
224
- an error is raised. Each of the static arguments will be broadcasted to
225
- all devices. Arguments that are not arrays or containers thereof must be
226
- marked as static. Defaults to ().
227
-
228
- Static arguments must be hashable, meaning both ``__hash__`` and
229
- ``__eq__`` are implemented, and should be immutable.
230
-
231
- devices: This is an experimental feature and the API is likely to change.
232
- Optional, a sequence of Devices to map over. (Available devices can be
233
- retrieved via jax.devices()). Must be given identically for each process
234
- in multi-process settings (and will therefore include devices across
235
- processes). If specified, the size of the mapped axis must be equal to
236
- the number of devices in the sequence local to the given process. Nested
237
- :py:func:`pmap` s with ``devices`` specified in either the inner or outer
238
- :py:func:`pmap` are not yet supported.
239
- backend: This is an experimental feature and the API is likely to change.
240
- Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
241
- axis_size: Optional; the size of the mapped axis.
242
- donate_argnums: Specify which positional argument buffers are "donated" to
243
- the computation. It is safe to donate argument buffers if you no longer need
244
- them once the computation has finished. In some cases XLA can make use of
245
- donated buffers to reduce the amount of memory needed to perform a
246
- computation, for example recycling one of your input buffers to store a
247
- result. You should not reuse buffers that you donate to a computation, JAX
248
- will raise an error if you try to.
249
- Note that donate_argnums only work for positional arguments, and keyword
250
- arguments will not be donated.
251
-
252
- For more details on buffer donation see the
253
- `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
254
- global_arg_shapes: Optional; a tuple of tuples of integers representing the
255
- shapes of the global arguments. These are arguments that are not replicated
256
- across devices, but are broadcasted to all devices. The tuple should have
257
- the same length as the number of global arguments, and each inner tuple
258
- should have the same length as the corresponding argument. The shapes of
259
- the global arguments must be the same on all devices.
260
- rngs: Optional, a random number generator or sequence of random number
261
- generators to be used in the mapped function. These random number
262
- generators are restored their random key after the mapped function is
263
- executed.
264
-
265
- Returns:
266
- A parallelized version of ``fun`` with arguments that correspond to those of
267
- ``fun`` but with extra array axes at positions indicated by ``in_axes`` and
268
- with output that has an additional leading array axis (with the same size).
269
-
270
- """
271
-
272
- if isinstance(fn, Missing):
273
- return functools.partial(
274
- pmap,
275
- axis_name=axis_name,
276
- in_axes=in_axes,
277
- out_axes=out_axes,
278
- static_broadcasted_argnums=static_broadcasted_argnums,
279
- devices=devices,
280
- backend=backend,
281
- axis_size=axis_size,
282
- donate_argnums=donate_argnums,
283
- global_arg_shapes=global_arg_shapes,
284
- ) # type: ignore[return-value]
285
-
286
- return StatefulMapping(
287
- fn,
288
- in_axes=in_axes,
289
- out_axes=out_axes,
290
- state_in_axes=state_in_axes,
291
- state_out_axes=state_out_axes,
292
- axis_name=axis_name,
293
- axis_size=axis_size,
294
- mapping_fn=functools.partial(
295
- jax.pmap,
296
- static_broadcasted_argnums=static_broadcasted_argnums,
297
- devices=devices,
298
- backend=backend,
299
- donate_argnums=donate_argnums,
300
- global_arg_shapes=global_arg_shapes,
301
- ),
302
- )
303
-
304
-
305
- def _batch_and_remainder(x, batch_size: int):
306
- leaves, tree_def = jax.tree.flatten(x)
307
-
308
- scan_leaves = []
309
- remainder_leaves = []
310
-
311
- length = None
312
- for leaf in leaves:
313
- if length is None:
314
- length = leaf.shape[0]
315
- if length != leaf.shape[0]:
316
- raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
317
-
318
- num_batches, num_remainder = divmod(length, batch_size)
319
- for leaf in leaves:
320
- total_batch_elems = num_batches * batch_size
321
- scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
322
- if num_remainder:
323
- remainder_leaves.append(leaf[total_batch_elems:])
324
-
325
- scan_tree = tree_def.unflatten(scan_leaves)
326
- if num_remainder:
327
- remainder_tree = tree_def.unflatten(remainder_leaves)
328
- return scan_tree, remainder_tree
329
- else:
330
- return scan_tree, None
331
-
332
-
333
- @set_module_as('brainstate.transform')
334
- def map(
335
- f,
336
- *xs,
337
- batch_size: int | None = None,
338
- ):
339
- """
340
- Map a function over leading array axes.
341
-
342
- Like Python's builtin map, except inputs and outputs are in the form of
343
- stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
344
- need to apply a function element by element for reduced memory usage or
345
- heterogeneous computation with other control flow primitives.
346
-
347
- When ``xs`` is an array type, the semantics of :func:`~map` are given by this
348
- Python implementation::
349
-
350
- def map(f, *xs):
351
- return np.stack([f(*x) for x in xs])
352
-
353
- Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
354
- many of the same advantages over a Python loop apply: ``xs`` may be an
355
- arbitrary nested pytree type, and the mapped computation is compiled only
356
- once.
357
-
358
- If ``batch_size`` is provided, the computation is executed in batches of that size
359
- and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
360
- version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
361
- divisible by the batch size, the remainder is processed in a separate ``vmap`` and
362
- concatenated to the result.
363
-
364
- >>> import jax.numpy as jnp
365
- >>> x = jnp.ones((10, 3, 4))
366
- >>> def f(x):
367
- ... print('inner shape:', x.shape)
368
- ... return x + 1
369
- >>> y = map(f, x, batch_size=3)
370
- inner shape: (3, 4)
371
- inner shape: (3, 4)
372
- >>> y.shape
373
- (10, 3, 4)
374
-
375
- In the example above, "inner shape" is printed twice, once while tracing the batched
376
- computation and once while tracing the remainder computation.
377
-
378
- Args:
379
- f: a Python function to apply element-wise over the first axis or axes of
380
- ``xs``.
381
- xs: values over which to map along the leading axis.
382
- batch_size: (optional) integer specifying the size of the batch for each step to execute
383
- in parallel.
384
-
385
- Returns:
386
- Mapped values.
387
- """
388
- if batch_size is not None:
389
- scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
390
- g = lambda _, x: ((), vmap(f)(*x))
391
- _, scan_ys = scan(g, (), scan_xs)
392
- if remainder_xs is None:
393
- ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
394
- else:
395
- remainder_ys = vmap(f)(*remainder_xs)
396
- ys = jax.tree.map(
397
- lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
398
- scan_ys,
399
- remainder_ys,
400
- )
401
- else:
402
- g = lambda _, x: ((), f(*x))
403
- _, ys = scan(g, (), xs)
404
- return ys
405
-
406
-
407
- def _flatten(x):
408
- return x.reshape(-1, *x.shape[2:])
409
-
410
-
411
- def _vmap_new_states_transform(
412
- fun: Callable[..., Any],
413
- *,
414
- # -- normal jax.vmap arguments -- #
415
- in_axes: int | None | Sequence[Any] = 0,
416
- out_axes: Any = 0,
417
- axis_name: AxisName | None = None,
418
- axis_size: int | None = None,
419
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
420
- # -- brainstate specific arguments -- #
421
- state_tag: str | None = None,
422
- state_to_exclude: Filter | None = None,
423
- state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
424
- state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
425
- ):
426
- # TODO: How about nested call ``vmap_new_states``?
427
- if isinstance(axis_size, int) and axis_size <= 0:
428
- raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
429
-
430
- @vmap(
431
- in_axes=in_axes,
432
- out_axes=out_axes,
433
- axis_name=axis_name,
434
- axis_size=axis_size,
435
- spmd_axis_name=spmd_axis_name,
436
- state_in_axes=state_in_axes,
437
- state_out_axes=state_out_axes,
438
- )
439
- def new_fun(args):
440
- # call the function
441
- with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
442
- out = fun(*args)
443
-
444
- # get vmap state values
445
- vmap_state_vals = catcher.get_state_values()
446
-
447
- return out, vmap_state_vals
448
-
449
- @functools.wraps(fun)
450
- def vmapped_fn(*args):
451
- # vmapping
452
- with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
453
- outs, vmap_state_vals = new_fun(args)
454
- vmap_states = catcher.get_states()
455
-
456
- # restore vmapped state values
457
- for st_val, st in zip(vmap_state_vals, vmap_states):
458
- st.restore_value(st_val)
459
- # ------------------------------------------------
460
- # --- this is CRUCIAL to avoid jax tracing leakage
461
- # ------------------------------------------------
462
- st.decrease_stack_level()
463
- return outs
464
-
465
- return vmapped_fn
466
-
467
-
468
- @set_module_as('brainstate.transform')
469
- def vmap_new_states(
470
- fun: Callable = Missing(),
471
- *,
472
- # -- normal jax.vmap arguments -- #
473
- in_axes: int | None | Sequence[Any] = 0,
474
- out_axes: Any = 0,
475
- axis_name: AxisName | None = None,
476
- axis_size: int | None = None,
477
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
478
- # -- brainstate specific arguments -- #
479
- state_tag: str | None = None,
480
- state_to_exclude: Filter = None,
481
- state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
482
- state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
483
- ):
484
- """
485
- Vectorize a function over new states created within it.
486
-
487
- This function applies JAX's vmap transformation to newly created states
488
- during the function's execution. It allows for more
489
- flexible vectorization in the context of stateful computations.
490
-
491
- Args:
492
- fun (Callable, optional): The function to be vectorized. Defaults to Missing().
493
- in_axes (int | None | Sequence[Any], optional): Specification of input axes for vectorization. Defaults to 0.
494
- out_axes (Any, optional): Specification of output axes after vectorization. Defaults to 0.
495
- axis_name (AxisName, optional): Name of the axis being vectorized over. Defaults to None.
496
- axis_size (int, optional): Size of the axis being vectorized over. Defaults to None.
497
- spmd_axis_name (AxisName | tuple[AxisName, ...], optional): Name(s) of SPMD axis/axes. Defaults to None.
498
- state_tag (str, optional): A tag to identify specific states. Defaults to None.
499
- state_to_exclude (Sequence[int], optional): Indices of states to exclude from vectorization. Defaults to ().
500
-
501
- Returns:
502
- Callable: A vectorized version of the input function that handles new state creation.
503
- """
504
- if isinstance(fun, Missing):
505
- return functools.partial(
506
- _vmap_new_states_transform,
507
- in_axes=in_axes,
508
- out_axes=out_axes,
509
- axis_name=axis_name,
510
- axis_size=axis_size,
511
- spmd_axis_name=spmd_axis_name,
512
- state_tag=state_tag,
513
- state_to_exclude=state_to_exclude,
514
- state_in_axes=state_in_axes,
515
- state_out_axes=state_out_axes,
516
- )
517
- else:
518
- return _vmap_new_states_transform(
519
- fun,
520
- in_axes=in_axes,
521
- out_axes=out_axes,
522
- axis_name=axis_name,
523
- axis_size=axis_size,
524
- spmd_axis_name=spmd_axis_name,
525
- state_tag=state_tag,
526
- state_to_exclude=state_to_exclude,
527
- state_in_axes=state_in_axes,
528
- state_out_axes=state_out_axes,
529
- )
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from typing import (
18
+ Any,
19
+ TypeVar,
20
+ Callable,
21
+ Hashable,
22
+ Sequence,
23
+ Iterable,
24
+ Tuple,
25
+ Union,
26
+ Optional,
27
+ Dict
28
+ )
29
+
30
+ import jax
31
+
32
+ from brainstate._compatible_import import Device
33
+ from brainstate._state import catch_new_states
34
+ from brainstate._utils import set_module_as
35
+ from brainstate.typing import Missing, Filter
36
+ from brainstate.util import NestedDict
37
+ from ._loop_collect_return import scan
38
+ from ._make_jaxpr import StatefulMapping
39
+
40
+ __all__ = [
41
+ 'vmap',
42
+ 'pmap',
43
+ 'map',
44
+ 'vmap_new_states',
45
+ ]
46
+
47
+ F = TypeVar("F", bound=Callable)
48
+ AxisName = Hashable
49
+
50
+
51
+ @set_module_as('brainstate.transform')
52
+ def vmap(
53
+ fn: F | Missing = Missing(),
54
+ *,
55
+ # --- normal jax.vmap arguments --- #
56
+ in_axes: int | None | Sequence[Any] = 0,
57
+ out_axes: Any = 0,
58
+ axis_name: AxisName | None = None,
59
+ axis_size: int | None = None,
60
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
61
+ # --- brainstate specific arguments --- #
62
+ state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
63
+ state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
64
+ unexpected_out_state_mapping: str = 'raise',
65
+ ) -> StatefulMapping | Callable[[F], StatefulMapping]:
66
+ """
67
+ Vectorize a callable while preserving BrainState state semantics.
68
+
69
+ This helper mirrors :func:`jax.vmap` but routes execution through
70
+ :class:`~brainstate.transform.StatefulMapping` so that reads and writes to
71
+ :class:`~brainstate.State` instances (including newly created random states)
72
+ are tracked correctly across the mapped axis. The returned object can be used
73
+ directly or as a decorator when ``fn`` is omitted.
74
+
75
+ Parameters
76
+ ----------
77
+ fn : callable, optional
78
+ Function to be vectorised. If omitted, the function acts as a decorator.
79
+ in_axes : int | None | sequence, default 0
80
+ Mapping specification for positional arguments, following the semantics
81
+ of :func:`jax.vmap`.
82
+ out_axes : any, default 0
83
+ Placement of the mapped axis in the result. Must broadcast with the
84
+ structure of the outputs.
85
+ axis_name : hashable, optional
86
+ Name for the mapped axis so that collective primitives (e.g. ``lax.psum``)
87
+ can target it.
88
+ axis_size : int, optional
89
+ Explicit size of the mapped axis. If omitted, the size is inferred from
90
+ the arguments.
91
+ spmd_axis_name : hashable or tuple[hashable], optional
92
+ Axis labels used when the transformed function is itself executed inside
93
+ another SPMD transform (e.g. nested :func:`vmap` or :func:`pmap`).
94
+ state_in_axes : dict[AxisName, Filter] or Filter, optional
95
+ Filters identifying which :class:`State` objects should be batched on
96
+ input. Passing a single filter is shorthand for ``{0: filter}``. Filters
97
+ are converted with :func:`brainstate.util.filter.to_predicate`.
98
+ state_out_axes : dict[AxisName, Filter] or Filter, optional
99
+ Filters describing how written states are scattered back across the
100
+ mapped axis. Semantics mirror ``state_in_axes``.
101
+ unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
102
+ Policy when a state is written during the mapped call but not matched by
103
+ ``state_out_axes``. ``'raise'`` propagates a :class:`BatchAxisError`,
104
+ ``'warn'`` emits a warning, and ``'ignore'`` silently accepts the state.
105
+
106
+ Returns
107
+ -------
108
+ StatefulMapping or callable
109
+ If ``fn`` is supplied, returns a :class:`StatefulMapping` instance that
110
+ behaves like ``fn`` but with batch semantics. Otherwise a decorator is
111
+ returned.
112
+
113
+ Raises
114
+ ------
115
+ ValueError
116
+ If axis sizes are inconsistent or cannot be inferred.
117
+ BatchAxisError
118
+ If a state write violates ``state_out_axes`` and the policy is ``'raise'``.
119
+
120
+ Examples
121
+ --------
122
+ .. code-block:: python
123
+
124
+ >>> import brainstate as bst
125
+ >>> import jax.numpy as jnp
126
+ >>> from brainstate.util.filter import OfType
127
+ >>>
128
+ >>> counter = bst.ShortTermState(jnp.array(0.0))
129
+ >>>
130
+ >>> @bst.transform.vmap(
131
+ ... in_axes=0,
132
+ ... out_axes=0,
133
+ ... state_in_axes={0: OfType(bst.ShortTermState)},
134
+ ... state_out_axes={0: OfType(bst.ShortTermState)},
135
+ ... )
136
+ ... def accumulate(x):
137
+ ... counter.value = counter.value + x
138
+ ... return counter.value
139
+ >>>
140
+ >>> xs = jnp.arange(3.0)
141
+ >>> accumulate(xs)
142
+ Array([0., 1., 3.], dtype=float32)
143
+ >>> counter.value
144
+ Array(3., dtype=float32)
145
+
146
+ See Also
147
+ --------
148
+ brainstate.transform.StatefulMapping : Underlying state-aware mapping helper.
149
+ pmap : Parallel mapping variant for multiple devices.
150
+ vmap_new_states : Vectorize newly created states within ``fn``.
151
+ """
152
+
153
+ if isinstance(fn, Missing):
154
+ return functools.partial(
155
+ vmap,
156
+ in_axes=in_axes,
157
+ out_axes=out_axes,
158
+ state_in_axes=state_in_axes,
159
+ state_out_axes=state_out_axes,
160
+ axis_name=axis_name,
161
+ axis_size=axis_size,
162
+ spmd_axis_name=spmd_axis_name,
163
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
164
+ ) # type: ignore[return-value]
165
+
166
+ return StatefulMapping(
167
+ fn,
168
+ in_axes=in_axes,
169
+ out_axes=out_axes,
170
+ state_in_axes=state_in_axes,
171
+ state_out_axes=state_out_axes,
172
+ axis_name=axis_name,
173
+ axis_size=axis_size,
174
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
175
+ mapping_fn=functools.partial(jax.vmap, spmd_axis_name=spmd_axis_name),
176
+ name='vmap'
177
+ )
178
+
179
+
180
+ @set_module_as('brainstate.transform')
181
+ def pmap(
182
+ fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
183
+ axis_name: Optional[AxisName] = None,
184
+ *,
185
+ in_axes: Any = 0,
186
+ out_axes: Any = 0,
187
+ static_broadcasted_argnums: int | Iterable[int] = (),
188
+ devices: Optional[Sequence[Device]] = None, # noqa: F811
189
+ backend: Optional[str] = None,
190
+ axis_size: Optional[int] = None,
191
+ donate_argnums: int | Iterable[int] = (),
192
+ global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
193
+ # --- brainstate specific arguments --- #
194
+ state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
195
+ state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
196
+ unexpected_out_state_mapping: str = 'raise',
197
+ ) -> Callable[[F], F] | F:
198
+ """
199
+ Parallel mapping with state-aware semantics across devices.
200
+
201
+ This function mirrors :func:`jax.pmap` but integrates with
202
+ :class:`~brainstate.transform.StatefulMapping` so that
203
+ :class:`~brainstate.State` objects (including random states) are replicated
204
+ and restored correctly on every device. When ``fn`` is omitted the function
205
+ can be used as a decorator.
206
+
207
+ Parameters
208
+ ----------
209
+ fn : callable, optional
210
+ Function to execute in SPMD style. If omitted, a decorator is returned.
211
+ axis_name : hashable, optional
212
+ Name for the mapped axis used by collective primitives.
213
+ in_axes : any, default 0
214
+ Axis mapping for positional arguments, identical to :func:`jax.pmap`.
215
+ out_axes : any, default 0
216
+ Placement of the mapped axis in the outputs.
217
+ static_broadcasted_argnums : int or iterable[int], default ()
218
+ Indices of positional arguments to treat as compile-time constants.
219
+ devices : sequence[Device], optional
220
+ Explicit device list to map over. Must be identical on every host in
221
+ multi-host setups.
222
+ backend : str, optional
223
+ Backend identifier (``'cpu'``, ``'gpu'``, or ``'tpu'``).
224
+ axis_size : int, optional
225
+ Size of the mapped axis. Defaults to ``len(devices)`` or the local device
226
+ count when ``devices`` is ``None``.
227
+ donate_argnums : int or iterable[int], default ()
228
+ Positional arguments whose buffers may be donated to the computation.
229
+ global_arg_shapes : tuple[tuple[int, ...], ...], optional
230
+ Shapes for globally distributed arguments (i.e. arguments not replicated
231
+ across devices).
232
+ state_in_axes : dict[AxisName, Filter] or Filter, optional
233
+ Filters indicating which states should be treated as device-mapped inputs.
234
+ state_out_axes : dict[AxisName, Filter] or Filter, optional
235
+ Filters describing how state writes are scattered back to devices.
236
+ unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
237
+ Policy applied when a state write is not covered by ``state_out_axes``.
238
+ rngs : Any, optional
239
+ Optional RNG seeds passed through to ``fn``. They are restored to their
240
+ original values after execution.
241
+
242
+ Returns
243
+ -------
244
+ StatefulMapping or callable
245
+ If ``fn`` is provided, returns a :class:`StatefulMapping` executing ``fn``
246
+ over devices. Otherwise returns a decorator that produces such an object.
247
+
248
+ Raises
249
+ ------
250
+ ValueError
251
+ If ``axis_size`` or argument shapes are inconsistent.
252
+ BatchAxisError
253
+ If an unexpected state write occurs and the policy is ``'raise'``.
254
+
255
+ Examples
256
+ --------
257
+ .. code-block:: python
258
+
259
+ >>> import brainstate as bst
260
+ >>> import jax.numpy as jnp
261
+ >>> from brainstate.util.filter import OfType
262
+ >>>
263
+ >>> weights = bst.ParamState(jnp.ones((4,)))
264
+ >>>
265
+ >>> @bst.transform.pmap(
266
+ ... axis_name='devices',
267
+ ... in_axes=0,
268
+ ... out_axes=0,
269
+ ... state_in_axes={0: OfType(bst.ParamState)},
270
+ ... state_out_axes={0: OfType(bst.ParamState)},
271
+ ... )
272
+ ... def update(delta):
273
+ ... weights.value = weights.value + delta
274
+ ... return weights.value
275
+ >>>
276
+ >>> deltas = jnp.arange(jax.local_device_count() * 4.).reshape(
277
+ ... jax.local_device_count(), 4
278
+ ... )
279
+ >>> updated = update(deltas)
280
+ >>> updated.shape
281
+ (jax.local_device_count(), 4)
282
+
283
+ See Also
284
+ --------
285
+ jax.pmap : Underlying JAX primitive.
286
+ vmap : Single-host vectorisation with the same state semantics.
287
+ """
288
+
289
+ if isinstance(fn, Missing):
290
+ return functools.partial(
291
+ pmap,
292
+ axis_name=axis_name,
293
+ in_axes=in_axes,
294
+ out_axes=out_axes,
295
+ static_broadcasted_argnums=static_broadcasted_argnums,
296
+ devices=devices,
297
+ backend=backend,
298
+ axis_size=axis_size,
299
+ donate_argnums=donate_argnums,
300
+ global_arg_shapes=global_arg_shapes,
301
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
302
+ ) # type: ignore[return-value]
303
+
304
+ return StatefulMapping(
305
+ fn,
306
+ in_axes=in_axes,
307
+ out_axes=out_axes,
308
+ state_in_axes=state_in_axes,
309
+ state_out_axes=state_out_axes,
310
+ axis_name=axis_name,
311
+ axis_size=axis_size,
312
+ mapping_fn=functools.partial(
313
+ jax.pmap,
314
+ static_broadcasted_argnums=static_broadcasted_argnums,
315
+ devices=devices,
316
+ backend=backend,
317
+ donate_argnums=donate_argnums,
318
+ global_arg_shapes=global_arg_shapes,
319
+ ),
320
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
321
+ name='pmap'
322
+ )
323
+
324
+
325
+ def _batch_and_remainder(x, batch_size: int):
326
+ leaves, tree_def = jax.tree.flatten(x)
327
+
328
+ scan_leaves = []
329
+ remainder_leaves = []
330
+
331
+ length = None
332
+ for leaf in leaves:
333
+ if length is None:
334
+ length = leaf.shape[0]
335
+ if length != leaf.shape[0]:
336
+ raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
337
+
338
+ num_batches, num_remainder = divmod(length, batch_size)
339
+ for leaf in leaves:
340
+ total_batch_elems = num_batches * batch_size
341
+ scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
342
+ if num_remainder:
343
+ remainder_leaves.append(leaf[total_batch_elems:])
344
+
345
+ scan_tree = tree_def.unflatten(scan_leaves)
346
+ if num_remainder:
347
+ remainder_tree = tree_def.unflatten(remainder_leaves)
348
+ return scan_tree, remainder_tree
349
+ else:
350
+ return scan_tree, None
351
+
352
+
353
+ @set_module_as('brainstate.transform')
354
+ def map(
355
+ f,
356
+ *xs,
357
+ batch_size: int | None = None,
358
+ ):
359
+ """
360
+ Apply a Python function over the leading axis of one or more pytrees.
361
+
362
+ Compared with :func:`jax.vmap`, this helper executes sequentially by default
363
+ (via :func:`jax.lax.scan`), making it useful when auto-vectorisation is
364
+ impractical or when memory usage must be reduced. Providing ``batch_size``
365
+ enables chunked evaluation that internally leverages :func:`vmap` to improve
366
+ throughput while keeping peak memory bounded.
367
+
368
+ Parameters
369
+ ----------
370
+ f : callable
371
+ Function applied element-wise across the leading dimension. Its return
372
+ value must be a pytree whose leaves can be stacked along axis ``0``.
373
+ *xs : Any
374
+ Positional pytrees sharing the same length along their leading axis.
375
+ batch_size : int, optional
376
+ Size of vectorised blocks. When given, ``map`` first processes full
377
+ batches using :func:`vmap` then handles any remainder sequentially.
378
+
379
+ Returns
380
+ -------
381
+ Any
382
+ PyTree matching the structure of ``f``'s outputs with results stacked
383
+ along the leading dimension.
384
+
385
+ Raises
386
+ ------
387
+ ValueError
388
+ If the inputs do not share the same leading length.
389
+
390
+ Examples
391
+ --------
392
+ .. code-block:: python
393
+
394
+ >>> import jax.numpy as jnp
395
+ >>> from brainstate.transform import map
396
+ >>>
397
+ >>> xs = jnp.arange(6).reshape(6, 1)
398
+ >>>
399
+ >>> def normalize(row):
400
+ ... return row / (1.0 + jnp.linalg.norm(row))
401
+ >>>
402
+ >>> stacked = map(normalize, xs, batch_size=2)
403
+ >>> stacked.shape
404
+ (6, 1)
405
+
406
+ See Also
407
+ --------
408
+ vmap : Vectorised mapping with automatic batching.
409
+ jax.lax.scan : Primitive used for the sequential fallback.
410
+ """
411
+ if batch_size is not None:
412
+ scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
413
+ g = lambda _, x: ((), vmap(f)(*x))
414
+ _, scan_ys = scan(g, (), scan_xs)
415
+ if remainder_xs is None:
416
+ ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
417
+ else:
418
+ remainder_ys = vmap(f)(*remainder_xs)
419
+ ys = jax.tree.map(
420
+ lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
421
+ scan_ys,
422
+ remainder_ys,
423
+ )
424
+ else:
425
+ g = lambda _, x: ((), f(*x))
426
+ _, ys = scan(g, (), xs)
427
+ return ys
428
+
429
+
430
+ def _flatten(x):
431
+ return x.reshape(-1, *x.shape[2:])
432
+
433
+
434
+ def _vmap_new_states_transform(
435
+ fun: Callable[..., Any],
436
+ *,
437
+ # -- normal jax.vmap arguments -- #
438
+ in_axes: int | None | Sequence[Any] = 0,
439
+ out_axes: Any = 0,
440
+ axis_name: AxisName | None = None,
441
+ axis_size: int | None = None,
442
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
443
+ # -- brainstate specific arguments -- #
444
+ state_tag: str | None = None,
445
+ state_to_exclude: Filter | None = None,
446
+ state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
447
+ state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
448
+ unexpected_out_state_mapping: str = 'raise',
449
+ ):
450
+ # TODO: How about nested call ``vmap_new_states``?
451
+ if isinstance(axis_size, int) and axis_size <= 0:
452
+ raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
453
+
454
+ @vmap(
455
+ in_axes=in_axes,
456
+ out_axes=out_axes,
457
+ axis_name=axis_name,
458
+ axis_size=axis_size,
459
+ spmd_axis_name=spmd_axis_name,
460
+ state_in_axes=state_in_axes,
461
+ state_out_axes=state_out_axes,
462
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
463
+ )
464
+ def new_fun(args):
465
+ # call the function
466
+ with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
467
+ out = fun(*args)
468
+
469
+ # get vmap state values
470
+ vmap_state_vals = catcher.get_state_values()
471
+
472
+ return out, vmap_state_vals
473
+
474
+ @functools.wraps(fun)
475
+ def vmapped_fn(*args):
476
+ # vmapping
477
+ with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
478
+ outs, vmap_state_vals = new_fun(args)
479
+ vmap_states = catcher.get_states()
480
+
481
+ # restore vmapped state values
482
+ for st_val, st in zip(vmap_state_vals, vmap_states):
483
+ st.restore_value(st_val)
484
+ # ------------------------------------------------
485
+ # --- this is CRUCIAL to avoid jax tracing leakage
486
+ # ------------------------------------------------
487
+ st.decrease_stack_level()
488
+ return outs
489
+
490
+ return vmapped_fn
491
+
492
+
493
+ @set_module_as('brainstate.transform')
494
+ def vmap_new_states(
495
+ fun: Callable = Missing(),
496
+ *,
497
+ # -- normal jax.vmap arguments -- #
498
+ in_axes: int | None | Sequence[Any] = 0,
499
+ out_axes: Any = 0,
500
+ axis_name: AxisName | None = None,
501
+ axis_size: int | None = None,
502
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
503
+ # -- brainstate specific arguments -- #
504
+ state_tag: str | None = None,
505
+ state_to_exclude: Filter = None,
506
+ state_in_axes: Union[Dict[AxisName, Filter], Filter] = None,
507
+ state_out_axes: Union[Dict[AxisName, Filter], Filter] = None,
508
+ unexpected_out_state_mapping: str = 'raise',
509
+ ):
510
+ """
511
+ Vectorise a function that creates new BrainState states on the fly.
512
+
513
+ The helper wraps :func:`vmap` but also captures states instantiated inside
514
+ ``fun`` via :func:`brainstate._state.catch_new_states`. Newly created states
515
+ are materialised for each batch element and restored after execution so that
516
+ their side effects persist exactly once. When ``fun`` is omitted the helper
517
+ can be used as a decorator.
518
+
519
+ Parameters
520
+ ----------
521
+ fun : callable, optional
522
+ Function to transform. If omitted, :func:`vmap_new_states` returns a
523
+ decorator expecting ``fun``.
524
+ in_axes : int | None | sequence, default 0
525
+ Mapping specification for positional arguments, following
526
+ :func:`jax.vmap` semantics.
527
+ out_axes : any, default 0
528
+ Placement of the mapped axis in the outputs.
529
+ axis_name : hashable, optional
530
+ Name of the mapped axis for collective primitives.
531
+ axis_size : int, optional
532
+ Explicit size of the mapped axis. Must be positive when provided.
533
+ spmd_axis_name : hashable or tuple[hashable], optional
534
+ Axis labels used when nesting inside other SPMD transforms.
535
+ state_tag : str, optional
536
+ Tag used to limit which newly created states are tracked.
537
+ state_to_exclude : Filter, optional
538
+ Filter describing states that should *not* participate in the mapping.
539
+ state_in_axes : dict[AxisName, Filter] or Filter, optional
540
+ Filters indicating which existing states are batched on input.
541
+ state_out_axes : dict[AxisName, Filter] or Filter, optional
542
+ Filters describing how written states are scattered over the mapped axis.
543
+ unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
544
+ Behaviour when a state write is not covered by ``state_out_axes``.
545
+
546
+ Returns
547
+ -------
548
+ callable
549
+ A function with vectorised semantics that also mirrors new state
550
+ creation across the mapped axis.
551
+
552
+ Raises
553
+ ------
554
+ ValueError
555
+ If ``axis_size`` is provided and is not strictly positive.
556
+ BatchAxisError
557
+ If unexpected state writes occur and the policy is ``'raise'``.
558
+
559
+ Examples
560
+ --------
561
+ .. code-block:: python
562
+
563
+ >>> import brainstate as bst
564
+ >>> import jax.numpy as jnp
565
+ >>> from brainstate.transform import vmap_new_states
566
+ >>>
567
+ >>> @vmap_new_states(in_axes=0, out_axes=0)
568
+ ... def forward(x):
569
+ ... scratch = bst.ShortTermState(jnp.array(0.0), tag='scratch')
570
+ ... scratch.value = scratch.value + x
571
+ ... return scratch.value
572
+ >>>
573
+ >>> forward(jnp.arange(3.0))
574
+ Array([0., 1., 2.], dtype=float32)
575
+
576
+ See Also
577
+ --------
578
+ vmap : State-aware vectorisation for existing states.
579
+ catch_new_states : Context manager used internally to intercept state creation.
580
+ """
581
+ if isinstance(fun, Missing):
582
+ return functools.partial(
583
+ _vmap_new_states_transform,
584
+ in_axes=in_axes,
585
+ out_axes=out_axes,
586
+ axis_name=axis_name,
587
+ axis_size=axis_size,
588
+ spmd_axis_name=spmd_axis_name,
589
+ state_tag=state_tag,
590
+ state_to_exclude=state_to_exclude,
591
+ state_in_axes=state_in_axes,
592
+ state_out_axes=state_out_axes,
593
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
594
+ )
595
+ else:
596
+ return _vmap_new_states_transform(
597
+ fun,
598
+ in_axes=in_axes,
599
+ out_axes=out_axes,
600
+ axis_name=axis_name,
601
+ axis_size=axis_size,
602
+ spmd_axis_name=spmd_axis_name,
603
+ state_tag=state_tag,
604
+ state_to_exclude=state_to_exclude,
605
+ state_in_axes=state_in_axes,
606
+ unexpected_out_state_mapping=unexpected_out_state_mapping,
607
+ )