brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,259 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Callable, Sequence
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from brainstate._utils import set_module_as
25
+ from ._error_if import jit_error_if
26
+ from ._make_jaxpr import StatefulFunction
27
+ from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
28
+
29
+ __all__ = [
30
+ 'cond', 'switch', 'ifelse',
31
+ ]
32
+
33
+
34
+ @set_module_as('brainstate.compile')
35
+ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
36
+ """
37
+ Conditionally apply ``true_fun`` or ``false_fun``.
38
+
39
+ Provided arguments are correctly typed, ``cond()`` has equivalent
40
+ semantics to this Python implementation, where ``pred`` must be a
41
+ scalar type::
42
+
43
+ def cond(pred, true_fun, false_fun, *operands):
44
+ if pred:
45
+ return true_fun(*operands)
46
+ else:
47
+ return false_fun(*operands)
48
+
49
+
50
+ In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
51
+ the two branches is executed (up to compiler rewrites and optimizations).
52
+ However, when transformed with :func:`~jax.vmap` to operate over a batch of
53
+ predicates, ``cond`` is converted to :func:`~jax.lax.select`.
54
+
55
+ Args:
56
+ pred: Boolean scalar type, indicating which branch function to apply.
57
+ true_fun: Function (A -> B), to be applied if ``pred`` is True.
58
+ false_fun: Function (A -> B), to be applied if ``pred`` is False.
59
+ operands: Operands (A) input to either branch depending on ``pred``. The
60
+ type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
61
+ thereof.
62
+
63
+ Returns:
64
+ Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
65
+ depending on the value of ``pred``. The type can be a scalar, array, or any
66
+ pytree (nested Python tuple/list/dict) thereof.
67
+ """
68
+ if not (callable(true_fun) and callable(false_fun)):
69
+ raise TypeError("true_fun and false_fun arguments should be callable.")
70
+
71
+ if pred is None:
72
+ raise TypeError("cond predicate is None")
73
+ if isinstance(pred, Sequence) or np.ndim(pred) != 0:
74
+ raise TypeError(f"Pred must be a scalar, got {pred} of " +
75
+ (f"type {type(pred)}" if isinstance(pred, Sequence)
76
+ else f"shape {np.shape(pred)}."))
77
+
78
+ # check pred
79
+ try:
80
+ pred_dtype = jax.dtypes.result_type(pred)
81
+ except TypeError as err:
82
+ raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
83
+ if pred_dtype.kind != 'b':
84
+ if pred_dtype.kind in 'iuf':
85
+ pred = pred != 0
86
+ else:
87
+ raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
88
+
89
+ # not jit
90
+ if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
91
+ if pred:
92
+ return true_fun(*operands)
93
+ else:
94
+ return false_fun(*operands)
95
+
96
+ # evaluate jaxpr
97
+ with jax.ensure_compile_time_eval():
98
+ stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
99
+ stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
100
+
101
+ # state trace and state values
102
+ state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
103
+ read_state_vals = state_trace.get_read_state_values(True)
104
+ write_state_vals = state_trace.get_write_state_values(True)
105
+
106
+ # wrap the functions
107
+ true_fun = wrap_single_fun_in_multi_branches(stateful_true, state_trace, read_state_vals, True)
108
+ false_fun = wrap_single_fun_in_multi_branches(stateful_false, state_trace, read_state_vals, True)
109
+
110
+ # cond
111
+ write_state_vals, out = jax.lax.cond(pred, true_fun, false_fun, write_state_vals, *operands)
112
+
113
+ # assign the written state values and restore the read state values
114
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
115
+ return out
116
+
117
+
118
+ @set_module_as('brainstate.compile')
119
+ def switch(index, branches: Sequence[Callable], *operands):
120
+ """
121
+ Apply exactly one of ``branches`` given by ``index``.
122
+
123
+ If ``index`` is out of bounds, it is clamped to within bounds.
124
+
125
+ Has the semantics of the following Python::
126
+
127
+ def switch(index, branches, *operands):
128
+ index = clamp(0, index, len(branches) - 1)
129
+ return branches[index](*operands)
130
+
131
+ Internally this wraps XLA's `Conditional
132
+ <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
133
+ operator. However, when transformed with :func:`~jax.vmap` to operate over a
134
+ batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
135
+
136
+ Args:
137
+ index: Integer scalar type, indicating which branch function to apply.
138
+ branches: Sequence of functions (A -> B) to be applied based on ``index``.
139
+ operands: Operands (A) input to whichever branch is applied.
140
+
141
+ Returns:
142
+ Value (B) of ``branch(*operands)`` for the branch that was selected based
143
+ on ``index``.
144
+ """
145
+ # check branches
146
+ if not all(callable(branch) for branch in branches):
147
+ raise TypeError("branches argument should be a sequence of callables.")
148
+
149
+ # check index
150
+ if len(np.shape(index)) != 0:
151
+ raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
152
+ try:
153
+ index_dtype = jax.dtypes.result_type(index)
154
+ except TypeError as err:
155
+ msg = f"Index type must be an integer, got {index}."
156
+ raise TypeError(msg) from err
157
+ if index_dtype.kind not in 'iu':
158
+ raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
159
+
160
+ # format branches
161
+ branches = tuple(branches)
162
+ if len(branches) == 0:
163
+ raise ValueError("Empty branch sequence")
164
+ elif len(branches) == 1:
165
+ return branches[0](*operands)
166
+
167
+ # format index
168
+ index = jax.lax.convert_element_type(index, np.int32)
169
+ lo = np.array(0, np.int32)
170
+ hi = np.array(len(branches) - 1, np.int32)
171
+ index = jax.lax.clamp(lo, index, hi)
172
+
173
+ # not jit
174
+ if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
175
+ return branches[int(index)](*operands)
176
+
177
+ # evaluate jaxpr
178
+ with jax.ensure_compile_time_eval():
179
+ wrapped_branches = [StatefulFunction(branch) for branch in branches]
180
+ for wrapped_branch in wrapped_branches:
181
+ wrapped_branch.make_jaxpr(*operands)
182
+
183
+ # wrap the functions
184
+ state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
185
+ state_trace.merge(*[wrapped_branch.get_state_trace() for wrapped_branch in wrapped_branches[2:]])
186
+ read_state_vals = state_trace.get_read_state_values(True)
187
+ write_state_vals = state_trace.get_write_state_values(True)
188
+ branches = [
189
+ wrap_single_fun_in_multi_branches(wrapped_branch, state_trace, read_state_vals, True)
190
+ for wrapped_branch in wrapped_branches
191
+ ]
192
+
193
+ # switch
194
+ write_state_vals, out = jax.lax.switch(index, branches, write_state_vals, *operands)
195
+
196
+ # write back state values or restore them
197
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
198
+ return out
199
+
200
+
201
+ @set_module_as('brainstate.compile')
202
+ def ifelse(conditions, branches, *operands, check_cond: bool = True):
203
+ """
204
+ ``If-else`` control flows looks like native Pythonic programming.
205
+
206
+ Examples
207
+ --------
208
+
209
+ >>> import brainstate as bst
210
+ >>> def f(a):
211
+ >>> return bst.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
212
+ >>> branches=[lambda: 1,
213
+ >>> lambda: 2,
214
+ >>> lambda: 3,
215
+ >>> lambda: 4,
216
+ >>> lambda: 5])
217
+ >>> f(1)
218
+ 4
219
+ >>> f(0)
220
+ 5
221
+
222
+ Parameters
223
+ ----------
224
+ conditions: bool, sequence of bool, Array
225
+ The boolean conditions.
226
+ branches: Any
227
+ The branches, at least has two elements. Elements can be functions,
228
+ arrays, or numbers. The number of ``branches`` and ``conditions`` has
229
+ the relationship of `len(branches) == len(conditions) + 1`.
230
+ Each branch should receive one arguement for ``operands``.
231
+ *operands: optional, Any
232
+ The operands for each branch.
233
+ check_cond: bool
234
+ Whether to check the conditions. Default is True.
235
+
236
+ Returns
237
+ -------
238
+ res: Any
239
+ The results of the control flow.
240
+ """
241
+ # check branches
242
+ if not all(callable(branch) for branch in branches):
243
+ raise TypeError("branches argument should be a sequence of callables.")
244
+
245
+ # format branches
246
+ branches = tuple(branches)
247
+ if len(branches) == 0:
248
+ raise ValueError("Empty branch sequence")
249
+ elif len(branches) == 1:
250
+ return branches[0](*operands)
251
+ if len(conditions) != len(branches):
252
+ raise ValueError("The number of conditions should be equal to the number of branches.")
253
+
254
+ # format index
255
+ conditions = jnp.asarray(conditions, np.int32)
256
+ if check_cond:
257
+ jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
258
+ index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
259
+ return switch(index, branches, *operands)
@@ -0,0 +1,221 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from __future__ import annotations
16
+
17
+ import unittest
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestCond(unittest.TestCase):
26
+ def test1(self):
27
+ bst.random.seed(1)
28
+ bst.compile.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
29
+ bst.compile.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
30
+
31
+ def test2(self):
32
+ st1 = bst.State(bst.random.rand(10))
33
+ st2 = bst.State(bst.random.rand(2))
34
+ st3 = bst.State(bst.random.rand(5))
35
+ st4 = bst.State(bst.random.rand(2, 10))
36
+
37
+ def true_fun(x):
38
+ st1.value = st2.value @ st4.value + x
39
+
40
+ def false_fun(x):
41
+ st3.value = (st3.value + 1.) * x
42
+
43
+ bst.compile.cond(True, true_fun, false_fun, 2.)
44
+ assert not isinstance(st1.value, jax.core.Tracer)
45
+ assert not isinstance(st2.value, jax.core.Tracer)
46
+ assert not isinstance(st3.value, jax.core.Tracer)
47
+ assert not isinstance(st4.value, jax.core.Tracer)
48
+
49
+
50
+ class TestSwitch(unittest.TestCase):
51
+ def testSwitch(self):
52
+ def branch(x):
53
+ y = jax.lax.mul(2, x)
54
+ return y, jax.lax.mul(2, y)
55
+
56
+ branches = [lambda x: (x, x),
57
+ branch,
58
+ lambda x: (x, -x)]
59
+
60
+ def fun(x):
61
+ if x <= 0:
62
+ return branches[0](x)
63
+ elif x == 1:
64
+ return branches[1](x)
65
+ else:
66
+ return branches[2](x)
67
+
68
+ def cfun(x):
69
+ return bst.compile.switch(x, branches, x)
70
+
71
+ self.assertEqual(fun(-1), cfun(-1))
72
+ self.assertEqual(fun(0), cfun(0))
73
+ self.assertEqual(fun(1), cfun(1))
74
+ self.assertEqual(fun(2), cfun(2))
75
+ self.assertEqual(fun(3), cfun(3))
76
+
77
+ cfun = jax.jit(cfun)
78
+
79
+ self.assertEqual(fun(-1), cfun(-1))
80
+ self.assertEqual(fun(0), cfun(0))
81
+ self.assertEqual(fun(1), cfun(1))
82
+ self.assertEqual(fun(2), cfun(2))
83
+ self.assertEqual(fun(3), cfun(3))
84
+
85
+ def testSwitchMultiOperands(self):
86
+ branches = [jax.lax.add, jax.lax.mul]
87
+
88
+ def fun(x):
89
+ i = 0 if x <= 0 else 1
90
+ return branches[i](x, x)
91
+
92
+ def cfun(x):
93
+ return bst.compile.switch(x, branches, x, x)
94
+
95
+ self.assertEqual(fun(-1), cfun(-1))
96
+ self.assertEqual(fun(0), cfun(0))
97
+ self.assertEqual(fun(1), cfun(1))
98
+ self.assertEqual(fun(2), cfun(2))
99
+ cfun = jax.jit(cfun)
100
+ self.assertEqual(fun(-1), cfun(-1))
101
+ self.assertEqual(fun(0), cfun(0))
102
+ self.assertEqual(fun(1), cfun(1))
103
+ self.assertEqual(fun(2), cfun(2))
104
+
105
+ def testSwitchResidualsMerge(self):
106
+ def get_conds(fun):
107
+ jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
108
+ return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
109
+
110
+ def branch_invars_len(cond_eqn):
111
+ lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
112
+ assert len(set(lens)) == 1
113
+ return lens[0]
114
+
115
+ def branch_outvars_len(cond_eqn):
116
+ lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
117
+ assert len(set(lens)) == 1
118
+ return lens[0]
119
+
120
+ branches1 = [lambda x: jnp.sin(x),
121
+ lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
122
+ branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
123
+ branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
124
+
125
+ def fun1(x, i):
126
+ return bst.compile.switch(i + 1, branches1, x)
127
+
128
+ def fun2(x, i):
129
+ return bst.compile.switch(i + 1, branches2, x)
130
+
131
+ def fun3(x, i):
132
+ return bst.compile.switch(i + 1, branches3, x)
133
+
134
+ fwd1, bwd1 = get_conds(fun1)
135
+ fwd2, bwd2 = get_conds(fun2)
136
+ fwd3, bwd3 = get_conds(fun3)
137
+
138
+ fwd1_num_out = branch_outvars_len(fwd1)
139
+ fwd2_num_out = branch_outvars_len(fwd2)
140
+ fwd3_num_out = branch_outvars_len(fwd3)
141
+ assert fwd1_num_out == fwd2_num_out
142
+ assert fwd3_num_out == fwd2_num_out + 1
143
+
144
+ bwd1_num_in = branch_invars_len(bwd1)
145
+ bwd2_num_in = branch_invars_len(bwd2)
146
+ bwd3_num_in = branch_invars_len(bwd3)
147
+ assert bwd1_num_in == bwd2_num_in
148
+ assert bwd3_num_in == bwd2_num_in + 1
149
+
150
+ def testOneBranchSwitch(self):
151
+ branch = lambda x: -x
152
+ f = lambda i, x: bst.compile.switch(i, [branch], x)
153
+ x = 7.
154
+ self.assertEqual(f(-1, x), branch(x))
155
+ self.assertEqual(f(0, x), branch(x))
156
+ self.assertEqual(f(1, x), branch(x))
157
+ cf = jax.jit(f)
158
+ self.assertEqual(cf(-1, x), branch(x))
159
+ self.assertEqual(cf(0, x), branch(x))
160
+ self.assertEqual(cf(1, x), branch(x))
161
+ cf = jax.jit(f, static_argnums=0)
162
+ self.assertEqual(cf(-1, x), branch(x))
163
+ self.assertEqual(cf(0, x), branch(x))
164
+ self.assertEqual(cf(1, x), branch(x))
165
+
166
+
167
+ class TestIfElse(unittest.TestCase):
168
+ def test1(self):
169
+ def f(a):
170
+ return bst.compile.ifelse(conditions=[a < 0,
171
+ a >= 0 and a < 2,
172
+ a >= 2 and a < 5,
173
+ a >= 5 and a < 10,
174
+ a >= 10],
175
+ branches=[lambda: 1,
176
+ lambda: 2,
177
+ lambda: 3,
178
+ lambda: 4,
179
+ lambda: 5])
180
+
181
+ self.assertTrue(f(3) == 3)
182
+ self.assertTrue(f(1) == 2)
183
+ self.assertTrue(f(-1) == 1)
184
+
185
+ def test_vmap(self):
186
+ def f(operands):
187
+ f = lambda a: bst.compile.ifelse([a > 10,
188
+ jnp.logical_and(a <= 10, a > 5),
189
+ jnp.logical_and(a <= 5, a > 2),
190
+ jnp.logical_and(a <= 2, a > 0),
191
+ a <= 0],
192
+ [lambda _: 1,
193
+ lambda _: 2,
194
+ lambda _: 3,
195
+ lambda _: 4,
196
+ lambda _: 5, ],
197
+ a)
198
+ return jax.vmap(f)(operands)
199
+
200
+ r = f(bst.random.randint(-20, 20, 200))
201
+ self.assertTrue(r.size == 200)
202
+
203
+ def test_grad1(self):
204
+ def F2(x):
205
+ return bst.compile.ifelse((x >= 10, x < 10),
206
+ [lambda x: x, lambda x: x ** 2, ],
207
+ x)
208
+
209
+ self.assertTrue(jax.grad(F2)(9.0) == 18.)
210
+ self.assertTrue(jax.grad(F2)(11.0) == 1.)
211
+
212
+ def test_grad2(self):
213
+ def F3(x):
214
+ return bst.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
215
+ [lambda x: x,
216
+ lambda x: x ** 2,
217
+ lambda x: x ** 4, ],
218
+ x)
219
+
220
+ self.assertTrue(jax.grad(F3)(9.0) == 18.)
221
+ self.assertTrue(jax.grad(F3)(11.0) == 1.)
@@ -0,0 +1,94 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from functools import partial
20
+ from typing import Callable, Union
21
+
22
+ import jax
23
+
24
+ from brainstate._utils import set_module_as
25
+ from ._unvmap import unvmap
26
+
27
+ __all__ = [
28
+ 'jit_error_if',
29
+ ]
30
+
31
+
32
+ def _err_jit_true_branch(err_fun, args, kwargs):
33
+ jax.debug.callback(err_fun, *args, **kwargs)
34
+
35
+
36
+ def _err_jit_false_branch(args, kwargs):
37
+ pass
38
+
39
+
40
+ def _error_msg(msg, *arg, **kwargs):
41
+ if len(arg):
42
+ msg = msg % arg
43
+ if len(kwargs):
44
+ msg = msg.format(**kwargs)
45
+ raise ValueError(msg)
46
+
47
+
48
+ @set_module_as('brainstate.compile')
49
+ def jit_error_if(
50
+ pred,
51
+ error: Union[Callable, str],
52
+ *err_args,
53
+ **err_kwargs,
54
+ ):
55
+ """
56
+ Check errors in a jit function.
57
+
58
+ Examples
59
+ --------
60
+
61
+ It can give a function which receive arguments that passed from the JIT variables and raise errors.
62
+
63
+ >>> def error(x):
64
+ >>> raise ValueError(f'error {x}')
65
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
66
+ >>> jit_error_if(x.sum() < 5., error, x)
67
+
68
+ Or, it can be a simple string message.
69
+
70
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
71
+ >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
72
+
73
+
74
+ Parameters
75
+ ----------
76
+ pred: bool, Array
77
+ The boolean prediction.
78
+ error: callable, str
79
+ The error function, which raise errors, or a string indicating the error message.
80
+ err_args:
81
+ The arguments which passed into `err_f`.
82
+ err_kwargs:
83
+ The keywords which passed into `err_f`.
84
+ """
85
+ if isinstance(error, str):
86
+ error = partial(_error_msg, error)
87
+
88
+ jax.lax.cond(
89
+ unvmap(pred, op='any'),
90
+ partial(_err_jit_true_branch, error),
91
+ _err_jit_false_branch,
92
+ jax.tree.map(functools.partial(unvmap, op='none'), err_args),
93
+ jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
94
+ )
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import jaxlib.xla_extension
23
+
24
+ import brainstate as bst
25
+
26
+
27
+ class TestJitError(unittest.TestCase):
28
+ def test1(self):
29
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
30
+ bst.compile.jit_error_if(True, 'error')
31
+
32
+ def err_f(x):
33
+ raise ValueError(f'error: {x}')
34
+
35
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
36
+ bst.compile.jit_error_if(True, err_f, 1.)
37
+
38
+ def test_vmap(self):
39
+ def f(x):
40
+ bst.compile.jit_error_if(x, 'error: {x}', x=x)
41
+
42
+ jax.vmap(f)(jnp.array([False, False, False]))
43
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
44
+ jax.vmap(f)(jnp.array([True, False, False]))
45
+
46
+ def test_vmap_vmap(self):
47
+ def f(x):
48
+ bst.compile.jit_error_if(x, 'error: {x}', x=x)
49
+
50
+ jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
51
+ [False, False, False]]))
52
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
53
+ jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
54
+ [True, False, False]]))