brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/compile/_jit.py
CHANGED
@@ -1,346 +1,346 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from collections.abc import Iterable, Sequence
|
18
|
-
from typing import (Any, Callable, Union)
|
19
|
-
|
20
|
-
import jax
|
21
|
-
from jax._src import sharding_impls
|
22
|
-
|
23
|
-
from brainstate._compatible_import import Device
|
24
|
-
from brainstate._utils import set_module_as
|
25
|
-
from brainstate.typing import Missing
|
26
|
-
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
27
|
-
from ._util import write_back_state_values
|
28
|
-
|
29
|
-
__all__ = ['jit']
|
30
|
-
|
31
|
-
|
32
|
-
class JittedFunction(Callable):
|
33
|
-
"""
|
34
|
-
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
|
-
"""
|
36
|
-
origin_fun: Callable # the original function
|
37
|
-
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
|
-
jitted_fun: jax.stages.Wrapped # the jitted function
|
39
|
-
clear_cache: Callable # clear the cache of the jitted function
|
40
|
-
eval_shape: Callable # evaluate the shape of the jitted function
|
41
|
-
compile: Callable # lower the jitted function
|
42
|
-
trace: Callable # trace the jitted
|
43
|
-
|
44
|
-
def __call__(self, *args, **kwargs):
|
45
|
-
pass
|
46
|
-
|
47
|
-
|
48
|
-
def _get_jitted_fun(
|
49
|
-
fun: Callable,
|
50
|
-
in_shardings,
|
51
|
-
out_shardings,
|
52
|
-
static_argnums,
|
53
|
-
donate_argnums,
|
54
|
-
static_argnames,
|
55
|
-
donate_argnames,
|
56
|
-
keep_unused,
|
57
|
-
device,
|
58
|
-
backend,
|
59
|
-
inline,
|
60
|
-
abstracted_axes,
|
61
|
-
**kwargs
|
62
|
-
) -> JittedFunction:
|
63
|
-
static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
64
|
-
donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
|
65
|
-
fun = StatefulFunction(
|
66
|
-
fun,
|
67
|
-
static_argnums=static_argnums,
|
68
|
-
static_argnames=static_argnames,
|
69
|
-
abstracted_axes=abstracted_axes,
|
70
|
-
cache_type='jit',
|
71
|
-
name='jit'
|
72
|
-
)
|
73
|
-
jit_fun = jax.jit(
|
74
|
-
fun.jaxpr_call,
|
75
|
-
static_argnums=tuple(i + 1 for i in static_argnums),
|
76
|
-
static_argnames=static_argnames,
|
77
|
-
donate_argnums=tuple(i + 1 for i in donate_argnums),
|
78
|
-
donate_argnames=donate_argnames,
|
79
|
-
keep_unused=keep_unused,
|
80
|
-
device=device,
|
81
|
-
backend=backend,
|
82
|
-
inline=inline,
|
83
|
-
in_shardings=in_shardings,
|
84
|
-
out_shardings=out_shardings,
|
85
|
-
abstracted_axes=abstracted_axes,
|
86
|
-
**kwargs
|
87
|
-
)
|
88
|
-
|
89
|
-
@functools.wraps(fun.fun)
|
90
|
-
def jitted_fun(*args, **params):
|
91
|
-
if jax.config.jax_disable_jit:
|
92
|
-
return fun.fun(*args, **params)
|
93
|
-
|
94
|
-
# compile the function and get the state trace
|
95
|
-
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
96
|
-
read_state_vals = state_trace.get_read_state_values(True)
|
97
|
-
|
98
|
-
# call the jitted function
|
99
|
-
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
100
|
-
|
101
|
-
# write the state values back to the states
|
102
|
-
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
103
|
-
return outs
|
104
|
-
|
105
|
-
def clear_cache():
|
106
|
-
"""
|
107
|
-
Clear the cache of the jitted function.
|
108
|
-
"""
|
109
|
-
# clear the cache of the stateful function
|
110
|
-
fun.clear_cache()
|
111
|
-
try:
|
112
|
-
# clear the cache of the jitted function
|
113
|
-
jit_fun.clear_cache()
|
114
|
-
except AttributeError:
|
115
|
-
pass
|
116
|
-
|
117
|
-
def eval_shape():
|
118
|
-
raise NotImplementedError
|
119
|
-
|
120
|
-
def trace():
|
121
|
-
"""Trace this function explicitly for the given arguments.
|
122
|
-
|
123
|
-
A traced function is staged out of Python and translated to a jaxpr. It is
|
124
|
-
ready for lowering but not yet lowered.
|
125
|
-
|
126
|
-
Returns:
|
127
|
-
A ``Traced`` instance representing the tracing.
|
128
|
-
"""
|
129
|
-
raise NotImplementedError
|
130
|
-
|
131
|
-
def compile(*args, **params):
|
132
|
-
"""Lower this function explicitly for the given arguments.
|
133
|
-
|
134
|
-
A lowered function is staged out of Python and translated to a
|
135
|
-
compiler's input language, possibly in a backend-dependent
|
136
|
-
manner. It is ready for compilation but not yet compiled.
|
137
|
-
|
138
|
-
Returns:
|
139
|
-
A ``Lowered`` instance representing the lowering.
|
140
|
-
"""
|
141
|
-
# compile the function and get the state trace
|
142
|
-
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
143
|
-
read_state_vals = state_trace.get_read_state_values(replace_writen=True)
|
144
|
-
write_state_vals = state_trace.get_write_state_values(replace_read=True)
|
145
|
-
|
146
|
-
# compile the model
|
147
|
-
ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
|
148
|
-
|
149
|
-
# write the state values back to the states
|
150
|
-
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
151
|
-
return ret
|
152
|
-
|
153
|
-
jitted_fun: JittedFunction
|
154
|
-
|
155
|
-
# the original function
|
156
|
-
jitted_fun.origin_fun = fun.fun
|
157
|
-
|
158
|
-
# the stateful function for extracting states
|
159
|
-
jitted_fun.stateful_fun = fun
|
160
|
-
|
161
|
-
# the jitted function
|
162
|
-
jitted_fun.jitted_fun = jit_fun
|
163
|
-
|
164
|
-
# clear cache
|
165
|
-
jitted_fun.clear_cache = clear_cache
|
166
|
-
|
167
|
-
# evaluate the shape of the jitted function
|
168
|
-
jitted_fun.eval_shape = eval_shape
|
169
|
-
|
170
|
-
# compile the jitted function
|
171
|
-
jitted_fun.compile = compile
|
172
|
-
|
173
|
-
# trace the jitted function
|
174
|
-
jitted_fun.trace = trace
|
175
|
-
|
176
|
-
return jitted_fun
|
177
|
-
|
178
|
-
|
179
|
-
@set_module_as('brainstate.compile')
|
180
|
-
def jit(
|
181
|
-
fun: Callable | Missing = Missing(),
|
182
|
-
in_shardings=sharding_impls.UNSPECIFIED,
|
183
|
-
out_shardings=sharding_impls.UNSPECIFIED,
|
184
|
-
static_argnums: int | Sequence[int] | None = None,
|
185
|
-
donate_argnums: int | Sequence[int] | None = None,
|
186
|
-
static_argnames: str | Sequence[str] | None = None,
|
187
|
-
donate_argnames: str | Iterable[str] | None = None,
|
188
|
-
keep_unused: bool = False,
|
189
|
-
device: Device | None = None,
|
190
|
-
backend: str | None = None,
|
191
|
-
inline: bool = False,
|
192
|
-
abstracted_axes: Any | None = None,
|
193
|
-
**kwargs
|
194
|
-
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
|
195
|
-
"""
|
196
|
-
Sets up ``fun`` for just-in-time compilation with XLA.
|
197
|
-
|
198
|
-
Args:
|
199
|
-
fun: Function to be jitted.
|
200
|
-
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
201
|
-
with all actual arguments replaced by resource assignment specifications.
|
202
|
-
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
203
|
-
whole subtree), in which case the leaves get broadcast to all values in
|
204
|
-
that subtree.
|
205
|
-
|
206
|
-
The ``in_shardings`` argument is optional. JAX will infer the shardings
|
207
|
-
from the input :py:class:`jax.Array`'s and defaults to replicating the input
|
208
|
-
if the sharding cannot be inferred.
|
209
|
-
|
210
|
-
The valid resource assignment specifications are:
|
211
|
-
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
212
|
-
will be partitioned. With this, using a mesh context manager is not
|
213
|
-
required.
|
214
|
-
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
|
215
|
-
it wants.
|
216
|
-
For in_shardings, JAX will mark is as replicated but this behavior
|
217
|
-
can change in the future.
|
218
|
-
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
219
|
-
determine the output shardings.
|
220
|
-
|
221
|
-
The size of every dimension has to be a multiple of the total number of
|
222
|
-
resources assigned to it. This is similar to pjit's in_shardings.
|
223
|
-
out_shardings: Like ``in_shardings``, but specifies resource
|
224
|
-
assignment for function outputs. This is similar to pjit's
|
225
|
-
out_shardings.
|
226
|
-
|
227
|
-
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
228
|
-
will use GSPMD's sharding propagation to figure out what the sharding of the
|
229
|
-
output(s) should be.
|
230
|
-
static_argnums: An optional int or collection of ints that specify which
|
231
|
-
positional arguments to treat as static (compile-time constant).
|
232
|
-
Operations that only depend on static arguments will be constant-folded in
|
233
|
-
Python (during tracing), and so the corresponding argument values can be
|
234
|
-
any Python object.
|
235
|
-
|
236
|
-
Static arguments should be hashable, meaning both ``__hash__`` and
|
237
|
-
``__eq__`` are implemented, and immutable. Calling the jitted function
|
238
|
-
with different values for these constants will trigger recompilation.
|
239
|
-
Arguments that are not arrays or containers thereof must be marked as
|
240
|
-
static.
|
241
|
-
|
242
|
-
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
|
243
|
-
arguments are treated as static. If ``static_argnums`` is not provided but
|
244
|
-
``static_argnames`` is, or vice versa, JAX uses
|
245
|
-
:code:`inspect.signature(fun)` to find any positional arguments that
|
246
|
-
correspond to ``static_argnames``
|
247
|
-
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
|
248
|
-
provided, ``inspect.signature`` is not used, and only actual
|
249
|
-
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
250
|
-
be treated as static.
|
251
|
-
static_argnames: An optional string or collection of strings specifying
|
252
|
-
which named arguments are treated as static (compile-time constant).
|
253
|
-
Operations that only depend on static arguments will be constant-folded in
|
254
|
-
Python (during tracing), and so the corresponding argument values can be
|
255
|
-
any Python object.
|
256
|
-
donate_argnums: Specify which positional argument buffers are "donated" to
|
257
|
-
the computation. It is safe to donate argument buffers if you no longer
|
258
|
-
need them once the computation has finished. In some cases XLA can make
|
259
|
-
use of donated buffers to reduce the amount of memory needed to perform a
|
260
|
-
computation, for example recycling one of your input buffers to store a
|
261
|
-
result. You should not reuse buffers that you donate to a computation, JAX
|
262
|
-
will raise an error if you try to. By default, no argument buffers are
|
263
|
-
donated.
|
264
|
-
|
265
|
-
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
266
|
-
arguments are donated. If ``donate_argnums`` is not provided but
|
267
|
-
``donate_argnames`` is, or vice versa, JAX uses
|
268
|
-
:code:`inspect.signature(fun)` to find any positional arguments that
|
269
|
-
correspond to ``donate_argnames``
|
270
|
-
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
271
|
-
provided, ``inspect.signature`` is not used, and only actual
|
272
|
-
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
273
|
-
be donated.
|
274
|
-
|
275
|
-
For more details on buffer donation see the
|
276
|
-
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
277
|
-
donate_argnames: An optional string or collection of strings specifying
|
278
|
-
which named arguments are donated to the computation. See the
|
279
|
-
comment on ``donate_argnums`` for details. If not
|
280
|
-
provided but ``donate_argnums`` is set, the default is based on calling
|
281
|
-
``inspect.signature(fun)`` to find corresponding named arguments.
|
282
|
-
keep_unused: If `False` (the default), arguments that JAX determines to be
|
283
|
-
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
284
|
-
Such arguments will not be transferred to the device nor provided to the
|
285
|
-
underlying executable. If `True`, unused arguments will not be pruned.
|
286
|
-
device: This is an experimental feature and the API is likely to change.
|
287
|
-
Optional, the Device the jitted function will run on. (Available devices
|
288
|
-
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
289
|
-
from XLA's DeviceAssignment logic and is usually to use
|
290
|
-
``jax.devices()[0]``.
|
291
|
-
backend: This is an experimental feature and the API is likely to change.
|
292
|
-
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
293
|
-
``'tpu'``.
|
294
|
-
inline: Specify whether this function should be inlined into enclosing
|
295
|
-
jaxprs (rather than being represented as an application of the xla_call
|
296
|
-
primitive with its own subjaxpr). Default False.
|
297
|
-
abstracted_axes:
|
298
|
-
|
299
|
-
Returns:
|
300
|
-
A wrapped version of ``fun``, set up for just-in-time compilation.
|
301
|
-
The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
|
302
|
-
and has the following attributes and methods:
|
303
|
-
|
304
|
-
- ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
|
305
|
-
- ``origin_fun(*args, **kwargs)``: the original function
|
306
|
-
- ``jitted_fun(*args, **kwargs)``: the jitted function
|
307
|
-
- ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
|
308
|
-
|
309
|
-
"""
|
310
|
-
|
311
|
-
if isinstance(fun, Missing):
|
312
|
-
def wrapper(fun_again: Callable) -> JittedFunction:
|
313
|
-
return _get_jitted_fun(
|
314
|
-
fun_again,
|
315
|
-
in_shardings=in_shardings,
|
316
|
-
out_shardings=out_shardings,
|
317
|
-
static_argnums=static_argnums,
|
318
|
-
donate_argnums=donate_argnums,
|
319
|
-
static_argnames=static_argnames,
|
320
|
-
donate_argnames=donate_argnames,
|
321
|
-
keep_unused=keep_unused,
|
322
|
-
device=device,
|
323
|
-
backend=backend,
|
324
|
-
inline=inline,
|
325
|
-
abstracted_axes=abstracted_axes,
|
326
|
-
**kwargs
|
327
|
-
)
|
328
|
-
|
329
|
-
return wrapper
|
330
|
-
|
331
|
-
else:
|
332
|
-
return _get_jitted_fun(
|
333
|
-
fun,
|
334
|
-
in_shardings,
|
335
|
-
out_shardings,
|
336
|
-
static_argnums,
|
337
|
-
donate_argnums,
|
338
|
-
static_argnames,
|
339
|
-
donate_argnames,
|
340
|
-
keep_unused,
|
341
|
-
device,
|
342
|
-
backend,
|
343
|
-
inline,
|
344
|
-
abstracted_axes,
|
345
|
-
**kwargs
|
346
|
-
)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from collections.abc import Iterable, Sequence
|
18
|
+
from typing import (Any, Callable, Union)
|
19
|
+
|
20
|
+
import jax
|
21
|
+
from jax._src import sharding_impls
|
22
|
+
|
23
|
+
from brainstate._compatible_import import Device
|
24
|
+
from brainstate._utils import set_module_as
|
25
|
+
from brainstate.typing import Missing
|
26
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
|
27
|
+
from ._util import write_back_state_values
|
28
|
+
|
29
|
+
__all__ = ['jit']
|
30
|
+
|
31
|
+
|
32
|
+
class JittedFunction(Callable):
|
33
|
+
"""
|
34
|
+
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
|
+
"""
|
36
|
+
origin_fun: Callable # the original function
|
37
|
+
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
|
+
jitted_fun: jax.stages.Wrapped # the jitted function
|
39
|
+
clear_cache: Callable # clear the cache of the jitted function
|
40
|
+
eval_shape: Callable # evaluate the shape of the jitted function
|
41
|
+
compile: Callable # lower the jitted function
|
42
|
+
trace: Callable # trace the jitted
|
43
|
+
|
44
|
+
def __call__(self, *args, **kwargs):
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
def _get_jitted_fun(
|
49
|
+
fun: Callable,
|
50
|
+
in_shardings,
|
51
|
+
out_shardings,
|
52
|
+
static_argnums,
|
53
|
+
donate_argnums,
|
54
|
+
static_argnames,
|
55
|
+
donate_argnames,
|
56
|
+
keep_unused,
|
57
|
+
device,
|
58
|
+
backend,
|
59
|
+
inline,
|
60
|
+
abstracted_axes,
|
61
|
+
**kwargs
|
62
|
+
) -> JittedFunction:
|
63
|
+
static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
64
|
+
donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
|
65
|
+
fun = StatefulFunction(
|
66
|
+
fun,
|
67
|
+
static_argnums=static_argnums,
|
68
|
+
static_argnames=static_argnames,
|
69
|
+
abstracted_axes=abstracted_axes,
|
70
|
+
cache_type='jit',
|
71
|
+
name='jit'
|
72
|
+
)
|
73
|
+
jit_fun = jax.jit(
|
74
|
+
fun.jaxpr_call,
|
75
|
+
static_argnums=tuple(i + 1 for i in static_argnums),
|
76
|
+
static_argnames=static_argnames,
|
77
|
+
donate_argnums=tuple(i + 1 for i in donate_argnums),
|
78
|
+
donate_argnames=donate_argnames,
|
79
|
+
keep_unused=keep_unused,
|
80
|
+
device=device,
|
81
|
+
backend=backend,
|
82
|
+
inline=inline,
|
83
|
+
in_shardings=in_shardings,
|
84
|
+
out_shardings=out_shardings,
|
85
|
+
abstracted_axes=abstracted_axes,
|
86
|
+
**kwargs
|
87
|
+
)
|
88
|
+
|
89
|
+
@functools.wraps(fun.fun)
|
90
|
+
def jitted_fun(*args, **params):
|
91
|
+
if jax.config.jax_disable_jit:
|
92
|
+
return fun.fun(*args, **params)
|
93
|
+
|
94
|
+
# compile the function and get the state trace
|
95
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
96
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
97
|
+
|
98
|
+
# call the jitted function
|
99
|
+
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
100
|
+
|
101
|
+
# write the state values back to the states
|
102
|
+
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
103
|
+
return outs
|
104
|
+
|
105
|
+
def clear_cache():
|
106
|
+
"""
|
107
|
+
Clear the cache of the jitted function.
|
108
|
+
"""
|
109
|
+
# clear the cache of the stateful function
|
110
|
+
fun.clear_cache()
|
111
|
+
try:
|
112
|
+
# clear the cache of the jitted function
|
113
|
+
jit_fun.clear_cache()
|
114
|
+
except AttributeError:
|
115
|
+
pass
|
116
|
+
|
117
|
+
def eval_shape():
|
118
|
+
raise NotImplementedError
|
119
|
+
|
120
|
+
def trace():
|
121
|
+
"""Trace this function explicitly for the given arguments.
|
122
|
+
|
123
|
+
A traced function is staged out of Python and translated to a jaxpr. It is
|
124
|
+
ready for lowering but not yet lowered.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
A ``Traced`` instance representing the tracing.
|
128
|
+
"""
|
129
|
+
raise NotImplementedError
|
130
|
+
|
131
|
+
def compile(*args, **params):
|
132
|
+
"""Lower this function explicitly for the given arguments.
|
133
|
+
|
134
|
+
A lowered function is staged out of Python and translated to a
|
135
|
+
compiler's input language, possibly in a backend-dependent
|
136
|
+
manner. It is ready for compilation but not yet compiled.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
A ``Lowered`` instance representing the lowering.
|
140
|
+
"""
|
141
|
+
# compile the function and get the state trace
|
142
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
143
|
+
read_state_vals = state_trace.get_read_state_values(replace_writen=True)
|
144
|
+
write_state_vals = state_trace.get_write_state_values(replace_read=True)
|
145
|
+
|
146
|
+
# compile the model
|
147
|
+
ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
|
148
|
+
|
149
|
+
# write the state values back to the states
|
150
|
+
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
151
|
+
return ret
|
152
|
+
|
153
|
+
jitted_fun: JittedFunction
|
154
|
+
|
155
|
+
# the original function
|
156
|
+
jitted_fun.origin_fun = fun.fun
|
157
|
+
|
158
|
+
# the stateful function for extracting states
|
159
|
+
jitted_fun.stateful_fun = fun
|
160
|
+
|
161
|
+
# the jitted function
|
162
|
+
jitted_fun.jitted_fun = jit_fun
|
163
|
+
|
164
|
+
# clear cache
|
165
|
+
jitted_fun.clear_cache = clear_cache
|
166
|
+
|
167
|
+
# evaluate the shape of the jitted function
|
168
|
+
jitted_fun.eval_shape = eval_shape
|
169
|
+
|
170
|
+
# compile the jitted function
|
171
|
+
jitted_fun.compile = compile
|
172
|
+
|
173
|
+
# trace the jitted function
|
174
|
+
jitted_fun.trace = trace
|
175
|
+
|
176
|
+
return jitted_fun
|
177
|
+
|
178
|
+
|
179
|
+
@set_module_as('brainstate.compile')
|
180
|
+
def jit(
|
181
|
+
fun: Callable | Missing = Missing(),
|
182
|
+
in_shardings=sharding_impls.UNSPECIFIED,
|
183
|
+
out_shardings=sharding_impls.UNSPECIFIED,
|
184
|
+
static_argnums: int | Sequence[int] | None = None,
|
185
|
+
donate_argnums: int | Sequence[int] | None = None,
|
186
|
+
static_argnames: str | Sequence[str] | None = None,
|
187
|
+
donate_argnames: str | Iterable[str] | None = None,
|
188
|
+
keep_unused: bool = False,
|
189
|
+
device: Device | None = None,
|
190
|
+
backend: str | None = None,
|
191
|
+
inline: bool = False,
|
192
|
+
abstracted_axes: Any | None = None,
|
193
|
+
**kwargs
|
194
|
+
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
|
195
|
+
"""
|
196
|
+
Sets up ``fun`` for just-in-time compilation with XLA.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
fun: Function to be jitted.
|
200
|
+
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
201
|
+
with all actual arguments replaced by resource assignment specifications.
|
202
|
+
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
203
|
+
whole subtree), in which case the leaves get broadcast to all values in
|
204
|
+
that subtree.
|
205
|
+
|
206
|
+
The ``in_shardings`` argument is optional. JAX will infer the shardings
|
207
|
+
from the input :py:class:`jax.Array`'s and defaults to replicating the input
|
208
|
+
if the sharding cannot be inferred.
|
209
|
+
|
210
|
+
The valid resource assignment specifications are:
|
211
|
+
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
212
|
+
will be partitioned. With this, using a mesh context manager is not
|
213
|
+
required.
|
214
|
+
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
|
215
|
+
it wants.
|
216
|
+
For in_shardings, JAX will mark is as replicated but this behavior
|
217
|
+
can change in the future.
|
218
|
+
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
219
|
+
determine the output shardings.
|
220
|
+
|
221
|
+
The size of every dimension has to be a multiple of the total number of
|
222
|
+
resources assigned to it. This is similar to pjit's in_shardings.
|
223
|
+
out_shardings: Like ``in_shardings``, but specifies resource
|
224
|
+
assignment for function outputs. This is similar to pjit's
|
225
|
+
out_shardings.
|
226
|
+
|
227
|
+
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
228
|
+
will use GSPMD's sharding propagation to figure out what the sharding of the
|
229
|
+
output(s) should be.
|
230
|
+
static_argnums: An optional int or collection of ints that specify which
|
231
|
+
positional arguments to treat as static (compile-time constant).
|
232
|
+
Operations that only depend on static arguments will be constant-folded in
|
233
|
+
Python (during tracing), and so the corresponding argument values can be
|
234
|
+
any Python object.
|
235
|
+
|
236
|
+
Static arguments should be hashable, meaning both ``__hash__`` and
|
237
|
+
``__eq__`` are implemented, and immutable. Calling the jitted function
|
238
|
+
with different values for these constants will trigger recompilation.
|
239
|
+
Arguments that are not arrays or containers thereof must be marked as
|
240
|
+
static.
|
241
|
+
|
242
|
+
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
|
243
|
+
arguments are treated as static. If ``static_argnums`` is not provided but
|
244
|
+
``static_argnames`` is, or vice versa, JAX uses
|
245
|
+
:code:`inspect.signature(fun)` to find any positional arguments that
|
246
|
+
correspond to ``static_argnames``
|
247
|
+
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
|
248
|
+
provided, ``inspect.signature`` is not used, and only actual
|
249
|
+
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
250
|
+
be treated as static.
|
251
|
+
static_argnames: An optional string or collection of strings specifying
|
252
|
+
which named arguments are treated as static (compile-time constant).
|
253
|
+
Operations that only depend on static arguments will be constant-folded in
|
254
|
+
Python (during tracing), and so the corresponding argument values can be
|
255
|
+
any Python object.
|
256
|
+
donate_argnums: Specify which positional argument buffers are "donated" to
|
257
|
+
the computation. It is safe to donate argument buffers if you no longer
|
258
|
+
need them once the computation has finished. In some cases XLA can make
|
259
|
+
use of donated buffers to reduce the amount of memory needed to perform a
|
260
|
+
computation, for example recycling one of your input buffers to store a
|
261
|
+
result. You should not reuse buffers that you donate to a computation, JAX
|
262
|
+
will raise an error if you try to. By default, no argument buffers are
|
263
|
+
donated.
|
264
|
+
|
265
|
+
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
266
|
+
arguments are donated. If ``donate_argnums`` is not provided but
|
267
|
+
``donate_argnames`` is, or vice versa, JAX uses
|
268
|
+
:code:`inspect.signature(fun)` to find any positional arguments that
|
269
|
+
correspond to ``donate_argnames``
|
270
|
+
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
271
|
+
provided, ``inspect.signature`` is not used, and only actual
|
272
|
+
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
273
|
+
be donated.
|
274
|
+
|
275
|
+
For more details on buffer donation see the
|
276
|
+
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
277
|
+
donate_argnames: An optional string or collection of strings specifying
|
278
|
+
which named arguments are donated to the computation. See the
|
279
|
+
comment on ``donate_argnums`` for details. If not
|
280
|
+
provided but ``donate_argnums`` is set, the default is based on calling
|
281
|
+
``inspect.signature(fun)`` to find corresponding named arguments.
|
282
|
+
keep_unused: If `False` (the default), arguments that JAX determines to be
|
283
|
+
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
284
|
+
Such arguments will not be transferred to the device nor provided to the
|
285
|
+
underlying executable. If `True`, unused arguments will not be pruned.
|
286
|
+
device: This is an experimental feature and the API is likely to change.
|
287
|
+
Optional, the Device the jitted function will run on. (Available devices
|
288
|
+
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
289
|
+
from XLA's DeviceAssignment logic and is usually to use
|
290
|
+
``jax.devices()[0]``.
|
291
|
+
backend: This is an experimental feature and the API is likely to change.
|
292
|
+
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
293
|
+
``'tpu'``.
|
294
|
+
inline: Specify whether this function should be inlined into enclosing
|
295
|
+
jaxprs (rather than being represented as an application of the xla_call
|
296
|
+
primitive with its own subjaxpr). Default False.
|
297
|
+
abstracted_axes:
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
A wrapped version of ``fun``, set up for just-in-time compilation.
|
301
|
+
The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
|
302
|
+
and has the following attributes and methods:
|
303
|
+
|
304
|
+
- ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
|
305
|
+
- ``origin_fun(*args, **kwargs)``: the original function
|
306
|
+
- ``jitted_fun(*args, **kwargs)``: the jitted function
|
307
|
+
- ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
|
308
|
+
|
309
|
+
"""
|
310
|
+
|
311
|
+
if isinstance(fun, Missing):
|
312
|
+
def wrapper(fun_again: Callable) -> JittedFunction:
|
313
|
+
return _get_jitted_fun(
|
314
|
+
fun_again,
|
315
|
+
in_shardings=in_shardings,
|
316
|
+
out_shardings=out_shardings,
|
317
|
+
static_argnums=static_argnums,
|
318
|
+
donate_argnums=donate_argnums,
|
319
|
+
static_argnames=static_argnames,
|
320
|
+
donate_argnames=donate_argnames,
|
321
|
+
keep_unused=keep_unused,
|
322
|
+
device=device,
|
323
|
+
backend=backend,
|
324
|
+
inline=inline,
|
325
|
+
abstracted_axes=abstracted_axes,
|
326
|
+
**kwargs
|
327
|
+
)
|
328
|
+
|
329
|
+
return wrapper
|
330
|
+
|
331
|
+
else:
|
332
|
+
return _get_jitted_fun(
|
333
|
+
fun,
|
334
|
+
in_shardings,
|
335
|
+
out_shardings,
|
336
|
+
static_argnums,
|
337
|
+
donate_argnums,
|
338
|
+
static_argnames,
|
339
|
+
donate_argnames,
|
340
|
+
keep_unused,
|
341
|
+
device,
|
342
|
+
backend,
|
343
|
+
inline,
|
344
|
+
abstracted_axes,
|
345
|
+
**kwargs
|
346
|
+
)
|