brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2016 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
18
+ written by the function. These state transformations are foundational for the BrainCore library. These utilities
19
+ include two basic functions: `StatefulFunction` and `make_jaxpr`.
20
+
21
+
22
+ ``StatefulFunction``
23
+ --------------------
24
+
25
+ The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
26
+ JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
27
+ The class provides the following methods:
28
+
29
+ - `make_jaxpr`: creates the JAX Jaxpr of the function.
30
+ - `jaxpr_call`: calls the function at the JAX Jaxpr level.
31
+ - `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
32
+ - `get_states`: returns the states that are read and written by the function.
33
+ - `get_read_states`: returns the states that are read by the function.
34
+ - `get_write_states`: returns the states that are written by the function.
35
+ - `get_static_args`: returns the static arguments from the arguments.
36
+ - `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
37
+ written by the function.
38
+ - `get_jaxpr`: returns the JAX Jaxpr of the function.
39
+ - `get_out_shapes`: returns the output shapes of the function.
40
+ - `get_out_treedef`: returns the output tree of the function.
41
+
42
+ ``make_jaxpr``
43
+ --------------
44
+
45
+ The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
46
+ arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
47
+ `ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
48
+ returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
49
+ and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
50
+ function.
51
+
52
+ """
53
+
54
+ import functools
55
+ import inspect
56
+ import operator
57
+ import threading
58
+ from collections import OrderedDict, defaultdict
59
+ from collections.abc import Hashable, Iterable, Sequence
60
+ from collections.abc import MutableSet
61
+ from contextlib import ExitStack
62
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
63
+
64
+ import jax
65
+ import jax.numpy as jnp
66
+ from jax._src import source_info_util
67
+ from jax._src.linear_util import annotate
68
+ from jax._src.traceback_util import api_boundary
69
+ from jax._src.util import memoize
70
+ from jax.api_util import shaped_abstractify
71
+ from jax.extend.linear_util import transformation_with_aux
72
+ from jax.interpreters import partial_eval as pe
73
+
74
+ from brainstate._compatible_import import (
75
+ ClosedJaxpr, extend_axis_env_nd, safe_map, safe_zip, unzip2, wraps, wrap_init,
76
+ Literal, Var, Jaxpr, make_iota, to_elt, BatchTracer, BatchTrace,
77
+ )
78
+ from brainstate._state import State, StateTraceStack
79
+ from brainstate._utils import set_module_as
80
+ from brainstate.random import RandomState
81
+ from brainstate.typing import Filter, PyTree
82
+ from brainstate.util import PrettyObject
83
+ from brainstate.util.filter import to_predicate
84
+
85
+ AxisName = Hashable
86
+
87
+ __all__ = [
88
+ "StatefulFunction",
89
+ "make_jaxpr",
90
+ "StatefulMapping",
91
+ ]
92
+
93
+
94
+ class hashabledict(dict):
95
+ def __hash__(self):
96
+ return hash(tuple(sorted(self.items())))
97
+
98
+
99
+ class _BoundedCache:
100
+ """
101
+ A thread-safe LRU cache with bounded size.
102
+
103
+ This cache stores a limited number of items and evicts the least recently used item
104
+ when the cache reaches its maximum size. All operations are thread-safe.
105
+
106
+ Parameters
107
+ ----------
108
+ maxsize : int, default 128
109
+ Maximum number of items to store in the cache.
110
+ """
111
+
112
+ def __init__(self, maxsize: int = 128):
113
+ self._cache = OrderedDict()
114
+ self._maxsize = maxsize
115
+ self._lock = threading.RLock()
116
+ self._hits = 0
117
+ self._misses = 0
118
+
119
+ def get(
120
+ self,
121
+ key: Any,
122
+ default: Any = None,
123
+ raise_on_miss: bool = False,
124
+ error_context: str = "item"
125
+ ) -> Any:
126
+ """
127
+ Get an item from the cache.
128
+
129
+ Parameters
130
+ ----------
131
+ key : Any
132
+ The cache key.
133
+ default : Any, optional
134
+ The default value to return if the key is not found.
135
+ raise_on_miss : bool, optional
136
+ If True, raise a detailed ValueError when the key is not found.
137
+ error_context : str, optional
138
+ Context description for the error message (e.g., "Function", "JAX expression").
139
+
140
+ Returns
141
+ -------
142
+ Any
143
+ The cached value or the default value.
144
+
145
+ Raises
146
+ ------
147
+ ValueError
148
+ If raise_on_miss is True and the key is not found.
149
+ """
150
+ with self._lock:
151
+ if key in self._cache:
152
+ self._cache.move_to_end(key)
153
+ self._hits += 1
154
+ return self._cache[key]
155
+ self._misses += 1
156
+
157
+ if raise_on_miss:
158
+ available_keys = list(self._cache.keys())
159
+ error_msg = [
160
+ f"{error_context} not compiled for the requested cache key.",
161
+ f"",
162
+ f"Requested key:",
163
+ f" {key}",
164
+ f"",
165
+ f"Available {{len(available_keys)}} keys:",
166
+ ]
167
+ if available_keys:
168
+ for i, k in enumerate(available_keys, 1):
169
+ error_msg.append(f" [{i}] {k}")
170
+ else:
171
+ error_msg.append(" (none - not compiled yet)")
172
+ error_msg.append("")
173
+ error_msg.append("Call make_jaxpr() first with matching arguments.")
174
+ raise ValueError("\n".join(error_msg))
175
+
176
+ return default
177
+
178
+ def set(self, key: Any, value: Any) -> None:
179
+ """
180
+ Set an item in the cache.
181
+
182
+ Parameters
183
+ ----------
184
+ key : Any
185
+ The cache key.
186
+ value : Any
187
+ The value to cache.
188
+
189
+ Raises
190
+ ------
191
+ ValueError
192
+ If the key already exists in the cache.
193
+ """
194
+ with self._lock:
195
+ if key in self._cache:
196
+ raise ValueError(
197
+ f"Cache key already exists: {key}. "
198
+ f"Cannot overwrite existing cached value. "
199
+ f"Clear the cache first if you need to recompile."
200
+ )
201
+ if len(self._cache) >= self._maxsize:
202
+ self._cache.popitem(last=False)
203
+ self._cache[key] = value
204
+
205
+ def pop(self, key: Any, default: Any = None) -> Any:
206
+ """
207
+ Remove and return an item from the cache.
208
+
209
+ Parameters
210
+ ----------
211
+ key : Any
212
+ The cache key to remove.
213
+ default : Any, optional
214
+ The default value to return if the key is not found.
215
+
216
+ Returns
217
+ -------
218
+ Any
219
+ The cached value or the default value if the key is not found.
220
+ """
221
+ with self._lock:
222
+ if key in self._cache:
223
+ return self._cache.pop(key)
224
+ return default
225
+
226
+ def replace(self, key: Any, value: Any) -> None:
227
+ """
228
+ Replace an existing item in the cache.
229
+
230
+ Parameters
231
+ ----------
232
+ key : Any
233
+ The cache key to replace.
234
+ value : Any
235
+ The new value to cache.
236
+
237
+ Raises
238
+ ------
239
+ KeyError
240
+ If the key does not exist in the cache.
241
+ """
242
+ with self._lock:
243
+ if key not in self._cache:
244
+ raise KeyError(
245
+ f"Cache key does not exist: {key}. "
246
+ f"Cannot replace non-existent cached value. "
247
+ f"Use set() to add a new cache entry."
248
+ )
249
+ self._cache[key] = value
250
+ self._cache.move_to_end(key)
251
+
252
+ def __contains__(self, key: Any) -> bool:
253
+ """
254
+ Check if a key exists in the cache.
255
+
256
+ Parameters
257
+ ----------
258
+ key : Any
259
+ The cache key to check.
260
+
261
+ Returns
262
+ -------
263
+ bool
264
+ True if the key exists in the cache, False otherwise.
265
+ """
266
+ with self._lock:
267
+ return key in self._cache
268
+
269
+ def __len__(self) -> int:
270
+ """
271
+ Get the number of items in the cache.
272
+
273
+ Returns
274
+ -------
275
+ int
276
+ The number of items currently in the cache.
277
+ """
278
+ with self._lock:
279
+ return len(self._cache)
280
+
281
+ def clear(self) -> None:
282
+ """
283
+ Clear all items from the cache and reset statistics.
284
+
285
+ This method removes all cached items and resets hit/miss counters to zero.
286
+ """
287
+ with self._lock:
288
+ self._cache.clear()
289
+ self._hits = 0
290
+ self._misses = 0
291
+
292
+ def keys(self):
293
+ """
294
+ Return all keys in the cache.
295
+
296
+ Returns
297
+ -------
298
+ list
299
+ A list of all keys currently in the cache.
300
+ """
301
+ with self._lock:
302
+ return list(self._cache.keys())
303
+
304
+ def get_stats(self) -> Dict[str, Any]:
305
+ """
306
+ Get cache statistics.
307
+
308
+ Returns
309
+ -------
310
+ dict
311
+ A dictionary with cache statistics including:
312
+
313
+ - 'size': Current number of items in cache
314
+ - 'maxsize': Maximum cache size
315
+ - 'hits': Number of cache hits
316
+ - 'misses': Number of cache misses
317
+ - 'hit_rate': Hit rate percentage (0-100)
318
+ """
319
+ with self._lock:
320
+ total = self._hits + self._misses
321
+ hit_rate = (self._hits / total * 100) if total > 0 else 0.0
322
+ return {
323
+ 'size': len(self._cache),
324
+ 'maxsize': self._maxsize,
325
+ 'hits': self._hits,
326
+ 'misses': self._misses,
327
+ 'hit_rate': hit_rate,
328
+ }
329
+
330
+
331
+ def _ensure_str(x: str) -> str:
332
+ if not isinstance(x, str):
333
+ raise TypeError(f"argument is not a string: {x}")
334
+ return x
335
+
336
+
337
+ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
338
+ """Convert x to a tuple of indices."""
339
+ x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
340
+ try:
341
+ return (operator.index(x),)
342
+ except TypeError:
343
+ return tuple(safe_map(operator.index, x))
344
+
345
+
346
+ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
347
+ """Convert x to a tuple of strings."""
348
+ if isinstance(x, str):
349
+ return (x,)
350
+ else:
351
+ return tuple(safe_map(_ensure_str, x))
352
+
353
+
354
+ def _jax_v04_new_arg_fn(frame, trace, aval):
355
+ """
356
+ Transform a new argument to a tracer.
357
+
358
+ Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
359
+
360
+ Args:
361
+ frame: The frame.
362
+ trace: The trace.
363
+ aval: The abstract value.
364
+
365
+ Returns:
366
+ The tracer.
367
+ """
368
+ tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
369
+ frame.tracers.append(tracer)
370
+ frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
371
+ frame.invars.append(var)
372
+ return tracer
373
+
374
+
375
+ def _jax_v04_new_jax_trace():
376
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
377
+ frame = main.jaxpr_stack[-1]
378
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
379
+ return frame, trace
380
+
381
+
382
+ class StatefulFunction(PrettyObject):
383
+ """
384
+ A wrapper class for functions that tracks state reads and writes during execution.
385
+
386
+ This class wraps a function to enable state management in JAX programs by tracking
387
+ which states are read from and written to during function execution. It provides
388
+ methods to compile the function into JAX's intermediate representation (jaxpr),
389
+ inspect state usage, and execute the function with proper state handling.
390
+
391
+ When you define a function:
392
+
393
+ .. code-block:: python
394
+
395
+ >>> state = brainstate.State(1.)
396
+ >>> def f(x):
397
+ ... # Your function logic here
398
+ ... y = x * 2 + state.value
399
+ ... state.value = y
400
+
401
+ Calling ``sf = StatefulFunction(f)`` creates a stateful version of ``f``. You can
402
+ then call it directly with compatibility with JIT:
403
+
404
+ .. code-block:: python
405
+
406
+ >>> sf = brainstate.transform.StatefulFunction(f)
407
+ >>> out = sf(x) # Automatically compiles and executes
408
+
409
+ Parameters
410
+ ----------
411
+ fun : callable
412
+ The function whose ``jaxpr`` is to be computed. Its positional
413
+ arguments and return value should be arrays, scalars, or standard Python
414
+ containers (tuple/list/dict) thereof.
415
+ static_argnums : int or iterable of int, optional
416
+ Indices of positional arguments to treat as static (known at compile time).
417
+ See :py:func:`jax.jit` for details. Default is ().
418
+ static_argnames : str or iterable of str, optional
419
+ Names of keyword arguments to treat as static (known at compile time).
420
+ See :py:func:`jax.jit` for details. Default is ().
421
+ axis_env : sequence of tuple, optional
422
+ A sequence of pairs where the first element is an axis name and the second
423
+ element is a positive integer representing the size of the mapped axis with
424
+ that name. This parameter is useful when lowering functions that involve
425
+ parallel communication collectives, and it specifies the axis name/size
426
+ environment that would be set up by applications of :py:func:`jax.pmap`.
427
+ Default is None.
428
+ abstracted_axes : pytree, optional
429
+ A pytree with the same structure as the input arguments to ``fun``. The
430
+ leaves of the pytree can be either None or a dict with axis names as keys
431
+ and integers as values. If the leaf is None, then the corresponding axis
432
+ is not abstracted. If the leaf is a dict, then the corresponding axis is
433
+ abstracted, and the dict specifies the axis name and size. The abstracted
434
+ axes are used to infer the input type of the function. If None, then all
435
+ axes are abstracted. Default is None.
436
+ name : str, optional
437
+ Name for the stateful function. Default is None.
438
+ return_only_write : bool, optional
439
+ If True, only return states that were written to during execution
440
+ (not just read). This can reduce memory usage when you only care
441
+ about modified states. Default is True.
442
+
443
+ Attributes
444
+ ----------
445
+ fun : callable
446
+ The wrapped function.
447
+ static_argnums : tuple of int
448
+ Indices of static positional arguments.
449
+ static_argnames : tuple of str
450
+ Names of static keyword arguments.
451
+ axis_env : sequence of tuple or None
452
+ Axis environment for parallel operations.
453
+ abstracted_axes : pytree or None
454
+ Abstract axes specification.
455
+ name : str or None
456
+ Name identifier for the function.
457
+ return_only_write : bool
458
+ Whether to return only written states.
459
+
460
+ Examples
461
+ --------
462
+ Basic usage with state management:
463
+
464
+ .. code-block:: python
465
+
466
+ >>> import brainstate
467
+ >>> import jax.numpy as jnp
468
+ >>>
469
+ >>> # Create a state
470
+ >>> state = brainstate.State(jnp.array([1.0, 2.0]))
471
+ >>>
472
+ >>> def f(x):
473
+ ... state.value += x
474
+ ... return state.value * 2
475
+ >>>
476
+ >>> # Create a stateful function
477
+ >>> sf = brainstate.transform.StatefulFunction(f)
478
+ >>>
479
+ >>> # Compile and get jaxpr
480
+ >>> x = jnp.array([0.5, 0.5])
481
+ >>> sf.make_jaxpr(x)
482
+ >>>
483
+ >>> # Get states that are read/written
484
+ >>> cache_key = sf.get_arg_cache_key(x)
485
+ >>> states = sf.get_states_by_cache(cache_key)
486
+ >>> read_states = sf.get_read_states_by_cache(cache_key)
487
+ >>> write_states = sf.get_write_states_by_cache(cache_key)
488
+
489
+ Using with static arguments:
490
+
491
+ .. code-block:: python
492
+
493
+ >>> def g(x, n):
494
+ ... state.value = state.value ** n
495
+ ... return state.value
496
+ >>>
497
+ >>> sf_static = brainstate.transform.StatefulFunction(
498
+ ... g, static_argnums=(1,)
499
+ ... )
500
+ >>> sf_static.make_jaxpr(x, 2)
501
+
502
+ Automatic state management:
503
+
504
+ .. code-block:: python
505
+
506
+ >>> # Execute with automatic state handling
507
+ >>> result = sf.jaxpr_call_auto(x)
508
+ >>> print(state.value) # State is automatically updated
509
+
510
+ See Also
511
+ --------
512
+ make_jaxpr : Function to create jaxpr from a function.
513
+ brainstate.State : The state container class.
514
+
515
+ Notes
516
+ -----
517
+ This class maintains internal thread-safe caches for compiled jaxprs, output
518
+ shapes, and state traces. The cache size is bounded at 128 entries per cache
519
+ type. Use ``clear_cache()`` to manually clear the caches if needed.
520
+
521
+ State objects should not be passed as direct inputs or outputs to the wrapped
522
+ function. Instead, they should be accessed within the function body, and the
523
+ class will automatically track their usage.
524
+ """
525
+ __module__ = "brainstate.transform"
526
+
527
+ def __init__(
528
+ self,
529
+ fun: Callable,
530
+ static_argnums: Union[int, Iterable[int]] = (),
531
+ static_argnames: Union[str, Iterable[str]] = (),
532
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
533
+ abstracted_axes: Optional[Any] = None,
534
+ name: Optional[str] = None,
535
+ return_only_write: bool = True,
536
+ ):
537
+ # explicit parameters
538
+ self.fun = fun
539
+ self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
540
+ self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
541
+ self.axis_env = axis_env
542
+ self.abstracted_axes = abstracted_axes
543
+ self.name = name
544
+ self.return_only_write = return_only_write
545
+
546
+ # implicit parameters - thread-safe bounded caches
547
+ self._cached_jaxpr = _BoundedCache(maxsize=128)
548
+ self._cached_out_shapes = _BoundedCache(maxsize=128)
549
+ self._cached_jaxpr_out_tree = _BoundedCache(maxsize=128)
550
+ self._cached_state_trace = _BoundedCache(maxsize=128)
551
+ self._cache_lock = threading.RLock()
552
+
553
+ def __pretty_repr_item__(self, k, v):
554
+ if k.startswith('_'):
555
+ return None
556
+ return k, v
557
+
558
+ def get_jaxpr_by_cache(self, cache_key: Hashable) -> ClosedJaxpr:
559
+ """
560
+ Read the JAX Jaxpr representation of the function.
561
+
562
+ Parameters
563
+ ----------
564
+ cache_key : Hashable
565
+ The hashable cache key for retrieving the compiled jaxpr.
566
+
567
+ Returns
568
+ -------
569
+ ClosedJaxpr
570
+ The JAX Jaxpr representation of the function.
571
+
572
+ Raises
573
+ ------
574
+ ValueError
575
+ If the function has not been compiled for the given cache key.
576
+ """
577
+ return self._cached_jaxpr.get(cache_key, raise_on_miss=True, error_context="JAX expression")
578
+
579
+ def get_jaxpr(self, *args, compile_if_miss: bool = True, **kwargs) -> ClosedJaxpr:
580
+ """
581
+ Read the JAX Jaxpr representation of the function by calling with args.
582
+
583
+ Parameters
584
+ ----------
585
+ *args
586
+ The arguments to the function.
587
+ compile_if_miss : bool, optional
588
+ Whether to compile the function if the cache key is not found. Default is True.
589
+ **kwargs
590
+ The keyword arguments to the function.
591
+
592
+ Returns
593
+ -------
594
+ ClosedJaxpr
595
+ The JAX Jaxpr representation of the function.
596
+ """
597
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
598
+ return self.get_jaxpr_by_cache(cache_key)
599
+
600
+ def get_out_shapes_by_cache(self, cache_key: Hashable) -> PyTree:
601
+ """
602
+ Read the output shapes of the function.
603
+
604
+ Parameters
605
+ ----------
606
+ cache_key : Hashable
607
+ The hashable cache key.
608
+
609
+ Returns
610
+ -------
611
+ PyTree
612
+ The output shapes of the function.
613
+
614
+ Raises
615
+ ------
616
+ ValueError
617
+ If the function has not been compiled for the given cache key.
618
+ """
619
+ return self._cached_out_shapes.get(cache_key, raise_on_miss=True, error_context="Output shapes")
620
+
621
+ def get_out_shapes(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
622
+ """
623
+ Read the output shapes of the function.
624
+
625
+ Parameters
626
+ ----------
627
+ *args
628
+ The arguments to the function.
629
+ compile_if_miss : bool, optional
630
+ Whether to compile the function if the cache key is not found. Default is True.
631
+ **kwargs
632
+ The keyword arguments to the function.
633
+
634
+ Returns
635
+ -------
636
+ PyTree
637
+ The output shapes of the function.
638
+ """
639
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
640
+ return self.get_out_shapes_by_cache(cache_key)
641
+
642
+ def get_out_treedef_by_cache(self, cache_key: Hashable) -> PyTree:
643
+ """
644
+ Read the output tree definition of the function.
645
+
646
+ Parameters
647
+ ----------
648
+ cache_key : Hashable
649
+ The hashable cache key.
650
+
651
+ Returns
652
+ -------
653
+ PyTree
654
+ The output tree definition of the function.
655
+
656
+ Raises
657
+ ------
658
+ ValueError
659
+ If the function has not been compiled for the given cache key.
660
+ """
661
+ return self._cached_jaxpr_out_tree.get(cache_key, raise_on_miss=True, error_context="Output tree")
662
+
663
+ def get_out_treedef(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
664
+ """
665
+ Read the output tree of the function.
666
+
667
+ Parameters
668
+ ----------
669
+ *args
670
+ The arguments to the function.
671
+ compile_if_miss : bool, optional
672
+ Whether to compile the function if the cache key is not found. Default is True.
673
+ **kwargs
674
+ The keyword arguments to the function.
675
+
676
+ Returns
677
+ -------
678
+ PyTree
679
+ The output tree of the function.
680
+ """
681
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
682
+ return self.get_out_treedef_by_cache(cache_key)
683
+
684
+ def get_state_trace_by_cache(self, cache_key: Hashable) -> StateTraceStack:
685
+ """
686
+ Read the state trace of the function.
687
+
688
+ Parameters
689
+ ----------
690
+ cache_key : Hashable
691
+ The hashable cache key.
692
+
693
+ Returns
694
+ -------
695
+ StateTraceStack
696
+ The state trace stack containing all tracked states.
697
+
698
+ Raises
699
+ ------
700
+ ValueError
701
+ If the function has not been compiled for the given cache key.
702
+ """
703
+ return self._cached_state_trace.get(cache_key, raise_on_miss=True, error_context="State trace")
704
+
705
+ def get_state_trace(self, *args, compile_if_miss: bool = True, **kwargs) -> StateTraceStack:
706
+ """
707
+ Read the state trace of the function.
708
+
709
+ Parameters
710
+ ----------
711
+ *args
712
+ The arguments to the function.
713
+ compile_if_miss : bool, optional
714
+ Whether to compile the function if the cache key is not found. Default is True.
715
+ **kwargs
716
+ The keyword arguments to the function.
717
+
718
+ Returns
719
+ -------
720
+ StateTraceStack
721
+ The state trace of the function.
722
+ """
723
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
724
+ return self.get_state_trace_by_cache(cache_key)
725
+
726
+ def get_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
727
+ """
728
+ Read the states that are accessed by the function.
729
+
730
+ Parameters
731
+ ----------
732
+ cache_key : Hashable
733
+ The hashable cache key.
734
+
735
+ Returns
736
+ -------
737
+ Tuple[State, ...]
738
+ The states that are read from or written to by the function.
739
+
740
+ Raises
741
+ ------
742
+ ValueError
743
+ If the function has not been compiled for the given cache key.
744
+ """
745
+ return tuple(self.get_state_trace_by_cache(cache_key).states)
746
+
747
+ def get_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
748
+ """
749
+ Compile the function, and get the states that are read and written by this function.
750
+
751
+ Parameters
752
+ ----------
753
+ *args
754
+ The arguments to the function.
755
+ compile_if_miss : bool, optional
756
+ Whether to compile the function if the cache key is not found. Default is True.
757
+ **kwargs
758
+ The keyword arguments to the function.
759
+
760
+ Returns
761
+ -------
762
+ Tuple[State, ...]
763
+ The states that are read and written by the function.
764
+ """
765
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
766
+ return self.get_states_by_cache(cache_key)
767
+
768
+ def get_read_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
769
+ """
770
+ Read the states that are read by the function.
771
+
772
+ Parameters
773
+ ----------
774
+ cache_key : Hashable
775
+ The hashable key.
776
+
777
+ Returns
778
+ -------
779
+ Tuple[State, ...]
780
+ The states that are read by the function.
781
+ """
782
+ return self.get_state_trace_by_cache(cache_key).get_read_states()
783
+
784
+ def get_read_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
785
+ """
786
+ Compile the function, and get the states that are read by this function.
787
+
788
+ Parameters
789
+ ----------
790
+ *args
791
+ The arguments to the function.
792
+ compile_if_miss : bool, optional
793
+ Whether to compile the function if the cache key is not found. Default is True.
794
+ **kwargs
795
+ The keyword arguments to the function.
796
+
797
+ Returns
798
+ -------
799
+ Tuple[State, ...]
800
+ The states that are read by the function.
801
+ """
802
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
803
+ return self.get_read_states_by_cache(cache_key)
804
+
805
+ def get_write_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
806
+ """
807
+ Read the states that are written by the function.
808
+
809
+ Parameters
810
+ ----------
811
+ cache_key : Hashable
812
+ The hashable cache key.
813
+
814
+ Returns
815
+ -------
816
+ Tuple[State, ...]
817
+ The states that are written by the function.
818
+ """
819
+ return self.get_state_trace_by_cache(cache_key).get_write_states()
820
+
821
+ def get_write_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
822
+ """
823
+ Compile the function, and get the states that are written by this function.
824
+
825
+ Parameters
826
+ ----------
827
+ *args
828
+ The arguments to the function.
829
+ compile_if_miss : bool, optional
830
+ Whether to compile the function if the cache key is not found. Default is True.
831
+ **kwargs
832
+ The keyword arguments to the function.
833
+
834
+ Returns
835
+ -------
836
+ Tuple[State, ...]
837
+ The states that are written by the function.
838
+ """
839
+ cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
840
+ return self.get_write_states_by_cache(cache_key)
841
+
842
+ def _check_input_ouput(self, x):
843
+ if isinstance(x, State):
844
+ x.raise_error_with_source_info(
845
+ ValueError(
846
+ 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
847
+ f'But we got {x}'
848
+ )
849
+ )
850
+
851
+ def get_arg_cache_key(self, *args, compile_if_miss: bool = False, **kwargs) -> hashabledict:
852
+ """
853
+ Compute the cache key for the given arguments.
854
+
855
+ This method separates static and dynamic arguments and creates a hashable
856
+ key that can be used to cache compiled jaxpr representations.
857
+
858
+ Parameters
859
+ ----------
860
+ *args
861
+ The positional arguments to the function.
862
+ compile_if_miss : bool, optional
863
+ Whether to compile the function if the cache key does not exist.
864
+ Default is False.
865
+ **kwargs
866
+ The keyword arguments to the function.
867
+
868
+ Returns
869
+ -------
870
+ hashabledict
871
+ A hashable dictionary containing the cache key with fields:
872
+ 'static_args', 'dyn_args', 'static_kwargs', 'dyn_kwargs'.
873
+
874
+ Examples
875
+ --------
876
+ .. code-block:: python
877
+
878
+ >>> import brainstate
879
+ >>> import jax.numpy as jnp
880
+ >>>
881
+ >>> def f(x, n):
882
+ ... return x ** n
883
+ >>>
884
+ >>> sf = brainstate.transform.StatefulFunction(
885
+ ... f, static_argnums=(1,)
886
+ ... )
887
+ >>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
888
+ """
889
+ static_args, dyn_args = [], []
890
+ for i, arg in enumerate(args):
891
+ if i in self.static_argnums:
892
+ static_args.append(arg)
893
+ else:
894
+ dyn_args.append(arg)
895
+ dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
896
+ static_kwargs, dyn_kwargs = [], []
897
+ for k, v in sorted(kwargs.items()):
898
+ if k in self.static_argnames:
899
+ static_kwargs.append((k, v))
900
+ else:
901
+ dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
902
+
903
+ static_args = make_hashable(tuple(static_args))
904
+ dyn_args = make_hashable(tuple(dyn_args))
905
+ static_kwargs = make_hashable(static_kwargs)
906
+ dyn_kwargs = make_hashable(dyn_kwargs)
907
+
908
+ cache_key = hashabledict(
909
+ static_args=static_args,
910
+ dyn_args=dyn_args,
911
+ static_kwargs=static_kwargs,
912
+ dyn_kwargs=dyn_kwargs,
913
+ )
914
+
915
+ if cache_key not in self._cached_state_trace and compile_if_miss:
916
+ self.make_jaxpr(*args, **kwargs)
917
+
918
+ return cache_key
919
+
920
+ def clear_cache(self) -> None:
921
+ """
922
+ Clear all compilation caches.
923
+
924
+ This method removes all cached jaxprs, output shapes, output trees,
925
+ and state traces. Use this when you need to recompile the function
926
+ or free memory.
927
+
928
+ Examples
929
+ --------
930
+ .. code-block:: python
931
+
932
+ >>> import brainstate
933
+ >>> import jax.numpy as jnp
934
+ >>>
935
+ >>> def f(x):
936
+ ... return x * 2
937
+ >>>
938
+ >>> sf = brainstate.transform.StatefulFunction(f)
939
+ >>> sf.make_jaxpr(jnp.array([1.0, 2.0]))
940
+ >>> sf.clear_cache() # Clear all cached compilations
941
+ """
942
+ self._cached_jaxpr.clear()
943
+ self._cached_out_shapes.clear()
944
+ self._cached_jaxpr_out_tree.clear()
945
+ self._cached_state_trace.clear()
946
+
947
+ def __jax_v04_new_arg(self):
948
+ # Should be within the calling of ``jax.make_jaxpr()``
949
+ frame, trace = _jax_v04_new_jax_trace()
950
+ # Set the function to transform the new argument to a tracer
951
+ fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
952
+ return fn
953
+
954
+ def __jax_new_version_new_arg(self):
955
+ trace = jax.core.trace_ctx.trace
956
+
957
+ def wrapper(x):
958
+ if jax.__version_info__ < (0, 6, 1):
959
+ fn = lambda xx: trace.new_arg(shaped_abstractify(xx))
960
+ else:
961
+ fn = lambda xx: trace.new_arg(shaped_abstractify(xx), source_info=source_info_util.current())
962
+ return jax.tree.map(fn, x._value)
963
+
964
+ return wrapper
965
+
966
+ def _wrapped_fun_to_eval(
967
+ self,
968
+ cache_key,
969
+ static_kwargs: dict,
970
+ *args,
971
+ **dyn_kwargs,
972
+ ) -> Tuple[Any, Tuple[State, ...]]:
973
+ """
974
+ Internal wrapper that executes the function and tracks state operations.
975
+
976
+ This method wraps the original function to track which states are read
977
+ and written during execution. It is used internally during jaxpr compilation.
978
+
979
+ Parameters
980
+ ----------
981
+ cache_key
982
+ The cache key for storing the state trace.
983
+ static_kwargs : dict
984
+ Static keyword arguments that were separated out.
985
+ *args
986
+ The positional arguments to the function.
987
+ **dyn_kwargs
988
+ Dynamic keyword arguments to the function.
989
+
990
+ Returns
991
+ -------
992
+ tuple
993
+ A tuple of (output, state_values) where output is the function result
994
+ and state_values are the tracked state values (either all or write-only
995
+ depending on return_only_write setting).
996
+ """
997
+ # state trace
998
+ state_trace: StateTraceStack = StateTraceStack(self.name)
999
+ if jax.__version_info__ < (0, 4, 36):
1000
+ state_trace.set_new_arg(self.__jax_v04_new_arg())
1001
+ else:
1002
+ state_trace.set_new_arg(self.__jax_new_version_new_arg())
1003
+ self._cached_state_trace.set(cache_key, state_trace)
1004
+ with state_trace:
1005
+ out = self.fun(*args, **dyn_kwargs, **static_kwargs)
1006
+ state_values = (
1007
+ state_trace.get_write_state_values(True)
1008
+ if self.return_only_write else
1009
+ state_trace.get_state_values()
1010
+ )
1011
+ state_trace.recovery_original_values()
1012
+
1013
+ # State instance as functional returns is not allowed.
1014
+ # Checking whether the states are returned.
1015
+ jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
1016
+ return out, state_values
1017
+
1018
+ def make_jaxpr(self, *args, **kwargs):
1019
+ """
1020
+ Create the JAX Jaxpr representation given example arguments.
1021
+
1022
+ This method compiles the function with the given arguments and caches
1023
+ the resulting Jaxpr, output shapes, and state trace for later use.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ *args
1028
+ The arguments to the function.
1029
+ **kwargs
1030
+ The keyword arguments to the function.
1031
+
1032
+ Returns
1033
+ -------
1034
+ StatefulFunction
1035
+ Returns self for method chaining.
1036
+
1037
+ Raises
1038
+ ------
1039
+ TypeError
1040
+ If State objects are passed as arguments or returned from the function.
1041
+ """
1042
+
1043
+ # check input types
1044
+ jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
1045
+
1046
+ # static args
1047
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
1048
+
1049
+ if cache_key not in self._cached_state_trace:
1050
+ try:
1051
+
1052
+ # jaxpr
1053
+ static_kwargs, dyn_kwargs = {}, {}
1054
+ for k, v in kwargs.items():
1055
+ if k in self.static_argnames:
1056
+ static_kwargs[k] = v
1057
+ else:
1058
+ dyn_kwargs[k] = v
1059
+ jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
1060
+ functools.partial(
1061
+ self._wrapped_fun_to_eval,
1062
+ cache_key,
1063
+ static_kwargs,
1064
+ ),
1065
+ static_argnums=self.static_argnums,
1066
+ axis_env=self.axis_env,
1067
+ return_shape=True,
1068
+ abstracted_axes=self.abstracted_axes,
1069
+ )(*args, **dyn_kwargs)
1070
+
1071
+ # returns
1072
+ self._cached_jaxpr_out_tree.set(cache_key, jax.tree.structure((out_shapes, state_shapes)))
1073
+ self._cached_out_shapes.set(cache_key, (out_shapes, state_shapes))
1074
+ self._cached_jaxpr.set(cache_key, jaxpr)
1075
+
1076
+ except Exception as e:
1077
+ # Clean up partial cache entries on error
1078
+ self._cached_state_trace.pop(cache_key, None)
1079
+ self._cached_out_shapes.pop(cache_key, None)
1080
+ self._cached_jaxpr.pop(cache_key, None)
1081
+ raise e
1082
+
1083
+ return self
1084
+
1085
+ def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
1086
+ """
1087
+ Call the function at the JAX Jaxpr level.
1088
+
1089
+ This method evaluates the compiled Jaxpr with the provided state values
1090
+ and arguments, returning updated state values and function outputs.
1091
+
1092
+ Parameters
1093
+ ----------
1094
+ state_vals : Sequence
1095
+ The current state values.
1096
+ *args
1097
+ The arguments to the function.
1098
+ **kwargs
1099
+ The keyword arguments to the function.
1100
+
1101
+ Returns
1102
+ -------
1103
+ tuple
1104
+ A tuple of (new_state_vals, out) where new_state_vals are the
1105
+ updated state values and out is the function output.
1106
+
1107
+ Raises
1108
+ ------
1109
+ ValueError
1110
+ If the number of state values doesn't match the expected number.
1111
+ """
1112
+ # state checking
1113
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
1114
+ states: Sequence[State] = self.get_states_by_cache(cache_key)
1115
+ if len(state_vals) != len(states):
1116
+ raise ValueError(f'State length mismatch: expected {len(states)} states, got {len(state_vals)}')
1117
+
1118
+ # parameters
1119
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
1120
+ args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
1121
+ args = jax.tree.flatten((args, kwargs, state_vals))[0]
1122
+
1123
+ # calling the function,
1124
+ # note that this function always returns state values
1125
+ # that both write and read by the function
1126
+ closed_jaxpr = self.get_jaxpr_by_cache(cache_key)
1127
+ out_treedef = self.get_out_treedef_by_cache(cache_key)
1128
+ jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
1129
+
1130
+ # output processing
1131
+ out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
1132
+ if len(new_state_vals) != len(state_vals):
1133
+ raise ValueError(f'State length mismatch in output: expected '
1134
+ f'{len(state_vals)} states, got {len(new_state_vals)}')
1135
+ return new_state_vals, out
1136
+
1137
+ def get_cache_stats(self) -> Dict[str, Any]:
1138
+ """
1139
+ Get comprehensive cache statistics for all internal caches.
1140
+
1141
+ Returns
1142
+ -------
1143
+ dict
1144
+ A dictionary with statistics for each cache including size, hits,
1145
+ misses, and hit rates. Keys are 'jaxpr_cache', 'out_shapes_cache',
1146
+ 'jaxpr_out_tree_cache', and 'state_trace_cache'.
1147
+ """
1148
+ return {
1149
+ 'jaxpr_cache': self._cached_jaxpr.get_stats(),
1150
+ 'out_shapes_cache': self._cached_out_shapes.get_stats(),
1151
+ 'jaxpr_out_tree_cache': self._cached_jaxpr_out_tree.get_stats(),
1152
+ 'state_trace_cache': self._cached_state_trace.get_stats(),
1153
+ }
1154
+
1155
+ def validate_states(self, cache_key: Hashable) -> bool:
1156
+ """
1157
+ Validate that all tracked states for a given cache key are still valid.
1158
+
1159
+ Parameters
1160
+ ----------
1161
+ cache_key : Hashable
1162
+ The cache key to validate states for.
1163
+
1164
+ Returns
1165
+ -------
1166
+ bool
1167
+ True if all states are valid.
1168
+
1169
+ Raises
1170
+ ------
1171
+ ValueError
1172
+ If any states are invalid or missing required attributes.
1173
+ """
1174
+ state_trace = self.get_state_trace_by_cache(cache_key)
1175
+ invalid_states = []
1176
+ for i, state in enumerate(state_trace.states):
1177
+ if not hasattr(state, 'value'):
1178
+ invalid_states.append((i, state))
1179
+
1180
+ if invalid_states:
1181
+ raise ValueError(
1182
+ f"Found {len(invalid_states)} invalid states at indices: "
1183
+ f"{[idx for idx, _ in invalid_states]}. "
1184
+ f"States must have a 'value' attribute."
1185
+ )
1186
+ return True
1187
+
1188
+ def validate_all_states(self) -> Dict[Any, bool]:
1189
+ """
1190
+ Validate states for all cached compilations.
1191
+
1192
+ Returns
1193
+ -------
1194
+ dict
1195
+ A dictionary mapping cache keys to validation results. Each value
1196
+ is either True (valid) or an error message string (invalid).
1197
+ """
1198
+ results = {}
1199
+ for cache_key in self._cached_state_trace.keys():
1200
+ try:
1201
+ results[cache_key] = self.validate_states(cache_key)
1202
+ except ValueError as e:
1203
+ results[cache_key] = str(e)
1204
+ return results
1205
+
1206
+ def jaxpr_call_auto(self, *args, **kwargs) -> Any:
1207
+ """
1208
+ Execute the function at the jaxpr level with automatic state management.
1209
+
1210
+ This method automatically retrieves current state values, executes the
1211
+ jaxpr-compiled function, and updates the states with the new values.
1212
+ It provides a convenient interface that handles all state management
1213
+ automatically.
1214
+
1215
+ Parameters
1216
+ ----------
1217
+ *args
1218
+ The positional arguments to the function.
1219
+ **kwargs
1220
+ The keyword arguments to the function.
1221
+
1222
+ Returns
1223
+ -------
1224
+ Any
1225
+ The output of the function.
1226
+
1227
+ Examples
1228
+ --------
1229
+ .. code-block:: python
1230
+
1231
+ >>> import brainstate
1232
+ >>> import jax.numpy as jnp
1233
+ >>>
1234
+ >>> state = brainstate.State(jnp.array([1.0, 2.0]))
1235
+ >>>
1236
+ >>> def f(x):
1237
+ ... state.value += x
1238
+ ... return state.value * 2
1239
+ >>>
1240
+ >>> sf = brainstate.transform.StatefulFunction(f)
1241
+ >>> x = jnp.array([0.5, 0.5])
1242
+ >>> sf.make_jaxpr(x)
1243
+ >>>
1244
+ >>> # Automatic state management
1245
+ >>> result = sf.jaxpr_call_auto(x)
1246
+ # # or
1247
+ >>> result = sf(x)
1248
+ >>> print(state.value) # State is automatically updated
1249
+ """
1250
+ state_trace = self.get_state_trace_by_cache(self.get_arg_cache_key(*args, **kwargs, compile_if_miss=True))
1251
+ all_read_state_vals = state_trace.get_read_state_values(True)
1252
+ state_vals, out = self.jaxpr_call(state_trace.get_state_values(), *args, **kwargs)
1253
+ state_trace.assign_state_vals_v2(all_read_state_vals, state_vals)
1254
+ return out
1255
+
1256
+ def __call__(self, *args, **kwargs):
1257
+ return self.jaxpr_call_auto(*args, **kwargs)
1258
+
1259
+
1260
+ @set_module_as("brainstate.transform")
1261
+ def make_jaxpr(
1262
+ fun: Callable,
1263
+ static_argnums: Union[int, Iterable[int]] = (),
1264
+ static_argnames: Union[str, Iterable[str]] = (),
1265
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1266
+ return_shape: bool = False,
1267
+ abstracted_axes: Optional[Any] = None,
1268
+ return_only_write: bool = False,
1269
+ ) -> Callable[
1270
+ ...,
1271
+ (Tuple[ClosedJaxpr, Tuple[State, ...]] |
1272
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
1273
+ ]:
1274
+ """
1275
+ Creates a function that produces its jaxpr given example args.
1276
+
1277
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
1278
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
1279
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
1280
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
1281
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
1282
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
1283
+
1284
+ Parameters
1285
+ ----------
1286
+ fun : callable
1287
+ The function whose ``jaxpr`` is to be computed. Its positional
1288
+ arguments and return value should be arrays, scalars, or standard Python
1289
+ containers (tuple/list/dict) thereof.
1290
+ static_argnums : int or iterable of int, optional
1291
+ See the :py:func:`jax.jit` docstring.
1292
+ static_argnames : str or iterable of str, optional
1293
+ See the :py:func:`jax.jit` docstring.
1294
+ axis_env : sequence of tuple, optional
1295
+ A sequence of pairs where the first element is an axis
1296
+ name and the second element is a positive integer representing the size of
1297
+ the mapped axis with that name. This parameter is useful when lowering
1298
+ functions that involve parallel communication collectives, and it
1299
+ specifies the axis name/size environment that would be set up by
1300
+ applications of :py:func:`jax.pmap`.
1301
+ return_shape : bool, default False
1302
+ If ``True``, the
1303
+ wrapped function returns a pair where the first element is the XLA
1304
+ computation and the second element is a pytree with the same structure as
1305
+ the output of ``fun`` and where the leaves are objects with ``shape``,
1306
+ ``dtype``, and ``named_shape`` attributes representing the corresponding
1307
+ types of the output leaves.
1308
+ abstracted_axes : pytree, optional
1309
+ A pytree with the same structure as the input
1310
+ arguments to ``fun``. The leaves of the pytree can be either None or a
1311
+ dict with axis names as keys and integers as values. If the leaf is None,
1312
+ then the corresponding axis is not abstracted. If the leaf is a dict, then
1313
+ the corresponding axis is abstracted, and the dict specifies the axis name
1314
+ and size. The abstracted axes are used to infer the input type of the
1315
+ function. If None, then all axes are abstracted.
1316
+ return_only_write : bool, default False
1317
+ If True, only return states that were written to during execution
1318
+ (not just read). This can reduce memory usage when you only care
1319
+ about modified states.
1320
+
1321
+ Returns
1322
+ -------
1323
+ callable
1324
+ A wrapped version of ``fun`` that when applied to example arguments returns
1325
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
1326
+ argument ``return_shape`` is ``True``, then the returned function instead
1327
+ returns a pair where the first element is the ``ClosedJaxpr``
1328
+ representation of ``fun`` and the second element is a pytree representing
1329
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
1330
+
1331
+ Examples
1332
+ --------
1333
+ Basic usage:
1334
+
1335
+ .. code-block:: python
1336
+
1337
+ >>> import jax
1338
+ >>> import brainstate
1339
+ >>> import jax.numpy as jnp
1340
+ >>>
1341
+ >>> def f(x):
1342
+ ... return jnp.sin(jnp.cos(x))
1343
+ >>>
1344
+ >>> # Create jaxpr maker
1345
+ >>> jaxpr_maker = brainstate.transform.make_jaxpr(f)
1346
+ >>> jaxpr, states = jaxpr_maker(3.0)
1347
+
1348
+ With gradient:
1349
+
1350
+ .. code-block:: python
1351
+
1352
+ >>> jaxpr_grad_maker = brainstate.transform.make_jaxpr(jax.grad(f))
1353
+ >>> jaxpr, states = jaxpr_grad_maker(3.0)
1354
+
1355
+ With shape information:
1356
+
1357
+ .. code-block:: python
1358
+
1359
+ >>> jaxpr_maker_with_shape = brainstate.transform.make_jaxpr(f, return_shape=True)
1360
+ >>> jaxpr, states, shapes = jaxpr_maker_with_shape(3.0)
1361
+
1362
+ With stateful function:
1363
+
1364
+ .. code-block:: python
1365
+
1366
+ >>> state = brainstate.State(jnp.array([1.0, 2.0]))
1367
+ >>>
1368
+ >>> def stateful_f(x):
1369
+ ... state.value += x
1370
+ ... return state.value
1371
+ >>>
1372
+ >>> jaxpr_maker = brainstate.transform.make_jaxpr(stateful_f)
1373
+ >>> jaxpr, states = jaxpr_maker(jnp.array([0.5, 0.5]))
1374
+ """
1375
+
1376
+ stateful_fun = StatefulFunction(
1377
+ fun,
1378
+ static_argnums=static_argnums,
1379
+ static_argnames=static_argnames,
1380
+ axis_env=axis_env,
1381
+ abstracted_axes=abstracted_axes,
1382
+ return_only_write=return_only_write,
1383
+ name='make_jaxpr'
1384
+ )
1385
+
1386
+ @wraps(fun)
1387
+ def make_jaxpr_f(*args, **kwargs):
1388
+ stateful_fun.make_jaxpr(*args, **kwargs)
1389
+ cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
1390
+ if return_shape:
1391
+ return (
1392
+ stateful_fun.get_jaxpr_by_cache(cache_key),
1393
+ stateful_fun.get_states_by_cache(cache_key),
1394
+ stateful_fun.get_out_shapes_by_cache(cache_key)[0]
1395
+ )
1396
+ else:
1397
+ return (
1398
+ stateful_fun.get_jaxpr_by_cache(cache_key),
1399
+ stateful_fun.get_states_by_cache(cache_key)
1400
+ )
1401
+
1402
+ # wrapped jaxpr builder function
1403
+ make_jaxpr_f.__module__ = "brainstate.transform"
1404
+ if hasattr(fun, "__qualname__"):
1405
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
1406
+ if hasattr(fun, "__name__"):
1407
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
1408
+ return make_jaxpr_f
1409
+
1410
+
1411
+ class StatefulMapping(StatefulFunction):
1412
+ __module__ = "brainstate.transform"
1413
+
1414
+ def __init__(
1415
+ self,
1416
+ fun: Callable,
1417
+ in_axes: Union[int, Tuple[int, ...], None] = 0,
1418
+ out_axes: Union[int, Tuple[int, ...], None] = 0,
1419
+ state_in_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1420
+ state_out_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
1421
+ # jit specific parameters
1422
+ static_argnums: Union[int, Iterable[int]] = (),
1423
+ static_argnames: Union[str, Iterable[str]] = (),
1424
+ axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
1425
+ abstracted_axes: Optional[Any] = None,
1426
+ # mapping specific parameters
1427
+ axis_size: Optional[int] = None,
1428
+ axis_name: AxisName | None = None,
1429
+ name: Optional[str] = None,
1430
+ # mapping function
1431
+ mapping_fn: Callable = jax.vmap,
1432
+ ):
1433
+ self.origin_fun = fun
1434
+ super().__init__(
1435
+ fun=self._wrapped_fun,
1436
+ static_argnums=static_argnums,
1437
+ static_argnames=static_argnames,
1438
+ axis_env=axis_env,
1439
+ abstracted_axes=abstracted_axes,
1440
+ name=name,
1441
+ return_only_write=False,
1442
+ )
1443
+ self.in_axes = in_axes
1444
+ self.out_axes = out_axes
1445
+ if state_in_axes is None:
1446
+ state_in_axes = dict()
1447
+ elif not isinstance(state_in_axes, dict):
1448
+ state_in_axes = {0: to_predicate(state_in_axes)}
1449
+ state_in_axes = {k: to_predicate(v) for k, v in state_in_axes.items()} # type: ignore
1450
+ self.state_in_axes = state_in_axes
1451
+
1452
+ if state_out_axes is None:
1453
+ state_out_axes = dict()
1454
+ elif not isinstance(state_out_axes, dict):
1455
+ state_out_axes = {0: to_predicate(state_out_axes)}
1456
+ state_out_axes = {k: to_predicate(v) for k, v in state_out_axes.items()} # type: ignore
1457
+ self.state_out_axes = state_out_axes
1458
+
1459
+ self.axis_size = axis_size
1460
+ self.axis_name = axis_name
1461
+ self.mapping_fn = mapping_fn
1462
+
1463
+ # Cache for discovered state-to-axis mappings
1464
+ self._cached_map_dim_to_in_states = _BoundedCache(maxsize=128)
1465
+ self._cached_map_dim_to_out_states = _BoundedCache(maxsize=128)
1466
+ self._cached_map_state_trace = _BoundedCache(maxsize=128)
1467
+ self._cached_map_batch_size = _BoundedCache(maxsize=128)
1468
+
1469
+ def _infer_batch_size(self, args, in_axes):
1470
+ if in_axes is None:
1471
+ raise ValueError("Cannot infer batch size when in_axes is None")
1472
+
1473
+ batch_sizes = []
1474
+
1475
+ def get_batch_size_from_arg(arg_, axis_):
1476
+ if axis_ is None:
1477
+ return None
1478
+
1479
+ def _get_size(arr):
1480
+ if not hasattr(arr, 'shape'):
1481
+ return None
1482
+ if arr.ndim == 0:
1483
+ return None
1484
+ ax = axis_ if axis_ >= 0 else arr.ndim + axis_
1485
+ if ax < 0 or ax >= arr.ndim:
1486
+ raise IndexError(f"Axis {ax} is out of bounds for array of shape {arr.shape}")
1487
+ return arr.shape[ax]
1488
+
1489
+ # Get all sizes from the pytree
1490
+ sizes = [s for s in jax.tree.leaves(jax.tree.map(_get_size, arg_)) if s is not None]
1491
+ return sizes[0] if sizes else None
1492
+
1493
+ if isinstance(in_axes, int):
1494
+ # All args batched along the same axis
1495
+ for arg in args:
1496
+ size = get_batch_size_from_arg(arg, in_axes)
1497
+ if size is not None:
1498
+ batch_sizes.append(size)
1499
+ elif isinstance(in_axes, (tuple, list)):
1500
+ # Different axes for different args
1501
+ if len(in_axes) != len(args):
1502
+ raise ValueError(
1503
+ f"Length of in_axes ({len(in_axes)}) must match number of arguments ({len(args)})"
1504
+ )
1505
+ for arg, axis in zip(args, in_axes):
1506
+ size = get_batch_size_from_arg(arg, axis)
1507
+ if size is not None:
1508
+ batch_sizes.append(size)
1509
+ else:
1510
+ raise TypeError(f"Unsupported in_axes type: {type(in_axes)}")
1511
+
1512
+ if not batch_sizes:
1513
+ if self.axis_size is None:
1514
+ raise ValueError("Cannot infer batch size when axis_size is None")
1515
+ batch_sizes.append(self.axis_size)
1516
+
1517
+ # Check all batch sizes are consistent
1518
+ if not all(s == batch_sizes[0] for s in batch_sizes):
1519
+ raise ValueError(
1520
+ f"Inconsistent batch sizes found: {batch_sizes}. "
1521
+ f"All batched arguments must have the same size along their batch axes."
1522
+ )
1523
+
1524
+ return batch_sizes[0]
1525
+
1526
+ def __new_batch_arg(self, batch_size: int, dim_to_states: dict):
1527
+ trace = jax.core.trace_ctx.trace
1528
+ assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
1529
+
1530
+ def wrapper(x):
1531
+ if isinstance(x, RandomState):
1532
+ idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current()))
1533
+ dim_to_states['random'].append(x)
1534
+ return to_elt(trace, idx, jnp.ones((batch_size,) + x._value.shape, x._value.dtype), 0)
1535
+ for dim, filter_ in self.state_in_axes.items():
1536
+ idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current()))
1537
+ if filter_(tuple(), x):
1538
+ dim_to_states[dim].append(x)
1539
+ return jax.tree.map(lambda xx: to_elt(trace, idx, xx, dim), x._value)
1540
+ return x._value
1541
+
1542
+ return wrapper
1543
+
1544
+ def __eval(self, cache_key, *args, **kwargs):
1545
+ def fn_to_eval(*new_args, **new_kwargs):
1546
+ dim_to_in_states = defaultdict(list)
1547
+ state_trace = StateTraceStack(name=self.name)
1548
+ state_trace.set_new_arg(
1549
+ self.__new_batch_arg(self._cached_map_batch_size.get(cache_key), dim_to_in_states)
1550
+ )
1551
+ self._cached_map_state_trace.set(cache_key, state_trace)
1552
+
1553
+ # call functions
1554
+ with state_trace:
1555
+ out_ = self.origin_fun(*new_args, **new_kwargs)
1556
+
1557
+ # cache
1558
+ self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states)
1559
+
1560
+ # vmapped state values
1561
+ out_states = defaultdict(list)
1562
+ out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
1563
+ for st in state_trace.states:
1564
+ if not isinstance(st, RandomState):
1565
+ leaves = jax.tree.leaves(st._value)
1566
+ batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
1567
+ if len(batch_dims) != 1:
1568
+ raise ValueError(
1569
+ f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
1570
+ "All leaves must have the same batch dimension."
1571
+ )
1572
+ batch_dim = batch_dims.pop()
1573
+ out_states[batch_dim].append(st)
1574
+ self._cached_map_dim_to_out_states.set(cache_key, out_states)
1575
+
1576
+ try:
1577
+ jax.vmap(
1578
+ fn_to_eval,
1579
+ in_axes=self.in_axes,
1580
+ out_axes=self.out_axes,
1581
+ axis_name=self.axis_name,
1582
+ axis_size=self.axis_size
1583
+ )(*args, **kwargs)
1584
+ self._cached_map_state_trace.get(cache_key).recovery_original_values()
1585
+ except Exception as e:
1586
+ if cache_key in self._cached_map_state_trace:
1587
+ self._cached_map_state_trace.get(cache_key).recovery_original_values()
1588
+ self._cached_map_state_trace.pop(cache_key, None)
1589
+ self._cached_map_dim_to_in_states.pop(cache_key, None)
1590
+ self._cached_map_dim_to_out_states.pop(cache_key, None)
1591
+ self._cached_map_batch_size.pop(cache_key, None)
1592
+ raise RuntimeError(f"Failed to evaluate {self}") from e
1593
+
1594
+ def __assign_vals_from_in_states(self, cache_key, rand_st, *other_st):
1595
+ in_states = self._cached_map_dim_to_in_states.get(cache_key)
1596
+ for st, val in zip(in_states['random'], rand_st):
1597
+ assert isinstance(st, RandomState)
1598
+ st.restore_value(val)
1599
+ for group, group_vals in zip([in_states[dim] for dim in in_states.keys() if dim != 'random'], other_st):
1600
+ for st, val in zip(group, group_vals):
1601
+ st.restore_value(val)
1602
+
1603
+ def __assign_vals_from_out_states(self, cache_key, rand_st, *other_st):
1604
+ out_states = self._cached_map_dim_to_out_states.get(cache_key)
1605
+ for st, val in zip(out_states['random'], rand_st):
1606
+ assert isinstance(st, RandomState)
1607
+ st.restore_value(val)
1608
+ for group, group_vals in zip([out_states[dim] for dim in out_states.keys() if dim != 'random'], other_st):
1609
+ for st, val in zip(group, group_vals):
1610
+ st.restore_value(val)
1611
+
1612
+ def __get_in_state_vals(self, cache_key: Hashable):
1613
+ in_states = self._cached_map_dim_to_in_states.get(cache_key)
1614
+ in_axes = []
1615
+ in_values = []
1616
+ for dim, states in in_states.items():
1617
+ if dim == 'random':
1618
+ continue
1619
+ in_axes.append(dim)
1620
+ in_values.append([st.value for st in states])
1621
+ return tuple(in_axes), in_values
1622
+
1623
+ def __get_out_state_vals(self, cache_key: Hashable):
1624
+ out_states = self._cached_map_dim_to_out_states.get(cache_key)
1625
+ out_axes = []
1626
+ out_values = []
1627
+ for dim, state in out_states.items():
1628
+ if dim == 'random':
1629
+ continue
1630
+ out_axes.append(dim)
1631
+ out_values.append([st.value for st in state])
1632
+ return tuple(out_axes), out_values
1633
+
1634
+ def __get_rand_state_vals(self, cache_key: Hashable):
1635
+ in_states = self._cached_map_dim_to_in_states.get(cache_key)
1636
+ batch_size = self._cached_map_batch_size.get(cache_key)
1637
+ rand_vals, rand_recover_vals = [], []
1638
+ for st in in_states['random']:
1639
+ assert isinstance(st, RandomState)
1640
+ rand_vals.append(st.split_key(batch_size))
1641
+ rand_recover_vals.append(st.value)
1642
+ return tuple(rand_vals), tuple(rand_recover_vals)
1643
+
1644
+ def __recover_rand_state_vals(self, cache_key: Hashable, rand_recover_vals):
1645
+ state_trace = self._cached_map_state_trace.get(cache_key)
1646
+ rand_states = [st for st in state_trace.states if isinstance(st, RandomState)]
1647
+ for st, val in zip(rand_states, rand_recover_vals):
1648
+ st.restore_value(val)
1649
+
1650
+ def _wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
1651
+ batch_size = self._infer_batch_size(args, self.in_axes)
1652
+ cache_key = self.get_arg_cache_key(*args, **kwargs)
1653
+ self._cached_map_batch_size.set(cache_key, batch_size)
1654
+ if cache_key not in self._cached_map_state_trace:
1655
+ self.__eval(cache_key, *args, **kwargs)
1656
+
1657
+ def fn_to_map(origin_args, rand_st, *non_rand_st):
1658
+ self.__assign_vals_from_in_states(cache_key, rand_st, *non_rand_st)
1659
+ out = self.origin_fun(*origin_args[0], **origin_args[1])
1660
+ return out, *self.__get_out_state_vals(cache_key)[1]
1661
+
1662
+ in_axes, in_state_vals = self.__get_in_state_vals(cache_key)
1663
+ out_axes, out_state_vals = self.__get_out_state_vals(cache_key)
1664
+ rand_vals, rand_recover_vals = self.__get_rand_state_vals(cache_key)
1665
+ mapped_fn = self.mapping_fn(
1666
+ fn_to_map,
1667
+ in_axes=(self.in_axes, 0) + in_axes,
1668
+ out_axes=(self.out_axes,) + out_axes,
1669
+ axis_size=self.axis_size,
1670
+ axis_name=self.axis_name,
1671
+ )
1672
+ out_, *out_state_vals = mapped_fn((args, kwargs), rand_vals, *in_state_vals)
1673
+ self.__assign_vals_from_out_states(cache_key, rand_recover_vals, *out_state_vals)
1674
+ return out_
1675
+
1676
+
1677
+ def _check_callable(fun):
1678
+ # In Python 3.10+, the only thing stopping us from supporting static methods
1679
+ # is that we can't take weak references to them, which the C++ JIT requires.
1680
+ if isinstance(fun, staticmethod):
1681
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
1682
+ if not callable(fun):
1683
+ raise TypeError(f"Expected a callable value, got {fun}")
1684
+ if inspect.isgeneratorfunction(fun):
1685
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
1686
+
1687
+
1688
+ def _broadcast_prefix(
1689
+ prefix_tree: Any,
1690
+ full_tree: Any,
1691
+ is_leaf: Callable[[Any], bool] | None = None
1692
+ ) -> list[Any]:
1693
+ # If prefix_tree is not a tree prefix of full_tree, this code can raise a
1694
+ # ValueError; use prefix_errors to find disagreements and raise more precise
1695
+ # error messages.
1696
+ result = []
1697
+ num_leaves = lambda t: jax.tree.structure(t).num_leaves
1698
+ add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
1699
+ jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
1700
+ return result
1701
+
1702
+
1703
+ def _flat_axes_specs(
1704
+ abstracted_axes, *args, **kwargs
1705
+ ) -> list[pe.AbstractedAxesSpec]:
1706
+ if kwargs:
1707
+ raise NotImplementedError
1708
+
1709
+ def ax_leaf(l):
1710
+ return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
1711
+ isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
1712
+
1713
+ return _broadcast_prefix(abstracted_axes, args, ax_leaf)
1714
+
1715
+
1716
+ @transformation_with_aux
1717
+ def _flatten_fun(in_tree, *args_flat):
1718
+ py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
1719
+ ans = yield py_args, py_kwargs
1720
+ yield jax.tree.flatten(ans)
1721
+
1722
+
1723
+ def _make_jaxpr(
1724
+ fun: Callable,
1725
+ static_argnums: int | Iterable[int] = (),
1726
+ axis_env: Sequence[tuple[AxisName, int]] | None = None,
1727
+ return_shape: bool = False,
1728
+ abstracted_axes: Any | None = None,
1729
+ ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
1730
+ """
1731
+ Create a function that produces its jaxpr given example args (internal implementation).
1732
+
1733
+ This is an internal implementation function. Users should use the public
1734
+ ``make_jaxpr`` function instead.
1735
+
1736
+ Parameters
1737
+ ----------
1738
+ fun : Callable
1739
+ The function whose ``jaxpr`` is to be computed. Its positional
1740
+ arguments and return value should be arrays, scalars, or standard Python
1741
+ containers (tuple/list/dict) thereof.
1742
+ static_argnums : int or iterable of int, optional
1743
+ See the :py:func:`jax.jit` docstring.
1744
+ axis_env : sequence of tuple, optional
1745
+ A sequence of pairs where the first element is an axis
1746
+ name and the second element is a positive integer representing the size of
1747
+ the mapped axis with that name. This parameter is useful when lowering
1748
+ functions that involve parallel communication collectives, and it
1749
+ specifies the axis name/size environment that would be set up by
1750
+ applications of :py:func:`jax.pmap`.
1751
+ return_shape : bool, default False
1752
+ If ``True``, the wrapped function returns a pair where the first element
1753
+ is the ``ClosedJaxpr`` representation of ``fun`` and the second element
1754
+ is a pytree with the same structure as the output of ``fun`` and where
1755
+ the leaves are objects with ``shape``, ``dtype``, and ``named_shape``
1756
+ attributes representing the corresponding types of the output leaves.
1757
+ abstracted_axes : Any, optional
1758
+ Axes specifications for abstract interpretation.
1759
+
1760
+ Returns
1761
+ -------
1762
+ Callable
1763
+ A wrapped version of ``fun`` that when applied to example arguments returns
1764
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
1765
+ argument ``return_shape`` is ``True``, then the returned function instead
1766
+ returns a pair where the first element is the ``ClosedJaxpr``
1767
+ representation of ``fun`` and the second element is a pytree representing
1768
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
1769
+
1770
+ Notes
1771
+ -----
1772
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
1773
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
1774
+ with let-bindings. This function adapts a function to return its
1775
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
1776
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
1777
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
1778
+
1779
+ Examples
1780
+ --------
1781
+ .. code-block:: python
1782
+
1783
+ >>> import jax
1784
+ >>>
1785
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
1786
+ >>> print(f(3.0))
1787
+ -0.83602
1788
+ >>> _make_jaxpr(f)(3.0)
1789
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
1790
+ >>> _make_jaxpr(jax.grad(f))(3.0)
1791
+ { lambda ; a:f32[]. let
1792
+ b:f32[] = cos a
1793
+ c:f32[] = sin a
1794
+ _:f32[] = sin b
1795
+ d:f32[] = cos b
1796
+ e:f32[] = mul 1.0 d
1797
+ f:f32[] = neg e
1798
+ g:f32[] = mul f c
1799
+ in (g,) }
1800
+ """
1801
+ _check_callable(fun)
1802
+ static_argnums = _ensure_index_tuple(static_argnums)
1803
+
1804
+ def _abstractify(args, kwargs):
1805
+ flat_args, in_tree = jax.tree.flatten((args, kwargs))
1806
+ if abstracted_axes is None:
1807
+ return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
1808
+ else:
1809
+ axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
1810
+ in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
1811
+ in_avals, keep_inputs = unzip2(in_type)
1812
+ return in_avals, in_tree, keep_inputs
1813
+
1814
+ @wraps(fun)
1815
+ @api_boundary
1816
+ def make_jaxpr_f(*args, **kwargs):
1817
+ f = wrap_init(fun, (), {}, 'brainstate.transform.make_jaxpr')
1818
+ if static_argnums:
1819
+ dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
1820
+ f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
1821
+ in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
1822
+ in_type = tuple(safe_zip(in_avals, keep_inputs))
1823
+ f, out_tree = _flatten_fun(f, in_tree)
1824
+ f = annotate(f, in_type)
1825
+ if jax.__version_info__ < (0, 5, 0):
1826
+ debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
1827
+ with ExitStack() as stack:
1828
+ if axis_env is not None:
1829
+ stack.enter_context(extend_axis_env_nd(axis_env))
1830
+ if jax.__version_info__ < (0, 5, 0):
1831
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
1832
+ else:
1833
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
1834
+ closed_jaxpr = ClosedJaxpr(jaxpr, consts)
1835
+ if return_shape:
1836
+ out_avals, _ = unzip2(out_type)
1837
+ out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
1838
+ return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
1839
+ return closed_jaxpr
1840
+
1841
+ make_jaxpr_f.__module__ = "brainstate.transform"
1842
+ if hasattr(fun, "__qualname__"):
1843
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
1844
+ if hasattr(fun, "__name__"):
1845
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
1846
+ return make_jaxpr_f
1847
+
1848
+
1849
+ def make_hashable(obj):
1850
+ """
1851
+ Convert a pytree into a hashable representation.
1852
+
1853
+ Parameters
1854
+ ----------
1855
+ obj : Any
1856
+ A pytree object (list, tuple, dict, set, or JAX pytree structure).
1857
+
1858
+ Returns
1859
+ -------
1860
+ Hashable
1861
+ A hashable representation of the input object. Lists become tuples,
1862
+ dicts become sorted tuples of key-value pairs, sets become frozensets,
1863
+ and other pytrees are flattened using JAX's tree utilities.
1864
+ """
1865
+ if isinstance(obj, (list, tuple)):
1866
+ return tuple(make_hashable(item) for item in obj)
1867
+ elif isinstance(obj, dict):
1868
+ return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
1869
+ elif isinstance(obj, set):
1870
+ return frozenset(make_hashable(item) for item in obj)
1871
+ else:
1872
+ # return obj
1873
+ # Use JAX's tree_util for any other pytree structures
1874
+ try:
1875
+ leaves, treedef = jax.tree.flatten(obj)
1876
+ return treedef, tuple(leaves)
1877
+ except (TypeError, ValueError):
1878
+ # Assume obj is already hashable
1879
+ return obj
1880
+
1881
+
1882
+ class IdentitySet(MutableSet):
1883
+ """Set that compares objects by identity.
1884
+
1885
+ This is a set that compares objects by identity instead of equality. It is
1886
+ useful for storing objects that are not hashable or that should be compared
1887
+ by identity.
1888
+
1889
+ This is a mutable set, but it does not support the ``__hash__`` method and
1890
+ therefore cannot be used as a dictionary key or as an element of another set.
1891
+ """
1892
+
1893
+ def __init__(self, iterable=None):
1894
+ self._data = {}
1895
+ if iterable is not None:
1896
+ self.update(iterable)
1897
+
1898
+ def __contains__(self, value):
1899
+ return id(value) in self._data
1900
+
1901
+ def __iter__(self):
1902
+ return iter(self._data.values())
1903
+
1904
+ def __len__(self):
1905
+ return len(self._data)
1906
+
1907
+ def add(self, value):
1908
+ self._data[id(value)] = value
1909
+
1910
+ def discard(self, value):
1911
+ self._data.pop(id(value), None)
1912
+
1913
+ def __repr__(self):
1914
+ return f"IdentitySet({list(repr(x) for x in self._data.values())})"
1915
+
1916
+ def __str__(self):
1917
+ return f"IdentitySet({list(str(x) for x in self._data.values())})"
1918
+
1919
+
1920
+ def constant_fold_jaxpr(jaxpr: Jaxpr):
1921
+ """
1922
+ Given a jaxpr, return a new jaxpr with all constant folding done.
1923
+ """
1924
+ return _partial_eval_jaxpr(jaxpr, {})
1925
+
1926
+
1927
+ def _partial_eval_jaxpr(jaxpr, env):
1928
+ env = env.copy()
1929
+ new_eqns = []
1930
+
1931
+ def read(var):
1932
+ if isinstance(var, Literal):
1933
+ return var.val
1934
+ else:
1935
+ return env.get(var, None)
1936
+
1937
+ def read_or_self(var):
1938
+ out = read(var)
1939
+ if out is None:
1940
+ return var
1941
+ elif isinstance(out, Var):
1942
+ return out
1943
+ elif isinstance(out, Literal):
1944
+ return Literal(out.val, var.aval)
1945
+ else:
1946
+ assert not isinstance(out, Jaxpr)
1947
+ return Literal(out, var.aval)
1948
+
1949
+ for eqn in jaxpr.eqns:
1950
+ vals = [read(var) for var in eqn.invars]
1951
+ if eqn.primitive.name in _constant_fold_blacklist:
1952
+ new_eqns.append(eqn)
1953
+ elif all(val is not None for val in vals):
1954
+ # go ahead and eval it
1955
+ out = _eval_eqn(eqn, vals)
1956
+
1957
+ # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
1958
+ if isinstance(out, Jaxpr):
1959
+ # we need to inline this
1960
+ new_eqns.extend(out.eqns)
1961
+ out = out.outvars
1962
+ elif not isinstance(out, tuple) and not isinstance(out, list):
1963
+ out = (out,)
1964
+
1965
+ for var, val in zip(eqn.outvars, out):
1966
+ assert not isinstance(val, Jaxpr)
1967
+ if isinstance(val, Literal):
1968
+ env[var] = val.val
1969
+ else:
1970
+ env[var] = val
1971
+ else:
1972
+ new_eqns.append(eqn)
1973
+
1974
+ # now that we've eval everything, inline all the constants
1975
+ out_eqns = []
1976
+ for eqn in new_eqns:
1977
+ eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
1978
+ out_eqns.append(eqn)
1979
+
1980
+ invars_still_used = IdentitySet()
1981
+ for eqn in out_eqns:
1982
+ for var in eqn.invars:
1983
+ invars_still_used.add(var)
1984
+
1985
+ invars = tuple(var for var in jaxpr.invars if var in invars_still_used)
1986
+
1987
+ # sub in any constants for outvars
1988
+ outvars = tuple(read_or_self(var) for var in jaxpr.outvars)
1989
+
1990
+ return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars, debug_info=None)
1991
+
1992
+
1993
+ def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
1994
+ if eqn.primitive.name == "closed_call":
1995
+ assert eqn.primitive.call_primitive
1996
+ assert not eqn.primitive.map_primitive
1997
+
1998
+ out = _partial_eval_jaxpr(
1999
+ eqn.params['call_jaxpr'].jaxpr,
2000
+ {
2001
+ var: val
2002
+ for var, val in
2003
+ zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)
2004
+ }
2005
+ )
2006
+ elif eqn.primitive.name == "scan":
2007
+ out = eqn.primitive.bind(*vals, **eqn.params)
2008
+ else:
2009
+ out = eqn.primitive.bind(*vals, **eqn.params)
2010
+ return out
2011
+
2012
+
2013
+ _constant_fold_blacklist = {
2014
+ 'broadcast_in_dim',
2015
+ 'broadcast',
2016
+ }