brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,259 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from collections.abc import Callable, Sequence
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
from brainstate._utils import set_module_as
|
25
|
+
from ._error_if import jit_error_if
|
26
|
+
from ._make_jaxpr import StatefulFunction
|
27
|
+
from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'cond', 'switch', 'ifelse',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
@set_module_as('brainstate.compile')
|
35
|
+
def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
36
|
+
"""
|
37
|
+
Conditionally apply ``true_fun`` or ``false_fun``.
|
38
|
+
|
39
|
+
Provided arguments are correctly typed, ``cond()`` has equivalent
|
40
|
+
semantics to this Python implementation, where ``pred`` must be a
|
41
|
+
scalar type::
|
42
|
+
|
43
|
+
def cond(pred, true_fun, false_fun, *operands):
|
44
|
+
if pred:
|
45
|
+
return true_fun(*operands)
|
46
|
+
else:
|
47
|
+
return false_fun(*operands)
|
48
|
+
|
49
|
+
|
50
|
+
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
|
51
|
+
the two branches is executed (up to compiler rewrites and optimizations).
|
52
|
+
However, when transformed with :func:`~jax.vmap` to operate over a batch of
|
53
|
+
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
pred: Boolean scalar type, indicating which branch function to apply.
|
57
|
+
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
58
|
+
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
59
|
+
operands: Operands (A) input to either branch depending on ``pred``. The
|
60
|
+
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
61
|
+
thereof.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
65
|
+
depending on the value of ``pred``. The type can be a scalar, array, or any
|
66
|
+
pytree (nested Python tuple/list/dict) thereof.
|
67
|
+
"""
|
68
|
+
if not (callable(true_fun) and callable(false_fun)):
|
69
|
+
raise TypeError("true_fun and false_fun arguments should be callable.")
|
70
|
+
|
71
|
+
if pred is None:
|
72
|
+
raise TypeError("cond predicate is None")
|
73
|
+
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
74
|
+
raise TypeError(f"Pred must be a scalar, got {pred} of " +
|
75
|
+
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
76
|
+
else f"shape {np.shape(pred)}."))
|
77
|
+
|
78
|
+
# check pred
|
79
|
+
try:
|
80
|
+
pred_dtype = jax.dtypes.result_type(pred)
|
81
|
+
except TypeError as err:
|
82
|
+
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
|
83
|
+
if pred_dtype.kind != 'b':
|
84
|
+
if pred_dtype.kind in 'iuf':
|
85
|
+
pred = pred != 0
|
86
|
+
else:
|
87
|
+
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
|
88
|
+
|
89
|
+
# not jit
|
90
|
+
if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
|
91
|
+
if pred:
|
92
|
+
return true_fun(*operands)
|
93
|
+
else:
|
94
|
+
return false_fun(*operands)
|
95
|
+
|
96
|
+
# evaluate jaxpr
|
97
|
+
with jax.ensure_compile_time_eval():
|
98
|
+
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
|
99
|
+
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
100
|
+
|
101
|
+
# state trace and state values
|
102
|
+
state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
|
103
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
104
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
105
|
+
|
106
|
+
# wrap the functions
|
107
|
+
true_fun = wrap_single_fun_in_multi_branches(stateful_true, state_trace, read_state_vals, True)
|
108
|
+
false_fun = wrap_single_fun_in_multi_branches(stateful_false, state_trace, read_state_vals, True)
|
109
|
+
|
110
|
+
# cond
|
111
|
+
write_state_vals, out = jax.lax.cond(pred, true_fun, false_fun, write_state_vals, *operands)
|
112
|
+
|
113
|
+
# assign the written state values and restore the read state values
|
114
|
+
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
115
|
+
return out
|
116
|
+
|
117
|
+
|
118
|
+
@set_module_as('brainstate.compile')
|
119
|
+
def switch(index, branches: Sequence[Callable], *operands):
|
120
|
+
"""
|
121
|
+
Apply exactly one of ``branches`` given by ``index``.
|
122
|
+
|
123
|
+
If ``index`` is out of bounds, it is clamped to within bounds.
|
124
|
+
|
125
|
+
Has the semantics of the following Python::
|
126
|
+
|
127
|
+
def switch(index, branches, *operands):
|
128
|
+
index = clamp(0, index, len(branches) - 1)
|
129
|
+
return branches[index](*operands)
|
130
|
+
|
131
|
+
Internally this wraps XLA's `Conditional
|
132
|
+
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
|
133
|
+
operator. However, when transformed with :func:`~jax.vmap` to operate over a
|
134
|
+
batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
index: Integer scalar type, indicating which branch function to apply.
|
138
|
+
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
139
|
+
operands: Operands (A) input to whichever branch is applied.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
143
|
+
on ``index``.
|
144
|
+
"""
|
145
|
+
# check branches
|
146
|
+
if not all(callable(branch) for branch in branches):
|
147
|
+
raise TypeError("branches argument should be a sequence of callables.")
|
148
|
+
|
149
|
+
# check index
|
150
|
+
if len(np.shape(index)) != 0:
|
151
|
+
raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
|
152
|
+
try:
|
153
|
+
index_dtype = jax.dtypes.result_type(index)
|
154
|
+
except TypeError as err:
|
155
|
+
msg = f"Index type must be an integer, got {index}."
|
156
|
+
raise TypeError(msg) from err
|
157
|
+
if index_dtype.kind not in 'iu':
|
158
|
+
raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
|
159
|
+
|
160
|
+
# format branches
|
161
|
+
branches = tuple(branches)
|
162
|
+
if len(branches) == 0:
|
163
|
+
raise ValueError("Empty branch sequence")
|
164
|
+
elif len(branches) == 1:
|
165
|
+
return branches[0](*operands)
|
166
|
+
|
167
|
+
# format index
|
168
|
+
index = jax.lax.convert_element_type(index, np.int32)
|
169
|
+
lo = np.array(0, np.int32)
|
170
|
+
hi = np.array(len(branches) - 1, np.int32)
|
171
|
+
index = jax.lax.clamp(lo, index, hi)
|
172
|
+
|
173
|
+
# not jit
|
174
|
+
if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
|
175
|
+
return branches[int(index)](*operands)
|
176
|
+
|
177
|
+
# evaluate jaxpr
|
178
|
+
with jax.ensure_compile_time_eval():
|
179
|
+
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
180
|
+
for wrapped_branch in wrapped_branches:
|
181
|
+
wrapped_branch.make_jaxpr(*operands)
|
182
|
+
|
183
|
+
# wrap the functions
|
184
|
+
state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
|
185
|
+
state_trace.merge(*[wrapped_branch.get_state_trace() for wrapped_branch in wrapped_branches[2:]])
|
186
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
187
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
188
|
+
branches = [
|
189
|
+
wrap_single_fun_in_multi_branches(wrapped_branch, state_trace, read_state_vals, True)
|
190
|
+
for wrapped_branch in wrapped_branches
|
191
|
+
]
|
192
|
+
|
193
|
+
# switch
|
194
|
+
write_state_vals, out = jax.lax.switch(index, branches, write_state_vals, *operands)
|
195
|
+
|
196
|
+
# write back state values or restore them
|
197
|
+
write_back_state_values(state_trace, read_state_vals, write_state_vals)
|
198
|
+
return out
|
199
|
+
|
200
|
+
|
201
|
+
@set_module_as('brainstate.compile')
|
202
|
+
def ifelse(conditions, branches, *operands, check_cond: bool = True):
|
203
|
+
"""
|
204
|
+
``If-else`` control flows looks like native Pythonic programming.
|
205
|
+
|
206
|
+
Examples
|
207
|
+
--------
|
208
|
+
|
209
|
+
>>> import brainstate as bst
|
210
|
+
>>> def f(a):
|
211
|
+
>>> return bst.compile.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
|
212
|
+
>>> branches=[lambda: 1,
|
213
|
+
>>> lambda: 2,
|
214
|
+
>>> lambda: 3,
|
215
|
+
>>> lambda: 4,
|
216
|
+
>>> lambda: 5])
|
217
|
+
>>> f(1)
|
218
|
+
4
|
219
|
+
>>> f(0)
|
220
|
+
5
|
221
|
+
|
222
|
+
Parameters
|
223
|
+
----------
|
224
|
+
conditions: bool, sequence of bool, Array
|
225
|
+
The boolean conditions.
|
226
|
+
branches: Any
|
227
|
+
The branches, at least has two elements. Elements can be functions,
|
228
|
+
arrays, or numbers. The number of ``branches`` and ``conditions`` has
|
229
|
+
the relationship of `len(branches) == len(conditions) + 1`.
|
230
|
+
Each branch should receive one arguement for ``operands``.
|
231
|
+
*operands: optional, Any
|
232
|
+
The operands for each branch.
|
233
|
+
check_cond: bool
|
234
|
+
Whether to check the conditions. Default is True.
|
235
|
+
|
236
|
+
Returns
|
237
|
+
-------
|
238
|
+
res: Any
|
239
|
+
The results of the control flow.
|
240
|
+
"""
|
241
|
+
# check branches
|
242
|
+
if not all(callable(branch) for branch in branches):
|
243
|
+
raise TypeError("branches argument should be a sequence of callables.")
|
244
|
+
|
245
|
+
# format branches
|
246
|
+
branches = tuple(branches)
|
247
|
+
if len(branches) == 0:
|
248
|
+
raise ValueError("Empty branch sequence")
|
249
|
+
elif len(branches) == 1:
|
250
|
+
return branches[0](*operands)
|
251
|
+
if len(conditions) != len(branches):
|
252
|
+
raise ValueError("The number of conditions should be equal to the number of branches.")
|
253
|
+
|
254
|
+
# format index
|
255
|
+
conditions = jnp.asarray(conditions, np.int32)
|
256
|
+
if check_cond:
|
257
|
+
jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
|
258
|
+
index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
|
259
|
+
return switch(index, branches, *operands)
|
@@ -0,0 +1,221 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax
|
20
|
+
import jax.numpy as jnp
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestCond(unittest.TestCase):
|
26
|
+
def test1(self):
|
27
|
+
bst.random.seed(1)
|
28
|
+
bst.compile.cond(True, lambda: bst.random.random(10), lambda: bst.random.random(10))
|
29
|
+
bst.compile.cond(False, lambda: bst.random.random(10), lambda: bst.random.random(10))
|
30
|
+
|
31
|
+
def test2(self):
|
32
|
+
st1 = bst.State(bst.random.rand(10))
|
33
|
+
st2 = bst.State(bst.random.rand(2))
|
34
|
+
st3 = bst.State(bst.random.rand(5))
|
35
|
+
st4 = bst.State(bst.random.rand(2, 10))
|
36
|
+
|
37
|
+
def true_fun(x):
|
38
|
+
st1.value = st2.value @ st4.value + x
|
39
|
+
|
40
|
+
def false_fun(x):
|
41
|
+
st3.value = (st3.value + 1.) * x
|
42
|
+
|
43
|
+
bst.compile.cond(True, true_fun, false_fun, 2.)
|
44
|
+
assert not isinstance(st1.value, jax.core.Tracer)
|
45
|
+
assert not isinstance(st2.value, jax.core.Tracer)
|
46
|
+
assert not isinstance(st3.value, jax.core.Tracer)
|
47
|
+
assert not isinstance(st4.value, jax.core.Tracer)
|
48
|
+
|
49
|
+
|
50
|
+
class TestSwitch(unittest.TestCase):
|
51
|
+
def testSwitch(self):
|
52
|
+
def branch(x):
|
53
|
+
y = jax.lax.mul(2, x)
|
54
|
+
return y, jax.lax.mul(2, y)
|
55
|
+
|
56
|
+
branches = [lambda x: (x, x),
|
57
|
+
branch,
|
58
|
+
lambda x: (x, -x)]
|
59
|
+
|
60
|
+
def fun(x):
|
61
|
+
if x <= 0:
|
62
|
+
return branches[0](x)
|
63
|
+
elif x == 1:
|
64
|
+
return branches[1](x)
|
65
|
+
else:
|
66
|
+
return branches[2](x)
|
67
|
+
|
68
|
+
def cfun(x):
|
69
|
+
return bst.compile.switch(x, branches, x)
|
70
|
+
|
71
|
+
self.assertEqual(fun(-1), cfun(-1))
|
72
|
+
self.assertEqual(fun(0), cfun(0))
|
73
|
+
self.assertEqual(fun(1), cfun(1))
|
74
|
+
self.assertEqual(fun(2), cfun(2))
|
75
|
+
self.assertEqual(fun(3), cfun(3))
|
76
|
+
|
77
|
+
cfun = jax.jit(cfun)
|
78
|
+
|
79
|
+
self.assertEqual(fun(-1), cfun(-1))
|
80
|
+
self.assertEqual(fun(0), cfun(0))
|
81
|
+
self.assertEqual(fun(1), cfun(1))
|
82
|
+
self.assertEqual(fun(2), cfun(2))
|
83
|
+
self.assertEqual(fun(3), cfun(3))
|
84
|
+
|
85
|
+
def testSwitchMultiOperands(self):
|
86
|
+
branches = [jax.lax.add, jax.lax.mul]
|
87
|
+
|
88
|
+
def fun(x):
|
89
|
+
i = 0 if x <= 0 else 1
|
90
|
+
return branches[i](x, x)
|
91
|
+
|
92
|
+
def cfun(x):
|
93
|
+
return bst.compile.switch(x, branches, x, x)
|
94
|
+
|
95
|
+
self.assertEqual(fun(-1), cfun(-1))
|
96
|
+
self.assertEqual(fun(0), cfun(0))
|
97
|
+
self.assertEqual(fun(1), cfun(1))
|
98
|
+
self.assertEqual(fun(2), cfun(2))
|
99
|
+
cfun = jax.jit(cfun)
|
100
|
+
self.assertEqual(fun(-1), cfun(-1))
|
101
|
+
self.assertEqual(fun(0), cfun(0))
|
102
|
+
self.assertEqual(fun(1), cfun(1))
|
103
|
+
self.assertEqual(fun(2), cfun(2))
|
104
|
+
|
105
|
+
def testSwitchResidualsMerge(self):
|
106
|
+
def get_conds(fun):
|
107
|
+
jaxpr = jax.make_jaxpr(jax.grad(fun))(0., 0)
|
108
|
+
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']
|
109
|
+
|
110
|
+
def branch_invars_len(cond_eqn):
|
111
|
+
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
|
112
|
+
assert len(set(lens)) == 1
|
113
|
+
return lens[0]
|
114
|
+
|
115
|
+
def branch_outvars_len(cond_eqn):
|
116
|
+
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
|
117
|
+
assert len(set(lens)) == 1
|
118
|
+
return lens[0]
|
119
|
+
|
120
|
+
branches1 = [lambda x: jnp.sin(x),
|
121
|
+
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
|
122
|
+
branches2 = branches1 + [lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
|
123
|
+
branches3 = branches2 + [lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
|
124
|
+
|
125
|
+
def fun1(x, i):
|
126
|
+
return bst.compile.switch(i + 1, branches1, x)
|
127
|
+
|
128
|
+
def fun2(x, i):
|
129
|
+
return bst.compile.switch(i + 1, branches2, x)
|
130
|
+
|
131
|
+
def fun3(x, i):
|
132
|
+
return bst.compile.switch(i + 1, branches3, x)
|
133
|
+
|
134
|
+
fwd1, bwd1 = get_conds(fun1)
|
135
|
+
fwd2, bwd2 = get_conds(fun2)
|
136
|
+
fwd3, bwd3 = get_conds(fun3)
|
137
|
+
|
138
|
+
fwd1_num_out = branch_outvars_len(fwd1)
|
139
|
+
fwd2_num_out = branch_outvars_len(fwd2)
|
140
|
+
fwd3_num_out = branch_outvars_len(fwd3)
|
141
|
+
assert fwd1_num_out == fwd2_num_out
|
142
|
+
assert fwd3_num_out == fwd2_num_out + 1
|
143
|
+
|
144
|
+
bwd1_num_in = branch_invars_len(bwd1)
|
145
|
+
bwd2_num_in = branch_invars_len(bwd2)
|
146
|
+
bwd3_num_in = branch_invars_len(bwd3)
|
147
|
+
assert bwd1_num_in == bwd2_num_in
|
148
|
+
assert bwd3_num_in == bwd2_num_in + 1
|
149
|
+
|
150
|
+
def testOneBranchSwitch(self):
|
151
|
+
branch = lambda x: -x
|
152
|
+
f = lambda i, x: bst.compile.switch(i, [branch], x)
|
153
|
+
x = 7.
|
154
|
+
self.assertEqual(f(-1, x), branch(x))
|
155
|
+
self.assertEqual(f(0, x), branch(x))
|
156
|
+
self.assertEqual(f(1, x), branch(x))
|
157
|
+
cf = jax.jit(f)
|
158
|
+
self.assertEqual(cf(-1, x), branch(x))
|
159
|
+
self.assertEqual(cf(0, x), branch(x))
|
160
|
+
self.assertEqual(cf(1, x), branch(x))
|
161
|
+
cf = jax.jit(f, static_argnums=0)
|
162
|
+
self.assertEqual(cf(-1, x), branch(x))
|
163
|
+
self.assertEqual(cf(0, x), branch(x))
|
164
|
+
self.assertEqual(cf(1, x), branch(x))
|
165
|
+
|
166
|
+
|
167
|
+
class TestIfElse(unittest.TestCase):
|
168
|
+
def test1(self):
|
169
|
+
def f(a):
|
170
|
+
return bst.compile.ifelse(conditions=[a < 0,
|
171
|
+
a >= 0 and a < 2,
|
172
|
+
a >= 2 and a < 5,
|
173
|
+
a >= 5 and a < 10,
|
174
|
+
a >= 10],
|
175
|
+
branches=[lambda: 1,
|
176
|
+
lambda: 2,
|
177
|
+
lambda: 3,
|
178
|
+
lambda: 4,
|
179
|
+
lambda: 5])
|
180
|
+
|
181
|
+
self.assertTrue(f(3) == 3)
|
182
|
+
self.assertTrue(f(1) == 2)
|
183
|
+
self.assertTrue(f(-1) == 1)
|
184
|
+
|
185
|
+
def test_vmap(self):
|
186
|
+
def f(operands):
|
187
|
+
f = lambda a: bst.compile.ifelse([a > 10,
|
188
|
+
jnp.logical_and(a <= 10, a > 5),
|
189
|
+
jnp.logical_and(a <= 5, a > 2),
|
190
|
+
jnp.logical_and(a <= 2, a > 0),
|
191
|
+
a <= 0],
|
192
|
+
[lambda _: 1,
|
193
|
+
lambda _: 2,
|
194
|
+
lambda _: 3,
|
195
|
+
lambda _: 4,
|
196
|
+
lambda _: 5, ],
|
197
|
+
a)
|
198
|
+
return jax.vmap(f)(operands)
|
199
|
+
|
200
|
+
r = f(bst.random.randint(-20, 20, 200))
|
201
|
+
self.assertTrue(r.size == 200)
|
202
|
+
|
203
|
+
def test_grad1(self):
|
204
|
+
def F2(x):
|
205
|
+
return bst.compile.ifelse((x >= 10, x < 10),
|
206
|
+
[lambda x: x, lambda x: x ** 2, ],
|
207
|
+
x)
|
208
|
+
|
209
|
+
self.assertTrue(jax.grad(F2)(9.0) == 18.)
|
210
|
+
self.assertTrue(jax.grad(F2)(11.0) == 1.)
|
211
|
+
|
212
|
+
def test_grad2(self):
|
213
|
+
def F3(x):
|
214
|
+
return bst.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
215
|
+
[lambda x: x,
|
216
|
+
lambda x: x ** 2,
|
217
|
+
lambda x: x ** 4, ],
|
218
|
+
x)
|
219
|
+
|
220
|
+
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
221
|
+
self.assertTrue(jax.grad(F3)(11.0) == 1.)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import functools
|
19
|
+
from functools import partial
|
20
|
+
from typing import Callable, Union
|
21
|
+
|
22
|
+
import jax
|
23
|
+
|
24
|
+
from brainstate._utils import set_module_as
|
25
|
+
from ._unvmap import unvmap
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'jit_error_if',
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
def _err_jit_true_branch(err_fun, args, kwargs):
|
33
|
+
jax.debug.callback(err_fun, *args, **kwargs)
|
34
|
+
|
35
|
+
|
36
|
+
def _err_jit_false_branch(args, kwargs):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
def _error_msg(msg, *arg, **kwargs):
|
41
|
+
if len(arg):
|
42
|
+
msg = msg % arg
|
43
|
+
if len(kwargs):
|
44
|
+
msg = msg.format(**kwargs)
|
45
|
+
raise ValueError(msg)
|
46
|
+
|
47
|
+
|
48
|
+
@set_module_as('brainstate.compile')
|
49
|
+
def jit_error_if(
|
50
|
+
pred,
|
51
|
+
error: Union[Callable, str],
|
52
|
+
*err_args,
|
53
|
+
**err_kwargs,
|
54
|
+
):
|
55
|
+
"""
|
56
|
+
Check errors in a jit function.
|
57
|
+
|
58
|
+
Examples
|
59
|
+
--------
|
60
|
+
|
61
|
+
It can give a function which receive arguments that passed from the JIT variables and raise errors.
|
62
|
+
|
63
|
+
>>> def error(x):
|
64
|
+
>>> raise ValueError(f'error {x}')
|
65
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
66
|
+
>>> jit_error_if(x.sum() < 5., error, x)
|
67
|
+
|
68
|
+
Or, it can be a simple string message.
|
69
|
+
|
70
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
71
|
+
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
72
|
+
|
73
|
+
|
74
|
+
Parameters
|
75
|
+
----------
|
76
|
+
pred: bool, Array
|
77
|
+
The boolean prediction.
|
78
|
+
error: callable, str
|
79
|
+
The error function, which raise errors, or a string indicating the error message.
|
80
|
+
err_args:
|
81
|
+
The arguments which passed into `err_f`.
|
82
|
+
err_kwargs:
|
83
|
+
The keywords which passed into `err_f`.
|
84
|
+
"""
|
85
|
+
if isinstance(error, str):
|
86
|
+
error = partial(_error_msg, error)
|
87
|
+
|
88
|
+
jax.lax.cond(
|
89
|
+
unvmap(pred, op='any'),
|
90
|
+
partial(_err_jit_true_branch, error),
|
91
|
+
_err_jit_false_branch,
|
92
|
+
jax.tree.map(functools.partial(unvmap, op='none'), err_args),
|
93
|
+
jax.tree.map(functools.partial(unvmap, op='none'), err_kwargs),
|
94
|
+
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import jaxlib.xla_extension
|
23
|
+
|
24
|
+
import brainstate as bst
|
25
|
+
|
26
|
+
|
27
|
+
class TestJitError(unittest.TestCase):
|
28
|
+
def test1(self):
|
29
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
30
|
+
bst.compile.jit_error_if(True, 'error')
|
31
|
+
|
32
|
+
def err_f(x):
|
33
|
+
raise ValueError(f'error: {x}')
|
34
|
+
|
35
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
36
|
+
bst.compile.jit_error_if(True, err_f, 1.)
|
37
|
+
|
38
|
+
def test_vmap(self):
|
39
|
+
def f(x):
|
40
|
+
bst.compile.jit_error_if(x, 'error: {x}', x=x)
|
41
|
+
|
42
|
+
jax.vmap(f)(jnp.array([False, False, False]))
|
43
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
44
|
+
jax.vmap(f)(jnp.array([True, False, False]))
|
45
|
+
|
46
|
+
def test_vmap_vmap(self):
|
47
|
+
def f(x):
|
48
|
+
bst.compile.jit_error_if(x, 'error: {x}', x=x)
|
49
|
+
|
50
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
51
|
+
[False, False, False]]))
|
52
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
53
|
+
jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
|
54
|
+
[True, False, False]]))
|