brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from collections.abc import Callable, Sequence
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
from brainstate._compatible_import import to_concrete_aval, Tracer
|
23
|
+
from brainstate._utils import set_module_as
|
24
|
+
from ._error_if import jit_error_if
|
25
|
+
from ._make_jaxpr import StatefulFunction
|
26
|
+
from ._util import wrap_single_fun_in_multi_branches
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
'cond', 'switch', 'ifelse',
|
30
|
+
]
|
31
|
+
|
32
|
+
|
33
|
+
@set_module_as('brainstate.transform')
|
34
|
+
def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
35
|
+
"""
|
36
|
+
Conditionally apply ``true_fun`` or ``false_fun``.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
pred : bool or array-like
|
41
|
+
Boolean scalar selecting which branch to execute. Numeric inputs are
|
42
|
+
treated as ``True`` when non-zero.
|
43
|
+
true_fun : Callable
|
44
|
+
Function that receives ``*operands`` when ``pred`` is ``True``.
|
45
|
+
false_fun : Callable
|
46
|
+
Function that receives ``*operands`` when ``pred`` is ``False``.
|
47
|
+
*operands : Any
|
48
|
+
Operands forwarded to either branch. May be any pytree of arrays,
|
49
|
+
scalars, or nested containers thereof.
|
50
|
+
|
51
|
+
Returns
|
52
|
+
-------
|
53
|
+
Any
|
54
|
+
Value returned by the selected branch with the same pytree structure
|
55
|
+
as produced by ``true_fun`` or ``false_fun``.
|
56
|
+
|
57
|
+
Notes
|
58
|
+
-----
|
59
|
+
Provided the arguments are correctly typed, :func:`cond` has semantics
|
60
|
+
that match the following Python implementation, where ``pred`` must be a
|
61
|
+
scalar:
|
62
|
+
|
63
|
+
.. code-block:: python
|
64
|
+
|
65
|
+
>>> def cond(pred, true_fun, false_fun, *operands):
|
66
|
+
... if pred:
|
67
|
+
... return true_fun(*operands)
|
68
|
+
... return false_fun(*operands)
|
69
|
+
|
70
|
+
In contrast with :func:`jax.lax.select`, using :func:`cond` indicates that only
|
71
|
+
one branch runs (subject to compiler rewrites and optimizations). When
|
72
|
+
transformed with :func:`~jax.vmap` over a batch of predicates, :func:`cond` is
|
73
|
+
converted to :func:`~jax.lax.select`.
|
74
|
+
|
75
|
+
Examples
|
76
|
+
--------
|
77
|
+
.. code-block:: python
|
78
|
+
|
79
|
+
>>> import brainstate
|
80
|
+
>>>
|
81
|
+
>>> def branch_true(x):
|
82
|
+
... return x + 1
|
83
|
+
>>>
|
84
|
+
>>> def branch_false(x):
|
85
|
+
... return x - 1
|
86
|
+
>>>
|
87
|
+
>>> brainstate.transform.cond(True, branch_true, branch_false, 3)
|
88
|
+
"""
|
89
|
+
if not (callable(true_fun) and callable(false_fun)):
|
90
|
+
raise TypeError("true_fun and false_fun arguments should be callable.")
|
91
|
+
|
92
|
+
if pred is None:
|
93
|
+
raise TypeError("cond predicate is None")
|
94
|
+
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
95
|
+
raise TypeError(f"Pred must be a scalar, got {pred} of " +
|
96
|
+
(f"type {type(pred)}" if isinstance(pred, Sequence) else f"shape {np.shape(pred)}."))
|
97
|
+
|
98
|
+
# check pred
|
99
|
+
try:
|
100
|
+
pred_dtype = jax.dtypes.result_type(pred)
|
101
|
+
except TypeError as err:
|
102
|
+
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
|
103
|
+
if pred_dtype.kind != 'b':
|
104
|
+
if pred_dtype.kind in 'iuf':
|
105
|
+
pred = pred != 0
|
106
|
+
else:
|
107
|
+
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
|
108
|
+
|
109
|
+
# not jit
|
110
|
+
if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(pred), Tracer):
|
111
|
+
if pred:
|
112
|
+
return true_fun(*operands)
|
113
|
+
else:
|
114
|
+
return false_fun(*operands)
|
115
|
+
|
116
|
+
# evaluate jaxpr
|
117
|
+
stateful_true = StatefulFunction(true_fun, name='cond:true').make_jaxpr(*operands)
|
118
|
+
stateful_false = StatefulFunction(false_fun, name='conda:false').make_jaxpr(*operands)
|
119
|
+
|
120
|
+
# state trace and state values
|
121
|
+
state_trace = (stateful_true.get_state_trace(*operands) +
|
122
|
+
stateful_false.get_state_trace(*operands))
|
123
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
124
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
125
|
+
|
126
|
+
# wrap the functions
|
127
|
+
true_fun = wrap_single_fun_in_multi_branches(
|
128
|
+
stateful_true, state_trace, read_state_vals, True, stateful_true.get_arg_cache_key(*operands)
|
129
|
+
)
|
130
|
+
false_fun = wrap_single_fun_in_multi_branches(
|
131
|
+
stateful_false, state_trace, read_state_vals, True, stateful_false.get_arg_cache_key(*operands)
|
132
|
+
)
|
133
|
+
|
134
|
+
# cond
|
135
|
+
write_state_vals, out = jax.lax.cond(pred, true_fun, false_fun, write_state_vals, *operands)
|
136
|
+
|
137
|
+
# assign the written state values and restore the read state values
|
138
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
139
|
+
return out
|
140
|
+
|
141
|
+
|
142
|
+
@set_module_as('brainstate.transform')
|
143
|
+
def switch(index, branches: Sequence[Callable], *operands):
|
144
|
+
"""
|
145
|
+
Apply exactly one branch from ``branches`` based on ``index``.
|
146
|
+
|
147
|
+
Parameters
|
148
|
+
----------
|
149
|
+
index : int or array-like
|
150
|
+
Scalar integer specifying which branch to execute.
|
151
|
+
branches : Sequence[Callable]
|
152
|
+
Sequence of callables; each receives ``*operands``.
|
153
|
+
*operands : Any
|
154
|
+
Operands forwarded to the selected branch. May be any pytree of arrays,
|
155
|
+
scalars, or nested containers thereof.
|
156
|
+
|
157
|
+
Returns
|
158
|
+
-------
|
159
|
+
Any
|
160
|
+
Value returned by the selected branch with the same pytree structure
|
161
|
+
as the selected callable.
|
162
|
+
|
163
|
+
Notes
|
164
|
+
-----
|
165
|
+
If ``index`` is out of bounds, it is clamped to ``[0, len(branches) - 1]``.
|
166
|
+
Conceptually, :func:`switch` behaves like:
|
167
|
+
|
168
|
+
.. code-block:: python
|
169
|
+
|
170
|
+
>>> def switch(index, branches, *operands):
|
171
|
+
... safe_index = clamp(0, index, len(branches) - 1)
|
172
|
+
... return branches[safe_index](*operands)
|
173
|
+
|
174
|
+
Internally this wraps XLA's `Conditional <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
|
175
|
+
operator. When transformed with :func:`~jax.vmap` over a batch of predicates,
|
176
|
+
:func:`switch` is converted to :func:`~jax.lax.select`.
|
177
|
+
|
178
|
+
Examples
|
179
|
+
--------
|
180
|
+
.. code-block:: python
|
181
|
+
|
182
|
+
>>> import brainstate
|
183
|
+
>>>
|
184
|
+
>>> branches = (
|
185
|
+
... lambda x: x - 1,
|
186
|
+
... lambda x: x,
|
187
|
+
... lambda x: x + 1,
|
188
|
+
... )
|
189
|
+
>>>
|
190
|
+
>>> brainstate.transform.switch(2, branches, 3)
|
191
|
+
"""
|
192
|
+
# check branches
|
193
|
+
if not all(callable(branch) for branch in branches):
|
194
|
+
raise TypeError("branches argument should be a sequence of callables.")
|
195
|
+
|
196
|
+
# check index
|
197
|
+
if len(np.shape(index)) != 0:
|
198
|
+
raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
|
199
|
+
try:
|
200
|
+
index_dtype = jax.dtypes.result_type(index)
|
201
|
+
except TypeError as err:
|
202
|
+
msg = f"Index type must be an integer, got {index}."
|
203
|
+
raise TypeError(msg) from err
|
204
|
+
if index_dtype.kind not in 'iu':
|
205
|
+
raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
|
206
|
+
|
207
|
+
# format branches
|
208
|
+
branches = tuple(branches)
|
209
|
+
if len(branches) == 0:
|
210
|
+
raise ValueError("Empty branch sequence")
|
211
|
+
elif len(branches) == 1:
|
212
|
+
return branches[0](*operands)
|
213
|
+
|
214
|
+
# format index
|
215
|
+
index = jax.lax.convert_element_type(index, np.int32)
|
216
|
+
lo = np.array(0, np.int32)
|
217
|
+
hi = np.array(len(branches) - 1, np.int32)
|
218
|
+
index = jax.lax.clamp(lo, index, hi)
|
219
|
+
|
220
|
+
# not jit
|
221
|
+
if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
|
222
|
+
return branches[int(index)](*operands)
|
223
|
+
|
224
|
+
# evaluate jaxpr
|
225
|
+
wrapped_branches = [StatefulFunction(branch, name='switch').make_jaxpr(*operands) for branch in branches]
|
226
|
+
|
227
|
+
# wrap the functions
|
228
|
+
state_trace = (wrapped_branches[0].get_state_trace(*operands) +
|
229
|
+
wrapped_branches[1].get_state_trace(*operands))
|
230
|
+
state_trace.merge(*[wrapped_branch.get_state_trace(*operands)
|
231
|
+
for wrapped_branch in wrapped_branches[2:]])
|
232
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
233
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
234
|
+
branches = [
|
235
|
+
wrap_single_fun_in_multi_branches(
|
236
|
+
wrapped_branch, state_trace, read_state_vals, True, wrapped_branch.get_arg_cache_key(*operands)
|
237
|
+
)
|
238
|
+
for wrapped_branch in wrapped_branches
|
239
|
+
]
|
240
|
+
|
241
|
+
# switch
|
242
|
+
write_state_vals, out = jax.lax.switch(index, branches, write_state_vals, *operands)
|
243
|
+
|
244
|
+
# write back state values or restore them
|
245
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
246
|
+
return out
|
247
|
+
|
248
|
+
|
249
|
+
@set_module_as('brainstate.transform')
|
250
|
+
def ifelse(conditions, branches, *operands, check_cond: bool = True):
|
251
|
+
"""
|
252
|
+
Represent multi-way ``if``/``elif``/``else`` control flow.
|
253
|
+
|
254
|
+
Parameters
|
255
|
+
----------
|
256
|
+
conditions : Sequence[bool] or Array
|
257
|
+
Sequence of mutually exclusive boolean predicates. When ``check_cond`` is
|
258
|
+
``True``, exactly one entry must evaluate to ``True``.
|
259
|
+
branches : Sequence[Callable]
|
260
|
+
Sequence of branch callables evaluated lazily. Must have the same length as
|
261
|
+
``conditions``, contain at least two callables, and each branch receives
|
262
|
+
``*operands`` when selected.
|
263
|
+
*operands : Any
|
264
|
+
Operands forwarded to the selected branch as positional arguments.
|
265
|
+
check_cond : bool, default=True
|
266
|
+
Whether to verify that exactly one condition evaluates to ``True``.
|
267
|
+
|
268
|
+
Returns
|
269
|
+
-------
|
270
|
+
Any
|
271
|
+
Value produced by the branch corresponding to the active condition.
|
272
|
+
|
273
|
+
Notes
|
274
|
+
-----
|
275
|
+
When ``check_cond`` is ``True``, exactly one condition must evaluate to ``True``.
|
276
|
+
A common pattern is to make the final condition ``True`` to encode a default
|
277
|
+
branch.
|
278
|
+
|
279
|
+
Examples
|
280
|
+
--------
|
281
|
+
.. code-block:: python
|
282
|
+
|
283
|
+
>>> import brainstate
|
284
|
+
>>>
|
285
|
+
>>> def describe(a):
|
286
|
+
... return brainstate.transform.ifelse(
|
287
|
+
... conditions=[a > 5, a > 0, True],
|
288
|
+
... branches=[
|
289
|
+
... lambda: "greater than five",
|
290
|
+
... lambda: "positive",
|
291
|
+
... lambda: "non-positive",
|
292
|
+
... ],
|
293
|
+
... )
|
294
|
+
>>>
|
295
|
+
>>> describe(7)
|
296
|
+
>>> describe(-1)
|
297
|
+
"""
|
298
|
+
# check branches
|
299
|
+
if not all(callable(branch) for branch in branches):
|
300
|
+
raise TypeError("branches argument should be a sequence of callables.")
|
301
|
+
|
302
|
+
# format branches
|
303
|
+
branches = tuple(branches)
|
304
|
+
if len(branches) == 0:
|
305
|
+
raise ValueError("Empty branch sequence")
|
306
|
+
elif len(branches) == 1:
|
307
|
+
return branches[0](*operands)
|
308
|
+
if len(conditions) != len(branches):
|
309
|
+
raise ValueError("The number of conditions should be equal to the number of branches.")
|
310
|
+
|
311
|
+
# format index
|
312
|
+
conditions = jnp.asarray(conditions, np.int32)
|
313
|
+
if check_cond:
|
314
|
+
jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
|
315
|
+
index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
|
316
|
+
return switch(index, branches, *operands)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -172,10 +172,10 @@ class TestIfElse(unittest.TestCase):
|
|
172
172
|
a >= 5 and a < 10,
|
173
173
|
a >= 10],
|
174
174
|
branches=[lambda: 1,
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
175
|
+
lambda: 2,
|
176
|
+
lambda: 3,
|
177
|
+
lambda: 4,
|
178
|
+
lambda: 5])
|
179
179
|
|
180
180
|
self.assertTrue(f(3) == 3)
|
181
181
|
self.assertTrue(f(1) == 2)
|
@@ -189,10 +189,10 @@ class TestIfElse(unittest.TestCase):
|
|
189
189
|
jnp.logical_and(a <= 2, a > 0),
|
190
190
|
a <= 0],
|
191
191
|
[lambda _: 1,
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
192
|
+
lambda _: 2,
|
193
|
+
lambda _: 3,
|
194
|
+
lambda _: 4,
|
195
|
+
lambda _: 5, ],
|
196
196
|
a)
|
197
197
|
return jax.vmap(f)(operands)
|
198
198
|
|
@@ -212,8 +212,8 @@ class TestIfElse(unittest.TestCase):
|
|
212
212
|
def F3(x):
|
213
213
|
return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
|
214
214
|
[lambda x: x,
|
215
|
-
|
216
|
-
|
215
|
+
lambda x: x ** 2,
|
216
|
+
lambda x: x ** 4, ],
|
217
217
|
x)
|
218
218
|
|
219
219
|
self.assertTrue(jax.grad(F3)(9.0) == 18.)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -43,7 +43,7 @@ def _error_msg(msg, *arg, **kwargs):
|
|
43
43
|
raise ValueError(msg)
|
44
44
|
|
45
45
|
|
46
|
-
@set_module_as('brainstate.
|
46
|
+
@set_module_as('brainstate.transform')
|
47
47
|
def jit_error_if(
|
48
48
|
pred,
|
49
49
|
error: Union[Callable, str],
|
@@ -53,32 +53,34 @@ def jit_error_if(
|
|
53
53
|
"""
|
54
54
|
Check errors in a jit function.
|
55
55
|
|
56
|
+
Parameters
|
57
|
+
----------
|
58
|
+
pred : bool or Array
|
59
|
+
The boolean prediction.
|
60
|
+
error : callable or str
|
61
|
+
The error function, which raise errors, or a string indicating the error message.
|
62
|
+
*err_args
|
63
|
+
The arguments which passed into the error function.
|
64
|
+
**err_kwargs
|
65
|
+
The keywords which passed into the error function.
|
66
|
+
|
56
67
|
Examples
|
57
68
|
--------
|
58
|
-
|
59
69
|
It can give a function which receive arguments that passed from the JIT variables and raise errors.
|
60
70
|
|
61
|
-
|
62
|
-
>>> raise ValueError(f'error {x}')
|
63
|
-
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
64
|
-
>>> jit_error_if(x.sum() < 5., error, x)
|
71
|
+
.. code-block:: python
|
65
72
|
|
66
|
-
|
73
|
+
>>> def error(x):
|
74
|
+
... raise ValueError(f'error {x}')
|
75
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
76
|
+
>>> jit_error_if(x.sum() < 5., error, x)
|
67
77
|
|
68
|
-
|
69
|
-
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
78
|
+
Or, it can be a simple string message.
|
70
79
|
|
80
|
+
.. code-block:: python
|
71
81
|
|
72
|
-
|
73
|
-
|
74
|
-
pred: bool, Array
|
75
|
-
The boolean prediction.
|
76
|
-
error: callable, str
|
77
|
-
The error function, which raise errors, or a string indicating the error message.
|
78
|
-
err_args:
|
79
|
-
The arguments which passed into `err_f`.
|
80
|
-
err_kwargs:
|
81
|
-
The keywords which passed into `err_f`.
|
82
|
+
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
|
83
|
+
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
|
82
84
|
"""
|
83
85
|
if isinstance(error, str):
|
84
86
|
error = partial(_error_msg, error)
|
@@ -0,0 +1,145 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import Any, TypeVar, Callable, Sequence, Union
|
18
|
+
|
19
|
+
import jax
|
20
|
+
|
21
|
+
from brainstate import random
|
22
|
+
from brainstate._utils import set_module_as
|
23
|
+
from brainstate.graph import Node, flatten, unflatten
|
24
|
+
from ._random import restore_rngs
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'abstract_init',
|
28
|
+
]
|
29
|
+
|
30
|
+
A = TypeVar('A')
|
31
|
+
|
32
|
+
|
33
|
+
@set_module_as('brainstate.transform')
|
34
|
+
def abstract_init(
|
35
|
+
fn: Callable[..., A],
|
36
|
+
*args: Any,
|
37
|
+
rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
|
38
|
+
**kwargs: Any,
|
39
|
+
) -> A:
|
40
|
+
"""
|
41
|
+
Compute the shape/dtype of ``fn`` without any FLOPs.
|
42
|
+
|
43
|
+
This function evaluates the shape and dtype of the output of a function without
|
44
|
+
actually executing the computational operations. It's particularly useful for
|
45
|
+
initializing neural network models to understand their structure and parameter
|
46
|
+
shapes without performing expensive computations.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
fn : callable
|
51
|
+
The function whose output shape should be evaluated.
|
52
|
+
*args
|
53
|
+
Positional argument tuple of arrays, scalars, or (nested) standard
|
54
|
+
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
55
|
+
those types. Since only the ``shape`` and ``dtype`` attributes are
|
56
|
+
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
57
|
+
that duck-types as ndarrays (note however that duck-typed objects cannot
|
58
|
+
be namedtuples because those are treated as standard Python containers).
|
59
|
+
rngs : RandomState or sequence of RandomState, default random.DEFAULT
|
60
|
+
A :class:`RandomState` or a sequence of :class:`RandomState` objects
|
61
|
+
representing the random number generators to use. If not provided, the
|
62
|
+
default random number generator will be used.
|
63
|
+
**kwargs
|
64
|
+
Keyword argument dict of arrays, scalars, or (nested) standard
|
65
|
+
Python containers (pytrees) of those types. As in ``args``, array values
|
66
|
+
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
67
|
+
|
68
|
+
Returns
|
69
|
+
-------
|
70
|
+
A
|
71
|
+
A nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves,
|
72
|
+
representing the structure and shape/dtype information of the function output.
|
73
|
+
|
74
|
+
Examples
|
75
|
+
--------
|
76
|
+
Basic usage with neural network initialization:
|
77
|
+
|
78
|
+
.. code-block:: python
|
79
|
+
|
80
|
+
>>> import brainstate
|
81
|
+
>>> import jax.numpy as jnp
|
82
|
+
>>>
|
83
|
+
>>> class MLP:
|
84
|
+
... def __init__(self, n_in, n_mid, n_out):
|
85
|
+
... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
|
86
|
+
... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
|
87
|
+
>>>
|
88
|
+
>>> # Get shape information without actual computation
|
89
|
+
>>> model_shape = brainstate.transform.abstract_init(lambda: MLP(1, 2, 3))
|
90
|
+
|
91
|
+
With function arguments:
|
92
|
+
|
93
|
+
.. code-block:: python
|
94
|
+
|
95
|
+
>>> def create_model(input_size, hidden_size, output_size):
|
96
|
+
... return brainstate.nn.Sequential([
|
97
|
+
... brainstate.nn.Linear(input_size, hidden_size),
|
98
|
+
... brainstate.nn.ReLU(),
|
99
|
+
... brainstate.nn.Linear(hidden_size, output_size)
|
100
|
+
... ])
|
101
|
+
>>>
|
102
|
+
>>> # Abstract initialization with arguments
|
103
|
+
>>> model_shape = brainstate.transform.abstract_init(
|
104
|
+
... create_model, 784, 256, 10
|
105
|
+
... )
|
106
|
+
|
107
|
+
Using custom random number generators:
|
108
|
+
|
109
|
+
.. code-block:: python
|
110
|
+
|
111
|
+
>>> import brainstate.random as random
|
112
|
+
>>>
|
113
|
+
>>> # Create custom RNG
|
114
|
+
>>> rng = random.RandomState(42)
|
115
|
+
>>>
|
116
|
+
>>> def init_with_custom_weights():
|
117
|
+
... return brainstate.nn.Linear(10, 5)
|
118
|
+
>>>
|
119
|
+
>>> model_shape = brainstate.transform.abstract_init(
|
120
|
+
... init_with_custom_weights, rngs=rng
|
121
|
+
... )
|
122
|
+
|
123
|
+
Evaluating function with array inputs:
|
124
|
+
|
125
|
+
.. code-block:: python
|
126
|
+
|
127
|
+
>>> def model_forward(x):
|
128
|
+
... layer = brainstate.nn.Linear(x.shape[-1], 128)
|
129
|
+
... return layer(x)
|
130
|
+
>>>
|
131
|
+
>>> # Use ShapeDtypeStruct to represent input without actual data
|
132
|
+
>>> input_shape = jax.ShapeDtypeStruct((32, 784), jnp.float32)
|
133
|
+
>>> output_shape = brainstate.transform.abstract_init(model_forward, input_shape)
|
134
|
+
"""
|
135
|
+
|
136
|
+
@functools.wraps(fn)
|
137
|
+
@restore_rngs(rngs=rngs)
|
138
|
+
def _eval_shape_fn(*args_, **kwargs_):
|
139
|
+
out = fn(*args_, **kwargs_)
|
140
|
+
assert isinstance(out, Node), 'The output of the function must be Node'
|
141
|
+
graph_def, treefy_states = flatten(out)
|
142
|
+
return graph_def, treefy_states
|
143
|
+
|
144
|
+
graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
145
|
+
return unflatten(graph_def_, treefy_states_)
|