brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -146
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -470
  58. brainstate/nn/_delay_test.py +238 -0
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1361
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1120
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -208
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.7.dist-info/RECORD +0 -131
  133. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1060 +1,1060 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from typing import (
18
- Any,
19
- TypeVar,
20
- Callable,
21
- Hashable,
22
- Sequence,
23
- Iterable,
24
- Tuple,
25
- Union,
26
- Optional,
27
- Dict,
28
- List
29
- )
30
-
31
- import jax
32
- from jax.interpreters.batching import BatchTracer
33
-
34
- from brainstate._compatible_import import Device
35
- from brainstate._state import State, catch_new_states
36
- from brainstate.compile import scan, StatefulFunction
37
- from brainstate.random import RandomState, DEFAULT
38
- from brainstate.typing import Missing, Filter
39
- from brainstate.util import NestedDict, BrainStateError
40
- from ._random import restore_rngs
41
-
42
- __all__ = [
43
- 'vmap',
44
- 'pmap',
45
- 'map',
46
- 'vmap_new_states',
47
- ]
48
-
49
- F = TypeVar("F", bound=Callable)
50
- AxisName = Hashable
51
- AxisToState = Dict[int, List[State]]
52
- StateToAxis = Dict[State, int]
53
-
54
-
55
- class BatchAxisError(BrainStateError):
56
- """
57
- Exception raised for errors related to batch axis operations.
58
-
59
- This custom exception is used to indicate errors that occur during
60
- batch processing or vectorization operations, particularly in the
61
- context of state management in the BrainState framework.
62
-
63
- Inherits from:
64
- BrainStateError: The base error class for BrainState-related exceptions.
65
- """
66
- pass
67
-
68
-
69
- def _flatten_in_out_states(
70
- in_states: Dict[int, Dict] | Any = None,
71
- ) -> Tuple[AxisToState, StateToAxis]:
72
- """
73
- Flattens and organizes input or output states into axis-based mappings.
74
-
75
- This function processes the input or output states, converting them into two
76
- dictionary representations: one mapping axes to states, and another mapping
77
- states to axes. It handles both structured (Dict[int, Dict]) and unstructured
78
- input formats.
79
-
80
- Args:
81
- in_states (Dict[int, Dict] | Any, optional): The input or output states to be
82
- flattened. Can be a nested dictionary structure where the outer keys are
83
- axes and inner dictionaries contain states, or any other structure
84
- containing states. Defaults to None.
85
-
86
- Returns:
87
- Tuple[AxisToState, StateToAxis]: A tuple containing two dictionaries:
88
- - AxisToState: Maps axes (int) to lists of states.
89
- - StateToAxis: Maps individual states to their corresponding axes (int).
90
-
91
- Note:
92
- If in_states is None, empty dictionaries are returned for both mappings.
93
- If in_states is not in the expected Dict[int, Dict] format, all states are
94
- assigned to axis 0.
95
- """
96
- if in_states is None:
97
- return dict(), dict()
98
- if isinstance(in_states, dict):
99
- keys = tuple(in_states.keys())
100
- values = tuple(in_states.values())
101
- is_axis_in_states = (
102
- all([isinstance(key, int) for key in keys]) and
103
- all([isinstance(value, dict) for value in values])
104
- )
105
- else:
106
- is_axis_in_states = False
107
- if is_axis_in_states:
108
- axis_to_states = {key: list(value.values()) for key, value in in_states.items()}
109
- state_to_axis = {}
110
- for key, value in in_states.items():
111
- for state in value.values():
112
- state_to_axis[state] = key
113
- return axis_to_states, state_to_axis
114
- else:
115
- in_states = jax.tree.leaves(in_states)
116
- axis_to_states = {0: list(in_states)}
117
- state_to_axis = {state: 0 for state in in_states}
118
- return axis_to_states, state_to_axis
119
-
120
-
121
- def _remove_axis(x, axis: int):
122
- """
123
- Remove a specified axis from an array or nested structure.
124
-
125
- This function removes a specified axis from an array or nested structure,
126
- adjusting the shape and structure of the output accordingly.
127
-
128
- Args:
129
- x (Any): The input array or nested structure to remove the axis from.
130
- axis (int): The axis to remove from the input.
131
-
132
- Returns:
133
- Any: The output array or nested structure with the specified axis removed.
134
- """
135
- assert isinstance(axis, int), f"Expected axis to be an integer, but got {type(axis)}"
136
- if axis < 0:
137
- axis += x.ndim
138
- if axis < 0 or axis >= x.ndim:
139
- raise IndexError(f"Axis {axis} is out of bounds for array of shape {x.shape}")
140
- return x[tuple(slice(None, None, None) if i != axis else 0 for i in range(x.ndim))]
141
-
142
-
143
- def _compile_stateful_function(
144
- stateful_fn: StatefulFunction,
145
- in_axes: int | Tuple[int, ...],
146
- args: Tuple
147
- ):
148
- """
149
- Compile a stateful function with specified input axes and arguments.
150
-
151
- This function prepares and compiles a stateful function for vectorized mapping (vmap)
152
- by adjusting the input arguments based on the specified axes and then generating
153
- the function's JAX program representation (jaxpr).
154
-
155
- Args:
156
- stateful_fn (StatefulFunction): The stateful function to be compiled.
157
- in_axes (int | Tuple[int, ...]): Specifies which axes of the input arguments
158
- to map over. Can be a single integer (same for all args) or a tuple of integers.
159
- args (Tuple): The input arguments to the function.
160
-
161
- Raises:
162
- ValueError: If the length of in_axes tuple doesn't match the number of arguments.
163
-
164
- Returns:
165
- None. The function modifies the stateful_fn in-place by calling make_jaxpr.
166
- """
167
- in_axes_st, in_axes = in_axes
168
- state_vals, args = args
169
-
170
- # check in_axes
171
- if isinstance(in_axes, tuple) and len(in_axes) != len(args):
172
- raise ValueError(
173
- "vmap in_axes must be an int, None, or a tuple of entries corresponding "
174
- "to the positional arguments passed to the function, "
175
- f"but got {len(in_axes)=}, {len(args)=}"
176
- )
177
-
178
- # check state_vals
179
- if len(state_vals) > 0:
180
- state_vals = [jax.tree.map(lambda x: _remove_axis(x, axis), vals)
181
- for vals, axis in zip(state_vals, in_axes_st)]
182
- else:
183
- state_vals = []
184
-
185
- if isinstance(in_axes, int):
186
- args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
187
- elif isinstance(in_axes, tuple):
188
- args = tuple([
189
- arg if in_axis is None else _remove_axis(arg, in_axis)
190
- for arg, in_axis in zip(args, in_axes)
191
- ])
192
- stateful_fn.make_jaxpr(state_vals, args)
193
- return stateful_fn.get_arg_cache_key(state_vals, args)
194
-
195
-
196
- def _get_batch_size(
197
- args: Tuple,
198
- in_axes: int | Tuple[int, ...],
199
- in_states: AxisToState,
200
- axis_size: Optional[int] = None,
201
- ) -> int:
202
- """
203
- Determine the batch size from input arguments, axes, and states.
204
-
205
- This function calculates the batch size by examining the shapes of input arguments
206
- and states along specified axes. It ensures consistency across all inputs.
207
-
208
- Args:
209
- args (Tuple): The input arguments to the function being vectorized.
210
- in_axes (int | Tuple[int, ...]): The axes along which to vectorize for each argument.
211
- Can be a single integer (same for all args) or a tuple of integers.
212
- in_states (AxisToState): A dictionary mapping axes to lists of states.
213
-
214
- Returns:
215
- int: The determined batch size.
216
-
217
- Raises:
218
- ValueError: If unable to determine batch size or if inconsistent batch sizes are found.
219
- """
220
- batch_sizes = []
221
-
222
- # Check batch size from args and in_axes
223
- if isinstance(in_axes, int):
224
- in_axes = (in_axes,) * len(args)
225
- for arg, in_axis in zip(args, in_axes):
226
- if in_axis is not None:
227
- arg_leaves = jax.tree.leaves(arg)
228
- if arg_leaves:
229
- batch_sizes.append(arg_leaves[0].shape[in_axis])
230
-
231
- # Check batch size from in_states
232
- if in_states is not None:
233
- for axis, states in in_states.items():
234
- for state in states:
235
- state_leaves = jax.tree.leaves(state.value)
236
- if len(state_leaves):
237
- batch_sizes.append(state_leaves[0].shape[axis])
238
-
239
- if len(batch_sizes) == 0:
240
- assert axis_size is not None, (
241
- "Unable to determine batch size. Please provide the 'axis_size' argument."
242
- )
243
- return axis_size
244
- else:
245
- # Ensure all batch sizes are consistent
246
- if len(set(batch_sizes)) > 1:
247
- raise ValueError(f"Inconsistent batch sizes found: {set(batch_sizes)}")
248
-
249
- return batch_sizes[0]
250
-
251
-
252
- def _format_state_axes(
253
- in_states,
254
- out_states,
255
- ):
256
- """
257
- Format and validate the axes of input and output states.
258
-
259
- This function processes the input and output states, ensuring consistency
260
- between their axis mappings. It also handles cases where a state appears
261
- in the input but not in the output.
262
-
263
- Args:
264
- in_states: The input states to be formatted. Can be a dictionary mapping
265
- axes to states, or any other structure containing states.
266
- out_states: The output states to be formatted. Can be a dictionary mapping
267
- axes to states, or any other structure containing states.
268
-
269
- Returns:
270
- A tuple containing four elements:
271
- - axis_to_in_states (dict): Mapping of axes to input states.
272
- - in_state_to_axis (dict): Mapping of input states to their axes.
273
- - axis_to_out_states (dict): Mapping of axes to output states.
274
- - out_state_to_axis (dict): Mapping of output states to their axes.
275
-
276
- Raises:
277
- BatchAxisError: If there's an inconsistency between the axis mappings
278
- of input and output states.
279
- """
280
- axis_to_in_states, in_state_to_axis = _flatten_in_out_states(in_states)
281
- axis_to_out_states, out_state_to_axis = _flatten_in_out_states(out_states)
282
- for _in_state, _axis in in_state_to_axis.items():
283
- if _in_state in out_state_to_axis:
284
- _out_axis = out_state_to_axis[_in_state]
285
- if _out_axis != _axis:
286
- _in_state.raise_error_with_source_info(
287
- BatchAxisError(
288
- f"State {_in_state} has been mapped to axis {_axis} in 'in_states', "
289
- f"However, it is mapped to axis {_out_axis} in 'out_states'."
290
- )
291
- )
292
- else:
293
- out_state_to_axis[_in_state] = _axis
294
- if _axis not in axis_to_out_states:
295
- axis_to_out_states[_axis] = []
296
- axis_to_out_states[_axis].append(_in_state)
297
-
298
- return axis_to_in_states, in_state_to_axis, axis_to_out_states, out_state_to_axis
299
-
300
-
301
- def _vmap_transform(
302
- f: F,
303
- *,
304
- in_axes: int | None | Sequence[Any] = 0,
305
- out_axes: Any = 0,
306
- in_states: Dict[int, Dict] | Any | None = None,
307
- out_states: Dict[int, Dict] | Any | None = None,
308
- axis_size: Optional[int] = None,
309
- axis_name: AxisName | None = None,
310
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
311
- ):
312
- """
313
- Transforms a function for vectorized mapping (vmap) with state management.
314
-
315
- This internal function applies vectorized mapping to the input function while
316
- handling state management for input and output states. It supports custom
317
- axis specifications for both inputs and outputs.
318
-
319
- Args:
320
- f (F): The function to be transformed for vectorized mapping.
321
- in_axes (int | None | Sequence[Any]): Specifies which axes of the input
322
- arguments to map over. Default is 0.
323
- out_axes (Any): Specifies where the mapped axis should appear in the output.
324
- Default is 0.
325
- in_states (Dict[int, Dict] | Any | None): Specifies the input states and
326
- their corresponding axes for mapping. Default is None.
327
- out_states (Dict[int, Dict] | Any | None): Specifies the output states and
328
- their corresponding axes for mapping. Default is None.
329
- **transform_kwargs: Additional keyword arguments for the transformation.
330
-
331
- Returns:
332
- Callable: A new function that applies vectorized mapping to the input
333
- function while managing states.
334
- """
335
-
336
- # TODO: support jax.disable_jit()
337
-
338
- # format state axes
339
- (
340
- axis_to_in_states,
341
- in_state_to_axis,
342
- axis_to_out_states,
343
- out_state_to_axis
344
- ) = _format_state_axes(in_states, out_states)
345
-
346
- # check in_axes
347
- if isinstance(in_axes, list):
348
- # To be a tree prefix of the positional args tuple, in_axes can never be a
349
- # list: if in_axes is not a leaf, it must be a tuple of trees. However,
350
- # in cases like these users expect tuples and lists to be treated
351
- # essentially interchangeably, so we canonicalize lists to tuples here
352
- # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
353
- in_axes = tuple(in_axes)
354
-
355
- def _vmap_fn_for_compilation(in_vmap_state_vals, args):
356
- """
357
- Compile a function for vectorized mapping (vmap) with state restoration.
358
-
359
- This internal function is used to prepare a function for vectorized mapping
360
- by restoring state values before calling the original function.
361
-
362
- Args:
363
- in_vmap_state_vals (List[List]): A nested list containing the state values
364
- to be restored. The outer list corresponds to different axes, while
365
- the inner lists contain the state values for each axis.
366
- args (Tuple): The arguments to be passed to the original function after
367
- state restoration.
368
-
369
- Returns:
370
- Any: The result of calling the original function 'f' with the restored
371
- state and provided arguments.
372
- """
373
- # restore state values
374
- for i, states in enumerate(axis_to_in_states.values()):
375
- for state, state_val in zip(states, in_vmap_state_vals[i]):
376
- state.restore_value(state_val)
377
-
378
- # call the function
379
- return f(*args)
380
-
381
- def _set_axis_env(batch_size):
382
- axis_env = None if axis_name is None else [(axis_name, batch_size)]
383
- stateful_fn.axis_env = axis_env
384
-
385
- # stateful function
386
- stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
387
-
388
- @functools.wraps(f)
389
- def new_fn_for_vmap(
390
- rng_keys,
391
- in_state_vmap_vals,
392
- in_state_oth_vals,
393
- args,
394
- ):
395
- """
396
- Wrapper function for vectorized mapping (vmap) that handles state restoration and function execution.
397
-
398
- This function restores state values, random number generators (RNGs), and other state values
399
- before calling the original function. It then processes the outputs and prepares them for
400
- vectorized mapping.
401
-
402
- Args:
403
- rng_keys (Sequence): Random number generator keys for each mapped instance.
404
- in_state_vmap_vals (Sequence[Sequence]): Input state values for vectorized mapping,
405
- organized by axis.
406
- in_state_oth_vals (Sequence): Other input state values not involved in vectorized mapping.
407
- args (Tuple): Arguments to be passed to the original function.
408
-
409
- Returns:
410
- Tuple: A tuple containing four elements:
411
- - out_rng_keys (List): Updated RNG keys after function execution.
412
- - out_state_vmap_vals (List[List]): Output state values for vectorized mapping,
413
- organized by axis.
414
- - out_state_oth_vals (List): Other output state values not involved in vectorized mapping.
415
- - outs: The output of the original function call.
416
-
417
- Raises:
418
- AssertionError: If there's a mismatch in the number of states, state values, or RNG keys.
419
- BatchAxisError: If a state value is batched but not included in out_states.
420
- """
421
- # restore vmapping state values
422
- for i, states in enumerate(axis_to_in_states.values()):
423
- assert len(states) == len(in_state_vmap_vals[i]), (
424
- f"The number of states in axis {i} should be equal to the number "
425
- f"of state values, but got {len(states)} and {len(in_state_vmap_vals[i])}."
426
- )
427
- for state, state_val in zip(states, in_state_vmap_vals[i]):
428
- state.restore_value(state_val)
429
-
430
- # restore rngs
431
- cache_key = stateful_fn.get_arg_cache_key(in_state_vmap_vals, args)
432
- state_trace = stateful_fn.get_state_trace(cache_key)
433
- rngs = state_trace.state_subset(RandomState)
434
- rng_sets = set(rngs)
435
- assert len(rngs) == len(rng_keys), (
436
- f"The number of random states in the function should be equal to the number "
437
- f"of random keys, but got {len(rngs)} and {len(rng_keys)}."
438
- )
439
- for rng, key in zip(rngs, rng_keys):
440
- rng.restore_value(key)
441
-
442
- # restore other state values
443
- oth_in_state = [
444
- st for st in state_trace.states
445
- if st not in in_state_to_axis and st not in rng_sets
446
- ]
447
- assert len(oth_in_state) == len(in_state_oth_vals), (
448
- f"The number of states in 'in_states' should be equal to the number "
449
- f"of state values, but got {len(oth_in_state)} and {len(in_state_oth_vals)}."
450
- )
451
- for state, state_val in zip(oth_in_state, in_state_oth_vals):
452
- state.restore_value(state_val)
453
-
454
- # call the function
455
- outs = stateful_fn.jaxpr_call_auto(in_state_vmap_vals, args)
456
-
457
- # analyze vmapping axis error
458
- for state in state_trace.get_write_states():
459
- leaves = jax.tree.leaves(state.value)
460
- if (
461
- any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
462
- and state not in out_state_to_axis
463
- ):
464
- if isinstance(state, RandomState) and state in rng_sets:
465
- continue
466
- state.raise_error_with_source_info(
467
- BatchAxisError(f"The value of State {state} is batched, "
468
- f"but it is not in the out_states.")
469
- )
470
-
471
- # out state values for vmapping
472
- out_state_vmap_vals = [
473
- [state.value for state in states]
474
- for axis, states in axis_to_out_states.items()
475
- ]
476
- out_state_oth_vals = [
477
- st.value for st in state_trace.states
478
- if st not in out_state_to_axis and st not in rng_sets
479
- ]
480
- out_rng_keys = [rng.value for rng in rngs]
481
- return out_rng_keys, out_state_vmap_vals, out_state_oth_vals, outs
482
-
483
- @functools.wraps(f)
484
- def vmapped_fn(*args, **kwargs):
485
- """
486
- Applies vectorized mapping (vmap) to the input function while managing state.
487
-
488
- This function handles the vectorization process, including state management,
489
- random number generation, and function compilation. It prepares the input
490
- states, compiles the stateful function, manages random number generators,
491
- applies the vmap transformation, and restores the output states.
492
-
493
- Args:
494
- *args: Variable length argument list containing the input arguments
495
- to be passed to the vectorized function.
496
-
497
- Returns:
498
- Any: The output of the vectorized function after applying vmap and
499
- managing states.
500
-
501
- Note:
502
- This function assumes the existence of several helper functions and
503
- data structures (e.g., axis_to_in_states, in_state_to_axis) which
504
- should be defined in the broader context.
505
- """
506
- if len(kwargs):
507
- raise NotImplementedError(
508
- "Keyword arguments `f(**kwargs)` are not supported in brainstate.augment.vmap"
509
- )
510
-
511
- # in states values
512
- in_state_map_vals = [
513
- [st.value for st in states]
514
- for axis, states in axis_to_in_states.items()
515
- ]
516
- st_in_axes = list(axis_to_in_states.keys())
517
- if len(st_in_axes) == 0:
518
- st_in_axes = 0
519
-
520
- # compile stateful function
521
- batch_size = None
522
- if axis_name is not None:
523
- batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
524
- _set_axis_env(batch_size)
525
- cache_key = _compile_stateful_function(
526
- stateful_fn,
527
- (st_in_axes, in_axes),
528
- (in_state_map_vals, args)
529
- )
530
-
531
- # random keys
532
- state_trace = stateful_fn.get_state_trace(cache_key)
533
- rngs = state_trace.state_subset(RandomState)
534
- rng_sets = set(rngs)
535
- if len(rngs):
536
- # batch size
537
- if batch_size is None:
538
- batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
539
- rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
540
- rng_backup = tuple(rng.split_key() for rng in rngs)
541
- else:
542
- rng_keys = tuple()
543
- rng_backup = tuple()
544
-
545
- # in states other values
546
- in_state_oth_vals = [
547
- st.value
548
- for st in state_trace.states
549
- if st not in in_state_to_axis and st not in rng_sets
550
- ]
551
-
552
- # out state axis
553
- st_out_axes = list(axis_to_out_states.keys())
554
- if len(st_out_axes) == 0:
555
- st_out_axes = 0
556
-
557
- # --- vmapping --- #
558
- fn = jax.vmap(
559
- new_fn_for_vmap,
560
- in_axes=(0, st_in_axes, None, in_axes),
561
- out_axes=(0, st_out_axes, None, out_axes),
562
- axis_size=axis_size,
563
- axis_name=axis_name,
564
- spmd_axis_name=spmd_axis_name,
565
- )
566
- _, out_state_map_vals, out_state_oth_vals, outs = fn(
567
- rng_keys, in_state_map_vals, in_state_oth_vals, args
568
- )
569
-
570
- # restore mapped state values
571
- for i, states in enumerate(axis_to_out_states.values()):
572
- assert len(states) == len(out_state_map_vals[i]), (
573
- f"The number of states in axis {i} should be equal to the number "
574
- f"of state values, but got {len(states)} and {len(out_state_map_vals[i])}."
575
- )
576
- for state, st_val in zip(states, out_state_map_vals[i]):
577
- state.restore_value(st_val)
578
-
579
- # restore other state values
580
- out_oth_states = [
581
- st for st in state_trace.states
582
- if st not in out_state_to_axis and st not in rng_sets
583
- ]
584
- assert len(out_oth_states) == len(out_state_oth_vals), (
585
- f"The number of states in 'out_states' should be equal to the number "
586
- f"of state values, but got {len(out_oth_states)} and {len(out_state_oth_vals)}."
587
- )
588
- for state, st_val in zip(out_oth_states, out_state_oth_vals):
589
- state.restore_value(st_val)
590
-
591
- # restore random keys
592
- for rng, key in zip(rngs, rng_backup):
593
- rng.restore_value(key)
594
- return outs
595
-
596
- return vmapped_fn
597
-
598
-
599
- def vmap(
600
- fn: F | Missing = Missing(),
601
- *,
602
- # --- normal jax.vmap arguments --- #
603
- in_axes: int | None | Sequence[Any] = 0,
604
- out_axes: Any = 0,
605
- axis_name: AxisName | None = None,
606
- axis_size: int | None = None,
607
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
608
- # --- brainstate specific arguments --- #
609
- in_states: Dict[int, Dict] | Any | None = None,
610
- out_states: Dict[int, Dict] | Any | None = None,
611
- ) -> F | Callable[[F], F]:
612
- """
613
- Vectorizing map. Creates a function which maps ``fun`` over argument axes.
614
-
615
- The transformation :func:`vmap` is designed to work with ``pygraph`` structure
616
- defined in the ``brainstate`` library. It is used to vectorize functions by
617
- pushing the mapped axis down into primitive operations.
618
-
619
- More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
620
-
621
- These are several example usage::
622
-
623
- >>> import brainstate as brainstate
624
- >>> import jax.numpy as jnp
625
-
626
- >>> class Model(brainstate.nn.Module):
627
- >>> def __init__(self):
628
- >>> super().__init__()
629
- >>>
630
- >>> self.a = brainstate.ShortTermState(brainstate.random.randn(5))
631
- >>> self.b = brainstate.ShortTermState(brainstate.random.randn(5))
632
- >>> self.c = brainstate.State(brainstate.random.randn(1))
633
-
634
- >>> def __call__(self, *args, **kwargs):
635
- >>> self.c.value = self.a.value * self.b.value
636
- >>> return self.c.value + 1.
637
-
638
- >>> model = Model()
639
-
640
- >>> r = brainstate.augment.vmap(
641
- >>> model,
642
- >>> in_states=model.states(brainstate.ShortTermState),
643
- >>> out_states=model.c
644
- >>> )()
645
-
646
- Args:
647
- fn: Function to be mapped over additional axes.
648
- in_axes: An integer, None, or sequence of values specifying which input
649
- array axes to map over.
650
- out_axes: An integer, None, or (nested) standard Python container
651
- (tuple/list/dict) thereof indicating where the mapped axis should appear
652
- in the output.
653
- axis_name: Optional, a hashable Python object used to identify the mapped
654
- axis so that parallel collectives can be applied.
655
- axis_size: Optional, an integer indicating the size of the axis to be
656
- mapped. If not provided, the mapped axis size is inferred from arguments.
657
- spmd_axis_name: Optional, a hashable Python object or tuple of hashable
658
- Python objects used to identify the mapped axis so that parallel collectives
659
- can be applied. This is used to specify multiple axes to be mapped over
660
- in a nested :func:`vmap` call. The length of the tuple must match the
661
- number of nested :func:`vmap` calls. The first element of the tuple
662
- corresponds to the outermost :func:`vmap` call, the second element to
663
- the next outermost, and so on. If the tuple is not provided, the
664
- ``axis_name`` is used for all nested :func:`vmap` calls.
665
- in_states: Optional, the :class:`State` objects to be mapped over in the inputs.
666
- out_states: Optional, the :class:`State` objects to be mapped over in the outputs.
667
-
668
- Returns:
669
- Batched/vectorized version of ``fun`` with arguments that correspond to
670
- those of ``fun``, but with extra array axes at positions indicated by
671
- ``in_axes``, and a return value that corresponds to that of ``fun``, but
672
- with extra array axes at positions indicated by ``out_axes``.
673
-
674
- """
675
-
676
- if isinstance(fn, Missing):
677
- return functools.partial(
678
- _vmap_transform,
679
- in_axes=in_axes,
680
- out_axes=out_axes,
681
- in_states=in_states,
682
- out_states=out_states,
683
- axis_name=axis_name,
684
- axis_size=axis_size,
685
- spmd_axis_name=spmd_axis_name,
686
- ) # type: ignore[return-value]
687
-
688
- return _vmap_transform(
689
- fn,
690
- in_axes=in_axes,
691
- out_axes=out_axes,
692
- in_states=in_states,
693
- out_states=out_states,
694
- axis_name=axis_name,
695
- axis_size=axis_size,
696
- spmd_axis_name=spmd_axis_name,
697
- )
698
-
699
-
700
- def pmap(
701
- fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
702
- axis_name: Optional[AxisName] = None,
703
- *,
704
- in_axes: Any = 0,
705
- out_axes: Any = 0,
706
- static_broadcasted_argnums: int | Iterable[int] = (),
707
- devices: Optional[Sequence[Device]] = None, # noqa: F811
708
- backend: Optional[str] = None,
709
- axis_size: Optional[int] = None,
710
- donate_argnums: int | Iterable[int] = (),
711
- global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
712
- # brainstate specific arguments
713
- rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
714
- ) -> Callable[[F], F] | F:
715
- """
716
- Parallel map with support for collective operations.
717
-
718
- The purpose of :py:func:`pmap` is to express single-program multiple-data
719
- (SPMD) programs. Applying :py:func:`pmap` to a function will compile the
720
- function with XLA (similarly to :py:func:`jit`), then execute it in parallel
721
- on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it
722
- is comparable to :py:func:`vmap` because both transformations map a function
723
- over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
724
- mapped axis down into primitive operations, :py:func:`pmap` instead replicates
725
- the function and executes each replica on its own XLA device in parallel.
726
-
727
- The mapped axis size must be less than or equal to the number of local XLA
728
- devices available, as returned by :py:func:`jax.local_device_count()` (unless
729
- ``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
730
- product of the mapped axis sizes must be less than or equal to the number of
731
- XLA devices.
732
-
733
- More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
734
-
735
-
736
- Args:
737
- fn: Function to be mapped over argument axes. Its arguments and return
738
- value should be arrays, scalars, or (nested) standard Python containers
739
- (tuple/list/dict) thereof. Positional arguments indicated by
740
- ``static_broadcasted_argnums`` can be anything at all, provided they are
741
- hashable and have an equality operation defined.
742
- axis_name: Optional, a hashable Python object used to identify the mapped
743
- axis so that parallel collectives can be applied.
744
- in_axes: A non-negative integer, None, or nested Python container thereof
745
- that specifies which axes of positional arguments to map over. Arguments
746
- passed as keywords are always mapped over their leading axis (i.e. axis
747
- index 0). See :py:func:`vmap` for details.
748
- out_axes: A non-negative integer, None, or nested Python container thereof
749
- indicating where the mapped axis should appear in the output. All outputs
750
- with a mapped axis must have a non-None ``out_axes`` specification
751
- (see :py:func:`vmap`).
752
- static_broadcasted_argnums: An int or collection of ints specifying which
753
- positional arguments to treat as static (compile-time constant).
754
- Operations that only depend on static arguments will be constant-folded.
755
- Calling the pmapped function with different values for these constants
756
- will trigger recompilation. If the pmapped function is called with fewer
757
- positional arguments than indicated by ``static_broadcasted_argnums`` then
758
- an error is raised. Each of the static arguments will be broadcasted to
759
- all devices. Arguments that are not arrays or containers thereof must be
760
- marked as static. Defaults to ().
761
-
762
- Static arguments must be hashable, meaning both ``__hash__`` and
763
- ``__eq__`` are implemented, and should be immutable.
764
-
765
- devices: This is an experimental feature and the API is likely to change.
766
- Optional, a sequence of Devices to map over. (Available devices can be
767
- retrieved via jax.devices()). Must be given identically for each process
768
- in multi-process settings (and will therefore include devices across
769
- processes). If specified, the size of the mapped axis must be equal to
770
- the number of devices in the sequence local to the given process. Nested
771
- :py:func:`pmap` s with ``devices`` specified in either the inner or outer
772
- :py:func:`pmap` are not yet supported.
773
- backend: This is an experimental feature and the API is likely to change.
774
- Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
775
- axis_size: Optional; the size of the mapped axis.
776
- donate_argnums: Specify which positional argument buffers are "donated" to
777
- the computation. It is safe to donate argument buffers if you no longer need
778
- them once the computation has finished. In some cases XLA can make use of
779
- donated buffers to reduce the amount of memory needed to perform a
780
- computation, for example recycling one of your input buffers to store a
781
- result. You should not reuse buffers that you donate to a computation, JAX
782
- will raise an error if you try to.
783
- Note that donate_argnums only work for positional arguments, and keyword
784
- arguments will not be donated.
785
-
786
- For more details on buffer donation see the
787
- `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
788
- global_arg_shapes: Optional; a tuple of tuples of integers representing the
789
- shapes of the global arguments. These are arguments that are not replicated
790
- across devices, but are broadcasted to all devices. The tuple should have
791
- the same length as the number of global arguments, and each inner tuple
792
- should have the same length as the corresponding argument. The shapes of
793
- the global arguments must be the same on all devices.
794
- rngs: Optional, a random number generator or sequence of random number
795
- generators to be used in the mapped function. These random number
796
- generators are restored their random key after the mapped function is
797
- executed.
798
-
799
- Returns:
800
- A parallelized version of ``fun`` with arguments that correspond to those of
801
- ``fun`` but with extra array axes at positions indicated by ``in_axes`` and
802
- with output that has an additional leading array axis (with the same size).
803
-
804
- """
805
-
806
- if isinstance(fn, Missing):
807
- return functools.partial(
808
- pmap,
809
- axis_name=axis_name,
810
- in_axes=in_axes,
811
- out_axes=out_axes,
812
- static_broadcasted_argnums=static_broadcasted_argnums,
813
- devices=devices,
814
- backend=backend,
815
- axis_size=axis_size,
816
- donate_argnums=donate_argnums,
817
- global_arg_shapes=global_arg_shapes,
818
- rngs=rngs,
819
- ) # type: ignore[return-value]
820
-
821
- return restore_rngs(
822
- jax.pmap(
823
- fn,
824
- in_axes=in_axes,
825
- out_axes=out_axes,
826
- axis_name=axis_name,
827
- static_broadcasted_argnums=static_broadcasted_argnums,
828
- devices=devices,
829
- backend=backend,
830
- axis_size=axis_size,
831
- donate_argnums=donate_argnums,
832
- global_arg_shapes=global_arg_shapes,
833
- ),
834
- rngs=rngs
835
- )
836
-
837
-
838
- def _batch_and_remainder(x, batch_size: int):
839
- leaves, tree_def = jax.tree.flatten(x)
840
-
841
- scan_leaves = []
842
- remainder_leaves = []
843
-
844
- length = None
845
- for leaf in leaves:
846
- if length is None:
847
- length = leaf.shape[0]
848
- if length != leaf.shape[0]:
849
- raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
850
-
851
- num_batches, num_remainder = divmod(length, batch_size)
852
- for leaf in leaves:
853
- total_batch_elems = num_batches * batch_size
854
- scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
855
- if num_remainder:
856
- remainder_leaves.append(leaf[total_batch_elems:])
857
-
858
- scan_tree = tree_def.unflatten(scan_leaves)
859
- if num_remainder:
860
- remainder_tree = tree_def.unflatten(remainder_leaves)
861
- return scan_tree, remainder_tree
862
- else:
863
- return scan_tree, None
864
-
865
-
866
- def map(
867
- f,
868
- *xs,
869
- batch_size: int | None = None,
870
- ):
871
- """
872
- Map a function over leading array axes.
873
-
874
- Like Python's builtin map, except inputs and outputs are in the form of
875
- stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
876
- need to apply a function element by element for reduced memory usage or
877
- heterogeneous computation with other control flow primitives.
878
-
879
- When ``xs`` is an array type, the semantics of :func:`~map` are given by this
880
- Python implementation::
881
-
882
- def map(f, *xs):
883
- return np.stack([f(*x) for x in xs])
884
-
885
- Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
886
- many of the same advantages over a Python loop apply: ``xs`` may be an
887
- arbitrary nested pytree type, and the mapped computation is compiled only
888
- once.
889
-
890
- If ``batch_size`` is provided, the computation is executed in batches of that size
891
- and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
892
- version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
893
- divisible by the batch size, the remainder is processed in a separate ``vmap`` and
894
- concatenated to the result.
895
-
896
- >>> import jax.numpy as jnp
897
- >>> x = jnp.ones((10, 3, 4))
898
- >>> def f(x):
899
- ... print('inner shape:', x.shape)
900
- ... return x + 1
901
- >>> y = map(f, x, batch_size=3)
902
- inner shape: (3, 4)
903
- inner shape: (3, 4)
904
- >>> y.shape
905
- (10, 3, 4)
906
-
907
- In the example above, "inner shape" is printed twice, once while tracing the batched
908
- computation and once while tracing the remainder computation.
909
-
910
- Args:
911
- f: a Python function to apply element-wise over the first axis or axes of
912
- ``xs``.
913
- xs: values over which to map along the leading axis.
914
- batch_size: (optional) integer specifying the size of the batch for each step to execute
915
- in parallel.
916
-
917
- Returns:
918
- Mapped values.
919
- """
920
- if batch_size is not None:
921
- scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
922
- g = lambda _, x: ((), vmap(f)(*x))
923
- _, scan_ys = scan(g, (), scan_xs)
924
- if remainder_xs is None:
925
- ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
926
- else:
927
- remainder_ys = vmap(f)(*remainder_xs)
928
- ys = jax.tree.map(
929
- lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
930
- scan_ys,
931
- remainder_ys,
932
- )
933
- else:
934
- g = lambda _, x: ((), f(*x))
935
- _, ys = scan(g, (), xs)
936
- return ys
937
-
938
-
939
- def _flatten(x):
940
- return x.reshape(-1, *x.shape[2:])
941
-
942
-
943
- def _vmap_new_states_transform(
944
- fun: Callable[..., Any],
945
- *,
946
- # -- normal jax.vmap arguments -- #
947
- in_axes: int | None | Sequence[Any] = 0,
948
- out_axes: Any = 0,
949
- axis_name: AxisName | None = None,
950
- axis_size: int | None = None,
951
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
952
- # -- brainstate specific arguments -- #
953
- state_tag: str | None = None,
954
- state_to_exclude: Filter | None = None,
955
- in_states: Dict[int, Dict] | Any | None = None,
956
- out_states: Dict[int, Dict] | Any | None = None,
957
- ):
958
- # TODO: How about nested call ``vmap_new_states``?
959
- if isinstance(axis_size, int) and axis_size <= 0:
960
- raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
961
-
962
- @vmap(
963
- in_axes=in_axes,
964
- out_axes=out_axes,
965
- axis_name=axis_name,
966
- axis_size=axis_size,
967
- spmd_axis_name=spmd_axis_name,
968
- in_states=in_states,
969
- out_states=out_states,
970
- )
971
- def new_fun(args):
972
- # call the function
973
- with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
974
- out = fun(*args)
975
-
976
- # get vmap state values
977
- vmap_state_vals = catcher.get_state_values()
978
-
979
- return out, vmap_state_vals
980
-
981
- @functools.wraps(fun)
982
- def vmapped_fn(*args):
983
- # vmapping
984
- with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
985
- outs, vmap_state_vals = new_fun(args)
986
- vmap_states = catcher.get_states()
987
-
988
- # restore vmapped state values
989
- for st_val, st in zip(vmap_state_vals, vmap_states):
990
- st.restore_value(st_val)
991
- # ------------------------------------------------
992
- # --- this is CRUCIAL to avoid jax tracing leakage
993
- # ------------------------------------------------
994
- st.decrease_stack_level()
995
- return outs
996
-
997
- return vmapped_fn
998
-
999
-
1000
- def vmap_new_states(
1001
- fun: Callable = Missing(),
1002
- *,
1003
- # -- normal jax.vmap arguments -- #
1004
- in_axes: int | None | Sequence[Any] = 0,
1005
- out_axes: Any = 0,
1006
- axis_name: AxisName | None = None,
1007
- axis_size: int | None = None,
1008
- spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
1009
- # -- brainstate specific arguments -- #
1010
- state_tag: str | None = None,
1011
- state_to_exclude: Filter = None,
1012
- in_states: Dict[int, Dict] | Any | None = None,
1013
- out_states: Dict[int, Dict] | Any | None = None,
1014
- ):
1015
- """
1016
- Vectorize a function over new states created within it.
1017
-
1018
- This function applies JAX's vmap transformation to newly created states
1019
- during the function's execution. It allows for more
1020
- flexible vectorization in the context of stateful computations.
1021
-
1022
- Args:
1023
- fun (Callable, optional): The function to be vectorized. Defaults to Missing().
1024
- in_axes (int | None | Sequence[Any], optional): Specification of input axes for vectorization. Defaults to 0.
1025
- out_axes (Any, optional): Specification of output axes after vectorization. Defaults to 0.
1026
- axis_name (AxisName, optional): Name of the axis being vectorized over. Defaults to None.
1027
- axis_size (int, optional): Size of the axis being vectorized over. Defaults to None.
1028
- spmd_axis_name (AxisName | tuple[AxisName, ...], optional): Name(s) of SPMD axis/axes. Defaults to None.
1029
- state_tag (str, optional): A tag to identify specific states. Defaults to None.
1030
- state_to_exclude (Sequence[int], optional): Indices of states to exclude from vectorization. Defaults to ().
1031
-
1032
- Returns:
1033
- Callable: A vectorized version of the input function that handles new state creation.
1034
- """
1035
- if isinstance(fun, Missing):
1036
- return functools.partial(
1037
- _vmap_new_states_transform,
1038
- in_axes=in_axes,
1039
- out_axes=out_axes,
1040
- axis_name=axis_name,
1041
- axis_size=axis_size,
1042
- spmd_axis_name=spmd_axis_name,
1043
- state_tag=state_tag,
1044
- state_to_exclude=state_to_exclude,
1045
- in_states=in_states,
1046
- out_states=out_states,
1047
- )
1048
- else:
1049
- return _vmap_new_states_transform(
1050
- fun,
1051
- in_axes=in_axes,
1052
- out_axes=out_axes,
1053
- axis_name=axis_name,
1054
- axis_size=axis_size,
1055
- spmd_axis_name=spmd_axis_name,
1056
- state_tag=state_tag,
1057
- state_to_exclude=state_to_exclude,
1058
- in_states=in_states,
1059
- out_states=out_states,
1060
- )
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from typing import (
18
+ Any,
19
+ TypeVar,
20
+ Callable,
21
+ Hashable,
22
+ Sequence,
23
+ Iterable,
24
+ Tuple,
25
+ Union,
26
+ Optional,
27
+ Dict,
28
+ List
29
+ )
30
+
31
+ import jax
32
+ from jax.interpreters.batching import BatchTracer
33
+
34
+ from brainstate._compatible_import import Device
35
+ from brainstate._state import State, catch_new_states
36
+ from brainstate.compile import scan, StatefulFunction
37
+ from brainstate.random import RandomState, DEFAULT
38
+ from brainstate.typing import Missing, Filter
39
+ from brainstate.util import NestedDict, BrainStateError
40
+ from ._random import restore_rngs
41
+
42
+ __all__ = [
43
+ 'vmap',
44
+ 'pmap',
45
+ 'map',
46
+ 'vmap_new_states',
47
+ ]
48
+
49
+ F = TypeVar("F", bound=Callable)
50
+ AxisName = Hashable
51
+ AxisToState = Dict[int, List[State]]
52
+ StateToAxis = Dict[State, int]
53
+
54
+
55
+ class BatchAxisError(BrainStateError):
56
+ """
57
+ Exception raised for errors related to batch axis operations.
58
+
59
+ This custom exception is used to indicate errors that occur during
60
+ batch processing or vectorization operations, particularly in the
61
+ context of state management in the BrainState framework.
62
+
63
+ Inherits from:
64
+ BrainStateError: The base error class for BrainState-related exceptions.
65
+ """
66
+ pass
67
+
68
+
69
+ def _flatten_in_out_states(
70
+ in_states: Dict[int, Dict] | Any = None,
71
+ ) -> Tuple[AxisToState, StateToAxis]:
72
+ """
73
+ Flattens and organizes input or output states into axis-based mappings.
74
+
75
+ This function processes the input or output states, converting them into two
76
+ dictionary representations: one mapping axes to states, and another mapping
77
+ states to axes. It handles both structured (Dict[int, Dict]) and unstructured
78
+ input formats.
79
+
80
+ Args:
81
+ in_states (Dict[int, Dict] | Any, optional): The input or output states to be
82
+ flattened. Can be a nested dictionary structure where the outer keys are
83
+ axes and inner dictionaries contain states, or any other structure
84
+ containing states. Defaults to None.
85
+
86
+ Returns:
87
+ Tuple[AxisToState, StateToAxis]: A tuple containing two dictionaries:
88
+ - AxisToState: Maps axes (int) to lists of states.
89
+ - StateToAxis: Maps individual states to their corresponding axes (int).
90
+
91
+ Note:
92
+ If in_states is None, empty dictionaries are returned for both mappings.
93
+ If in_states is not in the expected Dict[int, Dict] format, all states are
94
+ assigned to axis 0.
95
+ """
96
+ if in_states is None:
97
+ return dict(), dict()
98
+ if isinstance(in_states, dict):
99
+ keys = tuple(in_states.keys())
100
+ values = tuple(in_states.values())
101
+ is_axis_in_states = (
102
+ all([isinstance(key, int) for key in keys]) and
103
+ all([isinstance(value, dict) for value in values])
104
+ )
105
+ else:
106
+ is_axis_in_states = False
107
+ if is_axis_in_states:
108
+ axis_to_states = {key: list(value.values()) for key, value in in_states.items()}
109
+ state_to_axis = {}
110
+ for key, value in in_states.items():
111
+ for state in value.values():
112
+ state_to_axis[state] = key
113
+ return axis_to_states, state_to_axis
114
+ else:
115
+ in_states = jax.tree.leaves(in_states)
116
+ axis_to_states = {0: list(in_states)}
117
+ state_to_axis = {state: 0 for state in in_states}
118
+ return axis_to_states, state_to_axis
119
+
120
+
121
+ def _remove_axis(x, axis: int):
122
+ """
123
+ Remove a specified axis from an array or nested structure.
124
+
125
+ This function removes a specified axis from an array or nested structure,
126
+ adjusting the shape and structure of the output accordingly.
127
+
128
+ Args:
129
+ x (Any): The input array or nested structure to remove the axis from.
130
+ axis (int): The axis to remove from the input.
131
+
132
+ Returns:
133
+ Any: The output array or nested structure with the specified axis removed.
134
+ """
135
+ assert isinstance(axis, int), f"Expected axis to be an integer, but got {type(axis)}"
136
+ if axis < 0:
137
+ axis += x.ndim
138
+ if axis < 0 or axis >= x.ndim:
139
+ raise IndexError(f"Axis {axis} is out of bounds for array of shape {x.shape}")
140
+ return x[tuple(slice(None, None, None) if i != axis else 0 for i in range(x.ndim))]
141
+
142
+
143
+ def _compile_stateful_function(
144
+ stateful_fn: StatefulFunction,
145
+ in_axes: int | Tuple[int, ...],
146
+ args: Tuple
147
+ ):
148
+ """
149
+ Compile a stateful function with specified input axes and arguments.
150
+
151
+ This function prepares and compiles a stateful function for vectorized mapping (vmap)
152
+ by adjusting the input arguments based on the specified axes and then generating
153
+ the function's JAX program representation (jaxpr).
154
+
155
+ Args:
156
+ stateful_fn (StatefulFunction): The stateful function to be compiled.
157
+ in_axes (int | Tuple[int, ...]): Specifies which axes of the input arguments
158
+ to map over. Can be a single integer (same for all args) or a tuple of integers.
159
+ args (Tuple): The input arguments to the function.
160
+
161
+ Raises:
162
+ ValueError: If the length of in_axes tuple doesn't match the number of arguments.
163
+
164
+ Returns:
165
+ None. The function modifies the stateful_fn in-place by calling make_jaxpr.
166
+ """
167
+ in_axes_st, in_axes = in_axes
168
+ state_vals, args = args
169
+
170
+ # check in_axes
171
+ if isinstance(in_axes, tuple) and len(in_axes) != len(args):
172
+ raise ValueError(
173
+ "vmap in_axes must be an int, None, or a tuple of entries corresponding "
174
+ "to the positional arguments passed to the function, "
175
+ f"but got {len(in_axes)=}, {len(args)=}"
176
+ )
177
+
178
+ # check state_vals
179
+ if len(state_vals) > 0:
180
+ state_vals = [jax.tree.map(lambda x: _remove_axis(x, axis), vals)
181
+ for vals, axis in zip(state_vals, in_axes_st)]
182
+ else:
183
+ state_vals = []
184
+
185
+ if isinstance(in_axes, int):
186
+ args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
187
+ elif isinstance(in_axes, tuple):
188
+ args = tuple([
189
+ arg if in_axis is None else _remove_axis(arg, in_axis)
190
+ for arg, in_axis in zip(args, in_axes)
191
+ ])
192
+ stateful_fn.make_jaxpr(state_vals, args)
193
+ return stateful_fn.get_arg_cache_key(state_vals, args)
194
+
195
+
196
+ def _get_batch_size(
197
+ args: Tuple,
198
+ in_axes: int | Tuple[int, ...],
199
+ in_states: AxisToState,
200
+ axis_size: Optional[int] = None,
201
+ ) -> int:
202
+ """
203
+ Determine the batch size from input arguments, axes, and states.
204
+
205
+ This function calculates the batch size by examining the shapes of input arguments
206
+ and states along specified axes. It ensures consistency across all inputs.
207
+
208
+ Args:
209
+ args (Tuple): The input arguments to the function being vectorized.
210
+ in_axes (int | Tuple[int, ...]): The axes along which to vectorize for each argument.
211
+ Can be a single integer (same for all args) or a tuple of integers.
212
+ in_states (AxisToState): A dictionary mapping axes to lists of states.
213
+
214
+ Returns:
215
+ int: The determined batch size.
216
+
217
+ Raises:
218
+ ValueError: If unable to determine batch size or if inconsistent batch sizes are found.
219
+ """
220
+ batch_sizes = []
221
+
222
+ # Check batch size from args and in_axes
223
+ if isinstance(in_axes, int):
224
+ in_axes = (in_axes,) * len(args)
225
+ for arg, in_axis in zip(args, in_axes):
226
+ if in_axis is not None:
227
+ arg_leaves = jax.tree.leaves(arg)
228
+ if arg_leaves:
229
+ batch_sizes.append(arg_leaves[0].shape[in_axis])
230
+
231
+ # Check batch size from in_states
232
+ if in_states is not None:
233
+ for axis, states in in_states.items():
234
+ for state in states:
235
+ state_leaves = jax.tree.leaves(state.value)
236
+ if len(state_leaves):
237
+ batch_sizes.append(state_leaves[0].shape[axis])
238
+
239
+ if len(batch_sizes) == 0:
240
+ assert axis_size is not None, (
241
+ "Unable to determine batch size. Please provide the 'axis_size' argument."
242
+ )
243
+ return axis_size
244
+ else:
245
+ # Ensure all batch sizes are consistent
246
+ if len(set(batch_sizes)) > 1:
247
+ raise ValueError(f"Inconsistent batch sizes found: {set(batch_sizes)}")
248
+
249
+ return batch_sizes[0]
250
+
251
+
252
+ def _format_state_axes(
253
+ in_states,
254
+ out_states,
255
+ ):
256
+ """
257
+ Format and validate the axes of input and output states.
258
+
259
+ This function processes the input and output states, ensuring consistency
260
+ between their axis mappings. It also handles cases where a state appears
261
+ in the input but not in the output.
262
+
263
+ Args:
264
+ in_states: The input states to be formatted. Can be a dictionary mapping
265
+ axes to states, or any other structure containing states.
266
+ out_states: The output states to be formatted. Can be a dictionary mapping
267
+ axes to states, or any other structure containing states.
268
+
269
+ Returns:
270
+ A tuple containing four elements:
271
+ - axis_to_in_states (dict): Mapping of axes to input states.
272
+ - in_state_to_axis (dict): Mapping of input states to their axes.
273
+ - axis_to_out_states (dict): Mapping of axes to output states.
274
+ - out_state_to_axis (dict): Mapping of output states to their axes.
275
+
276
+ Raises:
277
+ BatchAxisError: If there's an inconsistency between the axis mappings
278
+ of input and output states.
279
+ """
280
+ axis_to_in_states, in_state_to_axis = _flatten_in_out_states(in_states)
281
+ axis_to_out_states, out_state_to_axis = _flatten_in_out_states(out_states)
282
+ for _in_state, _axis in in_state_to_axis.items():
283
+ if _in_state in out_state_to_axis:
284
+ _out_axis = out_state_to_axis[_in_state]
285
+ if _out_axis != _axis:
286
+ _in_state.raise_error_with_source_info(
287
+ BatchAxisError(
288
+ f"State {_in_state} has been mapped to axis {_axis} in 'in_states', "
289
+ f"However, it is mapped to axis {_out_axis} in 'out_states'."
290
+ )
291
+ )
292
+ else:
293
+ out_state_to_axis[_in_state] = _axis
294
+ if _axis not in axis_to_out_states:
295
+ axis_to_out_states[_axis] = []
296
+ axis_to_out_states[_axis].append(_in_state)
297
+
298
+ return axis_to_in_states, in_state_to_axis, axis_to_out_states, out_state_to_axis
299
+
300
+
301
+ def _vmap_transform(
302
+ f: F,
303
+ *,
304
+ in_axes: int | None | Sequence[Any] = 0,
305
+ out_axes: Any = 0,
306
+ in_states: Dict[int, Dict] | Any | None = None,
307
+ out_states: Dict[int, Dict] | Any | None = None,
308
+ axis_size: Optional[int] = None,
309
+ axis_name: AxisName | None = None,
310
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
311
+ ):
312
+ """
313
+ Transforms a function for vectorized mapping (vmap) with state management.
314
+
315
+ This internal function applies vectorized mapping to the input function while
316
+ handling state management for input and output states. It supports custom
317
+ axis specifications for both inputs and outputs.
318
+
319
+ Args:
320
+ f (F): The function to be transformed for vectorized mapping.
321
+ in_axes (int | None | Sequence[Any]): Specifies which axes of the input
322
+ arguments to map over. Default is 0.
323
+ out_axes (Any): Specifies where the mapped axis should appear in the output.
324
+ Default is 0.
325
+ in_states (Dict[int, Dict] | Any | None): Specifies the input states and
326
+ their corresponding axes for mapping. Default is None.
327
+ out_states (Dict[int, Dict] | Any | None): Specifies the output states and
328
+ their corresponding axes for mapping. Default is None.
329
+ **transform_kwargs: Additional keyword arguments for the transformation.
330
+
331
+ Returns:
332
+ Callable: A new function that applies vectorized mapping to the input
333
+ function while managing states.
334
+ """
335
+
336
+ # TODO: support jax.disable_jit()
337
+
338
+ # format state axes
339
+ (
340
+ axis_to_in_states,
341
+ in_state_to_axis,
342
+ axis_to_out_states,
343
+ out_state_to_axis
344
+ ) = _format_state_axes(in_states, out_states)
345
+
346
+ # check in_axes
347
+ if isinstance(in_axes, list):
348
+ # To be a tree prefix of the positional args tuple, in_axes can never be a
349
+ # list: if in_axes is not a leaf, it must be a tuple of trees. However,
350
+ # in cases like these users expect tuples and lists to be treated
351
+ # essentially interchangeably, so we canonicalize lists to tuples here
352
+ # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
353
+ in_axes = tuple(in_axes)
354
+
355
+ def _vmap_fn_for_compilation(in_vmap_state_vals, args):
356
+ """
357
+ Compile a function for vectorized mapping (vmap) with state restoration.
358
+
359
+ This internal function is used to prepare a function for vectorized mapping
360
+ by restoring state values before calling the original function.
361
+
362
+ Args:
363
+ in_vmap_state_vals (List[List]): A nested list containing the state values
364
+ to be restored. The outer list corresponds to different axes, while
365
+ the inner lists contain the state values for each axis.
366
+ args (Tuple): The arguments to be passed to the original function after
367
+ state restoration.
368
+
369
+ Returns:
370
+ Any: The result of calling the original function 'f' with the restored
371
+ state and provided arguments.
372
+ """
373
+ # restore state values
374
+ for i, states in enumerate(axis_to_in_states.values()):
375
+ for state, state_val in zip(states, in_vmap_state_vals[i]):
376
+ state.restore_value(state_val)
377
+
378
+ # call the function
379
+ return f(*args)
380
+
381
+ def _set_axis_env(batch_size):
382
+ axis_env = None if axis_name is None else [(axis_name, batch_size)]
383
+ stateful_fn.axis_env = axis_env
384
+
385
+ # stateful function
386
+ stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
387
+
388
+ @functools.wraps(f)
389
+ def new_fn_for_vmap(
390
+ rng_keys,
391
+ in_state_vmap_vals,
392
+ in_state_oth_vals,
393
+ args,
394
+ ):
395
+ """
396
+ Wrapper function for vectorized mapping (vmap) that handles state restoration and function execution.
397
+
398
+ This function restores state values, random number generators (RNGs), and other state values
399
+ before calling the original function. It then processes the outputs and prepares them for
400
+ vectorized mapping.
401
+
402
+ Args:
403
+ rng_keys (Sequence): Random number generator keys for each mapped instance.
404
+ in_state_vmap_vals (Sequence[Sequence]): Input state values for vectorized mapping,
405
+ organized by axis.
406
+ in_state_oth_vals (Sequence): Other input state values not involved in vectorized mapping.
407
+ args (Tuple): Arguments to be passed to the original function.
408
+
409
+ Returns:
410
+ Tuple: A tuple containing four elements:
411
+ - out_rng_keys (List): Updated RNG keys after function execution.
412
+ - out_state_vmap_vals (List[List]): Output state values for vectorized mapping,
413
+ organized by axis.
414
+ - out_state_oth_vals (List): Other output state values not involved in vectorized mapping.
415
+ - outs: The output of the original function call.
416
+
417
+ Raises:
418
+ AssertionError: If there's a mismatch in the number of states, state values, or RNG keys.
419
+ BatchAxisError: If a state value is batched but not included in out_states.
420
+ """
421
+ # restore vmapping state values
422
+ for i, states in enumerate(axis_to_in_states.values()):
423
+ assert len(states) == len(in_state_vmap_vals[i]), (
424
+ f"The number of states in axis {i} should be equal to the number "
425
+ f"of state values, but got {len(states)} and {len(in_state_vmap_vals[i])}."
426
+ )
427
+ for state, state_val in zip(states, in_state_vmap_vals[i]):
428
+ state.restore_value(state_val)
429
+
430
+ # restore rngs
431
+ cache_key = stateful_fn.get_arg_cache_key(in_state_vmap_vals, args)
432
+ state_trace = stateful_fn.get_state_trace(cache_key)
433
+ rngs = state_trace.state_subset(RandomState)
434
+ rng_sets = set(rngs)
435
+ assert len(rngs) == len(rng_keys), (
436
+ f"The number of random states in the function should be equal to the number "
437
+ f"of random keys, but got {len(rngs)} and {len(rng_keys)}."
438
+ )
439
+ for rng, key in zip(rngs, rng_keys):
440
+ rng.restore_value(key)
441
+
442
+ # restore other state values
443
+ oth_in_state = [
444
+ st for st in state_trace.states
445
+ if st not in in_state_to_axis and st not in rng_sets
446
+ ]
447
+ assert len(oth_in_state) == len(in_state_oth_vals), (
448
+ f"The number of states in 'in_states' should be equal to the number "
449
+ f"of state values, but got {len(oth_in_state)} and {len(in_state_oth_vals)}."
450
+ )
451
+ for state, state_val in zip(oth_in_state, in_state_oth_vals):
452
+ state.restore_value(state_val)
453
+
454
+ # call the function
455
+ outs = stateful_fn.jaxpr_call_auto(in_state_vmap_vals, args)
456
+
457
+ # analyze vmapping axis error
458
+ for state in state_trace.get_write_states():
459
+ leaves = jax.tree.leaves(state.value)
460
+ if (
461
+ any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
462
+ and state not in out_state_to_axis
463
+ ):
464
+ if isinstance(state, RandomState) and state in rng_sets:
465
+ continue
466
+ state.raise_error_with_source_info(
467
+ BatchAxisError(f"The value of State {state} is batched, "
468
+ f"but it is not in the out_states.")
469
+ )
470
+
471
+ # out state values for vmapping
472
+ out_state_vmap_vals = [
473
+ [state.value for state in states]
474
+ for axis, states in axis_to_out_states.items()
475
+ ]
476
+ out_state_oth_vals = [
477
+ st.value for st in state_trace.states
478
+ if st not in out_state_to_axis and st not in rng_sets
479
+ ]
480
+ out_rng_keys = [rng.value for rng in rngs]
481
+ return out_rng_keys, out_state_vmap_vals, out_state_oth_vals, outs
482
+
483
+ @functools.wraps(f)
484
+ def vmapped_fn(*args, **kwargs):
485
+ """
486
+ Applies vectorized mapping (vmap) to the input function while managing state.
487
+
488
+ This function handles the vectorization process, including state management,
489
+ random number generation, and function compilation. It prepares the input
490
+ states, compiles the stateful function, manages random number generators,
491
+ applies the vmap transformation, and restores the output states.
492
+
493
+ Args:
494
+ *args: Variable length argument list containing the input arguments
495
+ to be passed to the vectorized function.
496
+
497
+ Returns:
498
+ Any: The output of the vectorized function after applying vmap and
499
+ managing states.
500
+
501
+ Note:
502
+ This function assumes the existence of several helper functions and
503
+ data structures (e.g., axis_to_in_states, in_state_to_axis) which
504
+ should be defined in the broader context.
505
+ """
506
+ if len(kwargs):
507
+ raise NotImplementedError(
508
+ "Keyword arguments `f(**kwargs)` are not supported in brainstate.augment.vmap"
509
+ )
510
+
511
+ # in states values
512
+ in_state_map_vals = [
513
+ [st.value for st in states]
514
+ for axis, states in axis_to_in_states.items()
515
+ ]
516
+ st_in_axes = list(axis_to_in_states.keys())
517
+ if len(st_in_axes) == 0:
518
+ st_in_axes = 0
519
+
520
+ # compile stateful function
521
+ batch_size = None
522
+ if axis_name is not None:
523
+ batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
524
+ _set_axis_env(batch_size)
525
+ cache_key = _compile_stateful_function(
526
+ stateful_fn,
527
+ (st_in_axes, in_axes),
528
+ (in_state_map_vals, args)
529
+ )
530
+
531
+ # random keys
532
+ state_trace = stateful_fn.get_state_trace(cache_key)
533
+ rngs = state_trace.state_subset(RandomState)
534
+ rng_sets = set(rngs)
535
+ if len(rngs):
536
+ # batch size
537
+ if batch_size is None:
538
+ batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
539
+ rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
540
+ rng_backup = tuple(rng.split_key() for rng in rngs)
541
+ else:
542
+ rng_keys = tuple()
543
+ rng_backup = tuple()
544
+
545
+ # in states other values
546
+ in_state_oth_vals = [
547
+ st.value
548
+ for st in state_trace.states
549
+ if st not in in_state_to_axis and st not in rng_sets
550
+ ]
551
+
552
+ # out state axis
553
+ st_out_axes = list(axis_to_out_states.keys())
554
+ if len(st_out_axes) == 0:
555
+ st_out_axes = 0
556
+
557
+ # --- vmapping --- #
558
+ fn = jax.vmap(
559
+ new_fn_for_vmap,
560
+ in_axes=(0, st_in_axes, None, in_axes),
561
+ out_axes=(0, st_out_axes, None, out_axes),
562
+ axis_size=axis_size,
563
+ axis_name=axis_name,
564
+ spmd_axis_name=spmd_axis_name,
565
+ )
566
+ _, out_state_map_vals, out_state_oth_vals, outs = fn(
567
+ rng_keys, in_state_map_vals, in_state_oth_vals, args
568
+ )
569
+
570
+ # restore mapped state values
571
+ for i, states in enumerate(axis_to_out_states.values()):
572
+ assert len(states) == len(out_state_map_vals[i]), (
573
+ f"The number of states in axis {i} should be equal to the number "
574
+ f"of state values, but got {len(states)} and {len(out_state_map_vals[i])}."
575
+ )
576
+ for state, st_val in zip(states, out_state_map_vals[i]):
577
+ state.restore_value(st_val)
578
+
579
+ # restore other state values
580
+ out_oth_states = [
581
+ st for st in state_trace.states
582
+ if st not in out_state_to_axis and st not in rng_sets
583
+ ]
584
+ assert len(out_oth_states) == len(out_state_oth_vals), (
585
+ f"The number of states in 'out_states' should be equal to the number "
586
+ f"of state values, but got {len(out_oth_states)} and {len(out_state_oth_vals)}."
587
+ )
588
+ for state, st_val in zip(out_oth_states, out_state_oth_vals):
589
+ state.restore_value(st_val)
590
+
591
+ # restore random keys
592
+ for rng, key in zip(rngs, rng_backup):
593
+ rng.restore_value(key)
594
+ return outs
595
+
596
+ return vmapped_fn
597
+
598
+
599
+ def vmap(
600
+ fn: F | Missing = Missing(),
601
+ *,
602
+ # --- normal jax.vmap arguments --- #
603
+ in_axes: int | None | Sequence[Any] = 0,
604
+ out_axes: Any = 0,
605
+ axis_name: AxisName | None = None,
606
+ axis_size: int | None = None,
607
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
608
+ # --- brainstate specific arguments --- #
609
+ in_states: Dict[int, Dict] | Any | None = None,
610
+ out_states: Dict[int, Dict] | Any | None = None,
611
+ ) -> F | Callable[[F], F]:
612
+ """
613
+ Vectorizing map. Creates a function which maps ``fun`` over argument axes.
614
+
615
+ The transformation :func:`vmap` is designed to work with ``pygraph`` structure
616
+ defined in the ``brainstate`` library. It is used to vectorize functions by
617
+ pushing the mapped axis down into primitive operations.
618
+
619
+ More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
620
+
621
+ These are several example usage::
622
+
623
+ >>> import brainstate as brainstate
624
+ >>> import jax.numpy as jnp
625
+
626
+ >>> class Model(brainstate.nn.Module):
627
+ >>> def __init__(self):
628
+ >>> super().__init__()
629
+ >>>
630
+ >>> self.a = brainstate.ShortTermState(brainstate.random.randn(5))
631
+ >>> self.b = brainstate.ShortTermState(brainstate.random.randn(5))
632
+ >>> self.c = brainstate.State(brainstate.random.randn(1))
633
+
634
+ >>> def __call__(self, *args, **kwargs):
635
+ >>> self.c.value = self.a.value * self.b.value
636
+ >>> return self.c.value + 1.
637
+
638
+ >>> model = Model()
639
+
640
+ >>> r = brainstate.augment.vmap(
641
+ >>> model,
642
+ >>> in_states=model.states(brainstate.ShortTermState),
643
+ >>> out_states=model.c
644
+ >>> )()
645
+
646
+ Args:
647
+ fn: Function to be mapped over additional axes.
648
+ in_axes: An integer, None, or sequence of values specifying which input
649
+ array axes to map over.
650
+ out_axes: An integer, None, or (nested) standard Python container
651
+ (tuple/list/dict) thereof indicating where the mapped axis should appear
652
+ in the output.
653
+ axis_name: Optional, a hashable Python object used to identify the mapped
654
+ axis so that parallel collectives can be applied.
655
+ axis_size: Optional, an integer indicating the size of the axis to be
656
+ mapped. If not provided, the mapped axis size is inferred from arguments.
657
+ spmd_axis_name: Optional, a hashable Python object or tuple of hashable
658
+ Python objects used to identify the mapped axis so that parallel collectives
659
+ can be applied. This is used to specify multiple axes to be mapped over
660
+ in a nested :func:`vmap` call. The length of the tuple must match the
661
+ number of nested :func:`vmap` calls. The first element of the tuple
662
+ corresponds to the outermost :func:`vmap` call, the second element to
663
+ the next outermost, and so on. If the tuple is not provided, the
664
+ ``axis_name`` is used for all nested :func:`vmap` calls.
665
+ in_states: Optional, the :class:`State` objects to be mapped over in the inputs.
666
+ out_states: Optional, the :class:`State` objects to be mapped over in the outputs.
667
+
668
+ Returns:
669
+ Batched/vectorized version of ``fun`` with arguments that correspond to
670
+ those of ``fun``, but with extra array axes at positions indicated by
671
+ ``in_axes``, and a return value that corresponds to that of ``fun``, but
672
+ with extra array axes at positions indicated by ``out_axes``.
673
+
674
+ """
675
+
676
+ if isinstance(fn, Missing):
677
+ return functools.partial(
678
+ _vmap_transform,
679
+ in_axes=in_axes,
680
+ out_axes=out_axes,
681
+ in_states=in_states,
682
+ out_states=out_states,
683
+ axis_name=axis_name,
684
+ axis_size=axis_size,
685
+ spmd_axis_name=spmd_axis_name,
686
+ ) # type: ignore[return-value]
687
+
688
+ return _vmap_transform(
689
+ fn,
690
+ in_axes=in_axes,
691
+ out_axes=out_axes,
692
+ in_states=in_states,
693
+ out_states=out_states,
694
+ axis_name=axis_name,
695
+ axis_size=axis_size,
696
+ spmd_axis_name=spmd_axis_name,
697
+ )
698
+
699
+
700
+ def pmap(
701
+ fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
702
+ axis_name: Optional[AxisName] = None,
703
+ *,
704
+ in_axes: Any = 0,
705
+ out_axes: Any = 0,
706
+ static_broadcasted_argnums: int | Iterable[int] = (),
707
+ devices: Optional[Sequence[Device]] = None, # noqa: F811
708
+ backend: Optional[str] = None,
709
+ axis_size: Optional[int] = None,
710
+ donate_argnums: int | Iterable[int] = (),
711
+ global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
712
+ # brainstate specific arguments
713
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
714
+ ) -> Callable[[F], F] | F:
715
+ """
716
+ Parallel map with support for collective operations.
717
+
718
+ The purpose of :py:func:`pmap` is to express single-program multiple-data
719
+ (SPMD) programs. Applying :py:func:`pmap` to a function will compile the
720
+ function with XLA (similarly to :py:func:`jit`), then execute it in parallel
721
+ on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it
722
+ is comparable to :py:func:`vmap` because both transformations map a function
723
+ over array axes, but where :py:func:`vmap` vectorizes functions by pushing the
724
+ mapped axis down into primitive operations, :py:func:`pmap` instead replicates
725
+ the function and executes each replica on its own XLA device in parallel.
726
+
727
+ The mapped axis size must be less than or equal to the number of local XLA
728
+ devices available, as returned by :py:func:`jax.local_device_count()` (unless
729
+ ``devices`` is specified, see below). For nested :py:func:`pmap` calls, the
730
+ product of the mapped axis sizes must be less than or equal to the number of
731
+ XLA devices.
732
+
733
+ More information please see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
734
+
735
+
736
+ Args:
737
+ fn: Function to be mapped over argument axes. Its arguments and return
738
+ value should be arrays, scalars, or (nested) standard Python containers
739
+ (tuple/list/dict) thereof. Positional arguments indicated by
740
+ ``static_broadcasted_argnums`` can be anything at all, provided they are
741
+ hashable and have an equality operation defined.
742
+ axis_name: Optional, a hashable Python object used to identify the mapped
743
+ axis so that parallel collectives can be applied.
744
+ in_axes: A non-negative integer, None, or nested Python container thereof
745
+ that specifies which axes of positional arguments to map over. Arguments
746
+ passed as keywords are always mapped over their leading axis (i.e. axis
747
+ index 0). See :py:func:`vmap` for details.
748
+ out_axes: A non-negative integer, None, or nested Python container thereof
749
+ indicating where the mapped axis should appear in the output. All outputs
750
+ with a mapped axis must have a non-None ``out_axes`` specification
751
+ (see :py:func:`vmap`).
752
+ static_broadcasted_argnums: An int or collection of ints specifying which
753
+ positional arguments to treat as static (compile-time constant).
754
+ Operations that only depend on static arguments will be constant-folded.
755
+ Calling the pmapped function with different values for these constants
756
+ will trigger recompilation. If the pmapped function is called with fewer
757
+ positional arguments than indicated by ``static_broadcasted_argnums`` then
758
+ an error is raised. Each of the static arguments will be broadcasted to
759
+ all devices. Arguments that are not arrays or containers thereof must be
760
+ marked as static. Defaults to ().
761
+
762
+ Static arguments must be hashable, meaning both ``__hash__`` and
763
+ ``__eq__`` are implemented, and should be immutable.
764
+
765
+ devices: This is an experimental feature and the API is likely to change.
766
+ Optional, a sequence of Devices to map over. (Available devices can be
767
+ retrieved via jax.devices()). Must be given identically for each process
768
+ in multi-process settings (and will therefore include devices across
769
+ processes). If specified, the size of the mapped axis must be equal to
770
+ the number of devices in the sequence local to the given process. Nested
771
+ :py:func:`pmap` s with ``devices`` specified in either the inner or outer
772
+ :py:func:`pmap` are not yet supported.
773
+ backend: This is an experimental feature and the API is likely to change.
774
+ Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
775
+ axis_size: Optional; the size of the mapped axis.
776
+ donate_argnums: Specify which positional argument buffers are "donated" to
777
+ the computation. It is safe to donate argument buffers if you no longer need
778
+ them once the computation has finished. In some cases XLA can make use of
779
+ donated buffers to reduce the amount of memory needed to perform a
780
+ computation, for example recycling one of your input buffers to store a
781
+ result. You should not reuse buffers that you donate to a computation, JAX
782
+ will raise an error if you try to.
783
+ Note that donate_argnums only work for positional arguments, and keyword
784
+ arguments will not be donated.
785
+
786
+ For more details on buffer donation see the
787
+ `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
788
+ global_arg_shapes: Optional; a tuple of tuples of integers representing the
789
+ shapes of the global arguments. These are arguments that are not replicated
790
+ across devices, but are broadcasted to all devices. The tuple should have
791
+ the same length as the number of global arguments, and each inner tuple
792
+ should have the same length as the corresponding argument. The shapes of
793
+ the global arguments must be the same on all devices.
794
+ rngs: Optional, a random number generator or sequence of random number
795
+ generators to be used in the mapped function. These random number
796
+ generators are restored their random key after the mapped function is
797
+ executed.
798
+
799
+ Returns:
800
+ A parallelized version of ``fun`` with arguments that correspond to those of
801
+ ``fun`` but with extra array axes at positions indicated by ``in_axes`` and
802
+ with output that has an additional leading array axis (with the same size).
803
+
804
+ """
805
+
806
+ if isinstance(fn, Missing):
807
+ return functools.partial(
808
+ pmap,
809
+ axis_name=axis_name,
810
+ in_axes=in_axes,
811
+ out_axes=out_axes,
812
+ static_broadcasted_argnums=static_broadcasted_argnums,
813
+ devices=devices,
814
+ backend=backend,
815
+ axis_size=axis_size,
816
+ donate_argnums=donate_argnums,
817
+ global_arg_shapes=global_arg_shapes,
818
+ rngs=rngs,
819
+ ) # type: ignore[return-value]
820
+
821
+ return restore_rngs(
822
+ jax.pmap(
823
+ fn,
824
+ in_axes=in_axes,
825
+ out_axes=out_axes,
826
+ axis_name=axis_name,
827
+ static_broadcasted_argnums=static_broadcasted_argnums,
828
+ devices=devices,
829
+ backend=backend,
830
+ axis_size=axis_size,
831
+ donate_argnums=donate_argnums,
832
+ global_arg_shapes=global_arg_shapes,
833
+ ),
834
+ rngs=rngs
835
+ )
836
+
837
+
838
+ def _batch_and_remainder(x, batch_size: int):
839
+ leaves, tree_def = jax.tree.flatten(x)
840
+
841
+ scan_leaves = []
842
+ remainder_leaves = []
843
+
844
+ length = None
845
+ for leaf in leaves:
846
+ if length is None:
847
+ length = leaf.shape[0]
848
+ if length != leaf.shape[0]:
849
+ raise ValueError(f"All inputs must have the same length. Got {length} and {leaf.shape[0]}.")
850
+
851
+ num_batches, num_remainder = divmod(length, batch_size)
852
+ for leaf in leaves:
853
+ total_batch_elems = num_batches * batch_size
854
+ scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
855
+ if num_remainder:
856
+ remainder_leaves.append(leaf[total_batch_elems:])
857
+
858
+ scan_tree = tree_def.unflatten(scan_leaves)
859
+ if num_remainder:
860
+ remainder_tree = tree_def.unflatten(remainder_leaves)
861
+ return scan_tree, remainder_tree
862
+ else:
863
+ return scan_tree, None
864
+
865
+
866
+ def map(
867
+ f,
868
+ *xs,
869
+ batch_size: int | None = None,
870
+ ):
871
+ """
872
+ Map a function over leading array axes.
873
+
874
+ Like Python's builtin map, except inputs and outputs are in the form of
875
+ stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
876
+ need to apply a function element by element for reduced memory usage or
877
+ heterogeneous computation with other control flow primitives.
878
+
879
+ When ``xs`` is an array type, the semantics of :func:`~map` are given by this
880
+ Python implementation::
881
+
882
+ def map(f, *xs):
883
+ return np.stack([f(*x) for x in xs])
884
+
885
+ Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
886
+ many of the same advantages over a Python loop apply: ``xs`` may be an
887
+ arbitrary nested pytree type, and the mapped computation is compiled only
888
+ once.
889
+
890
+ If ``batch_size`` is provided, the computation is executed in batches of that size
891
+ and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
892
+ version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
893
+ divisible by the batch size, the remainder is processed in a separate ``vmap`` and
894
+ concatenated to the result.
895
+
896
+ >>> import jax.numpy as jnp
897
+ >>> x = jnp.ones((10, 3, 4))
898
+ >>> def f(x):
899
+ ... print('inner shape:', x.shape)
900
+ ... return x + 1
901
+ >>> y = map(f, x, batch_size=3)
902
+ inner shape: (3, 4)
903
+ inner shape: (3, 4)
904
+ >>> y.shape
905
+ (10, 3, 4)
906
+
907
+ In the example above, "inner shape" is printed twice, once while tracing the batched
908
+ computation and once while tracing the remainder computation.
909
+
910
+ Args:
911
+ f: a Python function to apply element-wise over the first axis or axes of
912
+ ``xs``.
913
+ xs: values over which to map along the leading axis.
914
+ batch_size: (optional) integer specifying the size of the batch for each step to execute
915
+ in parallel.
916
+
917
+ Returns:
918
+ Mapped values.
919
+ """
920
+ if batch_size is not None:
921
+ scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
922
+ g = lambda _, x: ((), vmap(f)(*x))
923
+ _, scan_ys = scan(g, (), scan_xs)
924
+ if remainder_xs is None:
925
+ ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
926
+ else:
927
+ remainder_ys = vmap(f)(*remainder_xs)
928
+ ys = jax.tree.map(
929
+ lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
930
+ scan_ys,
931
+ remainder_ys,
932
+ )
933
+ else:
934
+ g = lambda _, x: ((), f(*x))
935
+ _, ys = scan(g, (), xs)
936
+ return ys
937
+
938
+
939
+ def _flatten(x):
940
+ return x.reshape(-1, *x.shape[2:])
941
+
942
+
943
+ def _vmap_new_states_transform(
944
+ fun: Callable[..., Any],
945
+ *,
946
+ # -- normal jax.vmap arguments -- #
947
+ in_axes: int | None | Sequence[Any] = 0,
948
+ out_axes: Any = 0,
949
+ axis_name: AxisName | None = None,
950
+ axis_size: int | None = None,
951
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
952
+ # -- brainstate specific arguments -- #
953
+ state_tag: str | None = None,
954
+ state_to_exclude: Filter | None = None,
955
+ in_states: Dict[int, Dict] | Any | None = None,
956
+ out_states: Dict[int, Dict] | Any | None = None,
957
+ ):
958
+ # TODO: How about nested call ``vmap_new_states``?
959
+ if isinstance(axis_size, int) and axis_size <= 0:
960
+ raise ValueError(f"axis_size must be greater than 0, got {axis_size}.")
961
+
962
+ @vmap(
963
+ in_axes=in_axes,
964
+ out_axes=out_axes,
965
+ axis_name=axis_name,
966
+ axis_size=axis_size,
967
+ spmd_axis_name=spmd_axis_name,
968
+ in_states=in_states,
969
+ out_states=out_states,
970
+ )
971
+ def new_fun(args):
972
+ # call the function
973
+ with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
974
+ out = fun(*args)
975
+
976
+ # get vmap state values
977
+ vmap_state_vals = catcher.get_state_values()
978
+
979
+ return out, vmap_state_vals
980
+
981
+ @functools.wraps(fun)
982
+ def vmapped_fn(*args):
983
+ # vmapping
984
+ with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
985
+ outs, vmap_state_vals = new_fun(args)
986
+ vmap_states = catcher.get_states()
987
+
988
+ # restore vmapped state values
989
+ for st_val, st in zip(vmap_state_vals, vmap_states):
990
+ st.restore_value(st_val)
991
+ # ------------------------------------------------
992
+ # --- this is CRUCIAL to avoid jax tracing leakage
993
+ # ------------------------------------------------
994
+ st.decrease_stack_level()
995
+ return outs
996
+
997
+ return vmapped_fn
998
+
999
+
1000
+ def vmap_new_states(
1001
+ fun: Callable = Missing(),
1002
+ *,
1003
+ # -- normal jax.vmap arguments -- #
1004
+ in_axes: int | None | Sequence[Any] = 0,
1005
+ out_axes: Any = 0,
1006
+ axis_name: AxisName | None = None,
1007
+ axis_size: int | None = None,
1008
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
1009
+ # -- brainstate specific arguments -- #
1010
+ state_tag: str | None = None,
1011
+ state_to_exclude: Filter = None,
1012
+ in_states: Dict[int, Dict] | Any | None = None,
1013
+ out_states: Dict[int, Dict] | Any | None = None,
1014
+ ):
1015
+ """
1016
+ Vectorize a function over new states created within it.
1017
+
1018
+ This function applies JAX's vmap transformation to newly created states
1019
+ during the function's execution. It allows for more
1020
+ flexible vectorization in the context of stateful computations.
1021
+
1022
+ Args:
1023
+ fun (Callable, optional): The function to be vectorized. Defaults to Missing().
1024
+ in_axes (int | None | Sequence[Any], optional): Specification of input axes for vectorization. Defaults to 0.
1025
+ out_axes (Any, optional): Specification of output axes after vectorization. Defaults to 0.
1026
+ axis_name (AxisName, optional): Name of the axis being vectorized over. Defaults to None.
1027
+ axis_size (int, optional): Size of the axis being vectorized over. Defaults to None.
1028
+ spmd_axis_name (AxisName | tuple[AxisName, ...], optional): Name(s) of SPMD axis/axes. Defaults to None.
1029
+ state_tag (str, optional): A tag to identify specific states. Defaults to None.
1030
+ state_to_exclude (Sequence[int], optional): Indices of states to exclude from vectorization. Defaults to ().
1031
+
1032
+ Returns:
1033
+ Callable: A vectorized version of the input function that handles new state creation.
1034
+ """
1035
+ if isinstance(fun, Missing):
1036
+ return functools.partial(
1037
+ _vmap_new_states_transform,
1038
+ in_axes=in_axes,
1039
+ out_axes=out_axes,
1040
+ axis_name=axis_name,
1041
+ axis_size=axis_size,
1042
+ spmd_axis_name=spmd_axis_name,
1043
+ state_tag=state_tag,
1044
+ state_to_exclude=state_to_exclude,
1045
+ in_states=in_states,
1046
+ out_states=out_states,
1047
+ )
1048
+ else:
1049
+ return _vmap_new_states_transform(
1050
+ fun,
1051
+ in_axes=in_axes,
1052
+ out_axes=out_axes,
1053
+ axis_name=axis_name,
1054
+ axis_size=axis_size,
1055
+ spmd_axis_name=spmd_axis_name,
1056
+ state_tag=state_tag,
1057
+ state_to_exclude=state_to_exclude,
1058
+ in_states=in_states,
1059
+ out_states=out_states,
1060
+ )