brainstate 0.0.2.post20240824__py2.py3-none-any.whl → 0.0.2.post20240826__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,20 +18,20 @@ import unittest
18
18
  import jax
19
19
  import jax.numpy as jnp
20
20
 
21
- import brainstate as bc
21
+ import brainstate as bst
22
22
 
23
23
 
24
24
  class TestCond(unittest.TestCase):
25
25
  def test1(self):
26
- bc.random.seed(1)
27
- bc.transform.cond(True, lambda: bc.random.random(10), lambda: bc.random.random(10))
28
- bc.transform.cond(False, lambda: bc.random.random(10), lambda: bc.random.random(10))
26
+ bst.random.seed(1)
27
+ bst.transform.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
28
+ bst.transform.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
29
29
 
30
30
  def test2(self):
31
- st1 = bc.State(bc.random.rand(10))
32
- st2 = bc.State(bc.random.rand(2))
33
- st3 = bc.State(bc.random.rand(5))
34
- st4 = bc.State(bc.random.rand(2, 10))
31
+ st1 = bst.State(bst.random.rand(10))
32
+ st2 = bst.State(bst.random.rand(2))
33
+ st3 = bst.State(bst.random.rand(5))
34
+ st4 = bst.State(bst.random.rand(2, 10))
35
35
 
36
36
  def true_fun(x):
37
37
  st1.value = st2.value @ st4.value + x
@@ -39,7 +39,7 @@ class TestCond(unittest.TestCase):
39
39
  def false_fun(x):
40
40
  st3.value = (st3.value + 1.) * x
41
41
 
42
- bc.transform.cond(True, true_fun, false_fun, 2.)
42
+ bst.transform.cond(True, true_fun, false_fun, 2.)
43
43
  assert not isinstance(st1.value, jax.core.Tracer)
44
44
  assert not isinstance(st2.value, jax.core.Tracer)
45
45
  assert not isinstance(st3.value, jax.core.Tracer)
@@ -65,7 +65,7 @@ class TestSwitch(unittest.TestCase):
65
65
  return branches[2](x)
66
66
 
67
67
  def cfun(x):
68
- return bc.transform.switch(x, branches, x)
68
+ return bst.transform.switch(x, branches, x)
69
69
 
70
70
  self.assertEqual(fun(-1), cfun(-1))
71
71
  self.assertEqual(fun(0), cfun(0))
@@ -89,7 +89,7 @@ class TestSwitch(unittest.TestCase):
89
89
  return branches[i](x, x)
90
90
 
91
91
  def cfun(x):
92
- return bc.transform.switch(x, branches, x, x)
92
+ return bst.transform.switch(x, branches, x, x)
93
93
 
94
94
  self.assertEqual(fun(-1), cfun(-1))
95
95
  self.assertEqual(fun(0), cfun(0))
@@ -122,13 +122,13 @@ class TestSwitch(unittest.TestCase):
122
122
  branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
123
123
 
124
124
  def fun1(x, i):
125
- return bc.transform.switch(i + 1, branches1, x)
125
+ return bst.transform.switch(i + 1, branches1, x)
126
126
 
127
127
  def fun2(x, i):
128
- return bc.transform.switch(i + 1, branches2, x)
128
+ return bst.transform.switch(i + 1, branches2, x)
129
129
 
130
130
  def fun3(x, i):
131
- return bc.transform.switch(i + 1, branches3, x)
131
+ return bst.transform.switch(i + 1, branches3, x)
132
132
 
133
133
  fwd1, bwd1 = get_conds(fun1)
134
134
  fwd2, bwd2 = get_conds(fun2)
@@ -148,7 +148,7 @@ class TestSwitch(unittest.TestCase):
148
148
 
149
149
  def testOneBranchSwitch(self):
150
150
  branch = lambda x: -x
151
- f = lambda i, x: bc.transform.switch(i, [branch], x)
151
+ f = lambda i, x: bst.transform.switch(i, [branch], x)
152
152
  x = 7.
153
153
  self.assertEqual(f(-1, x), branch(x))
154
154
  self.assertEqual(f(0, x), branch(x))
@@ -166,12 +166,12 @@ class TestSwitch(unittest.TestCase):
166
166
  class TestIfElse(unittest.TestCase):
167
167
  def test1(self):
168
168
  def f(a):
169
- return bc.transform.ifelse(conditions=[a < 0,
170
- a >= 0 and a < 2,
171
- a >= 2 and a < 5,
172
- a >= 5 and a < 10,
173
- a >= 10],
174
- branches=[lambda: 1,
169
+ return bst.transform.ifelse(conditions=[a < 0,
170
+ a >= 0 and a < 2,
171
+ a >= 2 and a < 5,
172
+ a >= 5 and a < 10,
173
+ a >= 10],
174
+ branches=[lambda: 1,
175
175
  lambda: 2,
176
176
  lambda: 3,
177
177
  lambda: 4,
@@ -183,38 +183,38 @@ class TestIfElse(unittest.TestCase):
183
183
 
184
184
  def test_vmap(self):
185
185
  def f(operands):
186
- f = lambda a: bc.transform.ifelse([a > 10,
187
- jnp.logical_and(a <= 10, a > 5),
188
- jnp.logical_and(a <= 5, a > 2),
189
- jnp.logical_and(a <= 2, a > 0),
190
- a <= 0],
191
- [lambda _: 1,
186
+ f = lambda a: bst.transform.ifelse([a > 10,
187
+ jnp.logical_and(a <= 10, a > 5),
188
+ jnp.logical_and(a <= 5, a > 2),
189
+ jnp.logical_and(a <= 2, a > 0),
190
+ a <= 0],
191
+ [lambda _: 1,
192
192
  lambda _: 2,
193
193
  lambda _: 3,
194
194
  lambda _: 4,
195
195
  lambda _: 5, ],
196
- a)
196
+ a)
197
197
  return jax.vmap(f)(operands)
198
198
 
199
- r = f(bc.random.randint(-20, 20, 200))
199
+ r = f(bst.random.randint(-20, 20, 200))
200
200
  self.assertTrue(r.size == 200)
201
201
 
202
202
  def test_grad1(self):
203
203
  def F2(x):
204
- return bc.transform.ifelse((x >= 10, x < 10),
205
- [lambda x: x, lambda x: x ** 2, ],
206
- x)
204
+ return bst.transform.ifelse((x >= 10, x < 10),
205
+ [lambda x: x, lambda x: x ** 2, ],
206
+ x)
207
207
 
208
208
  self.assertTrue(jax.grad(F2)(9.0) == 18.)
209
209
  self.assertTrue(jax.grad(F2)(11.0) == 1.)
210
210
 
211
211
  def test_grad2(self):
212
212
  def F3(x):
213
- return bc.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
- [lambda x: x,
213
+ return bst.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
+ [lambda x: x,
215
215
  lambda x: x ** 2,
216
216
  lambda x: x ** 4, ],
217
- x)
217
+ x)
218
218
 
219
219
  self.assertTrue(jax.grad(F3)(9.0) == 18.)
220
220
  self.assertTrue(jax.grad(F3)(11.0) == 1.)
@@ -0,0 +1,94 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from functools import partial
20
+ from typing import Callable, Union
21
+
22
+ import jax
23
+
24
+ from brainstate._utils import set_module_as
25
+ from ._unvmap import unvmap
26
+
27
+ __all__ = [
28
+ 'jit_error_if',
29
+ ]
30
+
31
+
32
+ def _err_jit_true_branch(err_fun, args, kwargs):
33
+ jax.debug.callback(err_fun, *args, **kwargs)
34
+
35
+
36
+ def _err_jit_false_branch(args, kwargs):
37
+ pass
38
+
39
+
40
+ def _error_msg(msg, *arg, **kwargs):
41
+ if len(arg):
42
+ msg = msg % arg
43
+ if len(kwargs):
44
+ msg = msg.format(**kwargs)
45
+ raise ValueError(msg)
46
+
47
+
48
+ @set_module_as('brainstate.transform')
49
+ def jit_error_if(
50
+ pred,
51
+ error: Union[Callable, str],
52
+ *err_args,
53
+ **err_kwargs,
54
+ ):
55
+ """
56
+ Check errors in a jit function.
57
+
58
+ Examples
59
+ --------
60
+
61
+ It can give a function which receive arguments that passed from the JIT variables and raise errors.
62
+
63
+ >>> def error(x):
64
+ >>> raise ValueError(f'error {x}')
65
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
66
+ >>> jit_error_if(x.sum() < 5., error, x)
67
+
68
+ Or, it can be a simple string message.
69
+
70
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
71
+ >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
72
+
73
+
74
+ Parameters
75
+ ----------
76
+ pred: bool, Array
77
+ The boolean prediction.
78
+ error: callable, str
79
+ The error function, which raise errors, or a string indicating the error message.
80
+ err_args:
81
+ The arguments which passed into `err_f`.
82
+ err_kwargs:
83
+ The keywords which passed into `err_f`.
84
+ """
85
+ if isinstance(error, str):
86
+ error = partial(_error_msg, error)
87
+
88
+ jax.lax.cond(
89
+ unvmap(pred, op='any'),
90
+ partial(_err_jit_true_branch, error),
91
+ _err_jit_false_branch,
92
+ jax.tree.map(functools.partial(unvmap, op='none'), err_args),
93
+ jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
94
+ )
@@ -25,18 +25,18 @@ import brainstate as bst
25
25
  class TestJitError(unittest.TestCase):
26
26
  def test1(self):
27
27
  with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
28
- bst.transform.jit_error(True, 'error')
28
+ bst.transform.jit_error_if(True, 'error')
29
29
 
30
30
  def err_f(x):
31
31
  raise ValueError(f'error: {x}')
32
32
 
33
33
  with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
34
- bst.transform.jit_error(True, err_f, 1.)
34
+ bst.transform.jit_error_if(True, err_f, 1.)
35
35
 
36
36
  def test_vmap(self):
37
37
 
38
38
  def f(x):
39
- bst.transform.jit_error(x, 'error: {x}', x=x)
39
+ bst.transform.jit_error_if(x, 'error: {x}', x=x)
40
40
 
41
41
  jax.vmap(f)(jnp.array([False, False, False]))
42
42
  with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
@@ -45,7 +45,7 @@ class TestJitError(unittest.TestCase):
45
45
  def test_vmap_vmap(self):
46
46
 
47
47
  def f(x):
48
- bst.transform.jit_error(x, 'error: {x}', x=x)
48
+ bst.transform.jit_error_if(x, 'error: {x}', x=x)
49
49
 
50
50
  jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
51
51
  [False, False, False]]))