brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,204 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from typing import Callable, Tuple, Union
18
-
19
- import jax
20
-
21
- from brainstate.typing import Missing
22
- from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
23
- from ._util import write_back_state_values
24
-
25
- __all__ = [
26
- 'checkpoint',
27
- 'remat'
28
- ]
29
-
30
-
31
- def checkpoint(
32
- fun: Callable = Missing(),
33
- *,
34
- prevent_cse: bool = True,
35
- policy: Callable[..., bool] | None = None,
36
- static_argnums: int | Tuple[int, ...] = (),
37
- ) -> Union[Callable, Callable[[Callable], Callable]]:
38
- """Make ``fun`` recompute internal linearization points when differentiated.
39
-
40
- The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a
41
- way to trade off computation time and memory cost in the context of automatic
42
- differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
43
- and :func:`jax.vjp` but also with :func:`jax.linearize`.
44
-
45
- When differentiating a function in reverse-mode, by default all the
46
- linearization points (e.g. inputs to elementwise nonlinear primitive
47
- operations) are stored when evaluating the forward pass so that they can be
48
- reused on the backward pass. This evaluation strategy can lead to a high
49
- memory cost, or even to poor performance on hardware accelerators where memory
50
- access is much more expensive than FLOPs.
51
-
52
- An alternative evaluation strategy is for some of the linearization points to
53
- be recomputed (i.e. rematerialized) rather than stored. This approach can
54
- reduce memory usage at the cost of increased computation.
55
-
56
- This function decorator produces a new version of ``fun`` which follows
57
- the rematerialization strategy rather than the default store-everything
58
- strategy. That is, it returns a new version of ``fun`` which, when
59
- differentiated, doesn't store any of its intermediate linearization points.
60
- Instead, these linearization points are recomputed from the function's saved
61
- inputs.
62
-
63
- See the examples below.
64
-
65
- Args:
66
- fun: Function for which the autodiff evaluation strategy is to be changed
67
- from the default of storing all intermediate linearization points to
68
- recomputing them. Its arguments and return value should be arrays,
69
- scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
70
- prevent_cse: Optional, boolean keyword-only argument indicating whether to
71
- prevent common subexpression elimination (CSE) optimizations in the HLO
72
- generated from differentiation. This CSE prevention has costs because it
73
- can foil other optimizations, and because it can incur high overheads on
74
- some backends, especially GPU. The default is True because otherwise,
75
- under a :func:`~jax.jit` or :func:`~jax.pmap`, CSE can defeat the purpose
76
- of this decorator.
77
- But in some settings, like when used inside a :func:`~jax.lax.scan`, this
78
- CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can
79
- be set to False.
80
- static_argnums: Optional, int or sequence of ints, a keyword-only argument
81
- indicating which argument values on which to specialize for tracing and
82
- caching purposes. Specifying arguments as static can avoid
83
- ConcretizationTypeErrors when tracing, but at the cost of more retracing
84
- overheads. See the example below.
85
- policy: Optional, callable keyword-only argument. It should be one of the
86
- attributes of ``jax.checkpoint_policies``. The callable takes as input a
87
- type-level specification of a first-order primitive application and
88
- returns a boolean indicating whether the corresponding output value(s) can
89
- be saved as residuals (or instead must be recomputed in the (co)tangent
90
- computation if needed).
91
-
92
- Returns:
93
- A function (callable) with the same input/output behavior as ``fun`` but
94
- which, when differentiated using e.g. :func:`jax.grad`, :func:`jax.vjp`, or
95
- :func:`jax.linearize`, recomputes rather than stores intermediate
96
- linearization points, thus potentially saving memory at the cost of extra
97
- computation.
98
-
99
- Here is a simple example:
100
-
101
- >>> import jax
102
- >>> import jax.numpy as jnp
103
-
104
- >>> @jax.checkpoint
105
- ... def g(x):
106
- ... y = jnp.sin(x)
107
- ... z = jnp.sin(y)
108
- ... return z
109
- ...
110
- >>> jax.value_and_grad(g)(2.0)
111
- (Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
112
-
113
- Here, the same value is produced whether or not the :func:`jax.checkpoint`
114
- decorator is present. When the decorator is not present, the values
115
- ``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))`` are computed on the forward
116
- pass and are stored for use in the backward pass, because they are needed
117
- on the backward pass and depend only on the primal inputs. When using
118
- :func:`jax.checkpoint`, the forward pass will compute only the primal outputs
119
- and only the primal inputs (``2.0``) will be stored for the backward pass.
120
- At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
121
- ``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
122
-
123
- While :func:`jax.checkpoint` controls what values are stored from the
124
- forward-pass to be used on the backward pass, the total amount of memory
125
- required to evaluate a function or its VJP depends on many additional internal
126
- details of that function. Those details include which numerical primitives are
127
- used, how they're composed, where jit and control flow primitives like scan
128
- are used, and other factors.
129
-
130
- The :func:`jax.checkpoint` decorator can be applied recursively to express
131
- sophisticated autodiff rematerialization strategies. For example:
132
-
133
- >>> def recursive_checkpoint(funs):
134
- ... if len(funs) == 1:
135
- ... return funs[0]
136
- ... elif len(funs) == 2:
137
- ... f1, f2 = funs
138
- ... return lambda x: f1(f2(x))
139
- ... else:
140
- ... f1 = recursive_checkpoint(funs[:len(funs)//2])
141
- ... f2 = recursive_checkpoint(funs[len(funs)//2:])
142
- ... return lambda x: f1(jax.checkpoint(f2)(x))
143
- ...
144
-
145
- If ``fun`` involves Python control flow that depends on argument values,
146
- it may be necessary to use the ``static_argnums`` parameter. For example,
147
- consider a boolean flag argument::
148
-
149
- from functools import partial
150
-
151
- @partial(jax.checkpoint, static_argnums=(1,))
152
- def foo(x, is_training):
153
- if is_training:
154
- ...
155
- else:
156
- ...
157
-
158
- Here, the use of ``static_argnums`` allows the ``if`` statement's condition
159
- to depends on the value of ``is_training``. The cost to using
160
- ``static_argnums`` is that it introduces re-tracing overheads across calls:
161
- in the example, ``foo`` is re-traced every time it is called with a new value
162
- of ``is_training``. In some situations, ``jax.ensure_compile_time_eval``
163
- is needed as well::
164
-
165
- @partial(jax.checkpoint, static_argnums=(1,))
166
- def foo(x, y):
167
- with jax.ensure_compile_time_eval():
168
- y_pos = y > 0
169
- if y_pos:
170
- ...
171
- else:
172
- ...
173
-
174
- As an alternative to using ``static_argnums`` (and
175
- ``jax.ensure_compile_time_eval``), it may be easier to compute some values
176
- outside the :func:`jax.checkpoint`-decorated function and then close over them.
177
- """
178
- if isinstance(fun, Missing):
179
- return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
180
-
181
- static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
182
- fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
183
- checkpointed_fun = jax.checkpoint(
184
- fun.jaxpr_call,
185
- prevent_cse=prevent_cse,
186
- policy=policy,
187
- static_argnums=tuple(i + 1 for i in static_argnums)
188
- )
189
-
190
- @functools.wraps(fun.fun)
191
- def remat_fun(*args, **params):
192
- # compile the function and get the state trace
193
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
194
- read_state_vals = state_trace.get_read_state_values()
195
- # call the checkpointed function
196
- write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
197
- # write the state values back to the states
198
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
199
- return outs
200
-
201
- return remat_fun
202
-
203
-
204
- remat = checkpoint
@@ -1,256 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from collections.abc import Callable, Sequence
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- import numpy as np
21
-
22
- from brainstate._compatible_import import to_concrete_aval, Tracer
23
- from brainstate._utils import set_module_as
24
- from ._error_if import jit_error_if
25
- from ._make_jaxpr import StatefulFunction
26
- from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
27
-
28
- __all__ = [
29
- 'cond', 'switch', 'ifelse',
30
- ]
31
-
32
-
33
- @set_module_as('brainstate.compile')
34
- def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
35
- """
36
- Conditionally apply ``true_fun`` or ``false_fun``.
37
-
38
- Provided arguments are correctly typed, ``cond()`` has equivalent
39
- semantics to this Python implementation, where ``pred`` must be a
40
- scalar type::
41
-
42
- def cond(pred, true_fun, false_fun, *operands):
43
- if pred:
44
- return true_fun(*operands)
45
- else:
46
- return false_fun(*operands)
47
-
48
-
49
- In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
50
- the two branches is executed (up to compiler rewrites and optimizations).
51
- However, when transformed with :func:`~jax.vmap` to operate over a batch of
52
- predicates, ``cond`` is converted to :func:`~jax.lax.select`.
53
-
54
- Args:
55
- pred: Boolean scalar type, indicating which branch function to apply.
56
- true_fun: Function (A -> B), to be applied if ``pred`` is True.
57
- false_fun: Function (A -> B), to be applied if ``pred`` is False.
58
- operands: Operands (A) input to either branch depending on ``pred``. The
59
- type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
60
- thereof.
61
-
62
- Returns:
63
- Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
64
- depending on the value of ``pred``. The type can be a scalar, array, or any
65
- pytree (nested Python tuple/list/dict) thereof.
66
- """
67
- if not (callable(true_fun) and callable(false_fun)):
68
- raise TypeError("true_fun and false_fun arguments should be callable.")
69
-
70
- if pred is None:
71
- raise TypeError("cond predicate is None")
72
- if isinstance(pred, Sequence) or np.ndim(pred) != 0:
73
- raise TypeError(f"Pred must be a scalar, got {pred} of " +
74
- (f"type {type(pred)}" if isinstance(pred, Sequence)
75
- else f"shape {np.shape(pred)}."))
76
-
77
- # check pred
78
- try:
79
- pred_dtype = jax.dtypes.result_type(pred)
80
- except TypeError as err:
81
- raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
82
- if pred_dtype.kind != 'b':
83
- if pred_dtype.kind in 'iuf':
84
- pred = pred != 0
85
- else:
86
- raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
87
-
88
- # not jit
89
- if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(pred), Tracer):
90
- if pred:
91
- return true_fun(*operands)
92
- else:
93
- return false_fun(*operands)
94
-
95
- # evaluate jaxpr
96
- stateful_true = StatefulFunction(true_fun, name='cond:true').make_jaxpr(*operands)
97
- stateful_false = StatefulFunction(false_fun, name='conda:false').make_jaxpr(*operands)
98
-
99
- # state trace and state values
100
- state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
101
- read_state_vals = state_trace.get_read_state_values(True)
102
- write_state_vals = state_trace.get_write_state_values(True)
103
-
104
- # wrap the functions
105
- true_fun = wrap_single_fun_in_multi_branches(stateful_true, state_trace, read_state_vals, True)
106
- false_fun = wrap_single_fun_in_multi_branches(stateful_false, state_trace, read_state_vals, True)
107
-
108
- # cond
109
- write_state_vals, out = jax.lax.cond(pred, true_fun, false_fun, write_state_vals, *operands)
110
-
111
- # assign the written state values and restore the read state values
112
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
113
- return out
114
-
115
-
116
- @set_module_as('brainstate.compile')
117
- def switch(index, branches: Sequence[Callable], *operands):
118
- """
119
- Apply exactly one of ``branches`` given by ``index``.
120
-
121
- If ``index`` is out of bounds, it is clamped to within bounds.
122
-
123
- Has the semantics of the following Python::
124
-
125
- def switch(index, branches, *operands):
126
- index = clamp(0, index, len(branches) - 1)
127
- return branches[index](*operands)
128
-
129
- Internally this wraps XLA's `Conditional
130
- <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
131
- operator. However, when transformed with :func:`~jax.vmap` to operate over a
132
- batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
133
-
134
- Args:
135
- index: Integer scalar type, indicating which branch function to apply.
136
- branches: Sequence of functions (A -> B) to be applied based on ``index``.
137
- operands: Operands (A) input to whichever branch is applied.
138
-
139
- Returns:
140
- Value (B) of ``branch(*operands)`` for the branch that was selected based
141
- on ``index``.
142
- """
143
- # check branches
144
- if not all(callable(branch) for branch in branches):
145
- raise TypeError("branches argument should be a sequence of callables.")
146
-
147
- # check index
148
- if len(np.shape(index)) != 0:
149
- raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
150
- try:
151
- index_dtype = jax.dtypes.result_type(index)
152
- except TypeError as err:
153
- msg = f"Index type must be an integer, got {index}."
154
- raise TypeError(msg) from err
155
- if index_dtype.kind not in 'iu':
156
- raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
157
-
158
- # format branches
159
- branches = tuple(branches)
160
- if len(branches) == 0:
161
- raise ValueError("Empty branch sequence")
162
- elif len(branches) == 1:
163
- return branches[0](*operands)
164
-
165
- # format index
166
- index = jax.lax.convert_element_type(index, np.int32)
167
- lo = np.array(0, np.int32)
168
- hi = np.array(len(branches) - 1, np.int32)
169
- index = jax.lax.clamp(lo, index, hi)
170
-
171
- # not jit
172
- if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
173
- return branches[int(index)](*operands)
174
-
175
- # evaluate jaxpr
176
- wrapped_branches = [StatefulFunction(branch, name='switch') for branch in branches]
177
- for wrapped_branch in wrapped_branches:
178
- wrapped_branch.make_jaxpr(*operands)
179
-
180
- # wrap the functions
181
- state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
182
- state_trace.merge(*[wrapped_branch.get_state_trace() for wrapped_branch in wrapped_branches[2:]])
183
- read_state_vals = state_trace.get_read_state_values(True)
184
- write_state_vals = state_trace.get_write_state_values(True)
185
- branches = [
186
- wrap_single_fun_in_multi_branches(wrapped_branch, state_trace, read_state_vals, True)
187
- for wrapped_branch in wrapped_branches
188
- ]
189
-
190
- # switch
191
- write_state_vals, out = jax.lax.switch(index, branches, write_state_vals, *operands)
192
-
193
- # write back state values or restore them
194
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
195
- return out
196
-
197
-
198
- @set_module_as('brainstate.compile')
199
- def ifelse(conditions, branches, *operands, check_cond: bool = True):
200
- """
201
- ``If-else`` control flows looks like native Pythonic programming.
202
-
203
- Examples
204
- --------
205
-
206
- >>> import brainstate
207
- >>> def f(a):
208
- >>> return brainstate.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
209
- >>> branches=[lambda: 1,
210
- >>> lambda: 2,
211
- >>> lambda: 3,
212
- >>> lambda: 4,
213
- >>> lambda: 5])
214
- >>> f(1)
215
- 4
216
- >>> f(0)
217
- 5
218
-
219
- Parameters
220
- ----------
221
- conditions: bool, sequence of bool, Array
222
- The boolean conditions.
223
- branches: Any
224
- The branches, at least has two elements. Elements can be functions,
225
- arrays, or numbers. The number of ``branches`` and ``conditions`` has
226
- the relationship of `len(branches) == len(conditions) + 1`.
227
- Each branch should receive one arguement for ``operands``.
228
- *operands: optional, Any
229
- The operands for each branch.
230
- check_cond: bool
231
- Whether to check the conditions. Default is True.
232
-
233
- Returns
234
- -------
235
- res: Any
236
- The results of the control flow.
237
- """
238
- # check branches
239
- if not all(callable(branch) for branch in branches):
240
- raise TypeError("branches argument should be a sequence of callables.")
241
-
242
- # format branches
243
- branches = tuple(branches)
244
- if len(branches) == 0:
245
- raise ValueError("Empty branch sequence")
246
- elif len(branches) == 1:
247
- return branches[0](*operands)
248
- if len(conditions) != len(branches):
249
- raise ValueError("The number of conditions should be equal to the number of branches.")
250
-
251
- # format index
252
- conditions = jnp.asarray(conditions, np.int32)
253
- if check_cond:
254
- jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
255
- index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
256
- return switch(index, branches, *operands)