brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,675 +1,675 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import math
|
17
|
-
from functools import wraps
|
18
|
-
from typing import Callable, Optional, TypeVar, Tuple, Any
|
19
|
-
|
20
|
-
import jax
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate._utils import set_module_as
|
24
|
-
from ._make_jaxpr import StatefulFunction
|
25
|
-
from ._progress_bar import ProgressBar
|
26
|
-
from ._unvmap import unvmap
|
27
|
-
from ._util import wrap_single_fun
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
# "scan" syntax, which is similar to jax.lax.scan
|
31
|
-
'scan', 'checkpointed_scan',
|
32
|
-
|
33
|
-
# "for_loop" syntax
|
34
|
-
'for_loop', 'checkpointed_for_loop',
|
35
|
-
]
|
36
|
-
|
37
|
-
X = TypeVar('X')
|
38
|
-
Y = TypeVar('Y')
|
39
|
-
T = TypeVar('T')
|
40
|
-
Carry = TypeVar('Carry')
|
41
|
-
|
42
|
-
|
43
|
-
def _wrap_fun_with_pbar(
|
44
|
-
fun: Callable[[Carry, X], Tuple[Carry, Y]],
|
45
|
-
pbar_runner: Callable
|
46
|
-
):
|
47
|
-
@wraps(fun)
|
48
|
-
def new_fun(new_carry, inputs):
|
49
|
-
i, old_carry = new_carry
|
50
|
-
new_carry, new_outputs = fun(old_carry, inputs)
|
51
|
-
pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
|
52
|
-
return (i + 1, new_carry), new_outputs
|
53
|
-
|
54
|
-
return new_fun
|
55
|
-
|
56
|
-
|
57
|
-
@set_module_as('brainstate.transform')
|
58
|
-
def scan(
|
59
|
-
f: Callable[[Carry, X], Tuple[Carry, Y]],
|
60
|
-
init: Carry,
|
61
|
-
xs: X,
|
62
|
-
length: int | None = None,
|
63
|
-
reverse: bool = False,
|
64
|
-
unroll: int | bool = 1,
|
65
|
-
pbar: ProgressBar | int | None = None,
|
66
|
-
) -> Tuple[Carry, Y]:
|
67
|
-
"""
|
68
|
-
Scan a function over leading array axes while carrying along state.
|
69
|
-
|
70
|
-
The `Haskell-like type signature`_ in brief is
|
71
|
-
|
72
|
-
.. code-block:: haskell
|
73
|
-
|
74
|
-
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
75
|
-
|
76
|
-
where for any array type specifier ``t``, ``[t]`` represents the type with an additional
|
77
|
-
leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
|
78
|
-
represents the type with the same pytree structure and corresponding leaves
|
79
|
-
each with an additional leading axis.
|
80
|
-
|
81
|
-
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
|
82
|
-
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
|
83
|
-
given roughly by this Python implementation:
|
84
|
-
|
85
|
-
.. code-block:: python
|
86
|
-
|
87
|
-
>>> def scan(f, init, xs, length=None):
|
88
|
-
... if xs is None:
|
89
|
-
... xs = [None] * length
|
90
|
-
... carry = init
|
91
|
-
... ys = []
|
92
|
-
... for x in xs:
|
93
|
-
... carry, y = f(carry, x)
|
94
|
-
... ys.append(y)
|
95
|
-
... return carry, np.stack(ys)
|
96
|
-
|
97
|
-
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
|
98
|
-
values, and so multiple arrays can be scanned over at once and produce multiple
|
99
|
-
output arrays. ``None`` is actually a special case of this, as it represents an
|
100
|
-
empty pytree.
|
101
|
-
|
102
|
-
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
|
103
|
-
lowered to a single WhileOp. That makes it useful for reducing
|
104
|
-
compilation times for JIT-compiled functions, since native Python
|
105
|
-
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
|
106
|
-
XLA computations.
|
107
|
-
|
108
|
-
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
109
|
-
across all iterations (and not just be consistent up to NumPy rank/shape
|
110
|
-
broadcasting and dtype promotion rules, for example). In other words, the type
|
111
|
-
``c`` in the type signature above represents an array with a fixed shape and
|
112
|
-
dtype (or a nested tuple/list/dict container data structure with a fixed
|
113
|
-
structure and arrays with fixed shape and dtype at the leaves).
|
114
|
-
|
115
|
-
Parameters
|
116
|
-
----------
|
117
|
-
f : callable
|
118
|
-
A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
119
|
-
that ``f`` accepts two arguments where the first is a value of the loop
|
120
|
-
carry and the second is a slice of ``xs`` along its leading axis, and that
|
121
|
-
``f`` returns a pair where the first element represents a new value for
|
122
|
-
the loop carry and the second represents a slice of the output.
|
123
|
-
init : Carry
|
124
|
-
An initial loop carry value of type ``c``, which can be a scalar,
|
125
|
-
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
126
|
-
the initial loop carry value. This value must have the same structure as
|
127
|
-
the first element of the pair returned by ``f``.
|
128
|
-
xs : X
|
129
|
-
The value of type ``[a]`` over which to scan along the leading axis,
|
130
|
-
where ``[a]`` can be an array or any pytree (nested Python
|
131
|
-
tuple/list/dict) thereof with consistent leading axis sizes.
|
132
|
-
length : int, optional
|
133
|
-
Optional integer specifying the number of loop iterations, which
|
134
|
-
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
135
|
-
be used to perform scans where no input ``xs`` are needed).
|
136
|
-
reverse : bool, default False
|
137
|
-
Optional boolean specifying whether to run the scan iteration
|
138
|
-
forward (the default) or in reverse, equivalent to reversing the leading
|
139
|
-
axes of the arrays in both ``xs`` and in ``ys``.
|
140
|
-
unroll : int or bool, default 1
|
141
|
-
Optional positive int or bool specifying, in the underlying
|
142
|
-
operation of the scan primitive, how many scan iterations to unroll within
|
143
|
-
a single iteration of a loop. If an integer is provided, it determines how
|
144
|
-
many unrolled loop iterations to run within a single rolled iteration of
|
145
|
-
the loop. If a boolean is provided, it will determine if the loop is
|
146
|
-
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
147
|
-
`unroll=False`).
|
148
|
-
pbar : ProgressBar or int, optional
|
149
|
-
Optional :class:`~.ProgressBar` instance to display the progress
|
150
|
-
of the scan operation.
|
151
|
-
|
152
|
-
Returns
|
153
|
-
-------
|
154
|
-
tuple of (Carry, Y)
|
155
|
-
A pair of type ``(c, [b])`` where the first element represents the final
|
156
|
-
loop carry value and the second element represents the stacked outputs of
|
157
|
-
the second output of ``f`` when scanned over the leading axis of the inputs.
|
158
|
-
|
159
|
-
Examples
|
160
|
-
--------
|
161
|
-
Basic scan operation:
|
162
|
-
|
163
|
-
.. code-block:: python
|
164
|
-
|
165
|
-
>>> import brainstate
|
166
|
-
>>> import jax.numpy as jnp
|
167
|
-
>>>
|
168
|
-
>>> def step_fn(carry, x):
|
169
|
-
... return carry + x, carry * x
|
170
|
-
>>>
|
171
|
-
>>> init = 0.0
|
172
|
-
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
173
|
-
>>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs)
|
174
|
-
|
175
|
-
Scan with progress bar:
|
176
|
-
|
177
|
-
.. code-block:: python
|
178
|
-
|
179
|
-
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
180
|
-
>>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs, pbar=pbar)
|
181
|
-
|
182
|
-
References
|
183
|
-
----------
|
184
|
-
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
185
|
-
"""
|
186
|
-
# check "f"
|
187
|
-
if not callable(f):
|
188
|
-
raise TypeError("f argument should be a callable.")
|
189
|
-
|
190
|
-
# check "xs"
|
191
|
-
xs_flat, xs_tree = jax.tree.flatten(xs)
|
192
|
-
try:
|
193
|
-
lengths = [x.shape[0] for x in xs_flat]
|
194
|
-
except AttributeError as err:
|
195
|
-
raise ValueError("scan got value with no leading axis to scan over: "
|
196
|
-
"{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
|
197
|
-
if length is not None:
|
198
|
-
length = int(length)
|
199
|
-
if not all(length == l for l in lengths):
|
200
|
-
raise ValueError(("scan got `length` argument of {} which disagrees with "
|
201
|
-
"leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
|
202
|
-
else:
|
203
|
-
unique_lengths = set(lengths)
|
204
|
-
if len(unique_lengths) > 1:
|
205
|
-
msg = "scan got values with different leading axis sizes: {}."
|
206
|
-
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
207
|
-
elif len(unique_lengths) == 0:
|
208
|
-
raise ValueError("scan got no values to scan over and `length` not provided.")
|
209
|
-
else:
|
210
|
-
length, = unique_lengths
|
211
|
-
|
212
|
-
# function with progress bar
|
213
|
-
has_pbar = False
|
214
|
-
if pbar is not None:
|
215
|
-
has_pbar = True
|
216
|
-
if isinstance(pbar, ProgressBar):
|
217
|
-
pbar_runner = pbar.init(length)
|
218
|
-
elif isinstance(pbar, int):
|
219
|
-
pbar_runner = ProgressBar(freq=pbar).init(length)
|
220
|
-
else:
|
221
|
-
raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
|
222
|
-
f = _wrap_fun_with_pbar(f, pbar_runner)
|
223
|
-
init = (0, init) if pbar else init
|
224
|
-
|
225
|
-
# not jit
|
226
|
-
if jax.config.jax_disable_jit:
|
227
|
-
if length == 0:
|
228
|
-
raise ValueError(
|
229
|
-
"zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
|
230
|
-
carry = init
|
231
|
-
ys = []
|
232
|
-
maybe_reversed = reversed if reverse else lambda x: x
|
233
|
-
for i in maybe_reversed(range(length)):
|
234
|
-
xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
|
235
|
-
carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
|
236
|
-
ys.append(y)
|
237
|
-
stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
|
238
|
-
if has_pbar:
|
239
|
-
return carry[1], stacked_y
|
240
|
-
else:
|
241
|
-
return carry, stacked_y
|
242
|
-
|
243
|
-
# evaluate jaxpr, get all states #
|
244
|
-
# ------------------------------ #
|
245
|
-
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
246
|
-
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
247
|
-
args = [init, xs_tree.unflatten(x_avals)]
|
248
|
-
stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
|
249
|
-
state_trace = stateful_fun.get_state_trace(*args)
|
250
|
-
all_writen_state_vals = state_trace.get_write_state_values(True)
|
251
|
-
all_read_state_vals = state_trace.get_read_state_values(True)
|
252
|
-
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
253
|
-
|
254
|
-
# scan
|
255
|
-
init = (all_writen_state_vals, init)
|
256
|
-
(
|
257
|
-
(
|
258
|
-
all_writen_state_vals,
|
259
|
-
carry
|
260
|
-
),
|
261
|
-
ys
|
262
|
-
) = jax.lax.scan(
|
263
|
-
wrapped_f,
|
264
|
-
init,
|
265
|
-
xs,
|
266
|
-
length=length,
|
267
|
-
reverse=reverse,
|
268
|
-
unroll=unroll
|
269
|
-
)
|
270
|
-
# assign the written state values and restore the read state values
|
271
|
-
state_trace.assign_state_vals_v2(all_read_state_vals, all_writen_state_vals)
|
272
|
-
# carry
|
273
|
-
if has_pbar:
|
274
|
-
carry = carry[1]
|
275
|
-
return carry, ys
|
276
|
-
|
277
|
-
|
278
|
-
@set_module_as('brainstate.transform')
|
279
|
-
def checkpointed_scan(
|
280
|
-
f: Callable[[Carry, X], Tuple[Carry, Y]],
|
281
|
-
init: Carry,
|
282
|
-
xs: X,
|
283
|
-
length: Optional[int] = None,
|
284
|
-
base: int = 16,
|
285
|
-
pbar: Optional[ProgressBar | int] = None,
|
286
|
-
) -> Tuple[Carry, Y]:
|
287
|
-
"""
|
288
|
-
Scan a function over leading array axes while carrying along state.
|
289
|
-
This function is similar to :func:`~scan` but with a checkpointed version.
|
290
|
-
|
291
|
-
Parameters
|
292
|
-
----------
|
293
|
-
f : callable
|
294
|
-
A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
295
|
-
that ``f`` accepts two arguments where the first is a value of the loop
|
296
|
-
carry and the second is a slice of ``xs`` along its leading axis, and that
|
297
|
-
``f`` returns a pair where the first element represents a new value for
|
298
|
-
the loop carry and the second represents a slice of the output.
|
299
|
-
init : Carry
|
300
|
-
An initial loop carry value of type ``c``, which can be a scalar,
|
301
|
-
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
302
|
-
the initial loop carry value. This value must have the same structure as
|
303
|
-
the first element of the pair returned by ``f``.
|
304
|
-
xs : X
|
305
|
-
The value of type ``[a]`` over which to scan along the leading axis,
|
306
|
-
where ``[a]`` can be an array or any pytree (nested Python
|
307
|
-
tuple/list/dict) thereof with consistent leading axis sizes.
|
308
|
-
length : int, optional
|
309
|
-
Optional integer specifying the number of loop iterations, which
|
310
|
-
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
311
|
-
be used to perform scans where no input ``xs`` are needed).
|
312
|
-
base : int, default 16
|
313
|
-
Optional integer specifying the base for the bounded scan loop.
|
314
|
-
pbar : ProgressBar or int, optional
|
315
|
-
Optional :class:`~.ProgressBar` instance to display the progress
|
316
|
-
of the scan operation.
|
317
|
-
|
318
|
-
Returns
|
319
|
-
-------
|
320
|
-
tuple of (Carry, Y)
|
321
|
-
A pair of type ``(c, [b])`` where the first element represents the final
|
322
|
-
loop carry value and the second element represents the stacked outputs of
|
323
|
-
the second output of ``f`` when scanned over the leading axis of the inputs.
|
324
|
-
|
325
|
-
Examples
|
326
|
-
--------
|
327
|
-
Basic checkpointed scan operation:
|
328
|
-
|
329
|
-
.. code-block:: python
|
330
|
-
|
331
|
-
>>> import brainstate
|
332
|
-
>>> import jax.numpy as jnp
|
333
|
-
>>>
|
334
|
-
>>> def step_fn(carry, x):
|
335
|
-
... return carry + x, carry * x
|
336
|
-
>>>
|
337
|
-
>>> init = 0.0
|
338
|
-
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
339
|
-
>>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)
|
340
|
-
|
341
|
-
Using custom base for checkpointing:
|
342
|
-
|
343
|
-
.. code-block:: python
|
344
|
-
|
345
|
-
>>> final_carry, ys = brainstate.transform.checkpointed_scan(
|
346
|
-
... step_fn, init, xs, base=8
|
347
|
-
... )
|
348
|
-
"""
|
349
|
-
# check "f"
|
350
|
-
if not callable(f):
|
351
|
-
raise TypeError("f argument should be a callable.")
|
352
|
-
|
353
|
-
# check "xs"
|
354
|
-
xs_flat, xs_tree = jax.tree.flatten(xs)
|
355
|
-
try:
|
356
|
-
lengths = [x.shape[0] for x in xs_flat]
|
357
|
-
except AttributeError as err:
|
358
|
-
raise ValueError("scan got value with no leading axis to scan over: "
|
359
|
-
"{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
|
360
|
-
if length is not None:
|
361
|
-
length = int(length)
|
362
|
-
if not all(length == l for l in lengths):
|
363
|
-
raise ValueError(("scan got `length` argument of {} which disagrees with "
|
364
|
-
"leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
|
365
|
-
else:
|
366
|
-
unique_lengths = set(lengths)
|
367
|
-
if len(unique_lengths) > 1:
|
368
|
-
msg = "scan got values with different leading axis sizes: {}."
|
369
|
-
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
370
|
-
elif len(unique_lengths) == 0:
|
371
|
-
raise ValueError("scan got no values to scan over and `length` not provided.")
|
372
|
-
else:
|
373
|
-
length, = unique_lengths
|
374
|
-
|
375
|
-
# function with progress bar
|
376
|
-
if isinstance(pbar, ProgressBar):
|
377
|
-
pbar_runner = pbar.init(length)
|
378
|
-
elif isinstance(pbar, int):
|
379
|
-
pbar_runner = ProgressBar(freq=pbar).init(length)
|
380
|
-
else:
|
381
|
-
pbar_runner = None
|
382
|
-
|
383
|
-
# evaluate jaxpr
|
384
|
-
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
385
|
-
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
386
|
-
args = (init, xs_tree.unflatten(x_avals))
|
387
|
-
stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
|
388
|
-
state_trace = stateful_fun.get_state_trace(*args)
|
389
|
-
cache_key = stateful_fun.get_arg_cache_key(*args)
|
390
|
-
# get all states
|
391
|
-
been_written = state_trace.been_writen
|
392
|
-
read_state_vals = state_trace.get_read_state_values(True)
|
393
|
-
write_state_vals = state_trace.get_write_state_values(True)
|
394
|
-
|
395
|
-
# initialize the collected values/dataa
|
396
|
-
out_info = stateful_fun.get_out_shapes_by_cache(cache_key)[0]
|
397
|
-
assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
|
398
|
-
data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
|
399
|
-
del out_info
|
400
|
-
|
401
|
-
def wrapped_cond_fun(inp):
|
402
|
-
return inp[-1] < length
|
403
|
-
|
404
|
-
def wrapped_body_fun(inp):
|
405
|
-
(prev_write_states, carray), prev_collect, i = inp
|
406
|
-
# progress bar
|
407
|
-
if pbar_runner is not None:
|
408
|
-
pbar_runner(unvmap(i, op='none'))
|
409
|
-
# call the function
|
410
|
-
prev_states = [w_val if write else r_val
|
411
|
-
for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
|
412
|
-
new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
|
413
|
-
prev_states, carray, jax.tree.map(lambda x: x[i], xs)
|
414
|
-
)
|
415
|
-
# new written states
|
416
|
-
new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
|
417
|
-
|
418
|
-
# out of length
|
419
|
-
pred = i < length
|
420
|
-
new_collect = jax.tree.map(
|
421
|
-
# lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
|
422
|
-
lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
|
423
|
-
prev_collect,
|
424
|
-
out4updates,
|
425
|
-
)
|
426
|
-
new_write_states = jax.tree.map(
|
427
|
-
lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
|
428
|
-
prev_write_states,
|
429
|
-
new_write_states,
|
430
|
-
is_leaf=lambda x: x is None
|
431
|
-
)
|
432
|
-
new_carray = jax.tree.map(
|
433
|
-
lambda pc, nc: jax.lax.select(pred, nc, pc),
|
434
|
-
carray,
|
435
|
-
new_carray,
|
436
|
-
)
|
437
|
-
return (new_write_states, new_carray), new_collect, i + 1
|
438
|
-
|
439
|
-
# while_loop
|
440
|
-
rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
|
441
|
-
(write_state_vals, carry), data2collection, _ = (
|
442
|
-
_bounded_while_loop(
|
443
|
-
wrapped_cond_fun,
|
444
|
-
wrapped_body_fun,
|
445
|
-
((write_state_vals, init), data2collection, 0),
|
446
|
-
rounded_max_steps,
|
447
|
-
base,
|
448
|
-
pbar_runner
|
449
|
-
)
|
450
|
-
)
|
451
|
-
# assign the written state values and restore the read state values
|
452
|
-
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
453
|
-
del write_state_vals, read_state_vals, stateful_fun
|
454
|
-
return carry, data2collection
|
455
|
-
|
456
|
-
|
457
|
-
def _forloop_to_scan_fun(f: Callable):
|
458
|
-
@wraps(f)
|
459
|
-
def scan_fun(carry, x):
|
460
|
-
return carry, f(*x)
|
461
|
-
|
462
|
-
return scan_fun
|
463
|
-
|
464
|
-
|
465
|
-
@set_module_as('brainstate.transform')
|
466
|
-
def for_loop(
|
467
|
-
f: Callable[..., Y],
|
468
|
-
*xs,
|
469
|
-
length: Optional[int] = None,
|
470
|
-
reverse: bool = False,
|
471
|
-
unroll: int | bool = 1,
|
472
|
-
pbar: Optional[ProgressBar | int] = None
|
473
|
-
) -> Y:
|
474
|
-
"""
|
475
|
-
``for-loop`` control flow with :py:class:`~.State`.
|
476
|
-
|
477
|
-
Parameters
|
478
|
-
----------
|
479
|
-
f : callable
|
480
|
-
A Python function to be looped over that accepts variadic arguments
|
481
|
-
corresponding to slices of ``xs`` along their leading axes, and returns
|
482
|
-
the output for that iteration.
|
483
|
-
*xs
|
484
|
-
The values over which to loop along the leading axis,
|
485
|
-
where each can be an array or any pytree (nested Python
|
486
|
-
tuple/list/dict) thereof with consistent leading axis sizes.
|
487
|
-
length : int, optional
|
488
|
-
Optional integer specifying the number of loop iterations, which
|
489
|
-
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
490
|
-
be used to perform loops where no input ``xs`` are needed).
|
491
|
-
reverse : bool, default False
|
492
|
-
Optional boolean specifying whether to run the loop iteration
|
493
|
-
forward (the default) or in reverse, equivalent to reversing the leading
|
494
|
-
axes of the arrays in both ``xs`` and in ``ys``.
|
495
|
-
unroll : int or bool, default 1
|
496
|
-
Optional positive int or bool specifying, in the underlying
|
497
|
-
operation of the scan primitive, how many loop iterations to unroll within
|
498
|
-
a single iteration of a loop. If an integer is provided, it determines how
|
499
|
-
many unrolled loop iterations to run within a single rolled iteration of
|
500
|
-
the loop. If a boolean is provided, it will determine if the loop is
|
501
|
-
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
502
|
-
`unroll=False`).
|
503
|
-
pbar : ProgressBar or int, optional
|
504
|
-
Optional :class:`~.ProgressBar` instance to display the progress
|
505
|
-
of the loop operation.
|
506
|
-
|
507
|
-
Returns
|
508
|
-
-------
|
509
|
-
Y
|
510
|
-
The stacked outputs of ``f`` when looped over the leading axis of the inputs.
|
511
|
-
|
512
|
-
Examples
|
513
|
-
--------
|
514
|
-
Basic for-loop operation:
|
515
|
-
|
516
|
-
.. code-block:: python
|
517
|
-
|
518
|
-
>>> import brainstate
|
519
|
-
>>> import jax.numpy as jnp
|
520
|
-
>>>
|
521
|
-
>>> def process_item(x, y):
|
522
|
-
... return x * y + 1
|
523
|
-
>>>
|
524
|
-
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
525
|
-
>>> ys = jnp.array([4.0, 5.0, 6.0])
|
526
|
-
>>> results = brainstate.transform.for_loop(process_item, xs, ys)
|
527
|
-
|
528
|
-
For-loop with progress bar:
|
529
|
-
|
530
|
-
.. code-block:: python
|
531
|
-
|
532
|
-
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
533
|
-
>>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
|
534
|
-
|
535
|
-
For-loop with reverse iteration:
|
536
|
-
|
537
|
-
.. code-block:: python
|
538
|
-
|
539
|
-
>>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)
|
540
|
-
"""
|
541
|
-
_, ys = scan(
|
542
|
-
_forloop_to_scan_fun(f),
|
543
|
-
init=None,
|
544
|
-
xs=xs,
|
545
|
-
length=length,
|
546
|
-
reverse=reverse,
|
547
|
-
unroll=unroll,
|
548
|
-
pbar=pbar
|
549
|
-
)
|
550
|
-
return ys
|
551
|
-
|
552
|
-
|
553
|
-
@set_module_as('brainstate.transform')
|
554
|
-
def checkpointed_for_loop(
|
555
|
-
f: Callable[..., Y],
|
556
|
-
*xs: X,
|
557
|
-
length: Optional[int] = None,
|
558
|
-
base: int = 16,
|
559
|
-
pbar: Optional[ProgressBar | int] = None,
|
560
|
-
) -> Y:
|
561
|
-
"""
|
562
|
-
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
|
563
|
-
|
564
|
-
Parameters
|
565
|
-
----------
|
566
|
-
f : callable
|
567
|
-
A Python function to be looped over that accepts variadic arguments
|
568
|
-
corresponding to slices of ``xs`` along their leading axes, and returns
|
569
|
-
the output for that iteration.
|
570
|
-
*xs : X
|
571
|
-
The values over which to loop along the leading axis,
|
572
|
-
where each can be an array or any pytree (nested Python
|
573
|
-
tuple/list/dict) thereof with consistent leading axis sizes.
|
574
|
-
length : int, optional
|
575
|
-
Optional integer specifying the number of loop iterations, which
|
576
|
-
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
577
|
-
be used to perform loops where no input ``xs`` are needed).
|
578
|
-
base : int, default 16
|
579
|
-
Optional integer specifying the base for the bounded loop.
|
580
|
-
pbar : ProgressBar or int, optional
|
581
|
-
Optional :class:`~.ProgressBar` instance to display the progress
|
582
|
-
of the loop operation.
|
583
|
-
|
584
|
-
Returns
|
585
|
-
-------
|
586
|
-
Y
|
587
|
-
The stacked outputs of ``f`` when looped over the leading axis of the inputs.
|
588
|
-
|
589
|
-
Examples
|
590
|
-
--------
|
591
|
-
Basic checkpointed for-loop operation:
|
592
|
-
|
593
|
-
.. code-block:: python
|
594
|
-
|
595
|
-
>>> import brainstate
|
596
|
-
>>> import jax.numpy as jnp
|
597
|
-
>>>
|
598
|
-
>>> def process_item(x, y):
|
599
|
-
... return x * y + 1
|
600
|
-
>>>
|
601
|
-
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
602
|
-
>>> ys = jnp.array([4.0, 5.0, 6.0])
|
603
|
-
>>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
|
604
|
-
|
605
|
-
Using custom base for checkpointing:
|
606
|
-
|
607
|
-
.. code-block:: python
|
608
|
-
|
609
|
-
>>> results = brainstate.transform.checkpointed_for_loop(
|
610
|
-
... process_item, xs, ys, base=8
|
611
|
-
... )
|
612
|
-
"""
|
613
|
-
_, ys = checkpointed_scan(
|
614
|
-
_forloop_to_scan_fun(f),
|
615
|
-
init=None,
|
616
|
-
xs=xs,
|
617
|
-
length=length,
|
618
|
-
base=base,
|
619
|
-
pbar=pbar
|
620
|
-
)
|
621
|
-
return ys
|
622
|
-
|
623
|
-
|
624
|
-
# This function is adapted from ``while_loop`` in
|
625
|
-
# `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
626
|
-
|
627
|
-
# There's several tricks happening here to work around various limitations of JAX.
|
628
|
-
# (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
|
629
|
-
# 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
|
630
|
-
# is converted to a `lax.select`, which executes both branches unconditionally.
|
631
|
-
# Thus writing this naively, using a plain `lax.cond`, will mean the loop always
|
632
|
-
# runs to `max_steps` when executing under vmap. Instead we run (only) until every
|
633
|
-
# batch element has finished.
|
634
|
-
# 2. Treating in-place updates specially in the body_fun. Specifically we need to
|
635
|
-
# `lax.select` the update-to-make, not the updated buffer. This is because the
|
636
|
-
# latter instead results in XLA:CPU failing to determine that the buffer can be
|
637
|
-
# updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
|
638
|
-
# This is done through the extra `inplace` argument provided to `body_fun`.
|
639
|
-
# 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
|
640
|
-
# `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
|
641
|
-
# θ(number of steps actually taken).
|
642
|
-
# 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
|
643
|
-
# fewest superfluous operations. In practice this implies quite deep recursion in
|
644
|
-
# the construction of the bounded while loop, and this slows down the jaxpr
|
645
|
-
# creation and the XLA compilation. We choose `base=16` as a reasonable-looking
|
646
|
-
# compromise between compilation time and run time.
|
647
|
-
|
648
|
-
def _bounded_while_loop(
|
649
|
-
cond_fun: Callable,
|
650
|
-
body_fun: Callable,
|
651
|
-
val: Any,
|
652
|
-
max_steps: int,
|
653
|
-
base: int,
|
654
|
-
pbar_runner: Optional[Callable] = None
|
655
|
-
):
|
656
|
-
if max_steps == 1:
|
657
|
-
return body_fun(val)
|
658
|
-
else:
|
659
|
-
|
660
|
-
def true_call(val_):
|
661
|
-
return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
|
662
|
-
|
663
|
-
def false_call(val_):
|
664
|
-
if pbar_runner is not None:
|
665
|
-
pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
|
666
|
-
return val_[:-1] + (val_[-1] + max_steps,)
|
667
|
-
|
668
|
-
def scan_fn(val_, _):
|
669
|
-
return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
|
670
|
-
|
671
|
-
# Don't put checkpointing on the lowest level
|
672
|
-
if max_steps != base:
|
673
|
-
scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
|
674
|
-
|
675
|
-
return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import math
|
17
|
+
from functools import wraps
|
18
|
+
from typing import Callable, Optional, TypeVar, Tuple, Any
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate._utils import set_module_as
|
24
|
+
from ._make_jaxpr import StatefulFunction
|
25
|
+
from ._progress_bar import ProgressBar
|
26
|
+
from ._unvmap import unvmap
|
27
|
+
from ._util import wrap_single_fun
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
# "scan" syntax, which is similar to jax.lax.scan
|
31
|
+
'scan', 'checkpointed_scan',
|
32
|
+
|
33
|
+
# "for_loop" syntax
|
34
|
+
'for_loop', 'checkpointed_for_loop',
|
35
|
+
]
|
36
|
+
|
37
|
+
X = TypeVar('X')
|
38
|
+
Y = TypeVar('Y')
|
39
|
+
T = TypeVar('T')
|
40
|
+
Carry = TypeVar('Carry')
|
41
|
+
|
42
|
+
|
43
|
+
def _wrap_fun_with_pbar(
|
44
|
+
fun: Callable[[Carry, X], Tuple[Carry, Y]],
|
45
|
+
pbar_runner: Callable
|
46
|
+
):
|
47
|
+
@wraps(fun)
|
48
|
+
def new_fun(new_carry, inputs):
|
49
|
+
i, old_carry = new_carry
|
50
|
+
new_carry, new_outputs = fun(old_carry, inputs)
|
51
|
+
pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
|
52
|
+
return (i + 1, new_carry), new_outputs
|
53
|
+
|
54
|
+
return new_fun
|
55
|
+
|
56
|
+
|
57
|
+
@set_module_as('brainstate.transform')
|
58
|
+
def scan(
|
59
|
+
f: Callable[[Carry, X], Tuple[Carry, Y]],
|
60
|
+
init: Carry,
|
61
|
+
xs: X,
|
62
|
+
length: int | None = None,
|
63
|
+
reverse: bool = False,
|
64
|
+
unroll: int | bool = 1,
|
65
|
+
pbar: ProgressBar | int | None = None,
|
66
|
+
) -> Tuple[Carry, Y]:
|
67
|
+
"""
|
68
|
+
Scan a function over leading array axes while carrying along state.
|
69
|
+
|
70
|
+
The `Haskell-like type signature`_ in brief is
|
71
|
+
|
72
|
+
.. code-block:: haskell
|
73
|
+
|
74
|
+
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
75
|
+
|
76
|
+
where for any array type specifier ``t``, ``[t]`` represents the type with an additional
|
77
|
+
leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
|
78
|
+
represents the type with the same pytree structure and corresponding leaves
|
79
|
+
each with an additional leading axis.
|
80
|
+
|
81
|
+
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
|
82
|
+
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
|
83
|
+
given roughly by this Python implementation:
|
84
|
+
|
85
|
+
.. code-block:: python
|
86
|
+
|
87
|
+
>>> def scan(f, init, xs, length=None):
|
88
|
+
... if xs is None:
|
89
|
+
... xs = [None] * length
|
90
|
+
... carry = init
|
91
|
+
... ys = []
|
92
|
+
... for x in xs:
|
93
|
+
... carry, y = f(carry, x)
|
94
|
+
... ys.append(y)
|
95
|
+
... return carry, np.stack(ys)
|
96
|
+
|
97
|
+
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
|
98
|
+
values, and so multiple arrays can be scanned over at once and produce multiple
|
99
|
+
output arrays. ``None`` is actually a special case of this, as it represents an
|
100
|
+
empty pytree.
|
101
|
+
|
102
|
+
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
|
103
|
+
lowered to a single WhileOp. That makes it useful for reducing
|
104
|
+
compilation times for JIT-compiled functions, since native Python
|
105
|
+
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
|
106
|
+
XLA computations.
|
107
|
+
|
108
|
+
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
109
|
+
across all iterations (and not just be consistent up to NumPy rank/shape
|
110
|
+
broadcasting and dtype promotion rules, for example). In other words, the type
|
111
|
+
``c`` in the type signature above represents an array with a fixed shape and
|
112
|
+
dtype (or a nested tuple/list/dict container data structure with a fixed
|
113
|
+
structure and arrays with fixed shape and dtype at the leaves).
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
f : callable
|
118
|
+
A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
119
|
+
that ``f`` accepts two arguments where the first is a value of the loop
|
120
|
+
carry and the second is a slice of ``xs`` along its leading axis, and that
|
121
|
+
``f`` returns a pair where the first element represents a new value for
|
122
|
+
the loop carry and the second represents a slice of the output.
|
123
|
+
init : Carry
|
124
|
+
An initial loop carry value of type ``c``, which can be a scalar,
|
125
|
+
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
126
|
+
the initial loop carry value. This value must have the same structure as
|
127
|
+
the first element of the pair returned by ``f``.
|
128
|
+
xs : X
|
129
|
+
The value of type ``[a]`` over which to scan along the leading axis,
|
130
|
+
where ``[a]`` can be an array or any pytree (nested Python
|
131
|
+
tuple/list/dict) thereof with consistent leading axis sizes.
|
132
|
+
length : int, optional
|
133
|
+
Optional integer specifying the number of loop iterations, which
|
134
|
+
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
135
|
+
be used to perform scans where no input ``xs`` are needed).
|
136
|
+
reverse : bool, default False
|
137
|
+
Optional boolean specifying whether to run the scan iteration
|
138
|
+
forward (the default) or in reverse, equivalent to reversing the leading
|
139
|
+
axes of the arrays in both ``xs`` and in ``ys``.
|
140
|
+
unroll : int or bool, default 1
|
141
|
+
Optional positive int or bool specifying, in the underlying
|
142
|
+
operation of the scan primitive, how many scan iterations to unroll within
|
143
|
+
a single iteration of a loop. If an integer is provided, it determines how
|
144
|
+
many unrolled loop iterations to run within a single rolled iteration of
|
145
|
+
the loop. If a boolean is provided, it will determine if the loop is
|
146
|
+
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
147
|
+
`unroll=False`).
|
148
|
+
pbar : ProgressBar or int, optional
|
149
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
150
|
+
of the scan operation.
|
151
|
+
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
tuple of (Carry, Y)
|
155
|
+
A pair of type ``(c, [b])`` where the first element represents the final
|
156
|
+
loop carry value and the second element represents the stacked outputs of
|
157
|
+
the second output of ``f`` when scanned over the leading axis of the inputs.
|
158
|
+
|
159
|
+
Examples
|
160
|
+
--------
|
161
|
+
Basic scan operation:
|
162
|
+
|
163
|
+
.. code-block:: python
|
164
|
+
|
165
|
+
>>> import brainstate
|
166
|
+
>>> import jax.numpy as jnp
|
167
|
+
>>>
|
168
|
+
>>> def step_fn(carry, x):
|
169
|
+
... return carry + x, carry * x
|
170
|
+
>>>
|
171
|
+
>>> init = 0.0
|
172
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
173
|
+
>>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs)
|
174
|
+
|
175
|
+
Scan with progress bar:
|
176
|
+
|
177
|
+
.. code-block:: python
|
178
|
+
|
179
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
180
|
+
>>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs, pbar=pbar)
|
181
|
+
|
182
|
+
References
|
183
|
+
----------
|
184
|
+
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
185
|
+
"""
|
186
|
+
# check "f"
|
187
|
+
if not callable(f):
|
188
|
+
raise TypeError("f argument should be a callable.")
|
189
|
+
|
190
|
+
# check "xs"
|
191
|
+
xs_flat, xs_tree = jax.tree.flatten(xs)
|
192
|
+
try:
|
193
|
+
lengths = [x.shape[0] for x in xs_flat]
|
194
|
+
except AttributeError as err:
|
195
|
+
raise ValueError("scan got value with no leading axis to scan over: "
|
196
|
+
"{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
|
197
|
+
if length is not None:
|
198
|
+
length = int(length)
|
199
|
+
if not all(length == l for l in lengths):
|
200
|
+
raise ValueError(("scan got `length` argument of {} which disagrees with "
|
201
|
+
"leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
|
202
|
+
else:
|
203
|
+
unique_lengths = set(lengths)
|
204
|
+
if len(unique_lengths) > 1:
|
205
|
+
msg = "scan got values with different leading axis sizes: {}."
|
206
|
+
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
207
|
+
elif len(unique_lengths) == 0:
|
208
|
+
raise ValueError("scan got no values to scan over and `length` not provided.")
|
209
|
+
else:
|
210
|
+
length, = unique_lengths
|
211
|
+
|
212
|
+
# function with progress bar
|
213
|
+
has_pbar = False
|
214
|
+
if pbar is not None:
|
215
|
+
has_pbar = True
|
216
|
+
if isinstance(pbar, ProgressBar):
|
217
|
+
pbar_runner = pbar.init(length)
|
218
|
+
elif isinstance(pbar, int):
|
219
|
+
pbar_runner = ProgressBar(freq=pbar).init(length)
|
220
|
+
else:
|
221
|
+
raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
|
222
|
+
f = _wrap_fun_with_pbar(f, pbar_runner)
|
223
|
+
init = (0, init) if pbar else init
|
224
|
+
|
225
|
+
# not jit
|
226
|
+
if jax.config.jax_disable_jit:
|
227
|
+
if length == 0:
|
228
|
+
raise ValueError(
|
229
|
+
"zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
|
230
|
+
carry = init
|
231
|
+
ys = []
|
232
|
+
maybe_reversed = reversed if reverse else lambda x: x
|
233
|
+
for i in maybe_reversed(range(length)):
|
234
|
+
xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
|
235
|
+
carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
|
236
|
+
ys.append(y)
|
237
|
+
stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
|
238
|
+
if has_pbar:
|
239
|
+
return carry[1], stacked_y
|
240
|
+
else:
|
241
|
+
return carry, stacked_y
|
242
|
+
|
243
|
+
# evaluate jaxpr, get all states #
|
244
|
+
# ------------------------------ #
|
245
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
246
|
+
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
247
|
+
args = [init, xs_tree.unflatten(x_avals)]
|
248
|
+
stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
|
249
|
+
state_trace = stateful_fun.get_state_trace(*args)
|
250
|
+
all_writen_state_vals = state_trace.get_write_state_values(True)
|
251
|
+
all_read_state_vals = state_trace.get_read_state_values(True)
|
252
|
+
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
253
|
+
|
254
|
+
# scan
|
255
|
+
init = (all_writen_state_vals, init)
|
256
|
+
(
|
257
|
+
(
|
258
|
+
all_writen_state_vals,
|
259
|
+
carry
|
260
|
+
),
|
261
|
+
ys
|
262
|
+
) = jax.lax.scan(
|
263
|
+
wrapped_f,
|
264
|
+
init,
|
265
|
+
xs,
|
266
|
+
length=length,
|
267
|
+
reverse=reverse,
|
268
|
+
unroll=unroll
|
269
|
+
)
|
270
|
+
# assign the written state values and restore the read state values
|
271
|
+
state_trace.assign_state_vals_v2(all_read_state_vals, all_writen_state_vals)
|
272
|
+
# carry
|
273
|
+
if has_pbar:
|
274
|
+
carry = carry[1]
|
275
|
+
return carry, ys
|
276
|
+
|
277
|
+
|
278
|
+
@set_module_as('brainstate.transform')
|
279
|
+
def checkpointed_scan(
|
280
|
+
f: Callable[[Carry, X], Tuple[Carry, Y]],
|
281
|
+
init: Carry,
|
282
|
+
xs: X,
|
283
|
+
length: Optional[int] = None,
|
284
|
+
base: int = 16,
|
285
|
+
pbar: Optional[ProgressBar | int] = None,
|
286
|
+
) -> Tuple[Carry, Y]:
|
287
|
+
"""
|
288
|
+
Scan a function over leading array axes while carrying along state.
|
289
|
+
This function is similar to :func:`~scan` but with a checkpointed version.
|
290
|
+
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
f : callable
|
294
|
+
A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
295
|
+
that ``f`` accepts two arguments where the first is a value of the loop
|
296
|
+
carry and the second is a slice of ``xs`` along its leading axis, and that
|
297
|
+
``f`` returns a pair where the first element represents a new value for
|
298
|
+
the loop carry and the second represents a slice of the output.
|
299
|
+
init : Carry
|
300
|
+
An initial loop carry value of type ``c``, which can be a scalar,
|
301
|
+
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
302
|
+
the initial loop carry value. This value must have the same structure as
|
303
|
+
the first element of the pair returned by ``f``.
|
304
|
+
xs : X
|
305
|
+
The value of type ``[a]`` over which to scan along the leading axis,
|
306
|
+
where ``[a]`` can be an array or any pytree (nested Python
|
307
|
+
tuple/list/dict) thereof with consistent leading axis sizes.
|
308
|
+
length : int, optional
|
309
|
+
Optional integer specifying the number of loop iterations, which
|
310
|
+
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
311
|
+
be used to perform scans where no input ``xs`` are needed).
|
312
|
+
base : int, default 16
|
313
|
+
Optional integer specifying the base for the bounded scan loop.
|
314
|
+
pbar : ProgressBar or int, optional
|
315
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
316
|
+
of the scan operation.
|
317
|
+
|
318
|
+
Returns
|
319
|
+
-------
|
320
|
+
tuple of (Carry, Y)
|
321
|
+
A pair of type ``(c, [b])`` where the first element represents the final
|
322
|
+
loop carry value and the second element represents the stacked outputs of
|
323
|
+
the second output of ``f`` when scanned over the leading axis of the inputs.
|
324
|
+
|
325
|
+
Examples
|
326
|
+
--------
|
327
|
+
Basic checkpointed scan operation:
|
328
|
+
|
329
|
+
.. code-block:: python
|
330
|
+
|
331
|
+
>>> import brainstate
|
332
|
+
>>> import jax.numpy as jnp
|
333
|
+
>>>
|
334
|
+
>>> def step_fn(carry, x):
|
335
|
+
... return carry + x, carry * x
|
336
|
+
>>>
|
337
|
+
>>> init = 0.0
|
338
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
339
|
+
>>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)
|
340
|
+
|
341
|
+
Using custom base for checkpointing:
|
342
|
+
|
343
|
+
.. code-block:: python
|
344
|
+
|
345
|
+
>>> final_carry, ys = brainstate.transform.checkpointed_scan(
|
346
|
+
... step_fn, init, xs, base=8
|
347
|
+
... )
|
348
|
+
"""
|
349
|
+
# check "f"
|
350
|
+
if not callable(f):
|
351
|
+
raise TypeError("f argument should be a callable.")
|
352
|
+
|
353
|
+
# check "xs"
|
354
|
+
xs_flat, xs_tree = jax.tree.flatten(xs)
|
355
|
+
try:
|
356
|
+
lengths = [x.shape[0] for x in xs_flat]
|
357
|
+
except AttributeError as err:
|
358
|
+
raise ValueError("scan got value with no leading axis to scan over: "
|
359
|
+
"{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
|
360
|
+
if length is not None:
|
361
|
+
length = int(length)
|
362
|
+
if not all(length == l for l in lengths):
|
363
|
+
raise ValueError(("scan got `length` argument of {} which disagrees with "
|
364
|
+
"leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
|
365
|
+
else:
|
366
|
+
unique_lengths = set(lengths)
|
367
|
+
if len(unique_lengths) > 1:
|
368
|
+
msg = "scan got values with different leading axis sizes: {}."
|
369
|
+
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
370
|
+
elif len(unique_lengths) == 0:
|
371
|
+
raise ValueError("scan got no values to scan over and `length` not provided.")
|
372
|
+
else:
|
373
|
+
length, = unique_lengths
|
374
|
+
|
375
|
+
# function with progress bar
|
376
|
+
if isinstance(pbar, ProgressBar):
|
377
|
+
pbar_runner = pbar.init(length)
|
378
|
+
elif isinstance(pbar, int):
|
379
|
+
pbar_runner = ProgressBar(freq=pbar).init(length)
|
380
|
+
else:
|
381
|
+
pbar_runner = None
|
382
|
+
|
383
|
+
# evaluate jaxpr
|
384
|
+
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
385
|
+
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
386
|
+
args = (init, xs_tree.unflatten(x_avals))
|
387
|
+
stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
|
388
|
+
state_trace = stateful_fun.get_state_trace(*args)
|
389
|
+
cache_key = stateful_fun.get_arg_cache_key(*args)
|
390
|
+
# get all states
|
391
|
+
been_written = state_trace.been_writen
|
392
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
393
|
+
write_state_vals = state_trace.get_write_state_values(True)
|
394
|
+
|
395
|
+
# initialize the collected values/dataa
|
396
|
+
out_info = stateful_fun.get_out_shapes_by_cache(cache_key)[0]
|
397
|
+
assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
|
398
|
+
data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
|
399
|
+
del out_info
|
400
|
+
|
401
|
+
def wrapped_cond_fun(inp):
|
402
|
+
return inp[-1] < length
|
403
|
+
|
404
|
+
def wrapped_body_fun(inp):
|
405
|
+
(prev_write_states, carray), prev_collect, i = inp
|
406
|
+
# progress bar
|
407
|
+
if pbar_runner is not None:
|
408
|
+
pbar_runner(unvmap(i, op='none'))
|
409
|
+
# call the function
|
410
|
+
prev_states = [w_val if write else r_val
|
411
|
+
for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
|
412
|
+
new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
|
413
|
+
prev_states, carray, jax.tree.map(lambda x: x[i], xs)
|
414
|
+
)
|
415
|
+
# new written states
|
416
|
+
new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
|
417
|
+
|
418
|
+
# out of length
|
419
|
+
pred = i < length
|
420
|
+
new_collect = jax.tree.map(
|
421
|
+
# lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
|
422
|
+
lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
|
423
|
+
prev_collect,
|
424
|
+
out4updates,
|
425
|
+
)
|
426
|
+
new_write_states = jax.tree.map(
|
427
|
+
lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
|
428
|
+
prev_write_states,
|
429
|
+
new_write_states,
|
430
|
+
is_leaf=lambda x: x is None
|
431
|
+
)
|
432
|
+
new_carray = jax.tree.map(
|
433
|
+
lambda pc, nc: jax.lax.select(pred, nc, pc),
|
434
|
+
carray,
|
435
|
+
new_carray,
|
436
|
+
)
|
437
|
+
return (new_write_states, new_carray), new_collect, i + 1
|
438
|
+
|
439
|
+
# while_loop
|
440
|
+
rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
|
441
|
+
(write_state_vals, carry), data2collection, _ = (
|
442
|
+
_bounded_while_loop(
|
443
|
+
wrapped_cond_fun,
|
444
|
+
wrapped_body_fun,
|
445
|
+
((write_state_vals, init), data2collection, 0),
|
446
|
+
rounded_max_steps,
|
447
|
+
base,
|
448
|
+
pbar_runner
|
449
|
+
)
|
450
|
+
)
|
451
|
+
# assign the written state values and restore the read state values
|
452
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
453
|
+
del write_state_vals, read_state_vals, stateful_fun
|
454
|
+
return carry, data2collection
|
455
|
+
|
456
|
+
|
457
|
+
def _forloop_to_scan_fun(f: Callable):
|
458
|
+
@wraps(f)
|
459
|
+
def scan_fun(carry, x):
|
460
|
+
return carry, f(*x)
|
461
|
+
|
462
|
+
return scan_fun
|
463
|
+
|
464
|
+
|
465
|
+
@set_module_as('brainstate.transform')
|
466
|
+
def for_loop(
|
467
|
+
f: Callable[..., Y],
|
468
|
+
*xs,
|
469
|
+
length: Optional[int] = None,
|
470
|
+
reverse: bool = False,
|
471
|
+
unroll: int | bool = 1,
|
472
|
+
pbar: Optional[ProgressBar | int] = None
|
473
|
+
) -> Y:
|
474
|
+
"""
|
475
|
+
``for-loop`` control flow with :py:class:`~.State`.
|
476
|
+
|
477
|
+
Parameters
|
478
|
+
----------
|
479
|
+
f : callable
|
480
|
+
A Python function to be looped over that accepts variadic arguments
|
481
|
+
corresponding to slices of ``xs`` along their leading axes, and returns
|
482
|
+
the output for that iteration.
|
483
|
+
*xs
|
484
|
+
The values over which to loop along the leading axis,
|
485
|
+
where each can be an array or any pytree (nested Python
|
486
|
+
tuple/list/dict) thereof with consistent leading axis sizes.
|
487
|
+
length : int, optional
|
488
|
+
Optional integer specifying the number of loop iterations, which
|
489
|
+
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
490
|
+
be used to perform loops where no input ``xs`` are needed).
|
491
|
+
reverse : bool, default False
|
492
|
+
Optional boolean specifying whether to run the loop iteration
|
493
|
+
forward (the default) or in reverse, equivalent to reversing the leading
|
494
|
+
axes of the arrays in both ``xs`` and in ``ys``.
|
495
|
+
unroll : int or bool, default 1
|
496
|
+
Optional positive int or bool specifying, in the underlying
|
497
|
+
operation of the scan primitive, how many loop iterations to unroll within
|
498
|
+
a single iteration of a loop. If an integer is provided, it determines how
|
499
|
+
many unrolled loop iterations to run within a single rolled iteration of
|
500
|
+
the loop. If a boolean is provided, it will determine if the loop is
|
501
|
+
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
502
|
+
`unroll=False`).
|
503
|
+
pbar : ProgressBar or int, optional
|
504
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
505
|
+
of the loop operation.
|
506
|
+
|
507
|
+
Returns
|
508
|
+
-------
|
509
|
+
Y
|
510
|
+
The stacked outputs of ``f`` when looped over the leading axis of the inputs.
|
511
|
+
|
512
|
+
Examples
|
513
|
+
--------
|
514
|
+
Basic for-loop operation:
|
515
|
+
|
516
|
+
.. code-block:: python
|
517
|
+
|
518
|
+
>>> import brainstate
|
519
|
+
>>> import jax.numpy as jnp
|
520
|
+
>>>
|
521
|
+
>>> def process_item(x, y):
|
522
|
+
... return x * y + 1
|
523
|
+
>>>
|
524
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
525
|
+
>>> ys = jnp.array([4.0, 5.0, 6.0])
|
526
|
+
>>> results = brainstate.transform.for_loop(process_item, xs, ys)
|
527
|
+
|
528
|
+
For-loop with progress bar:
|
529
|
+
|
530
|
+
.. code-block:: python
|
531
|
+
|
532
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
533
|
+
>>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
|
534
|
+
|
535
|
+
For-loop with reverse iteration:
|
536
|
+
|
537
|
+
.. code-block:: python
|
538
|
+
|
539
|
+
>>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)
|
540
|
+
"""
|
541
|
+
_, ys = scan(
|
542
|
+
_forloop_to_scan_fun(f),
|
543
|
+
init=None,
|
544
|
+
xs=xs,
|
545
|
+
length=length,
|
546
|
+
reverse=reverse,
|
547
|
+
unroll=unroll,
|
548
|
+
pbar=pbar
|
549
|
+
)
|
550
|
+
return ys
|
551
|
+
|
552
|
+
|
553
|
+
@set_module_as('brainstate.transform')
|
554
|
+
def checkpointed_for_loop(
|
555
|
+
f: Callable[..., Y],
|
556
|
+
*xs: X,
|
557
|
+
length: Optional[int] = None,
|
558
|
+
base: int = 16,
|
559
|
+
pbar: Optional[ProgressBar | int] = None,
|
560
|
+
) -> Y:
|
561
|
+
"""
|
562
|
+
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
|
563
|
+
|
564
|
+
Parameters
|
565
|
+
----------
|
566
|
+
f : callable
|
567
|
+
A Python function to be looped over that accepts variadic arguments
|
568
|
+
corresponding to slices of ``xs`` along their leading axes, and returns
|
569
|
+
the output for that iteration.
|
570
|
+
*xs : X
|
571
|
+
The values over which to loop along the leading axis,
|
572
|
+
where each can be an array or any pytree (nested Python
|
573
|
+
tuple/list/dict) thereof with consistent leading axis sizes.
|
574
|
+
length : int, optional
|
575
|
+
Optional integer specifying the number of loop iterations, which
|
576
|
+
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
577
|
+
be used to perform loops where no input ``xs`` are needed).
|
578
|
+
base : int, default 16
|
579
|
+
Optional integer specifying the base for the bounded loop.
|
580
|
+
pbar : ProgressBar or int, optional
|
581
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
582
|
+
of the loop operation.
|
583
|
+
|
584
|
+
Returns
|
585
|
+
-------
|
586
|
+
Y
|
587
|
+
The stacked outputs of ``f`` when looped over the leading axis of the inputs.
|
588
|
+
|
589
|
+
Examples
|
590
|
+
--------
|
591
|
+
Basic checkpointed for-loop operation:
|
592
|
+
|
593
|
+
.. code-block:: python
|
594
|
+
|
595
|
+
>>> import brainstate
|
596
|
+
>>> import jax.numpy as jnp
|
597
|
+
>>>
|
598
|
+
>>> def process_item(x, y):
|
599
|
+
... return x * y + 1
|
600
|
+
>>>
|
601
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
602
|
+
>>> ys = jnp.array([4.0, 5.0, 6.0])
|
603
|
+
>>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
|
604
|
+
|
605
|
+
Using custom base for checkpointing:
|
606
|
+
|
607
|
+
.. code-block:: python
|
608
|
+
|
609
|
+
>>> results = brainstate.transform.checkpointed_for_loop(
|
610
|
+
... process_item, xs, ys, base=8
|
611
|
+
... )
|
612
|
+
"""
|
613
|
+
_, ys = checkpointed_scan(
|
614
|
+
_forloop_to_scan_fun(f),
|
615
|
+
init=None,
|
616
|
+
xs=xs,
|
617
|
+
length=length,
|
618
|
+
base=base,
|
619
|
+
pbar=pbar
|
620
|
+
)
|
621
|
+
return ys
|
622
|
+
|
623
|
+
|
624
|
+
# This function is adapted from ``while_loop`` in
|
625
|
+
# `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
626
|
+
|
627
|
+
# There's several tricks happening here to work around various limitations of JAX.
|
628
|
+
# (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
|
629
|
+
# 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
|
630
|
+
# is converted to a `lax.select`, which executes both branches unconditionally.
|
631
|
+
# Thus writing this naively, using a plain `lax.cond`, will mean the loop always
|
632
|
+
# runs to `max_steps` when executing under vmap. Instead we run (only) until every
|
633
|
+
# batch element has finished.
|
634
|
+
# 2. Treating in-place updates specially in the body_fun. Specifically we need to
|
635
|
+
# `lax.select` the update-to-make, not the updated buffer. This is because the
|
636
|
+
# latter instead results in XLA:CPU failing to determine that the buffer can be
|
637
|
+
# updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
|
638
|
+
# This is done through the extra `inplace` argument provided to `body_fun`.
|
639
|
+
# 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
|
640
|
+
# `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
|
641
|
+
# θ(number of steps actually taken).
|
642
|
+
# 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
|
643
|
+
# fewest superfluous operations. In practice this implies quite deep recursion in
|
644
|
+
# the construction of the bounded while loop, and this slows down the jaxpr
|
645
|
+
# creation and the XLA compilation. We choose `base=16` as a reasonable-looking
|
646
|
+
# compromise between compilation time and run time.
|
647
|
+
|
648
|
+
def _bounded_while_loop(
|
649
|
+
cond_fun: Callable,
|
650
|
+
body_fun: Callable,
|
651
|
+
val: Any,
|
652
|
+
max_steps: int,
|
653
|
+
base: int,
|
654
|
+
pbar_runner: Optional[Callable] = None
|
655
|
+
):
|
656
|
+
if max_steps == 1:
|
657
|
+
return body_fun(val)
|
658
|
+
else:
|
659
|
+
|
660
|
+
def true_call(val_):
|
661
|
+
return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
|
662
|
+
|
663
|
+
def false_call(val_):
|
664
|
+
if pbar_runner is not None:
|
665
|
+
pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
|
666
|
+
return val_[:-1] + (val_[-1] + max_steps,)
|
667
|
+
|
668
|
+
def scan_fn(val_, _):
|
669
|
+
return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
|
670
|
+
|
671
|
+
# Don't put checkpointing on the lowest level
|
672
|
+
if max_steps != base:
|
673
|
+
scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
|
674
|
+
|
675
|
+
return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]
|