brainstate 0.0.2.post20240825__py2.py3-none-any.whl → 0.0.2.post20240826__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +6 -6
- brainstate/random.py +5 -5
- brainstate/transform/__init__.py +16 -6
- brainstate/transform/_conditions.py +334 -0
- brainstate/transform/{_controls_test.py → _conditions_test.py} +35 -35
- brainstate/transform/_error_if.py +94 -0
- brainstate/transform/{_jit_error_test.py → _error_if_test.py} +4 -4
- brainstate/transform/_loop_collect_return.py +502 -0
- brainstate/transform/_loop_no_collection.py +170 -0
- brainstate/transform/_mapping.py +109 -0
- brainstate/transform/_unvmap.py +143 -0
- brainstate/typing.py +55 -1
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/METADATA +2 -2
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/RECORD +17 -13
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/WHEEL +1 -1
- brainstate/transform/_control.py +0 -665
- brainstate/transform/_jit_error.py +0 -180
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/top_level.txt +0 -0
brainstate/transform/_control.py
DELETED
@@ -1,665 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import operator
|
19
|
-
from collections.abc import Sequence
|
20
|
-
from functools import wraps, reduce
|
21
|
-
from typing import Callable, TypeVar, Any, Optional
|
22
|
-
|
23
|
-
import jax
|
24
|
-
import jax.numpy as jnp
|
25
|
-
import numpy as np
|
26
|
-
|
27
|
-
from brainstate._utils import set_module_as
|
28
|
-
from ._jit_error import jit_error, remove_vmap
|
29
|
-
from ._make_jaxpr import StatefulFunction, _assign_state_values
|
30
|
-
from ._progress_bar import ProgressBar
|
31
|
-
|
32
|
-
Carry = TypeVar('Carry')
|
33
|
-
X = TypeVar('X')
|
34
|
-
Y = TypeVar('Y')
|
35
|
-
T = TypeVar('T')
|
36
|
-
BooleanNumeric = Any # A bool, or a Boolean array.
|
37
|
-
|
38
|
-
__all__ = [
|
39
|
-
'cond', 'switch', 'ifelse', 'scan', 'for_loop', 'while_loop',
|
40
|
-
]
|
41
|
-
|
42
|
-
|
43
|
-
def _wrapped_fun(stateful_fun: StatefulFunction, states, return_states=True):
|
44
|
-
@wraps(stateful_fun.fun)
|
45
|
-
def wrapped_branch(state_vals, *operands):
|
46
|
-
assert len(states) == len(state_vals)
|
47
|
-
for st, val in zip(states, state_vals):
|
48
|
-
st.value = val
|
49
|
-
out = stateful_fun.jaxpr_call_auto(*operands)
|
50
|
-
if return_states:
|
51
|
-
return tuple(st.value for st in states), out
|
52
|
-
return out
|
53
|
-
|
54
|
-
return wrapped_branch
|
55
|
-
|
56
|
-
|
57
|
-
@set_module_as('brainstate.transform')
|
58
|
-
def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
59
|
-
"""
|
60
|
-
Conditionally apply ``true_fun`` or ``false_fun``.
|
61
|
-
|
62
|
-
Provided arguments are correctly typed, ``cond()`` has equivalent
|
63
|
-
semantics to this Python implementation, where ``pred`` must be a
|
64
|
-
scalar type::
|
65
|
-
|
66
|
-
def cond(pred, true_fun, false_fun, *operands):
|
67
|
-
if pred:
|
68
|
-
return true_fun(*operands)
|
69
|
-
else:
|
70
|
-
return false_fun(*operands)
|
71
|
-
|
72
|
-
|
73
|
-
In contrast with :func:`jax.lax.select`, using ``cond`` indicates that only one of
|
74
|
-
the two branches is executed (up to compiler rewrites and optimizations).
|
75
|
-
However, when transformed with :func:`~jax.vmap` to operate over a batch of
|
76
|
-
predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
77
|
-
|
78
|
-
Args:
|
79
|
-
pred: Boolean scalar type, indicating which branch function to apply.
|
80
|
-
true_fun: Function (A -> B), to be applied if ``pred`` is True.
|
81
|
-
false_fun: Function (A -> B), to be applied if ``pred`` is False.
|
82
|
-
operands: Operands (A) input to either branch depending on ``pred``. The
|
83
|
-
type can be a scalar, array, or any pytree (nested Python tuple/list/dict)
|
84
|
-
thereof.
|
85
|
-
|
86
|
-
Returns:
|
87
|
-
Value (B) of either ``true_fun(*operands)`` or ``false_fun(*operands)``,
|
88
|
-
depending on the value of ``pred``. The type can be a scalar, array, or any
|
89
|
-
pytree (nested Python tuple/list/dict) thereof.
|
90
|
-
"""
|
91
|
-
if not (callable(true_fun) and callable(false_fun)):
|
92
|
-
raise TypeError("true_fun and false_fun arguments should be callable.")
|
93
|
-
|
94
|
-
if pred is None:
|
95
|
-
raise TypeError("cond predicate is None")
|
96
|
-
if isinstance(pred, Sequence) or np.ndim(pred) != 0:
|
97
|
-
raise TypeError(f"Pred must be a scalar, got {pred} of " +
|
98
|
-
(f"type {type(pred)}" if isinstance(pred, Sequence)
|
99
|
-
else f"shape {np.shape(pred)}."))
|
100
|
-
|
101
|
-
# check pred
|
102
|
-
try:
|
103
|
-
pred_dtype = jax.dtypes.result_type(pred)
|
104
|
-
except TypeError as err:
|
105
|
-
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
|
106
|
-
if pred_dtype.kind != 'b':
|
107
|
-
if pred_dtype.kind in 'iuf':
|
108
|
-
pred = pred != 0
|
109
|
-
else:
|
110
|
-
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
|
111
|
-
|
112
|
-
# not jit
|
113
|
-
if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
|
114
|
-
if pred:
|
115
|
-
return true_fun(*operands)
|
116
|
-
else:
|
117
|
-
return false_fun(*operands)
|
118
|
-
|
119
|
-
# evaluate jaxpr
|
120
|
-
true_fun_wrap = StatefulFunction(true_fun).make_jaxpr(*operands)
|
121
|
-
false_fun_wrap = StatefulFunction(false_fun).make_jaxpr(*operands)
|
122
|
-
|
123
|
-
# wrap the functions
|
124
|
-
all_states = tuple(set(true_fun_wrap.get_states() + false_fun_wrap.get_states()))
|
125
|
-
true_fun = _wrapped_fun(true_fun_wrap, all_states)
|
126
|
-
false_fun = _wrapped_fun(false_fun_wrap, all_states)
|
127
|
-
|
128
|
-
# operands
|
129
|
-
operands = ([st.value for st in all_states],) + operands
|
130
|
-
|
131
|
-
# cond
|
132
|
-
state_vals, out = jax.lax.cond(pred, true_fun, false_fun, *operands)
|
133
|
-
_assign_state_values(all_states, state_vals)
|
134
|
-
return out
|
135
|
-
|
136
|
-
# ops, ops_tree = jax.tree.flatten(operands)
|
137
|
-
# linear_ops = [False] * len(ops)
|
138
|
-
# ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
|
139
|
-
#
|
140
|
-
# # true and false jaxprs
|
141
|
-
# jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
142
|
-
# (true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
143
|
-
# if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
|
144
|
-
# raise ValueError("Cannot pass `Ref`s into `cond`.")
|
145
|
-
# true_jaxpr, false_jaxpr = jaxprs
|
146
|
-
# out_tree, false_out_tree = out_trees
|
147
|
-
# if any(isinstance(out_aval, state.AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals):
|
148
|
-
# raise ValueError("Cannot return `Ref`s from `cond`.")
|
149
|
-
#
|
150
|
-
# _check_tree_and_avals("true_fun and false_fun output",
|
151
|
-
# out_tree, true_jaxpr.out_avals,
|
152
|
-
# false_out_tree, false_jaxpr.out_avals)
|
153
|
-
# joined_effects = jax.core.join_effects(true_jaxpr.effects, false_jaxpr.effects)
|
154
|
-
# disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
155
|
-
# if disallowed_effects:
|
156
|
-
# raise NotImplementedError(f'Effects not supported in `cond`: {disallowed_effects}')
|
157
|
-
#
|
158
|
-
# # replace jaxpr effects
|
159
|
-
# index = jax.lax.convert_element_type(pred, np.int32)
|
160
|
-
# if joined_effects:
|
161
|
-
# # Raise index in case of effects to allow data-dependence-based discharging
|
162
|
-
# # of those effects (even if they don't have an explicit data dependence).
|
163
|
-
# index = jax.core.raise_as_much_as_possible(index)
|
164
|
-
# false_jaxpr = _replace_jaxpr_effects(false_jaxpr, joined_effects)
|
165
|
-
# true_jaxpr = _replace_jaxpr_effects(true_jaxpr, joined_effects)
|
166
|
-
#
|
167
|
-
# # bind
|
168
|
-
# linear = [False] * len(consts) + linear_ops
|
169
|
-
# cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
|
170
|
-
#
|
171
|
-
# # outputs
|
172
|
-
# st_vals, out = jax.tree.unflatten(out_tree, cond_outs)
|
173
|
-
# for st, val in zip(all_states, st_vals):
|
174
|
-
# st.value = val
|
175
|
-
# return out
|
176
|
-
|
177
|
-
|
178
|
-
@set_module_as('brainstate.transform')
|
179
|
-
def switch(index, branches: Sequence[Callable], *operands):
|
180
|
-
"""
|
181
|
-
Apply exactly one of ``branches`` given by ``index``.
|
182
|
-
|
183
|
-
If ``index`` is out of bounds, it is clamped to within bounds.
|
184
|
-
|
185
|
-
Has the semantics of the following Python::
|
186
|
-
|
187
|
-
def switch(index, branches, *operands):
|
188
|
-
index = clamp(0, index, len(branches) - 1)
|
189
|
-
return branches[index](*operands)
|
190
|
-
|
191
|
-
Internally this wraps XLA's `Conditional
|
192
|
-
<https://www.tensorflow.org/xla/operation_semantics#conditional>`_
|
193
|
-
operator. However, when transformed with :func:`~jax.vmap` to operate over a
|
194
|
-
batch of predicates, ``cond`` is converted to :func:`~jax.lax.select`.
|
195
|
-
|
196
|
-
Args:
|
197
|
-
index: Integer scalar type, indicating which branch function to apply.
|
198
|
-
branches: Sequence of functions (A -> B) to be applied based on ``index``.
|
199
|
-
operands: Operands (A) input to whichever branch is applied.
|
200
|
-
|
201
|
-
Returns:
|
202
|
-
Value (B) of ``branch(*operands)`` for the branch that was selected based
|
203
|
-
on ``index``.
|
204
|
-
"""
|
205
|
-
# check branches
|
206
|
-
if not all(callable(branch) for branch in branches):
|
207
|
-
raise TypeError("branches argument should be a sequence of callables.")
|
208
|
-
|
209
|
-
# check index
|
210
|
-
if len(np.shape(index)) != 0:
|
211
|
-
raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
|
212
|
-
try:
|
213
|
-
index_dtype = jax.dtypes.result_type(index)
|
214
|
-
except TypeError as err:
|
215
|
-
msg = f"Index type must be an integer, got {index}."
|
216
|
-
raise TypeError(msg) from err
|
217
|
-
if index_dtype.kind not in 'iu':
|
218
|
-
raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
|
219
|
-
|
220
|
-
# format branches
|
221
|
-
branches = tuple(branches)
|
222
|
-
if len(branches) == 0:
|
223
|
-
raise ValueError("Empty branch sequence")
|
224
|
-
elif len(branches) == 1:
|
225
|
-
return branches[0](*operands)
|
226
|
-
|
227
|
-
# format index
|
228
|
-
index = jax.lax.convert_element_type(index, np.int32)
|
229
|
-
lo = np.array(0, np.int32)
|
230
|
-
hi = np.array(len(branches) - 1, np.int32)
|
231
|
-
index = jax.lax.clamp(lo, index, hi)
|
232
|
-
|
233
|
-
# not jit
|
234
|
-
if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
|
235
|
-
return branches[int(index)](*operands)
|
236
|
-
|
237
|
-
# evaluate jaxpr
|
238
|
-
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
239
|
-
for wrapped_branch in wrapped_branches:
|
240
|
-
wrapped_branch.make_jaxpr(*operands)
|
241
|
-
|
242
|
-
# wrap the functions
|
243
|
-
all_states = tuple(set(reduce(operator.add, [wrapped_branch.get_states() for wrapped_branch in wrapped_branches])))
|
244
|
-
branches = tuple(_wrapped_fun(wrapped_branch, all_states) for wrapped_branch in wrapped_branches)
|
245
|
-
|
246
|
-
# operands
|
247
|
-
operands = ([st.value for st in all_states],) + operands
|
248
|
-
|
249
|
-
# switch
|
250
|
-
state_vals, out = jax.lax.switch(index, branches, *operands)
|
251
|
-
_assign_state_values(all_states, state_vals)
|
252
|
-
return out
|
253
|
-
|
254
|
-
# ops, ops_tree = jax.tree.flatten(operands)
|
255
|
-
# ops_avals = tuple(jax.util.safe_map(_abstractify, ops))
|
256
|
-
#
|
257
|
-
# # true jaxprs
|
258
|
-
# jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
259
|
-
# branches, ops_tree, ops_avals, primitive_name='switch')
|
260
|
-
# for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
261
|
-
# _check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
262
|
-
# out_trees[0], jaxprs[0].out_avals,
|
263
|
-
# out_tree, jaxpr.out_avals)
|
264
|
-
# joined_effects = jax.core.join_effects(*(jaxpr.effects for jaxpr in jaxprs))
|
265
|
-
# disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects)
|
266
|
-
# if disallowed_effects:
|
267
|
-
# raise NotImplementedError(f'Effects not supported in `switch`: {disallowed_effects}')
|
268
|
-
# if joined_effects:
|
269
|
-
# # Raise index in case of effects to allow data-dependence-based discharging
|
270
|
-
# # of those effects (even if they don't have an explicit data dependence).
|
271
|
-
# index = jax.core.raise_as_much_as_possible(index)
|
272
|
-
#
|
273
|
-
# # bind
|
274
|
-
# linear = (False,) * (len(consts) + len(ops))
|
275
|
-
# cond_outs = jax.lax.cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs), linear=linear)
|
276
|
-
#
|
277
|
-
# # outputs
|
278
|
-
# st_vals, out = jax.tree.unflatten(out_trees[0], cond_outs)
|
279
|
-
# for st, val in zip(all_states, st_vals):
|
280
|
-
# st.value = val
|
281
|
-
# return out
|
282
|
-
|
283
|
-
|
284
|
-
@set_module_as('brainstate.transform')
|
285
|
-
def ifelse(conditions, branches, *operands, check_cond: bool = True):
|
286
|
-
"""
|
287
|
-
``If-else`` control flows looks like native Pythonic programming.
|
288
|
-
|
289
|
-
Examples
|
290
|
-
--------
|
291
|
-
|
292
|
-
>>> import brainstate as bst
|
293
|
-
>>> def f(a):
|
294
|
-
>>> return bst.transform.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0],
|
295
|
-
>>> branches=[lambda: 1,
|
296
|
-
>>> lambda: 2,
|
297
|
-
>>> lambda: 3,
|
298
|
-
>>> lambda: 4,
|
299
|
-
>>> lambda: 5])
|
300
|
-
>>> f(1)
|
301
|
-
4
|
302
|
-
>>> f(0)
|
303
|
-
5
|
304
|
-
|
305
|
-
Parameters
|
306
|
-
----------
|
307
|
-
conditions: bool, sequence of bool, Array
|
308
|
-
The boolean conditions.
|
309
|
-
branches: Any
|
310
|
-
The branches, at least has two elements. Elements can be functions,
|
311
|
-
arrays, or numbers. The number of ``branches`` and ``conditions`` has
|
312
|
-
the relationship of `len(branches) == len(conditions) + 1`.
|
313
|
-
Each branch should receive one arguement for ``operands``.
|
314
|
-
*operands: optional, Any
|
315
|
-
The operands for each branch.
|
316
|
-
check_cond: bool
|
317
|
-
Whether to check the conditions. Default is True.
|
318
|
-
|
319
|
-
Returns
|
320
|
-
-------
|
321
|
-
res: Any
|
322
|
-
The results of the control flow.
|
323
|
-
"""
|
324
|
-
# check branches
|
325
|
-
if not all(callable(branch) for branch in branches):
|
326
|
-
raise TypeError("branches argument should be a sequence of callables.")
|
327
|
-
|
328
|
-
# format branches
|
329
|
-
branches = tuple(branches)
|
330
|
-
if len(branches) == 0:
|
331
|
-
raise ValueError("Empty branch sequence")
|
332
|
-
elif len(branches) == 1:
|
333
|
-
return branches[0](*operands)
|
334
|
-
if len(conditions) != len(branches):
|
335
|
-
raise ValueError("The number of conditions should be equal to the number of branches.")
|
336
|
-
|
337
|
-
# format index
|
338
|
-
conditions = jnp.asarray(conditions, np.int32)
|
339
|
-
if check_cond:
|
340
|
-
jit_error(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
|
341
|
-
index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
|
342
|
-
return switch(index, branches, *operands)
|
343
|
-
|
344
|
-
|
345
|
-
def _wrap_fun_with_pbar(fun, pbar_runner):
|
346
|
-
@wraps(fun)
|
347
|
-
def new_fun(new_carry, inputs):
|
348
|
-
i, old_carry = new_carry
|
349
|
-
old_carry, old_outputs = fun(old_carry, inputs)
|
350
|
-
pbar_runner(remove_vmap(i, op='none'))
|
351
|
-
return (i + 1, old_carry), old_outputs
|
352
|
-
|
353
|
-
return new_fun
|
354
|
-
|
355
|
-
|
356
|
-
def _wrapped_scan_fun(stateful_fun: StatefulFunction, states):
|
357
|
-
@wraps(stateful_fun.fun)
|
358
|
-
def wrapped_fun(new_carry, inputs):
|
359
|
-
state_vals, carry = new_carry
|
360
|
-
assert len(states) == len(state_vals)
|
361
|
-
for st, val in zip(states, state_vals):
|
362
|
-
st.value = val
|
363
|
-
carry, out = stateful_fun.jaxpr_call_auto(carry, inputs)
|
364
|
-
return (tuple(st.value for st in states), carry), out
|
365
|
-
|
366
|
-
return wrapped_fun
|
367
|
-
|
368
|
-
|
369
|
-
@set_module_as('brainstate.transform')
|
370
|
-
def scan(
|
371
|
-
f: Callable[[Carry, X], tuple[Carry, Y]],
|
372
|
-
init: Carry,
|
373
|
-
xs: X,
|
374
|
-
length: int | None = None,
|
375
|
-
reverse: bool = False,
|
376
|
-
unroll: int | bool = 1,
|
377
|
-
pbar: ProgressBar | None = None,
|
378
|
-
) -> tuple[Carry, Y]:
|
379
|
-
"""
|
380
|
-
Scan a function over leading array axes while carrying along state.
|
381
|
-
|
382
|
-
The `Haskell-like type signature`_ in brief is
|
383
|
-
|
384
|
-
.. code-block:: haskell
|
385
|
-
|
386
|
-
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
|
387
|
-
|
388
|
-
where for any array type specifier ``t``, ``[t]`` represents the type with an additional
|
389
|
-
leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
|
390
|
-
represents the type with the same pytree structure and corresponding leaves
|
391
|
-
each with an additional leading axis.
|
392
|
-
|
393
|
-
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
|
394
|
-
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
|
395
|
-
given roughly by this Python implementation::
|
396
|
-
|
397
|
-
def scan(f, init, xs, length=None):
|
398
|
-
if xs is None:
|
399
|
-
xs = [None] * length
|
400
|
-
carry = init
|
401
|
-
ys = []
|
402
|
-
for x in xs:
|
403
|
-
carry, y = f(carry, x)
|
404
|
-
ys.append(y)
|
405
|
-
return carry, np.stack(ys)
|
406
|
-
|
407
|
-
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
|
408
|
-
values, and so multiple arrays can be scanned over at once and produce multiple
|
409
|
-
output arrays. ``None`` is actually a special case of this, as it represents an
|
410
|
-
empty pytree.
|
411
|
-
|
412
|
-
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
|
413
|
-
lowered to a single WhileOp. That makes it useful for reducing
|
414
|
-
compilation times for JIT-compiled functions, since native Python
|
415
|
-
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
|
416
|
-
XLA computations.
|
417
|
-
|
418
|
-
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
419
|
-
across all iterations (and not just be consistent up to NumPy rank/shape
|
420
|
-
broadcasting and dtype promotion rules, for example). In other words, the type
|
421
|
-
``c`` in the type signature above represents an array with a fixed shape and
|
422
|
-
dtype (or a nested tuple/list/dict container data structure with a fixed
|
423
|
-
structure and arrays with fixed shape and dtype at the leaves).
|
424
|
-
|
425
|
-
Args:
|
426
|
-
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
427
|
-
that ``f`` accepts two arguments where the first is a value of the loop
|
428
|
-
carry and the second is a slice of ``xs`` along its leading axis, and that
|
429
|
-
``f`` returns a pair where the first element represents a new value for
|
430
|
-
the loop carry and the second represents a slice of the output.
|
431
|
-
init: an initial loop carry value of type ``c``, which can be a scalar,
|
432
|
-
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
433
|
-
the initial loop carry value. This value must have the same structure as
|
434
|
-
the first element of the pair returned by ``f``.
|
435
|
-
xs: the value of type ``[a]`` over which to scan along the leading axis,
|
436
|
-
where ``[a]`` can be an array or any pytree (nested Python
|
437
|
-
tuple/list/dict) thereof with consistent leading axis sizes.
|
438
|
-
length: optional integer specifying the number of loop iterations, which
|
439
|
-
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
440
|
-
be used to perform scans where no input ``xs`` are needed).
|
441
|
-
reverse: optional boolean specifying whether to run the scan iteration
|
442
|
-
forward (the default) or in reverse, equivalent to reversing the leading
|
443
|
-
axes of the arrays in both ``xs`` and in ``ys``.
|
444
|
-
unroll: optional positive int or bool specifying, in the underlying
|
445
|
-
operation of the scan primitive, how many scan iterations to unroll within
|
446
|
-
a single iteration of a loop. If an integer is provided, it determines how
|
447
|
-
many unrolled loop iterations to run within a single rolled iteration of
|
448
|
-
the loop. If a boolean is provided, it will determine if the loop is
|
449
|
-
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
450
|
-
`unroll=False`).
|
451
|
-
pbar: optional :class:`~.ProgressBar` instance to display the progress
|
452
|
-
of the scan operation.
|
453
|
-
|
454
|
-
Returns:
|
455
|
-
A pair of type ``(c, [b])`` where the first element represents the final
|
456
|
-
loop carry value and the second element represents the stacked outputs of
|
457
|
-
the second output of ``f`` when scanned over the leading axis of the inputs.
|
458
|
-
|
459
|
-
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
460
|
-
"""
|
461
|
-
# check "f"
|
462
|
-
if not callable(f):
|
463
|
-
raise TypeError("f argument should be a callable.")
|
464
|
-
|
465
|
-
# check "xs"
|
466
|
-
xs_flat, xs_tree = jax.tree.flatten(xs)
|
467
|
-
try:
|
468
|
-
lengths = [x.shape[0] for x in xs_flat]
|
469
|
-
except AttributeError as err:
|
470
|
-
raise ValueError("scan got value with no leading axis to scan over: "
|
471
|
-
"{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
|
472
|
-
if length is not None:
|
473
|
-
length = int(length)
|
474
|
-
if not all(length == l for l in lengths):
|
475
|
-
raise ValueError(("scan got `length` argument of {} which disagrees with "
|
476
|
-
"leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
|
477
|
-
else:
|
478
|
-
unique_lengths = set(lengths)
|
479
|
-
if len(unique_lengths) > 1:
|
480
|
-
msg = "scan got values with different leading axis sizes: {}."
|
481
|
-
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
|
482
|
-
elif len(unique_lengths) == 0:
|
483
|
-
raise ValueError("scan got no values to scan over and `length` not provided.")
|
484
|
-
else:
|
485
|
-
length, = unique_lengths
|
486
|
-
|
487
|
-
# function with progress bar
|
488
|
-
has_pbar = False
|
489
|
-
if pbar is not None:
|
490
|
-
has_pbar = True
|
491
|
-
f = _wrap_fun_with_pbar(f, pbar.init(length))
|
492
|
-
init = (0, init) if pbar else init
|
493
|
-
|
494
|
-
# not jit
|
495
|
-
if jax.config.jax_disable_jit:
|
496
|
-
if length == 0:
|
497
|
-
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
|
498
|
-
carry = init
|
499
|
-
ys = []
|
500
|
-
maybe_reversed = reversed if reverse else lambda x: x
|
501
|
-
for i in maybe_reversed(range(length)):
|
502
|
-
xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
|
503
|
-
carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
|
504
|
-
ys.append(y)
|
505
|
-
stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
|
506
|
-
if has_pbar:
|
507
|
-
return carry[1], stacked_y
|
508
|
-
else:
|
509
|
-
return carry, stacked_y
|
510
|
-
|
511
|
-
# evaluate jaxpr, get all states #
|
512
|
-
# ------------------------------ #
|
513
|
-
xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
|
514
|
-
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
515
|
-
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
516
|
-
all_states = stateful_fun.get_states()
|
517
|
-
wrapped_f = _wrapped_scan_fun(stateful_fun, all_states)
|
518
|
-
|
519
|
-
# scan
|
520
|
-
init = (tuple(st.value for st in all_states), init)
|
521
|
-
(state_vals, carry), ys = jax.lax.scan(wrapped_f, init, xs, length=length, reverse=reverse, unroll=unroll)
|
522
|
-
_assign_state_values(all_states, state_vals)
|
523
|
-
if has_pbar:
|
524
|
-
carry = carry[1]
|
525
|
-
return carry, ys
|
526
|
-
|
527
|
-
|
528
|
-
def _forloop_to_scan_fun(f: Callable):
|
529
|
-
@wraps(f)
|
530
|
-
def scan_fun(carry, x):
|
531
|
-
return carry, f(*x)
|
532
|
-
|
533
|
-
return scan_fun
|
534
|
-
|
535
|
-
|
536
|
-
@set_module_as('brainstate.transform')
|
537
|
-
def for_loop(
|
538
|
-
f,
|
539
|
-
*xs,
|
540
|
-
length: Optional[int] = None,
|
541
|
-
reverse: bool = False,
|
542
|
-
unroll: int | bool = 1,
|
543
|
-
pbar: Optional[ProgressBar] = None
|
544
|
-
):
|
545
|
-
"""
|
546
|
-
``for-loop`` control flow with :py:class:`~.State`.
|
547
|
-
|
548
|
-
Args:
|
549
|
-
f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
550
|
-
that ``f`` accepts two arguments where the first is a value of the loop
|
551
|
-
carry and the second is a slice of ``xs`` along its leading axis, and that
|
552
|
-
``f`` returns a pair where the first element represents a new value for
|
553
|
-
the loop carry and the second represents a slice of the output.
|
554
|
-
xs: the value of type ``[a]`` over which to scan along the leading axis,
|
555
|
-
where ``[a]`` can be an array or any pytree (nested Python
|
556
|
-
tuple/list/dict) thereof with consistent leading axis sizes.
|
557
|
-
length: optional integer specifying the number of loop iterations, which
|
558
|
-
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
559
|
-
be used to perform scans where no input ``xs`` are needed).
|
560
|
-
reverse: optional boolean specifying whether to run the scan iteration
|
561
|
-
forward (the default) or in reverse, equivalent to reversing the leading
|
562
|
-
axes of the arrays in both ``xs`` and in ``ys``.
|
563
|
-
unroll: optional positive int or bool specifying, in the underlying
|
564
|
-
operation of the scan primitive, how many scan iterations to unroll within
|
565
|
-
a single iteration of a loop. If an integer is provided, it determines how
|
566
|
-
many unrolled loop iterations to run within a single rolled iteration of
|
567
|
-
the loop. If a boolean is provided, it will determine if the loop is
|
568
|
-
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
569
|
-
`unroll=False`).
|
570
|
-
pbar: optional :class:`~.ProgressBar` instance to display the progress
|
571
|
-
of the scan operation.
|
572
|
-
|
573
|
-
Returns:
|
574
|
-
The return represents the stacked outputs of the second output of ``f``
|
575
|
-
when scanned over the leading axis of the inputs.
|
576
|
-
|
577
|
-
"""
|
578
|
-
_, ys = scan(_forloop_to_scan_fun(f),
|
579
|
-
init=None,
|
580
|
-
xs=xs,
|
581
|
-
length=length,
|
582
|
-
reverse=reverse,
|
583
|
-
unroll=unroll,
|
584
|
-
pbar=pbar)
|
585
|
-
return ys
|
586
|
-
|
587
|
-
|
588
|
-
@set_module_as('brainstate.transform')
|
589
|
-
def while_loop(
|
590
|
-
cond_fun: Callable[[T], BooleanNumeric],
|
591
|
-
body_fun: Callable[[T], T],
|
592
|
-
init_val: T
|
593
|
-
) -> T:
|
594
|
-
"""
|
595
|
-
Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
596
|
-
|
597
|
-
The `Haskell-like type signature`_ in brief is
|
598
|
-
|
599
|
-
.. code-block:: haskell
|
600
|
-
|
601
|
-
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
602
|
-
|
603
|
-
The semantics of ``while_loop`` are given by this Python implementation::
|
604
|
-
|
605
|
-
def while_loop(cond_fun, body_fun, init_val):
|
606
|
-
val = init_val
|
607
|
-
while cond_fun(val):
|
608
|
-
val = body_fun(val)
|
609
|
-
return val
|
610
|
-
|
611
|
-
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
612
|
-
to a single WhileOp. That makes it useful for reducing compilation times
|
613
|
-
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
614
|
-
function are unrolled, leading to large XLA computations.
|
615
|
-
|
616
|
-
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
617
|
-
fixed shape and dtype across all iterations (and not just be consistent up to
|
618
|
-
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
619
|
-
other words, the type ``a`` in the type signature above represents an array
|
620
|
-
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
621
|
-
structure with a fixed structure and arrays with fixed shape and dtype at the
|
622
|
-
leaves).
|
623
|
-
|
624
|
-
Another difference from using Python-native loop constructs is that
|
625
|
-
``while_loop`` is not reverse-mode differentiable because XLA computations
|
626
|
-
require static bounds on memory requirements.
|
627
|
-
|
628
|
-
Args:
|
629
|
-
cond_fun: function of type ``a -> Bool``.
|
630
|
-
body_fun: function of type ``a -> a``.
|
631
|
-
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
632
|
-
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
633
|
-
loop carry value.
|
634
|
-
|
635
|
-
Returns:
|
636
|
-
The output from the final iteration of body_fun, of type ``a``.
|
637
|
-
|
638
|
-
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
639
|
-
"""
|
640
|
-
if not (callable(body_fun) and callable(cond_fun)):
|
641
|
-
raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
|
642
|
-
if jax.config.jax_disable_jit:
|
643
|
-
try:
|
644
|
-
val = init_val
|
645
|
-
while cond_fun(val):
|
646
|
-
val = body_fun(val)
|
647
|
-
return val
|
648
|
-
except jax.core.ConcretizationTypeError:
|
649
|
-
# Can't run this while_loop in Python (e.g. because there's a vmap
|
650
|
-
# transformation on it), so we fall back to the primitive version.
|
651
|
-
pass
|
652
|
-
|
653
|
-
# evaluate jaxpr
|
654
|
-
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
655
|
-
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
656
|
-
all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
|
657
|
-
new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
|
658
|
-
new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
|
659
|
-
|
660
|
-
# while_loop
|
661
|
-
state_vals, final_val = jax.lax.while_loop(new_cond_fun,
|
662
|
-
new_body_fun,
|
663
|
-
(tuple(st.value for st in all_states), init_val))
|
664
|
-
_assign_state_values(all_states, state_vals)
|
665
|
-
return final_val
|