brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
@@ -15,152 +15,152 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import dataclasses
19
18
  import functools
20
- from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Mapping, Tuple, Union, Optional
19
+ from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Tuple, Union, Optional, Dict, List
21
20
 
22
21
  import jax
22
+ from jax.interpreters.batching import BatchTracer
23
23
 
24
- from brainstate.graph import (NodeStates, graph_to_tree, tree_to_graph, update_context)
25
- from brainstate.graph._graph_convert import clear_non_graph_nodes
24
+ from brainstate._state import State, StateTraceStack
25
+ from brainstate.compile._loop_collect_return import scan
26
26
  from brainstate.random import DEFAULT, RandomState
27
- from brainstate.typing import Missing, Filter
28
- from brainstate.util import NestedDict
27
+ from brainstate.typing import Missing
28
+ from brainstate.util import NestedDict, BrainStateError
29
29
  from ._random import restore_rngs
30
30
 
31
31
  __all__ = [
32
- 'StateAxes',
33
32
  'vmap',
34
33
  'pmap',
34
+ 'map',
35
35
  ]
36
36
 
37
37
  AxisName = Hashable
38
38
  F = TypeVar("F", bound=Callable)
39
- Index = int
40
- Carry = TypeVar("Carry")
39
+ AxisToState = Dict[int, List[State]]
40
+ StateToAxis = Dict[State, int]
41
41
 
42
42
 
43
- class StateAxes:
44
- """
45
- A class to represent the axes of a state.
46
-
47
- This class is used to control how graph nodes like Modules are vectorized or
48
- parallelized by specifying the axes to be applied to substates of the graph
49
- node given a Filter.
50
-
51
- Args:
52
- filter_axes: A mapping from filters to axes. The axes can be an index, a carry or None.
53
-
54
- """
55
-
56
- def __init__(
57
- self,
58
- filter_axes: Union[Mapping[Filter, Index | Carry | None], Iterable[Tuple[Filter, Index | Carry | None]]],
59
- ):
60
- iterable = filter_axes.items() if isinstance(filter_axes, Mapping) else filter_axes
61
- self._filters = tuple(filter_ for filter_, _ in iterable)
62
- self._axes = tuple(axis for _, axis in iterable)
63
-
64
- @property
65
- def filters(self) -> Tuple[Filter, ...]:
66
- return self._filters
67
-
68
- @property
69
- def axes(self) -> Tuple[Index | Carry | None, ...]:
70
- return self._axes
71
-
72
- def __repr__(self):
73
- return f'StateAxes({dict(self.items())})'
74
-
75
- def items(self):
76
- return zip(self.filters, self.axes)
77
-
78
- def __eq__(self, other):
79
- return isinstance(other, StateAxes) and self.filters == other.filters and self.axes == other.axes
80
-
81
- def __hash__(self):
82
- return hash((self.filters, self.axes))
43
+ class BatchAxisError(BrainStateError):
44
+ pass
83
45
 
84
46
 
85
- def _map_split_fn(ctx, path, prefix, x):
86
- if isinstance(prefix, StateAxes):
87
- return NodeStates.from_split(*ctx.treefy_split(x, *prefix.filters), metadata=prefix)
88
- return NodeStates.from_split(*ctx.treefy_split(x), metadata=prefix)
89
-
90
-
91
- @dataclasses.dataclass(eq=False)
92
- class MapFn:
93
- f: Callable[..., Any]
94
- in_axes: Any
95
- out_axes: Any
96
- ctxtag: str
97
-
98
- def __post_init__(self):
99
- functools.update_wrapper(self, self.f)
100
-
101
- def __call__(self, *pure_args: Tuple[Any, ...]):
102
- # pytree to graph
103
- args = tree_to_graph(pure_args, ctxtag=self.ctxtag)
104
-
105
- # call the function
106
- out = self.f(*args)
107
-
108
- # graph to pytree
109
- args_out = clear_non_graph_nodes(args)
110
- pure_args_out, pure_out = graph_to_tree(
111
- (args_out, out),
112
- prefix=(self.in_axes, self.out_axes),
113
- split_fn=_map_split_fn,
114
- ctxtag=self.ctxtag,
47
+ def _flatten_in_out_states(
48
+ in_states: Dict[int, Dict] | Any = None,
49
+ ) -> Tuple[AxisToState, StateToAxis]:
50
+ if in_states is None:
51
+ return dict(), dict()
52
+ if isinstance(in_states, dict):
53
+ keys = tuple(in_states.keys())
54
+ values = tuple(in_states.values())
55
+ is_axis_in_states = (
56
+ all([isinstance(key, int) for key in keys]) and
57
+ all([isinstance(value, dict) for value in values])
115
58
  )
116
- return pure_args_out, pure_out
117
-
118
-
119
- def _map_transform(
120
- ctxtag,
121
- transform,
59
+ else:
60
+ is_axis_in_states = False
61
+ if is_axis_in_states:
62
+ axis_to_states = {key: list(value.values()) for key, value in in_states.items()}
63
+ state_to_axis = {}
64
+ for key, value in in_states.items():
65
+ for state in value.values():
66
+ state_to_axis[state] = key
67
+ return axis_to_states, state_to_axis
68
+ else:
69
+ in_states = jax.tree.leaves(in_states)
70
+ axis_to_states = {0: list(in_states)}
71
+ state_to_axis = {state: 0 for state in in_states}
72
+ return axis_to_states, state_to_axis
73
+
74
+
75
+ def _vmap_transform(
122
76
  f: F,
123
77
  *,
124
- in_axes: Optional[int | Sequence[Any]] = 0,
78
+ in_axes: int | None | Sequence[Any] = 0,
125
79
  out_axes: Any = 0,
80
+ in_states: Dict[int, Dict] | Any | None = None,
81
+ out_states: Dict[int, Dict] | Any | None = None,
126
82
  rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
127
83
  **transform_kwargs,
128
84
  ):
129
- # jax in axes
130
- jax_in_axes = jax.tree.map(
131
- lambda x: NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x,
132
- in_axes,
133
- )
134
-
135
- # jax out axes
136
- jax_out_axes = jax.tree.map(
137
- lambda x: NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x,
138
- out_axes,
139
- )
140
-
141
- # mapped function
142
- mapped_fn = transform(
143
- MapFn(f, in_axes, out_axes, ctxtag),
144
- in_axes=jax_in_axes,
145
- out_axes=(jax_in_axes, jax_out_axes),
146
- **transform_kwargs
147
- )
85
+ axis_to_in_states, in_state_to_axis = _flatten_in_out_states(in_states)
86
+ axis_to_out_states, out_state_to_axis = _flatten_in_out_states(out_states)
87
+ for _in_state, _axis in in_state_to_axis.items():
88
+ if _in_state in out_state_to_axis:
89
+ _out_axis = out_state_to_axis[_in_state]
90
+ if _out_axis != _axis:
91
+ _in_state.raise_error_with_source_info(
92
+ BatchAxisError(
93
+ f"State {_in_state} has been mapped to axis {_axis} in in_states, "
94
+ f"However, it is mapped to axis {_out_axis} in out_states."
95
+ )
96
+ )
97
+ else:
98
+ out_state_to_axis[_in_state] = _axis
99
+ if _axis not in axis_to_out_states:
100
+ axis_to_out_states[_axis] = []
101
+ axis_to_out_states[_axis].append(_in_state)
102
+ if isinstance(rngs, RandomState):
103
+ rngs = (rngs,)
104
+ rng_ids = set([id(rng) for rng in rngs])
148
105
 
149
106
  @functools.wraps(f)
150
- @restore_rngs(rngs=rngs) # restore the random key of default random number generator
151
- @update_context(ctxtag)
152
- def map_wrapper(*args):
153
- # graph to pytree
154
- pure_args = graph_to_tree(args, prefix=in_axes, split_fn=_map_split_fn, ctxtag=ctxtag)
155
-
156
- # vmap with pytree
157
- pure_args_out, pure_out = mapped_fn(*pure_args)
158
-
159
- # pytree to graph
160
- _args_out, out = tree_to_graph((pure_args_out, pure_out), ctxtag=ctxtag)
161
- return out
107
+ def new_fn(in_states_, args):
108
+ # restore state values
109
+ for i, states in enumerate(axis_to_in_states.values()):
110
+ for state, state_val in zip(states, in_states_[i]):
111
+ state.restore_value(state_val)
162
112
 
163
- return map_wrapper # type: ignore
113
+ # call the function
114
+ with StateTraceStack() as stack:
115
+ outs = f(*args)
116
+
117
+ # analyze
118
+ for state in stack.get_write_states():
119
+ leaves = jax.tree.leaves(state.value)
120
+ if isinstance(leaves[0], BatchTracer) and state not in out_state_to_axis:
121
+ if isinstance(state, RandomState) and id(state) in rng_ids:
122
+ continue
123
+ state.raise_error_with_source_info(
124
+ BatchAxisError(
125
+ f"The value of State {state} is batched, but it is not in the out_states."
126
+ )
127
+ )
128
+
129
+ out_states_ = [
130
+ [state.value for state in states]
131
+ for axis, states in axis_to_out_states.items()
132
+ ]
133
+ return out_states_, outs
134
+
135
+ def vmapped_fn(*args):
136
+ # vmapping
137
+ in_state_vals = [
138
+ [st.value for st in states]
139
+ for axis, states in axis_to_in_states.items()
140
+ ]
141
+ in_axes_st = list(axis_to_in_states.keys())
142
+ out_axes_st = list(axis_to_out_states.keys())
143
+ if len(in_axes_st) == 0:
144
+ in_axes_st = 0
145
+ if len(out_axes_st) == 0:
146
+ out_axes_st = 0
147
+ out_state_vals, outs = restore_rngs(
148
+ jax.vmap(
149
+ new_fn,
150
+ in_axes=(in_axes_st, in_axes),
151
+ out_axes=(out_axes_st, out_axes),
152
+ **transform_kwargs
153
+ ),
154
+ rngs=rngs
155
+ )(in_state_vals, args)
156
+
157
+ # restore mapped state values
158
+ for i, states in enumerate(axis_to_out_states.values()):
159
+ for state, st_val in zip(states, out_state_vals[i]):
160
+ state.restore_value(st_val)
161
+ return outs
162
+
163
+ return vmapped_fn
164
164
 
165
165
 
166
166
  def vmap(
@@ -172,6 +172,8 @@ def vmap(
172
172
  axis_size: int | None = None,
173
173
  spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
174
174
  # brainstate specific arguments
175
+ in_states: Dict[int, Dict] | Any | None = None,
176
+ out_states: Dict[int, Dict] | Any | None = None,
175
177
  rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
176
178
  ) -> F | Callable[[F], F]:
177
179
  """
@@ -183,103 +185,38 @@ def vmap(
183
185
 
184
186
  More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
185
187
 
186
-
187
188
  These are several example usage::
188
189
 
189
190
  >>> import brainstate as bst
190
191
  >>> import jax.numpy as jnp
191
192
 
192
- >>> model = bst.nn.Linear(2, 3)
193
- >>> x = jnp.ones((5, 2))
194
-
195
- >>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
196
- ... def forward(model, x):
197
- ... return model(x)
198
-
199
- >>> y = forward(model, x)
200
- >>> print(y.shape)
201
- (5, 3)
202
-
203
- Another example with a more complex model::
204
-
205
- >>> class LinearEnsemble(bst.nn.Module):
206
- ... def __init__(self, n: int):
207
- ... super().__init__()
208
- ... self.n = n
209
- ... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
210
-
211
- >>> model = LinearEnsemble(5)
212
- >>> x = jnp.ones((2,))
213
-
214
- >>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
215
- ... def forward(model, x):
216
- ... return jnp.dot(x, model.w.value)
217
-
218
- >>> y = forward(model, x)
219
- >>> print(y.shape)
220
- (5, 3)
193
+ >>> class Model(bst.nn.Module):
194
+ >>> def __init__(self):
195
+ >>> super().__init__()
196
+ >>>
197
+ >>> self.a = bst.ShortTermState(bst.random.randn(5))
198
+ >>> self.b = bst.ShortTermState(bst.random.randn(5))
199
+ >>> self.c = bst.State(bst.random.randn(1))
221
200
 
222
- To control how different types of states are vectorized, ``StateAxes``
223
- can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
224
- applied to each substate given a filter. The following example shows how to
225
- share the parameters between the ensemble members which keeping different
226
- batch statistics and dropout random state::
201
+ >>> def __call__(self, *args, **kwargs):
202
+ >>> self.c.value = self.a.value * self.b.value
203
+ >>> return self.c.value + 1.
227
204
 
228
- >>> class Foo(bst.nn.Module):
229
- ... def __init__(self):
230
- ... super().__init__()
231
- ... self.a = bst.ParamState(jnp.arange(4))
232
- ... self.b = bst.ShortTermState(jnp.arange(4))
205
+ >>> model = Model()
233
206
 
234
- >>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
235
- >>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
236
- ... def mul(foo):
237
- ... return foo.a.value * foo.b.value
238
-
239
- >>> model = Foo()
240
- >>> y = mul(model)
241
- >>> print(y.shape)
242
- (4, 4)
207
+ >>> r = bst.augment.vmap(
208
+ >>> model,
209
+ >>> in_states=model.states(bst.ShortTermState),
210
+ >>> out_states=model.c
211
+ >>> )()
243
212
 
244
213
  Args:
245
214
  fn: Function to be mapped over additional axes.
246
215
  in_axes: An integer, None, or sequence of values specifying which input
247
216
  array axes to map over.
248
-
249
- If each positional argument to ``fun`` is an array, then ``in_axes`` can
250
- be an integer, a None, or a tuple of integers and Nones with length equal
251
- to the number of positional arguments to ``fun``. An integer or ``None``
252
- indicates which array axis to map over for all arguments (with ``None``
253
- indicating not to map any axis), and a tuple indicates which axis to map
254
- for each corresponding positional argument. Axis integers must be in the
255
- range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
256
- dimensions (axes) of the corresponding input array.
257
-
258
- If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
259
- must be a sequence with length equal to the number of positional arguments to
260
- ``fun``, and for each argument the corresponding element of ``in_axes`` can
261
- be a container with a matching pytree structure specifying the mapping of its
262
- container elements. In other words, ``in_axes`` must be a container tree prefix
263
- of the positional argument tuple passed to ``fun``. See this link for more detail:
264
- https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
265
-
266
- Either ``axis_size`` must be provided explicitly, or at least one
267
- positional argument must have ``in_axes`` not None. The sizes of the
268
- mapped input axes for all mapped positional arguments must all be equal.
269
-
270
- Arguments passed as keywords are always mapped over their leading axis
271
- (i.e. axis index 0).
272
-
273
- See below for examples.
274
-
275
217
  out_axes: An integer, None, or (nested) standard Python container
276
218
  (tuple/list/dict) thereof indicating where the mapped axis should appear
277
- in the output. All outputs with a mapped axis must have a non-None
278
- ``out_axes`` specification. Axis integers must be in the range ``[-ndim,
279
- ndim)`` for each output array, where ``ndim`` is the number of dimensions
280
- (axes) of the array returned by the :func:`vmap`-ed function, which is one
281
- more than the number of dimensions (axes) of the corresponding array
282
- returned by ``fun``.
219
+ in the output.
283
220
  axis_name: Optional, a hashable Python object used to identify the mapped
284
221
  axis so that parallel collectives can be applied.
285
222
  axis_size: Optional, an integer indicating the size of the axis to be
@@ -296,6 +233,8 @@ def vmap(
296
233
  generators to be used in the mapped function. These random number
297
234
  generators are restored their random key after the mapped function is
298
235
  executed.
236
+ in_states: Optional, the :class:`State` objects to be mapped over in the inputs.
237
+ out_states: Optional, the :class:`State` objects to be mapped over in the outputs.
299
238
 
300
239
  Returns:
301
240
  Batched/vectorized version of ``fun`` with arguments that correspond to
@@ -304,23 +243,26 @@ def vmap(
304
243
  with extra array axes at positions indicated by ``out_axes``.
305
244
 
306
245
  """
246
+
307
247
  if isinstance(fn, Missing):
308
248
  return functools.partial(
309
- vmap,
249
+ _vmap_transform,
310
250
  in_axes=in_axes,
311
251
  out_axes=out_axes,
252
+ in_states=in_states,
253
+ out_states=out_states,
312
254
  axis_name=axis_name,
313
255
  axis_size=axis_size,
314
256
  spmd_axis_name=spmd_axis_name,
315
257
  rngs=rngs,
316
258
  ) # type: ignore[return-value]
317
259
 
318
- return _map_transform(
319
- 'vmap', # ctxtag
320
- jax.vmap,
260
+ return _vmap_transform(
321
261
  fn,
322
262
  in_axes=in_axes,
323
263
  out_axes=out_axes,
264
+ in_states=in_states,
265
+ out_states=out_states,
324
266
  axis_name=axis_name,
325
267
  axis_size=axis_size,
326
268
  spmd_axis_name=spmd_axis_name,
@@ -363,65 +305,6 @@ def pmap(
363
305
 
364
306
  More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
365
307
 
366
- If there are 4 XLA devices available, the following example will execute
367
- the function in parallel on each device::
368
-
369
-
370
- >>> import brainstate as bst
371
- >>> import jax.numpy as jnp
372
-
373
- >>> model = bst.nn.Linear(2, 3)
374
- >>> x = jnp.ones((4, 2))
375
-
376
- >>> @bst.augment.vmap(in_axes=(None, 0), out_axes=0)
377
- ... def forward(model, x):
378
- ... return model(x)
379
-
380
- >>> y = forward(model, x)
381
- >>> print(y.shape)
382
- (4, 3)
383
-
384
- Another example with a more complex model::
385
-
386
- >>> class LinearEnsemble(bst.nn.Module):
387
- ... def __init__(self, n: int):
388
- ... super().__init__()
389
- ... self.n = n
390
- ... self.w = bst.ParamState(bst.random.random((n, 2, 3)))
391
-
392
- >>> model = LinearEnsemble(4)
393
- >>> x = jnp.ones((2,))
394
-
395
- >>> @bst.augment.vmap(in_axes=(0, None), out_axes=0)
396
- ... def forward(model, x):
397
- ... return jnp.dot(x, model.w.value)
398
-
399
- >>> y = forward(model, x)
400
- >>> print(y.shape)
401
- (4, 3)
402
-
403
- To control how different types of states are vectorized, ``StateAxes``
404
- can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
405
- applied to each substate given a filter. The following example shows how to
406
- share the parameters between the ensemble members which keeping different
407
- batch statistics and dropout random state::
408
-
409
- >>> class Foo(bst.nn.Module):
410
- ... def __init__(self):
411
- ... super().__init__()
412
- ... self.a = bst.ParamState(jnp.arange(4))
413
- ... self.b = bst.ShortTermState(jnp.arange(4))
414
-
415
- >>> state_axes = bst.augment.StateAxes({bst.ParamState: 0, bst.ShortTermState: None})
416
- >>> @bst.augment.vmap(in_axes=(state_axes,), out_axes=0)
417
- ... def mul(foo):
418
- ... return foo.a.value * foo.b.value
419
-
420
- >>> model = Foo()
421
- >>> y = mul(model)
422
- >>> print(y.shape)
423
- (4, 4)
424
-
425
308
 
426
309
  Args:
427
310
  fn: Function to be mapped over argument axes. Its arguments and return
@@ -508,18 +391,123 @@ def pmap(
508
391
  rngs=rngs,
509
392
  ) # type: ignore[return-value]
510
393
 
511
- return _map_transform(
512
- 'pmap', # ctxtag
513
- jax.pmap,
514
- fn,
515
- in_axes=in_axes,
516
- out_axes=out_axes,
517
- axis_name=axis_name,
518
- static_broadcasted_argnums=static_broadcasted_argnums,
519
- devices=devices,
520
- backend=backend,
521
- axis_size=axis_size,
522
- donate_argnums=donate_argnums,
523
- global_arg_shapes=global_arg_shapes,
524
- rngs=rngs,
394
+ return restore_rngs(
395
+ jax.pmap(
396
+ fn,
397
+ in_axes=in_axes,
398
+ out_axes=out_axes,
399
+ axis_name=axis_name,
400
+ static_broadcasted_argnums=static_broadcasted_argnums,
401
+ devices=devices,
402
+ backend=backend,
403
+ axis_size=axis_size,
404
+ donate_argnums=donate_argnums,
405
+ global_arg_shapes=global_arg_shapes,
406
+ ),
407
+ rngs=rngs
525
408
  )
409
+
410
+
411
+ def _batch_and_remainder(x, batch_size: int):
412
+ leaves, tree_def = jax.tree.flatten(x)
413
+
414
+ scan_leaves = []
415
+ remainder_leaves = []
416
+
417
+ length = None
418
+ for leaf in leaves:
419
+ if length is None:
420
+ length = leaf.shape[0]
421
+ if length != leaf.shape[0]:
422
+ raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
423
+
424
+ num_batches, num_remainder = divmod(length, batch_size)
425
+ for leaf in leaves:
426
+ total_batch_elems = num_batches * batch_size
427
+ scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
428
+ if num_remainder:
429
+ remainder_leaves.append(leaf[total_batch_elems:])
430
+
431
+ scan_tree = tree_def.unflatten(scan_leaves)
432
+ if num_remainder:
433
+ remainder_tree = tree_def.unflatten(remainder_leaves)
434
+ return scan_tree, remainder_tree
435
+ else:
436
+ return scan_tree, None
437
+
438
+
439
+ def map(
440
+ f,
441
+ *xs,
442
+ batch_size: int | None = None,
443
+ ):
444
+ """
445
+ Map a function over leading array axes.
446
+
447
+ Like Python's builtin map, except inputs and outputs are in the form of
448
+ stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
449
+ need to apply a function element by element for reduced memory usage or
450
+ heterogeneous computation with other control flow primitives.
451
+
452
+ When ``xs`` is an array type, the semantics of :func:`~map` are given by this
453
+ Python implementation::
454
+
455
+ def map(f, *xs):
456
+ return np.stack([f(*x) for x in xs])
457
+
458
+ Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
459
+ many of the same advantages over a Python loop apply: ``xs`` may be an
460
+ arbitrary nested pytree type, and the mapped computation is compiled only
461
+ once.
462
+
463
+ If ``batch_size`` is provided, the computation is executed in batches of that size
464
+ and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
465
+ version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
466
+ divisible by the batch size, the remainder is processed in a separate ``vmap`` and
467
+ concatenated to the result.
468
+
469
+ >>> import jax.numpy as jnp
470
+ >>> x = jnp.ones((10, 3, 4))
471
+ >>> def f(x):
472
+ ... print('inner shape:', x.shape)
473
+ ... return x + 1
474
+ >>> y = map(f, x, batch_size=3)
475
+ inner shape: (3, 4)
476
+ inner shape: (3, 4)
477
+ >>> y.shape
478
+ (10, 3, 4)
479
+
480
+ In the example above, "inner shape" is printed twice, once while tracing the batched
481
+ computation and once while tracing the remainder computation.
482
+
483
+ Args:
484
+ f: a Python function to apply element-wise over the first axis or axes of
485
+ ``xs``.
486
+ xs: values over which to map along the leading axis.
487
+ batch_size: (optional) integer specifying the size of the batch for each step to execute
488
+ in parallel.
489
+
490
+ Returns:
491
+ Mapped values.
492
+ """
493
+ if batch_size is not None:
494
+ scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
495
+ g = lambda _, x: ((), vmap(f)(*x))
496
+ _, scan_ys = scan(g, (), scan_xs)
497
+ if remainder_xs is None:
498
+ ys = jax.tree.map(lambda x: flatten_(x), scan_ys)
499
+ else:
500
+ remainder_ys = vmap(f)(*remainder_xs)
501
+ ys = jax.tree.map(
502
+ lambda x, y: jax.lax.concatenate([flatten_(x), y], dimension=0),
503
+ scan_ys,
504
+ remainder_ys,
505
+ )
506
+ else:
507
+ g = lambda _, x: ((), f(*x))
508
+ _, ys = scan(g, (), xs)
509
+ return ys
510
+
511
+
512
+ def flatten_(x):
513
+ return x.reshape(-1, *x.shape[2:])