brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,220 +1,220 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import jax
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestCond(unittest.TestCase):
25
- def test1(self):
26
- brainstate.random.seed(1)
27
- brainstate.compile.cond(True, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
28
- brainstate.compile.cond(False, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
29
-
30
- def test2(self):
31
- st1 = brainstate.State(brainstate.random.rand(10))
32
- st2 = brainstate.State(brainstate.random.rand(2))
33
- st3 = brainstate.State(brainstate.random.rand(5))
34
- st4 = brainstate.State(brainstate.random.rand(2, 10))
35
-
36
- def true_fun(x):
37
- st1.value = st2.value @ st4.value + x
38
-
39
- def false_fun(x):
40
- st3.value = (st3.value + 1.) * x
41
-
42
- brainstate.compile.cond(True, true_fun, false_fun, 2.)
43
- assert not isinstance(st1.value, jax.core.Tracer)
44
- assert not isinstance(st2.value, jax.core.Tracer)
45
- assert not isinstance(st3.value, jax.core.Tracer)
46
- assert not isinstance(st4.value, jax.core.Tracer)
47
-
48
-
49
- class TestSwitch(unittest.TestCase):
50
- def testSwitch(self):
51
- def branch(x):
52
- y = jax.lax.mul(2, x)
53
- return y, jax.lax.mul(2, y)
54
-
55
- branches = [lambda x: (x, x),
56
- branch,
57
- lambda x: (x, -x)]
58
-
59
- def fun(x):
60
- if x <= 0:
61
- return branches[0](x)
62
- elif x == 1:
63
- return branches[1](x)
64
- else:
65
- return branches[2](x)
66
-
67
- def cfun(x):
68
- return brainstate.compile.switch(x, branches, x)
69
-
70
- self.assertEqual(fun(-1), cfun(-1))
71
- self.assertEqual(fun(0), cfun(0))
72
- self.assertEqual(fun(1), cfun(1))
73
- self.assertEqual(fun(2), cfun(2))
74
- self.assertEqual(fun(3), cfun(3))
75
-
76
- cfun = jax.jit(cfun)
77
-
78
- self.assertEqual(fun(-1), cfun(-1))
79
- self.assertEqual(fun(0), cfun(0))
80
- self.assertEqual(fun(1), cfun(1))
81
- self.assertEqual(fun(2), cfun(2))
82
- self.assertEqual(fun(3), cfun(3))
83
-
84
- def testSwitchMultiOperands(self):
85
- branches = [jax.lax.add, jax.lax.mul]
86
-
87
- def fun(x):
88
- i = 0 if x <= 0 else 1
89
- return branches[i](x, x)
90
-
91
- def cfun(x):
92
- return brainstate.compile.switch(x, branches, x, x)
93
-
94
- self.assertEqual(fun(-1), cfun(-1))
95
- self.assertEqual(fun(0), cfun(0))
96
- self.assertEqual(fun(1), cfun(1))
97
- self.assertEqual(fun(2), cfun(2))
98
- cfun = jax.jit(cfun)
99
- self.assertEqual(fun(-1), cfun(-1))
100
- self.assertEqual(fun(0), cfun(0))
101
- self.assertEqual(fun(1), cfun(1))
102
- self.assertEqual(fun(2), cfun(2))
103
-
104
- def testSwitchResidualsMerge(self):
105
- def get_conds(fun):
106
- jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
107
- return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
108
-
109
- def branch_invars_len(cond_eqn):
110
- lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
111
- assert len(set(lens)) == 1
112
- return lens[0]
113
-
114
- def branch_outvars_len(cond_eqn):
115
- lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
116
- assert len(set(lens)) == 1
117
- return lens[0]
118
-
119
- branches1 = [lambda x: jnp.sin(x),
120
- lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
121
- branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
122
- branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
123
-
124
- def fun1(x, i):
125
- return brainstate.compile.switch(i + 1, branches1, x)
126
-
127
- def fun2(x, i):
128
- return brainstate.compile.switch(i + 1, branches2, x)
129
-
130
- def fun3(x, i):
131
- return brainstate.compile.switch(i + 1, branches3, x)
132
-
133
- fwd1, bwd1 = get_conds(fun1)
134
- fwd2, bwd2 = get_conds(fun2)
135
- fwd3, bwd3 = get_conds(fun3)
136
-
137
- fwd1_num_out = branch_outvars_len(fwd1)
138
- fwd2_num_out = branch_outvars_len(fwd2)
139
- fwd3_num_out = branch_outvars_len(fwd3)
140
- assert fwd1_num_out == fwd2_num_out
141
- assert fwd3_num_out == fwd2_num_out + 1
142
-
143
- bwd1_num_in = branch_invars_len(bwd1)
144
- bwd2_num_in = branch_invars_len(bwd2)
145
- bwd3_num_in = branch_invars_len(bwd3)
146
- assert bwd1_num_in == bwd2_num_in
147
- assert bwd3_num_in == bwd2_num_in + 1
148
-
149
- def testOneBranchSwitch(self):
150
- branch = lambda x: -x
151
- f = lambda i, x: brainstate.compile.switch(i, [branch], x)
152
- x = 7.
153
- self.assertEqual(f(-1, x), branch(x))
154
- self.assertEqual(f(0, x), branch(x))
155
- self.assertEqual(f(1, x), branch(x))
156
- cf = jax.jit(f)
157
- self.assertEqual(cf(-1, x), branch(x))
158
- self.assertEqual(cf(0, x), branch(x))
159
- self.assertEqual(cf(1, x), branch(x))
160
- cf = jax.jit(f, static_argnums=0)
161
- self.assertEqual(cf(-1, x), branch(x))
162
- self.assertEqual(cf(0, x), branch(x))
163
- self.assertEqual(cf(1, x), branch(x))
164
-
165
-
166
- class TestIfElse(unittest.TestCase):
167
- def test1(self):
168
- def f(a):
169
- return brainstate.compile.ifelse(conditions=[a < 0,
170
- a >= 0 and a < 2,
171
- a >= 2 and a < 5,
172
- a >= 5 and a < 10,
173
- a >= 10],
174
- branches=[lambda: 1,
175
- lambda: 2,
176
- lambda: 3,
177
- lambda: 4,
178
- lambda: 5])
179
-
180
- self.assertTrue(f(3) == 3)
181
- self.assertTrue(f(1) == 2)
182
- self.assertTrue(f(-1) == 1)
183
-
184
- def test_vmap(self):
185
- def f(operands):
186
- f = lambda a: brainstate.compile.ifelse([a > 10,
187
- jnp.logical_and(a <= 10, a > 5),
188
- jnp.logical_and(a <= 5, a > 2),
189
- jnp.logical_and(a <= 2, a > 0),
190
- a <= 0],
191
- [lambda _: 1,
192
- lambda _: 2,
193
- lambda _: 3,
194
- lambda _: 4,
195
- lambda _: 5, ],
196
- a)
197
- return jax.vmap(f)(operands)
198
-
199
- r = f(brainstate.random.randint(-20, 20, 200))
200
- self.assertTrue(r.size == 200)
201
-
202
- def test_grad1(self):
203
- def F2(x):
204
- return brainstate.compile.ifelse((x >= 10, x < 10),
205
- [lambda x: x, lambda x: x ** 2, ],
206
- x)
207
-
208
- self.assertTrue(jax.grad(F2)(9.0) == 18.)
209
- self.assertTrue(jax.grad(F2)(11.0) == 1.)
210
-
211
- def test_grad2(self):
212
- def F3(x):
213
- return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
- [lambda x: x,
215
- lambda x: x ** 2,
216
- lambda x: x ** 4, ],
217
- x)
218
-
219
- self.assertTrue(jax.grad(F3)(9.0) == 18.)
220
- self.assertTrue(jax.grad(F3)(11.0) == 1.)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate
22
+
23
+
24
+ class TestCond(unittest.TestCase):
25
+ def test1(self):
26
+ brainstate.random.seed(1)
27
+ brainstate.compile.cond(True, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
28
+ brainstate.compile.cond(False, lambda: brainstate.random.random(10), lambda: brainstate.random.random(10))
29
+
30
+ def test2(self):
31
+ st1 = brainstate.State(brainstate.random.rand(10))
32
+ st2 = brainstate.State(brainstate.random.rand(2))
33
+ st3 = brainstate.State(brainstate.random.rand(5))
34
+ st4 = brainstate.State(brainstate.random.rand(2, 10))
35
+
36
+ def true_fun(x):
37
+ st1.value = st2.value @ st4.value + x
38
+
39
+ def false_fun(x):
40
+ st3.value = (st3.value + 1.) * x
41
+
42
+ brainstate.compile.cond(True, true_fun, false_fun, 2.)
43
+ assert not isinstance(st1.value, jax.core.Tracer)
44
+ assert not isinstance(st2.value, jax.core.Tracer)
45
+ assert not isinstance(st3.value, jax.core.Tracer)
46
+ assert not isinstance(st4.value, jax.core.Tracer)
47
+
48
+
49
+ class TestSwitch(unittest.TestCase):
50
+ def testSwitch(self):
51
+ def branch(x):
52
+ y = jax.lax.mul(2, x)
53
+ return y, jax.lax.mul(2, y)
54
+
55
+ branches = [lambda x: (x, x),
56
+ branch,
57
+ lambda x: (x, -x)]
58
+
59
+ def fun(x):
60
+ if x <= 0:
61
+ return branches[0](x)
62
+ elif x == 1:
63
+ return branches[1](x)
64
+ else:
65
+ return branches[2](x)
66
+
67
+ def cfun(x):
68
+ return brainstate.compile.switch(x, branches, x)
69
+
70
+ self.assertEqual(fun(-1), cfun(-1))
71
+ self.assertEqual(fun(0), cfun(0))
72
+ self.assertEqual(fun(1), cfun(1))
73
+ self.assertEqual(fun(2), cfun(2))
74
+ self.assertEqual(fun(3), cfun(3))
75
+
76
+ cfun = jax.jit(cfun)
77
+
78
+ self.assertEqual(fun(-1), cfun(-1))
79
+ self.assertEqual(fun(0), cfun(0))
80
+ self.assertEqual(fun(1), cfun(1))
81
+ self.assertEqual(fun(2), cfun(2))
82
+ self.assertEqual(fun(3), cfun(3))
83
+
84
+ def testSwitchMultiOperands(self):
85
+ branches = [jax.lax.add, jax.lax.mul]
86
+
87
+ def fun(x):
88
+ i = 0 if x <= 0 else 1
89
+ return branches[i](x, x)
90
+
91
+ def cfun(x):
92
+ return brainstate.compile.switch(x, branches, x, x)
93
+
94
+ self.assertEqual(fun(-1), cfun(-1))
95
+ self.assertEqual(fun(0), cfun(0))
96
+ self.assertEqual(fun(1), cfun(1))
97
+ self.assertEqual(fun(2), cfun(2))
98
+ cfun = jax.jit(cfun)
99
+ self.assertEqual(fun(-1), cfun(-1))
100
+ self.assertEqual(fun(0), cfun(0))
101
+ self.assertEqual(fun(1), cfun(1))
102
+ self.assertEqual(fun(2), cfun(2))
103
+
104
+ def testSwitchResidualsMerge(self):
105
+ def get_conds(fun):
106
+ jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
107
+ return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
108
+
109
+ def branch_invars_len(cond_eqn):
110
+ lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
111
+ assert len(set(lens)) == 1
112
+ return lens[0]
113
+
114
+ def branch_outvars_len(cond_eqn):
115
+ lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
116
+ assert len(set(lens)) == 1
117
+ return lens[0]
118
+
119
+ branches1 = [lambda x: jnp.sin(x),
120
+ lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
121
+ branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
122
+ branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
123
+
124
+ def fun1(x, i):
125
+ return brainstate.compile.switch(i + 1, branches1, x)
126
+
127
+ def fun2(x, i):
128
+ return brainstate.compile.switch(i + 1, branches2, x)
129
+
130
+ def fun3(x, i):
131
+ return brainstate.compile.switch(i + 1, branches3, x)
132
+
133
+ fwd1, bwd1 = get_conds(fun1)
134
+ fwd2, bwd2 = get_conds(fun2)
135
+ fwd3, bwd3 = get_conds(fun3)
136
+
137
+ fwd1_num_out = branch_outvars_len(fwd1)
138
+ fwd2_num_out = branch_outvars_len(fwd2)
139
+ fwd3_num_out = branch_outvars_len(fwd3)
140
+ assert fwd1_num_out == fwd2_num_out
141
+ assert fwd3_num_out == fwd2_num_out + 1
142
+
143
+ bwd1_num_in = branch_invars_len(bwd1)
144
+ bwd2_num_in = branch_invars_len(bwd2)
145
+ bwd3_num_in = branch_invars_len(bwd3)
146
+ assert bwd1_num_in == bwd2_num_in
147
+ assert bwd3_num_in == bwd2_num_in + 1
148
+
149
+ def testOneBranchSwitch(self):
150
+ branch = lambda x: -x
151
+ f = lambda i, x: brainstate.compile.switch(i, [branch], x)
152
+ x = 7.
153
+ self.assertEqual(f(-1, x), branch(x))
154
+ self.assertEqual(f(0, x), branch(x))
155
+ self.assertEqual(f(1, x), branch(x))
156
+ cf = jax.jit(f)
157
+ self.assertEqual(cf(-1, x), branch(x))
158
+ self.assertEqual(cf(0, x), branch(x))
159
+ self.assertEqual(cf(1, x), branch(x))
160
+ cf = jax.jit(f, static_argnums=0)
161
+ self.assertEqual(cf(-1, x), branch(x))
162
+ self.assertEqual(cf(0, x), branch(x))
163
+ self.assertEqual(cf(1, x), branch(x))
164
+
165
+
166
+ class TestIfElse(unittest.TestCase):
167
+ def test1(self):
168
+ def f(a):
169
+ return brainstate.compile.ifelse(conditions=[a < 0,
170
+ a >= 0 and a < 2,
171
+ a >= 2 and a < 5,
172
+ a >= 5 and a < 10,
173
+ a >= 10],
174
+ branches=[lambda: 1,
175
+ lambda: 2,
176
+ lambda: 3,
177
+ lambda: 4,
178
+ lambda: 5])
179
+
180
+ self.assertTrue(f(3) == 3)
181
+ self.assertTrue(f(1) == 2)
182
+ self.assertTrue(f(-1) == 1)
183
+
184
+ def test_vmap(self):
185
+ def f(operands):
186
+ f = lambda a: brainstate.compile.ifelse([a > 10,
187
+ jnp.logical_and(a <= 10, a > 5),
188
+ jnp.logical_and(a <= 5, a > 2),
189
+ jnp.logical_and(a <= 2, a > 0),
190
+ a <= 0],
191
+ [lambda _: 1,
192
+ lambda _: 2,
193
+ lambda _: 3,
194
+ lambda _: 4,
195
+ lambda _: 5, ],
196
+ a)
197
+ return jax.vmap(f)(operands)
198
+
199
+ r = f(brainstate.random.randint(-20, 20, 200))
200
+ self.assertTrue(r.size == 200)
201
+
202
+ def test_grad1(self):
203
+ def F2(x):
204
+ return brainstate.compile.ifelse((x >= 10, x < 10),
205
+ [lambda x: x, lambda x: x ** 2, ],
206
+ x)
207
+
208
+ self.assertTrue(jax.grad(F2)(9.0) == 18.)
209
+ self.assertTrue(jax.grad(F2)(11.0) == 1.)
210
+
211
+ def test_grad2(self):
212
+ def F3(x):
213
+ return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
+ [lambda x: x,
215
+ lambda x: x ** 2,
216
+ lambda x: x ** 4, ],
217
+ x)
218
+
219
+ self.assertTrue(jax.grad(F3)(9.0) == 18.)
220
+ self.assertTrue(jax.grad(F3)(11.0) == 1.)
@@ -1,92 +1,94 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from functools import partial
18
- from typing import Callable, Union
19
-
20
- import jax
21
-
22
- from brainstate._utils import set_module_as
23
- from ._unvmap import unvmap
24
-
25
- __all__ = [
26
- 'jit_error_if',
27
- ]
28
-
29
-
30
- def _err_jit_true_branch(err_fun, args, kwargs):
31
- jax.debug.callback(err_fun, *args, **kwargs)
32
-
33
-
34
- def _err_jit_false_branch(args, kwargs):
35
- pass
36
-
37
-
38
- def _error_msg(msg, *arg, **kwargs):
39
- if len(arg):
40
- msg = msg % arg
41
- if len(kwargs):
42
- msg = msg.format(**kwargs)
43
- raise ValueError(msg)
44
-
45
-
46
- @set_module_as('brainstate.compile')
47
- def jit_error_if(
48
- pred,
49
- error: Union[Callable, str],
50
- *err_args,
51
- **err_kwargs,
52
- ):
53
- """
54
- Check errors in a jit function.
55
-
56
- Examples
57
- --------
58
-
59
- It can give a function which receive arguments that passed from the JIT variables and raise errors.
60
-
61
- >>> def error(x):
62
- >>> raise ValueError(f'error {x}')
63
- >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
64
- >>> jit_error_if(x.sum() < 5., error, x)
65
-
66
- Or, it can be a simple string message.
67
-
68
- >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
69
- >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
70
-
71
-
72
- Parameters
73
- ----------
74
- pred: bool, Array
75
- The boolean prediction.
76
- error: callable, str
77
- The error function, which raise errors, or a string indicating the error message.
78
- err_args:
79
- The arguments which passed into `err_f`.
80
- err_kwargs:
81
- The keywords which passed into `err_f`.
82
- """
83
- if isinstance(error, str):
84
- error = partial(_error_msg, error)
85
-
86
- jax.lax.cond(
87
- unvmap(pred, op='any'),
88
- partial(_err_jit_true_branch, error),
89
- _err_jit_false_branch,
90
- jax.tree.map(functools.partial(unvmap, op='none'), err_args),
91
- jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
92
- )
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from functools import partial
18
+ from typing import Callable, Union
19
+
20
+ import jax
21
+
22
+ from brainstate._utils import set_module_as
23
+ from ._unvmap import unvmap
24
+
25
+ __all__ = [
26
+ 'jit_error_if',
27
+ ]
28
+
29
+
30
+ def _err_jit_true_branch(err_fun, args, kwargs):
31
+ jax.debug.callback(err_fun, *args, **kwargs)
32
+
33
+
34
+ def _err_jit_false_branch(args, kwargs):
35
+ pass
36
+
37
+
38
+ def _error_msg(msg, *arg, **kwargs):
39
+ if len(arg):
40
+ msg = msg % arg
41
+ if len(kwargs):
42
+ msg = msg.format(**kwargs)
43
+ raise ValueError(msg)
44
+
45
+
46
+ @set_module_as('brainstate.transform')
47
+ def jit_error_if(
48
+ pred,
49
+ error: Union[Callable, str],
50
+ *err_args,
51
+ **err_kwargs,
52
+ ):
53
+ """
54
+ Check errors in a jit function.
55
+
56
+ Parameters
57
+ ----------
58
+ pred : bool or Array
59
+ The boolean prediction.
60
+ error : callable or str
61
+ The error function, which raise errors, or a string indicating the error message.
62
+ *err_args
63
+ The arguments which passed into the error function.
64
+ **err_kwargs
65
+ The keywords which passed into the error function.
66
+
67
+ Examples
68
+ --------
69
+ It can give a function which receive arguments that passed from the JIT variables and raise errors.
70
+
71
+ .. code-block:: python
72
+
73
+ >>> def error(x):
74
+ ... raise ValueError(f'error {x}')
75
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
76
+ >>> jit_error_if(x.sum() < 5., error, x)
77
+
78
+ Or, it can be a simple string message.
79
+
80
+ .. code-block:: python
81
+
82
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
83
+ >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
84
+ """
85
+ if isinstance(error, str):
86
+ error = partial(_error_msg, error)
87
+
88
+ jax.lax.cond(
89
+ unvmap(pred, op='any'),
90
+ partial(_err_jit_true_branch, error),
91
+ _err_jit_false_branch,
92
+ jax.tree.map(functools.partial(unvmap, op='none'), err_args),
93
+ jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
94
+ )