brainstate 0.0.2.post20240825__py2.py3-none-any.whl → 0.0.2.post20240910__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_module.py CHANGED
@@ -60,7 +60,7 @@ from . import environ
60
60
  from ._state import State, StateDictManager, visible_state_dict
61
61
  from ._utils import set_module_as
62
62
  from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
- from .transform import jit_error
63
+ from .transform import jit_error_if
64
64
  from .typing import Size, ArrayLike, PyTree
65
65
  from .util import unique_name, DictManager, get_unique_name
66
66
 
@@ -1212,7 +1212,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1212
1212
  raise ValueError(f'The request delay length should be less than the '
1213
1213
  f'maximum delay {self.max_length - 1}. But we got {delay_len}')
1214
1214
 
1215
- jit_error(delay_step >= self.max_length, _check_delay, delay_step)
1215
+ jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
1216
1216
 
1217
1217
  # rotation method
1218
1218
  if self.delay_method == _DELAY_ROTATE:
@@ -1263,10 +1263,10 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1263
1263
  f'[{t_now - self.max_time - dt}, {t_now}], '
1264
1264
  f'but we got {t_delay}')
1265
1265
 
1266
- jit_error(jnp.logical_or(delay_time > current_time,
1267
- delay_time < current_time - self.max_time - dt),
1268
- _check_delay,
1269
- current_time, delay_time)
1266
+ jit_error_if(jnp.logical_or(delay_time > current_time,
1267
+ delay_time < current_time - self.max_time - dt),
1268
+ _check_delay,
1269
+ current_time, delay_time)
1270
1270
 
1271
1271
  diff = current_time - delay_time
1272
1272
  float_time_step = diff / dt
@@ -1597,6 +1597,6 @@ def _get_delay(delay_time, delay_step):
1597
1597
  delay_time = delay_step * environ.get_dt()
1598
1598
  else:
1599
1599
  assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
1600
- assert isinstance(delay_time, (int, float))
1600
+ # assert isinstance(delay_time, (int, float))
1601
1601
  delay_step = math.ceil(delay_time / environ.get_dt())
1602
1602
  return delay_time, delay_step
brainstate/_state.py CHANGED
@@ -234,6 +234,8 @@ class ShapeDtype:
234
234
  def __init__(self, shape, dtype):
235
235
  self.shape = shape
236
236
  self.dtype = dtype
237
+ self.ndim = len(shape)
238
+ self.size = np.prod(shape)
237
239
 
238
240
  def __repr__(self):
239
241
  return f'{self.dtype}{list(self.shape)}'
@@ -15,6 +15,8 @@
15
15
 
16
16
  from typing import Optional, Union
17
17
 
18
+
19
+ import brainunit as u
18
20
  from brainstate._module import (register_delay_of_target,
19
21
  Projection,
20
22
  Module,
@@ -278,11 +280,19 @@ class FullProjAlignPostMg(Projection):
278
280
  self.comm = comm
279
281
 
280
282
  # delay initialization
281
- if delay is not None and delay > 0.:
282
- delay_cls = register_delay_of_target(pre)
283
- delay_cls.register_entry(self.name, delay)
284
- self.delay = delay_cls
285
- self.has_delay = True
283
+ if delay is not None:
284
+ if isinstance(delay, u.Quantity):
285
+ has_delay = delay.mantissa > 0.
286
+ else:
287
+ has_delay = delay > 0.
288
+ if has_delay:
289
+ delay_cls = register_delay_of_target(pre)
290
+ delay_cls.register_entry(self.name, delay)
291
+ self.delay = delay_cls
292
+ self.has_delay = True
293
+ else:
294
+ self.delay = None
295
+ self.has_delay = False
286
296
  else:
287
297
  self.delay = None
288
298
  self.has_delay = False
@@ -502,11 +512,19 @@ class FullProjAlignPost(Projection):
502
512
  self.out = out
503
513
 
504
514
  # delay initialization
505
- if delay is not None and delay > 0.:
506
- delay_cls = register_delay_of_target(pre)
507
- delay_cls.register_entry(self.name, delay)
508
- self.delay = delay_cls
509
- self.has_delay = True
515
+ if delay is not None:
516
+ if isinstance(delay, u.Quantity):
517
+ has_delay = delay.mantissa > 0.
518
+ else:
519
+ has_delay = delay > 0.
520
+ if has_delay:
521
+ delay_cls = register_delay_of_target(pre)
522
+ delay_cls.register_entry(self.name, delay)
523
+ self.delay = delay_cls
524
+ self.has_delay = True
525
+ else:
526
+ self.delay = None
527
+ self.has_delay = False
510
528
  else:
511
529
  self.delay = None
512
530
  self.has_delay = False
brainstate/random.py CHANGED
@@ -33,7 +33,7 @@ from jax import lax, core, dtypes
33
33
  from brainstate import environ
34
34
  from ._random_for_unit import uniform_for_unit, permutation_for_unit
35
35
  from ._state import State
36
- from .transform._jit_error import jit_error
36
+ from .transform._error_if import jit_error_if
37
37
  from .typing import DTypeLike, Size, SeedOrKey
38
38
 
39
39
  __all__ = [
@@ -498,7 +498,7 @@ class RandomState(State):
498
498
  bu.Quantity(scale).in_unit(unit).mantissa
499
499
  )
500
500
 
501
- jit_error(
501
+ jit_error_if(
502
502
  bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
503
503
  "mean is more than 2 std from [lower, upper] in truncated_normal. "
504
504
  "The distribution of values may be incorrect."
@@ -549,7 +549,7 @@ class RandomState(State):
549
549
  size: Optional[Size] = None,
550
550
  key: Optional[SeedOrKey] = None):
551
551
  p = _check_py_seq(p)
552
- jit_error(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
552
+ jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
553
553
  if size is None:
554
554
  size = jnp.shape(p)
555
555
  key = self.split_key() if key is None else _formalize_key(key)
@@ -592,7 +592,7 @@ class RandomState(State):
592
592
  dtype: DTypeLike = None):
593
593
  n = _check_py_seq(n)
594
594
  p = _check_py_seq(p)
595
- jit_error(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
595
+ jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
596
596
  if size is None:
597
597
  size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
598
598
  key = self.split_key() if key is None else _formalize_key(key)
@@ -656,7 +656,7 @@ class RandomState(State):
656
656
  key = self.split_key() if key is None else _formalize_key(key)
657
657
  n = _check_py_seq(n)
658
658
  pvals = _check_py_seq(pvals)
659
- jit_error(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
659
+ jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
660
660
  if isinstance(n, jax.core.Tracer):
661
661
  raise ValueError("The total count parameter `n` should not be a jax abstract array.")
662
662
  size = _size2shape(size)
brainstate/surrogate.py CHANGED
@@ -158,6 +158,9 @@ class Sigmoid(Surrogate):
158
158
  def __repr__(self):
159
159
  return f'{self.__class__.__name__}(alpha={self.alpha})'
160
160
 
161
+ def __hash__(self):
162
+ return hash((self.__class__, self.alpha))
163
+
161
164
 
162
165
  def sigmoid(
163
166
  x: jax.Array,
@@ -243,6 +246,9 @@ class PiecewiseQuadratic(Surrogate):
243
246
  def __repr__(self):
244
247
  return f'{self.__class__.__name__}(alpha={self.alpha})'
245
248
 
249
+ def __hash__(self):
250
+ return hash((self.__class__, self.alpha))
251
+
246
252
 
247
253
  def piecewise_quadratic(
248
254
  x: jax.Array,
@@ -339,6 +345,9 @@ class PiecewiseExp(Surrogate):
339
345
  def __repr__(self):
340
346
  return f'{self.__class__.__name__}(alpha={self.alpha})'
341
347
 
348
+ def __hash__(self):
349
+ return hash((self.__class__, self.alpha))
350
+
342
351
 
343
352
  def piecewise_exp(
344
353
  x: jax.Array,
@@ -426,6 +435,9 @@ class SoftSign(Surrogate):
426
435
  def __repr__(self):
427
436
  return f'{self.__class__.__name__}(alpha={self.alpha})'
428
437
 
438
+ def __hash__(self):
439
+ return hash((self.__class__, self.alpha))
440
+
429
441
 
430
442
  def soft_sign(
431
443
  x: jax.Array,
@@ -508,6 +520,9 @@ class Arctan(Surrogate):
508
520
  def __repr__(self):
509
521
  return f'{self.__class__.__name__}(alpha={self.alpha})'
510
522
 
523
+ def __hash__(self):
524
+ return hash((self.__class__, self.alpha))
525
+
511
526
 
512
527
  def arctan(
513
528
  x: jax.Array,
@@ -589,6 +604,9 @@ class NonzeroSignLog(Surrogate):
589
604
  def __repr__(self):
590
605
  return f'{self.__class__.__name__}(alpha={self.alpha})'
591
606
 
607
+ def __hash__(self):
608
+ return hash((self.__class__, self.alpha))
609
+
592
610
 
593
611
  def nonzero_sign_log(
594
612
  x: jax.Array,
@@ -683,6 +701,9 @@ class ERF(Surrogate):
683
701
  def __repr__(self):
684
702
  return f'{self.__class__.__name__}(alpha={self.alpha})'
685
703
 
704
+ def __hash__(self):
705
+ return hash((self.__class__, self.alpha))
706
+
686
707
 
687
708
  def erf(
688
709
  x: jax.Array,
@@ -780,6 +801,9 @@ class PiecewiseLeakyRelu(Surrogate):
780
801
  def __repr__(self):
781
802
  return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
782
803
 
804
+ def __hash__(self):
805
+ return hash((self.__class__, self.c, self.w))
806
+
783
807
 
784
808
  def piecewise_leaky_relu(
785
809
  x: jax.Array,
@@ -898,6 +922,9 @@ class SquarewaveFourierSeries(Surrogate):
898
922
  def __repr__(self):
899
923
  return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
900
924
 
925
+ def __hash__(self):
926
+ return hash((self.__class__, self.n, self.t_period))
927
+
901
928
 
902
929
  def squarewave_fourier_series(
903
930
  x: jax.Array,
@@ -988,6 +1015,9 @@ class S2NN(Surrogate):
988
1015
  def __repr__(self):
989
1016
  return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
990
1017
 
1018
+ def __hash__(self):
1019
+ return hash((self.__class__, self.alpha, self.beta, self.epsilon))
1020
+
991
1021
 
992
1022
  def s2nn(
993
1023
  x: jax.Array,
@@ -1089,6 +1119,9 @@ class QPseudoSpike(Surrogate):
1089
1119
  def __repr__(self):
1090
1120
  return f'{self.__class__.__name__}(alpha={self.alpha})'
1091
1121
 
1122
+ def __hash__(self):
1123
+ return hash((self.__class__, self.alpha))
1124
+
1092
1125
 
1093
1126
  def q_pseudo_spike(
1094
1127
  x: jax.Array,
@@ -1178,6 +1211,9 @@ class LeakyRelu(Surrogate):
1178
1211
  def __repr__(self):
1179
1212
  return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
1180
1213
 
1214
+ def __hash__(self):
1215
+ return hash((self.__class__, self.alpha, self.beta))
1216
+
1181
1217
 
1182
1218
  def leaky_relu(
1183
1219
  x: jax.Array,
@@ -1277,6 +1313,9 @@ class LogTailedRelu(Surrogate):
1277
1313
  def __repr__(self):
1278
1314
  return f'{self.__class__.__name__}(alpha={self.alpha})'
1279
1315
 
1316
+ def __hash__(self):
1317
+ return hash((self.__class__, self.alpha))
1318
+
1280
1319
 
1281
1320
  def log_tailed_relu(
1282
1321
  x: jax.Array,
@@ -1368,6 +1407,9 @@ class ReluGrad(Surrogate):
1368
1407
  def __repr__(self):
1369
1408
  return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
1370
1409
 
1410
+ def __hash__(self):
1411
+ return hash((self.__class__, self.alpha, self.width))
1412
+
1371
1413
 
1372
1414
  def relu_grad(
1373
1415
  x: jax.Array,
@@ -1446,6 +1488,9 @@ class GaussianGrad(Surrogate):
1446
1488
  def __repr__(self):
1447
1489
  return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
1448
1490
 
1491
+ def __hash__(self):
1492
+ return hash((self.__class__, self.alpha, self.sigma))
1493
+
1449
1494
 
1450
1495
  def gaussian_grad(
1451
1496
  x: jax.Array,
@@ -1530,6 +1575,9 @@ class MultiGaussianGrad(Surrogate):
1530
1575
  def __repr__(self):
1531
1576
  return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
1532
1577
 
1578
+ def __hash__(self):
1579
+ return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
1580
+
1533
1581
 
1534
1582
  def multi_gaussian_grad(
1535
1583
  x: jax.Array,
@@ -1615,6 +1663,9 @@ class InvSquareGrad(Surrogate):
1615
1663
  def __repr__(self):
1616
1664
  return f'{self.__class__.__name__}(alpha={self.alpha})'
1617
1665
 
1666
+ def __hash__(self):
1667
+ return hash((self.__class__, self.alpha))
1668
+
1618
1669
 
1619
1670
  def inv_square_grad(
1620
1671
  x: jax.Array,
@@ -1685,6 +1736,9 @@ class SlayerGrad(Surrogate):
1685
1736
  def __repr__(self):
1686
1737
  return f'{self.__class__.__name__}(alpha={self.alpha})'
1687
1738
 
1739
+ def __hash__(self):
1740
+ return hash((self.__class__, self.alpha))
1741
+
1688
1742
 
1689
1743
  def slayer_grad(
1690
1744
  x: jax.Array,
@@ -19,17 +19,27 @@ This module contains the functions for the transformation of the brain data.
19
19
 
20
20
  from ._autograd import *
21
21
  from ._autograd import __all__ as _gradients_all
22
- from ._control import *
23
- from ._control import __all__ as _controls_all
22
+ from ._conditions import *
23
+ from ._conditions import __all__ as _conditions_all
24
+ from ._error_if import *
25
+ from ._error_if import __all__ as _jit_error_all
24
26
  from ._jit import *
25
27
  from ._jit import __all__ as _jit_all
26
- from ._jit_error import *
27
- from ._jit_error import __all__ as _jit_error_all
28
+ from ._loop_collect_return import *
29
+ from ._loop_collect_return import __all__ as _loops_all
30
+ from ._loop_no_collection import *
31
+ from ._loop_no_collection import __all__ as _loops_no_collection_all
28
32
  from ._make_jaxpr import *
29
33
  from ._make_jaxpr import __all__ as _make_jaxpr_all
34
+ from ._mapping import *
35
+ from ._mapping import __all__ as _mapping_all
30
36
  from ._progress_bar import *
31
37
  from ._progress_bar import __all__ as _progress_bar_all
32
38
 
33
- __all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all
39
+ __all__ = (_gradients_all + _jit_error_all + _conditions_all + _loops_all +
40
+ _make_jaxpr_all + _jit_all + _progress_bar_all + _loops_no_collection_all +
41
+ _mapping_all)
34
42
 
35
- del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all
43
+ del (_gradients_all, _jit_error_all, _conditions_all, _loops_all,
44
+ _make_jaxpr_all, _jit_all, _progress_bar_all, _loops_no_collection_all,
45
+ _mapping_all)
@@ -0,0 +1,334 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import operator
19
+ from collections.abc import Callable, Sequence
20
+ from functools import wraps, reduce
21
+
22
+ import jax
23
+ import jax.numpy as jnp
24
+ import numpy as np
25
+
26
+ from brainstate._utils import set_module_as
27
+ from ._error_if import jit_error_if
28
+ from ._make_jaxpr import StatefulFunction, _assign_state_values
29
+
30
+ __all__ = [
31
+ 'cond', 'switch', 'ifelse',
32
+ ]
33
+
34
+
35
+ def _wrapped_fun(stateful_fun: StatefulFunction, states, return_states=True):
36
+ @wraps(stateful_fun.fun)
37
+ def wrapped_branch(state_vals, *operands):
38
+ assert len(states) == len(state_vals)
39
+ for st, val in zip(states, state_vals):
40
+ st.value = val
41
+ out = stateful_fun.jaxpr_call_auto(*operands)
42
+ if return_states:
43
+ return tuple(st.value for st in states), out
44
+ return out
45
+
46
+ return wrapped_branch
47
+
48
+
49
+ @set_module_as('brainstate.transform')
50
+ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
51
+ """
52
+ Conditionally apply ``true_fun`` or ``false_fun``.
53
+
54
+ Provided arguments are correctly typed, ``cond()`` has equivalent
55
+ semantics to this Python implementation, where ``pred`` must be a
56
+ scalar type::
57
+
58
+ def cond(pred, true_fun, false_fun, *operands):
59
+ if pred:
60
+ return true_fun(*operands)
61
+ else:
62
+ return false_fun(*operands)
63
+
64
+
65
+ In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
66
+ the two branches is executed (up to compiler rewrites and optimizations).
67
+ However, when transformed with :func:`~jax.vmap` to operate over a batch of
68
+ predicates, ``cond`` is converted to :func:`~jax.lax.select`.
69
+
70
+ Args:
71
+ pred: Boolean scalar type, indicating which branch function to apply.
72
+ true_fun: Function (A -> B), to be applied if ``pred`` is True.
73
+ false_fun: Function (A -> B), to be applied if ``pred`` is False.
74
+ operands: Operands (A) input to either branch depending on ``pred``. The
75
+ type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
76
+ thereof.
77
+
78
+ Returns:
79
+ Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
80
+ depending on the value of ``pred``. The type can be a scalar, array, or any
81
+ pytree (nested Python tuple/list/dict) thereof.
82
+ """
83
+ if not (callable(true_fun) and callable(false_fun)):
84
+ raise TypeError("true_fun and false_fun arguments should be callable.")
85
+
86
+ if pred is None:
87
+ raise TypeError("cond predicate is None")
88
+ if isinstance(pred, Sequence) or np.ndim(pred) != 0:
89
+ raise TypeError(f"Pred must be a scalar, got {pred} of " +
90
+ (f"type {type(pred)}" if isinstance(pred, Sequence)
91
+ else f"shape {np.shape(pred)}."))
92
+
93
+ # check pred
94
+ try:
95
+ pred_dtype = jax.dtypes.result_type(pred)
96
+ except TypeError as err:
97
+ raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
98
+ if pred_dtype.kind != 'b':
99
+ if pred_dtype.kind in 'iuf':
100
+ pred = pred != 0
101
+ else:
102
+ raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
103
+
104
+ # not jit
105
+ if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
106
+ if pred:
107
+ return true_fun(*operands)
108
+ else:
109
+ return false_fun(*operands)
110
+
111
+ # evaluate jaxpr
112
+ true_fun_wrap = StatefulFunction(true_fun).make_jaxpr(*operands)
113
+ false_fun_wrap = StatefulFunction(false_fun).make_jaxpr(*operands)
114
+
115
+ # wrap the functions
116
+ all_states = tuple(set(true_fun_wrap.get_states() + false_fun_wrap.get_states()))
117
+ true_fun = _wrapped_fun(true_fun_wrap, all_states)
118
+ false_fun = _wrapped_fun(false_fun_wrap, all_states)
119
+
120
+ # operands
121
+ operands = ([st.value for st in all_states],) + operands
122
+
123
+ # cond
124
+ state_vals, out = jax.lax.cond(pred, true_fun, false_fun, *operands)
125
+ _assign_state_values(all_states, state_vals)
126
+ return out
127
+
128
+ # ops, ops_tree = jax.tree.flatten(operands)
129
+ # linear_ops = [False] * len(ops)
130
+ # ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
131
+ #
132
+ # # true and false jaxprs
133
+ # jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
134
+ # (true_fun, false_fun), ops_tree, ops_avals, 'cond')
135
+ # if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
136
+ # raise ValueError("Cannot pass `Ref`s into `cond`.")
137
+ # true_jaxpr, false_jaxpr = jaxprs
138
+ # out_tree, false_out_tree = out_trees
139
+ # if any(isinstance(out_aval, state.AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals):
140
+ # raise ValueError("Cannot return `Ref`s from `cond`.")
141
+ #
142
+ # _check_tree_and_avals("true_fun and false_fun output",
143
+ # out_tree, true_jaxpr.out_avals,
144
+ # false_out_tree, false_jaxpr.out_avals)
145
+ # joined_effects = jax.core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
146
+ # disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
147
+ # if disallowed_effects:
148
+ # raise NotImplementedError(f'Effects not supported in `cond`: {disallowed_effects}')
149
+ #
150
+ # # replace jaxpr effects
151
+ # index = jax.lax.convert_element_type(pred, np.int32)
152
+ # if joined_effects:
153
+ # # Raise index in case of effects to allow data-dependence-based discharging
154
+ # # of those effects (even if they don't have an explicit data dependence).
155
+ # index = jax.core.raise_as_much_as_possible(index)
156
+ # false_jaxpr = _replace_jaxpr_effects(false_jaxpr, joined_effects)
157
+ # true_jaxpr = _replace_jaxpr_effects(true_jaxpr, joined_effects)
158
+ #
159
+ # # bind
160
+ # linear = [False] * len(consts) + linear_ops
161
+ # cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
162
+ #
163
+ # # outputs
164
+ # st_vals, out = jax.tree.unflatten(out_tree, cond_outs)
165
+ # for st, val in zip(all_states, st_vals):
166
+ # st.value = val
167
+ # return out
168
+
169
+
170
+ @set_module_as('brainstate.transform')
171
+ def switch(index, branches: Sequence[Callable], *operands):
172
+ """
173
+ Apply exactly one of ``branches`` given by ``index``.
174
+
175
+ If ``index`` is out of bounds, it is clamped to within bounds.
176
+
177
+ Has the semantics of the following Python::
178
+
179
+ def switch(index, branches, *operands):
180
+ index = clamp(0, index, len(branches) - 1)
181
+ return branches[index](*operands)
182
+
183
+ Internally this wraps XLA's `Conditional
184
+ <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
185
+ operator. However, when transformed with :func:`~jax.vmap` to operate over a
186
+ batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
187
+
188
+ Args:
189
+ index: Integer scalar type, indicating which branch function to apply.
190
+ branches: Sequence of functions (A -> B) to be applied based on ``index``.
191
+ operands: Operands (A) input to whichever branch is applied.
192
+
193
+ Returns:
194
+ Value (B) of ``branch(*operands)`` for the branch that was selected based
195
+ on ``index``.
196
+ """
197
+ # check branches
198
+ if not all(callable(branch) for branch in branches):
199
+ raise TypeError("branches argument should be a sequence of callables.")
200
+
201
+ # check index
202
+ if len(np.shape(index)) != 0:
203
+ raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
204
+ try:
205
+ index_dtype = jax.dtypes.result_type(index)
206
+ except TypeError as err:
207
+ msg = f"Index type must be an integer, got {index}."
208
+ raise TypeError(msg) from err
209
+ if index_dtype.kind not in 'iu':
210
+ raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
211
+
212
+ # format branches
213
+ branches = tuple(branches)
214
+ if len(branches) == 0:
215
+ raise ValueError("Empty branch sequence")
216
+ elif len(branches) == 1:
217
+ return branches[0](*operands)
218
+
219
+ # format index
220
+ index = jax.lax.convert_element_type(index, np.int32)
221
+ lo = np.array(0, np.int32)
222
+ hi = np.array(len(branches) - 1, np.int32)
223
+ index = jax.lax.clamp(lo, index, hi)
224
+
225
+ # not jit
226
+ if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
227
+ return branches[int(index)](*operands)
228
+
229
+ # evaluate jaxpr
230
+ wrapped_branches = [StatefulFunction(branch) for branch in branches]
231
+ for wrapped_branch in wrapped_branches:
232
+ wrapped_branch.make_jaxpr(*operands)
233
+
234
+ # wrap the functions
235
+ all_states = tuple(set(reduce(operator.add, [wrapped_branch.get_states() for wrapped_branch in wrapped_branches])))
236
+ branches = tuple(_wrapped_fun(wrapped_branch, all_states) for wrapped_branch in wrapped_branches)
237
+
238
+ # operands
239
+ operands = ([st.value for st in all_states],) + operands
240
+
241
+ # switch
242
+ state_vals, out = jax.lax.switch(index, branches, *operands)
243
+ _assign_state_values(all_states, state_vals)
244
+ return out
245
+
246
+ # ops, ops_tree = jax.tree.flatten(operands)
247
+ # ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
248
+ #
249
+ # # true jaxprs
250
+ # jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
251
+ # branches, ops_tree, ops_avals, primitive_name='switch')
252
+ # for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
253
+ # _check_tree_and_avals(f"branch 0 and {i + 1} outputs",
254
+ # out_trees[0], jaxprs[0].out_avals,
255
+ # out_tree, jaxpr.out_avals)
256
+ # joined_effects = jax.core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
257
+ # disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
258
+ # if disallowed_effects:
259
+ # raise NotImplementedError(f'Effects not supported in `switch`: {disallowed_effects}')
260
+ # if joined_effects:
261
+ # # Raise index in case of effects to allow data-dependence-based discharging
262
+ # # of those effects (even if they don't have an explicit data dependence).
263
+ # index = jax.core.raise_as_much_as_possible(index)
264
+ #
265
+ # # bind
266
+ # linear = (False,) * (len(consts) + len(ops))
267
+ # cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
268
+ #
269
+ # # outputs
270
+ # st_vals, out = jax.tree.unflatten(out_trees[0], cond_outs)
271
+ # for st, val in zip(all_states, st_vals):
272
+ # st.value = val
273
+ # return out
274
+
275
+
276
+ @set_module_as('brainstate.transform')
277
+ def ifelse(conditions, branches, *operands, check_cond: bool = True):
278
+ """
279
+ ``If-else`` control flows looks like native Pythonic programming.
280
+
281
+ Examples
282
+ --------
283
+
284
+ >>> import brainstate as bst
285
+ >>> def f(a):
286
+ >>> return bst.transform.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
287
+ >>> branches=[lambda: 1,
288
+ >>> lambda: 2,
289
+ >>> lambda: 3,
290
+ >>> lambda: 4,
291
+ >>> lambda: 5])
292
+ >>> f(1)
293
+ 4
294
+ >>> f(0)
295
+ 5
296
+
297
+ Parameters
298
+ ----------
299
+ conditions: bool, sequence of bool, Array
300
+ The boolean conditions.
301
+ branches: Any
302
+ The branches, at least has two elements. Elements can be functions,
303
+ arrays, or numbers. The number of ``branches`` and ``conditions`` has
304
+ the relationship of `len(branches) == len(conditions) + 1`.
305
+ Each branch should receive one arguement for ``operands``.
306
+ *operands: optional, Any
307
+ The operands for each branch.
308
+ check_cond: bool
309
+ Whether to check the conditions. Default is True.
310
+
311
+ Returns
312
+ -------
313
+ res: Any
314
+ The results of the control flow.
315
+ """
316
+ # check branches
317
+ if not all(callable(branch) for branch in branches):
318
+ raise TypeError("branches argument should be a sequence of callables.")
319
+
320
+ # format branches
321
+ branches = tuple(branches)
322
+ if len(branches) == 0:
323
+ raise ValueError("Empty branch sequence")
324
+ elif len(branches) == 1:
325
+ return branches[0](*operands)
326
+ if len(conditions) != len(branches):
327
+ raise ValueError("The number of conditions should be equal to the number of branches.")
328
+
329
+ # format index
330
+ conditions = jnp.asarray(conditions, np.int32)
331
+ if check_cond:
332
+ jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
333
+ index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
334
+ return switch(index, branches, *operands)