brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,220 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate as bc
|
22
|
+
|
23
|
+
|
24
|
+
class TestCond(unittest.TestCase):
|
25
|
+
def test1(self):
|
26
|
+
bc.random.seed(1)
|
27
|
+
bc.transform.cond(True, lambda: bc.random.random(10), lambda: bc.random.random(10))
|
28
|
+
bc.transform.cond(False, lambda: bc.random.random(10), lambda: bc.random.random(10))
|
29
|
+
|
30
|
+
def test2(self):
|
31
|
+
st1 = bc.State(bc.random.rand(10))
|
32
|
+
st2 = bc.State(bc.random.rand(2))
|
33
|
+
st3 = bc.State(bc.random.rand(5))
|
34
|
+
st4 = bc.State(bc.random.rand(2, 10))
|
35
|
+
|
36
|
+
def true_fun(x):
|
37
|
+
st1.value = st2.value @ st4.value + x
|
38
|
+
|
39
|
+
def false_fun(x):
|
40
|
+
st3.value = (st3.value + 1.) * x
|
41
|
+
|
42
|
+
bc.transform.cond(True, true_fun, false_fun, 2.)
|
43
|
+
assert not isinstance(st1.value, jax.core.Tracer)
|
44
|
+
assert not isinstance(st2.value, jax.core.Tracer)
|
45
|
+
assert not isinstance(st3.value, jax.core.Tracer)
|
46
|
+
assert not isinstance(st4.value, jax.core.Tracer)
|
47
|
+
|
48
|
+
|
49
|
+
class TestSwitch(unittest.TestCase):
|
50
|
+
def testSwitch(self):
|
51
|
+
def branch(x):
|
52
|
+
y = jax.lax.mul(2, x)
|
53
|
+
return y, jax.lax.mul(2, y)
|
54
|
+
|
55
|
+
branches = [lambda x: (x, x),
|
56
|
+
branch,
|
57
|
+
lambda x: (x, -x)]
|
58
|
+
|
59
|
+
def fun(x):
|
60
|
+
if x <= 0:
|
61
|
+
return branches[0](x)
|
62
|
+
elif x == 1:
|
63
|
+
return branches[1](x)
|
64
|
+
else:
|
65
|
+
return branches[2](x)
|
66
|
+
|
67
|
+
def cfun(x):
|
68
|
+
return bc.transform.switch(x, branches, x)
|
69
|
+
|
70
|
+
self.assertEqual(fun(-1), cfun(-1))
|
71
|
+
self.assertEqual(fun(0), cfun(0))
|
72
|
+
self.assertEqual(fun(1), cfun(1))
|
73
|
+
self.assertEqual(fun(2), cfun(2))
|
74
|
+
self.assertEqual(fun(3), cfun(3))
|
75
|
+
|
76
|
+
cfun = jax.jit(cfun)
|
77
|
+
|
78
|
+
self.assertEqual(fun(-1), cfun(-1))
|
79
|
+
self.assertEqual(fun(0), cfun(0))
|
80
|
+
self.assertEqual(fun(1), cfun(1))
|
81
|
+
self.assertEqual(fun(2), cfun(2))
|
82
|
+
self.assertEqual(fun(3), cfun(3))
|
83
|
+
|
84
|
+
def testSwitchMultiOperands(self):
|
85
|
+
branches = [jax.lax.add, jax.lax.mul]
|
86
|
+
|
87
|
+
def fun(x):
|
88
|
+
i = 0 if x <= 0 else 1
|
89
|
+
return branches[i](x, x)
|
90
|
+
|
91
|
+
def cfun(x):
|
92
|
+
return bc.transform.switch(x, branches, x, x)
|
93
|
+
|
94
|
+
self.assertEqual(fun(-1), cfun(-1))
|
95
|
+
self.assertEqual(fun(0), cfun(0))
|
96
|
+
self.assertEqual(fun(1), cfun(1))
|
97
|
+
self.assertEqual(fun(2), cfun(2))
|
98
|
+
cfun = jax.jit(cfun)
|
99
|
+
self.assertEqual(fun(-1), cfun(-1))
|
100
|
+
self.assertEqual(fun(0), cfun(0))
|
101
|
+
self.assertEqual(fun(1), cfun(1))
|
102
|
+
self.assertEqual(fun(2), cfun(2))
|
103
|
+
|
104
|
+
def testSwitchResidualsMerge(self):
|
105
|
+
def get_conds(fun):
|
106
|
+
jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
|
107
|
+
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
|
108
|
+
|
109
|
+
def branch_invars_len(cond_eqn):
|
110
|
+
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
|
111
|
+
assert len(set(lens)) == 1
|
112
|
+
return lens[0]
|
113
|
+
|
114
|
+
def branch_outvars_len(cond_eqn):
|
115
|
+
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
|
116
|
+
assert len(set(lens)) == 1
|
117
|
+
return lens[0]
|
118
|
+
|
119
|
+
branches1 = [lambda x: jnp.sin(x),
|
120
|
+
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
|
121
|
+
branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
|
122
|
+
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
123
|
+
|
124
|
+
def fun1(x, i):
|
125
|
+
return bc.transform.switch(i + 1, branches1, x)
|
126
|
+
|
127
|
+
def fun2(x, i):
|
128
|
+
return bc.transform.switch(i + 1, branches2, x)
|
129
|
+
|
130
|
+
def fun3(x, i):
|
131
|
+
return bc.transform.switch(i + 1, branches3, x)
|
132
|
+
|
133
|
+
fwd1, bwd1 = get_conds(fun1)
|
134
|
+
fwd2, bwd2 = get_conds(fun2)
|
135
|
+
fwd3, bwd3 = get_conds(fun3)
|
136
|
+
|
137
|
+
fwd1_num_out = branch_outvars_len(fwd1)
|
138
|
+
fwd2_num_out = branch_outvars_len(fwd2)
|
139
|
+
fwd3_num_out = branch_outvars_len(fwd3)
|
140
|
+
assert fwd1_num_out == fwd2_num_out
|
141
|
+
assert fwd3_num_out == fwd2_num_out + 1
|
142
|
+
|
143
|
+
bwd1_num_in = branch_invars_len(bwd1)
|
144
|
+
bwd2_num_in = branch_invars_len(bwd2)
|
145
|
+
bwd3_num_in = branch_invars_len(bwd3)
|
146
|
+
assert bwd1_num_in == bwd2_num_in
|
147
|
+
assert bwd3_num_in == bwd2_num_in + 1
|
148
|
+
|
149
|
+
def testOneBranchSwitch(self):
|
150
|
+
branch = lambda x: -x
|
151
|
+
f = lambda i, x: bc.transform.switch(i, [branch], x)
|
152
|
+
x = 7.
|
153
|
+
self.assertEqual(f(-1, x), branch(x))
|
154
|
+
self.assertEqual(f(0, x), branch(x))
|
155
|
+
self.assertEqual(f(1, x), branch(x))
|
156
|
+
cf = jax.jit(f)
|
157
|
+
self.assertEqual(cf(-1, x), branch(x))
|
158
|
+
self.assertEqual(cf(0, x), branch(x))
|
159
|
+
self.assertEqual(cf(1, x), branch(x))
|
160
|
+
cf = jax.jit(f, static_argnums=0)
|
161
|
+
self.assertEqual(cf(-1, x), branch(x))
|
162
|
+
self.assertEqual(cf(0, x), branch(x))
|
163
|
+
self.assertEqual(cf(1, x), branch(x))
|
164
|
+
|
165
|
+
|
166
|
+
class TestIfElse(unittest.TestCase):
|
167
|
+
def test1(self):
|
168
|
+
def f(a):
|
169
|
+
return bc.transform.ifelse(conditions=[a < 0,
|
170
|
+
a >= 0 and a < 2,
|
171
|
+
a >= 2 and a < 5,
|
172
|
+
a >= 5 and a < 10,
|
173
|
+
a >= 10],
|
174
|
+
branches=[lambda: 1,
|
175
|
+
lambda: 2,
|
176
|
+
lambda: 3,
|
177
|
+
lambda: 4,
|
178
|
+
lambda: 5])
|
179
|
+
|
180
|
+
self.assertTrue(f(3) == 3)
|
181
|
+
self.assertTrue(f(1) == 2)
|
182
|
+
self.assertTrue(f(-1) == 1)
|
183
|
+
|
184
|
+
def test_vmap(self):
|
185
|
+
def f(operands):
|
186
|
+
f = lambda a: bc.transform.ifelse([a > 10,
|
187
|
+
jnp.logical_and(a <= 10, a > 5),
|
188
|
+
jnp.logical_and(a <= 5, a > 2),
|
189
|
+
jnp.logical_and(a <= 2, a > 0),
|
190
|
+
a <= 0],
|
191
|
+
[lambda _: 1,
|
192
|
+
lambda _: 2,
|
193
|
+
lambda _: 3,
|
194
|
+
lambda _: 4,
|
195
|
+
lambda _: 5, ],
|
196
|
+
a)
|
197
|
+
return jax.vmap(f)(operands)
|
198
|
+
|
199
|
+
r = f(bc.random.randint(-20, 20, 200))
|
200
|
+
self.assertTrue(r.size == 200)
|
201
|
+
|
202
|
+
def test_grad1(self):
|
203
|
+
def F2(x):
|
204
|
+
return bc.transform.ifelse((x >= 10, x < 10),
|
205
|
+
[lambda x: x, lambda x: x ** 2, ],
|
206
|
+
x)
|
207
|
+
|
208
|
+
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
209
|
+
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
210
|
+
|
211
|
+
def test_grad2(self):
|
212
|
+
def F3(x):
|
213
|
+
return bc.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
|
+
[lambda x: x,
|
215
|
+
lambda x: x ** 2,
|
216
|
+
lambda x: x ** 4, ],
|
217
|
+
x)
|
218
|
+
|
219
|
+
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
220
|
+
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|
@@ -0,0 +1,239 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from collections.abc import Iterable, Sequence
|
20
|
+
from typing import (Any, Callable, Union)
|
21
|
+
|
22
|
+
import jax
|
23
|
+
from jax._src import sharding_impls
|
24
|
+
from jax.lib import xla_client as xc
|
25
|
+
|
26
|
+
from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
|
27
|
+
from brainstate._utils import set_module_as
|
28
|
+
|
29
|
+
__all__ = ['jit']
|
30
|
+
|
31
|
+
|
32
|
+
class JittedFunction(Callable):
|
33
|
+
"""
|
34
|
+
A wrapped version of ``fun``, set up for just-in-time compilation.
|
35
|
+
"""
|
36
|
+
origin_fun: Callable # the original function
|
37
|
+
stateful_fun: StatefulFunction # the stateful function for extracting states
|
38
|
+
jitted_fun: jax.stages.Wrapped # the jitted function
|
39
|
+
clear_cache: Callable # clear the cache of the jitted function
|
40
|
+
|
41
|
+
|
42
|
+
def _get_jitted_fun(
|
43
|
+
fun: Callable,
|
44
|
+
in_shardings,
|
45
|
+
out_shardings,
|
46
|
+
static_argnums,
|
47
|
+
donate_argnums,
|
48
|
+
donate_argnames,
|
49
|
+
keep_unused,
|
50
|
+
device,
|
51
|
+
backend,
|
52
|
+
inline,
|
53
|
+
abstracted_axes,
|
54
|
+
**kwargs
|
55
|
+
) -> JittedFunction:
|
56
|
+
static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
|
57
|
+
# TODO: add to cache stack for clear_cache
|
58
|
+
fun = StatefulFunction(fun, static_argnums=static_argnums, abstracted_axes=abstracted_axes, cache_type='jit')
|
59
|
+
jit_fun = jax.jit(fun.jaxpr_call,
|
60
|
+
static_argnums=tuple(i + 1 for i in static_argnums),
|
61
|
+
donate_argnums=donate_argnums,
|
62
|
+
donate_argnames=donate_argnames,
|
63
|
+
keep_unused=keep_unused,
|
64
|
+
device=device,
|
65
|
+
backend=backend,
|
66
|
+
inline=inline,
|
67
|
+
in_shardings=in_shardings,
|
68
|
+
out_shardings=out_shardings,
|
69
|
+
abstracted_axes=abstracted_axes,
|
70
|
+
**kwargs)
|
71
|
+
|
72
|
+
@functools.wraps(fun.fun)
|
73
|
+
def jitted_fun(*args, **params):
|
74
|
+
if jax.config.jax_disable_jit:
|
75
|
+
return fun.fun(*args, **params)
|
76
|
+
states = fun.compile_and_get_states_by_static_args(*args, **kwargs)
|
77
|
+
state_vals, outs = jit_fun([st.value for st in states], *args, **params)
|
78
|
+
_assign_state_values(states, state_vals)
|
79
|
+
return outs
|
80
|
+
|
81
|
+
def clear_cache():
|
82
|
+
# clear the cache of the stateful function
|
83
|
+
fun.clear_cache()
|
84
|
+
# clear the cache of the jitted function
|
85
|
+
jit_fun.clear_cache()
|
86
|
+
|
87
|
+
jitted_fun: JittedFunction
|
88
|
+
# the original function
|
89
|
+
jitted_fun.origin_fun = fun.fun
|
90
|
+
# the stateful function for extracting states
|
91
|
+
jitted_fun.stateful_fun = fun
|
92
|
+
# the jitted function
|
93
|
+
jitted_fun.jitted_fun = jit_fun
|
94
|
+
# clear cache
|
95
|
+
jitted_fun.clear_cache = clear_cache
|
96
|
+
|
97
|
+
return jitted_fun
|
98
|
+
|
99
|
+
|
100
|
+
@set_module_as('brainstate.transform')
|
101
|
+
def jit(
|
102
|
+
fun: Callable = None,
|
103
|
+
in_shardings=sharding_impls.UNSPECIFIED,
|
104
|
+
out_shardings=sharding_impls.UNSPECIFIED,
|
105
|
+
static_argnums: int | Sequence[int] | None = None,
|
106
|
+
donate_argnums: int | Sequence[int] | None = None,
|
107
|
+
donate_argnames: str | Iterable[str] | None = None,
|
108
|
+
keep_unused: bool = False,
|
109
|
+
device: xc.Device | None = None,
|
110
|
+
backend: str | None = None,
|
111
|
+
inline: bool = False,
|
112
|
+
abstracted_axes: Any | None = None,
|
113
|
+
**kwargs
|
114
|
+
) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
|
115
|
+
"""
|
116
|
+
Sets up ``fun`` for just-in-time compilation with XLA.
|
117
|
+
|
118
|
+
Does not support setting ``static_argnames`` as in ``jax.jit()``.
|
119
|
+
|
120
|
+
|
121
|
+
Args:
|
122
|
+
fun: Function to be jitted.
|
123
|
+
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
124
|
+
with all actual arguments replaced by resource assignment specifications.
|
125
|
+
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
126
|
+
whole subtree), in which case the leaves get broadcast to all values in
|
127
|
+
that subtree.
|
128
|
+
|
129
|
+
The ``in_shardings`` argument is optional. JAX will infer the shardings
|
130
|
+
from the input :py:class:`jax.Array`'s and defaults to replicating the input
|
131
|
+
if the sharding cannot be inferred.
|
132
|
+
|
133
|
+
The valid resource assignment specifications are:
|
134
|
+
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
135
|
+
will be partitioned. With this, using a mesh context manager is not
|
136
|
+
required.
|
137
|
+
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
|
138
|
+
it wants.
|
139
|
+
For in_shardings, JAX will mark is as replicated but this behavior
|
140
|
+
can change in the future.
|
141
|
+
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
142
|
+
determine the output shardings.
|
143
|
+
|
144
|
+
The size of every dimension has to be a multiple of the total number of
|
145
|
+
resources assigned to it. This is similar to pjit's in_shardings.
|
146
|
+
out_shardings: Like ``in_shardings``, but specifies resource
|
147
|
+
assignment for function outputs. This is similar to pjit's
|
148
|
+
out_shardings.
|
149
|
+
|
150
|
+
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
151
|
+
will use GSPMD's sharding propagation to figure out what the sharding of the
|
152
|
+
output(s) should be.
|
153
|
+
static_argnums: An optional int or collection of ints that specify which
|
154
|
+
positional arguments to treat as static (compile-time constant).
|
155
|
+
Operations that only depend on static arguments will be constant-folded in
|
156
|
+
Python (during tracing), and so the corresponding argument values can be
|
157
|
+
any Python object.
|
158
|
+
|
159
|
+
Static arguments should be hashable, meaning both ``__hash__`` and
|
160
|
+
``__eq__`` are implemented, and immutable. Calling the jitted function
|
161
|
+
with different values for these constants will trigger recompilation.
|
162
|
+
Arguments that are not arrays or containers thereof must be marked as
|
163
|
+
static.
|
164
|
+
|
165
|
+
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
|
166
|
+
arguments are treated as static. If ``static_argnums`` is not provided but
|
167
|
+
``static_argnames`` is, or vice versa, JAX uses
|
168
|
+
:code:`inspect.signature(fun)` to find any positional arguments that
|
169
|
+
correspond to ``static_argnames``
|
170
|
+
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
|
171
|
+
provided, ``inspect.signature`` is not used, and only actual
|
172
|
+
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
173
|
+
be treated as static.
|
174
|
+
donate_argnums: Specify which positional argument buffers are "donated" to
|
175
|
+
the computation. It is safe to donate argument buffers if you no longer
|
176
|
+
need them once the computation has finished. In some cases XLA can make
|
177
|
+
use of donated buffers to reduce the amount of memory needed to perform a
|
178
|
+
computation, for example recycling one of your input buffers to store a
|
179
|
+
result. You should not reuse buffers that you donate to a computation, JAX
|
180
|
+
will raise an error if you try to. By default, no argument buffers are
|
181
|
+
donated.
|
182
|
+
|
183
|
+
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
184
|
+
arguments are donated. If ``donate_argnums`` is not provided but
|
185
|
+
``donate_argnames`` is, or vice versa, JAX uses
|
186
|
+
:code:`inspect.signature(fun)` to find any positional arguments that
|
187
|
+
correspond to ``donate_argnames``
|
188
|
+
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
189
|
+
provided, ``inspect.signature`` is not used, and only actual
|
190
|
+
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
191
|
+
be donated.
|
192
|
+
|
193
|
+
For more details on buffer donation see the
|
194
|
+
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
195
|
+
donate_argnames: An optional string or collection of strings specifying
|
196
|
+
which named arguments are donated to the computation. See the
|
197
|
+
comment on ``donate_argnums`` for details. If not
|
198
|
+
provided but ``donate_argnums`` is set, the default is based on calling
|
199
|
+
``inspect.signature(fun)`` to find corresponding named arguments.
|
200
|
+
keep_unused: If `False` (the default), arguments that JAX determines to be
|
201
|
+
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
202
|
+
Such arguments will not be transferred to the device nor provided to the
|
203
|
+
underlying executable. If `True`, unused arguments will not be pruned.
|
204
|
+
device: This is an experimental feature and the API is likely to change.
|
205
|
+
Optional, the Device the jitted function will run on. (Available devices
|
206
|
+
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
207
|
+
from XLA's DeviceAssignment logic and is usually to use
|
208
|
+
``jax.devices()[0]``.
|
209
|
+
backend: This is an experimental feature and the API is likely to change.
|
210
|
+
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
211
|
+
``'tpu'``.
|
212
|
+
inline: Specify whether this function should be inlined into enclosing
|
213
|
+
jaxprs (rather than being represented as an application of the xla_call
|
214
|
+
primitive with its own subjaxpr). Default False.
|
215
|
+
abstracted_axes:
|
216
|
+
|
217
|
+
Returns:
|
218
|
+
A wrapped version of ``fun``, set up for just-in-time compilation.
|
219
|
+
The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
|
220
|
+
and has the following attributes and methods:
|
221
|
+
|
222
|
+
- ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
|
223
|
+
- ``origin_fun(*args, **kwargs)``: the original function
|
224
|
+
- ``jitted_fun(*args, **kwargs)``: the jitted function
|
225
|
+
- ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
|
226
|
+
|
227
|
+
"""
|
228
|
+
|
229
|
+
if fun is None:
|
230
|
+
def wrapper(fun_again: Callable) -> JittedFunction:
|
231
|
+
return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
|
232
|
+
donate_argnums, donate_argnames, keep_unused,
|
233
|
+
device, backend, inline, abstracted_axes, **kwargs)
|
234
|
+
return wrapper
|
235
|
+
|
236
|
+
else:
|
237
|
+
return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
|
238
|
+
donate_argnums, donate_argnames, keep_unused,
|
239
|
+
device, backend, inline, abstracted_axes, **kwargs)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from functools import wraps, partial
|
19
|
+
from typing import Callable, Union
|
20
|
+
|
21
|
+
import jax
|
22
|
+
from jax import numpy as jnp
|
23
|
+
from jax.core import Primitive, ShapedArray
|
24
|
+
from jax.interpreters import batching, mlir, xla
|
25
|
+
from jax.lax import cond
|
26
|
+
|
27
|
+
from brainstate._utils import set_module_as
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'jit_error',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
@set_module_as('brainstate.transform')
|
35
|
+
def remove_vmap(x, op='any'):
|
36
|
+
if op == 'any':
|
37
|
+
return _any_without_vmap(x)
|
38
|
+
elif op == 'all':
|
39
|
+
return _all_without_vmap(x)
|
40
|
+
else:
|
41
|
+
raise ValueError(f'Do not support type: {op}')
|
42
|
+
|
43
|
+
|
44
|
+
_any_no_vmap_prim = Primitive('any_no_vmap')
|
45
|
+
|
46
|
+
|
47
|
+
def _any_without_vmap(x):
|
48
|
+
return _any_no_vmap_prim.bind(x)
|
49
|
+
|
50
|
+
|
51
|
+
def _any_without_vmap_imp(x):
|
52
|
+
return jnp.any(x)
|
53
|
+
|
54
|
+
|
55
|
+
def _any_without_vmap_abs(x):
|
56
|
+
return ShapedArray(shape=(), dtype=jnp.bool_)
|
57
|
+
|
58
|
+
|
59
|
+
def _any_without_vmap_batch(x, batch_axes):
|
60
|
+
(x,) = x
|
61
|
+
return _any_without_vmap(x), batching.not_mapped
|
62
|
+
|
63
|
+
|
64
|
+
_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
|
65
|
+
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
|
66
|
+
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
|
67
|
+
if hasattr(xla, "lower_fun"):
|
68
|
+
xla.register_translation(_any_no_vmap_prim,
|
69
|
+
xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
|
70
|
+
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))
|
71
|
+
|
72
|
+
_all_no_vmap_prim = Primitive('all_no_vmap')
|
73
|
+
|
74
|
+
|
75
|
+
def _all_without_vmap(x):
|
76
|
+
return _all_no_vmap_prim.bind(x)
|
77
|
+
|
78
|
+
|
79
|
+
def _all_without_vmap_imp(x):
|
80
|
+
return jnp.all(x)
|
81
|
+
|
82
|
+
|
83
|
+
def _all_without_vmap_abs(x):
|
84
|
+
return ShapedArray(shape=(), dtype=jnp.bool_)
|
85
|
+
|
86
|
+
|
87
|
+
def _all_without_vmap_batch(x, batch_axes):
|
88
|
+
(x,) = x
|
89
|
+
return _all_without_vmap(x), batching.not_mapped
|
90
|
+
|
91
|
+
|
92
|
+
_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
|
93
|
+
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
|
94
|
+
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
|
95
|
+
if hasattr(xla, "lower_fun"):
|
96
|
+
xla.register_translation(_all_no_vmap_prim,
|
97
|
+
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
|
98
|
+
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))
|
99
|
+
|
100
|
+
|
101
|
+
def _err_jit_true_branch(err_fun, x):
|
102
|
+
jax.debug.callback(err_fun, x)
|
103
|
+
return
|
104
|
+
|
105
|
+
|
106
|
+
def _err_jit_false_branch(x):
|
107
|
+
return
|
108
|
+
|
109
|
+
|
110
|
+
def _cond(err_fun, pred, err_arg):
|
111
|
+
@wraps(err_fun)
|
112
|
+
def true_err_fun(*arg):
|
113
|
+
err_fun(*arg)
|
114
|
+
|
115
|
+
cond(pred,
|
116
|
+
partial(_err_jit_true_branch, true_err_fun),
|
117
|
+
_err_jit_false_branch,
|
118
|
+
err_arg)
|
119
|
+
|
120
|
+
|
121
|
+
def _error_msg(msg, *arg):
|
122
|
+
if len(arg) == 0:
|
123
|
+
raise ValueError(msg)
|
124
|
+
else:
|
125
|
+
raise ValueError(msg.format(arg))
|
126
|
+
|
127
|
+
|
128
|
+
@set_module_as('brainstate.transform')
|
129
|
+
def jit_error(pred, err_fun: Union[Callable, str], err_arg=None, scope: str = 'any'):
|
130
|
+
"""Check errors in a jit function.
|
131
|
+
|
132
|
+
>>> def error(arg):
|
133
|
+
>>> raise ValueError(f'error {arg}')
|
134
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
135
|
+
>>> jit_error(x.sum() < 5., error, err_arg=x)
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
----------
|
139
|
+
pred: bool, Array
|
140
|
+
The boolean prediction.
|
141
|
+
err_fun: callable
|
142
|
+
The error function, which raise errors.
|
143
|
+
err_arg: any
|
144
|
+
The arguments which passed into `err_f`.
|
145
|
+
scope: str
|
146
|
+
The scope of the error message. Can be None, 'all' or 'any'.
|
147
|
+
"""
|
148
|
+
if isinstance(err_fun, str):
|
149
|
+
err_fun = partial(_error_msg, err_fun)
|
150
|
+
if scope is None:
|
151
|
+
pred = pred
|
152
|
+
elif scope == 'all':
|
153
|
+
pred = remove_vmap(pred, 'all')
|
154
|
+
elif scope == 'any':
|
155
|
+
pred = remove_vmap(pred, 'any')
|
156
|
+
else:
|
157
|
+
raise ValueError(f"Unknown scope: {scope}")
|
158
|
+
_cond(err_fun, pred, err_arg)
|