brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,334 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import operator
19
- from collections.abc import Callable, Sequence
20
- from functools import wraps, reduce
21
-
22
- import jax
23
- import jax.numpy as jnp
24
- import numpy as np
25
-
26
- from brainstate._utils import set_module_as
27
- from ._error_if import jit_error_if
28
- from ._make_jaxpr import StatefulFunction, _assign_state_values
29
-
30
- __all__ = [
31
- 'cond', 'switch', 'ifelse',
32
- ]
33
-
34
-
35
- def _wrapped_fun(stateful_fun: StatefulFunction, states, return_states=True):
36
- @wraps(stateful_fun.fun)
37
- def wrapped_branch(state_vals, *operands):
38
- assert len(states) == len(state_vals)
39
- for st, val in zip(states, state_vals):
40
- st.value = val
41
- out = stateful_fun.jaxpr_call_auto(*operands)
42
- if return_states:
43
- return tuple(st.value for st in states), out
44
- return out
45
-
46
- return wrapped_branch
47
-
48
-
49
- @set_module_as('brainstate.transform')
50
- def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
51
- """
52
- Conditionally apply ``true_fun`` or ``false_fun``.
53
-
54
- Provided arguments are correctly typed, ``cond()`` has equivalent
55
- semantics to this Python implementation, where ``pred`` must be a
56
- scalar type::
57
-
58
- def cond(pred, true_fun, false_fun, *operands):
59
- if pred:
60
- return true_fun(*operands)
61
- else:
62
- return false_fun(*operands)
63
-
64
-
65
- In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
66
- the two branches is executed (up to compiler rewrites and optimizations).
67
- However, when transformed with :func:`~jax.vmap` to operate over a batch of
68
- predicates, ``cond`` is converted to :func:`~jax.lax.select`.
69
-
70
- Args:
71
- pred: Boolean scalar type, indicating which branch function to apply.
72
- true_fun: Function (A -> B), to be applied if ``pred`` is True.
73
- false_fun: Function (A -> B), to be applied if ``pred`` is False.
74
- operands: Operands (A) input to either branch depending on ``pred``. The
75
- type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
76
- thereof.
77
-
78
- Returns:
79
- Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
80
- depending on the value of ``pred``. The type can be a scalar, array, or any
81
- pytree (nested Python tuple/list/dict) thereof.
82
- """
83
- if not (callable(true_fun) and callable(false_fun)):
84
- raise TypeError("true_fun and false_fun arguments should be callable.")
85
-
86
- if pred is None:
87
- raise TypeError("cond predicate is None")
88
- if isinstance(pred, Sequence) or np.ndim(pred) != 0:
89
- raise TypeError(f"Pred must be a scalar, got {pred} of " +
90
- (f"type {type(pred)}" if isinstance(pred, Sequence)
91
- else f"shape {np.shape(pred)}."))
92
-
93
- # check pred
94
- try:
95
- pred_dtype = jax.dtypes.result_type(pred)
96
- except TypeError as err:
97
- raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
98
- if pred_dtype.kind != 'b':
99
- if pred_dtype.kind in 'iuf':
100
- pred = pred != 0
101
- else:
102
- raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
103
-
104
- # not jit
105
- if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
106
- if pred:
107
- return true_fun(*operands)
108
- else:
109
- return false_fun(*operands)
110
-
111
- # evaluate jaxpr
112
- true_fun_wrap = StatefulFunction(true_fun).make_jaxpr(*operands)
113
- false_fun_wrap = StatefulFunction(false_fun).make_jaxpr(*operands)
114
-
115
- # wrap the functions
116
- all_states = tuple(set(true_fun_wrap.get_states() + false_fun_wrap.get_states()))
117
- true_fun = _wrapped_fun(true_fun_wrap, all_states)
118
- false_fun = _wrapped_fun(false_fun_wrap, all_states)
119
-
120
- # operands
121
- operands = ([st.value for st in all_states],) + operands
122
-
123
- # cond
124
- state_vals, out = jax.lax.cond(pred, true_fun, false_fun, *operands)
125
- _assign_state_values(all_states, state_vals)
126
- return out
127
-
128
- # ops, ops_tree = jax.tree.flatten(operands)
129
- # linear_ops = [False] * len(ops)
130
- # ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
131
- #
132
- # # true and false jaxprs
133
- # jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
134
- # (true_fun, false_fun), ops_tree, ops_avals, 'cond')
135
- # if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
136
- # raise ValueError("Cannot pass `Ref`s into `cond`.")
137
- # true_jaxpr, false_jaxpr = jaxprs
138
- # out_tree, false_out_tree = out_trees
139
- # if any(isinstance(out_aval, state.AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals):
140
- # raise ValueError("Cannot return `Ref`s from `cond`.")
141
- #
142
- # _check_tree_and_avals("true_fun and false_fun output",
143
- # out_tree, true_jaxpr.out_avals,
144
- # false_out_tree, false_jaxpr.out_avals)
145
- # joined_effects = jax.core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
146
- # disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
147
- # if disallowed_effects:
148
- # raise NotImplementedError(f'Effects not supported in `cond`: {disallowed_effects}')
149
- #
150
- # # replace jaxpr effects
151
- # index = jax.lax.convert_element_type(pred, np.int32)
152
- # if joined_effects:
153
- # # Raise index in case of effects to allow data-dependence-based discharging
154
- # # of those effects (even if they don't have an explicit data dependence).
155
- # index = jax.core.raise_as_much_as_possible(index)
156
- # false_jaxpr = _replace_jaxpr_effects(false_jaxpr, joined_effects)
157
- # true_jaxpr = _replace_jaxpr_effects(true_jaxpr, joined_effects)
158
- #
159
- # # bind
160
- # linear = [False] * len(consts) + linear_ops
161
- # cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
162
- #
163
- # # outputs
164
- # st_vals, out = jax.tree.unflatten(out_tree, cond_outs)
165
- # for st, val in zip(all_states, st_vals):
166
- # st.value = val
167
- # return out
168
-
169
-
170
- @set_module_as('brainstate.transform')
171
- def switch(index, branches: Sequence[Callable], *operands):
172
- """
173
- Apply exactly one of ``branches`` given by ``index``.
174
-
175
- If ``index`` is out of bounds, it is clamped to within bounds.
176
-
177
- Has the semantics of the following Python::
178
-
179
- def switch(index, branches, *operands):
180
- index = clamp(0, index, len(branches) - 1)
181
- return branches[index](*operands)
182
-
183
- Internally this wraps XLA's `Conditional
184
- <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
185
- operator. However, when transformed with :func:`~jax.vmap` to operate over a
186
- batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
187
-
188
- Args:
189
- index: Integer scalar type, indicating which branch function to apply.
190
- branches: Sequence of functions (A -> B) to be applied based on ``index``.
191
- operands: Operands (A) input to whichever branch is applied.
192
-
193
- Returns:
194
- Value (B) of ``branch(*operands)`` for the branch that was selected based
195
- on ``index``.
196
- """
197
- # check branches
198
- if not all(callable(branch) for branch in branches):
199
- raise TypeError("branches argument should be a sequence of callables.")
200
-
201
- # check index
202
- if len(np.shape(index)) != 0:
203
- raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
204
- try:
205
- index_dtype = jax.dtypes.result_type(index)
206
- except TypeError as err:
207
- msg = f"Index type must be an integer, got {index}."
208
- raise TypeError(msg) from err
209
- if index_dtype.kind not in 'iu':
210
- raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
211
-
212
- # format branches
213
- branches = tuple(branches)
214
- if len(branches) == 0:
215
- raise ValueError("Empty branch sequence")
216
- elif len(branches) == 1:
217
- return branches[0](*operands)
218
-
219
- # format index
220
- index = jax.lax.convert_element_type(index, np.int32)
221
- lo = np.array(0, np.int32)
222
- hi = np.array(len(branches) - 1, np.int32)
223
- index = jax.lax.clamp(lo, index, hi)
224
-
225
- # not jit
226
- if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
227
- return branches[int(index)](*operands)
228
-
229
- # evaluate jaxpr
230
- wrapped_branches = [StatefulFunction(branch) for branch in branches]
231
- for wrapped_branch in wrapped_branches:
232
- wrapped_branch.make_jaxpr(*operands)
233
-
234
- # wrap the functions
235
- all_states = tuple(set(reduce(operator.add, [wrapped_branch.get_states() for wrapped_branch in wrapped_branches])))
236
- branches = tuple(_wrapped_fun(wrapped_branch, all_states) for wrapped_branch in wrapped_branches)
237
-
238
- # operands
239
- operands = ([st.value for st in all_states],) + operands
240
-
241
- # switch
242
- state_vals, out = jax.lax.switch(index, branches, *operands)
243
- _assign_state_values(all_states, state_vals)
244
- return out
245
-
246
- # ops, ops_tree = jax.tree.flatten(operands)
247
- # ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
248
- #
249
- # # true jaxprs
250
- # jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
251
- # branches, ops_tree, ops_avals, primitive_name='switch')
252
- # for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
253
- # _check_tree_and_avals(f"branch 0 and {i + 1} outputs",
254
- # out_trees[0], jaxprs[0].out_avals,
255
- # out_tree, jaxpr.out_avals)
256
- # joined_effects = jax.core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
257
- # disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
258
- # if disallowed_effects:
259
- # raise NotImplementedError(f'Effects not supported in `switch`: {disallowed_effects}')
260
- # if joined_effects:
261
- # # Raise index in case of effects to allow data-dependence-based discharging
262
- # # of those effects (even if they don't have an explicit data dependence).
263
- # index = jax.core.raise_as_much_as_possible(index)
264
- #
265
- # # bind
266
- # linear = (False,) * (len(consts) + len(ops))
267
- # cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
268
- #
269
- # # outputs
270
- # st_vals, out = jax.tree.unflatten(out_trees[0], cond_outs)
271
- # for st, val in zip(all_states, st_vals):
272
- # st.value = val
273
- # return out
274
-
275
-
276
- @set_module_as('brainstate.transform')
277
- def ifelse(conditions, branches, *operands, check_cond: bool = True):
278
- """
279
- ``If-else`` control flows looks like native Pythonic programming.
280
-
281
- Examples
282
- --------
283
-
284
- >>> import brainstate as bst
285
- >>> def f(a):
286
- >>> return bst.transform.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
287
- >>> branches=[lambda: 1,
288
- >>> lambda: 2,
289
- >>> lambda: 3,
290
- >>> lambda: 4,
291
- >>> lambda: 5])
292
- >>> f(1)
293
- 4
294
- >>> f(0)
295
- 5
296
-
297
- Parameters
298
- ----------
299
- conditions: bool, sequence of bool, Array
300
- The boolean conditions.
301
- branches: Any
302
- The branches, at least has two elements. Elements can be functions,
303
- arrays, or numbers. The number of ``branches`` and ``conditions`` has
304
- the relationship of `len(branches) == len(conditions) + 1`.
305
- Each branch should receive one arguement for ``operands``.
306
- *operands: optional, Any
307
- The operands for each branch.
308
- check_cond: bool
309
- Whether to check the conditions. Default is True.
310
-
311
- Returns
312
- -------
313
- res: Any
314
- The results of the control flow.
315
- """
316
- # check branches
317
- if not all(callable(branch) for branch in branches):
318
- raise TypeError("branches argument should be a sequence of callables.")
319
-
320
- # format branches
321
- branches = tuple(branches)
322
- if len(branches) == 0:
323
- raise ValueError("Empty branch sequence")
324
- elif len(branches) == 1:
325
- return branches[0](*operands)
326
- if len(conditions) != len(branches):
327
- raise ValueError("The number of conditions should be equal to the number of branches.")
328
-
329
- # format index
330
- conditions = jnp.asarray(conditions, np.int32)
331
- if check_cond:
332
- jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
333
- index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
334
- return switch(index, branches, *operands)
@@ -1,220 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import jax
19
- import jax.numpy as jnp
20
-
21
- import brainstate as bst
22
-
23
-
24
- class TestCond(unittest.TestCase):
25
- def test1(self):
26
- bst.random.seed(1)
27
- bst.transform.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
28
- bst.transform.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
29
-
30
- def test2(self):
31
- st1 = bst.State(bst.random.rand(10))
32
- st2 = bst.State(bst.random.rand(2))
33
- st3 = bst.State(bst.random.rand(5))
34
- st4 = bst.State(bst.random.rand(2, 10))
35
-
36
- def true_fun(x):
37
- st1.value = st2.value @ st4.value + x
38
-
39
- def false_fun(x):
40
- st3.value = (st3.value + 1.) * x
41
-
42
- bst.transform.cond(True, true_fun, false_fun, 2.)
43
- assert not isinstance(st1.value, jax.core.Tracer)
44
- assert not isinstance(st2.value, jax.core.Tracer)
45
- assert not isinstance(st3.value, jax.core.Tracer)
46
- assert not isinstance(st4.value, jax.core.Tracer)
47
-
48
-
49
- class TestSwitch(unittest.TestCase):
50
- def testSwitch(self):
51
- def branch(x):
52
- y = jax.lax.mul(2, x)
53
- return y, jax.lax.mul(2, y)
54
-
55
- branches = [lambda x: (x, x),
56
- branch,
57
- lambda x: (x, -x)]
58
-
59
- def fun(x):
60
- if x <= 0:
61
- return branches[0](x)
62
- elif x == 1:
63
- return branches[1](x)
64
- else:
65
- return branches[2](x)
66
-
67
- def cfun(x):
68
- return bst.transform.switch(x, branches, x)
69
-
70
- self.assertEqual(fun(-1), cfun(-1))
71
- self.assertEqual(fun(0), cfun(0))
72
- self.assertEqual(fun(1), cfun(1))
73
- self.assertEqual(fun(2), cfun(2))
74
- self.assertEqual(fun(3), cfun(3))
75
-
76
- cfun = jax.jit(cfun)
77
-
78
- self.assertEqual(fun(-1), cfun(-1))
79
- self.assertEqual(fun(0), cfun(0))
80
- self.assertEqual(fun(1), cfun(1))
81
- self.assertEqual(fun(2), cfun(2))
82
- self.assertEqual(fun(3), cfun(3))
83
-
84
- def testSwitchMultiOperands(self):
85
- branches = [jax.lax.add, jax.lax.mul]
86
-
87
- def fun(x):
88
- i = 0 if x <= 0 else 1
89
- return branches[i](x, x)
90
-
91
- def cfun(x):
92
- return bst.transform.switch(x, branches, x, x)
93
-
94
- self.assertEqual(fun(-1), cfun(-1))
95
- self.assertEqual(fun(0), cfun(0))
96
- self.assertEqual(fun(1), cfun(1))
97
- self.assertEqual(fun(2), cfun(2))
98
- cfun = jax.jit(cfun)
99
- self.assertEqual(fun(-1), cfun(-1))
100
- self.assertEqual(fun(0), cfun(0))
101
- self.assertEqual(fun(1), cfun(1))
102
- self.assertEqual(fun(2), cfun(2))
103
-
104
- def testSwitchResidualsMerge(self):
105
- def get_conds(fun):
106
- jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
107
- return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
108
-
109
- def branch_invars_len(cond_eqn):
110
- lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
111
- assert len(set(lens)) == 1
112
- return lens[0]
113
-
114
- def branch_outvars_len(cond_eqn):
115
- lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
116
- assert len(set(lens)) == 1
117
- return lens[0]
118
-
119
- branches1 = [lambda x: jnp.sin(x),
120
- lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
121
- branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
122
- branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
123
-
124
- def fun1(x, i):
125
- return bst.transform.switch(i + 1, branches1, x)
126
-
127
- def fun2(x, i):
128
- return bst.transform.switch(i + 1, branches2, x)
129
-
130
- def fun3(x, i):
131
- return bst.transform.switch(i + 1, branches3, x)
132
-
133
- fwd1, bwd1 = get_conds(fun1)
134
- fwd2, bwd2 = get_conds(fun2)
135
- fwd3, bwd3 = get_conds(fun3)
136
-
137
- fwd1_num_out = branch_outvars_len(fwd1)
138
- fwd2_num_out = branch_outvars_len(fwd2)
139
- fwd3_num_out = branch_outvars_len(fwd3)
140
- assert fwd1_num_out == fwd2_num_out
141
- assert fwd3_num_out == fwd2_num_out + 1
142
-
143
- bwd1_num_in = branch_invars_len(bwd1)
144
- bwd2_num_in = branch_invars_len(bwd2)
145
- bwd3_num_in = branch_invars_len(bwd3)
146
- assert bwd1_num_in == bwd2_num_in
147
- assert bwd3_num_in == bwd2_num_in + 1
148
-
149
- def testOneBranchSwitch(self):
150
- branch = lambda x: -x
151
- f = lambda i, x: bst.transform.switch(i, [branch], x)
152
- x = 7.
153
- self.assertEqual(f(-1, x), branch(x))
154
- self.assertEqual(f(0, x), branch(x))
155
- self.assertEqual(f(1, x), branch(x))
156
- cf = jax.jit(f)
157
- self.assertEqual(cf(-1, x), branch(x))
158
- self.assertEqual(cf(0, x), branch(x))
159
- self.assertEqual(cf(1, x), branch(x))
160
- cf = jax.jit(f, static_argnums=0)
161
- self.assertEqual(cf(-1, x), branch(x))
162
- self.assertEqual(cf(0, x), branch(x))
163
- self.assertEqual(cf(1, x), branch(x))
164
-
165
-
166
- class TestIfElse(unittest.TestCase):
167
- def test1(self):
168
- def f(a):
169
- return bst.transform.ifelse(conditions=[a < 0,
170
- a >= 0 and a < 2,
171
- a >= 2 and a < 5,
172
- a >= 5 and a < 10,
173
- a >= 10],
174
- branches=[lambda: 1,
175
- lambda: 2,
176
- lambda: 3,
177
- lambda: 4,
178
- lambda: 5])
179
-
180
- self.assertTrue(f(3) == 3)
181
- self.assertTrue(f(1) == 2)
182
- self.assertTrue(f(-1) == 1)
183
-
184
- def test_vmap(self):
185
- def f(operands):
186
- f = lambda a: bst.transform.ifelse([a > 10,
187
- jnp.logical_and(a <= 10, a > 5),
188
- jnp.logical_and(a <= 5, a > 2),
189
- jnp.logical_and(a <= 2, a > 0),
190
- a <= 0],
191
- [lambda _: 1,
192
- lambda _: 2,
193
- lambda _: 3,
194
- lambda _: 4,
195
- lambda _: 5, ],
196
- a)
197
- return jax.vmap(f)(operands)
198
-
199
- r = f(bst.random.randint(-20, 20, 200))
200
- self.assertTrue(r.size == 200)
201
-
202
- def test_grad1(self):
203
- def F2(x):
204
- return bst.transform.ifelse((x >= 10, x < 10),
205
- [lambda x: x, lambda x: x ** 2, ],
206
- x)
207
-
208
- self.assertTrue(jax.grad(F2)(9.0) == 18.)
209
- self.assertTrue(jax.grad(F2)(11.0) == 1.)
210
-
211
- def test_grad2(self):
212
- def F3(x):
213
- return bst.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
- [lambda x: x,
215
- lambda x: x ** 2,
216
- lambda x: x ** 4, ],
217
- x)
218
-
219
- self.assertTrue(jax.grad(F3)(9.0) == 18.)
220
- self.assertTrue(jax.grad(F3)(11.0) == 1.)
@@ -1,94 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import functools
19
- from functools import partial
20
- from typing import Callable, Union
21
-
22
- import jax
23
-
24
- from brainstate._utils import set_module_as
25
- from ._unvmap import unvmap
26
-
27
- __all__ = [
28
- 'jit_error_if',
29
- ]
30
-
31
-
32
- def _err_jit_true_branch(err_fun, args, kwargs):
33
- jax.debug.callback(err_fun, *args, **kwargs)
34
-
35
-
36
- def _err_jit_false_branch(args, kwargs):
37
- pass
38
-
39
-
40
- def _error_msg(msg, *arg, **kwargs):
41
- if len(arg):
42
- msg = msg % arg
43
- if len(kwargs):
44
- msg = msg.format(**kwargs)
45
- raise ValueError(msg)
46
-
47
-
48
- @set_module_as('brainstate.transform')
49
- def jit_error_if(
50
- pred,
51
- error: Union[Callable, str],
52
- *err_args,
53
- **err_kwargs,
54
- ):
55
- """
56
- Check errors in a jit function.
57
-
58
- Examples
59
- --------
60
-
61
- It can give a function which receive arguments that passed from the JIT variables and raise errors.
62
-
63
- >>> def error(x):
64
- >>> raise ValueError(f'error {x}')
65
- >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
66
- >>> jit_error_if(x.sum() < 5., error, x)
67
-
68
- Or, it can be a simple string message.
69
-
70
- >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
71
- >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
72
-
73
-
74
- Parameters
75
- ----------
76
- pred: bool, Array
77
- The boolean prediction.
78
- error: callable, str
79
- The error function, which raise errors, or a string indicating the error message.
80
- err_args:
81
- The arguments which passed into `err_f`.
82
- err_kwargs:
83
- The keywords which passed into `err_f`.
84
- """
85
- if isinstance(error, str):
86
- error = partial(_error_msg, error)
87
-
88
- jax.lax.cond(
89
- unvmap(pred, op='any'),
90
- partial(_err_jit_true_branch, error),
91
- _err_jit_false_branch,
92
- jax.tree.map(functools.partial(unvmap, op='none'), err_args),
93
- jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
94
- )
@@ -1,55 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- import jaxlib.xla_extension
21
-
22
- import brainstate as bst
23
-
24
-
25
- class TestJitError(unittest.TestCase):
26
- def test1(self):
27
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
28
- bst.transform.jit_error_if(True, 'error')
29
-
30
- def err_f(x):
31
- raise ValueError(f'error: {x}')
32
-
33
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
34
- bst.transform.jit_error_if(True, err_f, 1.)
35
-
36
- def test_vmap(self):
37
-
38
- def f(x):
39
- bst.transform.jit_error_if(x, 'error: {x}', x=x)
40
-
41
- jax.vmap(f)(jnp.array([False, False, False]))
42
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
43
- jax.vmap(f)(jnp.array([True, False, False]))
44
-
45
- def test_vmap_vmap(self):
46
-
47
- def f(x):
48
- bst.transform.jit_error_if(x, 'error: {x}', x=x)
49
-
50
- jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
51
- [False, False, False]]))
52
- with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
53
- jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
54
- [True, False, False]]))
55
-