brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,739 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
This module implements how to create a JAX Jaxpr from a given function by considering the states that are read and
|
18
|
-
written by the function. These state transformations are foundational for the BrainCore library. These utilities
|
19
|
-
include two basic functions: `StatefulFunction` and `make_jaxpr`.
|
20
|
-
|
21
|
-
|
22
|
-
``StatefulFunction``
|
23
|
-
--------------------
|
24
|
-
|
25
|
-
The module provides a class called ``StatefulFunction`` that wraps a function and provides methods to get the
|
26
|
-
JAX Jaxpr, the output shapes, the states that are read and written by the function, and the output of the function.
|
27
|
-
The class provides the following methods:
|
28
|
-
|
29
|
-
- `make_jaxpr`: creates the JAX Jaxpr of the function.
|
30
|
-
- `jaxpr_call`: calls the function at the JAX Jaxpr level.
|
31
|
-
- `jaxpr_call_without_states`: calls the function at the JAX Jaxpr level without considering the states.
|
32
|
-
- `get_states`: returns the states that are read and written by the function.
|
33
|
-
- `get_read_states`: returns the states that are read by the function.
|
34
|
-
- `get_write_states`: returns the states that are written by the function.
|
35
|
-
- `get_static_args`: returns the static arguments from the arguments.
|
36
|
-
- `compile_and_get_states_by_static_args`: compiles the function and returns the states that are read and
|
37
|
-
written by the function.
|
38
|
-
- `get_jaxpr`: returns the JAX Jaxpr of the function.
|
39
|
-
- `get_out_shapes`: returns the output shapes of the function.
|
40
|
-
- `get_out_treedef`: returns the output tree of the function.
|
41
|
-
|
42
|
-
``make_jaxpr``
|
43
|
-
--------------
|
44
|
-
|
45
|
-
The module provides a function called `make_jaxpr` that creates a function that produces its JAX Jaxpr given example
|
46
|
-
arguments. The function returns a wrapped version of the function that when applied to example arguments returns a
|
47
|
-
`ClosedJaxpr` representation of the function on those arguments. If the argument `return_shape` is `True`, then the
|
48
|
-
returned function instead returns a pair where the first element is the `ClosedJaxpr` representation of the function
|
49
|
-
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output of the
|
50
|
-
function.
|
51
|
-
|
52
|
-
"""
|
53
|
-
|
54
|
-
from __future__ import annotations
|
55
|
-
|
56
|
-
import functools
|
57
|
-
import inspect
|
58
|
-
import operator
|
59
|
-
from collections.abc import Hashable, Iterable, Sequence
|
60
|
-
from contextlib import ExitStack
|
61
|
-
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
62
|
-
|
63
|
-
import jax
|
64
|
-
from jax._src import source_info_util
|
65
|
-
from jax._src.linear_util import annotate
|
66
|
-
from jax._src.traceback_util import api_boundary
|
67
|
-
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
68
|
-
from jax.interpreters import partial_eval as pe
|
69
|
-
from jax.interpreters.xla import abstractify
|
70
|
-
from jax.util import wraps
|
71
|
-
|
72
|
-
from brainstate._state import State, StateTrace
|
73
|
-
from brainstate._utils import set_module_as
|
74
|
-
from brainstate.typing import PyTree
|
75
|
-
|
76
|
-
AxisName = Hashable
|
77
|
-
|
78
|
-
__all__ = [
|
79
|
-
"StatefulFunction",
|
80
|
-
"make_jaxpr",
|
81
|
-
]
|
82
|
-
|
83
|
-
|
84
|
-
def _assign_state_values(states, state_vals) -> None:
|
85
|
-
"""
|
86
|
-
Assign the state values to the states.
|
87
|
-
|
88
|
-
Args:
|
89
|
-
states: The states.
|
90
|
-
state_vals: The state values.
|
91
|
-
"""
|
92
|
-
assert len(states) == len(state_vals), f'State length mismatch. {len(states)} != {len(state_vals)}.'
|
93
|
-
for st, val in zip(states, state_vals):
|
94
|
-
st.value = val
|
95
|
-
|
96
|
-
|
97
|
-
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
98
|
-
"""Convert x to a tuple of indices."""
|
99
|
-
x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
100
|
-
try:
|
101
|
-
return (operator.index(x),)
|
102
|
-
except TypeError:
|
103
|
-
return tuple(jax.util.safe_map(operator.index, x))
|
104
|
-
|
105
|
-
|
106
|
-
def _new_arg(frame, trace, aval):
|
107
|
-
"""
|
108
|
-
Transform a new argument to a tracer.
|
109
|
-
|
110
|
-
Modified from jax.interpreters.partial_eval.DynamicJaxprTrace.new_arg()
|
111
|
-
|
112
|
-
Args:
|
113
|
-
frame: The frame.
|
114
|
-
trace: The trace.
|
115
|
-
aval: The abstract value.
|
116
|
-
|
117
|
-
Returns:
|
118
|
-
The tracer.
|
119
|
-
"""
|
120
|
-
tracer = pe.DynamicJaxprTracer(trace, aval, source_info_util.current())
|
121
|
-
frame.tracers.append(tracer)
|
122
|
-
frame.tracer_to_var[id(tracer)] = var = frame.newvar(aval)
|
123
|
-
frame.invars.append(var)
|
124
|
-
return tracer
|
125
|
-
|
126
|
-
|
127
|
-
def wrapped_abstractify(x: Any) -> Any:
|
128
|
-
"""
|
129
|
-
Abstractify the input.
|
130
|
-
|
131
|
-
Args:
|
132
|
-
x: The input.
|
133
|
-
|
134
|
-
Returns:
|
135
|
-
The abstractified input.
|
136
|
-
"""
|
137
|
-
if isinstance(x, pe.DynamicJaxprTracer):
|
138
|
-
return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
|
139
|
-
return abstractify(x)
|
140
|
-
|
141
|
-
|
142
|
-
class StatefulFunction(object):
|
143
|
-
"""
|
144
|
-
A wrapper class for a function that collects the states that are read and written by the function. The states are
|
145
|
-
collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
|
146
|
-
manage the states in the JAX program. The class provides a function called `states` that returns the states
|
147
|
-
that are read and written by the function. The class provides a function called `to_state_manager` that returns
|
148
|
-
a StateDictManager instance that contains the states that are read and written by the function. The class provides
|
149
|
-
a function called `__call__` that wraps the function and returns the states that are read and written by the
|
150
|
-
function and the output of the function.
|
151
|
-
|
152
|
-
Args:
|
153
|
-
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
154
|
-
arguments and return value should be arrays, scalars, or standard Python
|
155
|
-
containers (tuple/list/dict) thereof.
|
156
|
-
static_argnums: See the :py:func:`jax.jit` docstring.
|
157
|
-
axis_env: Optional, a sequence of pairs where the first element is an axis
|
158
|
-
name and the second element is a positive integer representing the size of
|
159
|
-
the mapped axis with that name. This parameter is useful when lowering
|
160
|
-
functions that involve parallel communication collectives, and it
|
161
|
-
specifies the axis name/size environment that would be set up by
|
162
|
-
applications of :py:func:`jax.pmap`.
|
163
|
-
abstracted_axes: Optional, a pytree with the same structure as the input
|
164
|
-
arguments to ``fun``. The leaves of the pytree can be either None or a
|
165
|
-
dict with axis names as keys and integers as values. If the leaf is None,
|
166
|
-
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
167
|
-
the corresponding axis is abstracted, and the dict specifies the axis name
|
168
|
-
and size. The abstracted axes are used to infer the input type of the
|
169
|
-
function. If None, then all axes are abstracted.
|
170
|
-
state_returns: Optional, a string or a tuple of strings. The default is
|
171
|
-
``('read', 'write')``. The strings specify the categories of states to be
|
172
|
-
returned by the wrapped function. The categories are ``'read'`` and
|
173
|
-
``'write'``. If the category is ``'read'``, then the wrapped function
|
174
|
-
returns the states that are read by the function. If the category is
|
175
|
-
``'write'``, then the wrapped function returns the states that are written
|
176
|
-
by the function. If the category is ``'read'`` and ``'write'``, then the
|
177
|
-
wrapped function returns both the read and write states.
|
178
|
-
|
179
|
-
"""
|
180
|
-
__module__ = "brainstate.transform"
|
181
|
-
|
182
|
-
def __init__(
|
183
|
-
self,
|
184
|
-
fun: Callable,
|
185
|
-
static_argnums: Union[int, Iterable[int]] = (),
|
186
|
-
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
187
|
-
abstracted_axes: Optional[Any] = None,
|
188
|
-
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
|
189
|
-
cache_type: Optional[str] = None,
|
190
|
-
):
|
191
|
-
# explicit parameters
|
192
|
-
self.fun = fun
|
193
|
-
self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
194
|
-
self.axis_env = axis_env
|
195
|
-
self.abstracted_axes = abstracted_axes
|
196
|
-
self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
|
197
|
-
assert cache_type in [None, 'jit']
|
198
|
-
self.cache_type = cache_type
|
199
|
-
|
200
|
-
# implicit parameters
|
201
|
-
self._jaxpr: Dict[Any, jax.core.ClosedJaxpr] = dict()
|
202
|
-
self._out_shapes: Dict[Any, PyTree] = dict()
|
203
|
-
self._jaxpr_out_tree: Dict[Any, PyTree] = dict()
|
204
|
-
self._state_trace: Dict[Any, StateTrace] = dict()
|
205
|
-
|
206
|
-
def __repr__(self) -> str:
|
207
|
-
return (f"{self.__class__.__name__}({self.fun}, "
|
208
|
-
f"static_argnums={self.static_argnums}, "
|
209
|
-
f"axis_env={self.axis_env}, "
|
210
|
-
f"abstracted_axes={self.abstracted_axes}, "
|
211
|
-
f"state_returns={self.state_returns})")
|
212
|
-
|
213
|
-
def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
|
214
|
-
"""
|
215
|
-
Read the JAX Jaxpr representation of the function.
|
216
|
-
|
217
|
-
Args:
|
218
|
-
cache_key: The hashable key.
|
219
|
-
|
220
|
-
Returns:
|
221
|
-
The JAX Jaxpr representation of the function.
|
222
|
-
"""
|
223
|
-
if cache_key not in self._jaxpr:
|
224
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
225
|
-
return self._jaxpr[cache_key]
|
226
|
-
|
227
|
-
def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
|
228
|
-
"""
|
229
|
-
Read the output shapes of the function.
|
230
|
-
|
231
|
-
Args:
|
232
|
-
cache_key: The hashable key.
|
233
|
-
|
234
|
-
Returns:
|
235
|
-
The output shapes of the function.
|
236
|
-
"""
|
237
|
-
if cache_key not in self._out_shapes:
|
238
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
239
|
-
return self._out_shapes[cache_key]
|
240
|
-
|
241
|
-
def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
|
242
|
-
"""
|
243
|
-
Read the output tree of the function.
|
244
|
-
|
245
|
-
Args:
|
246
|
-
cache_key: The hashable key.
|
247
|
-
|
248
|
-
Returns:
|
249
|
-
The output tree of the function.
|
250
|
-
"""
|
251
|
-
if cache_key not in self._jaxpr_out_tree:
|
252
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
253
|
-
return self._jaxpr_out_tree[cache_key]
|
254
|
-
|
255
|
-
def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
|
256
|
-
"""
|
257
|
-
Read the states that are read and written by the function.
|
258
|
-
|
259
|
-
Args:
|
260
|
-
cache_key: The hashable key.
|
261
|
-
|
262
|
-
Returns:
|
263
|
-
The states that are read and written by the function.
|
264
|
-
"""
|
265
|
-
if cache_key not in self._state_trace:
|
266
|
-
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
267
|
-
return tuple(self._state_trace[cache_key].states)
|
268
|
-
|
269
|
-
def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
|
270
|
-
"""
|
271
|
-
Read the states that are read by the function.
|
272
|
-
|
273
|
-
Args:
|
274
|
-
cache_key: The hashable key.
|
275
|
-
|
276
|
-
Returns:
|
277
|
-
The states that are read by the function.
|
278
|
-
"""
|
279
|
-
_state_trace = self._state_trace[cache_key]
|
280
|
-
return tuple([st for st, ty in zip(_state_trace.states, _state_trace.types) if ty == 'read'])
|
281
|
-
|
282
|
-
def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
|
283
|
-
"""
|
284
|
-
Read the states that are written by the function.
|
285
|
-
|
286
|
-
Args:
|
287
|
-
cache_key: The hashable key.
|
288
|
-
|
289
|
-
Returns:
|
290
|
-
The states that are written by the function.
|
291
|
-
"""
|
292
|
-
state_trace = self._state_trace[cache_key]
|
293
|
-
return tuple([st for st, ty in zip(state_trace.states, state_trace.types) if ty == 'write'])
|
294
|
-
|
295
|
-
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
296
|
-
"""
|
297
|
-
Get the static arguments from the arguments.
|
298
|
-
|
299
|
-
Args:
|
300
|
-
*args: The arguments to the function.
|
301
|
-
|
302
|
-
Returns:
|
303
|
-
The static arguments.
|
304
|
-
"""
|
305
|
-
if self.cache_type == 'jit':
|
306
|
-
static_args, dyn_args = [], []
|
307
|
-
for i, arg in enumerate(args):
|
308
|
-
if i in self.static_argnums:
|
309
|
-
static_args.append(arg)
|
310
|
-
else:
|
311
|
-
dyn_args.append(arg)
|
312
|
-
dyn_args = jax.tree.map(wrapped_abstractify, jax.tree.leaves(dyn_args))
|
313
|
-
dyn_kwargs = jax.tree.map(wrapped_abstractify, jax.tree.leaves(kwargs))
|
314
|
-
return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
|
315
|
-
elif self.cache_type is None:
|
316
|
-
num_arg = len(args)
|
317
|
-
return tuple(args[i] for i in self.static_argnums if i < num_arg)
|
318
|
-
else:
|
319
|
-
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
320
|
-
|
321
|
-
def compile_and_get_states_by_static_args(self, *args, **kwargs) -> Tuple[State, ...]:
|
322
|
-
"""
|
323
|
-
Get the states that are read and written by the function.
|
324
|
-
|
325
|
-
Args:
|
326
|
-
*args: The arguments to the function.
|
327
|
-
**kwargs: The keyword arguments to the function.
|
328
|
-
|
329
|
-
Returns:
|
330
|
-
The states that are read and written by the function.
|
331
|
-
"""
|
332
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
333
|
-
if cache_key not in self._state_trace:
|
334
|
-
self.make_jaxpr(*args, **kwargs)
|
335
|
-
return self.get_states(cache_key)
|
336
|
-
|
337
|
-
def clear_cache(self) -> None:
|
338
|
-
"""
|
339
|
-
Clear the compilation cache.
|
340
|
-
"""
|
341
|
-
self._jaxpr.clear()
|
342
|
-
self._out_shapes.clear()
|
343
|
-
self._jaxpr_out_tree.clear()
|
344
|
-
self._state_trace.clear()
|
345
|
-
|
346
|
-
@staticmethod
|
347
|
-
def _init_trace_and_newarg() -> StateTrace:
|
348
|
-
# Should be within the calling of ``jax.make_jaxpr()``
|
349
|
-
state_trace: StateTrace = StateTrace()
|
350
|
-
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
351
|
-
frame = main.jaxpr_stack[-1]
|
352
|
-
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
353
|
-
state_trace.set_new_arg(functools.partial(_new_arg, frame, trace))
|
354
|
-
return state_trace
|
355
|
-
|
356
|
-
def _wrapped_fun_to_eval(self, cache_key, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
|
357
|
-
"""
|
358
|
-
Wrap the function and return the states that are read and written by the function and the output of the function.
|
359
|
-
|
360
|
-
Args:
|
361
|
-
*args: The arguments to the function.
|
362
|
-
**kwargs: The keyword arguments to the function.
|
363
|
-
|
364
|
-
Returns:
|
365
|
-
A tuple of the states that are read and written by the function and the output of the function.
|
366
|
-
"""
|
367
|
-
# state trace
|
368
|
-
_state_trace = self._init_trace_and_newarg()
|
369
|
-
self._state_trace[cache_key] = _state_trace
|
370
|
-
with _state_trace:
|
371
|
-
out = self.fun(*args, **kwargs)
|
372
|
-
state_values = _state_trace.collect_values('read', 'write', check_val_tree=True)
|
373
|
-
_state_trace.recovery_original_values()
|
374
|
-
|
375
|
-
# State instance as functional returns is not allowed.
|
376
|
-
# Checking whether the states are returned.
|
377
|
-
for leaf in jax.tree.leaves(out):
|
378
|
-
if isinstance(leaf, State):
|
379
|
-
leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
|
380
|
-
return out, state_values
|
381
|
-
|
382
|
-
def make_jaxpr(self, *args, **kwargs):
|
383
|
-
"""Creates a function that produces its jaxpr given example args.
|
384
|
-
|
385
|
-
A ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
386
|
-
argument ``return_shape`` is ``True``, then the returned function instead
|
387
|
-
returns a pair where the first element is the ``ClosedJaxpr``
|
388
|
-
representation of ``fun`` and the second element is a pytree representing
|
389
|
-
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
390
|
-
|
391
|
-
Args:
|
392
|
-
*args: The arguments to the function.
|
393
|
-
**kwargs: The keyword arguments to the function.
|
394
|
-
"""
|
395
|
-
|
396
|
-
# static args
|
397
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
398
|
-
|
399
|
-
if cache_key not in self._state_trace:
|
400
|
-
try:
|
401
|
-
# jaxpr
|
402
|
-
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
403
|
-
functools.partial(self._wrapped_fun_to_eval, cache_key),
|
404
|
-
static_argnums=self.static_argnums,
|
405
|
-
axis_env=self.axis_env,
|
406
|
-
return_shape=True,
|
407
|
-
abstracted_axes=self.abstracted_axes
|
408
|
-
)(*args, **kwargs)
|
409
|
-
|
410
|
-
# returns
|
411
|
-
self._jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
412
|
-
self._out_shapes[cache_key] = (out_shapes, state_shapes)
|
413
|
-
self._jaxpr[cache_key] = jaxpr
|
414
|
-
except Exception as e:
|
415
|
-
try:
|
416
|
-
self._state_trace.pop(cache_key)
|
417
|
-
except KeyError:
|
418
|
-
pass
|
419
|
-
raise e
|
420
|
-
|
421
|
-
return self
|
422
|
-
|
423
|
-
def jaxpr_call(self, state_vals, *args, **kwargs) -> Any:
|
424
|
-
"""
|
425
|
-
Call the function at the JAX Jaxpr level.
|
426
|
-
|
427
|
-
Args:
|
428
|
-
state_vals: The state values.
|
429
|
-
*args: The arguments to the function.
|
430
|
-
**kwargs: The keyword arguments to the function.
|
431
|
-
|
432
|
-
Returns:
|
433
|
-
State values and the function output.
|
434
|
-
"""
|
435
|
-
# state checking
|
436
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
437
|
-
states: Sequence[State] = self.get_states(cache_key)
|
438
|
-
assert len(state_vals) == len(states), 'State length mismatch.'
|
439
|
-
# # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
|
440
|
-
# for val, st in zip(state_vals, states): # check state's value tree structure
|
441
|
-
# st._check_value_tree(val)
|
442
|
-
|
443
|
-
# parameters
|
444
|
-
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
445
|
-
args = jax.tree.flatten((args, kwargs, state_vals))[0]
|
446
|
-
|
447
|
-
# calling the function
|
448
|
-
closed_jaxpr = self.get_jaxpr(cache_key)
|
449
|
-
out_treedef = self.get_out_treedef(cache_key)
|
450
|
-
jaxpr_outs = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
|
451
|
-
|
452
|
-
# output processing
|
453
|
-
out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
|
454
|
-
assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
|
455
|
-
# # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
|
456
|
-
# for val, st in zip(new_state_vals, states): # check state's value tree structure
|
457
|
-
# st._check_value_tree(val)
|
458
|
-
return new_state_vals, out
|
459
|
-
|
460
|
-
def jaxpr_call_auto(self, *args, **kwargs) -> Any:
|
461
|
-
"""
|
462
|
-
Call the function at the JAX Jaxpr level with automatic state management.
|
463
|
-
|
464
|
-
Args:
|
465
|
-
*args: The arguments to the function.
|
466
|
-
**kwargs: The keyword arguments to the function.
|
467
|
-
|
468
|
-
Returns:
|
469
|
-
The output of the function.
|
470
|
-
"""
|
471
|
-
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
472
|
-
states = self.get_states(cache_key)
|
473
|
-
state_vals, out = self.jaxpr_call([st.value for st in states], *args, **kwargs)
|
474
|
-
for st, val in zip(states, state_vals):
|
475
|
-
st.value = val
|
476
|
-
return out
|
477
|
-
|
478
|
-
|
479
|
-
@set_module_as("brainstate.transform")
|
480
|
-
def make_jaxpr(
|
481
|
-
fun: Callable,
|
482
|
-
static_argnums: Union[int, Iterable[int]] = (),
|
483
|
-
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
484
|
-
return_shape: bool = False,
|
485
|
-
abstracted_axes: Optional[Any] = None,
|
486
|
-
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
487
|
-
) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
|
488
|
-
Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
|
489
|
-
"""
|
490
|
-
Creates a function that produces its jaxpr given example args.
|
491
|
-
|
492
|
-
Args:
|
493
|
-
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
494
|
-
arguments and return value should be arrays, scalars, or standard Python
|
495
|
-
containers (tuple/list/dict) thereof.
|
496
|
-
static_argnums: See the :py:func:`jax.jit` docstring.
|
497
|
-
axis_env: Optional, a sequence of pairs where the first element is an axis
|
498
|
-
name and the second element is a positive integer representing the size of
|
499
|
-
the mapped axis with that name. This parameter is useful when lowering
|
500
|
-
functions that involve parallel communication collectives, and it
|
501
|
-
specifies the axis name/size environment that would be set up by
|
502
|
-
applications of :py:func:`jax.pmap`.
|
503
|
-
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
504
|
-
wrapped function returns a pair where the first element is the XLA
|
505
|
-
computation and the second element is a pytree with the same structure as
|
506
|
-
the output of ``fun`` and where the leaves are objects with ``shape``,
|
507
|
-
``dtype``, and ``named_shape`` attributes representing the corresponding
|
508
|
-
types of the output leaves.
|
509
|
-
abstracted_axes: Optional, a pytree with the same structure as the input
|
510
|
-
arguments to ``fun``. The leaves of the pytree can be either None or a
|
511
|
-
dict with axis names as keys and integers as values. If the leaf is None,
|
512
|
-
then the corresponding axis is not abstracted. If the leaf is a dict, then
|
513
|
-
the corresponding axis is abstracted, and the dict specifies the axis name
|
514
|
-
and size. The abstracted axes are used to infer the input type of the
|
515
|
-
function. If None, then all axes are abstracted.
|
516
|
-
state_returns: Optional, a string or a tuple of strings. The default is
|
517
|
-
``('read', 'write')``. The strings specify the categories of states to be
|
518
|
-
returned by the wrapped function. The categories are ``'read'`` and
|
519
|
-
``'write'``. If the category is ``'read'``, then the wrapped function
|
520
|
-
returns the states that are read by the function. If the category is
|
521
|
-
``'write'``, then the wrapped function returns the states that are written
|
522
|
-
by the function. If the category is ``'read'`` and ``'write'``, then the
|
523
|
-
wrapped function returns both the read and write states.
|
524
|
-
|
525
|
-
|
526
|
-
Returns:
|
527
|
-
A wrapped version of ``fun`` that when applied to example arguments returns
|
528
|
-
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
529
|
-
argument ``return_shape`` is ``True``, then the returned function instead
|
530
|
-
returns a pair where the first element is the ``ClosedJaxpr``
|
531
|
-
representation of ``fun`` and the second element is a pytree representing
|
532
|
-
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
533
|
-
|
534
|
-
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
535
|
-
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
536
|
-
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
537
|
-
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
538
|
-
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
539
|
-
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
540
|
-
|
541
|
-
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
542
|
-
instead give a few examples.
|
543
|
-
|
544
|
-
>>> import jax
|
545
|
-
>>> import brainstate as bst
|
546
|
-
>>>
|
547
|
-
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
548
|
-
>>> print(f(3.0))
|
549
|
-
-0.83602
|
550
|
-
>>> jaxpr, states = bst.transform.make_jaxpr(f)(3.0)
|
551
|
-
>>> jaxpr
|
552
|
-
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
553
|
-
>>> jaxpr, states = bst.transform.make_jaxpr(jax.grad(f))(3.0)
|
554
|
-
>>> jaxpr
|
555
|
-
{ lambda ; a:f32[]. let
|
556
|
-
b:f32[] = cos a
|
557
|
-
c:f32[] = sin a
|
558
|
-
_:f32[] = sin b
|
559
|
-
d:f32[] = cos b
|
560
|
-
e:f32[] = mul 1.0 d
|
561
|
-
f:f32[] = neg e
|
562
|
-
g:f32[] = mul f c
|
563
|
-
in (g,) }
|
564
|
-
"""
|
565
|
-
|
566
|
-
stateful_fun = StatefulFunction(fun, static_argnums, axis_env, abstracted_axes, state_returns)
|
567
|
-
|
568
|
-
@wraps(fun)
|
569
|
-
def make_jaxpr_f(*args, **kwargs):
|
570
|
-
stateful_fun.make_jaxpr(*args, **kwargs)
|
571
|
-
cache_key = stateful_fun.get_arg_cache_key(*args, **kwargs)
|
572
|
-
if return_shape:
|
573
|
-
return (stateful_fun.get_jaxpr(cache_key),
|
574
|
-
stateful_fun.get_states(cache_key),
|
575
|
-
stateful_fun.get_out_shapes(cache_key)[0])
|
576
|
-
else:
|
577
|
-
return (stateful_fun.get_jaxpr(cache_key),
|
578
|
-
stateful_fun.get_states(cache_key))
|
579
|
-
|
580
|
-
# wrapped jaxpr builder function
|
581
|
-
make_jaxpr_f.__module__ = "brainstate.transform"
|
582
|
-
if hasattr(fun, "__qualname__"):
|
583
|
-
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
584
|
-
if hasattr(fun, "__name__"):
|
585
|
-
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
586
|
-
return make_jaxpr_f
|
587
|
-
|
588
|
-
|
589
|
-
def _check_callable(fun):
|
590
|
-
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
591
|
-
# is that we can't take weak references to them, which the C++ JIT requires.
|
592
|
-
if isinstance(fun, staticmethod):
|
593
|
-
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
594
|
-
if not callable(fun):
|
595
|
-
raise TypeError(f"Expected a callable value, got {fun}")
|
596
|
-
if inspect.isgeneratorfunction(fun):
|
597
|
-
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
598
|
-
|
599
|
-
|
600
|
-
def _broadcast_prefix(
|
601
|
-
prefix_tree: Any,
|
602
|
-
full_tree: Any,
|
603
|
-
is_leaf: Callable[[Any], bool] | None = None
|
604
|
-
) -> list[Any]:
|
605
|
-
# If prefix_tree is not a tree prefix of full_tree, this code can raise a
|
606
|
-
# ValueError; use prefix_errors to find disagreements and raise more precise
|
607
|
-
# error messages.
|
608
|
-
result = []
|
609
|
-
num_leaves = lambda t: jax.tree.structure(t).num_leaves
|
610
|
-
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
611
|
-
jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
612
|
-
return result
|
613
|
-
|
614
|
-
|
615
|
-
def _flat_axes_specs(
|
616
|
-
abstracted_axes, *args, **kwargs
|
617
|
-
) -> list[pe.AbstractedAxesSpec]:
|
618
|
-
if kwargs:
|
619
|
-
raise NotImplementedError
|
620
|
-
|
621
|
-
def ax_leaf(l):
|
622
|
-
return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
|
623
|
-
isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
|
624
|
-
|
625
|
-
return _broadcast_prefix(abstracted_axes, args, ax_leaf)
|
626
|
-
|
627
|
-
|
628
|
-
@transformation_with_aux
|
629
|
-
def _flatten_fun(in_tree, *args_flat):
|
630
|
-
py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
|
631
|
-
ans = yield py_args, py_kwargs
|
632
|
-
yield jax.tree.flatten(ans)
|
633
|
-
|
634
|
-
|
635
|
-
def _make_jaxpr(
|
636
|
-
fun: Callable,
|
637
|
-
static_argnums: int | Iterable[int] = (),
|
638
|
-
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
639
|
-
return_shape: bool = False,
|
640
|
-
abstracted_axes: Any | None = None,
|
641
|
-
) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
|
642
|
-
"""Creates a function that produces its jaxpr given example args.
|
643
|
-
|
644
|
-
Args:
|
645
|
-
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
646
|
-
arguments and return value should be arrays, scalars, or standard Python
|
647
|
-
containers (tuple/list/dict) thereof.
|
648
|
-
static_argnums: See the :py:func:`jax.jit` docstring.
|
649
|
-
axis_env: Optional, a sequence of pairs where the first element is an axis
|
650
|
-
name and the second element is a positive integer representing the size of
|
651
|
-
the mapped axis with that name. This parameter is useful when lowering
|
652
|
-
functions that involve parallel communication collectives, and it
|
653
|
-
specifies the axis name/size environment that would be set up by
|
654
|
-
applications of :py:func:`jax.pmap`.
|
655
|
-
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
656
|
-
wrapped function returns a pair where the first element is the
|
657
|
-
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
658
|
-
pytree with the same structure as the output of ``fun`` and where the
|
659
|
-
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
660
|
-
attributes representing the corresponding types of the output leaves.
|
661
|
-
|
662
|
-
Returns:
|
663
|
-
A wrapped version of ``fun`` that when applied to example arguments returns
|
664
|
-
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
665
|
-
argument ``return_shape`` is ``True``, then the returned function instead
|
666
|
-
returns a pair where the first element is the ``ClosedJaxpr``
|
667
|
-
representation of ``fun`` and the second element is a pytree representing
|
668
|
-
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
669
|
-
|
670
|
-
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
671
|
-
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
672
|
-
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
673
|
-
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
674
|
-
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
675
|
-
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
676
|
-
|
677
|
-
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
678
|
-
instead give a few examples.
|
679
|
-
|
680
|
-
>>> import jax
|
681
|
-
>>>
|
682
|
-
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
683
|
-
>>> print(f(3.0))
|
684
|
-
-0.83602
|
685
|
-
>>> _make_jaxpr(f)(3.0)
|
686
|
-
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
687
|
-
>>> _make_jaxpr(jax.grad(f))(3.0)
|
688
|
-
{ lambda ; a:f32[]. let
|
689
|
-
b:f32[] = cos a
|
690
|
-
c:f32[] = sin a
|
691
|
-
_:f32[] = sin b
|
692
|
-
d:f32[] = cos b
|
693
|
-
e:f32[] = mul 1.0 d
|
694
|
-
f:f32[] = neg e
|
695
|
-
g:f32[] = mul f c
|
696
|
-
in (g,) }
|
697
|
-
"""
|
698
|
-
_check_callable(fun)
|
699
|
-
static_argnums = _ensure_index_tuple(static_argnums)
|
700
|
-
|
701
|
-
def _abstractify(args, kwargs):
|
702
|
-
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
703
|
-
if abstracted_axes is None:
|
704
|
-
return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
705
|
-
else:
|
706
|
-
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
707
|
-
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
708
|
-
in_avals, keep_inputs = jax.util.unzip2(in_type)
|
709
|
-
return in_avals, in_tree, keep_inputs
|
710
|
-
|
711
|
-
@wraps(fun)
|
712
|
-
@api_boundary
|
713
|
-
def make_jaxpr_f(*args, **kwargs):
|
714
|
-
f = wrap_init(fun)
|
715
|
-
if static_argnums:
|
716
|
-
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
717
|
-
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
718
|
-
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
719
|
-
in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
|
720
|
-
f, out_tree = _flatten_fun(f, in_tree)
|
721
|
-
f = annotate(f, in_type)
|
722
|
-
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
723
|
-
with ExitStack() as stack:
|
724
|
-
for axis_name, size in axis_env or []:
|
725
|
-
stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
|
726
|
-
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
727
|
-
closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
|
728
|
-
if return_shape:
|
729
|
-
out_avals, _ = jax.util.unzip2(out_type)
|
730
|
-
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
731
|
-
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
732
|
-
return closed_jaxpr
|
733
|
-
|
734
|
-
make_jaxpr_f.__module__ = "brainstate.transform"
|
735
|
-
if hasattr(fun, "__qualname__"):
|
736
|
-
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
737
|
-
if hasattr(fun, "__name__"):
|
738
|
-
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
739
|
-
return make_jaxpr_f
|