brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,58 +1,58 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax.numpy as jnp
|
19
|
-
import numpy as np
|
20
|
-
|
21
|
-
import brainstate
|
22
|
-
|
23
|
-
|
24
|
-
class TestForLoop(unittest.TestCase):
|
25
|
-
def test_for_loop(self):
|
26
|
-
a = brainstate.ShortTermState(0.)
|
27
|
-
b = brainstate.ShortTermState(0.)
|
28
|
-
|
29
|
-
def f(i):
|
30
|
-
a.value += (1 + b.value)
|
31
|
-
return a.value
|
32
|
-
|
33
|
-
n_iter = 10
|
34
|
-
ops = np.arange(n_iter)
|
35
|
-
r = brainstate.compile.for_loop(f, ops)
|
36
|
-
|
37
|
-
print(a)
|
38
|
-
print(b)
|
39
|
-
self.assertTrue(a.value == n_iter)
|
40
|
-
self.assertTrue(jnp.allclose(r, ops + 1))
|
41
|
-
|
42
|
-
def test_checkpointed_for_loop(self):
|
43
|
-
a = brainstate.ShortTermState(0.)
|
44
|
-
b = brainstate.ShortTermState(0.)
|
45
|
-
|
46
|
-
def f(i):
|
47
|
-
a.value += (1 + b.value)
|
48
|
-
return a.value
|
49
|
-
|
50
|
-
n_iter = 18
|
51
|
-
ops = jnp.arange(n_iter)
|
52
|
-
r = brainstate.compile.checkpointed_for_loop(f, ops, base=2, pbar=brainstate.compile.ProgressBar())
|
53
|
-
|
54
|
-
print(a)
|
55
|
-
print(b)
|
56
|
-
print(r)
|
57
|
-
self.assertTrue(a.value == n_iter)
|
58
|
-
self.assertTrue(jnp.allclose(r, ops + 1))
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax.numpy as jnp
|
19
|
+
import numpy as np
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
|
24
|
+
class TestForLoop(unittest.TestCase):
|
25
|
+
def test_for_loop(self):
|
26
|
+
a = brainstate.ShortTermState(0.)
|
27
|
+
b = brainstate.ShortTermState(0.)
|
28
|
+
|
29
|
+
def f(i):
|
30
|
+
a.value += (1 + b.value)
|
31
|
+
return a.value
|
32
|
+
|
33
|
+
n_iter = 10
|
34
|
+
ops = np.arange(n_iter)
|
35
|
+
r = brainstate.compile.for_loop(f, ops)
|
36
|
+
|
37
|
+
print(a)
|
38
|
+
print(b)
|
39
|
+
self.assertTrue(a.value == n_iter)
|
40
|
+
self.assertTrue(jnp.allclose(r, ops + 1))
|
41
|
+
|
42
|
+
def test_checkpointed_for_loop(self):
|
43
|
+
a = brainstate.ShortTermState(0.)
|
44
|
+
b = brainstate.ShortTermState(0.)
|
45
|
+
|
46
|
+
def f(i):
|
47
|
+
a.value += (1 + b.value)
|
48
|
+
return a.value
|
49
|
+
|
50
|
+
n_iter = 18
|
51
|
+
ops = jnp.arange(n_iter)
|
52
|
+
r = brainstate.compile.checkpointed_for_loop(f, ops, base=2, pbar=brainstate.compile.ProgressBar())
|
53
|
+
|
54
|
+
print(a)
|
55
|
+
print(b)
|
56
|
+
print(r)
|
57
|
+
self.assertTrue(a.value == n_iter)
|
58
|
+
self.assertTrue(jnp.allclose(r, ops + 1))
|
@@ -1,184 +1,184 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import math
|
17
|
-
from typing import Any, Callable, TypeVar
|
18
|
-
|
19
|
-
import jax
|
20
|
-
|
21
|
-
from brainstate._utils import set_module_as
|
22
|
-
from ._loop_collect_return import _bounded_while_loop
|
23
|
-
from ._make_jaxpr import StatefulFunction
|
24
|
-
from ._util import wrap_single_fun_in_multi_branches_while_loop as wrap_fn
|
25
|
-
from ._util import write_back_state_values
|
26
|
-
|
27
|
-
X = TypeVar('X')
|
28
|
-
Y = TypeVar('Y')
|
29
|
-
T = TypeVar('T')
|
30
|
-
Carry = TypeVar('Carry')
|
31
|
-
BooleanNumeric = Any # A bool, or a Boolean array.
|
32
|
-
|
33
|
-
__all__ = [
|
34
|
-
'while_loop', 'bounded_while_loop',
|
35
|
-
]
|
36
|
-
|
37
|
-
|
38
|
-
@set_module_as('brainstate.compile')
|
39
|
-
def while_loop(
|
40
|
-
cond_fun: Callable[[T], BooleanNumeric],
|
41
|
-
body_fun: Callable[[T], T],
|
42
|
-
init_val: T
|
43
|
-
) -> T:
|
44
|
-
"""
|
45
|
-
Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
46
|
-
|
47
|
-
The `Haskell-like type signature`_ in brief is
|
48
|
-
|
49
|
-
.. code-block:: haskell
|
50
|
-
|
51
|
-
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
52
|
-
|
53
|
-
The semantics of ``while_loop`` are given by this Python implementation::
|
54
|
-
|
55
|
-
def while_loop(cond_fun, body_fun, init_val):
|
56
|
-
val = init_val
|
57
|
-
while cond_fun(val):
|
58
|
-
val = body_fun(val)
|
59
|
-
return val
|
60
|
-
|
61
|
-
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
62
|
-
to a single WhileOp. That makes it useful for reducing compilation times
|
63
|
-
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
64
|
-
function are unrolled, leading to large XLA computations.
|
65
|
-
|
66
|
-
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
67
|
-
fixed shape and dtype across all iterations (and not just be consistent up to
|
68
|
-
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
69
|
-
other words, the type ``a`` in the type signature above represents an array
|
70
|
-
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
71
|
-
structure with a fixed structure and arrays with fixed shape and dtype at the
|
72
|
-
leaves).
|
73
|
-
|
74
|
-
Another difference from using Python-native loop constructs is that
|
75
|
-
``while_loop`` is not reverse-mode differentiable because XLA computations
|
76
|
-
require static bounds on memory requirements.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
cond_fun: function of type ``a -> Bool``.
|
80
|
-
body_fun: function of type ``a -> a``.
|
81
|
-
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
82
|
-
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
83
|
-
loop carry value.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
The output from the final iteration of body_fun, of type ``a``.
|
87
|
-
|
88
|
-
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
89
|
-
"""
|
90
|
-
if not (callable(body_fun) and callable(cond_fun)):
|
91
|
-
raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
|
92
|
-
if jax.config.jax_disable_jit:
|
93
|
-
try:
|
94
|
-
val = init_val
|
95
|
-
while cond_fun(val):
|
96
|
-
val = body_fun(val)
|
97
|
-
return val
|
98
|
-
except jax.core.ConcretizationTypeError:
|
99
|
-
# Can't run this while_loop in Python (e.g. because there's a vmap
|
100
|
-
# transformation on it), so we fall back to the primitive version.
|
101
|
-
pass
|
102
|
-
|
103
|
-
# evaluate jaxpr
|
104
|
-
stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
|
105
|
-
stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
|
106
|
-
if len(stateful_cond.get_write_states()) != 0:
|
107
|
-
raise ValueError("while_loop: cond_fun should not have any write states.")
|
108
|
-
|
109
|
-
# state trace and state values
|
110
|
-
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
111
|
-
read_state_vals = state_trace.get_read_state_values(True)
|
112
|
-
write_state_vals = state_trace.get_write_state_values(True)
|
113
|
-
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
|
114
|
-
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
|
115
|
-
|
116
|
-
# while_loop
|
117
|
-
state_vals, final_val = jax.lax.while_loop(new_cond_fn, new_body_fn, (write_state_vals, init_val))
|
118
|
-
|
119
|
-
# write back state values or restore them
|
120
|
-
write_back_state_values(state_trace, read_state_vals, state_vals)
|
121
|
-
return final_val
|
122
|
-
|
123
|
-
|
124
|
-
def bounded_while_loop(
|
125
|
-
cond_fun: Callable[[T], BooleanNumeric],
|
126
|
-
body_fun: Callable[[T], T],
|
127
|
-
init_val: T,
|
128
|
-
*,
|
129
|
-
max_steps: int,
|
130
|
-
base: int = 16,
|
131
|
-
):
|
132
|
-
"""
|
133
|
-
While loop with a bound on the maximum number of steps.
|
134
|
-
|
135
|
-
This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
136
|
-
|
137
|
-
This function is useful when you want to ensure that a while loop terminates
|
138
|
-
even if the condition function is never false. The function is implemented
|
139
|
-
using a scan operation, so it is reverse-mode differentiable.
|
140
|
-
|
141
|
-
Args:
|
142
|
-
cond_fun: A function of type ``a -> Bool``.
|
143
|
-
body_fun: A function of type ``a -> a``.
|
144
|
-
init_val: The initial value of type ``a``.
|
145
|
-
max_steps: A bound on the maximum number of steps, after which the loop
|
146
|
-
terminates unconditionally.
|
147
|
-
base: Run time will increase slightly as `base` increases. Compilation time will
|
148
|
-
decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
|
149
|
-
(Which happens as `base` increases.)
|
150
|
-
|
151
|
-
Returns:
|
152
|
-
The final value, as if computed by a `lax.while_loop`.
|
153
|
-
"""
|
154
|
-
|
155
|
-
# checking
|
156
|
-
if not isinstance(max_steps, int) or max_steps < 0:
|
157
|
-
raise ValueError("max_steps must be a non-negative integer")
|
158
|
-
init_val = jax.tree.map(jax.numpy.asarray, init_val)
|
159
|
-
if max_steps == 0:
|
160
|
-
return init_val
|
161
|
-
|
162
|
-
# evaluate jaxpr
|
163
|
-
stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
|
164
|
-
stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
|
165
|
-
if len(stateful_cond.get_write_states()) != 0:
|
166
|
-
raise ValueError("while_loop: cond_fun should not have any write states.")
|
167
|
-
|
168
|
-
# state trace and state values
|
169
|
-
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
170
|
-
read_state_vals = state_trace.get_read_state_values(True)
|
171
|
-
write_state_vals = state_trace.get_write_state_values(True)
|
172
|
-
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
|
173
|
-
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
|
174
|
-
|
175
|
-
# initial value
|
176
|
-
init_val = (write_state_vals, init_val)
|
177
|
-
|
178
|
-
# while_loop
|
179
|
-
rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
|
180
|
-
state_vals, val = _bounded_while_loop(new_cond_fn, new_body_fn, init_val, rounded_max_steps, base, None)
|
181
|
-
|
182
|
-
# write back state values or restore them
|
183
|
-
write_back_state_values(state_trace, read_state_vals, state_vals)
|
184
|
-
return val
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import math
|
17
|
+
from typing import Any, Callable, TypeVar
|
18
|
+
|
19
|
+
import jax
|
20
|
+
|
21
|
+
from brainstate._utils import set_module_as
|
22
|
+
from ._loop_collect_return import _bounded_while_loop
|
23
|
+
from ._make_jaxpr import StatefulFunction
|
24
|
+
from ._util import wrap_single_fun_in_multi_branches_while_loop as wrap_fn
|
25
|
+
from ._util import write_back_state_values
|
26
|
+
|
27
|
+
X = TypeVar('X')
|
28
|
+
Y = TypeVar('Y')
|
29
|
+
T = TypeVar('T')
|
30
|
+
Carry = TypeVar('Carry')
|
31
|
+
BooleanNumeric = Any # A bool, or a Boolean array.
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
'while_loop', 'bounded_while_loop',
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
@set_module_as('brainstate.compile')
|
39
|
+
def while_loop(
|
40
|
+
cond_fun: Callable[[T], BooleanNumeric],
|
41
|
+
body_fun: Callable[[T], T],
|
42
|
+
init_val: T
|
43
|
+
) -> T:
|
44
|
+
"""
|
45
|
+
Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
46
|
+
|
47
|
+
The `Haskell-like type signature`_ in brief is
|
48
|
+
|
49
|
+
.. code-block:: haskell
|
50
|
+
|
51
|
+
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
52
|
+
|
53
|
+
The semantics of ``while_loop`` are given by this Python implementation::
|
54
|
+
|
55
|
+
def while_loop(cond_fun, body_fun, init_val):
|
56
|
+
val = init_val
|
57
|
+
while cond_fun(val):
|
58
|
+
val = body_fun(val)
|
59
|
+
return val
|
60
|
+
|
61
|
+
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
62
|
+
to a single WhileOp. That makes it useful for reducing compilation times
|
63
|
+
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
64
|
+
function are unrolled, leading to large XLA computations.
|
65
|
+
|
66
|
+
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
67
|
+
fixed shape and dtype across all iterations (and not just be consistent up to
|
68
|
+
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
69
|
+
other words, the type ``a`` in the type signature above represents an array
|
70
|
+
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
71
|
+
structure with a fixed structure and arrays with fixed shape and dtype at the
|
72
|
+
leaves).
|
73
|
+
|
74
|
+
Another difference from using Python-native loop constructs is that
|
75
|
+
``while_loop`` is not reverse-mode differentiable because XLA computations
|
76
|
+
require static bounds on memory requirements.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
cond_fun: function of type ``a -> Bool``.
|
80
|
+
body_fun: function of type ``a -> a``.
|
81
|
+
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
82
|
+
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
83
|
+
loop carry value.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The output from the final iteration of body_fun, of type ``a``.
|
87
|
+
|
88
|
+
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
89
|
+
"""
|
90
|
+
if not (callable(body_fun) and callable(cond_fun)):
|
91
|
+
raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
|
92
|
+
if jax.config.jax_disable_jit:
|
93
|
+
try:
|
94
|
+
val = init_val
|
95
|
+
while cond_fun(val):
|
96
|
+
val = body_fun(val)
|
97
|
+
return val
|
98
|
+
except jax.core.ConcretizationTypeError:
|
99
|
+
# Can't run this while_loop in Python (e.g. because there's a vmap
|
100
|
+
# transformation on it), so we fall back to the primitive version.
|
101
|
+
pass
|
102
|
+
|
103
|
+
# evaluate jaxpr
|
104
|
+
stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
|
105
|
+
stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
|
106
|
+
if len(stateful_cond.get_write_states()) != 0:
|
107
|
+
raise ValueError("while_loop: cond_fun should not have any write states.")
|
108
|
+
|
109
|
+
# state trace and state values
|
110
|
+
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
111
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
112
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
113
|
+
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
|
114
|
+
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
|
115
|
+
|
116
|
+
# while_loop
|
117
|
+
state_vals, final_val = jax.lax.while_loop(new_cond_fn, new_body_fn, (write_state_vals, init_val))
|
118
|
+
|
119
|
+
# write back state values or restore them
|
120
|
+
write_back_state_values(state_trace, read_state_vals, state_vals)
|
121
|
+
return final_val
|
122
|
+
|
123
|
+
|
124
|
+
def bounded_while_loop(
|
125
|
+
cond_fun: Callable[[T], BooleanNumeric],
|
126
|
+
body_fun: Callable[[T], T],
|
127
|
+
init_val: T,
|
128
|
+
*,
|
129
|
+
max_steps: int,
|
130
|
+
base: int = 16,
|
131
|
+
):
|
132
|
+
"""
|
133
|
+
While loop with a bound on the maximum number of steps.
|
134
|
+
|
135
|
+
This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
136
|
+
|
137
|
+
This function is useful when you want to ensure that a while loop terminates
|
138
|
+
even if the condition function is never false. The function is implemented
|
139
|
+
using a scan operation, so it is reverse-mode differentiable.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
cond_fun: A function of type ``a -> Bool``.
|
143
|
+
body_fun: A function of type ``a -> a``.
|
144
|
+
init_val: The initial value of type ``a``.
|
145
|
+
max_steps: A bound on the maximum number of steps, after which the loop
|
146
|
+
terminates unconditionally.
|
147
|
+
base: Run time will increase slightly as `base` increases. Compilation time will
|
148
|
+
decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
|
149
|
+
(Which happens as `base` increases.)
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
The final value, as if computed by a `lax.while_loop`.
|
153
|
+
"""
|
154
|
+
|
155
|
+
# checking
|
156
|
+
if not isinstance(max_steps, int) or max_steps < 0:
|
157
|
+
raise ValueError("max_steps must be a non-negative integer")
|
158
|
+
init_val = jax.tree.map(jax.numpy.asarray, init_val)
|
159
|
+
if max_steps == 0:
|
160
|
+
return init_val
|
161
|
+
|
162
|
+
# evaluate jaxpr
|
163
|
+
stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
|
164
|
+
stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
|
165
|
+
if len(stateful_cond.get_write_states()) != 0:
|
166
|
+
raise ValueError("while_loop: cond_fun should not have any write states.")
|
167
|
+
|
168
|
+
# state trace and state values
|
169
|
+
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
170
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
171
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
172
|
+
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
|
173
|
+
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
|
174
|
+
|
175
|
+
# initial value
|
176
|
+
init_val = (write_state_vals, init_val)
|
177
|
+
|
178
|
+
# while_loop
|
179
|
+
rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
|
180
|
+
state_vals, val = _bounded_while_loop(new_cond_fn, new_body_fn, init_val, rounded_max_steps, base, None)
|
181
|
+
|
182
|
+
# write back state values or restore them
|
183
|
+
write_back_state_values(state_trace, read_state_vals, state_vals)
|
184
|
+
return val
|
@@ -1,50 +1,50 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from unittest import TestCase
|
18
|
-
|
19
|
-
import brainstate
|
20
|
-
|
21
|
-
|
22
|
-
class TestWhileLoop(TestCase):
|
23
|
-
def test1(self):
|
24
|
-
a = brainstate.State(1.)
|
25
|
-
b = brainstate.State(20.)
|
26
|
-
|
27
|
-
def cond(_):
|
28
|
-
return a.value < b.value
|
29
|
-
|
30
|
-
def body(_):
|
31
|
-
a.value += 1.
|
32
|
-
|
33
|
-
brainstate.compile.while_loop(cond, body, None)
|
34
|
-
|
35
|
-
print(a.value, b.value)
|
36
|
-
|
37
|
-
def test2(self):
|
38
|
-
a = brainstate.State(1.)
|
39
|
-
b = brainstate.State(20.)
|
40
|
-
|
41
|
-
def cond(x):
|
42
|
-
return a.value < b.value
|
43
|
-
|
44
|
-
def body(x):
|
45
|
-
a.value += x
|
46
|
-
return x
|
47
|
-
|
48
|
-
r = brainstate.compile.while_loop(cond, body, 1.)
|
49
|
-
|
50
|
-
print(a.value, b.value, r)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from unittest import TestCase
|
18
|
+
|
19
|
+
import brainstate
|
20
|
+
|
21
|
+
|
22
|
+
class TestWhileLoop(TestCase):
|
23
|
+
def test1(self):
|
24
|
+
a = brainstate.State(1.)
|
25
|
+
b = brainstate.State(20.)
|
26
|
+
|
27
|
+
def cond(_):
|
28
|
+
return a.value < b.value
|
29
|
+
|
30
|
+
def body(_):
|
31
|
+
a.value += 1.
|
32
|
+
|
33
|
+
brainstate.compile.while_loop(cond, body, None)
|
34
|
+
|
35
|
+
print(a.value, b.value)
|
36
|
+
|
37
|
+
def test2(self):
|
38
|
+
a = brainstate.State(1.)
|
39
|
+
b = brainstate.State(20.)
|
40
|
+
|
41
|
+
def cond(x):
|
42
|
+
return a.value < b.value
|
43
|
+
|
44
|
+
def body(x):
|
45
|
+
a.value += x
|
46
|
+
return x
|
47
|
+
|
48
|
+
r = brainstate.compile.while_loop(cond, body, 1.)
|
49
|
+
|
50
|
+
print(a.value, b.value, r)
|