brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +8 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +193 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +6 -1
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +68 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/_utils.py +89 -0
  68. brainstate/nn/metrics.py +3 -4
  69. brainstate/optim/_lr_scheduler.py +1 -2
  70. brainstate/optim/_lr_scheduler_test.py +2 -3
  71. brainstate/optim/_optax_optimizer_test.py +1 -2
  72. brainstate/optim/_sgd_optimizer.py +2 -3
  73. brainstate/random/_rand_funs.py +1 -2
  74. brainstate/random/_rand_funs_test.py +2 -3
  75. brainstate/random/_rand_seed.py +2 -3
  76. brainstate/random/_rand_seed_test.py +1 -2
  77. brainstate/random/_rand_state.py +3 -4
  78. brainstate/surrogate.py +5 -5
  79. brainstate/transform.py +0 -3
  80. brainstate/typing.py +28 -25
  81. brainstate/util/__init__.py +9 -7
  82. brainstate/util/_caller.py +1 -2
  83. brainstate/util/_error.py +27 -0
  84. brainstate/util/_others.py +60 -15
  85. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  86. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  87. brainstate/util/_pretty_repr.py +1 -2
  88. brainstate/util/_pretty_table.py +2900 -0
  89. brainstate/util/_struct.py +11 -11
  90. brainstate/util/filter.py +472 -0
  91. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA +2 -2
  92. brainstate-0.1.0.post20250217.dist-info/RECORD +128 -0
  93. brainstate/util/_filter.py +0 -178
  94. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/LICENSE +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/WHEEL +0 -0
  97. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt +0 -0
@@ -16,15 +16,26 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- from typing import Any, TypeVar, Callable, Hashable, Sequence, Iterable, Tuple, Union, Optional, Dict, List
20
-
21
19
  import jax
22
20
  from jax.interpreters.batching import BatchTracer
23
-
24
- from brainstate._state import State, StateTraceStack
25
- from brainstate.compile._loop_collect_return import scan
26
- from brainstate.random import DEFAULT, RandomState
27
- from brainstate.typing import Missing
21
+ from typing import (
22
+ Any,
23
+ TypeVar,
24
+ Callable,
25
+ Hashable,
26
+ Sequence,
27
+ Iterable,
28
+ Tuple,
29
+ Union,
30
+ Optional,
31
+ Dict,
32
+ List
33
+ )
34
+
35
+ from brainstate._state import State, catch_new_states
36
+ from brainstate.compile import scan, StatefulFunction
37
+ from brainstate.random import RandomState, DEFAULT
38
+ from brainstate.typing import Missing, Filter
28
39
  from brainstate.util import NestedDict, BrainStateError
29
40
  from ._random import restore_rngs
30
41
 
@@ -32,21 +43,56 @@ __all__ = [
32
43
  'vmap',
33
44
  'pmap',
34
45
  'map',
46
+ 'vmap_new_states',
35
47
  ]
36
48
 
37
- AxisName = Hashable
38
49
  F = TypeVar("F", bound=Callable)
50
+ AxisName = Hashable
39
51
  AxisToState = Dict[int, List[State]]
40
52
  StateToAxis = Dict[State, int]
41
53
 
42
54
 
43
55
  class BatchAxisError(BrainStateError):
56
+ """
57
+ Exception raised for errors related to batch axis operations.
58
+
59
+ This custom exception is used to indicate errors that occur during
60
+ batch processing or vectorization operations, particularly in the
61
+ context of state management in the BrainState framework.
62
+
63
+ Inherits from:
64
+ BrainStateError: The base error class for BrainState-related exceptions.
65
+ """
44
66
  pass
45
67
 
46
68
 
47
69
  def _flatten_in_out_states(
48
70
  in_states: Dict[int, Dict] | Any = None,
49
71
  ) -> Tuple[AxisToState, StateToAxis]:
72
+ """
73
+ Flattens and organizes input or output states into axis-based mappings.
74
+
75
+ This function processes the input or output states, converting them into two
76
+ dictionary representations: one mapping axes to states, and another mapping
77
+ states to axes. It handles both structured (Dict[int, Dict]) and unstructured
78
+ input formats.
79
+
80
+ Args:
81
+ in_states (Dict[int, Dict] | Any, optional): The input or output states to be
82
+ flattened. Can be a nested dictionary structure where the outer keys are
83
+ axes and inner dictionaries contain states, or any other structure
84
+ containing states. Defaults to None.
85
+
86
+ Returns:
87
+ Tuple[AxisToState, StateToAxis]: A tuple containing two dictionaries:
88
+ - AxisToState: Maps axes (int) to lists of states.
89
+ - StateToAxis: Maps individual states to their corresponding axes (int).
90
+
91
+ Note:
92
+ If in_states is None, empty dictionaries are returned for both mappings.
93
+ If in_states is not in the expected Dict[int, Dict] format, all states are
94
+ assigned to axis 0.
95
+ """
50
96
  if in_states is None:
51
97
  return dict(), dict()
52
98
  if isinstance(in_states, dict):
@@ -72,16 +118,165 @@ def _flatten_in_out_states(
72
118
  return axis_to_states, state_to_axis
73
119
 
74
120
 
75
- def _vmap_transform(
76
- f: F,
77
- *,
78
- in_axes: int | None | Sequence[Any] = 0,
79
- out_axes: Any = 0,
80
- in_states: Dict[int, Dict] | Any | None = None,
81
- out_states: Dict[int, Dict] | Any | None = None,
82
- rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
83
- **transform_kwargs,
121
+ def _remove_axis(x, axis: int):
122
+ """
123
+ Remove a specified axis from an array or nested structure.
124
+
125
+ This function removes a specified axis from an array or nested structure,
126
+ adjusting the shape and structure of the output accordingly.
127
+
128
+ Args:
129
+ x (Any): The input array or nested structure to remove the axis from.
130
+ axis (int): The axis to remove from the input.
131
+
132
+ Returns:
133
+ Any: The output array or nested structure with the specified axis removed.
134
+ """
135
+ assert isinstance(axis, int), f"Expected axis to be an integer, but got {type(axis)}"
136
+ if axis < 0:
137
+ axis += x.ndim
138
+ if axis < 0 or axis >= x.ndim:
139
+ raise IndexError(f"Axis {axis} is out of bounds for array of shape {x.shape}")
140
+ return x[tuple(slice(None, None, None) if i != axis else 0 for i in range(x.ndim))]
141
+
142
+
143
+ def _compile_stateful_function(
144
+ stateful_fn: StatefulFunction,
145
+ in_axes: int | Tuple[int, ...],
146
+ args: Tuple
147
+ ):
148
+ """
149
+ Compile a stateful function with specified input axes and arguments.
150
+
151
+ This function prepares and compiles a stateful function for vectorized mapping (vmap)
152
+ by adjusting the input arguments based on the specified axes and then generating
153
+ the function's JAX program representation (jaxpr).
154
+
155
+ Args:
156
+ stateful_fn (StatefulFunction): The stateful function to be compiled.
157
+ in_axes (int | Tuple[int, ...]): Specifies which axes of the input arguments
158
+ to map over. Can be a single integer (same for all args) or a tuple of integers.
159
+ args (Tuple): The input arguments to the function.
160
+
161
+ Raises:
162
+ ValueError: If the length of in_axes tuple doesn't match the number of arguments.
163
+
164
+ Returns:
165
+ None. The function modifies the stateful_fn in-place by calling make_jaxpr.
166
+ """
167
+ in_axes_st, in_axes = in_axes
168
+ state_vals, args = args
169
+
170
+ # check in_axes
171
+ if isinstance(in_axes, tuple) and len(in_axes) != len(args):
172
+ raise ValueError(
173
+ "vmap in_axes must be an int, None, or a tuple of entries corresponding "
174
+ "to the positional arguments passed to the function, "
175
+ f"but got {len(in_axes)=}, {len(args)=}"
176
+ )
177
+
178
+ # check state_vals
179
+ if len(state_vals) > 0:
180
+ state_vals = [jax.tree.map(lambda x: _remove_axis(x, axis), vals)
181
+ for vals, axis in zip(state_vals, in_axes_st)]
182
+ else:
183
+ state_vals = []
184
+
185
+ if isinstance(in_axes, int):
186
+ args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
187
+ elif isinstance(in_axes, tuple):
188
+ args = tuple(
189
+ [arg if in_axis is None else _remove_axis(arg, in_axis)
190
+ for arg, in_axis in zip(args, in_axes)]
191
+ )
192
+ stateful_fn.make_jaxpr(state_vals, args)
193
+ return stateful_fn.get_arg_cache_key(state_vals, args)
194
+
195
+
196
+ def _get_batch_size(
197
+ args: Tuple,
198
+ in_axes: int | Tuple[int, ...],
199
+ in_states: AxisToState,
200
+ axis_size: Optional[int] = None,
201
+ ) -> int:
202
+ """
203
+ Determine the batch size from input arguments, axes, and states.
204
+
205
+ This function calculates the batch size by examining the shapes of input arguments
206
+ and states along specified axes. It ensures consistency across all inputs.
207
+
208
+ Args:
209
+ args (Tuple): The input arguments to the function being vectorized.
210
+ in_axes (int | Tuple[int, ...]): The axes along which to vectorize for each argument.
211
+ Can be a single integer (same for all args) or a tuple of integers.
212
+ in_states (AxisToState): A dictionary mapping axes to lists of states.
213
+
214
+ Returns:
215
+ int: The determined batch size.
216
+
217
+ Raises:
218
+ ValueError: If unable to determine batch size or if inconsistent batch sizes are found.
219
+ """
220
+ batch_sizes = []
221
+
222
+ # Check batch size from args and in_axes
223
+ if isinstance(in_axes, int):
224
+ in_axes = (in_axes,) * len(args)
225
+ for arg, in_axis in zip(args, in_axes):
226
+ if in_axis is not None:
227
+ arg_leaves = jax.tree.leaves(arg)
228
+ if arg_leaves:
229
+ batch_sizes.append(arg_leaves[0].shape[in_axis])
230
+
231
+ # Check batch size from in_states
232
+ if in_states is not None:
233
+ for axis, states in in_states.items():
234
+ for state in states:
235
+ state_leaves = jax.tree.leaves(state.value)
236
+ if len(state_leaves):
237
+ batch_sizes.append(state_leaves[0].shape[axis])
238
+
239
+ if len(batch_sizes) == 0:
240
+ assert axis_size is not None, (
241
+ "Unable to determine batch size. Please provide the 'axis_size' argument."
242
+ )
243
+ return axis_size
244
+ else:
245
+ # Ensure all batch sizes are consistent
246
+ if len(set(batch_sizes)) > 1:
247
+ raise ValueError(f"Inconsistent batch sizes found: {set(batch_sizes)}")
248
+
249
+ return batch_sizes[0]
250
+
251
+
252
+ def _format_state_axes(
253
+ in_states,
254
+ out_states,
84
255
  ):
256
+ """
257
+ Format and validate the axes of input and output states.
258
+
259
+ This function processes the input and output states, ensuring consistency
260
+ between their axis mappings. It also handles cases where a state appears
261
+ in the input but not in the output.
262
+
263
+ Args:
264
+ in_states: The input states to be formatted. Can be a dictionary mapping
265
+ axes to states, or any other structure containing states.
266
+ out_states: The output states to be formatted. Can be a dictionary mapping
267
+ axes to states, or any other structure containing states.
268
+
269
+ Returns:
270
+ A tuple containing four elements:
271
+ - axis_to_in_states (dict): Mapping of axes to input states.
272
+ - in_state_to_axis (dict): Mapping of input states to their axes.
273
+ - axis_to_out_states (dict): Mapping of axes to output states.
274
+ - out_state_to_axis (dict): Mapping of output states to their axes.
275
+
276
+ Raises:
277
+ BatchAxisError: If there's an inconsistency between the axis mappings
278
+ of input and output states.
279
+ """
85
280
  axis_to_in_states, in_state_to_axis = _flatten_in_out_states(in_states)
86
281
  axis_to_out_states, out_state_to_axis = _flatten_in_out_states(out_states)
87
282
  for _in_state, _axis in in_state_to_axis.items():
@@ -90,8 +285,8 @@ def _vmap_transform(
90
285
  if _out_axis != _axis:
91
286
  _in_state.raise_error_with_source_info(
92
287
  BatchAxisError(
93
- f"State {_in_state} has been mapped to axis {_axis} in in_states, "
94
- f"However, it is mapped to axis {_out_axis} in out_states."
288
+ f"State {_in_state} has been mapped to axis {_axis} in 'in_states', "
289
+ f"However, it is mapped to axis {_out_axis} in 'out_states'."
95
290
  )
96
291
  )
97
292
  else:
@@ -99,65 +294,286 @@ def _vmap_transform(
99
294
  if _axis not in axis_to_out_states:
100
295
  axis_to_out_states[_axis] = []
101
296
  axis_to_out_states[_axis].append(_in_state)
102
- if isinstance(rngs, RandomState):
103
- rngs = (rngs,)
104
- rng_ids = set([id(rng) for rng in rngs])
105
297
 
106
- @functools.wraps(f)
107
- def new_fn(in_states_, args):
298
+ return axis_to_in_states, in_state_to_axis, axis_to_out_states, out_state_to_axis
299
+
300
+
301
+ def _vmap_transform(
302
+ f: F,
303
+ *,
304
+ in_axes: int | None | Sequence[Any] = 0,
305
+ out_axes: Any = 0,
306
+ in_states: Dict[int, Dict] | Any | None = None,
307
+ out_states: Dict[int, Dict] | Any | None = None,
308
+ axis_size: Optional[int] = None,
309
+ axis_name: AxisName | None = None,
310
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
311
+ ):
312
+ """
313
+ Transforms a function for vectorized mapping (vmap) with state management.
314
+
315
+ This internal function applies vectorized mapping to the input function while
316
+ handling state management for input and output states. It supports custom
317
+ axis specifications for both inputs and outputs.
318
+
319
+ Args:
320
+ f (F): The function to be transformed for vectorized mapping.
321
+ in_axes (int | None | Sequence[Any]): Specifies which axes of the input
322
+ arguments to map over. Default is 0.
323
+ out_axes (Any): Specifies where the mapped axis should appear in the output.
324
+ Default is 0.
325
+ in_states (Dict[int, Dict] | Any | None): Specifies the input states and
326
+ their corresponding axes for mapping. Default is None.
327
+ out_states (Dict[int, Dict] | Any | None): Specifies the output states and
328
+ their corresponding axes for mapping. Default is None.
329
+ **transform_kwargs: Additional keyword arguments for the transformation.
330
+
331
+ Returns:
332
+ Callable: A new function that applies vectorized mapping to the input
333
+ function while managing states.
334
+ """
335
+
336
+ # TODO: support jax.disable_jit()
337
+
338
+ # format state axes
339
+ (
340
+ axis_to_in_states,
341
+ in_state_to_axis,
342
+ axis_to_out_states,
343
+ out_state_to_axis
344
+ ) = _format_state_axes(in_states, out_states)
345
+
346
+ # check in_axes
347
+ if isinstance(in_axes, list):
348
+ # To be a tree prefix of the positional args tuple, in_axes can never be a
349
+ # list: if in_axes is not a leaf, it must be a tuple of trees. However,
350
+ # in cases like these users expect tuples and lists to be treated
351
+ # essentially interchangeably, so we canonicalize lists to tuples here
352
+ # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
353
+ in_axes = tuple(in_axes)
354
+
355
+ def _vmap_fn_for_compilation(in_vmap_state_vals, args):
356
+ """
357
+ Compile a function for vectorized mapping (vmap) with state restoration.
358
+
359
+ This internal function is used to prepare a function for vectorized mapping
360
+ by restoring state values before calling the original function.
361
+
362
+ Args:
363
+ in_vmap_state_vals (List[List]): A nested list containing the state values
364
+ to be restored. The outer list corresponds to different axes, while
365
+ the inner lists contain the state values for each axis.
366
+ args (Tuple): The arguments to be passed to the original function after
367
+ state restoration.
368
+
369
+ Returns:
370
+ Any: The result of calling the original function 'f' with the restored
371
+ state and provided arguments.
372
+ """
108
373
  # restore state values
109
374
  for i, states in enumerate(axis_to_in_states.values()):
110
- for state, state_val in zip(states, in_states_[i]):
375
+ for state, state_val in zip(states, in_vmap_state_vals[i]):
111
376
  state.restore_value(state_val)
112
377
 
113
378
  # call the function
114
- with StateTraceStack() as stack:
115
- outs = f(*args)
379
+ return f(*args)
380
+
381
+ # stateful function
382
+ stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
116
383
 
117
- # analyze
118
- for state in stack.get_write_states():
384
+ @functools.wraps(f)
385
+ def new_fn_for_vmap(
386
+ rng_keys,
387
+ in_state_vmap_vals,
388
+ in_state_oth_vals,
389
+ args,
390
+ ):
391
+ """
392
+ Wrapper function for vectorized mapping (vmap) that handles state restoration and function execution.
393
+
394
+ This function restores state values, random number generators (RNGs), and other state values
395
+ before calling the original function. It then processes the outputs and prepares them for
396
+ vectorized mapping.
397
+
398
+ Args:
399
+ rng_keys (Sequence): Random number generator keys for each mapped instance.
400
+ in_state_vmap_vals (Sequence[Sequence]): Input state values for vectorized mapping,
401
+ organized by axis.
402
+ in_state_oth_vals (Sequence): Other input state values not involved in vectorized mapping.
403
+ args (Tuple): Arguments to be passed to the original function.
404
+
405
+ Returns:
406
+ Tuple: A tuple containing four elements:
407
+ - out_rng_keys (List): Updated RNG keys after function execution.
408
+ - out_state_vmap_vals (List[List]): Output state values for vectorized mapping,
409
+ organized by axis.
410
+ - out_state_oth_vals (List): Other output state values not involved in vectorized mapping.
411
+ - outs: The output of the original function call.
412
+
413
+ Raises:
414
+ AssertionError: If there's a mismatch in the number of states, state values, or RNG keys.
415
+ BatchAxisError: If a state value is batched but not included in out_states.
416
+ """
417
+ # restore vmapping state values
418
+ for i, states in enumerate(axis_to_in_states.values()):
419
+ assert len(states) == len(in_state_vmap_vals[i]), (
420
+ f"The number of states in axis {i} should be equal to the number "
421
+ f"of state values, but got {len(states)} and {len(in_state_vmap_vals[i])}."
422
+ )
423
+ for state, state_val in zip(states, in_state_vmap_vals[i]):
424
+ state.restore_value(state_val)
425
+
426
+ # restore rngs
427
+ cache_key = stateful_fn.get_arg_cache_key(in_state_vmap_vals, args)
428
+ state_trace = stateful_fn.get_state_trace(cache_key)
429
+ rngs = state_trace.state_subset(RandomState)
430
+ rng_sets = set(rngs)
431
+ assert len(rngs) == len(rng_keys), (
432
+ f"The number of random states in the function should be equal to the number "
433
+ f"of random keys, but got {len(rngs)} and {len(rng_keys)}."
434
+ )
435
+ for rng, key in zip(rngs, rng_keys):
436
+ rng.restore_value(key)
437
+
438
+ # restore other state values
439
+ oth_in_state = [
440
+ st for st in state_trace.states
441
+ if st not in in_state_to_axis and st not in rng_sets
442
+ ]
443
+ assert len(oth_in_state) == len(in_state_oth_vals), (
444
+ f"The number of states in 'in_states' should be equal to the number "
445
+ f"of state values, but got {len(oth_in_state)} and {len(in_state_oth_vals)}."
446
+ )
447
+ for state, state_val in zip(oth_in_state, in_state_oth_vals):
448
+ state.restore_value(state_val)
449
+
450
+ # call the function
451
+ outs = stateful_fn.jaxpr_call_auto(in_state_vmap_vals, args)
452
+
453
+ # analyze vmapping axis error
454
+ for state in state_trace.get_write_states():
119
455
  leaves = jax.tree.leaves(state.value)
120
- if isinstance(leaves[0], BatchTracer) and state not in out_state_to_axis:
121
- if isinstance(state, RandomState) and id(state) in rng_ids:
456
+ if any([isinstance(leaf, BatchTracer) for leaf in leaves]) and state not in out_state_to_axis:
457
+ if isinstance(state, RandomState) and state in rng_sets:
122
458
  continue
123
459
  state.raise_error_with_source_info(
124
- BatchAxisError(
125
- f"The value of State {state} is batched, but it is not in the out_states."
126
- )
460
+ BatchAxisError(f"The value of State {state} is batched, "
461
+ f"but it is not in the out_states.")
127
462
  )
128
463
 
129
- out_states_ = [
464
+ # out state values for vmapping
465
+ out_state_vmap_vals = [
130
466
  [state.value for state in states]
131
467
  for axis, states in axis_to_out_states.items()
132
468
  ]
133
- return out_states_, outs
469
+ out_state_oth_vals = [
470
+ st.value for st in state_trace.states
471
+ if st not in out_state_to_axis and st not in rng_sets
472
+ ]
473
+ out_rng_keys = [rng.value for rng in rngs]
474
+ return out_rng_keys, out_state_vmap_vals, out_state_oth_vals, outs
134
475
 
476
+ @functools.wraps(f)
135
477
  def vmapped_fn(*args):
136
- # vmapping
137
- in_state_vals = [
478
+ """
479
+ Applies vectorized mapping (vmap) to the input function while managing state.
480
+
481
+ This function handles the vectorization process, including state management,
482
+ random number generation, and function compilation. It prepares the input
483
+ states, compiles the stateful function, manages random number generators,
484
+ applies the vmap transformation, and restores the output states.
485
+
486
+ Args:
487
+ *args: Variable length argument list containing the input arguments
488
+ to be passed to the vectorized function.
489
+
490
+ Returns:
491
+ Any: The output of the vectorized function after applying vmap and
492
+ managing states.
493
+
494
+ Note:
495
+ This function assumes the existence of several helper functions and
496
+ data structures (e.g., axis_to_in_states, in_state_to_axis) which
497
+ should be defined in the broader context.
498
+ """
499
+ # in states values
500
+ in_state_map_vals = [
138
501
  [st.value for st in states]
139
502
  for axis, states in axis_to_in_states.items()
140
503
  ]
141
- in_axes_st = list(axis_to_in_states.keys())
142
- out_axes_st = list(axis_to_out_states.keys())
143
- if len(in_axes_st) == 0:
144
- in_axes_st = 0
145
- if len(out_axes_st) == 0:
146
- out_axes_st = 0
147
- out_state_vals, outs = restore_rngs(
148
- jax.vmap(
149
- new_fn,
150
- in_axes=(in_axes_st, in_axes),
151
- out_axes=(out_axes_st, out_axes),
152
- **transform_kwargs
153
- ),
154
- rngs=rngs
155
- )(in_state_vals, args)
504
+ st_in_axes = list(axis_to_in_states.keys())
505
+ if len(st_in_axes) == 0:
506
+ st_in_axes = 0
507
+
508
+ # compile stateful function
509
+ cache_key = _compile_stateful_function(
510
+ stateful_fn,
511
+ (st_in_axes, in_axes),
512
+ (in_state_map_vals, args)
513
+ )
514
+
515
+ # random keys
516
+ state_trace = stateful_fn.get_state_trace(cache_key)
517
+ rngs = state_trace.state_subset(RandomState)
518
+ rng_sets = set(rngs)
519
+ if len(rngs):
520
+ # batch size
521
+ batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
522
+ rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
523
+ rng_backup = tuple(rng.split_key() for rng in rngs)
524
+ else:
525
+ rng_keys = tuple()
526
+ rng_backup = tuple()
527
+
528
+ # in states other values
529
+ in_state_oth_vals = [
530
+ st.value
531
+ for st in state_trace.states
532
+ if st not in in_state_to_axis and st not in rng_sets
533
+ ]
534
+
535
+ # out state axis
536
+ st_out_axes = list(axis_to_out_states.keys())
537
+ if len(st_out_axes) == 0:
538
+ st_out_axes = 0
539
+
540
+ # --- vmapping --- #
541
+ fn = jax.vmap(
542
+ new_fn_for_vmap,
543
+ in_axes=(0, st_in_axes, None, in_axes),
544
+ out_axes=(0, st_out_axes, None, out_axes),
545
+ axis_size=axis_size,
546
+ axis_name=axis_name,
547
+ spmd_axis_name=spmd_axis_name,
548
+ )
549
+ _, out_state_map_vals, out_state_oth_vals, outs = fn(
550
+ rng_keys, in_state_map_vals, in_state_oth_vals, args
551
+ )
156
552
 
157
553
  # restore mapped state values
158
554
  for i, states in enumerate(axis_to_out_states.values()):
159
- for state, st_val in zip(states, out_state_vals[i]):
555
+ assert len(states) == len(out_state_map_vals[i]), (
556
+ f"The number of states in axis {i} should be equal to the number "
557
+ f"of state values, but got {len(states)} and {len(out_state_map_vals[i])}."
558
+ )
559
+ for state, st_val in zip(states, out_state_map_vals[i]):
160
560
  state.restore_value(st_val)
561
+
562
+ # restore other state values
563
+ out_oth_states = [
564
+ st for st in state_trace.states
565
+ if st not in out_state_to_axis and st not in rng_sets
566
+ ]
567
+ assert len(out_oth_states) == len(out_state_oth_vals), (
568
+ f"The number of states in 'out_states' should be equal to the number "
569
+ f"of state values, but got {len(out_oth_states)} and {len(out_state_oth_vals)}."
570
+ )
571
+ for state, st_val in zip(out_oth_states, out_state_oth_vals):
572
+ state.restore_value(st_val)
573
+
574
+ # restore random keys
575
+ for rng, key in zip(rngs, rng_backup):
576
+ rng.restore_value(key)
161
577
  return outs
162
578
 
163
579
  return vmapped_fn
@@ -166,15 +582,15 @@ def _vmap_transform(
166
582
  def vmap(
167
583
  fn: F | Missing = Missing(),
168
584
  *,
585
+ # --- normal jax.vmap arguments --- #
169
586
  in_axes: int | None | Sequence[Any] = 0,
170
587
  out_axes: Any = 0,
171
588
  axis_name: AxisName | None = None,
172
589
  axis_size: int | None = None,
173
590
  spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
174
- # brainstate specific arguments
591
+ # --- brainstate specific arguments --- #
175
592
  in_states: Dict[int, Dict] | Any | None = None,
176
593
  out_states: Dict[int, Dict] | Any | None = None,
177
- rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
178
594
  ) -> F | Callable[[F], F]:
179
595
  """
180
596
  Vectorizing map. Creates a function which maps ``fun`` over argument axes.
@@ -229,10 +645,6 @@ def vmap(
229
645
  corresponds to the outermost :func:`vmap` call, the second element to
230
646
  the next outermost, and so on. If the tuple is not provided, the
231
647
  ``axis_name`` is used for all nested :func:`vmap` calls.
232
- rngs: Optional, a random number generator or sequence of random number
233
- generators to be used in the mapped function. These random number
234
- generators are restored their random key after the mapped function is
235
- executed.
236
648
  in_states: Optional, the :class:`State` objects to be mapped over in the inputs.
237
649
  out_states: Optional, the :class:`State` objects to be mapped over in the outputs.
238
650
 
@@ -254,7 +666,6 @@ def vmap(
254
666
  axis_name=axis_name,
255
667
  axis_size=axis_size,
256
668
  spmd_axis_name=spmd_axis_name,
257
- rngs=rngs,
258
669
  ) # type: ignore[return-value]
259
670
 
260
671
  return _vmap_transform(
@@ -266,7 +677,6 @@ def vmap(
266
677
  axis_name=axis_name,
267
678
  axis_size=axis_size,
268
679
  spmd_axis_name=spmd_axis_name,
269
- rngs=rngs
270
680
  )
271
681
 
272
682
 
@@ -511,3 +921,113 @@ def map(
511
921
 
512
922
  def flatten_(x):
513
923
  return x.reshape(-1, *x.shape[2:])
924
+
925
+
926
+ def _vmap_new_states_transform(
927
+ fun: Callable[..., Any],
928
+ *,
929
+ # -- normal jax.vmap arguments -- #
930
+ in_axes: int | None | Sequence[Any] = 0,
931
+ out_axes: Any = 0,
932
+ axis_name: AxisName | None = None,
933
+ axis_size: int | None = None,
934
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
935
+ # -- brainstate specific arguments -- #
936
+ state_tag: str | None = None,
937
+ state_to_exclude: Filter | None = None,
938
+ ):
939
+
940
+ # TODO: How about nested call ``vmap_new_states``?
941
+
942
+
943
+ @vmap(
944
+ in_axes=in_axes,
945
+ out_axes=out_axes,
946
+ axis_name=axis_name,
947
+ axis_size=axis_size,
948
+ spmd_axis_name=spmd_axis_name,
949
+ )
950
+ def new_fun(args):
951
+ # call the function
952
+ with catch_new_states(state_tag=state_tag, state_to_exclude=state_to_exclude) as catcher:
953
+ out = fun(*args)
954
+
955
+ # get vmap state values
956
+ vmap_state_vals = catcher.get_state_values()
957
+
958
+ return out, vmap_state_vals
959
+
960
+ @functools.wraps(fun)
961
+ def vmapped_fn(*args):
962
+ # vmapping
963
+ with catch_new_states(state_to_exclude=state_to_exclude) as catcher:
964
+ outs, vmap_state_vals = new_fun(args)
965
+ vmap_states = catcher.get_states()
966
+
967
+ # restore vmapped state values
968
+ for st_val, st in zip(vmap_state_vals, vmap_states):
969
+ st.restore_value(st_val)
970
+ # ------------------------------------------------
971
+ # --- this is CRUCIAL to avoid jax tracing leakage
972
+ # ------------------------------------------------
973
+ st.decrease_stack_level()
974
+ return outs
975
+
976
+ return vmapped_fn
977
+
978
+
979
+ def vmap_new_states(
980
+ fun: Callable = Missing(),
981
+ *,
982
+ # -- normal jax.vmap arguments -- #
983
+ in_axes: int | None | Sequence[Any] = 0,
984
+ out_axes: Any = 0,
985
+ axis_name: AxisName | None = None,
986
+ axis_size: int | None = None,
987
+ spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
988
+ # -- brainstate specific arguments -- #
989
+ state_tag: str | None = None,
990
+ state_to_exclude: Filter = None,
991
+ ):
992
+ """
993
+ Vectorize a function over new states created within it.
994
+
995
+ This function applies JAX's vmap transformation to newly created states
996
+ during the function's execution. It allows for more
997
+ flexible vectorization in the context of stateful computations.
998
+
999
+ Args:
1000
+ fun (Callable, optional): The function to be vectorized. Defaults to Missing().
1001
+ in_axes (int | None | Sequence[Any], optional): Specification of input axes for vectorization. Defaults to 0.
1002
+ out_axes (Any, optional): Specification of output axes after vectorization. Defaults to 0.
1003
+ axis_name (AxisName, optional): Name of the axis being vectorized over. Defaults to None.
1004
+ axis_size (int, optional): Size of the axis being vectorized over. Defaults to None.
1005
+ spmd_axis_name (AxisName | tuple[AxisName, ...], optional): Name(s) of SPMD axis/axes. Defaults to None.
1006
+ state_tag (str, optional): A tag to identify specific states. Defaults to None.
1007
+ state_to_exclude (Sequence[int], optional): Indices of states to exclude from vectorization. Defaults to ().
1008
+
1009
+ Returns:
1010
+ Callable: A vectorized version of the input function that handles new state creation.
1011
+ """
1012
+ if isinstance(fun, Missing):
1013
+ return functools.partial(
1014
+ _vmap_new_states_transform,
1015
+ in_axes=in_axes,
1016
+ out_axes=out_axes,
1017
+ axis_name=axis_name,
1018
+ axis_size=axis_size,
1019
+ spmd_axis_name=spmd_axis_name,
1020
+ state_tag=state_tag,
1021
+ state_to_exclude=state_to_exclude,
1022
+ )
1023
+ else:
1024
+ return _vmap_new_states_transform(
1025
+ fun,
1026
+ in_axes=in_axes,
1027
+ out_axes=out_axes,
1028
+ axis_name=axis_name,
1029
+ axis_size=axis_size,
1030
+ spmd_axis_name=spmd_axis_name,
1031
+ state_tag=state_tag,
1032
+ state_to_exclude=state_to_exclude,
1033
+ )