brainstate 0.0.2.post20240824__py2.py3-none-any.whl → 0.0.2.post20240826__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +6 -6
- brainstate/environ.py +54 -5
- brainstate/random.py +5 -5
- brainstate/transform/__init__.py +16 -6
- brainstate/transform/_conditions.py +334 -0
- brainstate/transform/{_controls_test.py → _conditions_test.py} +35 -35
- brainstate/transform/_error_if.py +94 -0
- brainstate/transform/{_jit_error_test.py → _error_if_test.py} +4 -4
- brainstate/transform/_loop_collect_return.py +502 -0
- brainstate/transform/_loop_no_collection.py +170 -0
- brainstate/transform/_mapping.py +109 -0
- brainstate/transform/_unvmap.py +143 -0
- brainstate/typing.py +55 -1
- {brainstate-0.0.2.post20240824.dist-info → brainstate-0.0.2.post20240826.dist-info}/METADATA +4 -4
- {brainstate-0.0.2.post20240824.dist-info → brainstate-0.0.2.post20240826.dist-info}/RECORD +18 -14
- brainstate/transform/_control.py +0 -665
- brainstate/transform/_jit_error.py +0 -180
- {brainstate-0.0.2.post20240824.dist-info → brainstate-0.0.2.post20240826.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240824.dist-info → brainstate-0.0.2.post20240826.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20240824.dist-info → brainstate-0.0.2.post20240826.dist-info}/top_level.txt +0 -0
@@ -18,20 +18,20 @@ import unittest
|
|
18
18
|
import jax
|
19
19
|
import jax.numpy as jnp
|
20
20
|
|
21
|
-
import brainstate as
|
21
|
+
import brainstate as bst
|
22
22
|
|
23
23
|
|
24
24
|
class TestCond(unittest.TestCase):
|
25
25
|
def test1(self):
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
bst.random.seed(1)
|
27
|
+
bst.transform.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
|
28
|
+
bst.transform.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
|
29
29
|
|
30
30
|
def test2(self):
|
31
|
-
st1 =
|
32
|
-
st2 =
|
33
|
-
st3 =
|
34
|
-
st4 =
|
31
|
+
st1 = bst.State(bst.random.rand(10))
|
32
|
+
st2 = bst.State(bst.random.rand(2))
|
33
|
+
st3 = bst.State(bst.random.rand(5))
|
34
|
+
st4 = bst.State(bst.random.rand(2, 10))
|
35
35
|
|
36
36
|
def true_fun(x):
|
37
37
|
st1.value = st2.value @ st4.value + x
|
@@ -39,7 +39,7 @@ class TestCond(unittest.TestCase):
|
|
39
39
|
def false_fun(x):
|
40
40
|
st3.value = (st3.value + 1.) * x
|
41
41
|
|
42
|
-
|
42
|
+
bst.transform.cond(True, true_fun, false_fun, 2.)
|
43
43
|
assert not isinstance(st1.value, jax.core.Tracer)
|
44
44
|
assert not isinstance(st2.value, jax.core.Tracer)
|
45
45
|
assert not isinstance(st3.value, jax.core.Tracer)
|
@@ -65,7 +65,7 @@ class TestSwitch(unittest.TestCase):
|
|
65
65
|
return branches[2](x)
|
66
66
|
|
67
67
|
def cfun(x):
|
68
|
-
return
|
68
|
+
return bst.transform.switch(x, branches, x)
|
69
69
|
|
70
70
|
self.assertEqual(fun(-1), cfun(-1))
|
71
71
|
self.assertEqual(fun(0), cfun(0))
|
@@ -89,7 +89,7 @@ class TestSwitch(unittest.TestCase):
|
|
89
89
|
return branches[i](x, x)
|
90
90
|
|
91
91
|
def cfun(x):
|
92
|
-
return
|
92
|
+
return bst.transform.switch(x, branches, x, x)
|
93
93
|
|
94
94
|
self.assertEqual(fun(-1), cfun(-1))
|
95
95
|
self.assertEqual(fun(0), cfun(0))
|
@@ -122,13 +122,13 @@ class TestSwitch(unittest.TestCase):
|
|
122
122
|
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
123
123
|
|
124
124
|
def fun1(x, i):
|
125
|
-
return
|
125
|
+
return bst.transform.switch(i + 1, branches1, x)
|
126
126
|
|
127
127
|
def fun2(x, i):
|
128
|
-
return
|
128
|
+
return bst.transform.switch(i + 1, branches2, x)
|
129
129
|
|
130
130
|
def fun3(x, i):
|
131
|
-
return
|
131
|
+
return bst.transform.switch(i + 1, branches3, x)
|
132
132
|
|
133
133
|
fwd1, bwd1 = get_conds(fun1)
|
134
134
|
fwd2, bwd2 = get_conds(fun2)
|
@@ -148,7 +148,7 @@ class TestSwitch(unittest.TestCase):
|
|
148
148
|
|
149
149
|
def testOneBranchSwitch(self):
|
150
150
|
branch = lambda x: -x
|
151
|
-
f = lambda i, x:
|
151
|
+
f = lambda i, x: bst.transform.switch(i, [branch], x)
|
152
152
|
x = 7.
|
153
153
|
self.assertEqual(f(-1, x), branch(x))
|
154
154
|
self.assertEqual(f(0, x), branch(x))
|
@@ -166,12 +166,12 @@ class TestSwitch(unittest.TestCase):
|
|
166
166
|
class TestIfElse(unittest.TestCase):
|
167
167
|
def test1(self):
|
168
168
|
def f(a):
|
169
|
-
return
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
169
|
+
return bst.transform.ifelse(conditions=[a < 0,
|
170
|
+
a >= 0 and a < 2,
|
171
|
+
a >= 2 and a < 5,
|
172
|
+
a >= 5 and a < 10,
|
173
|
+
a >= 10],
|
174
|
+
branches=[lambda: 1,
|
175
175
|
lambda: 2,
|
176
176
|
lambda: 3,
|
177
177
|
lambda: 4,
|
@@ -183,38 +183,38 @@ class TestIfElse(unittest.TestCase):
|
|
183
183
|
|
184
184
|
def test_vmap(self):
|
185
185
|
def f(operands):
|
186
|
-
f = lambda a:
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
186
|
+
f = lambda a: bst.transform.ifelse([a > 10,
|
187
|
+
jnp.logical_and(a <= 10, a > 5),
|
188
|
+
jnp.logical_and(a <= 5, a > 2),
|
189
|
+
jnp.logical_and(a <= 2, a > 0),
|
190
|
+
a <= 0],
|
191
|
+
[lambda _: 1,
|
192
192
|
lambda _: 2,
|
193
193
|
lambda _: 3,
|
194
194
|
lambda _: 4,
|
195
195
|
lambda _: 5, ],
|
196
|
-
|
196
|
+
a)
|
197
197
|
return jax.vmap(f)(operands)
|
198
198
|
|
199
|
-
r = f(
|
199
|
+
r = f(bst.random.randint(-20, 20, 200))
|
200
200
|
self.assertTrue(r.size == 200)
|
201
201
|
|
202
202
|
def test_grad1(self):
|
203
203
|
def F2(x):
|
204
|
-
return
|
205
|
-
|
206
|
-
|
204
|
+
return bst.transform.ifelse((x >= 10, x < 10),
|
205
|
+
[lambda x: x, lambda x: x ** 2, ],
|
206
|
+
x)
|
207
207
|
|
208
208
|
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
209
209
|
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
210
210
|
|
211
211
|
def test_grad2(self):
|
212
212
|
def F3(x):
|
213
|
-
return
|
214
|
-
|
213
|
+
return bst.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
|
+
[lambda x: x,
|
215
215
|
lambda x: x ** 2,
|
216
216
|
lambda x: x ** 4, ],
|
217
|
-
|
217
|
+
x)
|
218
218
|
|
219
219
|
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
220
220
|
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from functools import partial
|
20
|
+
from typing import Callable, Union
|
21
|
+
|
22
|
+
import jax
|
23
|
+
|
24
|
+
from brainstate._utils import set_module_as
|
25
|
+
from ._unvmap import unvmap
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'jit_error_if',
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
def _err_jit_true_branch(err_fun, args, kwargs):
|
33
|
+
jax.debug.callback(err_fun, *args, **kwargs)
|
34
|
+
|
35
|
+
|
36
|
+
def _err_jit_false_branch(args, kwargs):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
def _error_msg(msg, *arg, **kwargs):
|
41
|
+
if len(arg):
|
42
|
+
msg = msg % arg
|
43
|
+
if len(kwargs):
|
44
|
+
msg = msg.format(**kwargs)
|
45
|
+
raise ValueError(msg)
|
46
|
+
|
47
|
+
|
48
|
+
@set_module_as('brainstate.transform')
|
49
|
+
def jit_error_if(
|
50
|
+
pred,
|
51
|
+
error: Union[Callable, str],
|
52
|
+
*err_args,
|
53
|
+
**err_kwargs,
|
54
|
+
):
|
55
|
+
"""
|
56
|
+
Check errors in a jit function.
|
57
|
+
|
58
|
+
Examples
|
59
|
+
--------
|
60
|
+
|
61
|
+
It can give a function which receive arguments that passed from the JIT variables and raise errors.
|
62
|
+
|
63
|
+
>>> def error(x):
|
64
|
+
>>> raise ValueError(f'error {x}')
|
65
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
66
|
+
>>> jit_error_if(x.sum() < 5., error, x)
|
67
|
+
|
68
|
+
Or, it can be a simple string message.
|
69
|
+
|
70
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
71
|
+
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
72
|
+
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
pred: bool, Array
|
77
|
+
The boolean prediction.
|
78
|
+
error: callable, str
|
79
|
+
The error function, which raise errors, or a string indicating the error message.
|
80
|
+
err_args:
|
81
|
+
The arguments which passed into `err_f`.
|
82
|
+
err_kwargs:
|
83
|
+
The keywords which passed into `err_f`.
|
84
|
+
"""
|
85
|
+
if isinstance(error, str):
|
86
|
+
error = partial(_error_msg, error)
|
87
|
+
|
88
|
+
jax.lax.cond(
|
89
|
+
unvmap(pred, op='any'),
|
90
|
+
partial(_err_jit_true_branch, error),
|
91
|
+
_err_jit_false_branch,
|
92
|
+
jax.tree.map(functools.partial(unvmap, op='none'), err_args),
|
93
|
+
jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
|
94
|
+
)
|
@@ -25,18 +25,18 @@ import brainstate as bst
|
|
25
25
|
class TestJitError(unittest.TestCase):
|
26
26
|
def test1(self):
|
27
27
|
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
28
|
-
bst.transform.
|
28
|
+
bst.transform.jit_error_if(True, 'error')
|
29
29
|
|
30
30
|
def err_f(x):
|
31
31
|
raise ValueError(f'error: {x}')
|
32
32
|
|
33
33
|
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
34
|
-
bst.transform.
|
34
|
+
bst.transform.jit_error_if(True, err_f, 1.)
|
35
35
|
|
36
36
|
def test_vmap(self):
|
37
37
|
|
38
38
|
def f(x):
|
39
|
-
bst.transform.
|
39
|
+
bst.transform.jit_error_if(x, 'error: {x}', x=x)
|
40
40
|
|
41
41
|
jax.vmap(f)(jnp.array([False, False, False]))
|
42
42
|
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
@@ -45,7 +45,7 @@ class TestJitError(unittest.TestCase):
|
|
45
45
|
def test_vmap_vmap(self):
|
46
46
|
|
47
47
|
def f(x):
|
48
|
-
bst.transform.
|
48
|
+
bst.transform.jit_error_if(x, 'error: {x}', x=x)
|
49
49
|
|
50
50
|
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
51
51
|
[False, False, False]]))
|