brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,888 +1,888 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
|
18
|
-
written by the function. These state transformations are foundational for the BrainCore library. These utilities
|
19
|
-
include two basic functions: `StatefulFunction` and `make_jaxpr`.
|
20
|
-
|
21
|
-
|
22
|
-
``StatefulFunction``
|
23
|
-
--------------------
|
24
|
-
|
25
|
-
The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
|
26
|
-
JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
|
27
|
-
The class provides the following methods:
|
28
|
-
|
29
|
-
- `make_jaxpr`: creates the JAX Jaxpr of the function.
|
30
|
-
- `jaxpr_call`: calls the function at the JAX Jaxpr level.
|
31
|
-
- `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
|
32
|
-
- `get_states`: returns the states that are read and written by the function.
|
33
|
-
- `get_read_states`: returns the states that are read by the function.
|
34
|
-
- `get_write_states`: returns the states that are written by the function.
|
35
|
-
- `get_static_args`: returns the static arguments from the arguments.
|
36
|
-
- `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
|
37
|
-
written by the function.
|
38
|
-
- `get_jaxpr`: returns the JAX Jaxpr of the function.
|
39
|
-
- `get_out_shapes`: returns the output shapes of the function.
|
40
|
-
- `get_out_treedef`: returns the output tree of the function.
|
41
|
-
|
42
|
-
``make_jaxpr``
|
43
|
-
--------------
|
44
|
-
|
45
|
-
The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
|
46
|
-
arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
|
47
|
-
`ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
|
48
|
-
returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
|
49
|
-
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
|
50
|
-
function.
|
51
|
-
|
52
|
-
"""
|
53
|
-
|
54
|
-
import functools
|
55
|
-
import inspect
|
56
|
-
import operator
|
57
|
-
from collections.abc import Hashable, Iterable, Sequence
|
58
|
-
from contextlib import ExitStack
|
59
|
-
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
60
|
-
|
61
|
-
import jax
|
62
|
-
from jax._src import source_info_util
|
63
|
-
from jax._src.linear_util import annotate
|
64
|
-
from jax._src.traceback_util import api_boundary
|
65
|
-
from jax.api_util import shaped_abstractify
|
66
|
-
from jax.extend.linear_util import transformation_with_aux
|
67
|
-
from jax.interpreters import partial_eval as pe
|
68
|
-
|
69
|
-
from brainstate._compatible_import import (
|
70
|
-
ClosedJaxpr,
|
71
|
-
extend_axis_env_nd,
|
72
|
-
safe_map,
|
73
|
-
safe_zip,
|
74
|
-
unzip2,
|
75
|
-
wraps,
|
76
|
-
wrap_init,
|
77
|
-
)
|
78
|
-
from brainstate._state import State, StateTraceStack
|
79
|
-
from brainstate._utils import set_module_as
|
80
|
-
from brainstate.typing import PyTree
|
81
|
-
from brainstate.util import PrettyObject
|
82
|
-
|
83
|
-
AxisName = Hashable
|
84
|
-
|
85
|
-
__all__ = [
|
86
|
-
"StatefulFunction",
|
87
|
-
"make_jaxpr",
|
88
|
-
]
|
89
|
-
|
90
|
-
|
91
|
-
def _ensure_str(x: str) -> str:
|
92
|
-
if not isinstance(x, str):
|
93
|
-
raise TypeError(f"argument is not a string: {x}")
|
94
|
-
return x
|
95
|
-
|
96
|
-
|
97
|
-
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
98
|
-
"""Convert x to a tuple of indices."""
|
99
|
-
x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
100
|
-
try:
|
101
|
-
return (operator.index(x),)
|
102
|
-
except TypeError:
|
103
|
-
return tuple(safe_map(operator.index, x))
|
104
|
-
|
105
|
-
|
106
|
-
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
107
|
-
"""Convert x to a tuple of strings."""
|
108
|
-
if isinstance(x, str):
|
109
|
-
return (x,)
|
110
|
-
else:
|
111
|
-
return tuple(safe_map(_ensure_str, x))
|
112
|
-
|
113
|
-
|
114
|
-
def _jax_v04_new_arg_fn(frame, trace, aval):
|
115
|
-
"""
|
116
|
-
Transform a new argument to a tracer.
|
117
|
-
|
118
|
-
Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
|
119
|
-
|
120
|
-
Args:
|
121
|
-
frame: The frame.
|
122
|
-
trace: The trace.
|
123
|
-
aval: The abstract value.
|
124
|
-
|
125
|
-
Returns:
|
126
|
-
The tracer.
|
127
|
-
"""
|
128
|
-
tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
|
129
|
-
frame.tracers.append(tracer)
|
130
|
-
frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
|
131
|
-
frame.invars.append(var)
|
132
|
-
return tracer
|
133
|
-
|
134
|
-
|
135
|
-
def _jax_v04_new_jax_trace():
|
136
|
-
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
137
|
-
frame = main.jaxpr_stack[-1]
|
138
|
-
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
139
|
-
return frame, trace
|
140
|
-
|
141
|
-
|
142
|
-
def _jax_v04_new_arg():
|
143
|
-
# Should be within the calling of ``jax.make_jaxpr()``
|
144
|
-
frame, trace = _jax_v04_new_jax_trace()
|
145
|
-
# Set the function to transform the new argument to a tracer
|
146
|
-
fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
|
147
|
-
return fn
|
148
|
-
|
149
|
-
|
150
|
-
def _jax_new_version_new_arg():
|
151
|
-
trace = jax.core.trace_ctx.trace
|
152
|
-
|
153
|
-
def wrapper(x):
|
154
|
-
if jax.__version_info__ < (0, 6, 1):
|
155
|
-
return trace.new_arg(shaped_abstractify(x))
|
156
|
-
else:
|
157
|
-
return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
|
158
|
-
|
159
|
-
return wrapper
|
160
|
-
|
161
|
-
|
162
|
-
def _init_state_trace_stack(name) -> StateTraceStack:
|
163
|
-
state_trace: StateTraceStack = StateTraceStack(name=name)
|
164
|
-
|
165
|
-
if jax.__version_info__ < (0, 4, 36):
|
166
|
-
state_trace.set_new_arg(_jax_v04_new_arg())
|
167
|
-
else:
|
168
|
-
state_trace.set_new_arg(_jax_new_version_new_arg())
|
169
|
-
return state_trace
|
170
|
-
|
171
|
-
|
172
|
-
default_cache_key = ((), ())
|
173
|
-
|
174
|
-
|
175
|
-
class StatefulFunction(PrettyObject):
|
176
|
-
"""
|
177
|
-
A wrapper class for a function that collects the states that are read and written by the function. The states are
|
178
|
-
collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
|
179
|
-
manage the states in the JAX program. The class provides a function called `states` that returns the states
|
180
|
-
that are read and written by the function. The class provides a function called `to_state_manager` that returns
|
181
|
-
a StateDictManager instance that contains the states that are read and written by the function. The class provides
|
182
|
-
a function called `__call__` that wraps the function and returns the states that are read and written by the
|
183
|
-
function and the output of the function.
|
184
|
-
|
185
|
-
Args:
|
186
|
-
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
187
|
-
arguments and return value should be arrays, scalars, or standard Python
|
188
|
-
containers (tuple/list/dict) thereof.
|
189
|
-
static_argnums: See the :py:func:`jax.jit` docstring.
|
190
|
-
static_argnames: See the :py:func:`jax.jit` docstring.
|
191
|
-
axis_env: Optional, a sequence of pairs where the first element is an axis
|
192
|
-
name and the second element is a positive integer representing the size of
|
193
|
-
the mapped axis with that name. This parameter is useful when lowering
|
194
|
-
functions that involve parallel communication collectives, and it
|
195
|
-
specifies the axis name/size environment that would be set up by
|
196
|
-
applications of :py:func:`jax.pmap`.
|
197
|
-
abstracted_axes: Optional, a pytree with the same structure as the input
|
198
|
-
arguments to ``fun``. The leaves of the pytree can be either None or a
|
199
|
-
dict with axis names as keys and integers as values. If the leaf is None,
|
200
|
-
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
201
|
-
the corresponding axis is abstracted, and the dict specifies the axis name
|
202
|
-
and size. The abstracted axes are used to infer the input type of the
|
203
|
-
function. If None, then all axes are abstracted.
|
204
|
-
state_returns: Optional, a string or a tuple of strings. The default is
|
205
|
-
``('read', 'write')``. The strings specify the categories of states to be
|
206
|
-
returned by the wrapped function. The categories are ``'read'`` and
|
207
|
-
``'write'``. If the category is ``'read'``, then the wrapped function
|
208
|
-
returns the states that are read by the function. If the category is
|
209
|
-
``'write'``, then the wrapped function returns the states that are written
|
210
|
-
by the function. If the category is ``'read'`` and ``'write'``, then the
|
211
|
-
wrapped function returns both the read and write states.
|
212
|
-
|
213
|
-
"""
|
214
|
-
__module__ = "brainstate.compile"
|
215
|
-
|
216
|
-
def __init__(
|
217
|
-
self,
|
218
|
-
fun: Callable,
|
219
|
-
static_argnums: Union[int, Iterable[int]] = (),
|
220
|
-
static_argnames: Union[str, Iterable[str]] = (),
|
221
|
-
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
222
|
-
abstracted_axes: Optional[Any] = None,
|
223
|
-
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
|
224
|
-
cache_type: Optional[str] = None,
|
225
|
-
name: Optional[str] = None,
|
226
|
-
):
|
227
|
-
# explicit parameters
|
228
|
-
self.fun = fun
|
229
|
-
self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
230
|
-
self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
|
231
|
-
self.axis_env = axis_env
|
232
|
-
self.abstracted_axes = abstracted_axes
|
233
|
-
self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
|
234
|
-
assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
|
235
|
-
self.name = name
|
236
|
-
|
237
|
-
# implicit parameters
|
238
|
-
self.cache_type = cache_type
|
239
|
-
self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
|
240
|
-
self._cached_out_shapes: Dict[Any, PyTree] = dict()
|
241
|
-
self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
|
242
|
-
self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
|
243
|
-
|
244
|
-
def __pretty_repr_item__(self, k, v):
|
245
|
-
if k.startswith('_'):
|
246
|
-
return None
|
247
|
-
return k, v
|
248
|
-
|
249
|
-
def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
|
250
|
-
"""
|
251
|
-
Read the JAX Jaxpr representation of the function.
|
252
|
-
|
253
|
-
Args:
|
254
|
-
cache_key: The hashable key.
|
255
|
-
|
256
|
-
Returns:
|
257
|
-
The JAX Jaxpr representation of the function.
|
258
|
-
"""
|
259
|
-
if cache_key is None:
|
260
|
-
cache_key = default_cache_key
|
261
|
-
if cache_key not in self._cached_jaxpr:
|
262
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
263
|
-
return self._cached_jaxpr[cache_key]
|
264
|
-
|
265
|
-
def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
|
266
|
-
"""
|
267
|
-
Read the output shapes of the function.
|
268
|
-
|
269
|
-
Args:
|
270
|
-
cache_key: The hashable key.
|
271
|
-
|
272
|
-
Returns:
|
273
|
-
The output shapes of the function.
|
274
|
-
"""
|
275
|
-
if cache_key is None:
|
276
|
-
cache_key = default_cache_key
|
277
|
-
if cache_key not in self._cached_out_shapes:
|
278
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
279
|
-
return self._cached_out_shapes[cache_key]
|
280
|
-
|
281
|
-
def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
|
282
|
-
"""
|
283
|
-
Read the output tree of the function.
|
284
|
-
|
285
|
-
Args:
|
286
|
-
cache_key: The hashable key.
|
287
|
-
|
288
|
-
Returns:
|
289
|
-
The output tree of the function.
|
290
|
-
"""
|
291
|
-
if cache_key is None:
|
292
|
-
cache_key = default_cache_key
|
293
|
-
if cache_key not in self._cached_jaxpr_out_tree:
|
294
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
295
|
-
return self._cached_jaxpr_out_tree[cache_key]
|
296
|
-
|
297
|
-
def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
|
298
|
-
"""
|
299
|
-
Read the state trace of the function.
|
300
|
-
|
301
|
-
Args:
|
302
|
-
cache_key: The hashable key.
|
303
|
-
|
304
|
-
Returns:
|
305
|
-
The state trace of the function.
|
306
|
-
"""
|
307
|
-
if cache_key is None:
|
308
|
-
cache_key = default_cache_key
|
309
|
-
if cache_key not in self._cached_state_trace:
|
310
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
311
|
-
return self._cached_state_trace[cache_key]
|
312
|
-
|
313
|
-
def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
314
|
-
"""
|
315
|
-
Read the states that are read and written by the function.
|
316
|
-
|
317
|
-
Args:
|
318
|
-
cache_key: The hashable key.
|
319
|
-
|
320
|
-
Returns:
|
321
|
-
The states that are read and written by the function.
|
322
|
-
"""
|
323
|
-
if cache_key is None:
|
324
|
-
cache_key = default_cache_key
|
325
|
-
return tuple(self.get_state_trace(cache_key).states)
|
326
|
-
|
327
|
-
def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
328
|
-
"""
|
329
|
-
Read the states that are read by the function.
|
330
|
-
|
331
|
-
Args:
|
332
|
-
cache_key: The hashable key.
|
333
|
-
|
334
|
-
Returns:
|
335
|
-
The states that are read by the function.
|
336
|
-
"""
|
337
|
-
if cache_key is None:
|
338
|
-
cache_key = default_cache_key
|
339
|
-
return self.get_state_trace(cache_key).get_read_states()
|
340
|
-
|
341
|
-
def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
342
|
-
"""
|
343
|
-
Read the states that are written by the function.
|
344
|
-
|
345
|
-
Args:
|
346
|
-
cache_key: The hashable key.
|
347
|
-
|
348
|
-
Returns:
|
349
|
-
The states that are written by the function.
|
350
|
-
"""
|
351
|
-
if cache_key is None:
|
352
|
-
cache_key = default_cache_key
|
353
|
-
return self.get_state_trace(cache_key).get_write_states()
|
354
|
-
|
355
|
-
def _check_input_ouput(self, x):
|
356
|
-
if isinstance(x, State):
|
357
|
-
x.raise_error_with_source_info(
|
358
|
-
ValueError(
|
359
|
-
'Inputs/outputs for brainstate transformations cannot be an instance of State. '
|
360
|
-
f'But we got {x}'
|
361
|
-
)
|
362
|
-
)
|
363
|
-
|
364
|
-
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
365
|
-
"""
|
366
|
-
Get the static arguments from the arguments.
|
367
|
-
|
368
|
-
Args:
|
369
|
-
*args: The arguments to the function.
|
370
|
-
**kwargs: The keyword arguments to the function.
|
371
|
-
|
372
|
-
Returns:
|
373
|
-
The static arguments and keyword arguments as a tuple.
|
374
|
-
"""
|
375
|
-
if self.cache_type == 'jit':
|
376
|
-
static_args, dyn_args = [], []
|
377
|
-
for i, arg in enumerate(args):
|
378
|
-
if i in self.static_argnums:
|
379
|
-
static_args.append(arg)
|
380
|
-
else:
|
381
|
-
dyn_args.append(arg)
|
382
|
-
dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
|
383
|
-
static_kwargs, dyn_kwargs = [], []
|
384
|
-
for k, v in kwargs.items():
|
385
|
-
if k in self.static_argnames:
|
386
|
-
static_kwargs.append((k, v))
|
387
|
-
else:
|
388
|
-
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
389
|
-
|
390
|
-
static_args = make_hashable(tuple(static_args))
|
391
|
-
dyn_args = make_hashable(tuple(dyn_args))
|
392
|
-
static_kwargs = make_hashable(static_kwargs)
|
393
|
-
dyn_kwargs = make_hashable(dyn_kwargs)
|
394
|
-
|
395
|
-
cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
|
396
|
-
elif self.cache_type is None:
|
397
|
-
num_arg = len(args)
|
398
|
-
static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
|
399
|
-
static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
|
400
|
-
|
401
|
-
# Make everything hashable
|
402
|
-
static_args = make_hashable(static_args)
|
403
|
-
static_kwargs = make_hashable(static_kwargs)
|
404
|
-
|
405
|
-
cache_key = (static_args, static_kwargs)
|
406
|
-
else:
|
407
|
-
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
408
|
-
|
409
|
-
return cache_key
|
410
|
-
|
411
|
-
def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
|
412
|
-
"""
|
413
|
-
Compile the function, and get the states that are read and written by this function.
|
414
|
-
|
415
|
-
Args:
|
416
|
-
*args: The arguments to the function.
|
417
|
-
**kwargs: The keyword arguments to the function.
|
418
|
-
|
419
|
-
Returns:
|
420
|
-
The states that are read and written by the function.
|
421
|
-
"""
|
422
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
423
|
-
if cache_key not in self._cached_state_trace:
|
424
|
-
self.make_jaxpr(*args, **kwargs)
|
425
|
-
return self.get_states(cache_key)
|
426
|
-
|
427
|
-
def compile_function_and_get_state_trace(
|
428
|
-
self, *args, return_only_write: bool = False, **kwargs
|
429
|
-
) -> StateTraceStack:
|
430
|
-
"""
|
431
|
-
Compile the function, and get the states that are read and written by this function.
|
432
|
-
|
433
|
-
Args:
|
434
|
-
*args: The arguments to the function.
|
435
|
-
**kwargs: The keyword arguments to the function.
|
436
|
-
return_only_write: If True, only return the states that are written by the function.
|
437
|
-
|
438
|
-
Returns:
|
439
|
-
The state trace stack.
|
440
|
-
"""
|
441
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
442
|
-
if cache_key not in self._cached_state_trace:
|
443
|
-
self.make_jaxpr(*args, **kwargs, return_only_write=return_only_write)
|
444
|
-
return self.get_state_trace(cache_key)
|
445
|
-
|
446
|
-
def clear_cache(self) -> None:
|
447
|
-
"""
|
448
|
-
Clear the compilation cache.
|
449
|
-
"""
|
450
|
-
self._cached_jaxpr.clear()
|
451
|
-
self._cached_out_shapes.clear()
|
452
|
-
self._cached_jaxpr_out_tree.clear()
|
453
|
-
self._cached_state_trace.clear()
|
454
|
-
|
455
|
-
def _wrapped_fun_to_eval(
|
456
|
-
self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
|
457
|
-
) -> Tuple[Any, Tuple[State, ...]]:
|
458
|
-
"""
|
459
|
-
Wrap the function and return the states that are read and written by the function and the output of the function.
|
460
|
-
|
461
|
-
Args:
|
462
|
-
*args: The arguments to the function.
|
463
|
-
**kwargs: The keyword arguments to the function.
|
464
|
-
|
465
|
-
Returns:
|
466
|
-
A tuple of the states that are read and written by the function and the output of the function.
|
467
|
-
"""
|
468
|
-
# state trace
|
469
|
-
state_trace = _init_state_trace_stack(self.name)
|
470
|
-
self._cached_state_trace[cache_key] = state_trace
|
471
|
-
with state_trace:
|
472
|
-
out = self.fun(*args, **dyn_kwargs, **static_kwargs)
|
473
|
-
state_values = (
|
474
|
-
state_trace.get_write_state_values(True)
|
475
|
-
if return_only_write else
|
476
|
-
state_trace.get_state_values()
|
477
|
-
)
|
478
|
-
state_trace.recovery_original_values()
|
479
|
-
|
480
|
-
# State instance as functional returns is not allowed.
|
481
|
-
# Checking whether the states are returned.
|
482
|
-
jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
|
483
|
-
return out, state_values
|
484
|
-
|
485
|
-
def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
|
486
|
-
"""Creates a function that produces its jaxpr given example args.
|
487
|
-
|
488
|
-
A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
489
|
-
argument ``return_shape`` is ``True``, then the returned function instead
|
490
|
-
returns a pair where the first element is the ``ClosedJaxpr``
|
491
|
-
representation of ``fun`` and the second element is a pytree representing
|
492
|
-
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
493
|
-
|
494
|
-
Args:
|
495
|
-
*args: The arguments to the function.
|
496
|
-
**kwargs: The keyword arguments to the function.
|
497
|
-
return_only_write: If True, only return the states that are written by the function.
|
498
|
-
"""
|
499
|
-
|
500
|
-
# static args
|
501
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
502
|
-
|
503
|
-
# check input types
|
504
|
-
jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
|
505
|
-
|
506
|
-
if cache_key not in self._cached_state_trace:
|
507
|
-
try:
|
508
|
-
# jaxpr
|
509
|
-
static_kwargs, dyn_kwargs = {}, {}
|
510
|
-
for k, v in kwargs.items():
|
511
|
-
if k in self.static_argnames:
|
512
|
-
static_kwargs[k] = v
|
513
|
-
else:
|
514
|
-
dyn_kwargs[k] = v
|
515
|
-
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
516
|
-
functools.partial(
|
517
|
-
self._wrapped_fun_to_eval,
|
518
|
-
cache_key,
|
519
|
-
static_kwargs,
|
520
|
-
return_only_write=return_only_write
|
521
|
-
),
|
522
|
-
static_argnums=self.static_argnums,
|
523
|
-
axis_env=self.axis_env,
|
524
|
-
return_shape=True,
|
525
|
-
abstracted_axes=self.abstracted_axes
|
526
|
-
)(*args, **dyn_kwargs)
|
527
|
-
# returns
|
528
|
-
self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
529
|
-
self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
|
530
|
-
self._cached_jaxpr[cache_key] = jaxpr
|
531
|
-
|
532
|
-
except Exception as e:
|
533
|
-
try:
|
534
|
-
self._cached_state_trace.pop(cache_key)
|
535
|
-
except KeyError:
|
536
|
-
pass
|
537
|
-
raise e
|
538
|
-
|
539
|
-
return self
|
540
|
-
|
541
|
-
def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
|
542
|
-
"""
|
543
|
-
Call the function at the JAX Jaxpr level.
|
544
|
-
|
545
|
-
Args:
|
546
|
-
state_vals: The state values.
|
547
|
-
*args: The arguments to the function.
|
548
|
-
**kwargs: The keyword arguments to the function.
|
549
|
-
|
550
|
-
Returns:
|
551
|
-
State values and the function output.
|
552
|
-
"""
|
553
|
-
# state checking
|
554
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
555
|
-
states: Sequence[State] = self.get_states(cache_key)
|
556
|
-
assert len(state_vals) == len(states), 'State length mismatch.'
|
557
|
-
|
558
|
-
# parameters
|
559
|
-
kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
|
560
|
-
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
561
|
-
args = jax.tree.flatten((args, kwargs, state_vals))[0]
|
562
|
-
|
563
|
-
# calling the function,
|
564
|
-
# note that this function always returns state values
|
565
|
-
# that both write and read by the function
|
566
|
-
closed_jaxpr = self.get_jaxpr(cache_key)
|
567
|
-
out_treedef = self.get_out_treedef(cache_key)
|
568
|
-
jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
569
|
-
|
570
|
-
# output processing
|
571
|
-
out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
|
572
|
-
assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
|
573
|
-
return new_state_vals, out
|
574
|
-
|
575
|
-
def jaxpr_call_auto(self, *args, **kwargs) -> Any:
|
576
|
-
"""
|
577
|
-
Call the function at the JAX Jaxpr level with automatic state management.
|
578
|
-
|
579
|
-
Args:
|
580
|
-
*args: The arguments to the function.
|
581
|
-
**kwargs: The keyword arguments to the function.
|
582
|
-
|
583
|
-
Returns:
|
584
|
-
The output of the function.
|
585
|
-
"""
|
586
|
-
state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
|
587
|
-
state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
|
588
|
-
state_trace.assign_state_vals(state_vals)
|
589
|
-
return out
|
590
|
-
|
591
|
-
|
592
|
-
@set_module_as("brainstate.compile")
|
593
|
-
def make_jaxpr(
|
594
|
-
fun: Callable,
|
595
|
-
static_argnums: Union[int, Iterable[int]] = (),
|
596
|
-
static_argnames: Union[str, Iterable[str]] = (),
|
597
|
-
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
598
|
-
return_shape: bool = False,
|
599
|
-
abstracted_axes: Optional[Any] = None,
|
600
|
-
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
601
|
-
) -> Callable[
|
602
|
-
...,
|
603
|
-
(Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
604
|
-
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
|
605
|
-
]:
|
606
|
-
"""
|
607
|
-
Creates a function that produces its jaxpr given example args.
|
608
|
-
|
609
|
-
Args:
|
610
|
-
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
611
|
-
arguments and return value should be arrays, scalars, or standard Python
|
612
|
-
containers (tuple/list/dict) thereof.
|
613
|
-
static_argnums: See the :py:func:`jax.jit` docstring.
|
614
|
-
static_argnames: See the :py:func:`jax.jit` docstring.
|
615
|
-
axis_env: Optional, a sequence of pairs where the first element is an axis
|
616
|
-
name and the second element is a positive integer representing the size of
|
617
|
-
the mapped axis with that name. This parameter is useful when lowering
|
618
|
-
functions that involve parallel communication collectives, and it
|
619
|
-
specifies the axis name/size environment that would be set up by
|
620
|
-
applications of :py:func:`jax.pmap`.
|
621
|
-
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
622
|
-
wrapped function returns a pair where the first element is the XLA
|
623
|
-
computation and the second element is a pytree with the same structure as
|
624
|
-
the output of ``fun`` and where the leaves are objects with ``shape``,
|
625
|
-
``dtype``, and ``named_shape`` attributes representing the corresponding
|
626
|
-
types of the output leaves.
|
627
|
-
abstracted_axes: Optional, a pytree with the same structure as the input
|
628
|
-
arguments to ``fun``. The leaves of the pytree can be either None or a
|
629
|
-
dict with axis names as keys and integers as values. If the leaf is None,
|
630
|
-
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
631
|
-
the corresponding axis is abstracted, and the dict specifies the axis name
|
632
|
-
and size. The abstracted axes are used to infer the input type of the
|
633
|
-
function. If None, then all axes are abstracted.
|
634
|
-
state_returns: Optional, a string or a tuple of strings. The default is
|
635
|
-
``('read', 'write')``. The strings specify the categories of states to be
|
636
|
-
returned by the wrapped function. The categories are ``'read'`` and
|
637
|
-
``'write'``. If the category is ``'read'``, then the wrapped function
|
638
|
-
returns the states that are read by the function. If the category is
|
639
|
-
``'write'``, then the wrapped function returns the states that are written
|
640
|
-
by the function. If the category is ``'read'`` and ``'write'``, then the
|
641
|
-
wrapped function returns both the read and write states.
|
642
|
-
|
643
|
-
|
644
|
-
Returns:
|
645
|
-
A wrapped version of ``fun`` that when applied to example arguments returns
|
646
|
-
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
647
|
-
argument ``return_shape`` is ``True``, then the returned function instead
|
648
|
-
returns a pair where the first element is the ``ClosedJaxpr``
|
649
|
-
representation of ``fun`` and the second element is a pytree representing
|
650
|
-
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
651
|
-
|
652
|
-
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
653
|
-
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
654
|
-
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
655
|
-
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
656
|
-
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
657
|
-
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
658
|
-
|
659
|
-
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
660
|
-
instead give a few examples.
|
661
|
-
|
662
|
-
>>> import jax
|
663
|
-
>>> import brainstate as brainstate
|
664
|
-
>>>
|
665
|
-
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
666
|
-
>>> print(f(3.0))
|
667
|
-
-0.83602
|
668
|
-
>>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
|
669
|
-
>>> jaxpr
|
670
|
-
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
671
|
-
>>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
|
672
|
-
>>> jaxpr
|
673
|
-
{ lambda ; a:f32[]. let
|
674
|
-
b:f32[] = cos a
|
675
|
-
c:f32[] = sin a
|
676
|
-
_:f32[] = sin b
|
677
|
-
d:f32[] = cos b
|
678
|
-
e:f32[] = mul 1.0 d
|
679
|
-
f:f32[] = neg e
|
680
|
-
g:f32[] = mul f c
|
681
|
-
in (g,) }
|
682
|
-
"""
|
683
|
-
|
684
|
-
stateful_fun = StatefulFunction(
|
685
|
-
fun,
|
686
|
-
static_argnums=static_argnums,
|
687
|
-
static_argnames=static_argnames,
|
688
|
-
axis_env=axis_env,
|
689
|
-
abstracted_axes=abstracted_axes,
|
690
|
-
state_returns=state_returns,
|
691
|
-
name='make_jaxpr'
|
692
|
-
)
|
693
|
-
|
694
|
-
@wraps(fun)
|
695
|
-
def make_jaxpr_f(*args, **kwargs):
|
696
|
-
stateful_fun.make_jaxpr(*args, **kwargs)
|
697
|
-
cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
|
698
|
-
if return_shape:
|
699
|
-
return (stateful_fun.get_jaxpr(cache_key),
|
700
|
-
stateful_fun.get_states(cache_key),
|
701
|
-
stateful_fun.get_out_shapes(cache_key)[0])
|
702
|
-
else:
|
703
|
-
return (stateful_fun.get_jaxpr(cache_key),
|
704
|
-
stateful_fun.get_states(cache_key))
|
705
|
-
|
706
|
-
# wrapped jaxpr builder function
|
707
|
-
make_jaxpr_f.__module__ = "brainstate.compile"
|
708
|
-
if hasattr(fun, "__qualname__"):
|
709
|
-
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
710
|
-
if hasattr(fun, "__name__"):
|
711
|
-
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
712
|
-
return make_jaxpr_f
|
713
|
-
|
714
|
-
|
715
|
-
def _check_callable(fun):
|
716
|
-
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
717
|
-
# is that we can't take weak references to them, which the C++ JIT requires.
|
718
|
-
if isinstance(fun, staticmethod):
|
719
|
-
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
720
|
-
if not callable(fun):
|
721
|
-
raise TypeError(f"Expected a callable value, got {fun}")
|
722
|
-
if inspect.isgeneratorfunction(fun):
|
723
|
-
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
724
|
-
|
725
|
-
|
726
|
-
def _broadcast_prefix(
|
727
|
-
prefix_tree: Any,
|
728
|
-
full_tree: Any,
|
729
|
-
is_leaf: Callable[[Any], bool] | None = None
|
730
|
-
) -> list[Any]:
|
731
|
-
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
732
|
-
# ValueError; use prefix_errors to find disagreements and raise more precise
|
733
|
-
# error messages.
|
734
|
-
result = []
|
735
|
-
num_leaves = lambda t: jax.tree.structure(t).num_leaves
|
736
|
-
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
737
|
-
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
738
|
-
return result
|
739
|
-
|
740
|
-
|
741
|
-
def _flat_axes_specs(
|
742
|
-
abstracted_axes, *args, **kwargs
|
743
|
-
) -> list[pe.AbstractedAxesSpec]:
|
744
|
-
if kwargs:
|
745
|
-
raise NotImplementedError
|
746
|
-
|
747
|
-
def ax_leaf(l):
|
748
|
-
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
|
749
|
-
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
|
750
|
-
|
751
|
-
return _broadcast_prefix(abstracted_axes, args, ax_leaf)
|
752
|
-
|
753
|
-
|
754
|
-
@transformation_with_aux
|
755
|
-
def _flatten_fun(in_tree, *args_flat):
|
756
|
-
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
|
757
|
-
ans = yield py_args, py_kwargs
|
758
|
-
yield jax.tree.flatten(ans)
|
759
|
-
|
760
|
-
|
761
|
-
def _make_jaxpr(
|
762
|
-
fun: Callable,
|
763
|
-
static_argnums: int | Iterable[int] = (),
|
764
|
-
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
765
|
-
return_shape: bool = False,
|
766
|
-
abstracted_axes: Any | None = None,
|
767
|
-
) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
|
768
|
-
"""Creates a function that produces its jaxpr given example args.
|
769
|
-
|
770
|
-
Args:
|
771
|
-
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
772
|
-
arguments and return value should be arrays, scalars, or standard Python
|
773
|
-
containers (tuple/list/dict) thereof.
|
774
|
-
static_argnums: See the :py:func:`jax.jit` docstring.
|
775
|
-
axis_env: Optional, a sequence of pairs where the first element is an axis
|
776
|
-
name and the second element is a positive integer representing the size of
|
777
|
-
the mapped axis with that name. This parameter is useful when lowering
|
778
|
-
functions that involve parallel communication collectives, and it
|
779
|
-
specifies the axis name/size environment that would be set up by
|
780
|
-
applications of :py:func:`jax.pmap`.
|
781
|
-
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
782
|
-
wrapped function returns a pair where the first element is the
|
783
|
-
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
784
|
-
pytree with the same structure as the output of ``fun`` and where the
|
785
|
-
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
786
|
-
attributes representing the corresponding types of the output leaves.
|
787
|
-
|
788
|
-
Returns:
|
789
|
-
A wrapped version of ``fun`` that when applied to example arguments returns
|
790
|
-
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
791
|
-
argument ``return_shape`` is ``True``, then the returned function instead
|
792
|
-
returns a pair where the first element is the ``ClosedJaxpr``
|
793
|
-
representation of ``fun`` and the second element is a pytree representing
|
794
|
-
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
795
|
-
|
796
|
-
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
797
|
-
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
798
|
-
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
799
|
-
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
800
|
-
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
801
|
-
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
802
|
-
|
803
|
-
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
804
|
-
instead give a few examples.
|
805
|
-
|
806
|
-
>>> import jax
|
807
|
-
>>>
|
808
|
-
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
809
|
-
>>> print(f(3.0))
|
810
|
-
-0.83602
|
811
|
-
>>> _make_jaxpr(f)(3.0)
|
812
|
-
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
813
|
-
>>> _make_jaxpr(jax.grad(f))(3.0)
|
814
|
-
{ lambda ; a:f32[]. let
|
815
|
-
b:f32[] = cos a
|
816
|
-
c:f32[] = sin a
|
817
|
-
_:f32[] = sin b
|
818
|
-
d:f32[] = cos b
|
819
|
-
e:f32[] = mul 1.0 d
|
820
|
-
f:f32[] = neg e
|
821
|
-
g:f32[] = mul f c
|
822
|
-
in (g,) }
|
823
|
-
"""
|
824
|
-
_check_callable(fun)
|
825
|
-
static_argnums = _ensure_index_tuple(static_argnums)
|
826
|
-
|
827
|
-
def _abstractify(args, kwargs):
|
828
|
-
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
829
|
-
if abstracted_axes is None:
|
830
|
-
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
831
|
-
else:
|
832
|
-
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
833
|
-
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
834
|
-
in_avals, keep_inputs = unzip2(in_type)
|
835
|
-
return in_avals, in_tree, keep_inputs
|
836
|
-
|
837
|
-
@wraps(fun)
|
838
|
-
@api_boundary
|
839
|
-
def make_jaxpr_f(*args, **kwargs):
|
840
|
-
f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
|
841
|
-
if static_argnums:
|
842
|
-
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
843
|
-
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
844
|
-
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
845
|
-
in_type = tuple(safe_zip(in_avals, keep_inputs))
|
846
|
-
f, out_tree = _flatten_fun(f, in_tree)
|
847
|
-
f = annotate(f, in_type)
|
848
|
-
if jax.__version_info__ < (0, 5, 0):
|
849
|
-
debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
850
|
-
with ExitStack() as stack:
|
851
|
-
if axis_env is not None:
|
852
|
-
stack.enter_context(extend_axis_env_nd(axis_env))
|
853
|
-
if jax.__version_info__ < (0, 5, 0):
|
854
|
-
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
|
855
|
-
else:
|
856
|
-
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
857
|
-
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
858
|
-
if return_shape:
|
859
|
-
out_avals, _ = unzip2(out_type)
|
860
|
-
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
861
|
-
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
862
|
-
return closed_jaxpr
|
863
|
-
|
864
|
-
make_jaxpr_f.__module__ = "brainstate.compile"
|
865
|
-
if hasattr(fun, "__qualname__"):
|
866
|
-
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
867
|
-
if hasattr(fun, "__name__"):
|
868
|
-
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
869
|
-
return make_jaxpr_f
|
870
|
-
|
871
|
-
|
872
|
-
def make_hashable(obj):
|
873
|
-
"""Convert a pytree into a hashable representation."""
|
874
|
-
if isinstance(obj, (list, tuple)):
|
875
|
-
return tuple(make_hashable(item) for item in obj)
|
876
|
-
elif isinstance(obj, dict):
|
877
|
-
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
|
878
|
-
elif isinstance(obj, set):
|
879
|
-
return frozenset(make_hashable(item) for item in obj)
|
880
|
-
else:
|
881
|
-
# # Use JAX's tree_util for any other pytree structures
|
882
|
-
# try:
|
883
|
-
# leaves, treedef = jax.tree_util.tree_flatten(obj)
|
884
|
-
# hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
|
885
|
-
# return (str(treedef), hashable_leaves)
|
886
|
-
# except:
|
887
|
-
# # Assume obj is already hashable
|
888
|
-
return obj
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
|
18
|
+
written by the function. These state transformations are foundational for the BrainCore library. These utilities
|
19
|
+
include two basic functions: `StatefulFunction` and `make_jaxpr`.
|
20
|
+
|
21
|
+
|
22
|
+
``StatefulFunction``
|
23
|
+
--------------------
|
24
|
+
|
25
|
+
The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
|
26
|
+
JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
|
27
|
+
The class provides the following methods:
|
28
|
+
|
29
|
+
- `make_jaxpr`: creates the JAX Jaxpr of the function.
|
30
|
+
- `jaxpr_call`: calls the function at the JAX Jaxpr level.
|
31
|
+
- `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
|
32
|
+
- `get_states`: returns the states that are read and written by the function.
|
33
|
+
- `get_read_states`: returns the states that are read by the function.
|
34
|
+
- `get_write_states`: returns the states that are written by the function.
|
35
|
+
- `get_static_args`: returns the static arguments from the arguments.
|
36
|
+
- `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
|
37
|
+
written by the function.
|
38
|
+
- `get_jaxpr`: returns the JAX Jaxpr of the function.
|
39
|
+
- `get_out_shapes`: returns the output shapes of the function.
|
40
|
+
- `get_out_treedef`: returns the output tree of the function.
|
41
|
+
|
42
|
+
``make_jaxpr``
|
43
|
+
--------------
|
44
|
+
|
45
|
+
The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
|
46
|
+
arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
|
47
|
+
`ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
|
48
|
+
returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
|
49
|
+
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
|
50
|
+
function.
|
51
|
+
|
52
|
+
"""
|
53
|
+
|
54
|
+
import functools
|
55
|
+
import inspect
|
56
|
+
import operator
|
57
|
+
from collections.abc import Hashable, Iterable, Sequence
|
58
|
+
from contextlib import ExitStack
|
59
|
+
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
60
|
+
|
61
|
+
import jax
|
62
|
+
from jax._src import source_info_util
|
63
|
+
from jax._src.linear_util import annotate
|
64
|
+
from jax._src.traceback_util import api_boundary
|
65
|
+
from jax.api_util import shaped_abstractify
|
66
|
+
from jax.extend.linear_util import transformation_with_aux
|
67
|
+
from jax.interpreters import partial_eval as pe
|
68
|
+
|
69
|
+
from brainstate._compatible_import import (
|
70
|
+
ClosedJaxpr,
|
71
|
+
extend_axis_env_nd,
|
72
|
+
safe_map,
|
73
|
+
safe_zip,
|
74
|
+
unzip2,
|
75
|
+
wraps,
|
76
|
+
wrap_init,
|
77
|
+
)
|
78
|
+
from brainstate._state import State, StateTraceStack
|
79
|
+
from brainstate._utils import set_module_as
|
80
|
+
from brainstate.typing import PyTree
|
81
|
+
from brainstate.util import PrettyObject
|
82
|
+
|
83
|
+
AxisName = Hashable
|
84
|
+
|
85
|
+
__all__ = [
|
86
|
+
"StatefulFunction",
|
87
|
+
"make_jaxpr",
|
88
|
+
]
|
89
|
+
|
90
|
+
|
91
|
+
def _ensure_str(x: str) -> str:
|
92
|
+
if not isinstance(x, str):
|
93
|
+
raise TypeError(f"argument is not a string: {x}")
|
94
|
+
return x
|
95
|
+
|
96
|
+
|
97
|
+
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
98
|
+
"""Convert x to a tuple of indices."""
|
99
|
+
x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
100
|
+
try:
|
101
|
+
return (operator.index(x),)
|
102
|
+
except TypeError:
|
103
|
+
return tuple(safe_map(operator.index, x))
|
104
|
+
|
105
|
+
|
106
|
+
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
107
|
+
"""Convert x to a tuple of strings."""
|
108
|
+
if isinstance(x, str):
|
109
|
+
return (x,)
|
110
|
+
else:
|
111
|
+
return tuple(safe_map(_ensure_str, x))
|
112
|
+
|
113
|
+
|
114
|
+
def _jax_v04_new_arg_fn(frame, trace, aval):
|
115
|
+
"""
|
116
|
+
Transform a new argument to a tracer.
|
117
|
+
|
118
|
+
Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
|
119
|
+
|
120
|
+
Args:
|
121
|
+
frame: The frame.
|
122
|
+
trace: The trace.
|
123
|
+
aval: The abstract value.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
The tracer.
|
127
|
+
"""
|
128
|
+
tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
|
129
|
+
frame.tracers.append(tracer)
|
130
|
+
frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
|
131
|
+
frame.invars.append(var)
|
132
|
+
return tracer
|
133
|
+
|
134
|
+
|
135
|
+
def _jax_v04_new_jax_trace():
|
136
|
+
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
137
|
+
frame = main.jaxpr_stack[-1]
|
138
|
+
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
139
|
+
return frame, trace
|
140
|
+
|
141
|
+
|
142
|
+
def _jax_v04_new_arg():
|
143
|
+
# Should be within the calling of ``jax.make_jaxpr()``
|
144
|
+
frame, trace = _jax_v04_new_jax_trace()
|
145
|
+
# Set the function to transform the new argument to a tracer
|
146
|
+
fn = functools.partial(_jax_v04_new_arg_fn, frame, trace)
|
147
|
+
return fn
|
148
|
+
|
149
|
+
|
150
|
+
def _jax_new_version_new_arg():
|
151
|
+
trace = jax.core.trace_ctx.trace
|
152
|
+
|
153
|
+
def wrapper(x):
|
154
|
+
if jax.__version_info__ < (0, 6, 1):
|
155
|
+
return trace.new_arg(shaped_abstractify(x))
|
156
|
+
else:
|
157
|
+
return trace.new_arg(shaped_abstractify(x), source_info=source_info_util.current())
|
158
|
+
|
159
|
+
return wrapper
|
160
|
+
|
161
|
+
|
162
|
+
def _init_state_trace_stack(name) -> StateTraceStack:
|
163
|
+
state_trace: StateTraceStack = StateTraceStack(name=name)
|
164
|
+
|
165
|
+
if jax.__version_info__ < (0, 4, 36):
|
166
|
+
state_trace.set_new_arg(_jax_v04_new_arg())
|
167
|
+
else:
|
168
|
+
state_trace.set_new_arg(_jax_new_version_new_arg())
|
169
|
+
return state_trace
|
170
|
+
|
171
|
+
|
172
|
+
default_cache_key = ((), ())
|
173
|
+
|
174
|
+
|
175
|
+
class StatefulFunction(PrettyObject):
|
176
|
+
"""
|
177
|
+
A wrapper class for a function that collects the states that are read and written by the function. The states are
|
178
|
+
collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
|
179
|
+
manage the states in the JAX program. The class provides a function called `states` that returns the states
|
180
|
+
that are read and written by the function. The class provides a function called `to_state_manager` that returns
|
181
|
+
a StateDictManager instance that contains the states that are read and written by the function. The class provides
|
182
|
+
a function called `__call__` that wraps the function and returns the states that are read and written by the
|
183
|
+
function and the output of the function.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
187
|
+
arguments and return value should be arrays, scalars, or standard Python
|
188
|
+
containers (tuple/list/dict) thereof.
|
189
|
+
static_argnums: See the :py:func:`jax.jit` docstring.
|
190
|
+
static_argnames: See the :py:func:`jax.jit` docstring.
|
191
|
+
axis_env: Optional, a sequence of pairs where the first element is an axis
|
192
|
+
name and the second element is a positive integer representing the size of
|
193
|
+
the mapped axis with that name. This parameter is useful when lowering
|
194
|
+
functions that involve parallel communication collectives, and it
|
195
|
+
specifies the axis name/size environment that would be set up by
|
196
|
+
applications of :py:func:`jax.pmap`.
|
197
|
+
abstracted_axes: Optional, a pytree with the same structure as the input
|
198
|
+
arguments to ``fun``. The leaves of the pytree can be either None or a
|
199
|
+
dict with axis names as keys and integers as values. If the leaf is None,
|
200
|
+
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
201
|
+
the corresponding axis is abstracted, and the dict specifies the axis name
|
202
|
+
and size. The abstracted axes are used to infer the input type of the
|
203
|
+
function. If None, then all axes are abstracted.
|
204
|
+
state_returns: Optional, a string or a tuple of strings. The default is
|
205
|
+
``('read', 'write')``. The strings specify the categories of states to be
|
206
|
+
returned by the wrapped function. The categories are ``'read'`` and
|
207
|
+
``'write'``. If the category is ``'read'``, then the wrapped function
|
208
|
+
returns the states that are read by the function. If the category is
|
209
|
+
``'write'``, then the wrapped function returns the states that are written
|
210
|
+
by the function. If the category is ``'read'`` and ``'write'``, then the
|
211
|
+
wrapped function returns both the read and write states.
|
212
|
+
|
213
|
+
"""
|
214
|
+
__module__ = "brainstate.compile"
|
215
|
+
|
216
|
+
def __init__(
|
217
|
+
self,
|
218
|
+
fun: Callable,
|
219
|
+
static_argnums: Union[int, Iterable[int]] = (),
|
220
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
221
|
+
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
222
|
+
abstracted_axes: Optional[Any] = None,
|
223
|
+
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
|
224
|
+
cache_type: Optional[str] = None,
|
225
|
+
name: Optional[str] = None,
|
226
|
+
):
|
227
|
+
# explicit parameters
|
228
|
+
self.fun = fun
|
229
|
+
self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
230
|
+
self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
|
231
|
+
self.axis_env = axis_env
|
232
|
+
self.abstracted_axes = abstracted_axes
|
233
|
+
self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
|
234
|
+
assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
|
235
|
+
self.name = name
|
236
|
+
|
237
|
+
# implicit parameters
|
238
|
+
self.cache_type = cache_type
|
239
|
+
self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
|
240
|
+
self._cached_out_shapes: Dict[Any, PyTree] = dict()
|
241
|
+
self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
|
242
|
+
self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
|
243
|
+
|
244
|
+
def __pretty_repr_item__(self, k, v):
|
245
|
+
if k.startswith('_'):
|
246
|
+
return None
|
247
|
+
return k, v
|
248
|
+
|
249
|
+
def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
|
250
|
+
"""
|
251
|
+
Read the JAX Jaxpr representation of the function.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
cache_key: The hashable key.
|
255
|
+
|
256
|
+
Returns:
|
257
|
+
The JAX Jaxpr representation of the function.
|
258
|
+
"""
|
259
|
+
if cache_key is None:
|
260
|
+
cache_key = default_cache_key
|
261
|
+
if cache_key not in self._cached_jaxpr:
|
262
|
+
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
263
|
+
return self._cached_jaxpr[cache_key]
|
264
|
+
|
265
|
+
def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
|
266
|
+
"""
|
267
|
+
Read the output shapes of the function.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
cache_key: The hashable key.
|
271
|
+
|
272
|
+
Returns:
|
273
|
+
The output shapes of the function.
|
274
|
+
"""
|
275
|
+
if cache_key is None:
|
276
|
+
cache_key = default_cache_key
|
277
|
+
if cache_key not in self._cached_out_shapes:
|
278
|
+
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
279
|
+
return self._cached_out_shapes[cache_key]
|
280
|
+
|
281
|
+
def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
|
282
|
+
"""
|
283
|
+
Read the output tree of the function.
|
284
|
+
|
285
|
+
Args:
|
286
|
+
cache_key: The hashable key.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
The output tree of the function.
|
290
|
+
"""
|
291
|
+
if cache_key is None:
|
292
|
+
cache_key = default_cache_key
|
293
|
+
if cache_key not in self._cached_jaxpr_out_tree:
|
294
|
+
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
295
|
+
return self._cached_jaxpr_out_tree[cache_key]
|
296
|
+
|
297
|
+
def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
|
298
|
+
"""
|
299
|
+
Read the state trace of the function.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
cache_key: The hashable key.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
The state trace of the function.
|
306
|
+
"""
|
307
|
+
if cache_key is None:
|
308
|
+
cache_key = default_cache_key
|
309
|
+
if cache_key not in self._cached_state_trace:
|
310
|
+
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
311
|
+
return self._cached_state_trace[cache_key]
|
312
|
+
|
313
|
+
def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
314
|
+
"""
|
315
|
+
Read the states that are read and written by the function.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
cache_key: The hashable key.
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
The states that are read and written by the function.
|
322
|
+
"""
|
323
|
+
if cache_key is None:
|
324
|
+
cache_key = default_cache_key
|
325
|
+
return tuple(self.get_state_trace(cache_key).states)
|
326
|
+
|
327
|
+
def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
328
|
+
"""
|
329
|
+
Read the states that are read by the function.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
cache_key: The hashable key.
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
The states that are read by the function.
|
336
|
+
"""
|
337
|
+
if cache_key is None:
|
338
|
+
cache_key = default_cache_key
|
339
|
+
return self.get_state_trace(cache_key).get_read_states()
|
340
|
+
|
341
|
+
def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
342
|
+
"""
|
343
|
+
Read the states that are written by the function.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
cache_key: The hashable key.
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
The states that are written by the function.
|
350
|
+
"""
|
351
|
+
if cache_key is None:
|
352
|
+
cache_key = default_cache_key
|
353
|
+
return self.get_state_trace(cache_key).get_write_states()
|
354
|
+
|
355
|
+
def _check_input_ouput(self, x):
|
356
|
+
if isinstance(x, State):
|
357
|
+
x.raise_error_with_source_info(
|
358
|
+
ValueError(
|
359
|
+
'Inputs/outputs for brainstate transformations cannot be an instance of State. '
|
360
|
+
f'But we got {x}'
|
361
|
+
)
|
362
|
+
)
|
363
|
+
|
364
|
+
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
365
|
+
"""
|
366
|
+
Get the static arguments from the arguments.
|
367
|
+
|
368
|
+
Args:
|
369
|
+
*args: The arguments to the function.
|
370
|
+
**kwargs: The keyword arguments to the function.
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
The static arguments and keyword arguments as a tuple.
|
374
|
+
"""
|
375
|
+
if self.cache_type == 'jit':
|
376
|
+
static_args, dyn_args = [], []
|
377
|
+
for i, arg in enumerate(args):
|
378
|
+
if i in self.static_argnums:
|
379
|
+
static_args.append(arg)
|
380
|
+
else:
|
381
|
+
dyn_args.append(arg)
|
382
|
+
dyn_args = jax.tree.map(shaped_abstractify, dyn_args)
|
383
|
+
static_kwargs, dyn_kwargs = [], []
|
384
|
+
for k, v in kwargs.items():
|
385
|
+
if k in self.static_argnames:
|
386
|
+
static_kwargs.append((k, v))
|
387
|
+
else:
|
388
|
+
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
389
|
+
|
390
|
+
static_args = make_hashable(tuple(static_args))
|
391
|
+
dyn_args = make_hashable(tuple(dyn_args))
|
392
|
+
static_kwargs = make_hashable(static_kwargs)
|
393
|
+
dyn_kwargs = make_hashable(dyn_kwargs)
|
394
|
+
|
395
|
+
cache_key = (static_args, dyn_args, static_kwargs, dyn_kwargs)
|
396
|
+
elif self.cache_type is None:
|
397
|
+
num_arg = len(args)
|
398
|
+
static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
|
399
|
+
static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
|
400
|
+
|
401
|
+
# Make everything hashable
|
402
|
+
static_args = make_hashable(static_args)
|
403
|
+
static_kwargs = make_hashable(static_kwargs)
|
404
|
+
|
405
|
+
cache_key = (static_args, static_kwargs)
|
406
|
+
else:
|
407
|
+
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
408
|
+
|
409
|
+
return cache_key
|
410
|
+
|
411
|
+
def compile_function_and_get_states(self, *args, **kwargs) -> Tuple[State, ...]:
|
412
|
+
"""
|
413
|
+
Compile the function, and get the states that are read and written by this function.
|
414
|
+
|
415
|
+
Args:
|
416
|
+
*args: The arguments to the function.
|
417
|
+
**kwargs: The keyword arguments to the function.
|
418
|
+
|
419
|
+
Returns:
|
420
|
+
The states that are read and written by the function.
|
421
|
+
"""
|
422
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
423
|
+
if cache_key not in self._cached_state_trace:
|
424
|
+
self.make_jaxpr(*args, **kwargs)
|
425
|
+
return self.get_states(cache_key)
|
426
|
+
|
427
|
+
def compile_function_and_get_state_trace(
|
428
|
+
self, *args, return_only_write: bool = False, **kwargs
|
429
|
+
) -> StateTraceStack:
|
430
|
+
"""
|
431
|
+
Compile the function, and get the states that are read and written by this function.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
*args: The arguments to the function.
|
435
|
+
**kwargs: The keyword arguments to the function.
|
436
|
+
return_only_write: If True, only return the states that are written by the function.
|
437
|
+
|
438
|
+
Returns:
|
439
|
+
The state trace stack.
|
440
|
+
"""
|
441
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
442
|
+
if cache_key not in self._cached_state_trace:
|
443
|
+
self.make_jaxpr(*args, **kwargs, return_only_write=return_only_write)
|
444
|
+
return self.get_state_trace(cache_key)
|
445
|
+
|
446
|
+
def clear_cache(self) -> None:
|
447
|
+
"""
|
448
|
+
Clear the compilation cache.
|
449
|
+
"""
|
450
|
+
self._cached_jaxpr.clear()
|
451
|
+
self._cached_out_shapes.clear()
|
452
|
+
self._cached_jaxpr_out_tree.clear()
|
453
|
+
self._cached_state_trace.clear()
|
454
|
+
|
455
|
+
def _wrapped_fun_to_eval(
|
456
|
+
self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
|
457
|
+
) -> Tuple[Any, Tuple[State, ...]]:
|
458
|
+
"""
|
459
|
+
Wrap the function and return the states that are read and written by the function and the output of the function.
|
460
|
+
|
461
|
+
Args:
|
462
|
+
*args: The arguments to the function.
|
463
|
+
**kwargs: The keyword arguments to the function.
|
464
|
+
|
465
|
+
Returns:
|
466
|
+
A tuple of the states that are read and written by the function and the output of the function.
|
467
|
+
"""
|
468
|
+
# state trace
|
469
|
+
state_trace = _init_state_trace_stack(self.name)
|
470
|
+
self._cached_state_trace[cache_key] = state_trace
|
471
|
+
with state_trace:
|
472
|
+
out = self.fun(*args, **dyn_kwargs, **static_kwargs)
|
473
|
+
state_values = (
|
474
|
+
state_trace.get_write_state_values(True)
|
475
|
+
if return_only_write else
|
476
|
+
state_trace.get_state_values()
|
477
|
+
)
|
478
|
+
state_trace.recovery_original_values()
|
479
|
+
|
480
|
+
# State instance as functional returns is not allowed.
|
481
|
+
# Checking whether the states are returned.
|
482
|
+
jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
|
483
|
+
return out, state_values
|
484
|
+
|
485
|
+
def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
|
486
|
+
"""Creates a function that produces its jaxpr given example args.
|
487
|
+
|
488
|
+
A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
489
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
490
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
491
|
+
representation of ``fun`` and the second element is a pytree representing
|
492
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
493
|
+
|
494
|
+
Args:
|
495
|
+
*args: The arguments to the function.
|
496
|
+
**kwargs: The keyword arguments to the function.
|
497
|
+
return_only_write: If True, only return the states that are written by the function.
|
498
|
+
"""
|
499
|
+
|
500
|
+
# static args
|
501
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
502
|
+
|
503
|
+
# check input types
|
504
|
+
jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
|
505
|
+
|
506
|
+
if cache_key not in self._cached_state_trace:
|
507
|
+
try:
|
508
|
+
# jaxpr
|
509
|
+
static_kwargs, dyn_kwargs = {}, {}
|
510
|
+
for k, v in kwargs.items():
|
511
|
+
if k in self.static_argnames:
|
512
|
+
static_kwargs[k] = v
|
513
|
+
else:
|
514
|
+
dyn_kwargs[k] = v
|
515
|
+
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
516
|
+
functools.partial(
|
517
|
+
self._wrapped_fun_to_eval,
|
518
|
+
cache_key,
|
519
|
+
static_kwargs,
|
520
|
+
return_only_write=return_only_write
|
521
|
+
),
|
522
|
+
static_argnums=self.static_argnums,
|
523
|
+
axis_env=self.axis_env,
|
524
|
+
return_shape=True,
|
525
|
+
abstracted_axes=self.abstracted_axes
|
526
|
+
)(*args, **dyn_kwargs)
|
527
|
+
# returns
|
528
|
+
self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
529
|
+
self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
|
530
|
+
self._cached_jaxpr[cache_key] = jaxpr
|
531
|
+
|
532
|
+
except Exception as e:
|
533
|
+
try:
|
534
|
+
self._cached_state_trace.pop(cache_key)
|
535
|
+
except KeyError:
|
536
|
+
pass
|
537
|
+
raise e
|
538
|
+
|
539
|
+
return self
|
540
|
+
|
541
|
+
def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
|
542
|
+
"""
|
543
|
+
Call the function at the JAX Jaxpr level.
|
544
|
+
|
545
|
+
Args:
|
546
|
+
state_vals: The state values.
|
547
|
+
*args: The arguments to the function.
|
548
|
+
**kwargs: The keyword arguments to the function.
|
549
|
+
|
550
|
+
Returns:
|
551
|
+
State values and the function output.
|
552
|
+
"""
|
553
|
+
# state checking
|
554
|
+
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
555
|
+
states: Sequence[State] = self.get_states(cache_key)
|
556
|
+
assert len(state_vals) == len(states), 'State length mismatch.'
|
557
|
+
|
558
|
+
# parameters
|
559
|
+
kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
|
560
|
+
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
561
|
+
args = jax.tree.flatten((args, kwargs, state_vals))[0]
|
562
|
+
|
563
|
+
# calling the function,
|
564
|
+
# note that this function always returns state values
|
565
|
+
# that both write and read by the function
|
566
|
+
closed_jaxpr = self.get_jaxpr(cache_key)
|
567
|
+
out_treedef = self.get_out_treedef(cache_key)
|
568
|
+
jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
569
|
+
|
570
|
+
# output processing
|
571
|
+
out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
|
572
|
+
assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
|
573
|
+
return new_state_vals, out
|
574
|
+
|
575
|
+
def jaxpr_call_auto(self, *args, **kwargs) -> Any:
|
576
|
+
"""
|
577
|
+
Call the function at the JAX Jaxpr level with automatic state management.
|
578
|
+
|
579
|
+
Args:
|
580
|
+
*args: The arguments to the function.
|
581
|
+
**kwargs: The keyword arguments to the function.
|
582
|
+
|
583
|
+
Returns:
|
584
|
+
The output of the function.
|
585
|
+
"""
|
586
|
+
state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
|
587
|
+
state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
|
588
|
+
state_trace.assign_state_vals(state_vals)
|
589
|
+
return out
|
590
|
+
|
591
|
+
|
592
|
+
@set_module_as("brainstate.compile")
|
593
|
+
def make_jaxpr(
|
594
|
+
fun: Callable,
|
595
|
+
static_argnums: Union[int, Iterable[int]] = (),
|
596
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
597
|
+
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
598
|
+
return_shape: bool = False,
|
599
|
+
abstracted_axes: Optional[Any] = None,
|
600
|
+
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
601
|
+
) -> Callable[
|
602
|
+
...,
|
603
|
+
(Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
604
|
+
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
|
605
|
+
]:
|
606
|
+
"""
|
607
|
+
Creates a function that produces its jaxpr given example args.
|
608
|
+
|
609
|
+
Args:
|
610
|
+
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
611
|
+
arguments and return value should be arrays, scalars, or standard Python
|
612
|
+
containers (tuple/list/dict) thereof.
|
613
|
+
static_argnums: See the :py:func:`jax.jit` docstring.
|
614
|
+
static_argnames: See the :py:func:`jax.jit` docstring.
|
615
|
+
axis_env: Optional, a sequence of pairs where the first element is an axis
|
616
|
+
name and the second element is a positive integer representing the size of
|
617
|
+
the mapped axis with that name. This parameter is useful when lowering
|
618
|
+
functions that involve parallel communication collectives, and it
|
619
|
+
specifies the axis name/size environment that would be set up by
|
620
|
+
applications of :py:func:`jax.pmap`.
|
621
|
+
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
622
|
+
wrapped function returns a pair where the first element is the XLA
|
623
|
+
computation and the second element is a pytree with the same structure as
|
624
|
+
the output of ``fun`` and where the leaves are objects with ``shape``,
|
625
|
+
``dtype``, and ``named_shape`` attributes representing the corresponding
|
626
|
+
types of the output leaves.
|
627
|
+
abstracted_axes: Optional, a pytree with the same structure as the input
|
628
|
+
arguments to ``fun``. The leaves of the pytree can be either None or a
|
629
|
+
dict with axis names as keys and integers as values. If the leaf is None,
|
630
|
+
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
631
|
+
the corresponding axis is abstracted, and the dict specifies the axis name
|
632
|
+
and size. The abstracted axes are used to infer the input type of the
|
633
|
+
function. If None, then all axes are abstracted.
|
634
|
+
state_returns: Optional, a string or a tuple of strings. The default is
|
635
|
+
``('read', 'write')``. The strings specify the categories of states to be
|
636
|
+
returned by the wrapped function. The categories are ``'read'`` and
|
637
|
+
``'write'``. If the category is ``'read'``, then the wrapped function
|
638
|
+
returns the states that are read by the function. If the category is
|
639
|
+
``'write'``, then the wrapped function returns the states that are written
|
640
|
+
by the function. If the category is ``'read'`` and ``'write'``, then the
|
641
|
+
wrapped function returns both the read and write states.
|
642
|
+
|
643
|
+
|
644
|
+
Returns:
|
645
|
+
A wrapped version of ``fun`` that when applied to example arguments returns
|
646
|
+
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
647
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
648
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
649
|
+
representation of ``fun`` and the second element is a pytree representing
|
650
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
651
|
+
|
652
|
+
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
653
|
+
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
654
|
+
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
655
|
+
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
656
|
+
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
657
|
+
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
658
|
+
|
659
|
+
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
660
|
+
instead give a few examples.
|
661
|
+
|
662
|
+
>>> import jax
|
663
|
+
>>> import brainstate as brainstate
|
664
|
+
>>>
|
665
|
+
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
666
|
+
>>> print(f(3.0))
|
667
|
+
-0.83602
|
668
|
+
>>> jaxpr, states = brainstate.compile.make_jaxpr(f)(3.0)
|
669
|
+
>>> jaxpr
|
670
|
+
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
671
|
+
>>> jaxpr, states = brainstate.compile.make_jaxpr(jax.grad(f))(3.0)
|
672
|
+
>>> jaxpr
|
673
|
+
{ lambda ; a:f32[]. let
|
674
|
+
b:f32[] = cos a
|
675
|
+
c:f32[] = sin a
|
676
|
+
_:f32[] = sin b
|
677
|
+
d:f32[] = cos b
|
678
|
+
e:f32[] = mul 1.0 d
|
679
|
+
f:f32[] = neg e
|
680
|
+
g:f32[] = mul f c
|
681
|
+
in (g,) }
|
682
|
+
"""
|
683
|
+
|
684
|
+
stateful_fun = StatefulFunction(
|
685
|
+
fun,
|
686
|
+
static_argnums=static_argnums,
|
687
|
+
static_argnames=static_argnames,
|
688
|
+
axis_env=axis_env,
|
689
|
+
abstracted_axes=abstracted_axes,
|
690
|
+
state_returns=state_returns,
|
691
|
+
name='make_jaxpr'
|
692
|
+
)
|
693
|
+
|
694
|
+
@wraps(fun)
|
695
|
+
def make_jaxpr_f(*args, **kwargs):
|
696
|
+
stateful_fun.make_jaxpr(*args, **kwargs)
|
697
|
+
cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
|
698
|
+
if return_shape:
|
699
|
+
return (stateful_fun.get_jaxpr(cache_key),
|
700
|
+
stateful_fun.get_states(cache_key),
|
701
|
+
stateful_fun.get_out_shapes(cache_key)[0])
|
702
|
+
else:
|
703
|
+
return (stateful_fun.get_jaxpr(cache_key),
|
704
|
+
stateful_fun.get_states(cache_key))
|
705
|
+
|
706
|
+
# wrapped jaxpr builder function
|
707
|
+
make_jaxpr_f.__module__ = "brainstate.compile"
|
708
|
+
if hasattr(fun, "__qualname__"):
|
709
|
+
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
710
|
+
if hasattr(fun, "__name__"):
|
711
|
+
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
712
|
+
return make_jaxpr_f
|
713
|
+
|
714
|
+
|
715
|
+
def _check_callable(fun):
|
716
|
+
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
717
|
+
# is that we can't take weak references to them, which the C++ JIT requires.
|
718
|
+
if isinstance(fun, staticmethod):
|
719
|
+
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
720
|
+
if not callable(fun):
|
721
|
+
raise TypeError(f"Expected a callable value, got {fun}")
|
722
|
+
if inspect.isgeneratorfunction(fun):
|
723
|
+
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
724
|
+
|
725
|
+
|
726
|
+
def _broadcast_prefix(
|
727
|
+
prefix_tree: Any,
|
728
|
+
full_tree: Any,
|
729
|
+
is_leaf: Callable[[Any], bool] | None = None
|
730
|
+
) -> list[Any]:
|
731
|
+
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
732
|
+
# ValueError; use prefix_errors to find disagreements and raise more precise
|
733
|
+
# error messages.
|
734
|
+
result = []
|
735
|
+
num_leaves = lambda t: jax.tree.structure(t).num_leaves
|
736
|
+
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
737
|
+
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
738
|
+
return result
|
739
|
+
|
740
|
+
|
741
|
+
def _flat_axes_specs(
|
742
|
+
abstracted_axes, *args, **kwargs
|
743
|
+
) -> list[pe.AbstractedAxesSpec]:
|
744
|
+
if kwargs:
|
745
|
+
raise NotImplementedError
|
746
|
+
|
747
|
+
def ax_leaf(l):
|
748
|
+
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
|
749
|
+
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
|
750
|
+
|
751
|
+
return _broadcast_prefix(abstracted_axes, args, ax_leaf)
|
752
|
+
|
753
|
+
|
754
|
+
@transformation_with_aux
|
755
|
+
def _flatten_fun(in_tree, *args_flat):
|
756
|
+
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
|
757
|
+
ans = yield py_args, py_kwargs
|
758
|
+
yield jax.tree.flatten(ans)
|
759
|
+
|
760
|
+
|
761
|
+
def _make_jaxpr(
|
762
|
+
fun: Callable,
|
763
|
+
static_argnums: int | Iterable[int] = (),
|
764
|
+
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
765
|
+
return_shape: bool = False,
|
766
|
+
abstracted_axes: Any | None = None,
|
767
|
+
) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
|
768
|
+
"""Creates a function that produces its jaxpr given example args.
|
769
|
+
|
770
|
+
Args:
|
771
|
+
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
772
|
+
arguments and return value should be arrays, scalars, or standard Python
|
773
|
+
containers (tuple/list/dict) thereof.
|
774
|
+
static_argnums: See the :py:func:`jax.jit` docstring.
|
775
|
+
axis_env: Optional, a sequence of pairs where the first element is an axis
|
776
|
+
name and the second element is a positive integer representing the size of
|
777
|
+
the mapped axis with that name. This parameter is useful when lowering
|
778
|
+
functions that involve parallel communication collectives, and it
|
779
|
+
specifies the axis name/size environment that would be set up by
|
780
|
+
applications of :py:func:`jax.pmap`.
|
781
|
+
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
782
|
+
wrapped function returns a pair where the first element is the
|
783
|
+
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
784
|
+
pytree with the same structure as the output of ``fun`` and where the
|
785
|
+
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
786
|
+
attributes representing the corresponding types of the output leaves.
|
787
|
+
|
788
|
+
Returns:
|
789
|
+
A wrapped version of ``fun`` that when applied to example arguments returns
|
790
|
+
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
791
|
+
argument ``return_shape`` is ``True``, then the returned function instead
|
792
|
+
returns a pair where the first element is the ``ClosedJaxpr``
|
793
|
+
representation of ``fun`` and the second element is a pytree representing
|
794
|
+
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
795
|
+
|
796
|
+
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
797
|
+
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
798
|
+
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
799
|
+
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
800
|
+
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
801
|
+
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
802
|
+
|
803
|
+
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
804
|
+
instead give a few examples.
|
805
|
+
|
806
|
+
>>> import jax
|
807
|
+
>>>
|
808
|
+
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
809
|
+
>>> print(f(3.0))
|
810
|
+
-0.83602
|
811
|
+
>>> _make_jaxpr(f)(3.0)
|
812
|
+
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
813
|
+
>>> _make_jaxpr(jax.grad(f))(3.0)
|
814
|
+
{ lambda ; a:f32[]. let
|
815
|
+
b:f32[] = cos a
|
816
|
+
c:f32[] = sin a
|
817
|
+
_:f32[] = sin b
|
818
|
+
d:f32[] = cos b
|
819
|
+
e:f32[] = mul 1.0 d
|
820
|
+
f:f32[] = neg e
|
821
|
+
g:f32[] = mul f c
|
822
|
+
in (g,) }
|
823
|
+
"""
|
824
|
+
_check_callable(fun)
|
825
|
+
static_argnums = _ensure_index_tuple(static_argnums)
|
826
|
+
|
827
|
+
def _abstractify(args, kwargs):
|
828
|
+
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
829
|
+
if abstracted_axes is None:
|
830
|
+
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
831
|
+
else:
|
832
|
+
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
833
|
+
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
834
|
+
in_avals, keep_inputs = unzip2(in_type)
|
835
|
+
return in_avals, in_tree, keep_inputs
|
836
|
+
|
837
|
+
@wraps(fun)
|
838
|
+
@api_boundary
|
839
|
+
def make_jaxpr_f(*args, **kwargs):
|
840
|
+
f = wrap_init(fun, (), {}, 'brainstate.compile.make_jaxpr')
|
841
|
+
if static_argnums:
|
842
|
+
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
843
|
+
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
844
|
+
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
845
|
+
in_type = tuple(safe_zip(in_avals, keep_inputs))
|
846
|
+
f, out_tree = _flatten_fun(f, in_tree)
|
847
|
+
f = annotate(f, in_type)
|
848
|
+
if jax.__version_info__ < (0, 5, 0):
|
849
|
+
debug_info_ = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
850
|
+
with ExitStack() as stack:
|
851
|
+
if axis_env is not None:
|
852
|
+
stack.enter_context(extend_axis_env_nd(axis_env))
|
853
|
+
if jax.__version_info__ < (0, 5, 0):
|
854
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info_)
|
855
|
+
else:
|
856
|
+
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
857
|
+
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
858
|
+
if return_shape:
|
859
|
+
out_avals, _ = unzip2(out_type)
|
860
|
+
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
861
|
+
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
862
|
+
return closed_jaxpr
|
863
|
+
|
864
|
+
make_jaxpr_f.__module__ = "brainstate.compile"
|
865
|
+
if hasattr(fun, "__qualname__"):
|
866
|
+
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
867
|
+
if hasattr(fun, "__name__"):
|
868
|
+
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
869
|
+
return make_jaxpr_f
|
870
|
+
|
871
|
+
|
872
|
+
def make_hashable(obj):
|
873
|
+
"""Convert a pytree into a hashable representation."""
|
874
|
+
if isinstance(obj, (list, tuple)):
|
875
|
+
return tuple(make_hashable(item) for item in obj)
|
876
|
+
elif isinstance(obj, dict):
|
877
|
+
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
|
878
|
+
elif isinstance(obj, set):
|
879
|
+
return frozenset(make_hashable(item) for item in obj)
|
880
|
+
else:
|
881
|
+
# # Use JAX's tree_util for any other pytree structures
|
882
|
+
# try:
|
883
|
+
# leaves, treedef = jax.tree_util.tree_flatten(obj)
|
884
|
+
# hashable_leaves = tuple(make_hashable(leaf) for leaf in leaves)
|
885
|
+
# return (str(treedef), hashable_leaves)
|
886
|
+
# except:
|
887
|
+
# # Assume obj is already hashable
|
888
|
+
return obj
|