brainstate 0.0.2.post20240825__py2.py3-none-any.whl → 0.0.2.post20240826__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +6 -6
- brainstate/random.py +5 -5
- brainstate/transform/__init__.py +16 -6
- brainstate/transform/_conditions.py +334 -0
- brainstate/transform/{_controls_test.py → _conditions_test.py} +35 -35
- brainstate/transform/_error_if.py +94 -0
- brainstate/transform/{_jit_error_test.py → _error_if_test.py} +4 -4
- brainstate/transform/_loop_collect_return.py +502 -0
- brainstate/transform/_loop_no_collection.py +170 -0
- brainstate/transform/_mapping.py +109 -0
- brainstate/transform/_unvmap.py +143 -0
- brainstate/typing.py +55 -1
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/METADATA +2 -2
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/RECORD +17 -13
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/WHEEL +1 -1
- brainstate/transform/_control.py +0 -665
- brainstate/transform/_jit_error.py +0 -180
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/top_level.txt +0 -0
brainstate/_module.py
CHANGED
@@ -60,7 +60,7 @@ from . import environ
|
|
60
60
|
from ._state import State, StateDictManager, visible_state_dict
|
61
61
|
from ._utils import set_module_as
|
62
62
|
from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
|
63
|
-
from .transform import
|
63
|
+
from .transform import jit_error_if
|
64
64
|
from .typing import Size, ArrayLike, PyTree
|
65
65
|
from .util import unique_name, DictManager, get_unique_name
|
66
66
|
|
@@ -1212,7 +1212,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1212
1212
|
raise ValueError(f'The request delay length should be less than the '
|
1213
1213
|
f'maximum delay {self.max_length - 1}. But we got {delay_len}')
|
1214
1214
|
|
1215
|
-
|
1215
|
+
jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
|
1216
1216
|
|
1217
1217
|
# rotation method
|
1218
1218
|
if self.delay_method == _DELAY_ROTATE:
|
@@ -1263,10 +1263,10 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1263
1263
|
f'[{t_now - self.max_time - dt}, {t_now}], '
|
1264
1264
|
f'but we got {t_delay}')
|
1265
1265
|
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1266
|
+
jit_error_if(jnp.logical_or(delay_time > current_time,
|
1267
|
+
delay_time < current_time - self.max_time - dt),
|
1268
|
+
_check_delay,
|
1269
|
+
current_time, delay_time)
|
1270
1270
|
|
1271
1271
|
diff = current_time - delay_time
|
1272
1272
|
float_time_step = diff / dt
|
brainstate/random.py
CHANGED
@@ -33,7 +33,7 @@ from jax import lax, core, dtypes
|
|
33
33
|
from brainstate import environ
|
34
34
|
from ._random_for_unit import uniform_for_unit, permutation_for_unit
|
35
35
|
from ._state import State
|
36
|
-
from .transform.
|
36
|
+
from .transform._error_if import jit_error_if
|
37
37
|
from .typing import DTypeLike, Size, SeedOrKey
|
38
38
|
|
39
39
|
__all__ = [
|
@@ -498,7 +498,7 @@ class RandomState(State):
|
|
498
498
|
bu.Quantity(scale).in_unit(unit).mantissa
|
499
499
|
)
|
500
500
|
|
501
|
-
|
501
|
+
jit_error_if(
|
502
502
|
bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
503
503
|
"mean is more than 2 std from [lower, upper] in truncated_normal. "
|
504
504
|
"The distribution of values may be incorrect."
|
@@ -549,7 +549,7 @@ class RandomState(State):
|
|
549
549
|
size: Optional[Size] = None,
|
550
550
|
key: Optional[SeedOrKey] = None):
|
551
551
|
p = _check_py_seq(p)
|
552
|
-
|
552
|
+
jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
|
553
553
|
if size is None:
|
554
554
|
size = jnp.shape(p)
|
555
555
|
key = self.split_key() if key is None else _formalize_key(key)
|
@@ -592,7 +592,7 @@ class RandomState(State):
|
|
592
592
|
dtype: DTypeLike = None):
|
593
593
|
n = _check_py_seq(n)
|
594
594
|
p = _check_py_seq(p)
|
595
|
-
|
595
|
+
jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
|
596
596
|
if size is None:
|
597
597
|
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
|
598
598
|
key = self.split_key() if key is None else _formalize_key(key)
|
@@ -656,7 +656,7 @@ class RandomState(State):
|
|
656
656
|
key = self.split_key() if key is None else _formalize_key(key)
|
657
657
|
n = _check_py_seq(n)
|
658
658
|
pvals = _check_py_seq(pvals)
|
659
|
-
|
659
|
+
jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
|
660
660
|
if isinstance(n, jax.core.Tracer):
|
661
661
|
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
|
662
662
|
size = _size2shape(size)
|
brainstate/transform/__init__.py
CHANGED
@@ -19,17 +19,27 @@ This module contains the functions for the transformation of the brain data.
|
|
19
19
|
|
20
20
|
from ._autograd import *
|
21
21
|
from ._autograd import __all__ as _gradients_all
|
22
|
-
from .
|
23
|
-
from .
|
22
|
+
from ._conditions import *
|
23
|
+
from ._conditions import __all__ as _conditions_all
|
24
|
+
from ._error_if import *
|
25
|
+
from ._error_if import __all__ as _jit_error_all
|
24
26
|
from ._jit import *
|
25
27
|
from ._jit import __all__ as _jit_all
|
26
|
-
from .
|
27
|
-
from .
|
28
|
+
from ._loop_collect_return import *
|
29
|
+
from ._loop_collect_return import __all__ as _loops_all
|
30
|
+
from ._loop_no_collection import *
|
31
|
+
from ._loop_no_collection import __all__ as _loops_no_collection_all
|
28
32
|
from ._make_jaxpr import *
|
29
33
|
from ._make_jaxpr import __all__ as _make_jaxpr_all
|
34
|
+
from ._mapping import *
|
35
|
+
from ._mapping import __all__ as _mapping_all
|
30
36
|
from ._progress_bar import *
|
31
37
|
from ._progress_bar import __all__ as _progress_bar_all
|
32
38
|
|
33
|
-
__all__ = _gradients_all + _jit_error_all +
|
39
|
+
__all__ = (_gradients_all + _jit_error_all + _conditions_all + _loops_all +
|
40
|
+
_make_jaxpr_all + _jit_all + _progress_bar_all + _loops_no_collection_all +
|
41
|
+
_mapping_all)
|
34
42
|
|
35
|
-
del _gradients_all, _jit_error_all,
|
43
|
+
del (_gradients_all, _jit_error_all, _conditions_all, _loops_all,
|
44
|
+
_make_jaxpr_all, _jit_all, _progress_bar_all, _loops_no_collection_all,
|
45
|
+
_mapping_all)
|
@@ -0,0 +1,334 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import operator
|
19
|
+
from collections.abc import Callable, Sequence
|
20
|
+
from functools import wraps, reduce
|
21
|
+
|
22
|
+
import jax
|
23
|
+
import jax.numpy as jnp
|
24
|
+
import numpy as np
|
25
|
+
|
26
|
+
from brainstate._utils import set_module_as
|
27
|
+
from ._error_if import jit_error_if
|
28
|
+
from ._make_jaxpr import StatefulFunction, _assign_state_values
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
'cond', 'switch', 'ifelse',
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
def _wrapped_fun(stateful_fun: StatefulFunction, states, return_states=True):
|
36
|
+
@wraps(stateful_fun.fun)
|
37
|
+
def wrapped_branch(state_vals, *operands):
|
38
|
+
assert len(states) == len(state_vals)
|
39
|
+
for st, val in zip(states, state_vals):
|
40
|
+
st.value = val
|
41
|
+
out = stateful_fun.jaxpr_call_auto(*operands)
|
42
|
+
if return_states:
|
43
|
+
return tuple(st.value for st in states), out
|
44
|
+
return out
|
45
|
+
|
46
|
+
return wrapped_branch
|
47
|
+
|
48
|
+
|
49
|
+
@set_module_as('brainstate.transform')
|
50
|
+
def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
51
|
+
"""
|
52
|
+
Conditionally apply ``true_fun`` or ``false_fun``.
|
53
|
+
|
54
|
+
Provided arguments are correctly typed, ``cond()`` has equivalent
|
55
|
+
semantics to this Python implementation, where ``pred`` must be a
|
56
|
+
scalar type::
|
57
|
+
|
58
|
+
def cond(pred, true_fun, false_fun, *operands):
|
59
|
+
if pred:
|
60
|
+
return true_fun(*operands)
|
61
|
+
else:
|
62
|
+
return false_fun(*operands)
|
63
|
+
|
64
|
+
|
65
|
+
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
|
66
|
+
the two branches is executed (up to compiler rewrites and optimizations).
|
67
|
+
However, when transformed with :func:`~jax.vmap` to operate over a batch of
|
68
|
+
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
pred: Boolean scalar type, indicating which branch function to apply.
|
72
|
+
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
73
|
+
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
74
|
+
operands: Operands (A) input to either branch depending on ``pred``. The
|
75
|
+
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
76
|
+
thereof.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
80
|
+
depending on the value of ``pred``. The type can be a scalar, array, or any
|
81
|
+
pytree (nested Python tuple/list/dict) thereof.
|
82
|
+
"""
|
83
|
+
if not (callable(true_fun) and callable(false_fun)):
|
84
|
+
raise TypeError("true_fun and false_fun arguments should be callable.")
|
85
|
+
|
86
|
+
if pred is None:
|
87
|
+
raise TypeError("cond predicate is None")
|
88
|
+
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
89
|
+
raise TypeError(f"Pred must be a scalar, got {pred} of " +
|
90
|
+
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
91
|
+
else f"shape {np.shape(pred)}."))
|
92
|
+
|
93
|
+
# check pred
|
94
|
+
try:
|
95
|
+
pred_dtype = jax.dtypes.result_type(pred)
|
96
|
+
except TypeError as err:
|
97
|
+
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
|
98
|
+
if pred_dtype.kind != 'b':
|
99
|
+
if pred_dtype.kind in 'iuf':
|
100
|
+
pred = pred != 0
|
101
|
+
else:
|
102
|
+
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
|
103
|
+
|
104
|
+
# not jit
|
105
|
+
if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
|
106
|
+
if pred:
|
107
|
+
return true_fun(*operands)
|
108
|
+
else:
|
109
|
+
return false_fun(*operands)
|
110
|
+
|
111
|
+
# evaluate jaxpr
|
112
|
+
true_fun_wrap = StatefulFunction(true_fun).make_jaxpr(*operands)
|
113
|
+
false_fun_wrap = StatefulFunction(false_fun).make_jaxpr(*operands)
|
114
|
+
|
115
|
+
# wrap the functions
|
116
|
+
all_states = tuple(set(true_fun_wrap.get_states() + false_fun_wrap.get_states()))
|
117
|
+
true_fun = _wrapped_fun(true_fun_wrap, all_states)
|
118
|
+
false_fun = _wrapped_fun(false_fun_wrap, all_states)
|
119
|
+
|
120
|
+
# operands
|
121
|
+
operands = ([st.value for st in all_states],) + operands
|
122
|
+
|
123
|
+
# cond
|
124
|
+
state_vals, out = jax.lax.cond(pred, true_fun, false_fun, *operands)
|
125
|
+
_assign_state_values(all_states, state_vals)
|
126
|
+
return out
|
127
|
+
|
128
|
+
# ops, ops_tree = jax.tree.flatten(operands)
|
129
|
+
# linear_ops = [False] * len(ops)
|
130
|
+
# ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
|
131
|
+
#
|
132
|
+
# # true and false jaxprs
|
133
|
+
# jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
134
|
+
# (true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
135
|
+
# if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
|
136
|
+
# raise ValueError("Cannot pass `Ref`s into `cond`.")
|
137
|
+
# true_jaxpr, false_jaxpr = jaxprs
|
138
|
+
# out_tree, false_out_tree = out_trees
|
139
|
+
# if any(isinstance(out_aval, state.AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals):
|
140
|
+
# raise ValueError("Cannot return `Ref`s from `cond`.")
|
141
|
+
#
|
142
|
+
# _check_tree_and_avals("true_fun and false_fun output",
|
143
|
+
# out_tree, true_jaxpr.out_avals,
|
144
|
+
# false_out_tree, false_jaxpr.out_avals)
|
145
|
+
# joined_effects = jax.core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
|
146
|
+
# disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
147
|
+
# if disallowed_effects:
|
148
|
+
# raise NotImplementedError(f'Effects not supported in `cond`: {disallowed_effects}')
|
149
|
+
#
|
150
|
+
# # replace jaxpr effects
|
151
|
+
# index = jax.lax.convert_element_type(pred, np.int32)
|
152
|
+
# if joined_effects:
|
153
|
+
# # Raise index in case of effects to allow data-dependence-based discharging
|
154
|
+
# # of those effects (even if they don't have an explicit data dependence).
|
155
|
+
# index = jax.core.raise_as_much_as_possible(index)
|
156
|
+
# false_jaxpr = _replace_jaxpr_effects(false_jaxpr, joined_effects)
|
157
|
+
# true_jaxpr = _replace_jaxpr_effects(true_jaxpr, joined_effects)
|
158
|
+
#
|
159
|
+
# # bind
|
160
|
+
# linear = [False] * len(consts) + linear_ops
|
161
|
+
# cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
|
162
|
+
#
|
163
|
+
# # outputs
|
164
|
+
# st_vals, out = jax.tree.unflatten(out_tree, cond_outs)
|
165
|
+
# for st, val in zip(all_states, st_vals):
|
166
|
+
# st.value = val
|
167
|
+
# return out
|
168
|
+
|
169
|
+
|
170
|
+
@set_module_as('brainstate.transform')
|
171
|
+
def switch(index, branches: Sequence[Callable], *operands):
|
172
|
+
"""
|
173
|
+
Apply exactly one of ``branches`` given by ``index``.
|
174
|
+
|
175
|
+
If ``index`` is out of bounds, it is clamped to within bounds.
|
176
|
+
|
177
|
+
Has the semantics of the following Python::
|
178
|
+
|
179
|
+
def switch(index, branches, *operands):
|
180
|
+
index = clamp(0, index, len(branches) - 1)
|
181
|
+
return branches[index](*operands)
|
182
|
+
|
183
|
+
Internally this wraps XLA's `Conditional
|
184
|
+
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
|
185
|
+
operator. However, when transformed with :func:`~jax.vmap` to operate over a
|
186
|
+
batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
index: Integer scalar type, indicating which branch function to apply.
|
190
|
+
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
191
|
+
operands: Operands (A) input to whichever branch is applied.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
195
|
+
on ``index``.
|
196
|
+
"""
|
197
|
+
# check branches
|
198
|
+
if not all(callable(branch) for branch in branches):
|
199
|
+
raise TypeError("branches argument should be a sequence of callables.")
|
200
|
+
|
201
|
+
# check index
|
202
|
+
if len(np.shape(index)) != 0:
|
203
|
+
raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
|
204
|
+
try:
|
205
|
+
index_dtype = jax.dtypes.result_type(index)
|
206
|
+
except TypeError as err:
|
207
|
+
msg = f"Index type must be an integer, got {index}."
|
208
|
+
raise TypeError(msg) from err
|
209
|
+
if index_dtype.kind not in 'iu':
|
210
|
+
raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
|
211
|
+
|
212
|
+
# format branches
|
213
|
+
branches = tuple(branches)
|
214
|
+
if len(branches) == 0:
|
215
|
+
raise ValueError("Empty branch sequence")
|
216
|
+
elif len(branches) == 1:
|
217
|
+
return branches[0](*operands)
|
218
|
+
|
219
|
+
# format index
|
220
|
+
index = jax.lax.convert_element_type(index, np.int32)
|
221
|
+
lo = np.array(0, np.int32)
|
222
|
+
hi = np.array(len(branches) - 1, np.int32)
|
223
|
+
index = jax.lax.clamp(lo, index, hi)
|
224
|
+
|
225
|
+
# not jit
|
226
|
+
if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
|
227
|
+
return branches[int(index)](*operands)
|
228
|
+
|
229
|
+
# evaluate jaxpr
|
230
|
+
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
231
|
+
for wrapped_branch in wrapped_branches:
|
232
|
+
wrapped_branch.make_jaxpr(*operands)
|
233
|
+
|
234
|
+
# wrap the functions
|
235
|
+
all_states = tuple(set(reduce(operator.add, [wrapped_branch.get_states() for wrapped_branch in wrapped_branches])))
|
236
|
+
branches = tuple(_wrapped_fun(wrapped_branch, all_states) for wrapped_branch in wrapped_branches)
|
237
|
+
|
238
|
+
# operands
|
239
|
+
operands = ([st.value for st in all_states],) + operands
|
240
|
+
|
241
|
+
# switch
|
242
|
+
state_vals, out = jax.lax.switch(index, branches, *operands)
|
243
|
+
_assign_state_values(all_states, state_vals)
|
244
|
+
return out
|
245
|
+
|
246
|
+
# ops, ops_tree = jax.tree.flatten(operands)
|
247
|
+
# ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
|
248
|
+
#
|
249
|
+
# # true jaxprs
|
250
|
+
# jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
251
|
+
# branches, ops_tree, ops_avals, primitive_name='switch')
|
252
|
+
# for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
253
|
+
# _check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
254
|
+
# out_trees[0], jaxprs[0].out_avals,
|
255
|
+
# out_tree, jaxpr.out_avals)
|
256
|
+
# joined_effects = jax.core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
257
|
+
# disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
258
|
+
# if disallowed_effects:
|
259
|
+
# raise NotImplementedError(f'Effects not supported in `switch`: {disallowed_effects}')
|
260
|
+
# if joined_effects:
|
261
|
+
# # Raise index in case of effects to allow data-dependence-based discharging
|
262
|
+
# # of those effects (even if they don't have an explicit data dependence).
|
263
|
+
# index = jax.core.raise_as_much_as_possible(index)
|
264
|
+
#
|
265
|
+
# # bind
|
266
|
+
# linear = (False,) * (len(consts) + len(ops))
|
267
|
+
# cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
|
268
|
+
#
|
269
|
+
# # outputs
|
270
|
+
# st_vals, out = jax.tree.unflatten(out_trees[0], cond_outs)
|
271
|
+
# for st, val in zip(all_states, st_vals):
|
272
|
+
# st.value = val
|
273
|
+
# return out
|
274
|
+
|
275
|
+
|
276
|
+
@set_module_as('brainstate.transform')
|
277
|
+
def ifelse(conditions, branches, *operands, check_cond: bool = True):
|
278
|
+
"""
|
279
|
+
``If-else`` control flows looks like native Pythonic programming.
|
280
|
+
|
281
|
+
Examples
|
282
|
+
--------
|
283
|
+
|
284
|
+
>>> import brainstate as bst
|
285
|
+
>>> def f(a):
|
286
|
+
>>> return bst.transform.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
|
287
|
+
>>> branches=[lambda: 1,
|
288
|
+
>>> lambda: 2,
|
289
|
+
>>> lambda: 3,
|
290
|
+
>>> lambda: 4,
|
291
|
+
>>> lambda: 5])
|
292
|
+
>>> f(1)
|
293
|
+
4
|
294
|
+
>>> f(0)
|
295
|
+
5
|
296
|
+
|
297
|
+
Parameters
|
298
|
+
----------
|
299
|
+
conditions: bool, sequence of bool, Array
|
300
|
+
The boolean conditions.
|
301
|
+
branches: Any
|
302
|
+
The branches, at least has two elements. Elements can be functions,
|
303
|
+
arrays, or numbers. The number of ``branches`` and ``conditions`` has
|
304
|
+
the relationship of `len(branches) == len(conditions) + 1`.
|
305
|
+
Each branch should receive one arguement for ``operands``.
|
306
|
+
*operands: optional, Any
|
307
|
+
The operands for each branch.
|
308
|
+
check_cond: bool
|
309
|
+
Whether to check the conditions. Default is True.
|
310
|
+
|
311
|
+
Returns
|
312
|
+
-------
|
313
|
+
res: Any
|
314
|
+
The results of the control flow.
|
315
|
+
"""
|
316
|
+
# check branches
|
317
|
+
if not all(callable(branch) for branch in branches):
|
318
|
+
raise TypeError("branches argument should be a sequence of callables.")
|
319
|
+
|
320
|
+
# format branches
|
321
|
+
branches = tuple(branches)
|
322
|
+
if len(branches) == 0:
|
323
|
+
raise ValueError("Empty branch sequence")
|
324
|
+
elif len(branches) == 1:
|
325
|
+
return branches[0](*operands)
|
326
|
+
if len(conditions) != len(branches):
|
327
|
+
raise ValueError("The number of conditions should be equal to the number of branches.")
|
328
|
+
|
329
|
+
# format index
|
330
|
+
conditions = jnp.asarray(conditions, np.int32)
|
331
|
+
if check_cond:
|
332
|
+
jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
|
333
|
+
index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
|
334
|
+
return switch(index, branches, *operands)
|
@@ -18,20 +18,20 @@ import unittest
|
|
18
18
|
import jax
|
19
19
|
import jax.numpy as jnp
|
20
20
|
|
21
|
-
import brainstate as
|
21
|
+
import brainstate as bst
|
22
22
|
|
23
23
|
|
24
24
|
class TestCond(unittest.TestCase):
|
25
25
|
def test1(self):
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
bst.random.seed(1)
|
27
|
+
bst.transform.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
|
28
|
+
bst.transform.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
|
29
29
|
|
30
30
|
def test2(self):
|
31
|
-
st1 =
|
32
|
-
st2 =
|
33
|
-
st3 =
|
34
|
-
st4 =
|
31
|
+
st1 = bst.State(bst.random.rand(10))
|
32
|
+
st2 = bst.State(bst.random.rand(2))
|
33
|
+
st3 = bst.State(bst.random.rand(5))
|
34
|
+
st4 = bst.State(bst.random.rand(2, 10))
|
35
35
|
|
36
36
|
def true_fun(x):
|
37
37
|
st1.value = st2.value @ st4.value + x
|
@@ -39,7 +39,7 @@ class TestCond(unittest.TestCase):
|
|
39
39
|
def false_fun(x):
|
40
40
|
st3.value = (st3.value + 1.) * x
|
41
41
|
|
42
|
-
|
42
|
+
bst.transform.cond(True, true_fun, false_fun, 2.)
|
43
43
|
assert not isinstance(st1.value, jax.core.Tracer)
|
44
44
|
assert not isinstance(st2.value, jax.core.Tracer)
|
45
45
|
assert not isinstance(st3.value, jax.core.Tracer)
|
@@ -65,7 +65,7 @@ class TestSwitch(unittest.TestCase):
|
|
65
65
|
return branches[2](x)
|
66
66
|
|
67
67
|
def cfun(x):
|
68
|
-
return
|
68
|
+
return bst.transform.switch(x, branches, x)
|
69
69
|
|
70
70
|
self.assertEqual(fun(-1), cfun(-1))
|
71
71
|
self.assertEqual(fun(0), cfun(0))
|
@@ -89,7 +89,7 @@ class TestSwitch(unittest.TestCase):
|
|
89
89
|
return branches[i](x, x)
|
90
90
|
|
91
91
|
def cfun(x):
|
92
|
-
return
|
92
|
+
return bst.transform.switch(x, branches, x, x)
|
93
93
|
|
94
94
|
self.assertEqual(fun(-1), cfun(-1))
|
95
95
|
self.assertEqual(fun(0), cfun(0))
|
@@ -122,13 +122,13 @@ class TestSwitch(unittest.TestCase):
|
|
122
122
|
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
123
123
|
|
124
124
|
def fun1(x, i):
|
125
|
-
return
|
125
|
+
return bst.transform.switch(i + 1, branches1, x)
|
126
126
|
|
127
127
|
def fun2(x, i):
|
128
|
-
return
|
128
|
+
return bst.transform.switch(i + 1, branches2, x)
|
129
129
|
|
130
130
|
def fun3(x, i):
|
131
|
-
return
|
131
|
+
return bst.transform.switch(i + 1, branches3, x)
|
132
132
|
|
133
133
|
fwd1, bwd1 = get_conds(fun1)
|
134
134
|
fwd2, bwd2 = get_conds(fun2)
|
@@ -148,7 +148,7 @@ class TestSwitch(unittest.TestCase):
|
|
148
148
|
|
149
149
|
def testOneBranchSwitch(self):
|
150
150
|
branch = lambda x: -x
|
151
|
-
f = lambda i, x:
|
151
|
+
f = lambda i, x: bst.transform.switch(i, [branch], x)
|
152
152
|
x = 7.
|
153
153
|
self.assertEqual(f(-1, x), branch(x))
|
154
154
|
self.assertEqual(f(0, x), branch(x))
|
@@ -166,12 +166,12 @@ class TestSwitch(unittest.TestCase):
|
|
166
166
|
class TestIfElse(unittest.TestCase):
|
167
167
|
def test1(self):
|
168
168
|
def f(a):
|
169
|
-
return
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
169
|
+
return bst.transform.ifelse(conditions=[a < 0,
|
170
|
+
a >= 0 and a < 2,
|
171
|
+
a >= 2 and a < 5,
|
172
|
+
a >= 5 and a < 10,
|
173
|
+
a >= 10],
|
174
|
+
branches=[lambda: 1,
|
175
175
|
lambda: 2,
|
176
176
|
lambda: 3,
|
177
177
|
lambda: 4,
|
@@ -183,38 +183,38 @@ class TestIfElse(unittest.TestCase):
|
|
183
183
|
|
184
184
|
def test_vmap(self):
|
185
185
|
def f(operands):
|
186
|
-
f = lambda a:
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
186
|
+
f = lambda a: bst.transform.ifelse([a > 10,
|
187
|
+
jnp.logical_and(a <= 10, a > 5),
|
188
|
+
jnp.logical_and(a <= 5, a > 2),
|
189
|
+
jnp.logical_and(a <= 2, a > 0),
|
190
|
+
a <= 0],
|
191
|
+
[lambda _: 1,
|
192
192
|
lambda _: 2,
|
193
193
|
lambda _: 3,
|
194
194
|
lambda _: 4,
|
195
195
|
lambda _: 5, ],
|
196
|
-
|
196
|
+
a)
|
197
197
|
return jax.vmap(f)(operands)
|
198
198
|
|
199
|
-
r = f(
|
199
|
+
r = f(bst.random.randint(-20, 20, 200))
|
200
200
|
self.assertTrue(r.size == 200)
|
201
201
|
|
202
202
|
def test_grad1(self):
|
203
203
|
def F2(x):
|
204
|
-
return
|
205
|
-
|
206
|
-
|
204
|
+
return bst.transform.ifelse((x >= 10, x < 10),
|
205
|
+
[lambda x: x, lambda x: x ** 2, ],
|
206
|
+
x)
|
207
207
|
|
208
208
|
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
209
209
|
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
210
210
|
|
211
211
|
def test_grad2(self):
|
212
212
|
def F3(x):
|
213
|
-
return
|
214
|
-
|
213
|
+
return bst.transform.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
|
+
[lambda x: x,
|
215
215
|
lambda x: x ** 2,
|
216
216
|
lambda x: x ** 4, ],
|
217
|
-
|
217
|
+
x)
|
218
218
|
|
219
219
|
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
220
220
|
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|