brainstate 0.0.2.post20240825__py2.py3-none-any.whl → 0.0.2.post20240910__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,502 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ from functools import wraps
20
+ from typing import Callable, Optional, TypeVar, Tuple, Any
21
+
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+ from brainstate._utils import set_module_as
26
+ from ._make_jaxpr import StatefulFunction, _assign_state_values
27
+ from ._progress_bar import ProgressBar
28
+ from ._unvmap import unvmap
29
+
30
+ X = TypeVar('X')
31
+ Y = TypeVar('Y')
32
+ T = TypeVar('T')
33
+ Carry = TypeVar('Carry')
34
+
35
+ __all__ = [
36
+ # for loop & scan
37
+ 'scan', 'checkpointed_scan',
38
+ 'for_loop', 'checkpointed_for_loop',
39
+ ]
40
+
41
+
42
+ def _wrap_fun_with_pbar(fun, pbar_runner):
43
+ @wraps(fun)
44
+ def new_fun(new_carry, inputs):
45
+ i, old_carry = new_carry
46
+ old_carry, old_outputs = fun(old_carry, inputs)
47
+ pbar_runner(unvmap(i, op='none'))
48
+ return (i + 1, old_carry), old_outputs
49
+
50
+ return new_fun
51
+
52
+
53
+ def _wrapped_scan_fun(stateful_fun: StatefulFunction, states):
54
+ @wraps(stateful_fun.fun)
55
+ def wrapped_fun(new_carry, inputs):
56
+ state_vals, carry = new_carry
57
+ assert len(states) == len(state_vals)
58
+ for st, val in zip(states, state_vals):
59
+ st.value = val
60
+ carry, out = stateful_fun.jaxpr_call_auto(carry, inputs)
61
+ return (tuple(st.value for st in states), carry), out
62
+
63
+ return wrapped_fun
64
+
65
+
66
+ @set_module_as('brainstate.transform')
67
+ def scan(
68
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
69
+ init: Carry,
70
+ xs: X,
71
+ length: int | None = None,
72
+ reverse: bool = False,
73
+ unroll: int | bool = 1,
74
+ pbar: ProgressBar | None = None,
75
+ ) -> Tuple[Carry, Y]:
76
+ """
77
+ Scan a function over leading array axes while carrying along state.
78
+
79
+ The `Haskell-like type signature`_ in brief is
80
+
81
+ .. code-block:: haskell
82
+
83
+ scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
84
+
85
+ where for any array type specifier ``t``, ``[t]`` represents the type with an additional
86
+ leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
87
+ represents the type with the same pytree structure and corresponding leaves
88
+ each with an additional leading axis.
89
+
90
+ When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
91
+ of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
92
+ given roughly by this Python implementation::
93
+
94
+ def scan(f, init, xs, length=None):
95
+ if xs is None:
96
+ xs = [None] * length
97
+ carry = init
98
+ ys = []
99
+ for x in xs:
100
+ carry, y = f(carry, x)
101
+ ys.append(y)
102
+ return carry, np.stack(ys)
103
+
104
+ Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
105
+ values, and so multiple arrays can be scanned over at once and produce multiple
106
+ output arrays. ``None`` is actually a special case of this, as it represents an
107
+ empty pytree.
108
+
109
+ Also unlike that Python version, :func:`~scan` is a JAX primitive and is
110
+ lowered to a single WhileOp. That makes it useful for reducing
111
+ compilation times for JIT-compiled functions, since native Python
112
+ loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
113
+ XLA computations.
114
+
115
+ Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
116
+ across all iterations (and not just be consistent up to NumPy rank/shape
117
+ broadcasting and dtype promotion rules, for example). In other words, the type
118
+ ``c`` in the type signature above represents an array with a fixed shape and
119
+ dtype (or a nested tuple/list/dict container data structure with a fixed
120
+ structure and arrays with fixed shape and dtype at the leaves).
121
+
122
+ Args:
123
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
124
+ that ``f`` accepts two arguments where the first is a value of the loop
125
+ carry and the second is a slice of ``xs`` along its leading axis, and that
126
+ ``f`` returns a pair where the first element represents a new value for
127
+ the loop carry and the second represents a slice of the output.
128
+ init: an initial loop carry value of type ``c``, which can be a scalar,
129
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
130
+ the initial loop carry value. This value must have the same structure as
131
+ the first element of the pair returned by ``f``.
132
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
133
+ where ``[a]`` can be an array or any pytree (nested Python
134
+ tuple/list/dict) thereof with consistent leading axis sizes.
135
+ length: optional integer specifying the number of loop iterations, which
136
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
137
+ be used to perform scans where no input ``xs`` are needed).
138
+ reverse: optional boolean specifying whether to run the scan iteration
139
+ forward (the default) or in reverse, equivalent to reversing the leading
140
+ axes of the arrays in both ``xs`` and in ``ys``.
141
+ unroll: optional positive int or bool specifying, in the underlying
142
+ operation of the scan primitive, how many scan iterations to unroll within
143
+ a single iteration of a loop. If an integer is provided, it determines how
144
+ many unrolled loop iterations to run within a single rolled iteration of
145
+ the loop. If a boolean is provided, it will determine if the loop is
146
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
147
+ `unroll=False`).
148
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
149
+ of the scan operation.
150
+
151
+ Returns:
152
+ A pair of type ``(c, [b])`` where the first element represents the final
153
+ loop carry value and the second element represents the stacked outputs of
154
+ the second output of ``f`` when scanned over the leading axis of the inputs.
155
+
156
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
157
+ """
158
+ # check "f"
159
+ if not callable(f):
160
+ raise TypeError("f argument should be a callable.")
161
+
162
+ # check "xs"
163
+ xs_flat, xs_tree = jax.tree.flatten(xs)
164
+ try:
165
+ lengths = [x.shape[0] for x in xs_flat]
166
+ except AttributeError as err:
167
+ raise ValueError("scan got value with no leading axis to scan over: "
168
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
169
+ if length is not None:
170
+ length = int(length)
171
+ if not all(length == l for l in lengths):
172
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
173
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
174
+ else:
175
+ unique_lengths = set(lengths)
176
+ if len(unique_lengths) > 1:
177
+ msg = "scan got values with different leading axis sizes: {}."
178
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
179
+ elif len(unique_lengths) == 0:
180
+ raise ValueError("scan got no values to scan over and `length` not provided.")
181
+ else:
182
+ length, = unique_lengths
183
+
184
+ # function with progress bar
185
+ has_pbar = False
186
+ if pbar is not None:
187
+ has_pbar = True
188
+ f = _wrap_fun_with_pbar(f, pbar.init(length))
189
+ init = (0, init) if pbar else init
190
+
191
+ # not jit
192
+ if jax.config.jax_disable_jit:
193
+ if length == 0:
194
+ raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
195
+ carry = init
196
+ ys = []
197
+ maybe_reversed = reversed if reverse else lambda x: x
198
+ for i in maybe_reversed(range(length)):
199
+ xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
200
+ carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
201
+ ys.append(y)
202
+ stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
203
+ if has_pbar:
204
+ return carry[1], stacked_y
205
+ else:
206
+ return carry, stacked_y
207
+
208
+ # evaluate jaxpr, get all states #
209
+ # ------------------------------ #
210
+ xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
211
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
212
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
213
+ all_states = stateful_fun.get_states()
214
+ wrapped_f = _wrapped_scan_fun(stateful_fun, all_states)
215
+
216
+ # scan
217
+ init = (tuple(st.value for st in all_states), init)
218
+ (state_vals, carry), ys = jax.lax.scan(wrapped_f, init, xs, length=length, reverse=reverse, unroll=unroll)
219
+ _assign_state_values(all_states, state_vals)
220
+ if has_pbar:
221
+ carry = carry[1]
222
+ return carry, ys
223
+
224
+
225
+ def checkpointed_scan(
226
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
227
+ init: Carry,
228
+ xs: X,
229
+ length: Optional[int] = None,
230
+ base: int = 16,
231
+ pbar: Optional[ProgressBar] = None,
232
+ ):
233
+ """
234
+ Scan a function over leading array axes while carrying along state.
235
+ This function is similar to :func:`~scan` but with a checkpointed version.
236
+
237
+ Args:
238
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
239
+ that ``f`` accepts two arguments where the first is a value of the loop
240
+ carry and the second is a slice of ``xs`` along its leading axis, and that
241
+ ``f`` returns a pair where the first element represents a new value for
242
+ the loop carry and the second represents a slice of the output.
243
+ init: an initial loop carry value of type ``c``, which can be a scalar,
244
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
245
+ the initial loop carry value. This value must have the same structure as
246
+ the first element of the pair returned by ``f``.
247
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
248
+ where ``[a]`` can be an array or any pytree (nested Python
249
+ tuple/list/dict) thereof with consistent leading axis sizes.
250
+ length: optional integer specifying the number of loop iterations, which
251
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
252
+ be used to perform scans where no input ``xs`` are needed).
253
+ base: optional integer specifying the base for the bounded scan loop.
254
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
255
+ of the scan operation.
256
+
257
+ Returns:
258
+ A pair of type ``(c, [b])`` where the first element represents the final
259
+ loop carry value and the second element represents the stacked outputs of
260
+ the second output of ``f`` when scanned over the leading axis of the inputs.
261
+ """
262
+ # check "f"
263
+ if not callable(f):
264
+ raise TypeError("f argument should be a callable.")
265
+
266
+ # check "xs"
267
+ xs_flat, xs_tree = jax.tree.flatten(xs)
268
+ try:
269
+ lengths = [x.shape[0] for x in xs_flat]
270
+ except AttributeError as err:
271
+ raise ValueError("scan got value with no leading axis to scan over: "
272
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
273
+ if length is not None:
274
+ length = int(length)
275
+ if not all(length == l for l in lengths):
276
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
277
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
278
+ else:
279
+ unique_lengths = set(lengths)
280
+ if len(unique_lengths) > 1:
281
+ msg = "scan got values with different leading axis sizes: {}."
282
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
283
+ elif len(unique_lengths) == 0:
284
+ raise ValueError("scan got no values to scan over and `length` not provided.")
285
+ else:
286
+ length, = unique_lengths
287
+
288
+ # function with progress bar
289
+ if pbar is not None:
290
+ pbar_runner = pbar.init(length)
291
+ else:
292
+ pbar_runner = None
293
+
294
+ # evaluate jaxpr
295
+ xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
296
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
297
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
298
+ all_states = stateful_fun.get_states()
299
+ out_info = stateful_fun.get_out_shapes()[0]
300
+
301
+ # initialize the collected values/dataa
302
+ assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
303
+ data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
304
+ del out_info
305
+
306
+ def wrapped_cond_fun(inp):
307
+ return inp[-1] < length
308
+
309
+ def wrapped_body_fun(inp):
310
+ (prev_states, carray), prev_collect, i = inp
311
+ # progress bar
312
+ if pbar_runner is not None:
313
+ pbar_runner(unvmap(i, op='none'))
314
+ # call the function
315
+ new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
316
+ prev_states, carray, jax.tree.map(lambda x: x[i], xs))
317
+ # out of bounds
318
+ pred = i < length
319
+ new_collect = jax.tree.map(
320
+ lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
321
+ prev_collect,
322
+ out4updates,
323
+ )
324
+ new_states = jax.tree.map(
325
+ lambda ps, ns: jax.lax.select(pred, ns, ps),
326
+ prev_states,
327
+ new_states,
328
+ )
329
+ new_carray = jax.tree.map(
330
+ lambda pc, nc: jax.lax.select(pred, nc, pc),
331
+ carray,
332
+ new_carray,
333
+ )
334
+ return (new_states, new_carray), new_collect, i + 1
335
+
336
+ # while_loop
337
+ rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
338
+ (state_vals, carry), data2collection, _ = _bounded_while_loop(
339
+ wrapped_cond_fun,
340
+ wrapped_body_fun,
341
+ ((tuple(st.value for st in all_states), init), data2collection, 0),
342
+ rounded_max_steps,
343
+ base,
344
+ pbar_runner
345
+ )
346
+ _assign_state_values(all_states, state_vals)
347
+ del state_vals, all_states, stateful_fun
348
+ return carry, data2collection
349
+
350
+
351
+ def _forloop_to_scan_fun(f: Callable):
352
+ @wraps(f)
353
+ def scan_fun(carry, x):
354
+ return carry, f(*x)
355
+
356
+ return scan_fun
357
+
358
+
359
+ @set_module_as('brainstate.transform')
360
+ def for_loop(
361
+ f: Callable[[X], Y],
362
+ *xs,
363
+ length: Optional[int] = None,
364
+ reverse: bool = False,
365
+ unroll: int | bool = 1,
366
+ pbar: Optional[ProgressBar] = None
367
+ ):
368
+ """
369
+ ``for-loop`` control flow with :py:class:`~.State`.
370
+
371
+ Args:
372
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
373
+ that ``f`` accepts two arguments where the first is a value of the loop
374
+ carry and the second is a slice of ``xs`` along its leading axis, and that
375
+ ``f`` returns a pair where the first element represents a new value for
376
+ the loop carry and the second represents a slice of the output.
377
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
378
+ where ``[a]`` can be an array or any pytree (nested Python
379
+ tuple/list/dict) thereof with consistent leading axis sizes.
380
+ length: optional integer specifying the number of loop iterations, which
381
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
382
+ be used to perform scans where no input ``xs`` are needed).
383
+ reverse: optional boolean specifying whether to run the scan iteration
384
+ forward (the default) or in reverse, equivalent to reversing the leading
385
+ axes of the arrays in both ``xs`` and in ``ys``.
386
+ unroll: optional positive int or bool specifying, in the underlying
387
+ operation of the scan primitive, how many scan iterations to unroll within
388
+ a single iteration of a loop. If an integer is provided, it determines how
389
+ many unrolled loop iterations to run within a single rolled iteration of
390
+ the loop. If a boolean is provided, it will determine if the loop is
391
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
392
+ `unroll=False`).
393
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
394
+ of the scan operation.
395
+
396
+ Returns:
397
+ The return represents the stacked outputs of the second output of ``f``
398
+ when scanned over the leading axis of the inputs.
399
+
400
+ """
401
+ _, ys = scan(
402
+ _forloop_to_scan_fun(f),
403
+ init=None,
404
+ xs=xs,
405
+ length=length,
406
+ reverse=reverse,
407
+ unroll=unroll,
408
+ pbar=pbar
409
+ )
410
+ return ys
411
+
412
+
413
+ def checkpointed_for_loop(
414
+ f: Callable[[X], Y],
415
+ *xs: X,
416
+ length: Optional[int] = None,
417
+ base: int = 16,
418
+ pbar: Optional[ProgressBar] = None,
419
+ ):
420
+ """
421
+ ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
422
+
423
+ Args:
424
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
425
+ that ``f`` accepts two arguments where the first is a value of the loop
426
+ carry and the second is a slice of ``xs`` along its leading axis, and that
427
+ ``f`` returns a pair where the first element represents a new value for
428
+ the loop carry and the second represents a slice of the output.
429
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
430
+ where ``[a]`` can be an array or any pytree (nested Python
431
+ tuple/list/dict) thereof with consistent leading axis sizes.
432
+ length: optional integer specifying the number of loop iterations, which
433
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
434
+ be used to perform scans where no input ``xs`` are needed).
435
+ base: optional integer specifying the base for the bounded scan loop.
436
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
437
+ of the scan operation.
438
+
439
+ Returns:
440
+ The return represents the stacked outputs of the second output of ``f``
441
+ when scanned over the leading axis of the inputs.
442
+ """
443
+ _, ys = checkpointed_scan(
444
+ _forloop_to_scan_fun(f),
445
+ init=None,
446
+ xs=xs,
447
+ length=length,
448
+ base=base,
449
+ pbar=pbar
450
+ )
451
+ return ys
452
+
453
+
454
+ # There's several tricks happening here to work around various limitations of JAX.
455
+ # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
456
+ # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
457
+ # is converted to a `lax.select`, which executes both branches unconditionally.
458
+ # Thus writing this naively, using a plain `lax.cond`, will mean the loop always
459
+ # runs to `max_steps` when executing under vmap. Instead we run (only) until every
460
+ # batch element has finished.
461
+ # 2. Treating in-place updates specially in the body_fun. Specifically we need to
462
+ # `lax.select` the update-to-make, not the updated buffer. This is because the
463
+ # latter instead results in XLA:CPU failing to determine that the buffer can be
464
+ # updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
465
+ # This is done through the extra `inplace` argument provided to `body_fun`.
466
+ # 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
467
+ # `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
468
+ # θ(number of steps actually taken).
469
+ # 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
470
+ # fewest superfluous operations. In practice this implies quite deep recursion in
471
+ # the construction of the bounded while loop, and this slows down the jaxpr
472
+ # creation and the XLA compilation. We choose `base=16` as a reasonable-looking
473
+ # compromise between compilation time and run time.
474
+
475
+ def _bounded_while_loop(
476
+ cond_fun: Callable,
477
+ body_fun: Callable,
478
+ val: Any,
479
+ max_steps: int,
480
+ base: int,
481
+ pbar_runner: Optional[Callable] = None
482
+ ):
483
+ if max_steps == 1:
484
+ return body_fun(val)
485
+ else:
486
+
487
+ def true_call(val_):
488
+ return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
489
+
490
+ def false_call(val_):
491
+ if pbar_runner is not None:
492
+ pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
493
+ return val_[:-1] + (val_[-1] + max_steps,)
494
+
495
+ def scan_fn(val_, _):
496
+ return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
497
+
498
+ # Don't put checkpointing on the lowest level
499
+ if max_steps != base:
500
+ scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
501
+
502
+ return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]
@@ -0,0 +1,170 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ from typing import Any, Callable, TypeVar
20
+
21
+ import jax
22
+
23
+ from brainstate._utils import set_module_as
24
+ from ._conditions import _wrapped_fun
25
+ from ._loop_collect_return import _bounded_while_loop
26
+ from ._make_jaxpr import StatefulFunction, _assign_state_values
27
+
28
+ X = TypeVar('X')
29
+ Y = TypeVar('Y')
30
+ T = TypeVar('T')
31
+ Carry = TypeVar('Carry')
32
+ BooleanNumeric = Any # A bool, or a Boolean array.
33
+
34
+ __all__ = [
35
+ # while loop
36
+ 'while_loop', 'bounded_while_loop',
37
+ ]
38
+
39
+
40
+ @set_module_as('brainstate.transform')
41
+ def while_loop(
42
+ cond_fun: Callable[[T], BooleanNumeric],
43
+ body_fun: Callable[[T], T],
44
+ init_val: T
45
+ ) -> T:
46
+ """
47
+ Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
48
+
49
+ The `Haskell-like type signature`_ in brief is
50
+
51
+ .. code-block:: haskell
52
+
53
+ while_loop :: (a -> Bool) -> (a -> a) -> a -> a
54
+
55
+ The semantics of ``while_loop`` are given by this Python implementation::
56
+
57
+ def while_loop(cond_fun, body_fun, init_val):
58
+ val = init_val
59
+ while cond_fun(val):
60
+ val = body_fun(val)
61
+ return val
62
+
63
+ Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
64
+ to a single WhileOp. That makes it useful for reducing compilation times
65
+ for jit-compiled functions, since native Python loop constructs in an ``@jit``
66
+ function are unrolled, leading to large XLA computations.
67
+
68
+ Also unlike the Python analogue, the loop-carried value ``val`` must hold a
69
+ fixed shape and dtype across all iterations (and not just be consistent up to
70
+ NumPy rank/shape broadcasting and dtype promotion rules, for example). In
71
+ other words, the type ``a`` in the type signature above represents an array
72
+ with a fixed shape and dtype (or a nested tuple/list/dict container data
73
+ structure with a fixed structure and arrays with fixed shape and dtype at the
74
+ leaves).
75
+
76
+ Another difference from using Python-native loop constructs is that
77
+ ``while_loop`` is not reverse-mode differentiable because XLA computations
78
+ require static bounds on memory requirements.
79
+
80
+ Args:
81
+ cond_fun: function of type ``a -> Bool``.
82
+ body_fun: function of type ``a -> a``.
83
+ init_val: value of type ``a``, a type that can be a scalar, array, or any
84
+ pytree (nested Python tuple/list/dict) thereof, representing the initial
85
+ loop carry value.
86
+
87
+ Returns:
88
+ The output from the final iteration of body_fun, of type ``a``.
89
+
90
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
91
+ """
92
+ if not (callable(body_fun) and callable(cond_fun)):
93
+ raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
94
+ if jax.config.jax_disable_jit:
95
+ try:
96
+ val = init_val
97
+ while cond_fun(val):
98
+ val = body_fun(val)
99
+ return val
100
+ except jax.core.ConcretizationTypeError:
101
+ # Can't run this while_loop in Python (e.g. because there's a vmap
102
+ # transformation on it), so we fall back to the primitive version.
103
+ pass
104
+
105
+ # evaluate jaxpr
106
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
107
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
108
+ all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
109
+ new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
110
+ new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
111
+
112
+ # while_loop
113
+ state_vals, final_val = jax.lax.while_loop(new_cond_fun,
114
+ new_body_fun,
115
+ (tuple(st.value for st in all_states), init_val))
116
+ _assign_state_values(all_states, state_vals)
117
+ return final_val
118
+
119
+
120
+ def bounded_while_loop(
121
+ cond_fun: Callable[[T], BooleanNumeric],
122
+ body_fun: Callable[[T], T],
123
+ init_val: T,
124
+ *,
125
+ max_steps: int,
126
+ base: int = 16,
127
+ ):
128
+ """
129
+ While loop with a bound on the maximum number of steps.
130
+
131
+ This function is useful when you want to ensure that a while loop terminates
132
+ even if the condition function is never false. The function is implemented
133
+ using a scan operation, so it is reverse-mode differentiable.
134
+
135
+ Args:
136
+ cond_fun: A function of type ``a -> Bool``.
137
+ body_fun: A function of type ``a -> a``.
138
+ init_val: The initial value of type ``a``.
139
+ max_steps: A bound on the maximum number of steps, after which the loop
140
+ terminates unconditionally.
141
+ base: Run time will increase slightly as `base` increases. Compilation time will
142
+ decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
143
+ (Which happens as `base` increases.)
144
+
145
+ Returns:
146
+ The final value, as if computed by a `lax.while_loop`.
147
+ """
148
+
149
+ # checking
150
+ if not isinstance(max_steps, int) or max_steps < 0:
151
+ raise ValueError("max_steps must be a non-negative integer")
152
+ init_val = jax.tree.map(jax.numpy.asarray, init_val)
153
+ if max_steps == 0:
154
+ return init_val
155
+
156
+ # evaluate jaxpr
157
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
158
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
159
+ all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
160
+ new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
161
+ new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
162
+
163
+ # initial value
164
+ init_val = (tuple(st.value for st in all_states), init_val)
165
+
166
+ # while_loop
167
+ rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
168
+ state_vals, val = _bounded_while_loop(new_cond_fun, new_body_fun, init_val, rounded_max_steps, base, None)
169
+ _assign_state_values(all_states, state_vals)
170
+ return val