brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,2016 +1,2176 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
- written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
- include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
-
21
-
22
- ``StatefulFunction``
23
- --------------------
24
-
25
- The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
- JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
- The class provides the following methods:
28
-
29
- - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
- - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
- - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
- - `get_states`: returns the states that are read and written by the function.
33
- - `get_read_states`: returns the states that are read by the function.
34
- - `get_write_states`: returns the states that are written by the function.
35
- - `get_static_args`: returns the static arguments from the arguments.
36
- - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
- written by the function.
38
- - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
- - `get_out_shapes`: returns the output shapes of the function.
40
- - `get_out_treedef`: returns the output tree of the function.
41
-
42
- ``make_jaxpr``
43
- --------------
44
-
45
- The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
- arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
- `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
- returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
- and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
- function.
51
-
52
- """
53
-
54
- import functools
55
- import inspect
56
- import operator
57
- import threading
58
- from collections import OrderedDict, defaultdict
59
- from collections.abc import Hashable, Iterable, Sequence
60
- from collections.abc import MutableSet
61
- from contextlib import ExitStack
62
- from typing import Any, Callable, Dict, Optional, Tuple, Union
63
-
64
- import jax
65
- import jax.numpy as jnp
66
- from jax._src import source_info_util
67
- from jax._src.linear_util import annotate
68
- from jax._src.traceback_util import api_boundary
69
- from jax._src.util import memoize
70
- from jax.api_util import shaped_abstractify
71
- from jax.extend.linear_util import transformation_with_aux
72
- from jax.interpreters import partial_eval as pe
73
-
74
- from brainstate._compatible_import import (
75
- ClosedJaxpr, extend_axis_env_nd, safe_map, safe_zip, unzip2, wraps, wrap_init,
76
- Literal, Var, Jaxpr, make_iota, to_elt, BatchTracer, BatchTrace,
77
- )
78
- from brainstate._state import State, StateTraceStack
79
- from brainstate._utils import set_module_as
80
- from brainstate.random import RandomState
81
- from brainstate.typing import Filter, PyTree
82
- from brainstate.util import PrettyObject
83
- from brainstate.util.filter import to_predicate
84
-
85
- AxisName = Hashable
86
-
87
- __all__ = [
88
- "StatefulFunction",
89
- "make_jaxpr",
90
- "StatefulMapping",
91
- ]
92
-
93
-
94
- class hashabledict(dict):
95
- def __hash__(self):
96
- return hash(tuple(sorted(self.items())))
97
-
98
-
99
- class _BoundedCache:
100
- """
101
- A thread-safe LRU cache with bounded size.
102
-
103
- This cache stores a limited number of items and evicts the least recently used item
104
- when the cache reaches its maximum size. All operations are thread-safe.
105
-
106
- Parameters
107
- ----------
108
- maxsize : int, default 128
109
- Maximum number of items to store in the cache.
110
- """
111
-
112
- def __init__(self, maxsize: int = 128):
113
- self._cache = OrderedDict()
114
- self._maxsize = maxsize
115
- self._lock = threading.RLock()
116
- self._hits = 0
117
- self._misses = 0
118
-
119
- def get(
120
- self,
121
- key: Any,
122
- default: Any = None,
123
- raise_on_miss: bool = False,
124
- error_context: str = "item"
125
- ) -> Any:
126
- """
127
- Get an item from the cache.
128
-
129
- Parameters
130
- ----------
131
- key : Any
132
- The cache key.
133
- default : Any, optional
134
- The default value to return if the key is not found.
135
- raise_on_miss : bool, optional
136
- If True, raise a detailed ValueError when the key is not found.
137
- error_context : str, optional
138
- Context description for the error message (e.g., "Function", "JAX expression").
139
-
140
- Returns
141
- -------
142
- Any
143
- The cached value or the default value.
144
-
145
- Raises
146
- ------
147
- ValueError
148
- If raise_on_miss is True and the key is not found.
149
- """
150
- with self._lock:
151
- if key in self._cache:
152
- self._cache.move_to_end(key)
153
- self._hits += 1
154
- return self._cache[key]
155
- self._misses += 1
156
-
157
- if raise_on_miss:
158
- available_keys = list(self._cache.keys())
159
- error_msg = [
160
- f"{error_context} not compiled for the requested cache key.",
161
- f"",
162
- f"Requested key:",
163
- f" {key}",
164
- f"",
165
- f"Available {{len(available_keys)}} keys:",
166
- ]
167
- if available_keys:
168
- for i, k in enumerate(available_keys, 1):
169
- error_msg.append(f" [{i}] {k}")
170
- else:
171
- error_msg.append(" (none - not compiled yet)")
172
- error_msg.append("")
173
- error_msg.append("Call make_jaxpr() first with matching arguments.")
174
- raise ValueError("\n".join(error_msg))
175
-
176
- return default
177
-
178
- def set(self, key: Any, value: Any) -> None:
179
- """
180
- Set an item in the cache.
181
-
182
- Parameters
183
- ----------
184
- key : Any
185
- The cache key.
186
- value : Any
187
- The value to cache.
188
-
189
- Raises
190
- ------
191
- ValueError
192
- If the key already exists in the cache.
193
- """
194
- with self._lock:
195
- if key in self._cache:
196
- raise ValueError(
197
- f"Cache key already exists: {key}. "
198
- f"Cannot overwrite existing cached value. "
199
- f"Clear the cache first if you need to recompile."
200
- )
201
- if len(self._cache) >= self._maxsize:
202
- self._cache.popitem(last=False)
203
- self._cache[key] = value
204
-
205
- def pop(self, key: Any, default: Any = None) -> Any:
206
- """
207
- Remove and return an item from the cache.
208
-
209
- Parameters
210
- ----------
211
- key : Any
212
- The cache key to remove.
213
- default : Any, optional
214
- The default value to return if the key is not found.
215
-
216
- Returns
217
- -------
218
- Any
219
- The cached value or the default value if the key is not found.
220
- """
221
- with self._lock:
222
- if key in self._cache:
223
- return self._cache.pop(key)
224
- return default
225
-
226
- def replace(self, key: Any, value: Any) -> None:
227
- """
228
- Replace an existing item in the cache.
229
-
230
- Parameters
231
- ----------
232
- key : Any
233
- The cache key to replace.
234
- value : Any
235
- The new value to cache.
236
-
237
- Raises
238
- ------
239
- KeyError
240
- If the key does not exist in the cache.
241
- """
242
- with self._lock:
243
- if key not in self._cache:
244
- raise KeyError(
245
- f"Cache key does not exist: {key}. "
246
- f"Cannot replace non-existent cached value. "
247
- f"Use set() to add a new cache entry."
248
- )
249
- self._cache[key] = value
250
- self._cache.move_to_end(key)
251
-
252
- def __contains__(self, key: Any) -> bool:
253
- """
254
- Check if a key exists in the cache.
255
-
256
- Parameters
257
- ----------
258
- key : Any
259
- The cache key to check.
260
-
261
- Returns
262
- -------
263
- bool
264
- True if the key exists in the cache, False otherwise.
265
- """
266
- with self._lock:
267
- return key in self._cache
268
-
269
- def __len__(self) -> int:
270
- """
271
- Get the number of items in the cache.
272
-
273
- Returns
274
- -------
275
- int
276
- The number of items currently in the cache.
277
- """
278
- with self._lock:
279
- return len(self._cache)
280
-
281
- def clear(self) -> None:
282
- """
283
- Clear all items from the cache and reset statistics.
284
-
285
- This method removes all cached items and resets hit/miss counters to zero.
286
- """
287
- with self._lock:
288
- self._cache.clear()
289
- self._hits = 0
290
- self._misses = 0
291
-
292
- def keys(self):
293
- """
294
- Return all keys in the cache.
295
-
296
- Returns
297
- -------
298
- list
299
- A list of all keys currently in the cache.
300
- """
301
- with self._lock:
302
- return list(self._cache.keys())
303
-
304
- def get_stats(self) -> Dict[str, Any]:
305
- """
306
- Get cache statistics.
307
-
308
- Returns
309
- -------
310
- dict
311
- A dictionary with cache statistics including:
312
-
313
- - 'size': Current number of items in cache
314
- - 'maxsize': Maximum cache size
315
- - 'hits': Number of cache hits
316
- - 'misses': Number of cache misses
317
- - 'hit_rate': Hit rate percentage (0-100)
318
- """
319
- with self._lock:
320
- total = self._hits + self._misses
321
- hit_rate = (self._hits / total * 100) if total > 0 else 0.0
322
- return {
323
- 'size': len(self._cache),
324
- 'maxsize': self._maxsize,
325
- 'hits': self._hits,
326
- 'misses': self._misses,
327
- 'hit_rate': hit_rate,
328
- }
329
-
330
-
331
- def _ensure_str(x: str) -> str:
332
- if not isinstance(x, str):
333
- raise TypeError(f"argument is not a string: {x}")
334
- return x
335
-
336
-
337
- def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
338
- """Convert x to a tuple of indices."""
339
- x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
340
- try:
341
- return (operator.index(x),)
342
- except TypeError:
343
- return tuple(safe_map(operator.index, x))
344
-
345
-
346
- def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
347
- """Convert x to a tuple of strings."""
348
- if isinstance(x, str):
349
- return (x,)
350
- else:
351
- return tuple(safe_map(_ensure_str, x))
352
-
353
-
354
- def _jax_v04_new_arg_fn(frame, trace, aval):
355
- """
356
- Transform a new argument to a tracer.
357
-
358
- Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
359
-
360
- Args:
361
- frame: The frame.
362
- trace: The trace.
363
- aval: The abstract value.
364
-
365
- Returns:
366
- The tracer.
367
- """
368
- tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
369
- frame.tracers.append(tracer)
370
- frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
371
- frame.invars.append(var)
372
- return tracer
373
-
374
-
375
- def _jax_v04_new_jax_trace():
376
- main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
377
- frame = main.jaxpr_stack[-1]
378
- trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
379
- return frame, trace
380
-
381
-
382
- class StatefulFunction(PrettyObject):
383
- """
384
- A wrapper class for functions that tracks state reads and writes during execution.
385
-
386
- This class wraps a function to enable state management in JAX programs by tracking
387
- which states are read from and written to during function execution. It provides
388
- methods to compile the function into JAX's intermediate representation (jaxpr),
389
- inspect state usage, and execute the function with proper state handling.
390
-
391
- When you define a function:
392
-
393
- .. code-block:: python
394
-
395
- >>> state = brainstate.State(1.)
396
- >>> def f(x):
397
- ... # Your function logic here
398
- ... y = x * 2 + state.value
399
- ... state.value = y
400
-
401
- Calling ``sf = StatefulFunction(f)`` creates a stateful version of ``f``. You can
402
- then call it directly with compatibility with JIT:
403
-
404
- .. code-block:: python
405
-
406
- >>> sf = brainstate.transform.StatefulFunction(f)
407
- >>> out = sf(x) # Automatically compiles and executes
408
-
409
- Parameters
410
- ----------
411
- fun : callable
412
- The function whose ``jaxpr`` is to be computed. Its positional
413
- arguments and return value should be arrays, scalars, or standard Python
414
- containers (tuple/list/dict) thereof.
415
- static_argnums : int or iterable of int, optional
416
- Indices of positional arguments to treat as static (known at compile time).
417
- See :py:func:`jax.jit` for details. Default is ().
418
- static_argnames : str or iterable of str, optional
419
- Names of keyword arguments to treat as static (known at compile time).
420
- See :py:func:`jax.jit` for details. Default is ().
421
- axis_env : sequence of tuple, optional
422
- A sequence of pairs where the first element is an axis name and the second
423
- element is a positive integer representing the size of the mapped axis with
424
- that name. This parameter is useful when lowering functions that involve
425
- parallel communication collectives, and it specifies the axis name/size
426
- environment that would be set up by applications of :py:func:`jax.pmap`.
427
- Default is None.
428
- abstracted_axes : pytree, optional
429
- A pytree with the same structure as the input arguments to ``fun``. The
430
- leaves of the pytree can be either None or a dict with axis names as keys
431
- and integers as values. If the leaf is None, then the corresponding axis
432
- is not abstracted. If the leaf is a dict, then the corresponding axis is
433
- abstracted, and the dict specifies the axis name and size. The abstracted
434
- axes are used to infer the input type of the function. If None, then all
435
- axes are abstracted. Default is None.
436
- name : str, optional
437
- Name for the stateful function. Default is None.
438
- return_only_write : bool, optional
439
- If True, only return states that were written to during execution
440
- (not just read). This can reduce memory usage when you only care
441
- about modified states. Default is True.
442
-
443
- Attributes
444
- ----------
445
- fun : callable
446
- The wrapped function.
447
- static_argnums : tuple of int
448
- Indices of static positional arguments.
449
- static_argnames : tuple of str
450
- Names of static keyword arguments.
451
- axis_env : sequence of tuple or None
452
- Axis environment for parallel operations.
453
- abstracted_axes : pytree or None
454
- Abstract axes specification.
455
- name : str or None
456
- Name identifier for the function.
457
- return_only_write : bool
458
- Whether to return only written states.
459
-
460
- Examples
461
- --------
462
- Basic usage with state management:
463
-
464
- .. code-block:: python
465
-
466
- >>> import brainstate
467
- >>> import jax.numpy as jnp
468
- >>>
469
- >>> # Create a state
470
- >>> state = brainstate.State(jnp.array([1.0, 2.0]))
471
- >>>
472
- >>> def f(x):
473
- ... state.value += x
474
- ... return state.value * 2
475
- >>>
476
- >>> # Create a stateful function
477
- >>> sf = brainstate.transform.StatefulFunction(f)
478
- >>>
479
- >>> # Compile and get jaxpr
480
- >>> x = jnp.array([0.5, 0.5])
481
- >>> sf.make_jaxpr(x)
482
- >>>
483
- >>> # Get states that are read/written
484
- >>> cache_key = sf.get_arg_cache_key(x)
485
- >>> states = sf.get_states_by_cache(cache_key)
486
- >>> read_states = sf.get_read_states_by_cache(cache_key)
487
- >>> write_states = sf.get_write_states_by_cache(cache_key)
488
-
489
- Using with static arguments:
490
-
491
- .. code-block:: python
492
-
493
- >>> def g(x, n):
494
- ... state.value = state.value ** n
495
- ... return state.value
496
- >>>
497
- >>> sf_static = brainstate.transform.StatefulFunction(
498
- ... g, static_argnums=(1,)
499
- ... )
500
- >>> sf_static.make_jaxpr(x, 2)
501
-
502
- Automatic state management:
503
-
504
- .. code-block:: python
505
-
506
- >>> # Execute with automatic state handling
507
- >>> result = sf.jaxpr_call_auto(x)
508
- >>> print(state.value) # State is automatically updated
509
-
510
- See Also
511
- --------
512
- make_jaxpr : Function to create jaxpr from a function.
513
- brainstate.State : The state container class.
514
-
515
- Notes
516
- -----
517
- This class maintains internal thread-safe caches for compiled jaxprs, output
518
- shapes, and state traces. The cache size is bounded at 128 entries per cache
519
- type. Use ``clear_cache()`` to manually clear the caches if needed.
520
-
521
- State objects should not be passed as direct inputs or outputs to the wrapped
522
- function. Instead, they should be accessed within the function body, and the
523
- class will automatically track their usage.
524
- """
525
- __module__ = "brainstate.transform"
526
-
527
- def __init__(
528
- self,
529
- fun: Callable,
530
- static_argnums: Union[int, Iterable[int]] = (),
531
- static_argnames: Union[str, Iterable[str]] = (),
532
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
533
- abstracted_axes: Optional[Any] = None,
534
- name: Optional[str] = None,
535
- return_only_write: bool = True,
536
- ):
537
- # explicit parameters
538
- self.fun = fun
539
- self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
540
- self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
541
- self.axis_env = axis_env
542
- self.abstracted_axes = abstracted_axes
543
- self.name = name
544
- self.return_only_write = return_only_write
545
-
546
- # implicit parameters - thread-safe bounded caches
547
- self._cached_jaxpr = _BoundedCache(maxsize=128)
548
- self._cached_out_shapes = _BoundedCache(maxsize=128)
549
- self._cached_jaxpr_out_tree = _BoundedCache(maxsize=128)
550
- self._cached_state_trace = _BoundedCache(maxsize=128)
551
- self._cache_lock = threading.RLock()
552
-
553
- def __pretty_repr_item__(self, k, v):
554
- if k.startswith('_'):
555
- return None
556
- return k, v
557
-
558
- def get_jaxpr_by_cache(self, cache_key: Hashable) -> ClosedJaxpr:
559
- """
560
- Read the JAX Jaxpr representation of the function.
561
-
562
- Parameters
563
- ----------
564
- cache_key : Hashable
565
- The hashable cache key for retrieving the compiled jaxpr.
566
-
567
- Returns
568
- -------
569
- ClosedJaxpr
570
- The JAX Jaxpr representation of the function.
571
-
572
- Raises
573
- ------
574
- ValueError
575
- If the function has not been compiled for the given cache key.
576
- """
577
- return self._cached_jaxpr.get(cache_key, raise_on_miss=True, error_context="JAX expression")
578
-
579
- def get_jaxpr(self, *args, compile_if_miss: bool = True, **kwargs) -> ClosedJaxpr:
580
- """
581
- Read the JAX Jaxpr representation of the function by calling with args.
582
-
583
- Parameters
584
- ----------
585
- *args
586
- The arguments to the function.
587
- compile_if_miss : bool, optional
588
- Whether to compile the function if the cache key is not found. Default is True.
589
- **kwargs
590
- The keyword arguments to the function.
591
-
592
- Returns
593
- -------
594
- ClosedJaxpr
595
- The JAX Jaxpr representation of the function.
596
- """
597
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
598
- return self.get_jaxpr_by_cache(cache_key)
599
-
600
- def get_out_shapes_by_cache(self, cache_key: Hashable) -> PyTree:
601
- """
602
- Read the output shapes of the function.
603
-
604
- Parameters
605
- ----------
606
- cache_key : Hashable
607
- The hashable cache key.
608
-
609
- Returns
610
- -------
611
- PyTree
612
- The output shapes of the function.
613
-
614
- Raises
615
- ------
616
- ValueError
617
- If the function has not been compiled for the given cache key.
618
- """
619
- return self._cached_out_shapes.get(cache_key, raise_on_miss=True, error_context="Output shapes")
620
-
621
- def get_out_shapes(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
622
- """
623
- Read the output shapes of the function.
624
-
625
- Parameters
626
- ----------
627
- *args
628
- The arguments to the function.
629
- compile_if_miss : bool, optional
630
- Whether to compile the function if the cache key is not found. Default is True.
631
- **kwargs
632
- The keyword arguments to the function.
633
-
634
- Returns
635
- -------
636
- PyTree
637
- The output shapes of the function.
638
- """
639
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
640
- return self.get_out_shapes_by_cache(cache_key)
641
-
642
- def get_out_treedef_by_cache(self, cache_key: Hashable) -> PyTree:
643
- """
644
- Read the output tree definition of the function.
645
-
646
- Parameters
647
- ----------
648
- cache_key : Hashable
649
- The hashable cache key.
650
-
651
- Returns
652
- -------
653
- PyTree
654
- The output tree definition of the function.
655
-
656
- Raises
657
- ------
658
- ValueError
659
- If the function has not been compiled for the given cache key.
660
- """
661
- return self._cached_jaxpr_out_tree.get(cache_key, raise_on_miss=True, error_context="Output tree")
662
-
663
- def get_out_treedef(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
664
- """
665
- Read the output tree of the function.
666
-
667
- Parameters
668
- ----------
669
- *args
670
- The arguments to the function.
671
- compile_if_miss : bool, optional
672
- Whether to compile the function if the cache key is not found. Default is True.
673
- **kwargs
674
- The keyword arguments to the function.
675
-
676
- Returns
677
- -------
678
- PyTree
679
- The output tree of the function.
680
- """
681
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
682
- return self.get_out_treedef_by_cache(cache_key)
683
-
684
- def get_state_trace_by_cache(self, cache_key: Hashable) -> StateTraceStack:
685
- """
686
- Read the state trace of the function.
687
-
688
- Parameters
689
- ----------
690
- cache_key : Hashable
691
- The hashable cache key.
692
-
693
- Returns
694
- -------
695
- StateTraceStack
696
- The state trace stack containing all tracked states.
697
-
698
- Raises
699
- ------
700
- ValueError
701
- If the function has not been compiled for the given cache key.
702
- """
703
- return self._cached_state_trace.get(cache_key, raise_on_miss=True, error_context="State trace")
704
-
705
- def get_state_trace(self, *args, compile_if_miss: bool = True, **kwargs) -> StateTraceStack:
706
- """
707
- Read the state trace of the function.
708
-
709
- Parameters
710
- ----------
711
- *args
712
- The arguments to the function.
713
- compile_if_miss : bool, optional
714
- Whether to compile the function if the cache key is not found. Default is True.
715
- **kwargs
716
- The keyword arguments to the function.
717
-
718
- Returns
719
- -------
720
- StateTraceStack
721
- The state trace of the function.
722
- """
723
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
724
- return self.get_state_trace_by_cache(cache_key)
725
-
726
- def get_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
727
- """
728
- Read the states that are accessed by the function.
729
-
730
- Parameters
731
- ----------
732
- cache_key : Hashable
733
- The hashable cache key.
734
-
735
- Returns
736
- -------
737
- Tuple[State, ...]
738
- The states that are read from or written to by the function.
739
-
740
- Raises
741
- ------
742
- ValueError
743
- If the function has not been compiled for the given cache key.
744
- """
745
- return tuple(self.get_state_trace_by_cache(cache_key).states)
746
-
747
- def get_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
748
- """
749
- Compile the function, and get the states that are read and written by this function.
750
-
751
- Parameters
752
- ----------
753
- *args
754
- The arguments to the function.
755
- compile_if_miss : bool, optional
756
- Whether to compile the function if the cache key is not found. Default is True.
757
- **kwargs
758
- The keyword arguments to the function.
759
-
760
- Returns
761
- -------
762
- Tuple[State, ...]
763
- The states that are read and written by the function.
764
- """
765
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
766
- return self.get_states_by_cache(cache_key)
767
-
768
- def get_read_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
769
- """
770
- Read the states that are read by the function.
771
-
772
- Parameters
773
- ----------
774
- cache_key : Hashable
775
- The hashable key.
776
-
777
- Returns
778
- -------
779
- Tuple[State, ...]
780
- The states that are read by the function.
781
- """
782
- return self.get_state_trace_by_cache(cache_key).get_read_states()
783
-
784
- def get_read_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
785
- """
786
- Compile the function, and get the states that are read by this function.
787
-
788
- Parameters
789
- ----------
790
- *args
791
- The arguments to the function.
792
- compile_if_miss : bool, optional
793
- Whether to compile the function if the cache key is not found. Default is True.
794
- **kwargs
795
- The keyword arguments to the function.
796
-
797
- Returns
798
- -------
799
- Tuple[State, ...]
800
- The states that are read by the function.
801
- """
802
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
803
- return self.get_read_states_by_cache(cache_key)
804
-
805
- def get_write_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
806
- """
807
- Read the states that are written by the function.
808
-
809
- Parameters
810
- ----------
811
- cache_key : Hashable
812
- The hashable cache key.
813
-
814
- Returns
815
- -------
816
- Tuple[State, ...]
817
- The states that are written by the function.
818
- """
819
- return self.get_state_trace_by_cache(cache_key).get_write_states()
820
-
821
- def get_write_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
822
- """
823
- Compile the function, and get the states that are written by this function.
824
-
825
- Parameters
826
- ----------
827
- *args
828
- The arguments to the function.
829
- compile_if_miss : bool, optional
830
- Whether to compile the function if the cache key is not found. Default is True.
831
- **kwargs
832
- The keyword arguments to the function.
833
-
834
- Returns
835
- -------
836
- Tuple[State, ...]
837
- The states that are written by the function.
838
- """
839
- cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
840
- return self.get_write_states_by_cache(cache_key)
841
-
842
- def _check_input_ouput(self, x):
843
- if isinstance(x, State):
844
- x.raise_error_with_source_info(
845
- ValueError(
846
- 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
847
- f'But we got {x}'
848
- )
849
- )
850
-
851
- def get_arg_cache_key(self, *args, compile_if_miss: bool = False, **kwargs) -> hashabledict:
852
- """
853
- Compute the cache key for the given arguments.
854
-
855
- This method separates static and dynamic arguments and creates a hashable
856
- key that can be used to cache compiled jaxpr representations.
857
-
858
- Parameters
859
- ----------
860
- *args
861
- The positional arguments to the function.
862
- compile_if_miss : bool, optional
863
- Whether to compile the function if the cache key does not exist.
864
- Default is False.
865
- **kwargs
866
- The keyword arguments to the function.
867
-
868
- Returns
869
- -------
870
- hashabledict
871
- A hashable dictionary containing the cache key with fields:
872
- 'static_args', 'dyn_args', 'static_kwargs', 'dyn_kwargs'.
873
-
874
- Examples
875
- --------
876
- .. code-block:: python
877
-
878
- >>> import brainstate
879
- >>> import jax.numpy as jnp
880
- >>>
881
- >>> def f(x, n):
882
- ... return x ** n
883
- >>>
884
- >>> sf = brainstate.transform.StatefulFunction(
885
- ... f, static_argnums=(1,)
886
- ... )
887
- >>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
888
- """
889
- static_args, dyn_args = [], []
890
- for i, arg in enumerate(args):
891
- if i in self.static_argnums:
892
- static_args.append(arg)
893
- else:
894
- dyn_args.append(arg)
895
- dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
896
- static_kwargs, dyn_kwargs = [], []
897
- for k, v in sorted(kwargs.items()):
898
- if k in self.static_argnames:
899
- static_kwargs.append((k, v))
900
- else:
901
- dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
902
-
903
- static_args = make_hashable(tuple(static_args))
904
- dyn_args = make_hashable(tuple(dyn_args))
905
- static_kwargs = make_hashable(static_kwargs)
906
- dyn_kwargs = make_hashable(dyn_kwargs)
907
-
908
- cache_key = hashabledict(
909
- static_args=static_args,
910
- dyn_args=dyn_args,
911
- static_kwargs=static_kwargs,
912
- dyn_kwargs=dyn_kwargs,
913
- )
914
-
915
- if cache_key not in self._cached_state_trace and compile_if_miss:
916
- self.make_jaxpr(*args, **kwargs)
917
-
918
- return cache_key
919
-
920
- def clear_cache(self) -> None:
921
- """
922
- Clear all compilation caches.
923
-
924
- This method removes all cached jaxprs, output shapes, output trees,
925
- and state traces. Use this when you need to recompile the function
926
- or free memory.
927
-
928
- Examples
929
- --------
930
- .. code-block:: python
931
-
932
- >>> import brainstate
933
- >>> import jax.numpy as jnp
934
- >>>
935
- >>> def f(x):
936
- ... return x * 2
937
- >>>
938
- >>> sf = brainstate.transform.StatefulFunction(f)
939
- >>> sf.make_jaxpr(jnp.array([1.0, 2.0]))
940
- >>> sf.clear_cache() # Clear all cached compilations
941
- """
942
- self._cached_jaxpr.clear()
943
- self._cached_out_shapes.clear()
944
- self._cached_jaxpr_out_tree.clear()
945
- self._cached_state_trace.clear()
946
-
947
- def __jax_v04_new_arg(self):
948
- # Should be within the calling of ``jax.make_jaxpr()``
949
- frame, trace = _jax_v04_new_jax_trace()
950
- # Set the function to transform the new argument to a tracer
951
- fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
952
- return fn
953
-
954
- def __jax_new_version_new_arg(self):
955
- trace = jax.core.trace_ctx.trace
956
-
957
- def wrapper(x):
958
- if jax.__version_info__ < (0, 6, 1):
959
- fn = lambda xx: trace.new_arg(shaped_abstractify(xx))
960
- else:
961
- fn = lambda xx: trace.new_arg(shaped_abstractify(xx), source_info=source_info_util.current())
962
- return jax.tree.map(fn, x._value)
963
-
964
- return wrapper
965
-
966
- def _wrapped_fun_to_eval(
967
- self,
968
- cache_key,
969
- static_kwargs: dict,
970
- *args,
971
- **dyn_kwargs,
972
- ) -> Tuple[Any, Tuple[State, ...]]:
973
- """
974
- Internal wrapper that executes the function and tracks state operations.
975
-
976
- This method wraps the original function to track which states are read
977
- and written during execution. It is used internally during jaxpr compilation.
978
-
979
- Parameters
980
- ----------
981
- cache_key
982
- The cache key for storing the state trace.
983
- static_kwargs : dict
984
- Static keyword arguments that were separated out.
985
- *args
986
- The positional arguments to the function.
987
- **dyn_kwargs
988
- Dynamic keyword arguments to the function.
989
-
990
- Returns
991
- -------
992
- tuple
993
- A tuple of (output, state_values) where output is the function result
994
- and state_values are the tracked state values (either all or write-only
995
- depending on return_only_write setting).
996
- """
997
- # state trace
998
- state_trace: StateTraceStack = StateTraceStack(self.name)
999
- if jax.__version_info__ < (0, 4, 36):
1000
- state_trace.set_new_arg(self.__jax_v04_new_arg())
1001
- else:
1002
- state_trace.set_new_arg(self.__jax_new_version_new_arg())
1003
- self._cached_state_trace.set(cache_key, state_trace)
1004
- with state_trace:
1005
- out = self.fun(*args, **dyn_kwargs, **static_kwargs)
1006
- state_values = (
1007
- state_trace.get_write_state_values(True)
1008
- if self.return_only_write else
1009
- state_trace.get_state_values()
1010
- )
1011
- state_trace.recovery_original_values()
1012
-
1013
- # State instance as functional returns is not allowed.
1014
- # Checking whether the states are returned.
1015
- jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
1016
- return out, state_values
1017
-
1018
- def make_jaxpr(self, *args, **kwargs):
1019
- """
1020
- Create the JAX Jaxpr representation given example arguments.
1021
-
1022
- This method compiles the function with the given arguments and caches
1023
- the resulting Jaxpr, output shapes, and state trace for later use.
1024
-
1025
- Parameters
1026
- ----------
1027
- *args
1028
- The arguments to the function.
1029
- **kwargs
1030
- The keyword arguments to the function.
1031
-
1032
- Returns
1033
- -------
1034
- StatefulFunction
1035
- Returns self for method chaining.
1036
-
1037
- Raises
1038
- ------
1039
- TypeError
1040
- If State objects are passed as arguments or returned from the function.
1041
- """
1042
-
1043
- # check input types
1044
- jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
1045
-
1046
- # static args
1047
- cache_key = self.get_arg_cache_key(*args, **kwargs)
1048
-
1049
- if cache_key not in self._cached_state_trace:
1050
- try:
1051
-
1052
- # jaxpr
1053
- static_kwargs, dyn_kwargs = {}, {}
1054
- for k, v in kwargs.items():
1055
- if k in self.static_argnames:
1056
- static_kwargs[k] = v
1057
- else:
1058
- dyn_kwargs[k] = v
1059
- jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
1060
- functools.partial(
1061
- self._wrapped_fun_to_eval,
1062
- cache_key,
1063
- static_kwargs,
1064
- ),
1065
- static_argnums=self.static_argnums,
1066
- axis_env=self.axis_env,
1067
- return_shape=True,
1068
- abstracted_axes=self.abstracted_axes,
1069
- )(*args, **dyn_kwargs)
1070
-
1071
- # returns
1072
- self._cached_jaxpr_out_tree.set(cache_key, jax.tree.structure((out_shapes, state_shapes)))
1073
- self._cached_out_shapes.set(cache_key, (out_shapes, state_shapes))
1074
- self._cached_jaxpr.set(cache_key, jaxpr)
1075
-
1076
- except Exception as e:
1077
- # Clean up partial cache entries on error
1078
- self._cached_state_trace.pop(cache_key, None)
1079
- self._cached_out_shapes.pop(cache_key, None)
1080
- self._cached_jaxpr.pop(cache_key, None)
1081
- raise e
1082
-
1083
- return self
1084
-
1085
- def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
1086
- """
1087
- Call the function at the JAX Jaxpr level.
1088
-
1089
- This method evaluates the compiled Jaxpr with the provided state values
1090
- and arguments, returning updated state values and function outputs.
1091
-
1092
- Parameters
1093
- ----------
1094
- state_vals : Sequence
1095
- The current state values.
1096
- *args
1097
- The arguments to the function.
1098
- **kwargs
1099
- The keyword arguments to the function.
1100
-
1101
- Returns
1102
- -------
1103
- tuple
1104
- A tuple of (new_state_vals, out) where new_state_vals are the
1105
- updated state values and out is the function output.
1106
-
1107
- Raises
1108
- ------
1109
- ValueError
1110
- If the number of state values doesn't match the expected number.
1111
- """
1112
- # state checking
1113
- cache_key = self.get_arg_cache_key(*args, **kwargs)
1114
- states: Sequence[State] = self.get_states_by_cache(cache_key)
1115
- if len(state_vals) != len(states):
1116
- raise ValueError(f'State length mismatch: expected {len(states)} states, got {len(state_vals)}')
1117
-
1118
- # parameters
1119
- kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
1120
- args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
1121
- args = jax.tree.flatten((args, kwargs, state_vals))[0]
1122
-
1123
- # calling the function,
1124
- # note that this function always returns state values
1125
- # that both write and read by the function
1126
- closed_jaxpr = self.get_jaxpr_by_cache(cache_key)
1127
- out_treedef = self.get_out_treedef_by_cache(cache_key)
1128
- jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
1129
-
1130
- # output processing
1131
- out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
1132
- if len(new_state_vals) != len(state_vals):
1133
- raise ValueError(f'State length mismatch in output: expected '
1134
- f'{len(state_vals)} states, got {len(new_state_vals)}')
1135
- return new_state_vals, out
1136
-
1137
- def get_cache_stats(self) -> Dict[str, Any]:
1138
- """
1139
- Get comprehensive cache statistics for all internal caches.
1140
-
1141
- Returns
1142
- -------
1143
- dict
1144
- A dictionary with statistics for each cache including size, hits,
1145
- misses, and hit rates. Keys are 'jaxpr_cache', 'out_shapes_cache',
1146
- 'jaxpr_out_tree_cache', and 'state_trace_cache'.
1147
- """
1148
- return {
1149
- 'jaxpr_cache': self._cached_jaxpr.get_stats(),
1150
- 'out_shapes_cache': self._cached_out_shapes.get_stats(),
1151
- 'jaxpr_out_tree_cache': self._cached_jaxpr_out_tree.get_stats(),
1152
- 'state_trace_cache': self._cached_state_trace.get_stats(),
1153
- }
1154
-
1155
- def validate_states(self, cache_key: Hashable) -> bool:
1156
- """
1157
- Validate that all tracked states for a given cache key are still valid.
1158
-
1159
- Parameters
1160
- ----------
1161
- cache_key : Hashable
1162
- The cache key to validate states for.
1163
-
1164
- Returns
1165
- -------
1166
- bool
1167
- True if all states are valid.
1168
-
1169
- Raises
1170
- ------
1171
- ValueError
1172
- If any states are invalid or missing required attributes.
1173
- """
1174
- state_trace = self.get_state_trace_by_cache(cache_key)
1175
- invalid_states = []
1176
- for i, state in enumerate(state_trace.states):
1177
- if not hasattr(state, 'value'):
1178
- invalid_states.append((i, state))
1179
-
1180
- if invalid_states:
1181
- raise ValueError(
1182
- f"Found {len(invalid_states)} invalid states at indices: "
1183
- f"{[idx for idx, _ in invalid_states]}. "
1184
- f"States must have a 'value' attribute."
1185
- )
1186
- return True
1187
-
1188
- def validate_all_states(self) -> Dict[Any, bool]:
1189
- """
1190
- Validate states for all cached compilations.
1191
-
1192
- Returns
1193
- -------
1194
- dict
1195
- A dictionary mapping cache keys to validation results. Each value
1196
- is either True (valid) or an error message string (invalid).
1197
- """
1198
- results = {}
1199
- for cache_key in self._cached_state_trace.keys():
1200
- try:
1201
- results[cache_key] = self.validate_states(cache_key)
1202
- except ValueError as e:
1203
- results[cache_key] = str(e)
1204
- return results
1205
-
1206
- def jaxpr_call_auto(self, *args, **kwargs) -> Any:
1207
- """
1208
- Execute the function at the jaxpr level with automatic state management.
1209
-
1210
- This method automatically retrieves current state values, executes the
1211
- jaxpr-compiled function, and updates the states with the new values.
1212
- It provides a convenient interface that handles all state management
1213
- automatically.
1214
-
1215
- Parameters
1216
- ----------
1217
- *args
1218
- The positional arguments to the function.
1219
- **kwargs
1220
- The keyword arguments to the function.
1221
-
1222
- Returns
1223
- -------
1224
- Any
1225
- The output of the function.
1226
-
1227
- Examples
1228
- --------
1229
- .. code-block:: python
1230
-
1231
- >>> import brainstate
1232
- >>> import jax.numpy as jnp
1233
- >>>
1234
- >>> state = brainstate.State(jnp.array([1.0, 2.0]))
1235
- >>>
1236
- >>> def f(x):
1237
- ... state.value += x
1238
- ... return state.value * 2
1239
- >>>
1240
- >>> sf = brainstate.transform.StatefulFunction(f)
1241
- >>> x = jnp.array([0.5, 0.5])
1242
- >>> sf.make_jaxpr(x)
1243
- >>>
1244
- >>> # Automatic state management
1245
- >>> result = sf.jaxpr_call_auto(x)
1246
- # # or
1247
- >>> result = sf(x)
1248
- >>> print(state.value) # State is automatically updated
1249
- """
1250
- state_trace = self.get_state_trace_by_cache(self.get_arg_cache_key(*args, **kwargs, compile_if_miss=True))
1251
- all_read_state_vals = state_trace.get_read_state_values(True)
1252
- state_vals, out = self.jaxpr_call(state_trace.get_state_values(), *args, **kwargs)
1253
- state_trace.assign_state_vals_v2(all_read_state_vals, state_vals)
1254
- return out
1255
-
1256
- def __call__(self, *args, **kwargs):
1257
- return self.jaxpr_call_auto(*args, **kwargs)
1258
-
1259
-
1260
- @set_module_as("brainstate.transform")
1261
- def make_jaxpr(
1262
- fun: Callable,
1263
- static_argnums: Union[int, Iterable[int]] = (),
1264
- static_argnames: Union[str, Iterable[str]] = (),
1265
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1266
- return_shape: bool = False,
1267
- abstracted_axes: Optional[Any] = None,
1268
- return_only_write: bool = False,
1269
- ) -> Callable[
1270
- ...,
1271
- (Tuple[ClosedJaxpr, Tuple[State, ...]] |
1272
- Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
1273
- ]:
1274
- """
1275
- Creates a function that produces its jaxpr given example args.
1276
-
1277
- A ``jaxpr`` is JAX's intermediate representation for program traces. The
1278
- ``jaxpr`` language is based on the simply-typed first-order lambda calculus
1279
- with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
1280
- ``jaxpr``, which we can inspect to understand what JAX is doing internally.
1281
- The ``jaxpr`` returned is a trace of ``fun`` abstracted to
1282
- :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
1283
-
1284
- Parameters
1285
- ----------
1286
- fun : callable
1287
- The function whose ``jaxpr`` is to be computed. Its positional
1288
- arguments and return value should be arrays, scalars, or standard Python
1289
- containers (tuple/list/dict) thereof.
1290
- static_argnums : int or iterable of int, optional
1291
- See the :py:func:`jax.jit` docstring.
1292
- static_argnames : str or iterable of str, optional
1293
- See the :py:func:`jax.jit` docstring.
1294
- axis_env : sequence of tuple, optional
1295
- A sequence of pairs where the first element is an axis
1296
- name and the second element is a positive integer representing the size of
1297
- the mapped axis with that name. This parameter is useful when lowering
1298
- functions that involve parallel communication collectives, and it
1299
- specifies the axis name/size environment that would be set up by
1300
- applications of :py:func:`jax.pmap`.
1301
- return_shape : bool, default False
1302
- If ``True``, the
1303
- wrapped function returns a pair where the first element is the XLA
1304
- computation and the second element is a pytree with the same structure as
1305
- the output of ``fun`` and where the leaves are objects with ``shape``,
1306
- ``dtype``, and ``named_shape`` attributes representing the corresponding
1307
- types of the output leaves.
1308
- abstracted_axes : pytree, optional
1309
- A pytree with the same structure as the input
1310
- arguments to ``fun``. The leaves of the pytree can be either None or a
1311
- dict with axis names as keys and integers as values. If the leaf is None,
1312
- then the corresponding axis is not abstracted. If the leaf is a dict, then
1313
- the corresponding axis is abstracted, and the dict specifies the axis name
1314
- and size. The abstracted axes are used to infer the input type of the
1315
- function. If None, then all axes are abstracted.
1316
- return_only_write : bool, default False
1317
- If True, only return states that were written to during execution
1318
- (not just read). This can reduce memory usage when you only care
1319
- about modified states.
1320
-
1321
- Returns
1322
- -------
1323
- callable
1324
- A wrapped version of ``fun`` that when applied to example arguments returns
1325
- a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
1326
- argument ``return_shape`` is ``True``, then the returned function instead
1327
- returns a pair where the first element is the ``ClosedJaxpr``
1328
- representation of ``fun`` and the second element is a pytree representing
1329
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
1330
-
1331
- Examples
1332
- --------
1333
- Basic usage:
1334
-
1335
- .. code-block:: python
1336
-
1337
- >>> import jax
1338
- >>> import brainstate
1339
- >>> import jax.numpy as jnp
1340
- >>>
1341
- >>> def f(x):
1342
- ... return jnp.sin(jnp.cos(x))
1343
- >>>
1344
- >>> # Create jaxpr maker
1345
- >>> jaxpr_maker = brainstate.transform.make_jaxpr(f)
1346
- >>> jaxpr, states = jaxpr_maker(3.0)
1347
-
1348
- With gradient:
1349
-
1350
- .. code-block:: python
1351
-
1352
- >>> jaxpr_grad_maker = brainstate.transform.make_jaxpr(jax.grad(f))
1353
- >>> jaxpr, states = jaxpr_grad_maker(3.0)
1354
-
1355
- With shape information:
1356
-
1357
- .. code-block:: python
1358
-
1359
- >>> jaxpr_maker_with_shape = brainstate.transform.make_jaxpr(f, return_shape=True)
1360
- >>> jaxpr, states, shapes = jaxpr_maker_with_shape(3.0)
1361
-
1362
- With stateful function:
1363
-
1364
- .. code-block:: python
1365
-
1366
- >>> state = brainstate.State(jnp.array([1.0, 2.0]))
1367
- >>>
1368
- >>> def stateful_f(x):
1369
- ... state.value += x
1370
- ... return state.value
1371
- >>>
1372
- >>> jaxpr_maker = brainstate.transform.make_jaxpr(stateful_f)
1373
- >>> jaxpr, states = jaxpr_maker(jnp.array([0.5, 0.5]))
1374
- """
1375
-
1376
- stateful_fun = StatefulFunction(
1377
- fun,
1378
- static_argnums=static_argnums,
1379
- static_argnames=static_argnames,
1380
- axis_env=axis_env,
1381
- abstracted_axes=abstracted_axes,
1382
- return_only_write=return_only_write,
1383
- name='make_jaxpr'
1384
- )
1385
-
1386
- @wraps(fun)
1387
- def make_jaxpr_f(*args, **kwargs):
1388
- stateful_fun.make_jaxpr(*args, **kwargs)
1389
- cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
1390
- if return_shape:
1391
- return (
1392
- stateful_fun.get_jaxpr_by_cache(cache_key),
1393
- stateful_fun.get_states_by_cache(cache_key),
1394
- stateful_fun.get_out_shapes_by_cache(cache_key)[0]
1395
- )
1396
- else:
1397
- return (
1398
- stateful_fun.get_jaxpr_by_cache(cache_key),
1399
- stateful_fun.get_states_by_cache(cache_key)
1400
- )
1401
-
1402
- # wrapped jaxpr builder function
1403
- make_jaxpr_f.__module__ = "brainstate.transform"
1404
- if hasattr(fun, "__qualname__"):
1405
- make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
1406
- if hasattr(fun, "__name__"):
1407
- make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
1408
- return make_jaxpr_f
1409
-
1410
-
1411
- class StatefulMapping(StatefulFunction):
1412
- __module__ = "brainstate.transform"
1413
-
1414
- def __init__(
1415
- self,
1416
- fun: Callable,
1417
- in_axes: Union[int, Tuple[int, ...], None] = 0,
1418
- out_axes: Union[int, Tuple[int, ...], None] = 0,
1419
- state_in_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1420
- state_out_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1421
- # jit specific parameters
1422
- static_argnums: Union[int, Iterable[int]] = (),
1423
- static_argnames: Union[str, Iterable[str]] = (),
1424
- axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1425
- abstracted_axes: Optional[Any] = None,
1426
- # mapping specific parameters
1427
- axis_size: Optional[int] = None,
1428
- axis_name: AxisName | None = None,
1429
- name: Optional[str] = None,
1430
- # mapping function
1431
- mapping_fn: Callable = jax.vmap,
1432
- ):
1433
- self.origin_fun = fun
1434
- super().__init__(
1435
- fun=self._wrapped_fun,
1436
- static_argnums=static_argnums,
1437
- static_argnames=static_argnames,
1438
- axis_env=axis_env,
1439
- abstracted_axes=abstracted_axes,
1440
- name=name,
1441
- return_only_write=False,
1442
- )
1443
- self.in_axes = in_axes
1444
- self.out_axes = out_axes
1445
- if state_in_axes is None:
1446
- state_in_axes = dict()
1447
- elif not isinstance(state_in_axes, dict):
1448
- state_in_axes = {0: to_predicate(state_in_axes)}
1449
- state_in_axes = {k: to_predicate(v) for k, v in state_in_axes.items()} # type: ignore
1450
- self.state_in_axes = state_in_axes
1451
-
1452
- if state_out_axes is None:
1453
- state_out_axes = dict()
1454
- elif not isinstance(state_out_axes, dict):
1455
- state_out_axes = {0: to_predicate(state_out_axes)}
1456
- state_out_axes = {k: to_predicate(v) for k, v in state_out_axes.items()} # type: ignore
1457
- self.state_out_axes = state_out_axes
1458
-
1459
- self.axis_size = axis_size
1460
- self.axis_name = axis_name
1461
- self.mapping_fn = mapping_fn
1462
-
1463
- # Cache for discovered state-to-axis mappings
1464
- self._cached_map_dim_to_in_states = _BoundedCache(maxsize=128)
1465
- self._cached_map_dim_to_out_states = _BoundedCache(maxsize=128)
1466
- self._cached_map_state_trace = _BoundedCache(maxsize=128)
1467
- self._cached_map_batch_size = _BoundedCache(maxsize=128)
1468
-
1469
- def _infer_batch_size(self, args, in_axes):
1470
- if in_axes is None:
1471
- raise ValueError("Cannot infer batch size when in_axes is None")
1472
-
1473
- batch_sizes = []
1474
-
1475
- def get_batch_size_from_arg(arg_, axis_):
1476
- if axis_ is None:
1477
- return None
1478
-
1479
- def _get_size(arr):
1480
- if not hasattr(arr, 'shape'):
1481
- return None
1482
- if arr.ndim == 0:
1483
- return None
1484
- ax = axis_ if axis_ >= 0 else arr.ndim + axis_
1485
- if ax < 0 or ax >= arr.ndim:
1486
- raise IndexError(f"Axis {ax} is out of bounds for array of shape {arr.shape}")
1487
- return arr.shape[ax]
1488
-
1489
- # Get all sizes from the pytree
1490
- sizes = [s for s in jax.tree.leaves(jax.tree.map(_get_size, arg_)) if s is not None]
1491
- return sizes[0] if sizes else None
1492
-
1493
- if isinstance(in_axes, int):
1494
- # All args batched along the same axis
1495
- for arg in args:
1496
- size = get_batch_size_from_arg(arg, in_axes)
1497
- if size is not None:
1498
- batch_sizes.append(size)
1499
- elif isinstance(in_axes, (tuple, list)):
1500
- # Different axes for different args
1501
- if len(in_axes) != len(args):
1502
- raise ValueError(
1503
- f"Length of in_axes ({len(in_axes)}) must match number of arguments ({len(args)})"
1504
- )
1505
- for arg, axis in zip(args, in_axes):
1506
- size = get_batch_size_from_arg(arg, axis)
1507
- if size is not None:
1508
- batch_sizes.append(size)
1509
- else:
1510
- raise TypeError(f"Unsupported in_axes type: {type(in_axes)}")
1511
-
1512
- if not batch_sizes:
1513
- if self.axis_size is None:
1514
- raise ValueError("Cannot infer batch size when axis_size is None")
1515
- batch_sizes.append(self.axis_size)
1516
-
1517
- # Check all batch sizes are consistent
1518
- if not all(s == batch_sizes[0] for s in batch_sizes):
1519
- raise ValueError(
1520
- f"Inconsistent batch sizes found: {batch_sizes}. "
1521
- f"All batched arguments must have the same size along their batch axes."
1522
- )
1523
-
1524
- return batch_sizes[0]
1525
-
1526
- def __new_batch_arg(self, batch_size: int, dim_to_states: dict):
1527
- trace = jax.core.trace_ctx.trace
1528
- assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
1529
-
1530
- def wrapper(x):
1531
- if isinstance(x, RandomState):
1532
- idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current()))
1533
- dim_to_states['random'].append(x)
1534
- return to_elt(trace, idx, jnp.ones((batch_size,) + x._value.shape, x._value.dtype), 0)
1535
- for dim, filter_ in self.state_in_axes.items():
1536
- idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current()))
1537
- if filter_(tuple(), x):
1538
- dim_to_states[dim].append(x)
1539
- return jax.tree.map(lambda xx: to_elt(trace, idx, xx, dim), x._value)
1540
- return x._value
1541
-
1542
- return wrapper
1543
-
1544
- def __eval(self, cache_key, *args, **kwargs):
1545
- def fn_to_eval(*new_args, **new_kwargs):
1546
- dim_to_in_states = defaultdict(list)
1547
- state_trace = StateTraceStack(name=self.name)
1548
- state_trace.set_new_arg(
1549
- self.__new_batch_arg(self._cached_map_batch_size.get(cache_key), dim_to_in_states)
1550
- )
1551
- self._cached_map_state_trace.set(cache_key, state_trace)
1552
-
1553
- # call functions
1554
- with state_trace:
1555
- out_ = self.origin_fun(*new_args, **new_kwargs)
1556
-
1557
- # cache
1558
- self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states)
1559
-
1560
- # vmapped state values
1561
- out_states = defaultdict(list)
1562
- out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
1563
- for st in state_trace.states:
1564
- if not isinstance(st, RandomState):
1565
- leaves = jax.tree.leaves(st._value)
1566
- batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
1567
- if len(batch_dims) != 1:
1568
- raise ValueError(
1569
- f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
1570
- "All leaves must have the same batch dimension."
1571
- )
1572
- batch_dim = batch_dims.pop()
1573
- out_states[batch_dim].append(st)
1574
- self._cached_map_dim_to_out_states.set(cache_key, out_states)
1575
-
1576
- try:
1577
- jax.vmap(
1578
- fn_to_eval,
1579
- in_axes=self.in_axes,
1580
- out_axes=self.out_axes,
1581
- axis_name=self.axis_name,
1582
- axis_size=self.axis_size
1583
- )(*args, **kwargs)
1584
- self._cached_map_state_trace.get(cache_key).recovery_original_values()
1585
- except Exception as e:
1586
- if cache_key in self._cached_map_state_trace:
1587
- self._cached_map_state_trace.get(cache_key).recovery_original_values()
1588
- self._cached_map_state_trace.pop(cache_key, None)
1589
- self._cached_map_dim_to_in_states.pop(cache_key, None)
1590
- self._cached_map_dim_to_out_states.pop(cache_key, None)
1591
- self._cached_map_batch_size.pop(cache_key, None)
1592
- raise RuntimeError(f"Failed to evaluate {self}") from e
1593
-
1594
- def __assign_vals_from_in_states(self, cache_key, rand_st, *other_st):
1595
- in_states = self._cached_map_dim_to_in_states.get(cache_key)
1596
- for st, val in zip(in_states['random'], rand_st):
1597
- assert isinstance(st, RandomState)
1598
- st.restore_value(val)
1599
- for group, group_vals in zip([in_states[dim] for dim in in_states.keys() if dim != 'random'], other_st):
1600
- for st, val in zip(group, group_vals):
1601
- st.restore_value(val)
1602
-
1603
- def __assign_vals_from_out_states(self, cache_key, rand_st, *other_st):
1604
- out_states = self._cached_map_dim_to_out_states.get(cache_key)
1605
- for st, val in zip(out_states['random'], rand_st):
1606
- assert isinstance(st, RandomState)
1607
- st.restore_value(val)
1608
- for group, group_vals in zip([out_states[dim] for dim in out_states.keys() if dim != 'random'], other_st):
1609
- for st, val in zip(group, group_vals):
1610
- st.restore_value(val)
1611
-
1612
- def __get_in_state_vals(self, cache_key: Hashable):
1613
- in_states = self._cached_map_dim_to_in_states.get(cache_key)
1614
- in_axes = []
1615
- in_values = []
1616
- for dim, states in in_states.items():
1617
- if dim == 'random':
1618
- continue
1619
- in_axes.append(dim)
1620
- in_values.append([st.value for st in states])
1621
- return tuple(in_axes), in_values
1622
-
1623
- def __get_out_state_vals(self, cache_key: Hashable):
1624
- out_states = self._cached_map_dim_to_out_states.get(cache_key)
1625
- out_axes = []
1626
- out_values = []
1627
- for dim, state in out_states.items():
1628
- if dim == 'random':
1629
- continue
1630
- out_axes.append(dim)
1631
- out_values.append([st.value for st in state])
1632
- return tuple(out_axes), out_values
1633
-
1634
- def __get_rand_state_vals(self, cache_key: Hashable):
1635
- in_states = self._cached_map_dim_to_in_states.get(cache_key)
1636
- batch_size = self._cached_map_batch_size.get(cache_key)
1637
- rand_vals, rand_recover_vals = [], []
1638
- for st in in_states['random']:
1639
- assert isinstance(st, RandomState)
1640
- rand_vals.append(st.split_key(batch_size))
1641
- rand_recover_vals.append(st.value)
1642
- return tuple(rand_vals), tuple(rand_recover_vals)
1643
-
1644
- def __recover_rand_state_vals(self, cache_key: Hashable, rand_recover_vals):
1645
- state_trace = self._cached_map_state_trace.get(cache_key)
1646
- rand_states = [st for st in state_trace.states if isinstance(st, RandomState)]
1647
- for st, val in zip(rand_states, rand_recover_vals):
1648
- st.restore_value(val)
1649
-
1650
- def _wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
1651
- batch_size = self._infer_batch_size(args, self.in_axes)
1652
- cache_key = self.get_arg_cache_key(*args, **kwargs)
1653
- self._cached_map_batch_size.set(cache_key, batch_size)
1654
- if cache_key not in self._cached_map_state_trace:
1655
- self.__eval(cache_key, *args, **kwargs)
1656
-
1657
- def fn_to_map(origin_args, rand_st, *non_rand_st):
1658
- self.__assign_vals_from_in_states(cache_key, rand_st, *non_rand_st)
1659
- out = self.origin_fun(*origin_args[0], **origin_args[1])
1660
- return out, *self.__get_out_state_vals(cache_key)[1]
1661
-
1662
- in_axes, in_state_vals = self.__get_in_state_vals(cache_key)
1663
- out_axes, out_state_vals = self.__get_out_state_vals(cache_key)
1664
- rand_vals, rand_recover_vals = self.__get_rand_state_vals(cache_key)
1665
- mapped_fn = self.mapping_fn(
1666
- fn_to_map,
1667
- in_axes=(self.in_axes, 0) + in_axes,
1668
- out_axes=(self.out_axes,) + out_axes,
1669
- axis_size=self.axis_size,
1670
- axis_name=self.axis_name,
1671
- )
1672
- out_, *out_state_vals = mapped_fn((args, kwargs), rand_vals, *in_state_vals)
1673
- self.__assign_vals_from_out_states(cache_key, rand_recover_vals, *out_state_vals)
1674
- return out_
1675
-
1676
-
1677
- def _check_callable(fun):
1678
- # In Python 3.10+, the only thing stopping us from supporting static methods
1679
- # is that we can't take weak references to them, which the C++ JIT requires.
1680
- if isinstance(fun, staticmethod):
1681
- raise TypeError(f"staticmethod arguments are not supported, got {fun}")
1682
- if not callable(fun):
1683
- raise TypeError(f"Expected a callable value, got {fun}")
1684
- if inspect.isgeneratorfunction(fun):
1685
- raise TypeError(f"Expected a function, got a generator function: {fun}")
1686
-
1687
-
1688
- def _broadcast_prefix(
1689
- prefix_tree: Any,
1690
- full_tree: Any,
1691
- is_leaf: Callable[[Any], bool] | None = None
1692
- ) -> list[Any]:
1693
- # If prefix_tree is not a tree prefix of full_tree, this code can raise a
1694
- # ValueError; use prefix_errors to find disagreements and raise more precise
1695
- # error messages.
1696
- result = []
1697
- num_leaves = lambda t: jax.tree.structure(t).num_leaves
1698
- add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
1699
- jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
1700
- return result
1701
-
1702
-
1703
- def _flat_axes_specs(
1704
- abstracted_axes, *args, **kwargs
1705
- ) -> list[pe.AbstractedAxesSpec]:
1706
- if kwargs:
1707
- raise NotImplementedError
1708
-
1709
- def ax_leaf(l):
1710
- return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
1711
- isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
1712
-
1713
- return _broadcast_prefix(abstracted_axes, args, ax_leaf)
1714
-
1715
-
1716
- @transformation_with_aux
1717
- def _flatten_fun(in_tree, *args_flat):
1718
- py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
1719
- ans = yield py_args, py_kwargs
1720
- yield jax.tree.flatten(ans)
1721
-
1722
-
1723
- def _make_jaxpr(
1724
- fun: Callable,
1725
- static_argnums: int | Iterable[int] = (),
1726
- axis_env: Sequence[tuple[AxisName, int]] | None = None,
1727
- return_shape: bool = False,
1728
- abstracted_axes: Any | None = None,
1729
- ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
1730
- """
1731
- Create a function that produces its jaxpr given example args (internal implementation).
1732
-
1733
- This is an internal implementation function. Users should use the public
1734
- ``make_jaxpr`` function instead.
1735
-
1736
- Parameters
1737
- ----------
1738
- fun : Callable
1739
- The function whose ``jaxpr`` is to be computed. Its positional
1740
- arguments and return value should be arrays, scalars, or standard Python
1741
- containers (tuple/list/dict) thereof.
1742
- static_argnums : int or iterable of int, optional
1743
- See the :py:func:`jax.jit` docstring.
1744
- axis_env : sequence of tuple, optional
1745
- A sequence of pairs where the first element is an axis
1746
- name and the second element is a positive integer representing the size of
1747
- the mapped axis with that name. This parameter is useful when lowering
1748
- functions that involve parallel communication collectives, and it
1749
- specifies the axis name/size environment that would be set up by
1750
- applications of :py:func:`jax.pmap`.
1751
- return_shape : bool, default False
1752
- If ``True``, the wrapped function returns a pair where the first element
1753
- is the ``ClosedJaxpr`` representation of ``fun`` and the second element
1754
- is a pytree with the same structure as the output of ``fun`` and where
1755
- the leaves are objects with ``shape``, ``dtype``, and ``named_shape``
1756
- attributes representing the corresponding types of the output leaves.
1757
- abstracted_axes : Any, optional
1758
- Axes specifications for abstract interpretation.
1759
-
1760
- Returns
1761
- -------
1762
- Callable
1763
- A wrapped version of ``fun`` that when applied to example arguments returns
1764
- a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
1765
- argument ``return_shape`` is ``True``, then the returned function instead
1766
- returns a pair where the first element is the ``ClosedJaxpr``
1767
- representation of ``fun`` and the second element is a pytree representing
1768
- the structure, shape, dtypes, and named shapes of the output of ``fun``.
1769
-
1770
- Notes
1771
- -----
1772
- A ``jaxpr`` is JAX's intermediate representation for program traces. The
1773
- ``jaxpr`` language is based on the simply-typed first-order lambda calculus
1774
- with let-bindings. This function adapts a function to return its
1775
- ``jaxpr``, which we can inspect to understand what JAX is doing internally.
1776
- The ``jaxpr`` returned is a trace of ``fun`` abstracted to
1777
- :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
1778
-
1779
- Examples
1780
- --------
1781
- .. code-block:: python
1782
-
1783
- >>> import jax
1784
- >>>
1785
- >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
1786
- >>> print(f(3.0))
1787
- -0.83602
1788
- >>> _make_jaxpr(f)(3.0)
1789
- { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
1790
- >>> _make_jaxpr(jax.grad(f))(3.0)
1791
- { lambda ; a:f32[]. let
1792
- b:f32[] = cos a
1793
- c:f32[] = sin a
1794
- _:f32[] = sin b
1795
- d:f32[] = cos b
1796
- e:f32[] = mul 1.0 d
1797
- f:f32[] = neg e
1798
- g:f32[] = mul f c
1799
- in (g,) }
1800
- """
1801
- _check_callable(fun)
1802
- static_argnums = _ensure_index_tuple(static_argnums)
1803
-
1804
- def _abstractify(args, kwargs):
1805
- flat_args, in_tree = jax.tree.flatten((args, kwargs))
1806
- if abstracted_axes is None:
1807
- return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
1808
- else:
1809
- axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
1810
- in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
1811
- in_avals, keep_inputs = unzip2(in_type)
1812
- return in_avals, in_tree, keep_inputs
1813
-
1814
- @wraps(fun)
1815
- @api_boundary
1816
- def make_jaxpr_f(*args, **kwargs):
1817
- f = wrap_init(fun, (), {}, 'brainstate.transform.make_jaxpr')
1818
- if static_argnums:
1819
- dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
1820
- f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
1821
- in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
1822
- in_type = tuple(safe_zip(in_avals, keep_inputs))
1823
- f, out_tree = _flatten_fun(f, in_tree)
1824
- f = annotate(f, in_type)
1825
- if jax.__version_info__ < (0, 5, 0):
1826
- debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
1827
- with ExitStack() as stack:
1828
- if axis_env is not None:
1829
- stack.enter_context(extend_axis_env_nd(axis_env))
1830
- if jax.__version_info__ < (0, 5, 0):
1831
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
1832
- else:
1833
- jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
1834
- closed_jaxpr = ClosedJaxpr(jaxpr, consts)
1835
- if return_shape:
1836
- out_avals, _ = unzip2(out_type)
1837
- out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
1838
- return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
1839
- return closed_jaxpr
1840
-
1841
- make_jaxpr_f.__module__ = "brainstate.transform"
1842
- if hasattr(fun, "__qualname__"):
1843
- make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
1844
- if hasattr(fun, "__name__"):
1845
- make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
1846
- return make_jaxpr_f
1847
-
1848
-
1849
- def make_hashable(obj):
1850
- """
1851
- Convert a pytree into a hashable representation.
1852
-
1853
- Parameters
1854
- ----------
1855
- obj : Any
1856
- A pytree object (list, tuple, dict, set, or JAX pytree structure).
1857
-
1858
- Returns
1859
- -------
1860
- Hashable
1861
- A hashable representation of the input object. Lists become tuples,
1862
- dicts become sorted tuples of key-value pairs, sets become frozensets,
1863
- and other pytrees are flattened using JAX's tree utilities.
1864
- """
1865
- if isinstance(obj, (list, tuple)):
1866
- return tuple(make_hashable(item) for item in obj)
1867
- elif isinstance(obj, dict):
1868
- return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
1869
- elif isinstance(obj, set):
1870
- return frozenset(make_hashable(item) for item in obj)
1871
- else:
1872
- # return obj
1873
- # Use JAX's tree_util for any other pytree structures
1874
- try:
1875
- leaves, treedef = jax.tree.flatten(obj)
1876
- return treedef, tuple(leaves)
1877
- except (TypeError, ValueError):
1878
- # Assume obj is already hashable
1879
- return obj
1880
-
1881
-
1882
- class IdentitySet(MutableSet):
1883
- """Set that compares objects by identity.
1884
-
1885
- This is a set that compares objects by identity instead of equality. It is
1886
- useful for storing objects that are not hashable or that should be compared
1887
- by identity.
1888
-
1889
- This is a mutable set, but it does not support the ``__hash__`` method and
1890
- therefore cannot be used as a dictionary key or as an element of another set.
1891
- """
1892
-
1893
- def __init__(self, iterable=None):
1894
- self._data = {}
1895
- if iterable is not None:
1896
- self.update(iterable)
1897
-
1898
- def __contains__(self, value):
1899
- return id(value) in self._data
1900
-
1901
- def __iter__(self):
1902
- return iter(self._data.values())
1903
-
1904
- def __len__(self):
1905
- return len(self._data)
1906
-
1907
- def add(self, value):
1908
- self._data[id(value)] = value
1909
-
1910
- def discard(self, value):
1911
- self._data.pop(id(value), None)
1912
-
1913
- def __repr__(self):
1914
- return f"IdentitySet({list(repr(x) for x in self._data.values())})"
1915
-
1916
- def __str__(self):
1917
- return f"IdentitySet({list(str(x) for x in self._data.values())})"
1918
-
1919
-
1920
- def constant_fold_jaxpr(jaxpr: Jaxpr):
1921
- """
1922
- Given a jaxpr, return a new jaxpr with all constant folding done.
1923
- """
1924
- return _partial_eval_jaxpr(jaxpr, {})
1925
-
1926
-
1927
- def _partial_eval_jaxpr(jaxpr, env):
1928
- env = env.copy()
1929
- new_eqns = []
1930
-
1931
- def read(var):
1932
- if isinstance(var, Literal):
1933
- return var.val
1934
- else:
1935
- return env.get(var, None)
1936
-
1937
- def read_or_self(var):
1938
- out = read(var)
1939
- if out is None:
1940
- return var
1941
- elif isinstance(out, Var):
1942
- return out
1943
- elif isinstance(out, Literal):
1944
- return Literal(out.val, var.aval)
1945
- else:
1946
- assert not isinstance(out, Jaxpr)
1947
- return Literal(out, var.aval)
1948
-
1949
- for eqn in jaxpr.eqns:
1950
- vals = [read(var) for var in eqn.invars]
1951
- if eqn.primitive.name in _constant_fold_blacklist:
1952
- new_eqns.append(eqn)
1953
- elif all(val is not None for val in vals):
1954
- # go ahead and eval it
1955
- out = _eval_eqn(eqn, vals)
1956
-
1957
- # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
1958
- if isinstance(out, Jaxpr):
1959
- # we need to inline this
1960
- new_eqns.extend(out.eqns)
1961
- out = out.outvars
1962
- elif not isinstance(out, tuple) and not isinstance(out, list):
1963
- out = (out,)
1964
-
1965
- for var, val in zip(eqn.outvars, out):
1966
- assert not isinstance(val, Jaxpr)
1967
- if isinstance(val, Literal):
1968
- env[var] = val.val
1969
- else:
1970
- env[var] = val
1971
- else:
1972
- new_eqns.append(eqn)
1973
-
1974
- # now that we've eval everything, inline all the constants
1975
- out_eqns = []
1976
- for eqn in new_eqns:
1977
- eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
1978
- out_eqns.append(eqn)
1979
-
1980
- invars_still_used = IdentitySet()
1981
- for eqn in out_eqns:
1982
- for var in eqn.invars:
1983
- invars_still_used.add(var)
1984
-
1985
- invars = tuple(var for var in jaxpr.invars if var in invars_still_used)
1986
-
1987
- # sub in any constants for outvars
1988
- outvars = tuple(read_or_self(var) for var in jaxpr.outvars)
1989
-
1990
- return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars, debug_info=None)
1991
-
1992
-
1993
- def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
1994
- if eqn.primitive.name == "closed_call":
1995
- assert eqn.primitive.call_primitive
1996
- assert not eqn.primitive.map_primitive
1997
-
1998
- out = _partial_eval_jaxpr(
1999
- eqn.params['call_jaxpr'].jaxpr,
2000
- {
2001
- var: val
2002
- for var, val in
2003
- zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)
2004
- }
2005
- )
2006
- elif eqn.primitive.name == "scan":
2007
- out = eqn.primitive.bind(*vals, **eqn.params)
2008
- else:
2009
- out = eqn.primitive.bind(*vals, **eqn.params)
2010
- return out
2011
-
2012
-
2013
- _constant_fold_blacklist = {
2014
- 'broadcast_in_dim',
2015
- 'broadcast',
2016
- }
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
+ written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
+ include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
+
21
+
22
+ ``StatefulFunction``
23
+ --------------------
24
+
25
+ The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
+ JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
+ The class provides the following methods:
28
+
29
+ - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
+ - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
+ - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
+ - `get_states`: returns the states that are read and written by the function.
33
+ - `get_read_states`: returns the states that are read by the function.
34
+ - `get_write_states`: returns the states that are written by the function.
35
+ - `get_static_args`: returns the static arguments from the arguments.
36
+ - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
+ written by the function.
38
+ - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
+ - `get_out_shapes`: returns the output shapes of the function.
40
+ - `get_out_treedef`: returns the output tree of the function.
41
+
42
+ ``make_jaxpr``
43
+ --------------
44
+
45
+ The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
+ arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
+ `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
+ returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
+ and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
+ function.
51
+
52
+ """
53
+
54
+ import functools
55
+ import inspect
56
+ import operator
57
+ import threading
58
+ import warnings
59
+ from collections import OrderedDict, defaultdict
60
+ from collections.abc import Hashable, Iterable, Sequence
61
+ from collections.abc import MutableSet
62
+ from contextlib import ExitStack
63
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
64
+
65
+ import jax
66
+ import jax.numpy as jnp
67
+ from jax._src import source_info_util
68
+ from jax._src.linear_util import annotate
69
+ from jax._src.traceback_util import api_boundary
70
+ from jax.api_util import shaped_abstractify
71
+ from jax.extend.linear_util import transformation_with_aux
72
+ from jax.interpreters import partial_eval as pe
73
+
74
+ from brainstate._compatible_import import (
75
+ ClosedJaxpr, extend_axis_env_nd, safe_map, safe_zip, unzip2, wraps, wrap_init,
76
+ Literal, Var, Jaxpr, make_iota, to_elt, BatchTracer, BatchTrace,
77
+ )
78
+ from brainstate._error import BatchAxisError
79
+ from brainstate._state import State, StateTraceStack
80
+ from brainstate._utils import set_module_as
81
+ from brainstate.random import RandomState
82
+ from brainstate.typing import Filter, PyTree
83
+ from brainstate.util import PrettyObject
84
+ from brainstate.util.filter import to_predicate
85
+
86
+ AxisName = Hashable
87
+
88
+ __all__ = [
89
+ "StatefulFunction",
90
+ "make_jaxpr",
91
+ "StatefulMapping",
92
+ ]
93
+
94
+
95
+ class hashabledict(dict):
96
+ def __hash__(self):
97
+ return hash(tuple(sorted(self.items())))
98
+
99
+
100
+ class _BoundedCache:
101
+ """
102
+ A thread-safe LRU cache with bounded size.
103
+
104
+ This cache stores a limited number of items and evicts the least recently used item
105
+ when the cache reaches its maximum size. All operations are thread-safe.
106
+
107
+ Parameters
108
+ ----------
109
+ maxsize : int, default 128
110
+ Maximum number of items to store in the cache.
111
+ """
112
+
113
+ def __init__(self, maxsize: int = 128):
114
+ self._cache = OrderedDict()
115
+ self._maxsize = maxsize
116
+ self._lock = threading.RLock()
117
+ self._hits = 0
118
+ self._misses = 0
119
+
120
+ def get(
121
+ self,
122
+ key: Any,
123
+ default: Any = None,
124
+ raise_on_miss: bool = False,
125
+ error_context: str = "item"
126
+ ) -> Any:
127
+ """
128
+ Get an item from the cache.
129
+
130
+ Parameters
131
+ ----------
132
+ key : Any
133
+ The cache key.
134
+ default : Any, optional
135
+ The default value to return if the key is not found.
136
+ raise_on_miss : bool, optional
137
+ If True, raise a detailed ValueError when the key is not found.
138
+ error_context : str, optional
139
+ Context description for the error message (e.g., "Function", "JAX expression").
140
+
141
+ Returns
142
+ -------
143
+ Any
144
+ The cached value or the default value.
145
+
146
+ Raises
147
+ ------
148
+ ValueError
149
+ If raise_on_miss is True and the key is not found.
150
+ """
151
+ with self._lock:
152
+ if key in self._cache:
153
+ self._cache.move_to_end(key)
154
+ self._hits += 1
155
+ return self._cache[key]
156
+ self._misses += 1
157
+
158
+ if raise_on_miss:
159
+ available_keys = list(self._cache.keys())
160
+ error_msg = [
161
+ f"{error_context} not compiled for the requested cache key.",
162
+ f"",
163
+ f"Requested key:",
164
+ f" {key}",
165
+ f"",
166
+ f"Available {{len(available_keys)}} keys:",
167
+ ]
168
+ if available_keys:
169
+ for i, k in enumerate(available_keys, 1):
170
+ error_msg.append(f" [{i}] {k}")
171
+ else:
172
+ error_msg.append(" (none - not compiled yet)")
173
+ error_msg.append("")
174
+ error_msg.append("Call make_jaxpr() first with matching arguments.")
175
+ raise ValueError("\n".join(error_msg))
176
+
177
+ return default
178
+
179
+ def set(self, key: Any, value: Any) -> None:
180
+ """
181
+ Set an item in the cache.
182
+
183
+ Parameters
184
+ ----------
185
+ key : Any
186
+ The cache key.
187
+ value : Any
188
+ The value to cache.
189
+
190
+ Raises
191
+ ------
192
+ ValueError
193
+ If the key already exists in the cache.
194
+ """
195
+ with self._lock:
196
+ if key in self._cache:
197
+ raise ValueError(
198
+ f"Cache key already exists: {key}. "
199
+ f"Cannot overwrite existing cached value. "
200
+ f"Clear the cache first if you need to recompile."
201
+ )
202
+ if len(self._cache) >= self._maxsize:
203
+ self._cache.popitem(last=False)
204
+ self._cache[key] = value
205
+
206
+ def pop(self, key: Any, default: Any = None) -> Any:
207
+ """
208
+ Remove and return an item from the cache.
209
+
210
+ Parameters
211
+ ----------
212
+ key : Any
213
+ The cache key to remove.
214
+ default : Any, optional
215
+ The default value to return if the key is not found.
216
+
217
+ Returns
218
+ -------
219
+ Any
220
+ The cached value or the default value if the key is not found.
221
+ """
222
+ with self._lock:
223
+ if key in self._cache:
224
+ return self._cache.pop(key)
225
+ return default
226
+
227
+ def replace(self, key: Any, value: Any) -> None:
228
+ """
229
+ Replace an existing item in the cache.
230
+
231
+ Parameters
232
+ ----------
233
+ key : Any
234
+ The cache key to replace.
235
+ value : Any
236
+ The new value to cache.
237
+
238
+ Raises
239
+ ------
240
+ KeyError
241
+ If the key does not exist in the cache.
242
+ """
243
+ with self._lock:
244
+ if key not in self._cache:
245
+ raise KeyError(
246
+ f"Cache key does not exist: {key}. "
247
+ f"Cannot replace non-existent cached value. "
248
+ f"Use set() to add a new cache entry."
249
+ )
250
+ self._cache[key] = value
251
+ self._cache.move_to_end(key)
252
+
253
+ def __contains__(self, key: Any) -> bool:
254
+ """
255
+ Check if a key exists in the cache.
256
+
257
+ Parameters
258
+ ----------
259
+ key : Any
260
+ The cache key to check.
261
+
262
+ Returns
263
+ -------
264
+ bool
265
+ True if the key exists in the cache, False otherwise.
266
+ """
267
+ with self._lock:
268
+ return key in self._cache
269
+
270
+ def __len__(self) -> int:
271
+ """
272
+ Get the number of items in the cache.
273
+
274
+ Returns
275
+ -------
276
+ int
277
+ The number of items currently in the cache.
278
+ """
279
+ with self._lock:
280
+ return len(self._cache)
281
+
282
+ def clear(self) -> None:
283
+ """
284
+ Clear all items from the cache and reset statistics.
285
+
286
+ This method removes all cached items and resets hit/miss counters to zero.
287
+ """
288
+ with self._lock:
289
+ self._cache.clear()
290
+ self._hits = 0
291
+ self._misses = 0
292
+
293
+ def keys(self):
294
+ """
295
+ Return all keys in the cache.
296
+
297
+ Returns
298
+ -------
299
+ list
300
+ A list of all keys currently in the cache.
301
+ """
302
+ with self._lock:
303
+ return list(self._cache.keys())
304
+
305
+ def get_stats(self) -> Dict[str, Any]:
306
+ """
307
+ Get cache statistics.
308
+
309
+ Returns
310
+ -------
311
+ dict
312
+ A dictionary with cache statistics including:
313
+
314
+ - 'size': Current number of items in cache
315
+ - 'maxsize': Maximum cache size
316
+ - 'hits': Number of cache hits
317
+ - 'misses': Number of cache misses
318
+ - 'hit_rate': Hit rate percentage (0-100)
319
+ """
320
+ with self._lock:
321
+ total = self._hits + self._misses
322
+ hit_rate = (self._hits / total * 100) if total > 0 else 0.0
323
+ return {
324
+ 'size': len(self._cache),
325
+ 'maxsize': self._maxsize,
326
+ 'hits': self._hits,
327
+ 'misses': self._misses,
328
+ 'hit_rate': hit_rate,
329
+ }
330
+
331
+
332
+ def _ensure_str(x: str) -> str:
333
+ if not isinstance(x, str):
334
+ raise TypeError(f"argument is not a string: {x}")
335
+ return x
336
+
337
+
338
+ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
339
+ """Convert x to a tuple of indices."""
340
+ x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
341
+ try:
342
+ return (operator.index(x),)
343
+ except TypeError:
344
+ return tuple(safe_map(operator.index, x))
345
+
346
+
347
+ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
348
+ """Convert x to a tuple of strings."""
349
+ if isinstance(x, str):
350
+ return (x,)
351
+ else:
352
+ return tuple(safe_map(_ensure_str, x))
353
+
354
+
355
+ def _jax_v04_new_arg_fn(frame, trace, aval):
356
+ """
357
+ Transform a new argument to a tracer.
358
+
359
+ Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
360
+
361
+ Args:
362
+ frame: The frame.
363
+ trace: The trace.
364
+ aval: The abstract value.
365
+
366
+ Returns:
367
+ The tracer.
368
+ """
369
+ tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
370
+ frame.tracers.append(tracer)
371
+ frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
372
+ frame.invars.append(var)
373
+ return tracer
374
+
375
+
376
+ def _jax_v04_new_jax_trace():
377
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
378
+ frame = main.jaxpr_stack[-1]
379
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
380
+ return frame, trace
381
+
382
+
383
+ class StatefulFunction(PrettyObject):
384
+ """
385
+ A wrapper class for functions that tracks state reads and writes during execution.
386
+
387
+ This class wraps a function to enable state management in JAX programs by tracking
388
+ which states are read from and written to during function execution. It provides
389
+ methods to compile the function into JAX's intermediate representation (jaxpr),
390
+ inspect state usage, and execute the function with proper state handling.
391
+
392
+ When you define a function:
393
+
394
+ .. code-block:: python
395
+
396
+ >>> state = brainstate.State(1.)
397
+ >>> def f(x):
398
+ ... # Your function logic here
399
+ ... y = x * 2 + state.value
400
+ ... state.value = y
401
+
402
+ Calling ``sf = StatefulFunction(f)`` creates a stateful version of ``f``. You can
403
+ then call it directly with compatibility with JIT:
404
+
405
+ .. code-block:: python
406
+
407
+ >>> sf = brainstate.transform.StatefulFunction(f)
408
+ >>> out = sf(x) # Automatically compiles and executes
409
+
410
+ Parameters
411
+ ----------
412
+ fun : callable
413
+ The function whose ``jaxpr`` is to be computed. Its positional
414
+ arguments and return value should be arrays, scalars, or standard Python
415
+ containers (tuple/list/dict) thereof.
416
+ static_argnums : int or iterable of int, optional
417
+ Indices of positional arguments to treat as static (known at compile time).
418
+ See :py:func:`jax.jit` for details. Default is ().
419
+ static_argnames : str or iterable of str, optional
420
+ Names of keyword arguments to treat as static (known at compile time).
421
+ See :py:func:`jax.jit` for details. Default is ().
422
+ axis_env : sequence of tuple, optional
423
+ A sequence of pairs where the first element is an axis name and the second
424
+ element is a positive integer representing the size of the mapped axis with
425
+ that name. This parameter is useful when lowering functions that involve
426
+ parallel communication collectives, and it specifies the axis name/size
427
+ environment that would be set up by applications of :py:func:`jax.pmap`.
428
+ Default is None.
429
+ abstracted_axes : pytree, optional
430
+ A pytree with the same structure as the input arguments to ``fun``. The
431
+ leaves of the pytree can be either None or a dict with axis names as keys
432
+ and integers as values. If the leaf is None, then the corresponding axis
433
+ is not abstracted. If the leaf is a dict, then the corresponding axis is
434
+ abstracted, and the dict specifies the axis name and size. The abstracted
435
+ axes are used to infer the input type of the function. If None, then all
436
+ axes are abstracted. Default is None.
437
+ name : str, optional
438
+ Name for the stateful function. Default is None.
439
+ return_only_write : bool, optional
440
+ If True, only return states that were written to during execution
441
+ (not just read). This can reduce memory usage when you only care
442
+ about modified states. Default is True.
443
+
444
+ Attributes
445
+ ----------
446
+ fun : callable
447
+ The wrapped function.
448
+ static_argnums : tuple of int
449
+ Indices of static positional arguments.
450
+ static_argnames : tuple of str
451
+ Names of static keyword arguments.
452
+ axis_env : sequence of tuple or None
453
+ Axis environment for parallel operations.
454
+ abstracted_axes : pytree or None
455
+ Abstract axes specification.
456
+ name : str or None
457
+ Name identifier for the function.
458
+ return_only_write : bool
459
+ Whether to return only written states.
460
+
461
+ Examples
462
+ --------
463
+ Basic usage with state management:
464
+
465
+ .. code-block:: python
466
+
467
+ >>> import brainstate
468
+ >>> import jax.numpy as jnp
469
+ >>>
470
+ >>> # Create a state
471
+ >>> state = brainstate.State(jnp.array([1.0, 2.0]))
472
+ >>>
473
+ >>> def f(x):
474
+ ... state.value += x
475
+ ... return state.value * 2
476
+ >>>
477
+ >>> # Create a stateful function
478
+ >>> sf = brainstate.transform.StatefulFunction(f)
479
+ >>>
480
+ >>> # Compile and get jaxpr
481
+ >>> x = jnp.array([0.5, 0.5])
482
+ >>> sf.make_jaxpr(x)
483
+ >>>
484
+ >>> # Get states that are read/written
485
+ >>> cache_key = sf.get_arg_cache_key(x)
486
+ >>> states = sf.get_states_by_cache(cache_key)
487
+ >>> read_states = sf.get_read_states_by_cache(cache_key)
488
+ >>> write_states = sf.get_write_states_by_cache(cache_key)
489
+
490
+ Using with static arguments:
491
+
492
+ .. code-block:: python
493
+
494
+ >>> def g(x, n):
495
+ ... state.value = state.value ** n
496
+ ... return state.value
497
+ >>>
498
+ >>> sf_static = brainstate.transform.StatefulFunction(
499
+ ... g, static_argnums=(1,)
500
+ ... )
501
+ >>> sf_static.make_jaxpr(x, 2)
502
+
503
+ Automatic state management:
504
+
505
+ .. code-block:: python
506
+
507
+ >>> # Execute with automatic state handling
508
+ >>> result = sf.jaxpr_call_auto(x)
509
+ >>> print(state.value) # State is automatically updated
510
+
511
+ See Also
512
+ --------
513
+ make_jaxpr : Function to create jaxpr from a function.
514
+ brainstate.State : The state container class.
515
+
516
+ Notes
517
+ -----
518
+ This class maintains internal thread-safe caches for compiled jaxprs, output
519
+ shapes, and state traces. The cache size is bounded at 128 entries per cache
520
+ type. Use ``clear_cache()`` to manually clear the caches if needed.
521
+
522
+ State objects should not be passed as direct inputs or outputs to the wrapped
523
+ function. Instead, they should be accessed within the function body, and the
524
+ class will automatically track their usage.
525
+ """
526
+ __module__ = "brainstate.transform"
527
+
528
+ def __init__(
529
+ self,
530
+ fun: Callable,
531
+ static_argnums: Union[int, Iterable[int]] = (),
532
+ static_argnames: Union[str, Iterable[str]] = (),
533
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
534
+ abstracted_axes: Optional[Any] = None,
535
+ name: Optional[str] = None,
536
+ return_only_write: bool = True,
537
+ ):
538
+ # explicit parameters
539
+ self.fun = fun
540
+ self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
541
+ self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
542
+ self.axis_env = axis_env
543
+ self.abstracted_axes = abstracted_axes
544
+ self.name = name
545
+ self.return_only_write = return_only_write
546
+
547
+ # implicit parameters - thread-safe bounded caches
548
+ self._cached_jaxpr = _BoundedCache(maxsize=128)
549
+ self._cached_out_shapes = _BoundedCache(maxsize=128)
550
+ self._cached_jaxpr_out_tree = _BoundedCache(maxsize=128)
551
+ self._cached_state_trace = _BoundedCache(maxsize=128)
552
+ self._cache_lock = threading.RLock()
553
+
554
+ def __pretty_repr_item__(self, k, v):
555
+ if k.startswith('_'):
556
+ return None
557
+ return k, v
558
+
559
+ def get_jaxpr_by_cache(self, cache_key: Hashable) -> ClosedJaxpr:
560
+ """
561
+ Read the JAX Jaxpr representation of the function.
562
+
563
+ Parameters
564
+ ----------
565
+ cache_key : Hashable
566
+ The hashable cache key for retrieving the compiled jaxpr.
567
+
568
+ Returns
569
+ -------
570
+ ClosedJaxpr
571
+ The JAX Jaxpr representation of the function.
572
+
573
+ Raises
574
+ ------
575
+ ValueError
576
+ If the function has not been compiled for the given cache key.
577
+ """
578
+ return self._cached_jaxpr.get(cache_key, raise_on_miss=True, error_context="JAX expression")
579
+
580
+ def get_jaxpr(self, *args, compile_if_miss: bool = True, **kwargs) -> ClosedJaxpr:
581
+ """
582
+ Read the JAX Jaxpr representation of the function by calling with args.
583
+
584
+ Parameters
585
+ ----------
586
+ *args
587
+ The arguments to the function.
588
+ compile_if_miss : bool, optional
589
+ Whether to compile the function if the cache key is not found. Default is True.
590
+ **kwargs
591
+ The keyword arguments to the function.
592
+
593
+ Returns
594
+ -------
595
+ ClosedJaxpr
596
+ The JAX Jaxpr representation of the function.
597
+ """
598
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
599
+ return self.get_jaxpr_by_cache(cache_key)
600
+
601
+ def get_out_shapes_by_cache(self, cache_key: Hashable) -> PyTree:
602
+ """
603
+ Read the output shapes of the function.
604
+
605
+ Parameters
606
+ ----------
607
+ cache_key : Hashable
608
+ The hashable cache key.
609
+
610
+ Returns
611
+ -------
612
+ PyTree
613
+ The output shapes of the function.
614
+
615
+ Raises
616
+ ------
617
+ ValueError
618
+ If the function has not been compiled for the given cache key.
619
+ """
620
+ return self._cached_out_shapes.get(cache_key, raise_on_miss=True, error_context="Output shapes")
621
+
622
+ def get_out_shapes(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
623
+ """
624
+ Read the output shapes of the function.
625
+
626
+ Parameters
627
+ ----------
628
+ *args
629
+ The arguments to the function.
630
+ compile_if_miss : bool, optional
631
+ Whether to compile the function if the cache key is not found. Default is True.
632
+ **kwargs
633
+ The keyword arguments to the function.
634
+
635
+ Returns
636
+ -------
637
+ PyTree
638
+ The output shapes of the function.
639
+ """
640
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
641
+ return self.get_out_shapes_by_cache(cache_key)
642
+
643
+ def get_out_treedef_by_cache(self, cache_key: Hashable) -> PyTree:
644
+ """
645
+ Read the output tree definition of the function.
646
+
647
+ Parameters
648
+ ----------
649
+ cache_key : Hashable
650
+ The hashable cache key.
651
+
652
+ Returns
653
+ -------
654
+ PyTree
655
+ The output tree definition of the function.
656
+
657
+ Raises
658
+ ------
659
+ ValueError
660
+ If the function has not been compiled for the given cache key.
661
+ """
662
+ return self._cached_jaxpr_out_tree.get(cache_key, raise_on_miss=True, error_context="Output tree")
663
+
664
+ def get_out_treedef(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
665
+ """
666
+ Read the output tree of the function.
667
+
668
+ Parameters
669
+ ----------
670
+ *args
671
+ The arguments to the function.
672
+ compile_if_miss : bool, optional
673
+ Whether to compile the function if the cache key is not found. Default is True.
674
+ **kwargs
675
+ The keyword arguments to the function.
676
+
677
+ Returns
678
+ -------
679
+ PyTree
680
+ The output tree of the function.
681
+ """
682
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
683
+ return self.get_out_treedef_by_cache(cache_key)
684
+
685
+ def get_state_trace_by_cache(self, cache_key: Hashable) -> StateTraceStack:
686
+ """
687
+ Read the state trace of the function.
688
+
689
+ Parameters
690
+ ----------
691
+ cache_key : Hashable
692
+ The hashable cache key.
693
+
694
+ Returns
695
+ -------
696
+ StateTraceStack
697
+ The state trace stack containing all tracked states.
698
+
699
+ Raises
700
+ ------
701
+ ValueError
702
+ If the function has not been compiled for the given cache key.
703
+ """
704
+ return self._cached_state_trace.get(cache_key, raise_on_miss=True, error_context="State trace")
705
+
706
+ def get_state_trace(self, *args, compile_if_miss: bool = True, **kwargs) -> StateTraceStack:
707
+ """
708
+ Read the state trace of the function.
709
+
710
+ Parameters
711
+ ----------
712
+ *args
713
+ The arguments to the function.
714
+ compile_if_miss : bool, optional
715
+ Whether to compile the function if the cache key is not found. Default is True.
716
+ **kwargs
717
+ The keyword arguments to the function.
718
+
719
+ Returns
720
+ -------
721
+ StateTraceStack
722
+ The state trace of the function.
723
+ """
724
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
725
+ return self.get_state_trace_by_cache(cache_key)
726
+
727
+ def get_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
728
+ """
729
+ Read the states that are accessed by the function.
730
+
731
+ Parameters
732
+ ----------
733
+ cache_key : Hashable
734
+ The hashable cache key.
735
+
736
+ Returns
737
+ -------
738
+ Tuple[State, ...]
739
+ The states that are read from or written to by the function.
740
+
741
+ Raises
742
+ ------
743
+ ValueError
744
+ If the function has not been compiled for the given cache key.
745
+ """
746
+ return tuple(self.get_state_trace_by_cache(cache_key).states)
747
+
748
+ def get_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
749
+ """
750
+ Compile the function, and get the states that are read and written by this function.
751
+
752
+ Parameters
753
+ ----------
754
+ *args
755
+ The arguments to the function.
756
+ compile_if_miss : bool, optional
757
+ Whether to compile the function if the cache key is not found. Default is True.
758
+ **kwargs
759
+ The keyword arguments to the function.
760
+
761
+ Returns
762
+ -------
763
+ Tuple[State, ...]
764
+ The states that are read and written by the function.
765
+ """
766
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
767
+ return self.get_states_by_cache(cache_key)
768
+
769
+ def get_read_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
770
+ """
771
+ Read the states that are read by the function.
772
+
773
+ Parameters
774
+ ----------
775
+ cache_key : Hashable
776
+ The hashable key.
777
+
778
+ Returns
779
+ -------
780
+ Tuple[State, ...]
781
+ The states that are read by the function.
782
+ """
783
+ return self.get_state_trace_by_cache(cache_key).get_read_states()
784
+
785
+ def get_read_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
786
+ """
787
+ Compile the function, and get the states that are read by this function.
788
+
789
+ Parameters
790
+ ----------
791
+ *args
792
+ The arguments to the function.
793
+ compile_if_miss : bool, optional
794
+ Whether to compile the function if the cache key is not found. Default is True.
795
+ **kwargs
796
+ The keyword arguments to the function.
797
+
798
+ Returns
799
+ -------
800
+ Tuple[State, ...]
801
+ The states that are read by the function.
802
+ """
803
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
804
+ return self.get_read_states_by_cache(cache_key)
805
+
806
+ def get_write_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
807
+ """
808
+ Read the states that are written by the function.
809
+
810
+ Parameters
811
+ ----------
812
+ cache_key : Hashable
813
+ The hashable cache key.
814
+
815
+ Returns
816
+ -------
817
+ Tuple[State, ...]
818
+ The states that are written by the function.
819
+ """
820
+ return self.get_state_trace_by_cache(cache_key).get_write_states()
821
+
822
+ def get_write_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
823
+ """
824
+ Compile the function, and get the states that are written by this function.
825
+
826
+ Parameters
827
+ ----------
828
+ *args
829
+ The arguments to the function.
830
+ compile_if_miss : bool, optional
831
+ Whether to compile the function if the cache key is not found. Default is True.
832
+ **kwargs
833
+ The keyword arguments to the function.
834
+
835
+ Returns
836
+ -------
837
+ Tuple[State, ...]
838
+ The states that are written by the function.
839
+ """
840
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
841
+ return self.get_write_states_by_cache(cache_key)
842
+
843
+ def _check_input_ouput(self, x):
844
+ if isinstance(x, State):
845
+ x.raise_error_with_source_info(
846
+ ValueError(
847
+ 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
848
+ f'But we got {x}'
849
+ )
850
+ )
851
+
852
+ def get_arg_cache_key(self, *args, compile_if_miss: bool = False, **kwargs) -> hashabledict:
853
+ """
854
+ Compute the cache key for the given arguments.
855
+
856
+ This method separates static and dynamic arguments and creates a hashable
857
+ key that can be used to cache compiled jaxpr representations.
858
+
859
+ Parameters
860
+ ----------
861
+ *args
862
+ The positional arguments to the function.
863
+ compile_if_miss : bool, optional
864
+ Whether to compile the function if the cache key does not exist.
865
+ Default is False.
866
+ **kwargs
867
+ The keyword arguments to the function.
868
+
869
+ Returns
870
+ -------
871
+ hashabledict
872
+ A hashable dictionary containing the cache key with fields:
873
+ 'static_args', 'dyn_args', 'static_kwargs', 'dyn_kwargs'.
874
+
875
+ Examples
876
+ --------
877
+ .. code-block:: python
878
+
879
+ >>> import brainstate
880
+ >>> import jax.numpy as jnp
881
+ >>>
882
+ >>> def f(x, n):
883
+ ... return x ** n
884
+ >>>
885
+ >>> sf = brainstate.transform.StatefulFunction(
886
+ ... f, static_argnums=(1,)
887
+ ... )
888
+ >>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
889
+ """
890
+ static_args, dyn_args = [], []
891
+ for i, arg in enumerate(args):
892
+ if i in self.static_argnums:
893
+ static_args.append(arg)
894
+ else:
895
+ dyn_args.append(arg)
896
+ dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
897
+ static_kwargs, dyn_kwargs = [], []
898
+ for k, v in sorted(kwargs.items()):
899
+ if k in self.static_argnames:
900
+ static_kwargs.append((k, v))
901
+ else:
902
+ dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
903
+
904
+ static_args = make_hashable(tuple(static_args))
905
+ dyn_args = make_hashable(tuple(dyn_args))
906
+ static_kwargs = make_hashable(static_kwargs)
907
+ dyn_kwargs = make_hashable(dyn_kwargs)
908
+
909
+ cache_key = hashabledict(
910
+ static_args=static_args,
911
+ dyn_args=dyn_args,
912
+ static_kwargs=static_kwargs,
913
+ dyn_kwargs=dyn_kwargs,
914
+ )
915
+
916
+ if cache_key not in self._cached_state_trace and compile_if_miss:
917
+ self.make_jaxpr(*args, **kwargs)
918
+
919
+ return cache_key
920
+
921
+ def clear_cache(self) -> None:
922
+ """
923
+ Clear all compilation caches.
924
+
925
+ This method removes all cached jaxprs, output shapes, output trees,
926
+ and state traces. Use this when you need to recompile the function
927
+ or free memory.
928
+
929
+ Examples
930
+ --------
931
+ .. code-block:: python
932
+
933
+ >>> import brainstate
934
+ >>> import jax.numpy as jnp
935
+ >>>
936
+ >>> def f(x):
937
+ ... return x * 2
938
+ >>>
939
+ >>> sf = brainstate.transform.StatefulFunction(f)
940
+ >>> sf.make_jaxpr(jnp.array([1.0, 2.0]))
941
+ >>> sf.clear_cache() # Clear all cached compilations
942
+ """
943
+ self._cached_jaxpr.clear()
944
+ self._cached_out_shapes.clear()
945
+ self._cached_jaxpr_out_tree.clear()
946
+ self._cached_state_trace.clear()
947
+
948
+ def __jax_v04_new_arg(self):
949
+ # Should be within the calling of ``jax.make_jaxpr()``
950
+ frame, trace = _jax_v04_new_jax_trace()
951
+ # Set the function to transform the new argument to a tracer
952
+ fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
953
+ return fn
954
+
955
+ def __jax_new_version_new_arg(self):
956
+ trace = jax.core.trace_ctx.trace
957
+
958
+ def wrapper(x):
959
+ if jax.__version_info__ < (0, 6, 1):
960
+ fn = lambda xx: trace.new_arg(shaped_abstractify(xx))
961
+ else:
962
+ fn = lambda xx: trace.new_arg(shaped_abstractify(xx), source_info=source_info_util.current())
963
+ return jax.tree.map(fn, x._value)
964
+
965
+ return wrapper
966
+
967
+ def _wrapped_fun_to_eval(
968
+ self,
969
+ cache_key,
970
+ static_kwargs: dict,
971
+ *args,
972
+ **dyn_kwargs,
973
+ ) -> Tuple[Any, Tuple[State, ...]]:
974
+ """
975
+ Internal wrapper that executes the function and tracks state operations.
976
+
977
+ This method wraps the original function to track which states are read
978
+ and written during execution. It is used internally during jaxpr compilation.
979
+
980
+ Parameters
981
+ ----------
982
+ cache_key
983
+ The cache key for storing the state trace.
984
+ static_kwargs : dict
985
+ Static keyword arguments that were separated out.
986
+ *args
987
+ The positional arguments to the function.
988
+ **dyn_kwargs
989
+ Dynamic keyword arguments to the function.
990
+
991
+ Returns
992
+ -------
993
+ tuple
994
+ A tuple of (output, state_values) where output is the function result
995
+ and state_values are the tracked state values (either all or write-only
996
+ depending on return_only_write setting).
997
+ """
998
+ # state trace
999
+ state_trace: StateTraceStack = StateTraceStack(self.name)
1000
+ if jax.__version_info__ < (0, 4, 36):
1001
+ state_trace.set_new_arg(self.__jax_v04_new_arg())
1002
+ else:
1003
+ state_trace.set_new_arg(self.__jax_new_version_new_arg())
1004
+ self._cached_state_trace.set(cache_key, state_trace)
1005
+ with state_trace:
1006
+ out = self.fun(*args, **dyn_kwargs, **static_kwargs)
1007
+ state_values = (
1008
+ state_trace.get_write_state_values(True)
1009
+ if self.return_only_write else
1010
+ state_trace.get_state_values()
1011
+ )
1012
+ state_trace.recovery_original_values()
1013
+
1014
+ # State instance as functional returns is not allowed.
1015
+ # Checking whether the states are returned.
1016
+ jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
1017
+ return out, state_values
1018
+
1019
+ def make_jaxpr(self, *args, **kwargs):
1020
+ """
1021
+ Create the JAX Jaxpr representation given example arguments.
1022
+
1023
+ This method compiles the function with the given arguments and caches
1024
+ the resulting Jaxpr, output shapes, and state trace for later use.
1025
+
1026
+ Parameters
1027
+ ----------
1028
+ *args
1029
+ The arguments to the function.
1030
+ **kwargs
1031
+ The keyword arguments to the function.
1032
+
1033
+ Returns
1034
+ -------
1035
+ StatefulFunction
1036
+ Returns self for method chaining.
1037
+
1038
+ Raises
1039
+ ------
1040
+ TypeError
1041
+ If State objects are passed as arguments or returned from the function.
1042
+ """
1043
+
1044
+ # check input types
1045
+ jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
1046
+
1047
+ # static args
1048
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
1049
+
1050
+ if cache_key not in self._cached_state_trace:
1051
+ try:
1052
+
1053
+ # jaxpr
1054
+ static_kwargs, dyn_kwargs = {}, {}
1055
+ for k, v in kwargs.items():
1056
+ if k in self.static_argnames:
1057
+ static_kwargs[k] = v
1058
+ else:
1059
+ dyn_kwargs[k] = v
1060
+ jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
1061
+ functools.partial(
1062
+ self._wrapped_fun_to_eval,
1063
+ cache_key,
1064
+ static_kwargs,
1065
+ ),
1066
+ static_argnums=self.static_argnums,
1067
+ axis_env=self.axis_env,
1068
+ return_shape=True,
1069
+ abstracted_axes=self.abstracted_axes,
1070
+ )(*args, **dyn_kwargs)
1071
+
1072
+ # returns
1073
+ self._cached_jaxpr_out_tree.set(cache_key, jax.tree.structure((out_shapes, state_shapes)))
1074
+ self._cached_out_shapes.set(cache_key, (out_shapes, state_shapes))
1075
+ self._cached_jaxpr.set(cache_key, jaxpr)
1076
+
1077
+ except Exception as e:
1078
+ # Clean up partial cache entries on error
1079
+ self._cached_state_trace.pop(cache_key, None)
1080
+ self._cached_out_shapes.pop(cache_key, None)
1081
+ self._cached_jaxpr.pop(cache_key, None)
1082
+ raise e
1083
+
1084
+ return self
1085
+
1086
+ def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
1087
+ """
1088
+ Call the function at the JAX Jaxpr level.
1089
+
1090
+ This method evaluates the compiled Jaxpr with the provided state values
1091
+ and arguments, returning updated state values and function outputs.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ state_vals : Sequence
1096
+ The current state values.
1097
+ *args
1098
+ The arguments to the function.
1099
+ **kwargs
1100
+ The keyword arguments to the function.
1101
+
1102
+ Returns
1103
+ -------
1104
+ tuple
1105
+ A tuple of (new_state_vals, out) where new_state_vals are the
1106
+ updated state values and out is the function output.
1107
+
1108
+ Raises
1109
+ ------
1110
+ ValueError
1111
+ If the number of state values doesn't match the expected number.
1112
+ """
1113
+ # state checking
1114
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
1115
+ states: Sequence[State] = self.get_states_by_cache(cache_key)
1116
+ if len(state_vals) != len(states):
1117
+ raise ValueError(f'State length mismatch: expected {len(states)} states, got {len(state_vals)}')
1118
+
1119
+ # parameters
1120
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
1121
+ args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
1122
+ args = jax.tree.flatten((args, kwargs, state_vals))[0]
1123
+
1124
+ # calling the function,
1125
+ # note that this function always returns state values
1126
+ # that both write and read by the function
1127
+ closed_jaxpr = self.get_jaxpr_by_cache(cache_key)
1128
+ out_treedef = self.get_out_treedef_by_cache(cache_key)
1129
+ jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
1130
+
1131
+ # output processing
1132
+ out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
1133
+ if len(new_state_vals) != len(state_vals):
1134
+ raise ValueError(f'State length mismatch in output: expected '
1135
+ f'{len(state_vals)} states, got {len(new_state_vals)}')
1136
+ return new_state_vals, out
1137
+
1138
+ def get_cache_stats(self) -> Dict[str, Any]:
1139
+ """
1140
+ Get comprehensive cache statistics for all internal caches.
1141
+
1142
+ Returns
1143
+ -------
1144
+ dict
1145
+ A dictionary with statistics for each cache including size, hits,
1146
+ misses, and hit rates. Keys are 'jaxpr_cache', 'out_shapes_cache',
1147
+ 'jaxpr_out_tree_cache', and 'state_trace_cache'.
1148
+ """
1149
+ return {
1150
+ 'jaxpr_cache': self._cached_jaxpr.get_stats(),
1151
+ 'out_shapes_cache': self._cached_out_shapes.get_stats(),
1152
+ 'jaxpr_out_tree_cache': self._cached_jaxpr_out_tree.get_stats(),
1153
+ 'state_trace_cache': self._cached_state_trace.get_stats(),
1154
+ }
1155
+
1156
+ def validate_states(self, cache_key: Hashable) -> bool:
1157
+ """
1158
+ Validate that all tracked states for a given cache key are still valid.
1159
+
1160
+ Parameters
1161
+ ----------
1162
+ cache_key : Hashable
1163
+ The cache key to validate states for.
1164
+
1165
+ Returns
1166
+ -------
1167
+ bool
1168
+ True if all states are valid.
1169
+
1170
+ Raises
1171
+ ------
1172
+ ValueError
1173
+ If any states are invalid or missing required attributes.
1174
+ """
1175
+ state_trace = self.get_state_trace_by_cache(cache_key)
1176
+ invalid_states = []
1177
+ for i, state in enumerate(state_trace.states):
1178
+ if not hasattr(state, 'value'):
1179
+ invalid_states.append((i, state))
1180
+
1181
+ if invalid_states:
1182
+ raise ValueError(
1183
+ f"Found {len(invalid_states)} invalid states at indices: "
1184
+ f"{[idx for idx, _ in invalid_states]}. "
1185
+ f"States must have a 'value' attribute."
1186
+ )
1187
+ return True
1188
+
1189
+ def validate_all_states(self) -> Dict[Any, bool]:
1190
+ """
1191
+ Validate states for all cached compilations.
1192
+
1193
+ Returns
1194
+ -------
1195
+ dict
1196
+ A dictionary mapping cache keys to validation results. Each value
1197
+ is either True (valid) or an error message string (invalid).
1198
+ """
1199
+ results = {}
1200
+ for cache_key in self._cached_state_trace.keys():
1201
+ try:
1202
+ results[cache_key] = self.validate_states(cache_key)
1203
+ except ValueError as e:
1204
+ results[cache_key] = str(e)
1205
+ return results
1206
+
1207
+ def jaxpr_call_auto(self, *args, **kwargs) -> Any:
1208
+ """
1209
+ Execute the function at the jaxpr level with automatic state management.
1210
+
1211
+ This method automatically retrieves current state values, executes the
1212
+ jaxpr-compiled function, and updates the states with the new values.
1213
+ It provides a convenient interface that handles all state management
1214
+ automatically.
1215
+
1216
+ Parameters
1217
+ ----------
1218
+ *args
1219
+ The positional arguments to the function.
1220
+ **kwargs
1221
+ The keyword arguments to the function.
1222
+
1223
+ Returns
1224
+ -------
1225
+ Any
1226
+ The output of the function.
1227
+
1228
+ Examples
1229
+ --------
1230
+ .. code-block:: python
1231
+
1232
+ >>> import brainstate
1233
+ >>> import jax.numpy as jnp
1234
+ >>>
1235
+ >>> state = brainstate.State(jnp.array([1.0, 2.0]))
1236
+ >>>
1237
+ >>> def f(x):
1238
+ ... state.value += x
1239
+ ... return state.value * 2
1240
+ >>>
1241
+ >>> sf = brainstate.transform.StatefulFunction(f)
1242
+ >>> x = jnp.array([0.5, 0.5])
1243
+ >>> sf.make_jaxpr(x)
1244
+ >>>
1245
+ >>> # Automatic state management
1246
+ >>> result = sf.jaxpr_call_auto(x)
1247
+ # # or
1248
+ >>> result = sf(x)
1249
+ >>> print(state.value) # State is automatically updated
1250
+ """
1251
+ state_trace = self.get_state_trace_by_cache(self.get_arg_cache_key(*args, **kwargs, compile_if_miss=True))
1252
+ all_read_state_vals = state_trace.get_read_state_values(True)
1253
+ state_vals, out = self.jaxpr_call(state_trace.get_state_values(), *args, **kwargs)
1254
+ state_trace.assign_state_vals_v2(all_read_state_vals, state_vals)
1255
+ return out
1256
+
1257
+ def __call__(self, *args, **kwargs):
1258
+ return self.jaxpr_call_auto(*args, **kwargs)
1259
+
1260
+
1261
+ @set_module_as("brainstate.transform")
1262
+ def make_jaxpr(
1263
+ fun: Callable,
1264
+ static_argnums: Union[int, Iterable[int]] = (),
1265
+ static_argnames: Union[str, Iterable[str]] = (),
1266
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1267
+ return_shape: bool = False,
1268
+ abstracted_axes: Optional[Any] = None,
1269
+ return_only_write: bool = False,
1270
+ ) -> Callable[
1271
+ ...,
1272
+ (Tuple[ClosedJaxpr, Tuple[State, ...]] |
1273
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
1274
+ ]:
1275
+ """
1276
+ Creates a function that produces its jaxpr given example args.
1277
+
1278
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
1279
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
1280
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
1281
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
1282
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
1283
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
1284
+
1285
+ Parameters
1286
+ ----------
1287
+ fun : callable
1288
+ The function whose ``jaxpr`` is to be computed. Its positional
1289
+ arguments and return value should be arrays, scalars, or standard Python
1290
+ containers (tuple/list/dict) thereof.
1291
+ static_argnums : int or iterable of int, optional
1292
+ See the :py:func:`jax.jit` docstring.
1293
+ static_argnames : str or iterable of str, optional
1294
+ See the :py:func:`jax.jit` docstring.
1295
+ axis_env : sequence of tuple, optional
1296
+ A sequence of pairs where the first element is an axis
1297
+ name and the second element is a positive integer representing the size of
1298
+ the mapped axis with that name. This parameter is useful when lowering
1299
+ functions that involve parallel communication collectives, and it
1300
+ specifies the axis name/size environment that would be set up by
1301
+ applications of :py:func:`jax.pmap`.
1302
+ return_shape : bool, default False
1303
+ If ``True``, the
1304
+ wrapped function returns a pair where the first element is the XLA
1305
+ computation and the second element is a pytree with the same structure as
1306
+ the output of ``fun`` and where the leaves are objects with ``shape``,
1307
+ ``dtype``, and ``named_shape`` attributes representing the corresponding
1308
+ types of the output leaves.
1309
+ abstracted_axes : pytree, optional
1310
+ A pytree with the same structure as the input
1311
+ arguments to ``fun``. The leaves of the pytree can be either None or a
1312
+ dict with axis names as keys and integers as values. If the leaf is None,
1313
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
1314
+ the corresponding axis is abstracted, and the dict specifies the axis name
1315
+ and size. The abstracted axes are used to infer the input type of the
1316
+ function. If None, then all axes are abstracted.
1317
+ return_only_write : bool, default False
1318
+ If True, only return states that were written to during execution
1319
+ (not just read). This can reduce memory usage when you only care
1320
+ about modified states.
1321
+
1322
+ Returns
1323
+ -------
1324
+ callable
1325
+ A wrapped version of ``fun`` that when applied to example arguments returns
1326
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
1327
+ argument ``return_shape`` is ``True``, then the returned function instead
1328
+ returns a pair where the first element is the ``ClosedJaxpr``
1329
+ representation of ``fun`` and the second element is a pytree representing
1330
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
1331
+
1332
+ Examples
1333
+ --------
1334
+ Basic usage:
1335
+
1336
+ .. code-block:: python
1337
+
1338
+ >>> import jax
1339
+ >>> import brainstate
1340
+ >>> import jax.numpy as jnp
1341
+ >>>
1342
+ >>> def f(x):
1343
+ ... return jnp.sin(jnp.cos(x))
1344
+ >>>
1345
+ >>> # Create jaxpr maker
1346
+ >>> jaxpr_maker = brainstate.transform.make_jaxpr(f)
1347
+ >>> jaxpr, states = jaxpr_maker(3.0)
1348
+
1349
+ With gradient:
1350
+
1351
+ .. code-block:: python
1352
+
1353
+ >>> jaxpr_grad_maker = brainstate.transform.make_jaxpr(jax.grad(f))
1354
+ >>> jaxpr, states = jaxpr_grad_maker(3.0)
1355
+
1356
+ With shape information:
1357
+
1358
+ .. code-block:: python
1359
+
1360
+ >>> jaxpr_maker_with_shape = brainstate.transform.make_jaxpr(f, return_shape=True)
1361
+ >>> jaxpr, states, shapes = jaxpr_maker_with_shape(3.0)
1362
+
1363
+ With stateful function:
1364
+
1365
+ .. code-block:: python
1366
+
1367
+ >>> state = brainstate.State(jnp.array([1.0, 2.0]))
1368
+ >>>
1369
+ >>> def stateful_f(x):
1370
+ ... state.value += x
1371
+ ... return state.value
1372
+ >>>
1373
+ >>> jaxpr_maker = brainstate.transform.make_jaxpr(stateful_f)
1374
+ >>> jaxpr, states = jaxpr_maker(jnp.array([0.5, 0.5]))
1375
+ """
1376
+
1377
+ stateful_fun = StatefulFunction(
1378
+ fun,
1379
+ static_argnums=static_argnums,
1380
+ static_argnames=static_argnames,
1381
+ axis_env=axis_env,
1382
+ abstracted_axes=abstracted_axes,
1383
+ return_only_write=return_only_write,
1384
+ name='make_jaxpr'
1385
+ )
1386
+
1387
+ @wraps(fun)
1388
+ def make_jaxpr_f(*args, **kwargs):
1389
+ stateful_fun.make_jaxpr(*args, **kwargs)
1390
+ cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
1391
+ if return_shape:
1392
+ return (
1393
+ stateful_fun.get_jaxpr_by_cache(cache_key),
1394
+ stateful_fun.get_states_by_cache(cache_key),
1395
+ stateful_fun.get_out_shapes_by_cache(cache_key)[0]
1396
+ )
1397
+ else:
1398
+ return (
1399
+ stateful_fun.get_jaxpr_by_cache(cache_key),
1400
+ stateful_fun.get_states_by_cache(cache_key)
1401
+ )
1402
+
1403
+ # wrapped jaxpr builder function
1404
+ make_jaxpr_f.__module__ = "brainstate.transform"
1405
+ if hasattr(fun, "__qualname__"):
1406
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
1407
+ if hasattr(fun, "__name__"):
1408
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
1409
+ return make_jaxpr_f
1410
+
1411
+
1412
+ class StatefulMapping(StatefulFunction):
1413
+ """
1414
+ Vectorized wrapper that preserves BrainState state semantics during mapping.
1415
+
1416
+ ``StatefulMapping`` extends JAX mapping transforms (such as :func:`jax.vmap`
1417
+ and :func:`jax.pmap`) with awareness of :class:`~brainstate.State`
1418
+ instances. It tracks state reads and writes across the mapped axis,
1419
+ ensures deterministic random-number handling, and restores side effects
1420
+ after each batched execution. The helper is typically constructed by
1421
+ :func:`brainstate.transform.vmap` or :func:`brainstate.transform.pmap`, but
1422
+ it can also be instantiated directly for custom mapping primitives.
1423
+
1424
+ Parameters
1425
+ ----------
1426
+ fun : callable
1427
+ Stateless callable to be wrapped. The callable may close over
1428
+ :class:`~brainstate.State` objects that should be tracked during the
1429
+ mapping transform.
1430
+ in_axes : int, tuple of int, or None, default 0
1431
+ Alignment of the mapped axis per positional argument, following the
1432
+ semantics of :func:`jax.vmap`. Arguments mapped with ``None`` are treated
1433
+ as static.
1434
+ out_axes : int, tuple of int, or None, default 0
1435
+ Placement of the mapped axis in the return value, consistent with JAX
1436
+ mapping primitives.
1437
+ state_in_axes : dict[AxisName, Filter] or Filter, optional
1438
+ Specification of input states that participate in the mapped axis. A
1439
+ dictionary maps axis identifiers to :mod:`brainstate.util.filter`
1440
+ predicates; passing a single filter applies it to axis ``0``. Values are
1441
+ normalized via :func:`brainstate.util.filter.to_predicate`.
1442
+ state_out_axes : dict[AxisName, Filter] or Filter, optional
1443
+ Specification of state outputs to scatter back along the mapped axis.
1444
+ Uses the same semantics and normalization as ``state_in_axes``.
1445
+ unexpected_out_state_mapping : {'raise', 'warn', 'ignore'}, default 'raise'
1446
+ Strategy for handling states written during the mapped call that are not
1447
+ captured by ``state_out_axes``.
1448
+ axis_size : int, optional
1449
+ Explicit size of the mapped axis. When omitted, the size is inferred
1450
+ from the mapped arguments.
1451
+ axis_name : hashable, optional
1452
+ Name for the mapped axis so that collective primitives can target it.
1453
+ name : str, optional
1454
+ Human-readable identifier for diagnostics and debugging.
1455
+ mapping_fn : callable, default ``jax.vmap``
1456
+ Mapping primitive that executes ``fun``. The callable must accept the
1457
+ ``in_axes`` and ``out_axes`` keyword arguments used by :func:`jax.vmap`.
1458
+
1459
+ Attributes
1460
+ ----------
1461
+ origin_fun : callable
1462
+ Original Python callable wrapped by the mapping helper.
1463
+ in_axes : int, tuple of int, or None
1464
+ Mapping specification for positional arguments.
1465
+ out_axes : int, tuple of int, or None
1466
+ Mapping specification for the return value.
1467
+ state_in_axes : dict[AxisName, Predicate]
1468
+ Normalized predicates describing which states to batch on input.
1469
+ state_out_axes : dict[AxisName, Predicate]
1470
+ Normalized predicates describing which states to batch on output.
1471
+ axis_size : int or None
1472
+ Size of the mapped axis, if explicitly provided.
1473
+ axis_name : hashable or None
1474
+ Axis identifier forwarded to collective primitives.
1475
+ mapping_fn : callable
1476
+ Mapping primitive responsible for executing ``fun``.
1477
+
1478
+ Raises
1479
+ ------
1480
+ TypeError
1481
+ If ``in_axes`` has an unsupported type.
1482
+ ValueError
1483
+ If batch dimensions are inconsistent or cannot be inferred.
1484
+ RuntimeError
1485
+ If tracing or executing the mapped function fails.
1486
+
1487
+ Notes
1488
+ -----
1489
+ Random states (for example :class:`~brainstate.RandomState`) encountered
1490
+ during execution are automatically split along the mapped axis and restored
1491
+ afterwards; this behaviour cannot be disabled. The wrapper caches inferred
1492
+ state placements, batch sizes, and trace stacks keyed by abstract argument
1493
+ signatures so repeated calls with the same structure avoid re-tracing.
1494
+
1495
+ Examples
1496
+ --------
1497
+ .. code-block:: python
1498
+
1499
+ >>> import brainstate
1500
+ >>> import jax.numpy as jnp
1501
+ >>> from brainstate.util.filter import OfType
1502
+ >>>
1503
+ >>> counter = brainstate.ShortTermState(jnp.array(0.0))
1504
+ >>>
1505
+ >>> def accumulate(x):
1506
+ ... counter.value = counter.value + x
1507
+ ... return counter.value
1508
+ >>>
1509
+ >>> batched_accumulate = brainstate.transform.StatefulMapping(
1510
+ ... accumulate,
1511
+ ... in_axes=0,
1512
+ ... out_axes=0,
1513
+ ... state_in_axes={0: OfType(brainstate.ShortTermState)},
1514
+ ... state_out_axes={0: OfType(brainstate.ShortTermState)},
1515
+ ... name="batched_accumulate",
1516
+ ... )
1517
+ >>>
1518
+ >>> xs = jnp.ones((3,))
1519
+ >>> batched_accumulate(xs)
1520
+ Array([1., 2., 3.], dtype=float32)
1521
+ >>> counter.value
1522
+ Array(3., dtype=float32)
1523
+
1524
+ See Also
1525
+ --------
1526
+ brainstate.transform.vmap : Convenience API returning a ``StatefulMapping``.
1527
+ brainstate.transform.pmap : Device-mapped variant aware of BrainState states.
1528
+ """
1529
+ __module__ = "brainstate.transform"
1530
+
1531
+ def __init__(
1532
+ self,
1533
+ fun: Callable,
1534
+ in_axes: Union[int, Tuple[int, ...], None] = 0,
1535
+ out_axes: Union[int, Tuple[int, ...], None] = 0,
1536
+ state_in_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1537
+ state_out_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1538
+ unexpected_out_state_mapping: str = 'raise',
1539
+ # JIT specific parameters
1540
+ static_argnums: Union[int, Iterable[int]] = (),
1541
+ static_argnames: Union[str, Iterable[str]] = (),
1542
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1543
+ abstracted_axes: Optional[Any] = None,
1544
+ return_only_write: bool = True,
1545
+ # mapping specific parameters
1546
+ axis_size: Optional[int] = None,
1547
+ axis_name: AxisName | None = None,
1548
+ name: Optional[str] = None,
1549
+ # mapping function
1550
+ mapping_fn: Callable = jax.vmap,
1551
+ ):
1552
+ super().__init__(
1553
+ fun=self.__wrapped_fun,
1554
+ static_argnums=static_argnums,
1555
+ static_argnames=static_argnames,
1556
+ axis_env=axis_env,
1557
+ abstracted_axes=abstracted_axes,
1558
+ return_only_write=return_only_write,
1559
+ name=name,
1560
+ )
1561
+
1562
+ self.name = name
1563
+ self.origin_fun = fun
1564
+ self.in_axes = in_axes
1565
+ self.out_axes = out_axes
1566
+ if state_in_axes is None:
1567
+ state_in_axes = dict()
1568
+ elif not isinstance(state_in_axes, dict):
1569
+ state_in_axes = {0: to_predicate(state_in_axes)}
1570
+ state_in_axes = {k: to_predicate(v) for k, v in state_in_axes.items()} # type: ignore
1571
+ self.state_in_axes = state_in_axes
1572
+
1573
+ if state_out_axes is None:
1574
+ state_out_axes = dict()
1575
+ elif not isinstance(state_out_axes, dict):
1576
+ state_out_axes = {0: to_predicate(state_out_axes)}
1577
+ state_out_axes = {k: to_predicate(v) for k, v in state_out_axes.items()} # type: ignore
1578
+ self.state_out_axes = state_out_axes
1579
+
1580
+ self.axis_size = axis_size
1581
+ self.axis_name = axis_name
1582
+ self.mapping_fn = mapping_fn
1583
+ self.unexpected_out_state_mapping = unexpected_out_state_mapping
1584
+
1585
+ # Cache for discovered state-to-axis mappings
1586
+ self._cached_map_dim_to_in_states = _BoundedCache(maxsize=128)
1587
+ self._cached_map_dim_to_out_states = _BoundedCache(maxsize=128)
1588
+ self._cached_map_state_trace = _BoundedCache(maxsize=128)
1589
+ self._cached_map_batch_size = _BoundedCache(maxsize=128)
1590
+
1591
+ def __infer_batch_size(self, args, in_axes):
1592
+ def get_batch_size_from_arg(arg_, axis_):
1593
+ if axis_ is None:
1594
+ return None
1595
+
1596
+ def _get_size(arr):
1597
+ if not hasattr(arr, 'shape'):
1598
+ return None
1599
+ if arr.ndim == 0:
1600
+ return None
1601
+ ax = axis_ if axis_ >= 0 else arr.ndim + axis_
1602
+ if ax < 0 or ax >= arr.ndim:
1603
+ raise IndexError(f"Axis {ax} is out of bounds for array of shape {arr.shape}")
1604
+ return arr.shape[ax]
1605
+
1606
+ # Get all sizes from the pytree
1607
+ sizes = [s for s in jax.tree.leaves(jax.tree.map(_get_size, arg_)) if s is not None]
1608
+ return sizes[0] if sizes else None
1609
+
1610
+ batch_sizes = []
1611
+ if isinstance(in_axes, int):
1612
+ # All args batched along the same axis
1613
+ for arg in args:
1614
+ size = get_batch_size_from_arg(arg, in_axes)
1615
+ if size is not None:
1616
+ batch_sizes.append(size)
1617
+ elif isinstance(in_axes, (tuple, list)):
1618
+ # Different axes for different args
1619
+ if len(in_axes) != len(args):
1620
+ raise ValueError(
1621
+ f"Length of in_axes ({len(in_axes)}) must match number of arguments ({len(args)})"
1622
+ )
1623
+ for arg, axis in zip(args, in_axes):
1624
+ size = get_batch_size_from_arg(arg, axis)
1625
+ if size is not None:
1626
+ batch_sizes.append(size)
1627
+ elif in_axes is None:
1628
+ pass
1629
+ else:
1630
+ raise TypeError(f"Unsupported in_axes type: {type(in_axes)}")
1631
+
1632
+ if not batch_sizes:
1633
+ if self.axis_size is None:
1634
+ raise ValueError("Cannot infer batch size when axis_size is None")
1635
+ batch_sizes.append(self.axis_size)
1636
+
1637
+ # Check all batch sizes are consistent
1638
+ if not all(s == batch_sizes[0] for s in batch_sizes):
1639
+ raise ValueError(
1640
+ f"Inconsistent batch sizes found: {batch_sizes}. "
1641
+ f"All batched arguments must have the same size along their batch axes."
1642
+ )
1643
+
1644
+ return batch_sizes[0]
1645
+
1646
+ def __new_batch_arg(self, trace, batch_size: int, dim_to_states: dict):
1647
+ def wrapper(x):
1648
+ if isinstance(x, RandomState):
1649
+ idx = lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current())
1650
+ dim_to_states['random'].append(x)
1651
+ return to_elt(trace, idx, self._rand_value, 0)
1652
+ for dim, filter_ in self.state_in_axes.items():
1653
+ idx = lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current())
1654
+ if filter_(tuple(), x):
1655
+ dim_to_states[dim].append(x)
1656
+ return jax.tree.map(lambda xx: to_elt(trace, idx, xx, dim), x._value)
1657
+ return x._value
1658
+
1659
+ return wrapper
1660
+
1661
+ def __find_batch_dim(self, st):
1662
+ leaves = jax.tree.leaves(st._value)
1663
+ batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
1664
+ if len(batch_dims) != 1:
1665
+ raise ValueError(
1666
+ f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
1667
+ "All leaves must have the same batch dimension."
1668
+ )
1669
+ dim = batch_dims.pop()
1670
+ return dim
1671
+
1672
+ def __fn_to_eval(self, cache_key, *new_args, **new_kwargs):
1673
+ # state trace
1674
+ trace = jax.core.trace_ctx.trace
1675
+ assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
1676
+ dim_to_in_states = defaultdict(list)
1677
+ state_trace = StateTraceStack(name=self.name)
1678
+ state_trace.set_new_arg(
1679
+ self.__new_batch_arg(trace, self._cached_map_batch_size.get(cache_key), dim_to_in_states)
1680
+ )
1681
+ self._cached_map_state_trace.set(cache_key, state_trace)
1682
+
1683
+ # call functions
1684
+ with state_trace:
1685
+ out_ = self.origin_fun(*new_args, **new_kwargs)
1686
+
1687
+ # cache vmapped in states
1688
+ self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states.copy())
1689
+ mapped_in_states = set([id(v) for vv in dim_to_in_states.values() for v in vv])
1690
+
1691
+ # vmapped out states
1692
+ out_states = defaultdict(list)
1693
+ out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
1694
+ for st in state_trace.states:
1695
+ if isinstance(st, RandomState):
1696
+ continue
1697
+ find = False
1698
+ for dim, filter_ in self.state_out_axes.items():
1699
+ if filter_(tuple(), st):
1700
+ out_states[dim].append(st)
1701
+ find = True
1702
+ break
1703
+ if find:
1704
+ continue
1705
+ dim = self.__find_batch_dim(st)
1706
+ if dim is None or id(st) in mapped_in_states:
1707
+ out_states[dim].append(st)
1708
+ else:
1709
+ if self.unexpected_out_state_mapping == 'raise':
1710
+ st.raise_error_with_source_info(
1711
+ BatchAxisError(
1712
+ f'State\n {st} \n was not expected to be batched on output. '
1713
+ 'Please adjust state_out_axes or set unexpected_out_state_mapping to "warn" or "ignore".'
1714
+ )
1715
+ )
1716
+ elif self.unexpected_out_state_mapping == 'warn':
1717
+ warnings.warn(
1718
+ f'State\n {st} \n was not expected to be batched on output. '
1719
+ f'Please adjust state_out_axes or set unexpected_out_state_mapping to "ignore".',
1720
+ UserWarning,
1721
+ )
1722
+ out_states[dim].append(st)
1723
+ elif self.unexpected_out_state_mapping == 'ignore':
1724
+ out_states[dim].append(st)
1725
+ else:
1726
+ raise ValueError(
1727
+ 'Invalid value for unexpected_out_state_mapping: '
1728
+ f'{self.unexpected_out_state_mapping}. Must be "raise", "warn", or "ignore".'
1729
+ )
1730
+ self._cached_map_dim_to_out_states.set(cache_key, out_states)
1731
+
1732
+ def __eval(self, cache_key, *args, **kwargs):
1733
+ try:
1734
+ jax.vmap(
1735
+ functools.partial(self.__fn_to_eval, cache_key),
1736
+ in_axes=self.in_axes,
1737
+ out_axes=self.out_axes,
1738
+ axis_name=self.axis_name,
1739
+ axis_size=self.axis_size
1740
+ )(*args, **kwargs)
1741
+ self._cached_map_state_trace.get(cache_key).recovery_original_values()
1742
+ except Exception as e:
1743
+ if cache_key in self._cached_map_state_trace:
1744
+ self._cached_map_state_trace.get(cache_key).recovery_original_values()
1745
+ self._cached_map_state_trace.pop(cache_key, None)
1746
+ self._cached_map_dim_to_in_states.pop(cache_key, None)
1747
+ self._cached_map_dim_to_out_states.pop(cache_key, None)
1748
+ self._cached_map_batch_size.pop(cache_key, None)
1749
+ raise e
1750
+
1751
+ def __assign_vals_from_in_states(self, cache_key, rand_st, *other_st):
1752
+ in_states = self._cached_map_dim_to_in_states.get(cache_key)
1753
+ for st, val in zip(in_states['random'], rand_st):
1754
+ assert isinstance(st, RandomState)
1755
+ st.restore_value(val)
1756
+ for group, group_vals in zip([in_states[dim] for dim in in_states.keys() if dim != 'random'], other_st):
1757
+ for st, val in zip(group, group_vals):
1758
+ st.restore_value(val)
1759
+
1760
+ def __assign_vals_from_out_states(self, cache_key, rand_st, *other_st):
1761
+ out_states = self._cached_map_dim_to_out_states.get(cache_key)
1762
+ for st, val in zip(out_states['random'], rand_st):
1763
+ assert isinstance(st, RandomState)
1764
+ st.restore_value(val)
1765
+ for group, group_vals in zip([out_states[dim] for dim in out_states.keys() if dim != 'random'], other_st):
1766
+ for st, val in zip(group, group_vals):
1767
+ st.restore_value(val)
1768
+
1769
+ def __get_in_state_vals(self, cache_key: Hashable):
1770
+ in_states = self._cached_map_dim_to_in_states.get(cache_key)
1771
+ in_axes = []
1772
+ in_values = []
1773
+ for dim, states in in_states.items():
1774
+ if dim == 'random':
1775
+ continue
1776
+ in_axes.append(dim)
1777
+ in_values.append([st.value for st in states])
1778
+ return tuple(in_axes), in_values
1779
+
1780
+ def __get_out_state_vals(self, cache_key: Hashable):
1781
+ out_states = self._cached_map_dim_to_out_states.get(cache_key)
1782
+ out_axes = []
1783
+ out_values = []
1784
+ for dim, state in out_states.items():
1785
+ if dim == 'random':
1786
+ continue
1787
+ out_axes.append(dim)
1788
+ out_values.append([st.value for st in state])
1789
+ return tuple(out_axes), out_values
1790
+
1791
+ def __get_rand_state_vals(self, cache_key: Hashable):
1792
+ in_states = self._cached_map_dim_to_in_states.get(cache_key)
1793
+ batch_size = self._cached_map_batch_size.get(cache_key)
1794
+ rand_vals, rand_recover_vals = [], []
1795
+ for st in in_states['random']:
1796
+ assert isinstance(st, RandomState)
1797
+ rand_vals.append(st.split_key(batch_size))
1798
+ rand_recover_vals.append(st.value)
1799
+ return tuple(rand_vals), tuple(rand_recover_vals)
1800
+
1801
+ def __recover_rand_state_vals(self, cache_key: Hashable, rand_recover_vals):
1802
+ state_trace = self._cached_map_state_trace.get(cache_key)
1803
+ rand_states = [st for st in state_trace.states if isinstance(st, RandomState)]
1804
+ for st, val in zip(rand_states, rand_recover_vals):
1805
+ st.restore_value(val)
1806
+
1807
+ def __wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
1808
+ if len(kwargs):
1809
+ raise NotImplementedError(
1810
+ 'StatefulMapping currently does not support keyword arguments.'
1811
+ )
1812
+
1813
+ batch_size = self.__infer_batch_size(args, self.in_axes)
1814
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
1815
+ if cache_key not in self._cached_map_state_trace:
1816
+ self._rand_value = RandomState._batch_keys(batch_size)
1817
+ self._cached_map_batch_size.set(cache_key, batch_size)
1818
+ self.__eval(cache_key, *args, **kwargs)
1819
+
1820
+ def fn_to_map(origin_args, rand_st, *non_rand_st):
1821
+ self.__assign_vals_from_in_states(cache_key, rand_st, *non_rand_st)
1822
+ out = self.origin_fun(*origin_args)
1823
+ return out, *self.__get_out_state_vals(cache_key)[1]
1824
+
1825
+ in_axes, in_state_vals = self.__get_in_state_vals(cache_key)
1826
+ out_axes, out_state_vals = self.__get_out_state_vals(cache_key)
1827
+ rand_vals, rand_recover_vals = self.__get_rand_state_vals(cache_key)
1828
+ mapped_fn = self.mapping_fn(
1829
+ fn_to_map,
1830
+ in_axes=(self.in_axes, 0 if len(rand_vals) else None) + in_axes,
1831
+ out_axes=(self.out_axes,) + out_axes,
1832
+ axis_size=self.axis_size,
1833
+ axis_name=self.axis_name,
1834
+ )
1835
+ out_, *out_state_vals = mapped_fn(args, rand_vals, *in_state_vals)
1836
+ self.__assign_vals_from_out_states(cache_key, rand_recover_vals, *out_state_vals)
1837
+ return out_
1838
+
1839
+
1840
+ def _check_callable(fun):
1841
+ # In Python 3.10+, the only thing stopping us from supporting static methods
1842
+ # is that we can't take weak references to them, which the C++ JIT requires.
1843
+ if isinstance(fun, staticmethod):
1844
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
1845
+ if not callable(fun):
1846
+ raise TypeError(f"Expected a callable value, got {fun}")
1847
+ if inspect.isgeneratorfunction(fun):
1848
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
1849
+
1850
+
1851
+ def _broadcast_prefix(
1852
+ prefix_tree: Any,
1853
+ full_tree: Any,
1854
+ is_leaf: Callable[[Any], bool] | None = None
1855
+ ) -> list[Any]:
1856
+ # If prefix_tree is not a tree prefix of full_tree, this code can raise a
1857
+ # ValueError; use prefix_errors to find disagreements and raise more precise
1858
+ # error messages.
1859
+ result = []
1860
+ num_leaves = lambda t: jax.tree.structure(t).num_leaves
1861
+ add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
1862
+ jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
1863
+ return result
1864
+
1865
+
1866
+ def _flat_axes_specs(
1867
+ abstracted_axes, *args, **kwargs
1868
+ ) -> list[pe.AbstractedAxesSpec]:
1869
+ if kwargs:
1870
+ raise NotImplementedError
1871
+
1872
+ def ax_leaf(l):
1873
+ return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
1874
+ isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
1875
+
1876
+ return _broadcast_prefix(abstracted_axes, args, ax_leaf)
1877
+
1878
+
1879
+ @transformation_with_aux
1880
+ def _flatten_fun(in_tree, *args_flat):
1881
+ py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
1882
+ ans = yield py_args, py_kwargs
1883
+ yield jax.tree.flatten(ans)
1884
+
1885
+
1886
+ def _make_jaxpr(
1887
+ fun: Callable,
1888
+ static_argnums: int | Iterable[int] = (),
1889
+ axis_env: Sequence[tuple[AxisName, int]] | None = None,
1890
+ return_shape: bool = False,
1891
+ abstracted_axes: Any | None = None,
1892
+ ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
1893
+ """
1894
+ Create a function that produces its jaxpr given example args (internal implementation).
1895
+
1896
+ This is an internal implementation function. Users should use the public
1897
+ ``make_jaxpr`` function instead.
1898
+
1899
+ Parameters
1900
+ ----------
1901
+ fun : Callable
1902
+ The function whose ``jaxpr`` is to be computed. Its positional
1903
+ arguments and return value should be arrays, scalars, or standard Python
1904
+ containers (tuple/list/dict) thereof.
1905
+ static_argnums : int or iterable of int, optional
1906
+ See the :py:func:`jax.jit` docstring.
1907
+ axis_env : sequence of tuple, optional
1908
+ A sequence of pairs where the first element is an axis
1909
+ name and the second element is a positive integer representing the size of
1910
+ the mapped axis with that name. This parameter is useful when lowering
1911
+ functions that involve parallel communication collectives, and it
1912
+ specifies the axis name/size environment that would be set up by
1913
+ applications of :py:func:`jax.pmap`.
1914
+ return_shape : bool, default False
1915
+ If ``True``, the wrapped function returns a pair where the first element
1916
+ is the ``ClosedJaxpr`` representation of ``fun`` and the second element
1917
+ is a pytree with the same structure as the output of ``fun`` and where
1918
+ the leaves are objects with ``shape``, ``dtype``, and ``named_shape``
1919
+ attributes representing the corresponding types of the output leaves.
1920
+ abstracted_axes : Any, optional
1921
+ Axes specifications for abstract interpretation.
1922
+
1923
+ Returns
1924
+ -------
1925
+ Callable
1926
+ A wrapped version of ``fun`` that when applied to example arguments returns
1927
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
1928
+ argument ``return_shape`` is ``True``, then the returned function instead
1929
+ returns a pair where the first element is the ``ClosedJaxpr``
1930
+ representation of ``fun`` and the second element is a pytree representing
1931
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
1932
+
1933
+ Notes
1934
+ -----
1935
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
1936
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
1937
+ with let-bindings. This function adapts a function to return its
1938
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
1939
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
1940
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
1941
+
1942
+ Examples
1943
+ --------
1944
+ .. code-block:: python
1945
+
1946
+ >>> import jax
1947
+ >>>
1948
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
1949
+ >>> print(f(3.0))
1950
+ -0.83602
1951
+ >>> _make_jaxpr(f)(3.0)
1952
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
1953
+ >>> _make_jaxpr(jax.grad(f))(3.0)
1954
+ { lambda ; a:f32[]. let
1955
+ b:f32[] = cos a
1956
+ c:f32[] = sin a
1957
+ _:f32[] = sin b
1958
+ d:f32[] = cos b
1959
+ e:f32[] = mul 1.0 d
1960
+ f:f32[] = neg e
1961
+ g:f32[] = mul f c
1962
+ in (g,) }
1963
+ """
1964
+ _check_callable(fun)
1965
+ static_argnums = _ensure_index_tuple(static_argnums)
1966
+
1967
+ def _abstractify(args, kwargs):
1968
+ flat_args, in_tree = jax.tree.flatten((args, kwargs))
1969
+ if abstracted_axes is None:
1970
+ return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
1971
+ else:
1972
+ axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
1973
+ in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
1974
+ in_avals, keep_inputs = unzip2(in_type)
1975
+ return in_avals, in_tree, keep_inputs
1976
+
1977
+ @wraps(fun)
1978
+ @api_boundary
1979
+ def make_jaxpr_f(*args, **kwargs):
1980
+ f = wrap_init(fun, (), {}, 'brainstate.transform.make_jaxpr')
1981
+ if static_argnums:
1982
+ dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
1983
+ f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
1984
+ in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
1985
+ in_type = tuple(safe_zip(in_avals, keep_inputs))
1986
+ f, out_tree = _flatten_fun(f, in_tree)
1987
+ f = annotate(f, in_type)
1988
+ if jax.__version_info__ < (0, 5, 0):
1989
+ debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
1990
+ with ExitStack() as stack:
1991
+ if axis_env is not None:
1992
+ stack.enter_context(extend_axis_env_nd(axis_env))
1993
+ if jax.__version_info__ < (0, 5, 0):
1994
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
1995
+ else:
1996
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
1997
+ closed_jaxpr = ClosedJaxpr(jaxpr, consts)
1998
+ if return_shape:
1999
+ out_avals, _ = unzip2(out_type)
2000
+ out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
2001
+ return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
2002
+ return closed_jaxpr
2003
+
2004
+ make_jaxpr_f.__module__ = "brainstate.transform"
2005
+ if hasattr(fun, "__qualname__"):
2006
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
2007
+ if hasattr(fun, "__name__"):
2008
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
2009
+ return make_jaxpr_f
2010
+
2011
+
2012
+ def make_hashable(obj):
2013
+ """
2014
+ Convert a pytree into a hashable representation.
2015
+
2016
+ Parameters
2017
+ ----------
2018
+ obj : Any
2019
+ A pytree object (list, tuple, dict, set, or JAX pytree structure).
2020
+
2021
+ Returns
2022
+ -------
2023
+ Hashable
2024
+ A hashable representation of the input object. Lists become tuples,
2025
+ dicts become sorted tuples of key-value pairs, sets become frozensets,
2026
+ and other pytrees are flattened using JAX's tree utilities.
2027
+ """
2028
+ if isinstance(obj, (list, tuple)):
2029
+ return tuple(make_hashable(item) for item in obj)
2030
+ elif isinstance(obj, dict):
2031
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
2032
+ elif isinstance(obj, set):
2033
+ return frozenset(make_hashable(item) for item in obj)
2034
+ else:
2035
+ # return obj
2036
+ # Use JAX's tree_util for any other pytree structures
2037
+ try:
2038
+ leaves, treedef = jax.tree.flatten(obj)
2039
+ return treedef, tuple(leaves)
2040
+ except (TypeError, ValueError):
2041
+ # Assume obj is already hashable
2042
+ return obj
2043
+
2044
+
2045
+ class IdentitySet(MutableSet):
2046
+ """Set that compares objects by identity.
2047
+
2048
+ This is a set that compares objects by identity instead of equality. It is
2049
+ useful for storing objects that are not hashable or that should be compared
2050
+ by identity.
2051
+
2052
+ This is a mutable set, but it does not support the ``__hash__`` method and
2053
+ therefore cannot be used as a dictionary key or as an element of another set.
2054
+ """
2055
+
2056
+ def __init__(self, iterable=None):
2057
+ self._data = {}
2058
+ if iterable is not None:
2059
+ self.update(iterable)
2060
+
2061
+ def __contains__(self, value):
2062
+ return id(value) in self._data
2063
+
2064
+ def __iter__(self):
2065
+ return iter(self._data.values())
2066
+
2067
+ def __len__(self):
2068
+ return len(self._data)
2069
+
2070
+ def add(self, value):
2071
+ self._data[id(value)] = value
2072
+
2073
+ def discard(self, value):
2074
+ self._data.pop(id(value), None)
2075
+
2076
+ def __repr__(self):
2077
+ return f"IdentitySet({list(repr(x) for x in self._data.values())})"
2078
+
2079
+ def __str__(self):
2080
+ return f"IdentitySet({list(str(x) for x in self._data.values())})"
2081
+
2082
+
2083
+ def constant_fold_jaxpr(jaxpr: Jaxpr):
2084
+ """
2085
+ Given a jaxpr, return a new jaxpr with all constant folding done.
2086
+ """
2087
+ return _partial_eval_jaxpr(jaxpr, {})
2088
+
2089
+
2090
+ _constant_fold_blacklist = {'broadcast_in_dim', 'broadcast'}
2091
+
2092
+
2093
+ def _partial_eval_jaxpr(jaxpr, env):
2094
+ env = env.copy()
2095
+ new_eqns = []
2096
+
2097
+ def read(var):
2098
+ if isinstance(var, Literal):
2099
+ return var.val
2100
+ else:
2101
+ return env.get(var, None)
2102
+
2103
+ def read_or_self(var):
2104
+ out = read(var)
2105
+ if out is None:
2106
+ return var
2107
+ elif isinstance(out, Var):
2108
+ return out
2109
+ elif isinstance(out, Literal):
2110
+ return Literal(out.val, var.aval)
2111
+ else:
2112
+ assert not isinstance(out, Jaxpr)
2113
+ return Literal(out, var.aval)
2114
+
2115
+ for eqn in jaxpr.eqns:
2116
+ vals = [read(var) for var in eqn.invars]
2117
+ if eqn.primitive.name in _constant_fold_blacklist:
2118
+ new_eqns.append(eqn)
2119
+ elif all(val is not None for val in vals):
2120
+ # go ahead and eval it
2121
+ out = _eval_eqn(eqn, vals)
2122
+
2123
+ # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
2124
+ if isinstance(out, Jaxpr):
2125
+ # we need to inline this
2126
+ new_eqns.extend(out.eqns)
2127
+ out = out.outvars
2128
+ elif not isinstance(out, tuple) and not isinstance(out, list):
2129
+ out = (out,)
2130
+
2131
+ for var, val in zip(eqn.outvars, out):
2132
+ assert not isinstance(val, Jaxpr)
2133
+ if isinstance(val, Literal):
2134
+ env[var] = val.val
2135
+ else:
2136
+ env[var] = val
2137
+ else:
2138
+ new_eqns.append(eqn)
2139
+
2140
+ # now that we've eval everything, inline all the constants
2141
+ out_eqns = []
2142
+ for eqn in new_eqns:
2143
+ eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
2144
+ out_eqns.append(eqn)
2145
+
2146
+ invars_still_used = IdentitySet()
2147
+ for eqn in out_eqns:
2148
+ for var in eqn.invars:
2149
+ invars_still_used.add(var)
2150
+
2151
+ invars = tuple(var for var in jaxpr.invars if var in invars_still_used)
2152
+
2153
+ # sub in any constants for outvars
2154
+ outvars = tuple(read_or_self(var) for var in jaxpr.outvars)
2155
+
2156
+ return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars, debug_info=None)
2157
+
2158
+
2159
+ def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
2160
+ if eqn.primitive.name == "closed_call":
2161
+ assert eqn.primitive.call_primitive
2162
+ assert not eqn.primitive.map_primitive
2163
+
2164
+ out = _partial_eval_jaxpr(
2165
+ eqn.params['call_jaxpr'].jaxpr,
2166
+ {
2167
+ var: val
2168
+ for var, val in
2169
+ zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)
2170
+ }
2171
+ )
2172
+ elif eqn.primitive.name == "scan":
2173
+ out = eqn.primitive.bind(*vals, **eqn.params)
2174
+ else:
2175
+ out = eqn.primitive.bind(*vals, **eqn.params)
2176
+ return out