brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2016 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
|
18
|
+
written by the function. These state transformations are foundational for the BrainCore library. These utilities
|
19
|
+
include two basic functions: `StatefulFunction` and `make_jaxpr`.
|
20
|
+
|
21
|
+
|
22
|
+
``StatefulFunction``
|
23
|
+
--------------------
|
24
|
+
|
25
|
+
The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
|
26
|
+
JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
|
27
|
+
The class provides the following methods:
|
28
|
+
|
29
|
+
- `make_jaxpr`: creates the JAX Jaxpr of the function.
|
30
|
+
- `jaxpr_call`: calls the function at the JAX Jaxpr level.
|
31
|
+
- `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
|
32
|
+
- `get_states`: returns the states that are read and written by the function.
|
33
|
+
- `get_read_states`: returns the states that are read by the function.
|
34
|
+
- `get_write_states`: returns the states that are written by the function.
|
35
|
+
- `get_static_args`: returns the static arguments from the arguments.
|
36
|
+
- `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
|
37
|
+
written by the function.
|
38
|
+
- `get_jaxpr`: returns the JAX Jaxpr of the function.
|
39
|
+
- `get_out_shapes`: returns the output shapes of the function.
|
40
|
+
- `get_out_treedef`: returns the output tree of the function.
|
41
|
+
|
42
|
+
``make_jaxpr``
|
43
|
+
--------------
|
44
|
+
|
45
|
+
The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
|
46
|
+
arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
|
47
|
+
`ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
|
48
|
+
returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
|
49
|
+
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
|
50
|
+
function.
|
51
|
+
|
52
|
+
"""
|
53
|
+
|
54
|
+
import functools
|
55
|
+
import inspect
|
56
|
+
import operator
|
57
|
+
import threading
|
58
|
+
from collections import OrderedDict, defaultdict
|
59
|
+
from collections.abc import Hashable, Iterable, Sequence
|
60
|
+
from collections.abc import MutableSet
|
61
|
+
from contextlib import ExitStack
|
62
|
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
63
|
+
|
64
|
+
import jax
|
65
|
+
import jax.numpy as jnp
|
66
|
+
from jax._src import source_info_util
|
67
|
+
from jax._src.linear_util import annotate
|
68
|
+
from jax._src.traceback_util import api_boundary
|
69
|
+
from jax._src.util import memoize
|
70
|
+
from jax.api_util import shaped_abstractify
|
71
|
+
from jax.extend.linear_util import transformation_with_aux
|
72
|
+
from jax.interpreters import partial_eval as pe
|
73
|
+
|
74
|
+
from brainstate._compatible_import import (
|
75
|
+
ClosedJaxpr, extend_axis_env_nd, safe_map, safe_zip, unzip2, wraps, wrap_init,
|
76
|
+
Literal, Var, Jaxpr, make_iota, to_elt, BatchTracer, BatchTrace,
|
77
|
+
)
|
78
|
+
from brainstate._state import State, StateTraceStack
|
79
|
+
from brainstate._utils import set_module_as
|
80
|
+
from brainstate.random import RandomState
|
81
|
+
from brainstate.typing import Filter, PyTree
|
82
|
+
from brainstate.util import PrettyObject
|
83
|
+
from brainstate.util.filter import to_predicate
|
84
|
+
|
85
|
+
AxisName = Hashable
|
86
|
+
|
87
|
+
__all__ = [
|
88
|
+
"StatefulFunction",
|
89
|
+
"make_jaxpr",
|
90
|
+
"StatefulMapping",
|
91
|
+
]
|
92
|
+
|
93
|
+
|
94
|
+
class hashabledict(dict):
|
95
|
+
def __hash__(self):
|
96
|
+
return hash(tuple(sorted(self.items())))
|
97
|
+
|
98
|
+
|
99
|
+
class _BoundedCache:
|
100
|
+
"""
|
101
|
+
A thread-safe LRU cache with bounded size.
|
102
|
+
|
103
|
+
This cache stores a limited number of items and evicts the least recently used item
|
104
|
+
when the cache reaches its maximum size. All operations are thread-safe.
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
----------
|
108
|
+
maxsize : int, default 128
|
109
|
+
Maximum number of items to store in the cache.
|
110
|
+
"""
|
111
|
+
|
112
|
+
def __init__(self, maxsize: int = 128):
|
113
|
+
self._cache = OrderedDict()
|
114
|
+
self._maxsize = maxsize
|
115
|
+
self._lock = threading.RLock()
|
116
|
+
self._hits = 0
|
117
|
+
self._misses = 0
|
118
|
+
|
119
|
+
def get(
|
120
|
+
self,
|
121
|
+
key: Any,
|
122
|
+
default: Any = None,
|
123
|
+
raise_on_miss: bool = False,
|
124
|
+
error_context: str = "item"
|
125
|
+
) -> Any:
|
126
|
+
"""
|
127
|
+
Get an item from the cache.
|
128
|
+
|
129
|
+
Parameters
|
130
|
+
----------
|
131
|
+
key : Any
|
132
|
+
The cache key.
|
133
|
+
default : Any, optional
|
134
|
+
The default value to return if the key is not found.
|
135
|
+
raise_on_miss : bool, optional
|
136
|
+
If True, raise a detailed ValueError when the key is not found.
|
137
|
+
error_context : str, optional
|
138
|
+
Context description for the error message (e.g., "Function", "JAX expression").
|
139
|
+
|
140
|
+
Returns
|
141
|
+
-------
|
142
|
+
Any
|
143
|
+
The cached value or the default value.
|
144
|
+
|
145
|
+
Raises
|
146
|
+
------
|
147
|
+
ValueError
|
148
|
+
If raise_on_miss is True and the key is not found.
|
149
|
+
"""
|
150
|
+
with self._lock:
|
151
|
+
if key in self._cache:
|
152
|
+
self._cache.move_to_end(key)
|
153
|
+
self._hits += 1
|
154
|
+
return self._cache[key]
|
155
|
+
self._misses += 1
|
156
|
+
|
157
|
+
if raise_on_miss:
|
158
|
+
available_keys = list(self._cache.keys())
|
159
|
+
error_msg = [
|
160
|
+
f"{error_context} not compiled for the requested cache key.",
|
161
|
+
f"",
|
162
|
+
f"Requested key:",
|
163
|
+
f" {key}",
|
164
|
+
f"",
|
165
|
+
f"Available {{len(available_keys)}} keys:",
|
166
|
+
]
|
167
|
+
if available_keys:
|
168
|
+
for i, k in enumerate(available_keys, 1):
|
169
|
+
error_msg.append(f" [{i}] {k}")
|
170
|
+
else:
|
171
|
+
error_msg.append(" (none - not compiled yet)")
|
172
|
+
error_msg.append("")
|
173
|
+
error_msg.append("Call make_jaxpr() first with matching arguments.")
|
174
|
+
raise ValueError("\n".join(error_msg))
|
175
|
+
|
176
|
+
return default
|
177
|
+
|
178
|
+
def set(self, key: Any, value: Any) -> None:
|
179
|
+
"""
|
180
|
+
Set an item in the cache.
|
181
|
+
|
182
|
+
Parameters
|
183
|
+
----------
|
184
|
+
key : Any
|
185
|
+
The cache key.
|
186
|
+
value : Any
|
187
|
+
The value to cache.
|
188
|
+
|
189
|
+
Raises
|
190
|
+
------
|
191
|
+
ValueError
|
192
|
+
If the key already exists in the cache.
|
193
|
+
"""
|
194
|
+
with self._lock:
|
195
|
+
if key in self._cache:
|
196
|
+
raise ValueError(
|
197
|
+
f"Cache key already exists: {key}. "
|
198
|
+
f"Cannot overwrite existing cached value. "
|
199
|
+
f"Clear the cache first if you need to recompile."
|
200
|
+
)
|
201
|
+
if len(self._cache) >= self._maxsize:
|
202
|
+
self._cache.popitem(last=False)
|
203
|
+
self._cache[key] = value
|
204
|
+
|
205
|
+
def pop(self, key: Any, default: Any = None) -> Any:
|
206
|
+
"""
|
207
|
+
Remove and return an item from the cache.
|
208
|
+
|
209
|
+
Parameters
|
210
|
+
----------
|
211
|
+
key : Any
|
212
|
+
The cache key to remove.
|
213
|
+
default : Any, optional
|
214
|
+
The default value to return if the key is not found.
|
215
|
+
|
216
|
+
Returns
|
217
|
+
-------
|
218
|
+
Any
|
219
|
+
The cached value or the default value if the key is not found.
|
220
|
+
"""
|
221
|
+
with self._lock:
|
222
|
+
if key in self._cache:
|
223
|
+
return self._cache.pop(key)
|
224
|
+
return default
|
225
|
+
|
226
|
+
def replace(self, key: Any, value: Any) -> None:
|
227
|
+
"""
|
228
|
+
Replace an existing item in the cache.
|
229
|
+
|
230
|
+
Parameters
|
231
|
+
----------
|
232
|
+
key : Any
|
233
|
+
The cache key to replace.
|
234
|
+
value : Any
|
235
|
+
The new value to cache.
|
236
|
+
|
237
|
+
Raises
|
238
|
+
------
|
239
|
+
KeyError
|
240
|
+
If the key does not exist in the cache.
|
241
|
+
"""
|
242
|
+
with self._lock:
|
243
|
+
if key not in self._cache:
|
244
|
+
raise KeyError(
|
245
|
+
f"Cache key does not exist: {key}. "
|
246
|
+
f"Cannot replace non-existent cached value. "
|
247
|
+
f"Use set() to add a new cache entry."
|
248
|
+
)
|
249
|
+
self._cache[key] = value
|
250
|
+
self._cache.move_to_end(key)
|
251
|
+
|
252
|
+
def __contains__(self, key: Any) -> bool:
|
253
|
+
"""
|
254
|
+
Check if a key exists in the cache.
|
255
|
+
|
256
|
+
Parameters
|
257
|
+
----------
|
258
|
+
key : Any
|
259
|
+
The cache key to check.
|
260
|
+
|
261
|
+
Returns
|
262
|
+
-------
|
263
|
+
bool
|
264
|
+
True if the key exists in the cache, False otherwise.
|
265
|
+
"""
|
266
|
+
with self._lock:
|
267
|
+
return key in self._cache
|
268
|
+
|
269
|
+
def __len__(self) -> int:
|
270
|
+
"""
|
271
|
+
Get the number of items in the cache.
|
272
|
+
|
273
|
+
Returns
|
274
|
+
-------
|
275
|
+
int
|
276
|
+
The number of items currently in the cache.
|
277
|
+
"""
|
278
|
+
with self._lock:
|
279
|
+
return len(self._cache)
|
280
|
+
|
281
|
+
def clear(self) -> None:
|
282
|
+
"""
|
283
|
+
Clear all items from the cache and reset statistics.
|
284
|
+
|
285
|
+
This method removes all cached items and resets hit/miss counters to zero.
|
286
|
+
"""
|
287
|
+
with self._lock:
|
288
|
+
self._cache.clear()
|
289
|
+
self._hits = 0
|
290
|
+
self._misses = 0
|
291
|
+
|
292
|
+
def keys(self):
|
293
|
+
"""
|
294
|
+
Return all keys in the cache.
|
295
|
+
|
296
|
+
Returns
|
297
|
+
-------
|
298
|
+
list
|
299
|
+
A list of all keys currently in the cache.
|
300
|
+
"""
|
301
|
+
with self._lock:
|
302
|
+
return list(self._cache.keys())
|
303
|
+
|
304
|
+
def get_stats(self) -> Dict[str, Any]:
|
305
|
+
"""
|
306
|
+
Get cache statistics.
|
307
|
+
|
308
|
+
Returns
|
309
|
+
-------
|
310
|
+
dict
|
311
|
+
A dictionary with cache statistics including:
|
312
|
+
|
313
|
+
- 'size': Current number of items in cache
|
314
|
+
- 'maxsize': Maximum cache size
|
315
|
+
- 'hits': Number of cache hits
|
316
|
+
- 'misses': Number of cache misses
|
317
|
+
- 'hit_rate': Hit rate percentage (0-100)
|
318
|
+
"""
|
319
|
+
with self._lock:
|
320
|
+
total = self._hits + self._misses
|
321
|
+
hit_rate = (self._hits / total * 100) if total > 0 else 0.0
|
322
|
+
return {
|
323
|
+
'size': len(self._cache),
|
324
|
+
'maxsize': self._maxsize,
|
325
|
+
'hits': self._hits,
|
326
|
+
'misses': self._misses,
|
327
|
+
'hit_rate': hit_rate,
|
328
|
+
}
|
329
|
+
|
330
|
+
|
331
|
+
def _ensure_str(x: str) -> str:
|
332
|
+
if not isinstance(x, str):
|
333
|
+
raise TypeError(f"argument is not a string: {x}")
|
334
|
+
return x
|
335
|
+
|
336
|
+
|
337
|
+
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
338
|
+
"""Convert x to a tuple of indices."""
|
339
|
+
x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
340
|
+
try:
|
341
|
+
return (operator.index(x),)
|
342
|
+
except TypeError:
|
343
|
+
return tuple(safe_map(operator.index, x))
|
344
|
+
|
345
|
+
|
346
|
+
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
347
|
+
"""Convert x to a tuple of strings."""
|
348
|
+
if isinstance(x, str):
|
349
|
+
return (x,)
|
350
|
+
else:
|
351
|
+
return tuple(safe_map(_ensure_str, x))
|
352
|
+
|
353
|
+
|
354
|
+
def _jax_v04_new_arg_fn(frame, trace, aval):
|
355
|
+
"""
|
356
|
+
Transform a new argument to a tracer.
|
357
|
+
|
358
|
+
Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
|
359
|
+
|
360
|
+
Args:
|
361
|
+
frame: The frame.
|
362
|
+
trace: The trace.
|
363
|
+
aval: The abstract value.
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
The tracer.
|
367
|
+
"""
|
368
|
+
tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
|
369
|
+
frame.tracers.append(tracer)
|
370
|
+
frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
|
371
|
+
frame.invars.append(var)
|
372
|
+
return tracer
|
373
|
+
|
374
|
+
|
375
|
+
def _jax_v04_new_jax_trace():
|
376
|
+
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
377
|
+
frame = main.jaxpr_stack[-1]
|
378
|
+
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
379
|
+
return frame, trace
|
380
|
+
|
381
|
+
|
382
|
+
class StatefulFunction(PrettyObject):
|
383
|
+
"""
|
384
|
+
A wrapper class for functions that tracks state reads and writes during execution.
|
385
|
+
|
386
|
+
This class wraps a function to enable state management in JAX programs by tracking
|
387
|
+
which states are read from and written to during function execution. It provides
|
388
|
+
methods to compile the function into JAX's intermediate representation (jaxpr),
|
389
|
+
inspect state usage, and execute the function with proper state handling.
|
390
|
+
|
391
|
+
When you define a function:
|
392
|
+
|
393
|
+
.. code-block:: python
|
394
|
+
|
395
|
+
>>> state = brainstate.State(1.)
|
396
|
+
>>> def f(x):
|
397
|
+
... # Your function logic here
|
398
|
+
... y = x * 2 + state.value
|
399
|
+
... state.value = y
|
400
|
+
|
401
|
+
Calling ``sf = StatefulFunction(f)`` creates a stateful version of ``f``. You can
|
402
|
+
then call it directly with compatibility with JIT:
|
403
|
+
|
404
|
+
.. code-block:: python
|
405
|
+
|
406
|
+
>>> sf = brainstate.transform.StatefulFunction(f)
|
407
|
+
>>> out = sf(x) # Automatically compiles and executes
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
fun : callable
|
412
|
+
The function whose ``jaxpr`` is to be computed. Its positional
|
413
|
+
arguments and return value should be arrays, scalars, or standard Python
|
414
|
+
containers (tuple/list/dict) thereof.
|
415
|
+
static_argnums : int or iterable of int, optional
|
416
|
+
Indices of positional arguments to treat as static (known at compile time).
|
417
|
+
See :py:func:`jax.jit` for details. Default is ().
|
418
|
+
static_argnames : str or iterable of str, optional
|
419
|
+
Names of keyword arguments to treat as static (known at compile time).
|
420
|
+
See :py:func:`jax.jit` for details. Default is ().
|
421
|
+
axis_env : sequence of tuple, optional
|
422
|
+
A sequence of pairs where the first element is an axis name and the second
|
423
|
+
element is a positive integer representing the size of the mapped axis with
|
424
|
+
that name. This parameter is useful when lowering functions that involve
|
425
|
+
parallel communication collectives, and it specifies the axis name/size
|
426
|
+
environment that would be set up by applications of :py:func:`jax.pmap`.
|
427
|
+
Default is None.
|
428
|
+
abstracted_axes : pytree, optional
|
429
|
+
A pytree with the same structure as the input arguments to ``fun``. The
|
430
|
+
leaves of the pytree can be either None or a dict with axis names as keys
|
431
|
+
and integers as values. If the leaf is None, then the corresponding axis
|
432
|
+
is not abstracted. If the leaf is a dict, then the corresponding axis is
|
433
|
+
abstracted, and the dict specifies the axis name and size. The abstracted
|
434
|
+
axes are used to infer the input type of the function. If None, then all
|
435
|
+
axes are abstracted. Default is None.
|
436
|
+
name : str, optional
|
437
|
+
Name for the stateful function. Default is None.
|
438
|
+
return_only_write : bool, optional
|
439
|
+
If True, only return states that were written to during execution
|
440
|
+
(not just read). This can reduce memory usage when you only care
|
441
|
+
about modified states. Default is True.
|
442
|
+
|
443
|
+
Attributes
|
444
|
+
----------
|
445
|
+
fun : callable
|
446
|
+
The wrapped function.
|
447
|
+
static_argnums : tuple of int
|
448
|
+
Indices of static positional arguments.
|
449
|
+
static_argnames : tuple of str
|
450
|
+
Names of static keyword arguments.
|
451
|
+
axis_env : sequence of tuple or None
|
452
|
+
Axis environment for parallel operations.
|
453
|
+
abstracted_axes : pytree or None
|
454
|
+
Abstract axes specification.
|
455
|
+
name : str or None
|
456
|
+
Name identifier for the function.
|
457
|
+
return_only_write : bool
|
458
|
+
Whether to return only written states.
|
459
|
+
|
460
|
+
Examples
|
461
|
+
--------
|
462
|
+
Basic usage with state management:
|
463
|
+
|
464
|
+
.. code-block:: python
|
465
|
+
|
466
|
+
>>> import brainstate
|
467
|
+
>>> import jax.numpy as jnp
|
468
|
+
>>>
|
469
|
+
>>> # Create a state
|
470
|
+
>>> state = brainstate.State(jnp.array([1.0, 2.0]))
|
471
|
+
>>>
|
472
|
+
>>> def f(x):
|
473
|
+
... state.value += x
|
474
|
+
... return state.value * 2
|
475
|
+
>>>
|
476
|
+
>>> # Create a stateful function
|
477
|
+
>>> sf = brainstate.transform.StatefulFunction(f)
|
478
|
+
>>>
|
479
|
+
>>> # Compile and get jaxpr
|
480
|
+
>>> x = jnp.array([0.5, 0.5])
|
481
|
+
>>> sf.make_jaxpr(x)
|
482
|
+
>>>
|
483
|
+
>>> # Get states that are read/written
|
484
|
+
>>> cache_key = sf.get_arg_cache_key(x)
|
485
|
+
>>> states = sf.get_states_by_cache(cache_key)
|
486
|
+
>>> read_states = sf.get_read_states_by_cache(cache_key)
|
487
|
+
>>> write_states = sf.get_write_states_by_cache(cache_key)
|
488
|
+
|
489
|
+
Using with static arguments:
|
490
|
+
|
491
|
+
.. code-block:: python
|
492
|
+
|
493
|
+
>>> def g(x, n):
|
494
|
+
... state.value = state.value ** n
|
495
|
+
... return state.value
|
496
|
+
>>>
|
497
|
+
>>> sf_static = brainstate.transform.StatefulFunction(
|
498
|
+
... g, static_argnums=(1,)
|
499
|
+
... )
|
500
|
+
>>> sf_static.make_jaxpr(x, 2)
|
501
|
+
|
502
|
+
Automatic state management:
|
503
|
+
|
504
|
+
.. code-block:: python
|
505
|
+
|
506
|
+
>>> # Execute with automatic state handling
|
507
|
+
>>> result = sf.jaxpr_call_auto(x)
|
508
|
+
>>> print(state.value) # State is automatically updated
|
509
|
+
|
510
|
+
See Also
|
511
|
+
--------
|
512
|
+
make_jaxpr : Function to create jaxpr from a function.
|
513
|
+
brainstate.State : The state container class.
|
514
|
+
|
515
|
+
Notes
|
516
|
+
-----
|
517
|
+
This class maintains internal thread-safe caches for compiled jaxprs, output
|
518
|
+
shapes, and state traces. The cache size is bounded at 128 entries per cache
|
519
|
+
type. Use ``clear_cache()`` to manually clear the caches if needed.
|
520
|
+
|
521
|
+
State objects should not be passed as direct inputs or outputs to the wrapped
|
522
|
+
function. Instead, they should be accessed within the function body, and the
|
523
|
+
class will automatically track their usage.
|
524
|
+
"""
|
525
|
+
__module__ = "brainstate.transform"
|
526
|
+
|
527
|
+
def __init__(
|
528
|
+
self,
|
529
|
+
fun: Callable,
|
530
|
+
static_argnums: Union[int, Iterable[int]] = (),
|
531
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
532
|
+
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
533
|
+
abstracted_axes: Optional[Any] = None,
|
534
|
+
name: Optional[str] = None,
|
535
|
+
return_only_write: bool = True,
|
536
|
+
):
|
537
|
+
# explicit parameters
|
538
|
+
self.fun = fun
|
539
|
+
self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
540
|
+
self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
|
541
|
+
self.axis_env = axis_env
|
542
|
+
self.abstracted_axes = abstracted_axes
|
543
|
+
self.name = name
|
544
|
+
self.return_only_write = return_only_write
|
545
|
+
|
546
|
+
# implicit parameters - thread-safe bounded caches
|
547
|
+
self._cached_jaxpr = _BoundedCache(maxsize=128)
|
548
|
+
self._cached_out_shapes = _BoundedCache(maxsize=128)
|
549
|
+
self._cached_jaxpr_out_tree = _BoundedCache(maxsize=128)
|
550
|
+
self._cached_state_trace = _BoundedCache(maxsize=128)
|
551
|
+
self._cache_lock = threading.RLock()
|
552
|
+
|
553
|
+
def __pretty_repr_item__(self, k, v):
|
554
|
+
if k.startswith('_'):
|
555
|
+
return None
|
556
|
+
return k, v
|
557
|
+
|
558
|
+
def get_jaxpr_by_cache(self, cache_key: Hashable) -> ClosedJaxpr:
|
559
|
+
"""
|
560
|
+
Read the JAX Jaxpr representation of the function.
|
561
|
+
|
562
|
+
Parameters
|
563
|
+
----------
|
564
|
+
cache_key : Hashable
|
565
|
+
The hashable cache key for retrieving the compiled jaxpr.
|
566
|
+
|
567
|
+
Returns
|
568
|
+
-------
|
569
|
+
ClosedJaxpr
|
570
|
+
The JAX Jaxpr representation of the function.
|
571
|
+
|
572
|
+
Raises
|
573
|
+
------
|
574
|
+
ValueError
|
575
|
+
If the function has not been compiled for the given cache key.
|
576
|
+
"""
|
577
|
+
return self._cached_jaxpr.get(cache_key, raise_on_miss=True, error_context="JAX expression")
|
578
|
+
|
579
|
+
def get_jaxpr(self, *args, compile_if_miss: bool = True, **kwargs) -> ClosedJaxpr:
|
580
|
+
"""
|
581
|
+
Read the JAX Jaxpr representation of the function by calling with args.
|
582
|
+
|
583
|
+
Parameters
|
584
|
+
----------
|
585
|
+
*args
|
586
|
+
The arguments to the function.
|
587
|
+
compile_if_miss : bool, optional
|
588
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
589
|
+
**kwargs
|
590
|
+
The keyword arguments to the function.
|
591
|
+
|
592
|
+
Returns
|
593
|
+
-------
|
594
|
+
ClosedJaxpr
|
595
|
+
The JAX Jaxpr representation of the function.
|
596
|
+
"""
|
597
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
598
|
+
return self.get_jaxpr_by_cache(cache_key)
|
599
|
+
|
600
|
+
def get_out_shapes_by_cache(self, cache_key: Hashable) -> PyTree:
|
601
|
+
"""
|
602
|
+
Read the output shapes of the function.
|
603
|
+
|
604
|
+
Parameters
|
605
|
+
----------
|
606
|
+
cache_key : Hashable
|
607
|
+
The hashable cache key.
|
608
|
+
|
609
|
+
Returns
|
610
|
+
-------
|
611
|
+
PyTree
|
612
|
+
The output shapes of the function.
|
613
|
+
|
614
|
+
Raises
|
615
|
+
------
|
616
|
+
ValueError
|
617
|
+
If the function has not been compiled for the given cache key.
|
618
|
+
"""
|
619
|
+
return self._cached_out_shapes.get(cache_key, raise_on_miss=True, error_context="Output shapes")
|
620
|
+
|
621
|
+
def get_out_shapes(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
|
622
|
+
"""
|
623
|
+
Read the output shapes of the function.
|
624
|
+
|
625
|
+
Parameters
|
626
|
+
----------
|
627
|
+
*args
|
628
|
+
The arguments to the function.
|
629
|
+
compile_if_miss : bool, optional
|
630
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
631
|
+
**kwargs
|
632
|
+
The keyword arguments to the function.
|
633
|
+
|
634
|
+
Returns
|
635
|
+
-------
|
636
|
+
PyTree
|
637
|
+
The output shapes of the function.
|
638
|
+
"""
|
639
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
640
|
+
return self.get_out_shapes_by_cache(cache_key)
|
641
|
+
|
642
|
+
def get_out_treedef_by_cache(self, cache_key: Hashable) -> PyTree:
|
643
|
+
"""
|
644
|
+
Read the output tree definition of the function.
|
645
|
+
|
646
|
+
Parameters
|
647
|
+
----------
|
648
|
+
cache_key : Hashable
|
649
|
+
The hashable cache key.
|
650
|
+
|
651
|
+
Returns
|
652
|
+
-------
|
653
|
+
PyTree
|
654
|
+
The output tree definition of the function.
|
655
|
+
|
656
|
+
Raises
|
657
|
+
------
|
658
|
+
ValueError
|
659
|
+
If the function has not been compiled for the given cache key.
|
660
|
+
"""
|
661
|
+
return self._cached_jaxpr_out_tree.get(cache_key, raise_on_miss=True, error_context="Output tree")
|
662
|
+
|
663
|
+
def get_out_treedef(self, *args, compile_if_miss: bool = True, **kwargs) -> PyTree:
|
664
|
+
"""
|
665
|
+
Read the output tree of the function.
|
666
|
+
|
667
|
+
Parameters
|
668
|
+
----------
|
669
|
+
*args
|
670
|
+
The arguments to the function.
|
671
|
+
compile_if_miss : bool, optional
|
672
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
673
|
+
**kwargs
|
674
|
+
The keyword arguments to the function.
|
675
|
+
|
676
|
+
Returns
|
677
|
+
-------
|
678
|
+
PyTree
|
679
|
+
The output tree of the function.
|
680
|
+
"""
|
681
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
682
|
+
return self.get_out_treedef_by_cache(cache_key)
|
683
|
+
|
684
|
+
def get_state_trace_by_cache(self, cache_key: Hashable) -> StateTraceStack:
|
685
|
+
"""
|
686
|
+
Read the state trace of the function.
|
687
|
+
|
688
|
+
Parameters
|
689
|
+
----------
|
690
|
+
cache_key : Hashable
|
691
|
+
The hashable cache key.
|
692
|
+
|
693
|
+
Returns
|
694
|
+
-------
|
695
|
+
StateTraceStack
|
696
|
+
The state trace stack containing all tracked states.
|
697
|
+
|
698
|
+
Raises
|
699
|
+
------
|
700
|
+
ValueError
|
701
|
+
If the function has not been compiled for the given cache key.
|
702
|
+
"""
|
703
|
+
return self._cached_state_trace.get(cache_key, raise_on_miss=True, error_context="State trace")
|
704
|
+
|
705
|
+
def get_state_trace(self, *args, compile_if_miss: bool = True, **kwargs) -> StateTraceStack:
|
706
|
+
"""
|
707
|
+
Read the state trace of the function.
|
708
|
+
|
709
|
+
Parameters
|
710
|
+
----------
|
711
|
+
*args
|
712
|
+
The arguments to the function.
|
713
|
+
compile_if_miss : bool, optional
|
714
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
715
|
+
**kwargs
|
716
|
+
The keyword arguments to the function.
|
717
|
+
|
718
|
+
Returns
|
719
|
+
-------
|
720
|
+
StateTraceStack
|
721
|
+
The state trace of the function.
|
722
|
+
"""
|
723
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
724
|
+
return self.get_state_trace_by_cache(cache_key)
|
725
|
+
|
726
|
+
def get_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
|
727
|
+
"""
|
728
|
+
Read the states that are accessed by the function.
|
729
|
+
|
730
|
+
Parameters
|
731
|
+
----------
|
732
|
+
cache_key : Hashable
|
733
|
+
The hashable cache key.
|
734
|
+
|
735
|
+
Returns
|
736
|
+
-------
|
737
|
+
Tuple[State, ...]
|
738
|
+
The states that are read from or written to by the function.
|
739
|
+
|
740
|
+
Raises
|
741
|
+
------
|
742
|
+
ValueError
|
743
|
+
If the function has not been compiled for the given cache key.
|
744
|
+
"""
|
745
|
+
return tuple(self.get_state_trace_by_cache(cache_key).states)
|
746
|
+
|
747
|
+
def get_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
|
748
|
+
"""
|
749
|
+
Compile the function, and get the states that are read and written by this function.
|
750
|
+
|
751
|
+
Parameters
|
752
|
+
----------
|
753
|
+
*args
|
754
|
+
The arguments to the function.
|
755
|
+
compile_if_miss : bool, optional
|
756
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
757
|
+
**kwargs
|
758
|
+
The keyword arguments to the function.
|
759
|
+
|
760
|
+
Returns
|
761
|
+
-------
|
762
|
+
Tuple[State, ...]
|
763
|
+
The states that are read and written by the function.
|
764
|
+
"""
|
765
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
766
|
+
return self.get_states_by_cache(cache_key)
|
767
|
+
|
768
|
+
def get_read_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
|
769
|
+
"""
|
770
|
+
Read the states that are read by the function.
|
771
|
+
|
772
|
+
Parameters
|
773
|
+
----------
|
774
|
+
cache_key : Hashable
|
775
|
+
The hashable key.
|
776
|
+
|
777
|
+
Returns
|
778
|
+
-------
|
779
|
+
Tuple[State, ...]
|
780
|
+
The states that are read by the function.
|
781
|
+
"""
|
782
|
+
return self.get_state_trace_by_cache(cache_key).get_read_states()
|
783
|
+
|
784
|
+
def get_read_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
|
785
|
+
"""
|
786
|
+
Compile the function, and get the states that are read by this function.
|
787
|
+
|
788
|
+
Parameters
|
789
|
+
----------
|
790
|
+
*args
|
791
|
+
The arguments to the function.
|
792
|
+
compile_if_miss : bool, optional
|
793
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
794
|
+
**kwargs
|
795
|
+
The keyword arguments to the function.
|
796
|
+
|
797
|
+
Returns
|
798
|
+
-------
|
799
|
+
Tuple[State, ...]
|
800
|
+
The states that are read by the function.
|
801
|
+
"""
|
802
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
803
|
+
return self.get_read_states_by_cache(cache_key)
|
804
|
+
|
805
|
+
def get_write_states_by_cache(self, cache_key: Hashable) -> Tuple[State, ...]:
|
806
|
+
"""
|
807
|
+
Read the states that are written by the function.
|
808
|
+
|
809
|
+
Parameters
|
810
|
+
----------
|
811
|
+
cache_key : Hashable
|
812
|
+
The hashable cache key.
|
813
|
+
|
814
|
+
Returns
|
815
|
+
-------
|
816
|
+
Tuple[State, ...]
|
817
|
+
The states that are written by the function.
|
818
|
+
"""
|
819
|
+
return self.get_state_trace_by_cache(cache_key).get_write_states()
|
820
|
+
|
821
|
+
def get_write_states(self, *args, compile_if_miss: bool = True, **kwargs) -> Tuple[State, ...]:
|
822
|
+
"""
|
823
|
+
Compile the function, and get the states that are written by this function.
|
824
|
+
|
825
|
+
Parameters
|
826
|
+
----------
|
827
|
+
*args
|
828
|
+
The arguments to the function.
|
829
|
+
compile_if_miss : bool, optional
|
830
|
+
Whether to compile the function if the cache key is not found. Default is True.
|
831
|
+
**kwargs
|
832
|
+
The keyword arguments to the function.
|
833
|
+
|
834
|
+
Returns
|
835
|
+
-------
|
836
|
+
Tuple[State, ...]
|
837
|
+
The states that are written by the function.
|
838
|
+
"""
|
839
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs, compile_if_miss=compile_if_miss)
|
840
|
+
return self.get_write_states_by_cache(cache_key)
|
841
|
+
|
842
|
+
def _check_input_ouput(self, x):
|
843
|
+
if isinstance(x, State):
|
844
|
+
x.raise_error_with_source_info(
|
845
|
+
ValueError(
|
846
|
+
'Inputs/outputs for brainstate transformations cannot be an instance of State. '
|
847
|
+
f'But we got {x}'
|
848
|
+
)
|
849
|
+
)
|
850
|
+
|
851
|
+
def get_arg_cache_key(self, *args, compile_if_miss: bool = False, **kwargs) -> hashabledict:
|
852
|
+
"""
|
853
|
+
Compute the cache key for the given arguments.
|
854
|
+
|
855
|
+
This method separates static and dynamic arguments and creates a hashable
|
856
|
+
key that can be used to cache compiled jaxpr representations.
|
857
|
+
|
858
|
+
Parameters
|
859
|
+
----------
|
860
|
+
*args
|
861
|
+
The positional arguments to the function.
|
862
|
+
compile_if_miss : bool, optional
|
863
|
+
Whether to compile the function if the cache key does not exist.
|
864
|
+
Default is False.
|
865
|
+
**kwargs
|
866
|
+
The keyword arguments to the function.
|
867
|
+
|
868
|
+
Returns
|
869
|
+
-------
|
870
|
+
hashabledict
|
871
|
+
A hashable dictionary containing the cache key with fields:
|
872
|
+
'static_args', 'dyn_args', 'static_kwargs', 'dyn_kwargs'.
|
873
|
+
|
874
|
+
Examples
|
875
|
+
--------
|
876
|
+
.. code-block:: python
|
877
|
+
|
878
|
+
>>> import brainstate
|
879
|
+
>>> import jax.numpy as jnp
|
880
|
+
>>>
|
881
|
+
>>> def f(x, n):
|
882
|
+
... return x ** n
|
883
|
+
>>>
|
884
|
+
>>> sf = brainstate.transform.StatefulFunction(
|
885
|
+
... f, static_argnums=(1,)
|
886
|
+
... )
|
887
|
+
>>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
|
888
|
+
"""
|
889
|
+
static_args, dyn_args = [], []
|
890
|
+
for i, arg in enumerate(args):
|
891
|
+
if i in self.static_argnums:
|
892
|
+
static_args.append(arg)
|
893
|
+
else:
|
894
|
+
dyn_args.append(arg)
|
895
|
+
dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
|
896
|
+
static_kwargs, dyn_kwargs = [], []
|
897
|
+
for k, v in sorted(kwargs.items()):
|
898
|
+
if k in self.static_argnames:
|
899
|
+
static_kwargs.append((k, v))
|
900
|
+
else:
|
901
|
+
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
902
|
+
|
903
|
+
static_args = make_hashable(tuple(static_args))
|
904
|
+
dyn_args = make_hashable(tuple(dyn_args))
|
905
|
+
static_kwargs = make_hashable(static_kwargs)
|
906
|
+
dyn_kwargs = make_hashable(dyn_kwargs)
|
907
|
+
|
908
|
+
cache_key = hashabledict(
|
909
|
+
static_args=static_args,
|
910
|
+
dyn_args=dyn_args,
|
911
|
+
static_kwargs=static_kwargs,
|
912
|
+
dyn_kwargs=dyn_kwargs,
|
913
|
+
)
|
914
|
+
|
915
|
+
if cache_key not in self._cached_state_trace and compile_if_miss:
|
916
|
+
self.make_jaxpr(*args, **kwargs)
|
917
|
+
|
918
|
+
return cache_key
|
919
|
+
|
920
|
+
def clear_cache(self) -> None:
|
921
|
+
"""
|
922
|
+
Clear all compilation caches.
|
923
|
+
|
924
|
+
This method removes all cached jaxprs, output shapes, output trees,
|
925
|
+
and state traces. Use this when you need to recompile the function
|
926
|
+
or free memory.
|
927
|
+
|
928
|
+
Examples
|
929
|
+
--------
|
930
|
+
.. code-block:: python
|
931
|
+
|
932
|
+
>>> import brainstate
|
933
|
+
>>> import jax.numpy as jnp
|
934
|
+
>>>
|
935
|
+
>>> def f(x):
|
936
|
+
... return x * 2
|
937
|
+
>>>
|
938
|
+
>>> sf = brainstate.transform.StatefulFunction(f)
|
939
|
+
>>> sf.make_jaxpr(jnp.array([1.0, 2.0]))
|
940
|
+
>>> sf.clear_cache() # Clear all cached compilations
|
941
|
+
"""
|
942
|
+
self._cached_jaxpr.clear()
|
943
|
+
self._cached_out_shapes.clear()
|
944
|
+
self._cached_jaxpr_out_tree.clear()
|
945
|
+
self._cached_state_trace.clear()
|
946
|
+
|
947
|
+
def __jax_v04_new_arg(self):
|
948
|
+
# Should be within the calling of ``jax.make_jaxpr()``
|
949
|
+
frame, trace = _jax_v04_new_jax_trace()
|
950
|
+
# Set the function to transform the new argument to a tracer
|
951
|
+
fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
|
952
|
+
return fn
|
953
|
+
|
954
|
+
def __jax_new_version_new_arg(self):
|
955
|
+
trace = jax.core.trace_ctx.trace
|
956
|
+
|
957
|
+
def wrapper(x):
|
958
|
+
if jax.__version_info__ < (0, 6, 1):
|
959
|
+
fn = lambda xx: trace.new_arg(shaped_abstractify(xx))
|
960
|
+
else:
|
961
|
+
fn = lambda xx: trace.new_arg(shaped_abstractify(xx), source_info=source_info_util.current())
|
962
|
+
return jax.tree.map(fn, x._value)
|
963
|
+
|
964
|
+
return wrapper
|
965
|
+
|
966
|
+
def _wrapped_fun_to_eval(
|
967
|
+
self,
|
968
|
+
cache_key,
|
969
|
+
static_kwargs: dict,
|
970
|
+
*args,
|
971
|
+
**dyn_kwargs,
|
972
|
+
) -> Tuple[Any, Tuple[State, ...]]:
|
973
|
+
"""
|
974
|
+
Internal wrapper that executes the function and tracks state operations.
|
975
|
+
|
976
|
+
This method wraps the original function to track which states are read
|
977
|
+
and written during execution. It is used internally during jaxpr compilation.
|
978
|
+
|
979
|
+
Parameters
|
980
|
+
----------
|
981
|
+
cache_key
|
982
|
+
The cache key for storing the state trace.
|
983
|
+
static_kwargs : dict
|
984
|
+
Static keyword arguments that were separated out.
|
985
|
+
*args
|
986
|
+
The positional arguments to the function.
|
987
|
+
**dyn_kwargs
|
988
|
+
Dynamic keyword arguments to the function.
|
989
|
+
|
990
|
+
Returns
|
991
|
+
-------
|
992
|
+
tuple
|
993
|
+
A tuple of (output, state_values) where output is the function result
|
994
|
+
and state_values are the tracked state values (either all or write-only
|
995
|
+
depending on return_only_write setting).
|
996
|
+
"""
|
997
|
+
# state trace
|
998
|
+
state_trace: StateTraceStack = StateTraceStack(self.name)
|
999
|
+
if jax.__version_info__ < (0, 4, 36):
|
1000
|
+
state_trace.set_new_arg(self.__jax_v04_new_arg())
|
1001
|
+
else:
|
1002
|
+
state_trace.set_new_arg(self.__jax_new_version_new_arg())
|
1003
|
+
self._cached_state_trace.set(cache_key, state_trace)
|
1004
|
+
with state_trace:
|
1005
|
+
out = self.fun(*args, **dyn_kwargs, **static_kwargs)
|
1006
|
+
state_values = (
|
1007
|
+
state_trace.get_write_state_values(True)
|
1008
|
+
if self.return_only_write else
|
1009
|
+
state_trace.get_state_values()
|
1010
|
+
)
|
1011
|
+
state_trace.recovery_original_values()
|
1012
|
+
|
1013
|
+
# State instance as functional returns is not allowed.
|
1014
|
+
# Checking whether the states are returned.
|
1015
|
+
jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
|
1016
|
+
return out, state_values
|
1017
|
+
|
1018
|
+
def make_jaxpr(self, *args, **kwargs):
|
1019
|
+
"""
|
1020
|
+
Create the JAX Jaxpr representation given example arguments.
|
1021
|
+
|
1022
|
+
This method compiles the function with the given arguments and caches
|
1023
|
+
the resulting Jaxpr, output shapes, and state trace for later use.
|
1024
|
+
|
1025
|
+
Parameters
|
1026
|
+
----------
|
1027
|
+
*args
|
1028
|
+
The arguments to the function.
|
1029
|
+
**kwargs
|
1030
|
+
The keyword arguments to the function.
|
1031
|
+
|
1032
|
+
Returns
|
1033
|
+
-------
|
1034
|
+
StatefulFunction
|
1035
|
+
Returns self for method chaining.
|
1036
|
+
|
1037
|
+
Raises
|
1038
|
+
------
|
1039
|
+
TypeError
|
1040
|
+
If State objects are passed as arguments or returned from the function.
|
1041
|
+
"""
|
1042
|
+
|
1043
|
+
# check input types
|
1044
|
+
jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
|
1045
|
+
|
1046
|
+
# static args
|
1047
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
1048
|
+
|
1049
|
+
if cache_key not in self._cached_state_trace:
|
1050
|
+
try:
|
1051
|
+
|
1052
|
+
# jaxpr
|
1053
|
+
static_kwargs, dyn_kwargs = {}, {}
|
1054
|
+
for k, v in kwargs.items():
|
1055
|
+
if k in self.static_argnames:
|
1056
|
+
static_kwargs[k] = v
|
1057
|
+
else:
|
1058
|
+
dyn_kwargs[k] = v
|
1059
|
+
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
1060
|
+
functools.partial(
|
1061
|
+
self._wrapped_fun_to_eval,
|
1062
|
+
cache_key,
|
1063
|
+
static_kwargs,
|
1064
|
+
),
|
1065
|
+
static_argnums=self.static_argnums,
|
1066
|
+
axis_env=self.axis_env,
|
1067
|
+
return_shape=True,
|
1068
|
+
abstracted_axes=self.abstracted_axes,
|
1069
|
+
)(*args, **dyn_kwargs)
|
1070
|
+
|
1071
|
+
# returns
|
1072
|
+
self._cached_jaxpr_out_tree.set(cache_key, jax.tree.structure((out_shapes, state_shapes)))
|
1073
|
+
self._cached_out_shapes.set(cache_key, (out_shapes, state_shapes))
|
1074
|
+
self._cached_jaxpr.set(cache_key, jaxpr)
|
1075
|
+
|
1076
|
+
except Exception as e:
|
1077
|
+
# Clean up partial cache entries on error
|
1078
|
+
self._cached_state_trace.pop(cache_key, None)
|
1079
|
+
self._cached_out_shapes.pop(cache_key, None)
|
1080
|
+
self._cached_jaxpr.pop(cache_key, None)
|
1081
|
+
raise e
|
1082
|
+
|
1083
|
+
return self
|
1084
|
+
|
1085
|
+
def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
|
1086
|
+
"""
|
1087
|
+
Call the function at the JAX Jaxpr level.
|
1088
|
+
|
1089
|
+
This method evaluates the compiled Jaxpr with the provided state values
|
1090
|
+
and arguments, returning updated state values and function outputs.
|
1091
|
+
|
1092
|
+
Parameters
|
1093
|
+
----------
|
1094
|
+
state_vals : Sequence
|
1095
|
+
The current state values.
|
1096
|
+
*args
|
1097
|
+
The arguments to the function.
|
1098
|
+
**kwargs
|
1099
|
+
The keyword arguments to the function.
|
1100
|
+
|
1101
|
+
Returns
|
1102
|
+
-------
|
1103
|
+
tuple
|
1104
|
+
A tuple of (new_state_vals, out) where new_state_vals are the
|
1105
|
+
updated state values and out is the function output.
|
1106
|
+
|
1107
|
+
Raises
|
1108
|
+
------
|
1109
|
+
ValueError
|
1110
|
+
If the number of state values doesn't match the expected number.
|
1111
|
+
"""
|
1112
|
+
# state checking
|
1113
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
1114
|
+
states: Sequence[State] = self.get_states_by_cache(cache_key)
|
1115
|
+
if len(state_vals) != len(states):
|
1116
|
+
raise ValueError(f'State length mismatch: expected {len(states)} states, got {len(state_vals)}')
|
1117
|
+
|
1118
|
+
# parameters
|
1119
|
+
kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
|
1120
|
+
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
1121
|
+
args = jax.tree.flatten((args, kwargs, state_vals))[0]
|
1122
|
+
|
1123
|
+
# calling the function,
|
1124
|
+
# note that this function always returns state values
|
1125
|
+
# that both write and read by the function
|
1126
|
+
closed_jaxpr = self.get_jaxpr_by_cache(cache_key)
|
1127
|
+
out_treedef = self.get_out_treedef_by_cache(cache_key)
|
1128
|
+
jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
1129
|
+
|
1130
|
+
# output processing
|
1131
|
+
out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
|
1132
|
+
if len(new_state_vals) != len(state_vals):
|
1133
|
+
raise ValueError(f'State length mismatch in output: expected '
|
1134
|
+
f'{len(state_vals)} states, got {len(new_state_vals)}')
|
1135
|
+
return new_state_vals, out
|
1136
|
+
|
1137
|
+
def get_cache_stats(self) -> Dict[str, Any]:
|
1138
|
+
"""
|
1139
|
+
Get comprehensive cache statistics for all internal caches.
|
1140
|
+
|
1141
|
+
Returns
|
1142
|
+
-------
|
1143
|
+
dict
|
1144
|
+
A dictionary with statistics for each cache including size, hits,
|
1145
|
+
misses, and hit rates. Keys are 'jaxpr_cache', 'out_shapes_cache',
|
1146
|
+
'jaxpr_out_tree_cache', and 'state_trace_cache'.
|
1147
|
+
"""
|
1148
|
+
return {
|
1149
|
+
'jaxpr_cache': self._cached_jaxpr.get_stats(),
|
1150
|
+
'out_shapes_cache': self._cached_out_shapes.get_stats(),
|
1151
|
+
'jaxpr_out_tree_cache': self._cached_jaxpr_out_tree.get_stats(),
|
1152
|
+
'state_trace_cache': self._cached_state_trace.get_stats(),
|
1153
|
+
}
|
1154
|
+
|
1155
|
+
def validate_states(self, cache_key: Hashable) -> bool:
|
1156
|
+
"""
|
1157
|
+
Validate that all tracked states for a given cache key are still valid.
|
1158
|
+
|
1159
|
+
Parameters
|
1160
|
+
----------
|
1161
|
+
cache_key : Hashable
|
1162
|
+
The cache key to validate states for.
|
1163
|
+
|
1164
|
+
Returns
|
1165
|
+
-------
|
1166
|
+
bool
|
1167
|
+
True if all states are valid.
|
1168
|
+
|
1169
|
+
Raises
|
1170
|
+
------
|
1171
|
+
ValueError
|
1172
|
+
If any states are invalid or missing required attributes.
|
1173
|
+
"""
|
1174
|
+
state_trace = self.get_state_trace_by_cache(cache_key)
|
1175
|
+
invalid_states = []
|
1176
|
+
for i, state in enumerate(state_trace.states):
|
1177
|
+
if not hasattr(state, 'value'):
|
1178
|
+
invalid_states.append((i, state))
|
1179
|
+
|
1180
|
+
if invalid_states:
|
1181
|
+
raise ValueError(
|
1182
|
+
f"Found {len(invalid_states)} invalid states at indices: "
|
1183
|
+
f"{[idx for idx, _ in invalid_states]}. "
|
1184
|
+
f"States must have a 'value' attribute."
|
1185
|
+
)
|
1186
|
+
return True
|
1187
|
+
|
1188
|
+
def validate_all_states(self) -> Dict[Any, bool]:
|
1189
|
+
"""
|
1190
|
+
Validate states for all cached compilations.
|
1191
|
+
|
1192
|
+
Returns
|
1193
|
+
-------
|
1194
|
+
dict
|
1195
|
+
A dictionary mapping cache keys to validation results. Each value
|
1196
|
+
is either True (valid) or an error message string (invalid).
|
1197
|
+
"""
|
1198
|
+
results = {}
|
1199
|
+
for cache_key in self._cached_state_trace.keys():
|
1200
|
+
try:
|
1201
|
+
results[cache_key] = self.validate_states(cache_key)
|
1202
|
+
except ValueError as e:
|
1203
|
+
results[cache_key] = str(e)
|
1204
|
+
return results
|
1205
|
+
|
1206
|
+
def jaxpr_call_auto(self, *args, **kwargs) -> Any:
|
1207
|
+
"""
|
1208
|
+
Execute the function at the jaxpr level with automatic state management.
|
1209
|
+
|
1210
|
+
This method automatically retrieves current state values, executes the
|
1211
|
+
jaxpr-compiled function, and updates the states with the new values.
|
1212
|
+
It provides a convenient interface that handles all state management
|
1213
|
+
automatically.
|
1214
|
+
|
1215
|
+
Parameters
|
1216
|
+
----------
|
1217
|
+
*args
|
1218
|
+
The positional arguments to the function.
|
1219
|
+
**kwargs
|
1220
|
+
The keyword arguments to the function.
|
1221
|
+
|
1222
|
+
Returns
|
1223
|
+
-------
|
1224
|
+
Any
|
1225
|
+
The output of the function.
|
1226
|
+
|
1227
|
+
Examples
|
1228
|
+
--------
|
1229
|
+
.. code-block:: python
|
1230
|
+
|
1231
|
+
>>> import brainstate
|
1232
|
+
>>> import jax.numpy as jnp
|
1233
|
+
>>>
|
1234
|
+
>>> state = brainstate.State(jnp.array([1.0, 2.0]))
|
1235
|
+
>>>
|
1236
|
+
>>> def f(x):
|
1237
|
+
... state.value += x
|
1238
|
+
... return state.value * 2
|
1239
|
+
>>>
|
1240
|
+
>>> sf = brainstate.transform.StatefulFunction(f)
|
1241
|
+
>>> x = jnp.array([0.5, 0.5])
|
1242
|
+
>>> sf.make_jaxpr(x)
|
1243
|
+
>>>
|
1244
|
+
>>> # Automatic state management
|
1245
|
+
>>> result = sf.jaxpr_call_auto(x)
|
1246
|
+
# # or
|
1247
|
+
>>> result = sf(x)
|
1248
|
+
>>> print(state.value) # State is automatically updated
|
1249
|
+
"""
|
1250
|
+
state_trace = self.get_state_trace_by_cache(self.get_arg_cache_key(*args, **kwargs, compile_if_miss=True))
|
1251
|
+
all_read_state_vals = state_trace.get_read_state_values(True)
|
1252
|
+
state_vals, out = self.jaxpr_call(state_trace.get_state_values(), *args, **kwargs)
|
1253
|
+
state_trace.assign_state_vals_v2(all_read_state_vals, state_vals)
|
1254
|
+
return out
|
1255
|
+
|
1256
|
+
def __call__(self, *args, **kwargs):
|
1257
|
+
return self.jaxpr_call_auto(*args, **kwargs)
|
1258
|
+
|
1259
|
+
|
1260
|
+
@set_module_as("brainstate.transform")
|
1261
|
+
def make_jaxpr(
|
1262
|
+
fun: Callable,
|
1263
|
+
static_argnums: Union[int, Iterable[int]] = (),
|
1264
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
1265
|
+
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
1266
|
+
return_shape: bool = False,
|
1267
|
+
abstracted_axes: Optional[Any] = None,
|
1268
|
+
return_only_write: bool = False,
|
1269
|
+
) -> Callable[
|
1270
|
+
...,
|
1271
|
+
(Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
1272
|
+
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
|
1273
|
+
]:
|
1274
|
+
"""
|
1275
|
+
Creates a function that produces its jaxpr given example args.
|
1276
|
+
|
1277
|
+
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
1278
|
+
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
1279
|
+
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
1280
|
+
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
1281
|
+
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
1282
|
+
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
1283
|
+
|
1284
|
+
Parameters
|
1285
|
+
----------
|
1286
|
+
fun : callable
|
1287
|
+
The function whose ``jaxpr`` is to be computed. Its positional
|
1288
|
+
arguments and return value should be arrays, scalars, or standard Python
|
1289
|
+
containers (tuple/list/dict) thereof.
|
1290
|
+
static_argnums : int or iterable of int, optional
|
1291
|
+
See the :py:func:`jax.jit` docstring.
|
1292
|
+
static_argnames : str or iterable of str, optional
|
1293
|
+
See the :py:func:`jax.jit` docstring.
|
1294
|
+
axis_env : sequence of tuple, optional
|
1295
|
+
A sequence of pairs where the first element is an axis
|
1296
|
+
name and the second element is a positive integer representing the size of
|
1297
|
+
the mapped axis with that name. This parameter is useful when lowering
|
1298
|
+
functions that involve parallel communication collectives, and it
|
1299
|
+
specifies the axis name/size environment that would be set up by
|
1300
|
+
applications of :py:func:`jax.pmap`.
|
1301
|
+
return_shape : bool, default False
|
1302
|
+
If ``True``, the
|
1303
|
+
wrapped function returns a pair where the first element is the XLA
|
1304
|
+
computation and the second element is a pytree with the same structure as
|
1305
|
+
the output of ``fun`` and where the leaves are objects with ``shape``,
|
1306
|
+
``dtype``, and ``named_shape`` attributes representing the corresponding
|
1307
|
+
types of the output leaves.
|
1308
|
+
abstracted_axes : pytree, optional
|
1309
|
+
A pytree with the same structure as the input
|
1310
|
+
arguments to ``fun``. The leaves of the pytree can be either None or a
|
1311
|
+
dict with axis names as keys and integers as values. If the leaf is None,
|
1312
|
+
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
1313
|
+
the corresponding axis is abstracted, and the dict specifies the axis name
|
1314
|
+
and size. The abstracted axes are used to infer the input type of the
|
1315
|
+
function. If None, then all axes are abstracted.
|
1316
|
+
return_only_write : bool, default False
|
1317
|
+
If True, only return states that were written to during execution
|
1318
|
+
(not just read). This can reduce memory usage when you only care
|
1319
|
+
about modified states.
|
1320
|
+
|
1321
|
+
Returns
|
1322
|
+
-------
|
1323
|
+
callable
|
1324
|
+
A wrapped version of ``fun`` that when applied to example arguments returns
|
1325
|
+
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
1326
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
1327
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
1328
|
+
representation of ``fun`` and the second element is a pytree representing
|
1329
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
1330
|
+
|
1331
|
+
Examples
|
1332
|
+
--------
|
1333
|
+
Basic usage:
|
1334
|
+
|
1335
|
+
.. code-block:: python
|
1336
|
+
|
1337
|
+
>>> import jax
|
1338
|
+
>>> import brainstate
|
1339
|
+
>>> import jax.numpy as jnp
|
1340
|
+
>>>
|
1341
|
+
>>> def f(x):
|
1342
|
+
... return jnp.sin(jnp.cos(x))
|
1343
|
+
>>>
|
1344
|
+
>>> # Create jaxpr maker
|
1345
|
+
>>> jaxpr_maker = brainstate.transform.make_jaxpr(f)
|
1346
|
+
>>> jaxpr, states = jaxpr_maker(3.0)
|
1347
|
+
|
1348
|
+
With gradient:
|
1349
|
+
|
1350
|
+
.. code-block:: python
|
1351
|
+
|
1352
|
+
>>> jaxpr_grad_maker = brainstate.transform.make_jaxpr(jax.grad(f))
|
1353
|
+
>>> jaxpr, states = jaxpr_grad_maker(3.0)
|
1354
|
+
|
1355
|
+
With shape information:
|
1356
|
+
|
1357
|
+
.. code-block:: python
|
1358
|
+
|
1359
|
+
>>> jaxpr_maker_with_shape = brainstate.transform.make_jaxpr(f, return_shape=True)
|
1360
|
+
>>> jaxpr, states, shapes = jaxpr_maker_with_shape(3.0)
|
1361
|
+
|
1362
|
+
With stateful function:
|
1363
|
+
|
1364
|
+
.. code-block:: python
|
1365
|
+
|
1366
|
+
>>> state = brainstate.State(jnp.array([1.0, 2.0]))
|
1367
|
+
>>>
|
1368
|
+
>>> def stateful_f(x):
|
1369
|
+
... state.value += x
|
1370
|
+
... return state.value
|
1371
|
+
>>>
|
1372
|
+
>>> jaxpr_maker = brainstate.transform.make_jaxpr(stateful_f)
|
1373
|
+
>>> jaxpr, states = jaxpr_maker(jnp.array([0.5, 0.5]))
|
1374
|
+
"""
|
1375
|
+
|
1376
|
+
stateful_fun = StatefulFunction(
|
1377
|
+
fun,
|
1378
|
+
static_argnums=static_argnums,
|
1379
|
+
static_argnames=static_argnames,
|
1380
|
+
axis_env=axis_env,
|
1381
|
+
abstracted_axes=abstracted_axes,
|
1382
|
+
return_only_write=return_only_write,
|
1383
|
+
name='make_jaxpr'
|
1384
|
+
)
|
1385
|
+
|
1386
|
+
@wraps(fun)
|
1387
|
+
def make_jaxpr_f(*args, **kwargs):
|
1388
|
+
stateful_fun.make_jaxpr(*args, **kwargs)
|
1389
|
+
cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
|
1390
|
+
if return_shape:
|
1391
|
+
return (
|
1392
|
+
stateful_fun.get_jaxpr_by_cache(cache_key),
|
1393
|
+
stateful_fun.get_states_by_cache(cache_key),
|
1394
|
+
stateful_fun.get_out_shapes_by_cache(cache_key)[0]
|
1395
|
+
)
|
1396
|
+
else:
|
1397
|
+
return (
|
1398
|
+
stateful_fun.get_jaxpr_by_cache(cache_key),
|
1399
|
+
stateful_fun.get_states_by_cache(cache_key)
|
1400
|
+
)
|
1401
|
+
|
1402
|
+
# wrapped jaxpr builder function
|
1403
|
+
make_jaxpr_f.__module__ = "brainstate.transform"
|
1404
|
+
if hasattr(fun, "__qualname__"):
|
1405
|
+
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
1406
|
+
if hasattr(fun, "__name__"):
|
1407
|
+
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
1408
|
+
return make_jaxpr_f
|
1409
|
+
|
1410
|
+
|
1411
|
+
class StatefulMapping(StatefulFunction):
|
1412
|
+
__module__ = "brainstate.transform"
|
1413
|
+
|
1414
|
+
def __init__(
|
1415
|
+
self,
|
1416
|
+
fun: Callable,
|
1417
|
+
in_axes: Union[int, Tuple[int, ...], None] = 0,
|
1418
|
+
out_axes: Union[int, Tuple[int, ...], None] = 0,
|
1419
|
+
state_in_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
|
1420
|
+
state_out_axes: Optional[Union[Dict[AxisName, Filter], Filter]] = None,
|
1421
|
+
# jit specific parameters
|
1422
|
+
static_argnums: Union[int, Iterable[int]] = (),
|
1423
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
1424
|
+
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
1425
|
+
abstracted_axes: Optional[Any] = None,
|
1426
|
+
# mapping specific parameters
|
1427
|
+
axis_size: Optional[int] = None,
|
1428
|
+
axis_name: AxisName | None = None,
|
1429
|
+
name: Optional[str] = None,
|
1430
|
+
# mapping function
|
1431
|
+
mapping_fn: Callable = jax.vmap,
|
1432
|
+
):
|
1433
|
+
self.origin_fun = fun
|
1434
|
+
super().__init__(
|
1435
|
+
fun=self._wrapped_fun,
|
1436
|
+
static_argnums=static_argnums,
|
1437
|
+
static_argnames=static_argnames,
|
1438
|
+
axis_env=axis_env,
|
1439
|
+
abstracted_axes=abstracted_axes,
|
1440
|
+
name=name,
|
1441
|
+
return_only_write=False,
|
1442
|
+
)
|
1443
|
+
self.in_axes = in_axes
|
1444
|
+
self.out_axes = out_axes
|
1445
|
+
if state_in_axes is None:
|
1446
|
+
state_in_axes = dict()
|
1447
|
+
elif not isinstance(state_in_axes, dict):
|
1448
|
+
state_in_axes = {0: to_predicate(state_in_axes)}
|
1449
|
+
state_in_axes = {k: to_predicate(v) for k, v in state_in_axes.items()} # type: ignore
|
1450
|
+
self.state_in_axes = state_in_axes
|
1451
|
+
|
1452
|
+
if state_out_axes is None:
|
1453
|
+
state_out_axes = dict()
|
1454
|
+
elif not isinstance(state_out_axes, dict):
|
1455
|
+
state_out_axes = {0: to_predicate(state_out_axes)}
|
1456
|
+
state_out_axes = {k: to_predicate(v) for k, v in state_out_axes.items()} # type: ignore
|
1457
|
+
self.state_out_axes = state_out_axes
|
1458
|
+
|
1459
|
+
self.axis_size = axis_size
|
1460
|
+
self.axis_name = axis_name
|
1461
|
+
self.mapping_fn = mapping_fn
|
1462
|
+
|
1463
|
+
# Cache for discovered state-to-axis mappings
|
1464
|
+
self._cached_map_dim_to_in_states = _BoundedCache(maxsize=128)
|
1465
|
+
self._cached_map_dim_to_out_states = _BoundedCache(maxsize=128)
|
1466
|
+
self._cached_map_state_trace = _BoundedCache(maxsize=128)
|
1467
|
+
self._cached_map_batch_size = _BoundedCache(maxsize=128)
|
1468
|
+
|
1469
|
+
def _infer_batch_size(self, args, in_axes):
|
1470
|
+
if in_axes is None:
|
1471
|
+
raise ValueError("Cannot infer batch size when in_axes is None")
|
1472
|
+
|
1473
|
+
batch_sizes = []
|
1474
|
+
|
1475
|
+
def get_batch_size_from_arg(arg_, axis_):
|
1476
|
+
if axis_ is None:
|
1477
|
+
return None
|
1478
|
+
|
1479
|
+
def _get_size(arr):
|
1480
|
+
if not hasattr(arr, 'shape'):
|
1481
|
+
return None
|
1482
|
+
if arr.ndim == 0:
|
1483
|
+
return None
|
1484
|
+
ax = axis_ if axis_ >= 0 else arr.ndim + axis_
|
1485
|
+
if ax < 0 or ax >= arr.ndim:
|
1486
|
+
raise IndexError(f"Axis {ax} is out of bounds for array of shape {arr.shape}")
|
1487
|
+
return arr.shape[ax]
|
1488
|
+
|
1489
|
+
# Get all sizes from the pytree
|
1490
|
+
sizes = [s for s in jax.tree.leaves(jax.tree.map(_get_size, arg_)) if s is not None]
|
1491
|
+
return sizes[0] if sizes else None
|
1492
|
+
|
1493
|
+
if isinstance(in_axes, int):
|
1494
|
+
# All args batched along the same axis
|
1495
|
+
for arg in args:
|
1496
|
+
size = get_batch_size_from_arg(arg, in_axes)
|
1497
|
+
if size is not None:
|
1498
|
+
batch_sizes.append(size)
|
1499
|
+
elif isinstance(in_axes, (tuple, list)):
|
1500
|
+
# Different axes for different args
|
1501
|
+
if len(in_axes) != len(args):
|
1502
|
+
raise ValueError(
|
1503
|
+
f"Length of in_axes ({len(in_axes)}) must match number of arguments ({len(args)})"
|
1504
|
+
)
|
1505
|
+
for arg, axis in zip(args, in_axes):
|
1506
|
+
size = get_batch_size_from_arg(arg, axis)
|
1507
|
+
if size is not None:
|
1508
|
+
batch_sizes.append(size)
|
1509
|
+
else:
|
1510
|
+
raise TypeError(f"Unsupported in_axes type: {type(in_axes)}")
|
1511
|
+
|
1512
|
+
if not batch_sizes:
|
1513
|
+
if self.axis_size is None:
|
1514
|
+
raise ValueError("Cannot infer batch size when axis_size is None")
|
1515
|
+
batch_sizes.append(self.axis_size)
|
1516
|
+
|
1517
|
+
# Check all batch sizes are consistent
|
1518
|
+
if not all(s == batch_sizes[0] for s in batch_sizes):
|
1519
|
+
raise ValueError(
|
1520
|
+
f"Inconsistent batch sizes found: {batch_sizes}. "
|
1521
|
+
f"All batched arguments must have the same size along their batch axes."
|
1522
|
+
)
|
1523
|
+
|
1524
|
+
return batch_sizes[0]
|
1525
|
+
|
1526
|
+
def __new_batch_arg(self, batch_size: int, dim_to_states: dict):
|
1527
|
+
trace = jax.core.trace_ctx.trace
|
1528
|
+
assert isinstance(trace, BatchTrace), f"Expected to be called within a BatchTrace context, but got {trace}"
|
1529
|
+
|
1530
|
+
def wrapper(x):
|
1531
|
+
if isinstance(x, RandomState):
|
1532
|
+
idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), 0, source_info_util.current()))
|
1533
|
+
dim_to_states['random'].append(x)
|
1534
|
+
return to_elt(trace, idx, jnp.ones((batch_size,) + x._value.shape, x._value.dtype), 0)
|
1535
|
+
for dim, filter_ in self.state_in_axes.items():
|
1536
|
+
idx = memoize(lambda: BatchTracer(trace, make_iota(batch_size), dim, source_info_util.current()))
|
1537
|
+
if filter_(tuple(), x):
|
1538
|
+
dim_to_states[dim].append(x)
|
1539
|
+
return jax.tree.map(lambda xx: to_elt(trace, idx, xx, dim), x._value)
|
1540
|
+
return x._value
|
1541
|
+
|
1542
|
+
return wrapper
|
1543
|
+
|
1544
|
+
def __eval(self, cache_key, *args, **kwargs):
|
1545
|
+
def fn_to_eval(*new_args, **new_kwargs):
|
1546
|
+
dim_to_in_states = defaultdict(list)
|
1547
|
+
state_trace = StateTraceStack(name=self.name)
|
1548
|
+
state_trace.set_new_arg(
|
1549
|
+
self.__new_batch_arg(self._cached_map_batch_size.get(cache_key), dim_to_in_states)
|
1550
|
+
)
|
1551
|
+
self._cached_map_state_trace.set(cache_key, state_trace)
|
1552
|
+
|
1553
|
+
# call functions
|
1554
|
+
with state_trace:
|
1555
|
+
out_ = self.origin_fun(*new_args, **new_kwargs)
|
1556
|
+
|
1557
|
+
# cache
|
1558
|
+
self._cached_map_dim_to_in_states.set(cache_key, dim_to_in_states)
|
1559
|
+
|
1560
|
+
# vmapped state values
|
1561
|
+
out_states = defaultdict(list)
|
1562
|
+
out_states['random'] = [st for st in state_trace.states if isinstance(st, RandomState)]
|
1563
|
+
for st in state_trace.states:
|
1564
|
+
if not isinstance(st, RandomState):
|
1565
|
+
leaves = jax.tree.leaves(st._value)
|
1566
|
+
batch_dims = set([leaf.batch_dim if isinstance(leaf, BatchTracer) else None for leaf in leaves])
|
1567
|
+
if len(batch_dims) != 1:
|
1568
|
+
raise ValueError(
|
1569
|
+
f"State {st} has inconsistent batch dimensions in its leaves: {batch_dims}. "
|
1570
|
+
"All leaves must have the same batch dimension."
|
1571
|
+
)
|
1572
|
+
batch_dim = batch_dims.pop()
|
1573
|
+
out_states[batch_dim].append(st)
|
1574
|
+
self._cached_map_dim_to_out_states.set(cache_key, out_states)
|
1575
|
+
|
1576
|
+
try:
|
1577
|
+
jax.vmap(
|
1578
|
+
fn_to_eval,
|
1579
|
+
in_axes=self.in_axes,
|
1580
|
+
out_axes=self.out_axes,
|
1581
|
+
axis_name=self.axis_name,
|
1582
|
+
axis_size=self.axis_size
|
1583
|
+
)(*args, **kwargs)
|
1584
|
+
self._cached_map_state_trace.get(cache_key).recovery_original_values()
|
1585
|
+
except Exception as e:
|
1586
|
+
if cache_key in self._cached_map_state_trace:
|
1587
|
+
self._cached_map_state_trace.get(cache_key).recovery_original_values()
|
1588
|
+
self._cached_map_state_trace.pop(cache_key, None)
|
1589
|
+
self._cached_map_dim_to_in_states.pop(cache_key, None)
|
1590
|
+
self._cached_map_dim_to_out_states.pop(cache_key, None)
|
1591
|
+
self._cached_map_batch_size.pop(cache_key, None)
|
1592
|
+
raise RuntimeError(f"Failed to evaluate {self}") from e
|
1593
|
+
|
1594
|
+
def __assign_vals_from_in_states(self, cache_key, rand_st, *other_st):
|
1595
|
+
in_states = self._cached_map_dim_to_in_states.get(cache_key)
|
1596
|
+
for st, val in zip(in_states['random'], rand_st):
|
1597
|
+
assert isinstance(st, RandomState)
|
1598
|
+
st.restore_value(val)
|
1599
|
+
for group, group_vals in zip([in_states[dim] for dim in in_states.keys() if dim != 'random'], other_st):
|
1600
|
+
for st, val in zip(group, group_vals):
|
1601
|
+
st.restore_value(val)
|
1602
|
+
|
1603
|
+
def __assign_vals_from_out_states(self, cache_key, rand_st, *other_st):
|
1604
|
+
out_states = self._cached_map_dim_to_out_states.get(cache_key)
|
1605
|
+
for st, val in zip(out_states['random'], rand_st):
|
1606
|
+
assert isinstance(st, RandomState)
|
1607
|
+
st.restore_value(val)
|
1608
|
+
for group, group_vals in zip([out_states[dim] for dim in out_states.keys() if dim != 'random'], other_st):
|
1609
|
+
for st, val in zip(group, group_vals):
|
1610
|
+
st.restore_value(val)
|
1611
|
+
|
1612
|
+
def __get_in_state_vals(self, cache_key: Hashable):
|
1613
|
+
in_states = self._cached_map_dim_to_in_states.get(cache_key)
|
1614
|
+
in_axes = []
|
1615
|
+
in_values = []
|
1616
|
+
for dim, states in in_states.items():
|
1617
|
+
if dim == 'random':
|
1618
|
+
continue
|
1619
|
+
in_axes.append(dim)
|
1620
|
+
in_values.append([st.value for st in states])
|
1621
|
+
return tuple(in_axes), in_values
|
1622
|
+
|
1623
|
+
def __get_out_state_vals(self, cache_key: Hashable):
|
1624
|
+
out_states = self._cached_map_dim_to_out_states.get(cache_key)
|
1625
|
+
out_axes = []
|
1626
|
+
out_values = []
|
1627
|
+
for dim, state in out_states.items():
|
1628
|
+
if dim == 'random':
|
1629
|
+
continue
|
1630
|
+
out_axes.append(dim)
|
1631
|
+
out_values.append([st.value for st in state])
|
1632
|
+
return tuple(out_axes), out_values
|
1633
|
+
|
1634
|
+
def __get_rand_state_vals(self, cache_key: Hashable):
|
1635
|
+
in_states = self._cached_map_dim_to_in_states.get(cache_key)
|
1636
|
+
batch_size = self._cached_map_batch_size.get(cache_key)
|
1637
|
+
rand_vals, rand_recover_vals = [], []
|
1638
|
+
for st in in_states['random']:
|
1639
|
+
assert isinstance(st, RandomState)
|
1640
|
+
rand_vals.append(st.split_key(batch_size))
|
1641
|
+
rand_recover_vals.append(st.value)
|
1642
|
+
return tuple(rand_vals), tuple(rand_recover_vals)
|
1643
|
+
|
1644
|
+
def __recover_rand_state_vals(self, cache_key: Hashable, rand_recover_vals):
|
1645
|
+
state_trace = self._cached_map_state_trace.get(cache_key)
|
1646
|
+
rand_states = [st for st in state_trace.states if isinstance(st, RandomState)]
|
1647
|
+
for st, val in zip(rand_states, rand_recover_vals):
|
1648
|
+
st.restore_value(val)
|
1649
|
+
|
1650
|
+
def _wrapped_fun(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
|
1651
|
+
batch_size = self._infer_batch_size(args, self.in_axes)
|
1652
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
1653
|
+
self._cached_map_batch_size.set(cache_key, batch_size)
|
1654
|
+
if cache_key not in self._cached_map_state_trace:
|
1655
|
+
self.__eval(cache_key, *args, **kwargs)
|
1656
|
+
|
1657
|
+
def fn_to_map(origin_args, rand_st, *non_rand_st):
|
1658
|
+
self.__assign_vals_from_in_states(cache_key, rand_st, *non_rand_st)
|
1659
|
+
out = self.origin_fun(*origin_args[0], **origin_args[1])
|
1660
|
+
return out, *self.__get_out_state_vals(cache_key)[1]
|
1661
|
+
|
1662
|
+
in_axes, in_state_vals = self.__get_in_state_vals(cache_key)
|
1663
|
+
out_axes, out_state_vals = self.__get_out_state_vals(cache_key)
|
1664
|
+
rand_vals, rand_recover_vals = self.__get_rand_state_vals(cache_key)
|
1665
|
+
mapped_fn = self.mapping_fn(
|
1666
|
+
fn_to_map,
|
1667
|
+
in_axes=(self.in_axes, 0) + in_axes,
|
1668
|
+
out_axes=(self.out_axes,) + out_axes,
|
1669
|
+
axis_size=self.axis_size,
|
1670
|
+
axis_name=self.axis_name,
|
1671
|
+
)
|
1672
|
+
out_, *out_state_vals = mapped_fn((args, kwargs), rand_vals, *in_state_vals)
|
1673
|
+
self.__assign_vals_from_out_states(cache_key, rand_recover_vals, *out_state_vals)
|
1674
|
+
return out_
|
1675
|
+
|
1676
|
+
|
1677
|
+
def _check_callable(fun):
|
1678
|
+
# In Python 3.10+, the only thing stopping us from supporting static methods
|
1679
|
+
# is that we can't take weak references to them, which the C++ JIT requires.
|
1680
|
+
if isinstance(fun, staticmethod):
|
1681
|
+
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
1682
|
+
if not callable(fun):
|
1683
|
+
raise TypeError(f"Expected a callable value, got {fun}")
|
1684
|
+
if inspect.isgeneratorfunction(fun):
|
1685
|
+
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
1686
|
+
|
1687
|
+
|
1688
|
+
def _broadcast_prefix(
|
1689
|
+
prefix_tree: Any,
|
1690
|
+
full_tree: Any,
|
1691
|
+
is_leaf: Callable[[Any], bool] | None = None
|
1692
|
+
) -> list[Any]:
|
1693
|
+
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
1694
|
+
# ValueError; use prefix_errors to find disagreements and raise more precise
|
1695
|
+
# error messages.
|
1696
|
+
result = []
|
1697
|
+
num_leaves = lambda t: jax.tree.structure(t).num_leaves
|
1698
|
+
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
1699
|
+
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
1700
|
+
return result
|
1701
|
+
|
1702
|
+
|
1703
|
+
def _flat_axes_specs(
|
1704
|
+
abstracted_axes, *args, **kwargs
|
1705
|
+
) -> list[pe.AbstractedAxesSpec]:
|
1706
|
+
if kwargs:
|
1707
|
+
raise NotImplementedError
|
1708
|
+
|
1709
|
+
def ax_leaf(l):
|
1710
|
+
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
|
1711
|
+
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
|
1712
|
+
|
1713
|
+
return _broadcast_prefix(abstracted_axes, args, ax_leaf)
|
1714
|
+
|
1715
|
+
|
1716
|
+
@transformation_with_aux
|
1717
|
+
def _flatten_fun(in_tree, *args_flat):
|
1718
|
+
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
|
1719
|
+
ans = yield py_args, py_kwargs
|
1720
|
+
yield jax.tree.flatten(ans)
|
1721
|
+
|
1722
|
+
|
1723
|
+
def _make_jaxpr(
|
1724
|
+
fun: Callable,
|
1725
|
+
static_argnums: int | Iterable[int] = (),
|
1726
|
+
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
1727
|
+
return_shape: bool = False,
|
1728
|
+
abstracted_axes: Any | None = None,
|
1729
|
+
) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
|
1730
|
+
"""
|
1731
|
+
Create a function that produces its jaxpr given example args (internal implementation).
|
1732
|
+
|
1733
|
+
This is an internal implementation function. Users should use the public
|
1734
|
+
``make_jaxpr`` function instead.
|
1735
|
+
|
1736
|
+
Parameters
|
1737
|
+
----------
|
1738
|
+
fun : Callable
|
1739
|
+
The function whose ``jaxpr`` is to be computed. Its positional
|
1740
|
+
arguments and return value should be arrays, scalars, or standard Python
|
1741
|
+
containers (tuple/list/dict) thereof.
|
1742
|
+
static_argnums : int or iterable of int, optional
|
1743
|
+
See the :py:func:`jax.jit` docstring.
|
1744
|
+
axis_env : sequence of tuple, optional
|
1745
|
+
A sequence of pairs where the first element is an axis
|
1746
|
+
name and the second element is a positive integer representing the size of
|
1747
|
+
the mapped axis with that name. This parameter is useful when lowering
|
1748
|
+
functions that involve parallel communication collectives, and it
|
1749
|
+
specifies the axis name/size environment that would be set up by
|
1750
|
+
applications of :py:func:`jax.pmap`.
|
1751
|
+
return_shape : bool, default False
|
1752
|
+
If ``True``, the wrapped function returns a pair where the first element
|
1753
|
+
is the ``ClosedJaxpr`` representation of ``fun`` and the second element
|
1754
|
+
is a pytree with the same structure as the output of ``fun`` and where
|
1755
|
+
the leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
1756
|
+
attributes representing the corresponding types of the output leaves.
|
1757
|
+
abstracted_axes : Any, optional
|
1758
|
+
Axes specifications for abstract interpretation.
|
1759
|
+
|
1760
|
+
Returns
|
1761
|
+
-------
|
1762
|
+
Callable
|
1763
|
+
A wrapped version of ``fun`` that when applied to example arguments returns
|
1764
|
+
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
1765
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
1766
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
1767
|
+
representation of ``fun`` and the second element is a pytree representing
|
1768
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
1769
|
+
|
1770
|
+
Notes
|
1771
|
+
-----
|
1772
|
+
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
1773
|
+
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
1774
|
+
with let-bindings. This function adapts a function to return its
|
1775
|
+
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
1776
|
+
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
1777
|
+
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
1778
|
+
|
1779
|
+
Examples
|
1780
|
+
--------
|
1781
|
+
.. code-block:: python
|
1782
|
+
|
1783
|
+
>>> import jax
|
1784
|
+
>>>
|
1785
|
+
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
1786
|
+
>>> print(f(3.0))
|
1787
|
+
-0.83602
|
1788
|
+
>>> _make_jaxpr(f)(3.0)
|
1789
|
+
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
1790
|
+
>>> _make_jaxpr(jax.grad(f))(3.0)
|
1791
|
+
{ lambda ; a:f32[]. let
|
1792
|
+
b:f32[] = cos a
|
1793
|
+
c:f32[] = sin a
|
1794
|
+
_:f32[] = sin b
|
1795
|
+
d:f32[] = cos b
|
1796
|
+
e:f32[] = mul 1.0 d
|
1797
|
+
f:f32[] = neg e
|
1798
|
+
g:f32[] = mul f c
|
1799
|
+
in (g,) }
|
1800
|
+
"""
|
1801
|
+
_check_callable(fun)
|
1802
|
+
static_argnums = _ensure_index_tuple(static_argnums)
|
1803
|
+
|
1804
|
+
def _abstractify(args, kwargs):
|
1805
|
+
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
1806
|
+
if abstracted_axes is None:
|
1807
|
+
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
1808
|
+
else:
|
1809
|
+
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
1810
|
+
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
1811
|
+
in_avals, keep_inputs = unzip2(in_type)
|
1812
|
+
return in_avals, in_tree, keep_inputs
|
1813
|
+
|
1814
|
+
@wraps(fun)
|
1815
|
+
@api_boundary
|
1816
|
+
def make_jaxpr_f(*args, **kwargs):
|
1817
|
+
f = wrap_init(fun, (), {}, 'brainstate.transform.make_jaxpr')
|
1818
|
+
if static_argnums:
|
1819
|
+
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
1820
|
+
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
1821
|
+
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
1822
|
+
in_type = tuple(safe_zip(in_avals, keep_inputs))
|
1823
|
+
f, out_tree = _flatten_fun(f, in_tree)
|
1824
|
+
f = annotate(f, in_type)
|
1825
|
+
if jax.__version_info__ < (0, 5, 0):
|
1826
|
+
debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
1827
|
+
with ExitStack() as stack:
|
1828
|
+
if axis_env is not None:
|
1829
|
+
stack.enter_context(extend_axis_env_nd(axis_env))
|
1830
|
+
if jax.__version_info__ < (0, 5, 0):
|
1831
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
|
1832
|
+
else:
|
1833
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
1834
|
+
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
1835
|
+
if return_shape:
|
1836
|
+
out_avals, _ = unzip2(out_type)
|
1837
|
+
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
1838
|
+
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
1839
|
+
return closed_jaxpr
|
1840
|
+
|
1841
|
+
make_jaxpr_f.__module__ = "brainstate.transform"
|
1842
|
+
if hasattr(fun, "__qualname__"):
|
1843
|
+
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
1844
|
+
if hasattr(fun, "__name__"):
|
1845
|
+
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
1846
|
+
return make_jaxpr_f
|
1847
|
+
|
1848
|
+
|
1849
|
+
def make_hashable(obj):
|
1850
|
+
"""
|
1851
|
+
Convert a pytree into a hashable representation.
|
1852
|
+
|
1853
|
+
Parameters
|
1854
|
+
----------
|
1855
|
+
obj : Any
|
1856
|
+
A pytree object (list, tuple, dict, set, or JAX pytree structure).
|
1857
|
+
|
1858
|
+
Returns
|
1859
|
+
-------
|
1860
|
+
Hashable
|
1861
|
+
A hashable representation of the input object. Lists become tuples,
|
1862
|
+
dicts become sorted tuples of key-value pairs, sets become frozensets,
|
1863
|
+
and other pytrees are flattened using JAX's tree utilities.
|
1864
|
+
"""
|
1865
|
+
if isinstance(obj, (list, tuple)):
|
1866
|
+
return tuple(make_hashable(item) for item in obj)
|
1867
|
+
elif isinstance(obj, dict):
|
1868
|
+
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
|
1869
|
+
elif isinstance(obj, set):
|
1870
|
+
return frozenset(make_hashable(item) for item in obj)
|
1871
|
+
else:
|
1872
|
+
# return obj
|
1873
|
+
# Use JAX's tree_util for any other pytree structures
|
1874
|
+
try:
|
1875
|
+
leaves, treedef = jax.tree.flatten(obj)
|
1876
|
+
return treedef, tuple(leaves)
|
1877
|
+
except (TypeError, ValueError):
|
1878
|
+
# Assume obj is already hashable
|
1879
|
+
return obj
|
1880
|
+
|
1881
|
+
|
1882
|
+
class IdentitySet(MutableSet):
|
1883
|
+
"""Set that compares objects by identity.
|
1884
|
+
|
1885
|
+
This is a set that compares objects by identity instead of equality. It is
|
1886
|
+
useful for storing objects that are not hashable or that should be compared
|
1887
|
+
by identity.
|
1888
|
+
|
1889
|
+
This is a mutable set, but it does not support the ``__hash__`` method and
|
1890
|
+
therefore cannot be used as a dictionary key or as an element of another set.
|
1891
|
+
"""
|
1892
|
+
|
1893
|
+
def __init__(self, iterable=None):
|
1894
|
+
self._data = {}
|
1895
|
+
if iterable is not None:
|
1896
|
+
self.update(iterable)
|
1897
|
+
|
1898
|
+
def __contains__(self, value):
|
1899
|
+
return id(value) in self._data
|
1900
|
+
|
1901
|
+
def __iter__(self):
|
1902
|
+
return iter(self._data.values())
|
1903
|
+
|
1904
|
+
def __len__(self):
|
1905
|
+
return len(self._data)
|
1906
|
+
|
1907
|
+
def add(self, value):
|
1908
|
+
self._data[id(value)] = value
|
1909
|
+
|
1910
|
+
def discard(self, value):
|
1911
|
+
self._data.pop(id(value), None)
|
1912
|
+
|
1913
|
+
def __repr__(self):
|
1914
|
+
return f"IdentitySet({list(repr(x) for x in self._data.values())})"
|
1915
|
+
|
1916
|
+
def __str__(self):
|
1917
|
+
return f"IdentitySet({list(str(x) for x in self._data.values())})"
|
1918
|
+
|
1919
|
+
|
1920
|
+
def constant_fold_jaxpr(jaxpr: Jaxpr):
|
1921
|
+
"""
|
1922
|
+
Given a jaxpr, return a new jaxpr with all constant folding done.
|
1923
|
+
"""
|
1924
|
+
return _partial_eval_jaxpr(jaxpr, {})
|
1925
|
+
|
1926
|
+
|
1927
|
+
def _partial_eval_jaxpr(jaxpr, env):
|
1928
|
+
env = env.copy()
|
1929
|
+
new_eqns = []
|
1930
|
+
|
1931
|
+
def read(var):
|
1932
|
+
if isinstance(var, Literal):
|
1933
|
+
return var.val
|
1934
|
+
else:
|
1935
|
+
return env.get(var, None)
|
1936
|
+
|
1937
|
+
def read_or_self(var):
|
1938
|
+
out = read(var)
|
1939
|
+
if out is None:
|
1940
|
+
return var
|
1941
|
+
elif isinstance(out, Var):
|
1942
|
+
return out
|
1943
|
+
elif isinstance(out, Literal):
|
1944
|
+
return Literal(out.val, var.aval)
|
1945
|
+
else:
|
1946
|
+
assert not isinstance(out, Jaxpr)
|
1947
|
+
return Literal(out, var.aval)
|
1948
|
+
|
1949
|
+
for eqn in jaxpr.eqns:
|
1950
|
+
vals = [read(var) for var in eqn.invars]
|
1951
|
+
if eqn.primitive.name in _constant_fold_blacklist:
|
1952
|
+
new_eqns.append(eqn)
|
1953
|
+
elif all(val is not None for val in vals):
|
1954
|
+
# go ahead and eval it
|
1955
|
+
out = _eval_eqn(eqn, vals)
|
1956
|
+
|
1957
|
+
# two options: either it's a jaxpr result (partial eval) or it's a value or a list of values
|
1958
|
+
if isinstance(out, Jaxpr):
|
1959
|
+
# we need to inline this
|
1960
|
+
new_eqns.extend(out.eqns)
|
1961
|
+
out = out.outvars
|
1962
|
+
elif not isinstance(out, tuple) and not isinstance(out, list):
|
1963
|
+
out = (out,)
|
1964
|
+
|
1965
|
+
for var, val in zip(eqn.outvars, out):
|
1966
|
+
assert not isinstance(val, Jaxpr)
|
1967
|
+
if isinstance(val, Literal):
|
1968
|
+
env[var] = val.val
|
1969
|
+
else:
|
1970
|
+
env[var] = val
|
1971
|
+
else:
|
1972
|
+
new_eqns.append(eqn)
|
1973
|
+
|
1974
|
+
# now that we've eval everything, inline all the constants
|
1975
|
+
out_eqns = []
|
1976
|
+
for eqn in new_eqns:
|
1977
|
+
eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars))
|
1978
|
+
out_eqns.append(eqn)
|
1979
|
+
|
1980
|
+
invars_still_used = IdentitySet()
|
1981
|
+
for eqn in out_eqns:
|
1982
|
+
for var in eqn.invars:
|
1983
|
+
invars_still_used.add(var)
|
1984
|
+
|
1985
|
+
invars = tuple(var for var in jaxpr.invars if var in invars_still_used)
|
1986
|
+
|
1987
|
+
# sub in any constants for outvars
|
1988
|
+
outvars = tuple(read_or_self(var) for var in jaxpr.outvars)
|
1989
|
+
|
1990
|
+
return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars, debug_info=None)
|
1991
|
+
|
1992
|
+
|
1993
|
+
def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jax.Array]:
|
1994
|
+
if eqn.primitive.name == "closed_call":
|
1995
|
+
assert eqn.primitive.call_primitive
|
1996
|
+
assert not eqn.primitive.map_primitive
|
1997
|
+
|
1998
|
+
out = _partial_eval_jaxpr(
|
1999
|
+
eqn.params['call_jaxpr'].jaxpr,
|
2000
|
+
{
|
2001
|
+
var: val
|
2002
|
+
for var, val in
|
2003
|
+
zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)
|
2004
|
+
}
|
2005
|
+
)
|
2006
|
+
elif eqn.primitive.name == "scan":
|
2007
|
+
out = eqn.primitive.bind(*vals, **eqn.params)
|
2008
|
+
else:
|
2009
|
+
out = eqn.primitive.bind(*vals, **eqn.params)
|
2010
|
+
return out
|
2011
|
+
|
2012
|
+
|
2013
|
+
_constant_fold_blacklist = {
|
2014
|
+
'broadcast_in_dim',
|
2015
|
+
'broadcast',
|
2016
|
+
}
|