brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,220 +1,220 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
import brainstate
|
22
|
-
|
23
|
-
|
24
|
-
class TestCond(unittest.TestCase):
|
25
|
-
def test1(self):
|
26
|
-
brainstate.random.seed(1)
|
27
|
-
brainstate.compile.cond(True, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
|
28
|
-
brainstate.compile.cond(False, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
|
29
|
-
|
30
|
-
def test2(self):
|
31
|
-
st1 = brainstate.State(brainstate.random.rand(10))
|
32
|
-
st2 = brainstate.State(brainstate.random.rand(2))
|
33
|
-
st3 = brainstate.State(brainstate.random.rand(5))
|
34
|
-
st4 = brainstate.State(brainstate.random.rand(2, 10))
|
35
|
-
|
36
|
-
def true_fun(x):
|
37
|
-
st1.value = st2.value @ st4.value + x
|
38
|
-
|
39
|
-
def false_fun(x):
|
40
|
-
st3.value = (st3.value + 1.) * x
|
41
|
-
|
42
|
-
brainstate.compile.cond(True, true_fun, false_fun, 2.)
|
43
|
-
assert not isinstance(st1.value, jax.core.Tracer)
|
44
|
-
assert not isinstance(st2.value, jax.core.Tracer)
|
45
|
-
assert not isinstance(st3.value, jax.core.Tracer)
|
46
|
-
assert not isinstance(st4.value, jax.core.Tracer)
|
47
|
-
|
48
|
-
|
49
|
-
class TestSwitch(unittest.TestCase):
|
50
|
-
def testSwitch(self):
|
51
|
-
def branch(x):
|
52
|
-
y = jax.lax.mul(2, x)
|
53
|
-
return y, jax.lax.mul(2, y)
|
54
|
-
|
55
|
-
branches = [lambda x: (x, x),
|
56
|
-
branch,
|
57
|
-
lambda x: (x, -x)]
|
58
|
-
|
59
|
-
def fun(x):
|
60
|
-
if x <= 0:
|
61
|
-
return branches[0](x)
|
62
|
-
elif x == 1:
|
63
|
-
return branches[1](x)
|
64
|
-
else:
|
65
|
-
return branches[2](x)
|
66
|
-
|
67
|
-
def cfun(x):
|
68
|
-
return brainstate.compile.switch(x, branches, x)
|
69
|
-
|
70
|
-
self.assertEqual(fun(-1), cfun(-1))
|
71
|
-
self.assertEqual(fun(0), cfun(0))
|
72
|
-
self.assertEqual(fun(1), cfun(1))
|
73
|
-
self.assertEqual(fun(2), cfun(2))
|
74
|
-
self.assertEqual(fun(3), cfun(3))
|
75
|
-
|
76
|
-
cfun = jax.jit(cfun)
|
77
|
-
|
78
|
-
self.assertEqual(fun(-1), cfun(-1))
|
79
|
-
self.assertEqual(fun(0), cfun(0))
|
80
|
-
self.assertEqual(fun(1), cfun(1))
|
81
|
-
self.assertEqual(fun(2), cfun(2))
|
82
|
-
self.assertEqual(fun(3), cfun(3))
|
83
|
-
|
84
|
-
def testSwitchMultiOperands(self):
|
85
|
-
branches = [jax.lax.add, jax.lax.mul]
|
86
|
-
|
87
|
-
def fun(x):
|
88
|
-
i = 0 if x <= 0 else 1
|
89
|
-
return branches[i](x, x)
|
90
|
-
|
91
|
-
def cfun(x):
|
92
|
-
return brainstate.compile.switch(x, branches, x, x)
|
93
|
-
|
94
|
-
self.assertEqual(fun(-1), cfun(-1))
|
95
|
-
self.assertEqual(fun(0), cfun(0))
|
96
|
-
self.assertEqual(fun(1), cfun(1))
|
97
|
-
self.assertEqual(fun(2), cfun(2))
|
98
|
-
cfun = jax.jit(cfun)
|
99
|
-
self.assertEqual(fun(-1), cfun(-1))
|
100
|
-
self.assertEqual(fun(0), cfun(0))
|
101
|
-
self.assertEqual(fun(1), cfun(1))
|
102
|
-
self.assertEqual(fun(2), cfun(2))
|
103
|
-
|
104
|
-
def testSwitchResidualsMerge(self):
|
105
|
-
def get_conds(fun):
|
106
|
-
jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
|
107
|
-
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
|
108
|
-
|
109
|
-
def branch_invars_len(cond_eqn):
|
110
|
-
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
|
111
|
-
assert len(set(lens)) == 1
|
112
|
-
return lens[0]
|
113
|
-
|
114
|
-
def branch_outvars_len(cond_eqn):
|
115
|
-
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
|
116
|
-
assert len(set(lens)) == 1
|
117
|
-
return lens[0]
|
118
|
-
|
119
|
-
branches1 = [lambda x: jnp.sin(x),
|
120
|
-
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
|
121
|
-
branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
|
122
|
-
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
123
|
-
|
124
|
-
def fun1(x, i):
|
125
|
-
return brainstate.compile.switch(i + 1, branches1, x)
|
126
|
-
|
127
|
-
def fun2(x, i):
|
128
|
-
return brainstate.compile.switch(i + 1, branches2, x)
|
129
|
-
|
130
|
-
def fun3(x, i):
|
131
|
-
return brainstate.compile.switch(i + 1, branches3, x)
|
132
|
-
|
133
|
-
fwd1, bwd1 = get_conds(fun1)
|
134
|
-
fwd2, bwd2 = get_conds(fun2)
|
135
|
-
fwd3, bwd3 = get_conds(fun3)
|
136
|
-
|
137
|
-
fwd1_num_out = branch_outvars_len(fwd1)
|
138
|
-
fwd2_num_out = branch_outvars_len(fwd2)
|
139
|
-
fwd3_num_out = branch_outvars_len(fwd3)
|
140
|
-
assert fwd1_num_out == fwd2_num_out
|
141
|
-
assert fwd3_num_out == fwd2_num_out + 1
|
142
|
-
|
143
|
-
bwd1_num_in = branch_invars_len(bwd1)
|
144
|
-
bwd2_num_in = branch_invars_len(bwd2)
|
145
|
-
bwd3_num_in = branch_invars_len(bwd3)
|
146
|
-
assert bwd1_num_in == bwd2_num_in
|
147
|
-
assert bwd3_num_in == bwd2_num_in + 1
|
148
|
-
|
149
|
-
def testOneBranchSwitch(self):
|
150
|
-
branch = lambda x: -x
|
151
|
-
f = lambda i, x: brainstate.compile.switch(i, [branch], x)
|
152
|
-
x = 7.
|
153
|
-
self.assertEqual(f(-1, x), branch(x))
|
154
|
-
self.assertEqual(f(0, x), branch(x))
|
155
|
-
self.assertEqual(f(1, x), branch(x))
|
156
|
-
cf = jax.jit(f)
|
157
|
-
self.assertEqual(cf(-1, x), branch(x))
|
158
|
-
self.assertEqual(cf(0, x), branch(x))
|
159
|
-
self.assertEqual(cf(1, x), branch(x))
|
160
|
-
cf = jax.jit(f, static_argnums=0)
|
161
|
-
self.assertEqual(cf(-1, x), branch(x))
|
162
|
-
self.assertEqual(cf(0, x), branch(x))
|
163
|
-
self.assertEqual(cf(1, x), branch(x))
|
164
|
-
|
165
|
-
|
166
|
-
class TestIfElse(unittest.TestCase):
|
167
|
-
def test1(self):
|
168
|
-
def f(a):
|
169
|
-
return brainstate.compile.ifelse(conditions=[a < 0,
|
170
|
-
a >= 0 and a < 2,
|
171
|
-
a >= 2 and a < 5,
|
172
|
-
a >= 5 and a < 10,
|
173
|
-
a >= 10],
|
174
|
-
branches=[lambda: 1,
|
175
|
-
lambda: 2,
|
176
|
-
lambda: 3,
|
177
|
-
lambda: 4,
|
178
|
-
lambda: 5])
|
179
|
-
|
180
|
-
self.assertTrue(f(3) == 3)
|
181
|
-
self.assertTrue(f(1) == 2)
|
182
|
-
self.assertTrue(f(-1) == 1)
|
183
|
-
|
184
|
-
def test_vmap(self):
|
185
|
-
def f(operands):
|
186
|
-
f = lambda a: brainstate.compile.ifelse([a > 10,
|
187
|
-
jnp.logical_and(a <= 10, a > 5),
|
188
|
-
jnp.logical_and(a <= 5, a > 2),
|
189
|
-
jnp.logical_and(a <= 2, a > 0),
|
190
|
-
a <= 0],
|
191
|
-
[lambda _: 1,
|
192
|
-
lambda _: 2,
|
193
|
-
lambda _: 3,
|
194
|
-
lambda _: 4,
|
195
|
-
lambda _: 5, ],
|
196
|
-
a)
|
197
|
-
return jax.vmap(f)(operands)
|
198
|
-
|
199
|
-
r = f(brainstate.random.randint(-20, 20, 200))
|
200
|
-
self.assertTrue(r.size == 200)
|
201
|
-
|
202
|
-
def test_grad1(self):
|
203
|
-
def F2(x):
|
204
|
-
return brainstate.compile.ifelse((x >= 10, x < 10),
|
205
|
-
[lambda x: x, lambda x: x ** 2, ],
|
206
|
-
x)
|
207
|
-
|
208
|
-
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
209
|
-
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
210
|
-
|
211
|
-
def test_grad2(self):
|
212
|
-
def F3(x):
|
213
|
-
return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
|
-
[lambda x: x,
|
215
|
-
lambda x: x ** 2,
|
216
|
-
lambda x: x ** 4, ],
|
217
|
-
x)
|
218
|
-
|
219
|
-
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
220
|
-
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
|
24
|
+
class TestCond(unittest.TestCase):
|
25
|
+
def test1(self):
|
26
|
+
brainstate.random.seed(1)
|
27
|
+
brainstate.compile.cond(True, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
|
28
|
+
brainstate.compile.cond(False, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
|
29
|
+
|
30
|
+
def test2(self):
|
31
|
+
st1 = brainstate.State(brainstate.random.rand(10))
|
32
|
+
st2 = brainstate.State(brainstate.random.rand(2))
|
33
|
+
st3 = brainstate.State(brainstate.random.rand(5))
|
34
|
+
st4 = brainstate.State(brainstate.random.rand(2, 10))
|
35
|
+
|
36
|
+
def true_fun(x):
|
37
|
+
st1.value = st2.value @ st4.value + x
|
38
|
+
|
39
|
+
def false_fun(x):
|
40
|
+
st3.value = (st3.value + 1.) * x
|
41
|
+
|
42
|
+
brainstate.compile.cond(True, true_fun, false_fun, 2.)
|
43
|
+
assert not isinstance(st1.value, jax.core.Tracer)
|
44
|
+
assert not isinstance(st2.value, jax.core.Tracer)
|
45
|
+
assert not isinstance(st3.value, jax.core.Tracer)
|
46
|
+
assert not isinstance(st4.value, jax.core.Tracer)
|
47
|
+
|
48
|
+
|
49
|
+
class TestSwitch(unittest.TestCase):
|
50
|
+
def testSwitch(self):
|
51
|
+
def branch(x):
|
52
|
+
y = jax.lax.mul(2, x)
|
53
|
+
return y, jax.lax.mul(2, y)
|
54
|
+
|
55
|
+
branches = [lambda x: (x, x),
|
56
|
+
branch,
|
57
|
+
lambda x: (x, -x)]
|
58
|
+
|
59
|
+
def fun(x):
|
60
|
+
if x <= 0:
|
61
|
+
return branches[0](x)
|
62
|
+
elif x == 1:
|
63
|
+
return branches[1](x)
|
64
|
+
else:
|
65
|
+
return branches[2](x)
|
66
|
+
|
67
|
+
def cfun(x):
|
68
|
+
return brainstate.compile.switch(x, branches, x)
|
69
|
+
|
70
|
+
self.assertEqual(fun(-1), cfun(-1))
|
71
|
+
self.assertEqual(fun(0), cfun(0))
|
72
|
+
self.assertEqual(fun(1), cfun(1))
|
73
|
+
self.assertEqual(fun(2), cfun(2))
|
74
|
+
self.assertEqual(fun(3), cfun(3))
|
75
|
+
|
76
|
+
cfun = jax.jit(cfun)
|
77
|
+
|
78
|
+
self.assertEqual(fun(-1), cfun(-1))
|
79
|
+
self.assertEqual(fun(0), cfun(0))
|
80
|
+
self.assertEqual(fun(1), cfun(1))
|
81
|
+
self.assertEqual(fun(2), cfun(2))
|
82
|
+
self.assertEqual(fun(3), cfun(3))
|
83
|
+
|
84
|
+
def testSwitchMultiOperands(self):
|
85
|
+
branches = [jax.lax.add, jax.lax.mul]
|
86
|
+
|
87
|
+
def fun(x):
|
88
|
+
i = 0 if x <= 0 else 1
|
89
|
+
return branches[i](x, x)
|
90
|
+
|
91
|
+
def cfun(x):
|
92
|
+
return brainstate.compile.switch(x, branches, x, x)
|
93
|
+
|
94
|
+
self.assertEqual(fun(-1), cfun(-1))
|
95
|
+
self.assertEqual(fun(0), cfun(0))
|
96
|
+
self.assertEqual(fun(1), cfun(1))
|
97
|
+
self.assertEqual(fun(2), cfun(2))
|
98
|
+
cfun = jax.jit(cfun)
|
99
|
+
self.assertEqual(fun(-1), cfun(-1))
|
100
|
+
self.assertEqual(fun(0), cfun(0))
|
101
|
+
self.assertEqual(fun(1), cfun(1))
|
102
|
+
self.assertEqual(fun(2), cfun(2))
|
103
|
+
|
104
|
+
def testSwitchResidualsMerge(self):
|
105
|
+
def get_conds(fun):
|
106
|
+
jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
|
107
|
+
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
|
108
|
+
|
109
|
+
def branch_invars_len(cond_eqn):
|
110
|
+
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
|
111
|
+
assert len(set(lens)) == 1
|
112
|
+
return lens[0]
|
113
|
+
|
114
|
+
def branch_outvars_len(cond_eqn):
|
115
|
+
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
|
116
|
+
assert len(set(lens)) == 1
|
117
|
+
return lens[0]
|
118
|
+
|
119
|
+
branches1 = [lambda x: jnp.sin(x),
|
120
|
+
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
|
121
|
+
branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
|
122
|
+
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
123
|
+
|
124
|
+
def fun1(x, i):
|
125
|
+
return brainstate.compile.switch(i + 1, branches1, x)
|
126
|
+
|
127
|
+
def fun2(x, i):
|
128
|
+
return brainstate.compile.switch(i + 1, branches2, x)
|
129
|
+
|
130
|
+
def fun3(x, i):
|
131
|
+
return brainstate.compile.switch(i + 1, branches3, x)
|
132
|
+
|
133
|
+
fwd1, bwd1 = get_conds(fun1)
|
134
|
+
fwd2, bwd2 = get_conds(fun2)
|
135
|
+
fwd3, bwd3 = get_conds(fun3)
|
136
|
+
|
137
|
+
fwd1_num_out = branch_outvars_len(fwd1)
|
138
|
+
fwd2_num_out = branch_outvars_len(fwd2)
|
139
|
+
fwd3_num_out = branch_outvars_len(fwd3)
|
140
|
+
assert fwd1_num_out == fwd2_num_out
|
141
|
+
assert fwd3_num_out == fwd2_num_out + 1
|
142
|
+
|
143
|
+
bwd1_num_in = branch_invars_len(bwd1)
|
144
|
+
bwd2_num_in = branch_invars_len(bwd2)
|
145
|
+
bwd3_num_in = branch_invars_len(bwd3)
|
146
|
+
assert bwd1_num_in == bwd2_num_in
|
147
|
+
assert bwd3_num_in == bwd2_num_in + 1
|
148
|
+
|
149
|
+
def testOneBranchSwitch(self):
|
150
|
+
branch = lambda x: -x
|
151
|
+
f = lambda i, x: brainstate.compile.switch(i, [branch], x)
|
152
|
+
x = 7.
|
153
|
+
self.assertEqual(f(-1, x), branch(x))
|
154
|
+
self.assertEqual(f(0, x), branch(x))
|
155
|
+
self.assertEqual(f(1, x), branch(x))
|
156
|
+
cf = jax.jit(f)
|
157
|
+
self.assertEqual(cf(-1, x), branch(x))
|
158
|
+
self.assertEqual(cf(0, x), branch(x))
|
159
|
+
self.assertEqual(cf(1, x), branch(x))
|
160
|
+
cf = jax.jit(f, static_argnums=0)
|
161
|
+
self.assertEqual(cf(-1, x), branch(x))
|
162
|
+
self.assertEqual(cf(0, x), branch(x))
|
163
|
+
self.assertEqual(cf(1, x), branch(x))
|
164
|
+
|
165
|
+
|
166
|
+
class TestIfElse(unittest.TestCase):
|
167
|
+
def test1(self):
|
168
|
+
def f(a):
|
169
|
+
return brainstate.compile.ifelse(conditions=[a < 0,
|
170
|
+
a >= 0 and a < 2,
|
171
|
+
a >= 2 and a < 5,
|
172
|
+
a >= 5 and a < 10,
|
173
|
+
a >= 10],
|
174
|
+
branches=[lambda: 1,
|
175
|
+
lambda: 2,
|
176
|
+
lambda: 3,
|
177
|
+
lambda: 4,
|
178
|
+
lambda: 5])
|
179
|
+
|
180
|
+
self.assertTrue(f(3) == 3)
|
181
|
+
self.assertTrue(f(1) == 2)
|
182
|
+
self.assertTrue(f(-1) == 1)
|
183
|
+
|
184
|
+
def test_vmap(self):
|
185
|
+
def f(operands):
|
186
|
+
f = lambda a: brainstate.compile.ifelse([a > 10,
|
187
|
+
jnp.logical_and(a <= 10, a > 5),
|
188
|
+
jnp.logical_and(a <= 5, a > 2),
|
189
|
+
jnp.logical_and(a <= 2, a > 0),
|
190
|
+
a <= 0],
|
191
|
+
[lambda _: 1,
|
192
|
+
lambda _: 2,
|
193
|
+
lambda _: 3,
|
194
|
+
lambda _: 4,
|
195
|
+
lambda _: 5, ],
|
196
|
+
a)
|
197
|
+
return jax.vmap(f)(operands)
|
198
|
+
|
199
|
+
r = f(brainstate.random.randint(-20, 20, 200))
|
200
|
+
self.assertTrue(r.size == 200)
|
201
|
+
|
202
|
+
def test_grad1(self):
|
203
|
+
def F2(x):
|
204
|
+
return brainstate.compile.ifelse((x >= 10, x < 10),
|
205
|
+
[lambda x: x, lambda x: x ** 2, ],
|
206
|
+
x)
|
207
|
+
|
208
|
+
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
209
|
+
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
210
|
+
|
211
|
+
def test_grad2(self):
|
212
|
+
def F3(x):
|
213
|
+
return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
|
+
[lambda x: x,
|
215
|
+
lambda x: x ** 2,
|
216
|
+
lambda x: x ** 4, ],
|
217
|
+
x)
|
218
|
+
|
219
|
+
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
220
|
+
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|
@@ -1,94 +1,94 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import functools
|
17
|
-
from functools import partial
|
18
|
-
from typing import Callable, Union
|
19
|
-
|
20
|
-
import jax
|
21
|
-
|
22
|
-
from brainstate._utils import set_module_as
|
23
|
-
from ._unvmap import unvmap
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'jit_error_if',
|
27
|
-
]
|
28
|
-
|
29
|
-
|
30
|
-
def _err_jit_true_branch(err_fun, args, kwargs):
|
31
|
-
jax.debug.callback(err_fun, *args, **kwargs)
|
32
|
-
|
33
|
-
|
34
|
-
def _err_jit_false_branch(args, kwargs):
|
35
|
-
pass
|
36
|
-
|
37
|
-
|
38
|
-
def _error_msg(msg, *arg, **kwargs):
|
39
|
-
if len(arg):
|
40
|
-
msg = msg % arg
|
41
|
-
if len(kwargs):
|
42
|
-
msg = msg.format(**kwargs)
|
43
|
-
raise ValueError(msg)
|
44
|
-
|
45
|
-
|
46
|
-
@set_module_as('brainstate.transform')
|
47
|
-
def jit_error_if(
|
48
|
-
pred,
|
49
|
-
error: Union[Callable, str],
|
50
|
-
*err_args,
|
51
|
-
**err_kwargs,
|
52
|
-
):
|
53
|
-
"""
|
54
|
-
Check errors in a jit function.
|
55
|
-
|
56
|
-
Parameters
|
57
|
-
----------
|
58
|
-
pred : bool or Array
|
59
|
-
The boolean prediction.
|
60
|
-
error : callable or str
|
61
|
-
The error function, which raise errors, or a string indicating the error message.
|
62
|
-
*err_args
|
63
|
-
The arguments which passed into the error function.
|
64
|
-
**err_kwargs
|
65
|
-
The keywords which passed into the error function.
|
66
|
-
|
67
|
-
Examples
|
68
|
-
--------
|
69
|
-
It can give a function which receive arguments that passed from the JIT variables and raise errors.
|
70
|
-
|
71
|
-
.. code-block:: python
|
72
|
-
|
73
|
-
>>> def error(x):
|
74
|
-
... raise ValueError(f'error {x}')
|
75
|
-
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
76
|
-
>>> jit_error_if(x.sum() < 5., error, x)
|
77
|
-
|
78
|
-
Or, it can be a simple string message.
|
79
|
-
|
80
|
-
.. code-block:: python
|
81
|
-
|
82
|
-
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
83
|
-
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
84
|
-
"""
|
85
|
-
if isinstance(error, str):
|
86
|
-
error = partial(_error_msg, error)
|
87
|
-
|
88
|
-
jax.lax.cond(
|
89
|
-
unvmap(pred, op='any'),
|
90
|
-
partial(_err_jit_true_branch, error),
|
91
|
-
_err_jit_false_branch,
|
92
|
-
jax.tree.map(functools.partial(unvmap, op='none'), err_args),
|
93
|
-
jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
|
94
|
-
)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from functools import partial
|
18
|
+
from typing import Callable, Union
|
19
|
+
|
20
|
+
import jax
|
21
|
+
|
22
|
+
from brainstate._utils import set_module_as
|
23
|
+
from ._unvmap import unvmap
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'jit_error_if',
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
def _err_jit_true_branch(err_fun, args, kwargs):
|
31
|
+
jax.debug.callback(err_fun, *args, **kwargs)
|
32
|
+
|
33
|
+
|
34
|
+
def _err_jit_false_branch(args, kwargs):
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
def _error_msg(msg, *arg, **kwargs):
|
39
|
+
if len(arg):
|
40
|
+
msg = msg % arg
|
41
|
+
if len(kwargs):
|
42
|
+
msg = msg.format(**kwargs)
|
43
|
+
raise ValueError(msg)
|
44
|
+
|
45
|
+
|
46
|
+
@set_module_as('brainstate.transform')
|
47
|
+
def jit_error_if(
|
48
|
+
pred,
|
49
|
+
error: Union[Callable, str],
|
50
|
+
*err_args,
|
51
|
+
**err_kwargs,
|
52
|
+
):
|
53
|
+
"""
|
54
|
+
Check errors in a jit function.
|
55
|
+
|
56
|
+
Parameters
|
57
|
+
----------
|
58
|
+
pred : bool or Array
|
59
|
+
The boolean prediction.
|
60
|
+
error : callable or str
|
61
|
+
The error function, which raise errors, or a string indicating the error message.
|
62
|
+
*err_args
|
63
|
+
The arguments which passed into the error function.
|
64
|
+
**err_kwargs
|
65
|
+
The keywords which passed into the error function.
|
66
|
+
|
67
|
+
Examples
|
68
|
+
--------
|
69
|
+
It can give a function which receive arguments that passed from the JIT variables and raise errors.
|
70
|
+
|
71
|
+
.. code-block:: python
|
72
|
+
|
73
|
+
>>> def error(x):
|
74
|
+
... raise ValueError(f'error {x}')
|
75
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
76
|
+
>>> jit_error_if(x.sum() < 5., error, x)
|
77
|
+
|
78
|
+
Or, it can be a simple string message.
|
79
|
+
|
80
|
+
.. code-block:: python
|
81
|
+
|
82
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
83
|
+
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
84
|
+
"""
|
85
|
+
if isinstance(error, str):
|
86
|
+
error = partial(_error_msg, error)
|
87
|
+
|
88
|
+
jax.lax.cond(
|
89
|
+
unvmap(pred, op='any'),
|
90
|
+
partial(_err_jit_true_branch, error),
|
91
|
+
_err_jit_false_branch,
|
92
|
+
jax.tree.map(functools.partial(unvmap, op='none'), err_args),
|
93
|
+
jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
|
94
|
+
)
|