brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -24,7 +24,7 @@ from brainstate._utils import set_module_as
|
|
24
24
|
from ._make_jaxpr import StatefulFunction
|
25
25
|
from ._progress_bar import ProgressBar
|
26
26
|
from ._unvmap import unvmap
|
27
|
-
from ._util import
|
27
|
+
from ._util import wrap_single_fun
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
# "scan" syntax, which is similar to jax.lax.scan
|
@@ -54,7 +54,7 @@ def _wrap_fun_with_pbar(
|
|
54
54
|
return new_fun
|
55
55
|
|
56
56
|
|
57
|
-
@set_module_as('brainstate.
|
57
|
+
@set_module_as('brainstate.transform')
|
58
58
|
def scan(
|
59
59
|
f: Callable[[Carry, X], Tuple[Carry, Y]],
|
60
60
|
init: Carry,
|
@@ -80,17 +80,19 @@ def scan(
|
|
80
80
|
|
81
81
|
When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
|
82
82
|
of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
|
83
|
-
given roughly by this Python implementation
|
83
|
+
given roughly by this Python implementation:
|
84
84
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
85
|
+
.. code-block:: python
|
86
|
+
|
87
|
+
>>> def scan(f, init, xs, length=None):
|
88
|
+
... if xs is None:
|
89
|
+
... xs = [None] * length
|
90
|
+
... carry = init
|
91
|
+
... ys = []
|
92
|
+
... for x in xs:
|
93
|
+
... carry, y = f(carry, x)
|
94
|
+
... ys.append(y)
|
95
|
+
... return carry, np.stack(ys)
|
94
96
|
|
95
97
|
Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
|
96
98
|
values, and so multiple arrays can be scanned over at once and produce multiple
|
@@ -110,40 +112,75 @@ def scan(
|
|
110
112
|
dtype (or a nested tuple/list/dict container data structure with a fixed
|
111
113
|
structure and arrays with fixed shape and dtype at the leaves).
|
112
114
|
|
113
|
-
|
114
|
-
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
f : callable
|
118
|
+
A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
115
119
|
that ``f`` accepts two arguments where the first is a value of the loop
|
116
120
|
carry and the second is a slice of ``xs`` along its leading axis, and that
|
117
121
|
``f`` returns a pair where the first element represents a new value for
|
118
122
|
the loop carry and the second represents a slice of the output.
|
119
|
-
|
123
|
+
init : Carry
|
124
|
+
An initial loop carry value of type ``c``, which can be a scalar,
|
120
125
|
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
121
126
|
the initial loop carry value. This value must have the same structure as
|
122
127
|
the first element of the pair returned by ``f``.
|
123
|
-
|
128
|
+
xs : X
|
129
|
+
The value of type ``[a]`` over which to scan along the leading axis,
|
124
130
|
where ``[a]`` can be an array or any pytree (nested Python
|
125
131
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
126
|
-
|
132
|
+
length : int, optional
|
133
|
+
Optional integer specifying the number of loop iterations, which
|
127
134
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
128
135
|
be used to perform scans where no input ``xs`` are needed).
|
129
|
-
|
136
|
+
reverse : bool, default False
|
137
|
+
Optional boolean specifying whether to run the scan iteration
|
130
138
|
forward (the default) or in reverse, equivalent to reversing the leading
|
131
139
|
axes of the arrays in both ``xs`` and in ``ys``.
|
132
|
-
|
140
|
+
unroll : int or bool, default 1
|
141
|
+
Optional positive int or bool specifying, in the underlying
|
133
142
|
operation of the scan primitive, how many scan iterations to unroll within
|
134
143
|
a single iteration of a loop. If an integer is provided, it determines how
|
135
144
|
many unrolled loop iterations to run within a single rolled iteration of
|
136
145
|
the loop. If a boolean is provided, it will determine if the loop is
|
137
146
|
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
138
147
|
`unroll=False`).
|
139
|
-
|
148
|
+
pbar : ProgressBar or int, optional
|
149
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
140
150
|
of the scan operation.
|
141
151
|
|
142
|
-
Returns
|
143
|
-
|
144
|
-
|
145
|
-
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
tuple of (Carry, Y)
|
155
|
+
A pair of type ``(c, [b])`` where the first element represents the final
|
156
|
+
loop carry value and the second element represents the stacked outputs of
|
157
|
+
the second output of ``f`` when scanned over the leading axis of the inputs.
|
158
|
+
|
159
|
+
Examples
|
160
|
+
--------
|
161
|
+
Basic scan operation:
|
162
|
+
|
163
|
+
.. code-block:: python
|
146
164
|
|
165
|
+
>>> import brainstate
|
166
|
+
>>> import jax.numpy as jnp
|
167
|
+
>>>
|
168
|
+
>>> def step_fn(carry, x):
|
169
|
+
... return carry + x, carry * x
|
170
|
+
>>>
|
171
|
+
>>> init = 0.0
|
172
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
173
|
+
>>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs)
|
174
|
+
|
175
|
+
Scan with progress bar:
|
176
|
+
|
177
|
+
.. code-block:: python
|
178
|
+
|
179
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
180
|
+
>>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs, pbar=pbar)
|
181
|
+
|
182
|
+
References
|
183
|
+
----------
|
147
184
|
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
148
185
|
"""
|
149
186
|
# check "f"
|
@@ -207,8 +244,9 @@ def scan(
|
|
207
244
|
# ------------------------------ #
|
208
245
|
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
209
246
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
210
|
-
|
211
|
-
|
247
|
+
args = [init, xs_tree.unflatten(x_avals)]
|
248
|
+
stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
|
249
|
+
state_trace = stateful_fun.get_state_trace(*args)
|
212
250
|
all_writen_state_vals = state_trace.get_write_state_values(True)
|
213
251
|
all_read_state_vals = state_trace.get_read_state_values(True)
|
214
252
|
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
@@ -230,13 +268,14 @@ def scan(
|
|
230
268
|
unroll=unroll
|
231
269
|
)
|
232
270
|
# assign the written state values and restore the read state values
|
233
|
-
|
271
|
+
state_trace.assign_state_vals_v2(all_read_state_vals, all_writen_state_vals)
|
234
272
|
# carry
|
235
273
|
if has_pbar:
|
236
274
|
carry = carry[1]
|
237
275
|
return carry, ys
|
238
276
|
|
239
277
|
|
278
|
+
@set_module_as('brainstate.transform')
|
240
279
|
def checkpointed_scan(
|
241
280
|
f: Callable[[Carry, X], Tuple[Carry, Y]],
|
242
281
|
init: Carry,
|
@@ -249,30 +288,63 @@ def checkpointed_scan(
|
|
249
288
|
Scan a function over leading array axes while carrying along state.
|
250
289
|
This function is similar to :func:`~scan` but with a checkpointed version.
|
251
290
|
|
252
|
-
|
253
|
-
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
f : callable
|
294
|
+
A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
|
254
295
|
that ``f`` accepts two arguments where the first is a value of the loop
|
255
296
|
carry and the second is a slice of ``xs`` along its leading axis, and that
|
256
297
|
``f`` returns a pair where the first element represents a new value for
|
257
298
|
the loop carry and the second represents a slice of the output.
|
258
|
-
|
299
|
+
init : Carry
|
300
|
+
An initial loop carry value of type ``c``, which can be a scalar,
|
259
301
|
array, or any pytree (nested Python tuple/list/dict) thereof, representing
|
260
302
|
the initial loop carry value. This value must have the same structure as
|
261
303
|
the first element of the pair returned by ``f``.
|
262
|
-
|
304
|
+
xs : X
|
305
|
+
The value of type ``[a]`` over which to scan along the leading axis,
|
263
306
|
where ``[a]`` can be an array or any pytree (nested Python
|
264
307
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
265
|
-
|
308
|
+
length : int, optional
|
309
|
+
Optional integer specifying the number of loop iterations, which
|
266
310
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
267
311
|
be used to perform scans where no input ``xs`` are needed).
|
268
|
-
|
269
|
-
|
312
|
+
base : int, default 16
|
313
|
+
Optional integer specifying the base for the bounded scan loop.
|
314
|
+
pbar : ProgressBar or int, optional
|
315
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
270
316
|
of the scan operation.
|
271
317
|
|
272
|
-
Returns
|
273
|
-
|
274
|
-
|
275
|
-
|
318
|
+
Returns
|
319
|
+
-------
|
320
|
+
tuple of (Carry, Y)
|
321
|
+
A pair of type ``(c, [b])`` where the first element represents the final
|
322
|
+
loop carry value and the second element represents the stacked outputs of
|
323
|
+
the second output of ``f`` when scanned over the leading axis of the inputs.
|
324
|
+
|
325
|
+
Examples
|
326
|
+
--------
|
327
|
+
Basic checkpointed scan operation:
|
328
|
+
|
329
|
+
.. code-block:: python
|
330
|
+
|
331
|
+
>>> import brainstate
|
332
|
+
>>> import jax.numpy as jnp
|
333
|
+
>>>
|
334
|
+
>>> def step_fn(carry, x):
|
335
|
+
... return carry + x, carry * x
|
336
|
+
>>>
|
337
|
+
>>> init = 0.0
|
338
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
339
|
+
>>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)
|
340
|
+
|
341
|
+
Using custom base for checkpointing:
|
342
|
+
|
343
|
+
.. code-block:: python
|
344
|
+
|
345
|
+
>>> final_carry, ys = brainstate.transform.checkpointed_scan(
|
346
|
+
... step_fn, init, xs, base=8
|
347
|
+
... )
|
276
348
|
"""
|
277
349
|
# check "f"
|
278
350
|
if not callable(f):
|
@@ -311,15 +383,17 @@ def checkpointed_scan(
|
|
311
383
|
# evaluate jaxpr
|
312
384
|
xs_avals = [jax.core.get_aval(x) for x in xs_flat]
|
313
385
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
314
|
-
|
315
|
-
|
386
|
+
args = (init, xs_tree.unflatten(x_avals))
|
387
|
+
stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
|
388
|
+
state_trace = stateful_fun.get_state_trace(*args)
|
389
|
+
cache_key = stateful_fun.get_arg_cache_key(*args)
|
316
390
|
# get all states
|
317
391
|
been_written = state_trace.been_writen
|
318
392
|
read_state_vals = state_trace.get_read_state_values(True)
|
319
393
|
write_state_vals = state_trace.get_write_state_values(True)
|
320
394
|
|
321
395
|
# initialize the collected values/dataa
|
322
|
-
out_info = stateful_fun.
|
396
|
+
out_info = stateful_fun.get_out_shapes_by_cache(cache_key)[0]
|
323
397
|
assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
|
324
398
|
data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
|
325
399
|
del out_info
|
@@ -375,7 +449,7 @@ def checkpointed_scan(
|
|
375
449
|
)
|
376
450
|
)
|
377
451
|
# assign the written state values and restore the read state values
|
378
|
-
|
452
|
+
state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
|
379
453
|
del write_state_vals, read_state_vals, stateful_fun
|
380
454
|
return carry, data2collection
|
381
455
|
|
@@ -388,7 +462,7 @@ def _forloop_to_scan_fun(f: Callable):
|
|
388
462
|
return scan_fun
|
389
463
|
|
390
464
|
|
391
|
-
@set_module_as('brainstate.
|
465
|
+
@set_module_as('brainstate.transform')
|
392
466
|
def for_loop(
|
393
467
|
f: Callable[..., Y],
|
394
468
|
*xs,
|
@@ -400,35 +474,69 @@ def for_loop(
|
|
400
474
|
"""
|
401
475
|
``for-loop`` control flow with :py:class:`~.State`.
|
402
476
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
the
|
409
|
-
|
410
|
-
|
477
|
+
Parameters
|
478
|
+
----------
|
479
|
+
f : callable
|
480
|
+
A Python function to be looped over that accepts variadic arguments
|
481
|
+
corresponding to slices of ``xs`` along their leading axes, and returns
|
482
|
+
the output for that iteration.
|
483
|
+
*xs
|
484
|
+
The values over which to loop along the leading axis,
|
485
|
+
where each can be an array or any pytree (nested Python
|
411
486
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
412
|
-
|
487
|
+
length : int, optional
|
488
|
+
Optional integer specifying the number of loop iterations, which
|
413
489
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
414
|
-
be used to perform
|
415
|
-
|
490
|
+
be used to perform loops where no input ``xs`` are needed).
|
491
|
+
reverse : bool, default False
|
492
|
+
Optional boolean specifying whether to run the loop iteration
|
416
493
|
forward (the default) or in reverse, equivalent to reversing the leading
|
417
494
|
axes of the arrays in both ``xs`` and in ``ys``.
|
418
|
-
|
419
|
-
|
495
|
+
unroll : int or bool, default 1
|
496
|
+
Optional positive int or bool specifying, in the underlying
|
497
|
+
operation of the scan primitive, how many loop iterations to unroll within
|
420
498
|
a single iteration of a loop. If an integer is provided, it determines how
|
421
499
|
many unrolled loop iterations to run within a single rolled iteration of
|
422
500
|
the loop. If a boolean is provided, it will determine if the loop is
|
423
501
|
completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
|
424
502
|
`unroll=False`).
|
425
|
-
|
426
|
-
|
503
|
+
pbar : ProgressBar or int, optional
|
504
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
505
|
+
of the loop operation.
|
506
|
+
|
507
|
+
Returns
|
508
|
+
-------
|
509
|
+
Y
|
510
|
+
The stacked outputs of ``f`` when looped over the leading axis of the inputs.
|
511
|
+
|
512
|
+
Examples
|
513
|
+
--------
|
514
|
+
Basic for-loop operation:
|
515
|
+
|
516
|
+
.. code-block:: python
|
517
|
+
|
518
|
+
>>> import brainstate
|
519
|
+
>>> import jax.numpy as jnp
|
520
|
+
>>>
|
521
|
+
>>> def process_item(x, y):
|
522
|
+
... return x * y + 1
|
523
|
+
>>>
|
524
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
525
|
+
>>> ys = jnp.array([4.0, 5.0, 6.0])
|
526
|
+
>>> results = brainstate.transform.for_loop(process_item, xs, ys)
|
527
|
+
|
528
|
+
For-loop with progress bar:
|
529
|
+
|
530
|
+
.. code-block:: python
|
531
|
+
|
532
|
+
>>> pbar = brainstate.transform.ProgressBar(freq=10)
|
533
|
+
>>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
|
427
534
|
|
428
|
-
|
429
|
-
The return represents the stacked outputs of the second output of ``f``
|
430
|
-
when scanned over the leading axis of the inputs.
|
535
|
+
For-loop with reverse iteration:
|
431
536
|
|
537
|
+
.. code-block:: python
|
538
|
+
|
539
|
+
>>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)
|
432
540
|
"""
|
433
541
|
_, ys = scan(
|
434
542
|
_forloop_to_scan_fun(f),
|
@@ -442,6 +550,7 @@ def for_loop(
|
|
442
550
|
return ys
|
443
551
|
|
444
552
|
|
553
|
+
@set_module_as('brainstate.transform')
|
445
554
|
def checkpointed_for_loop(
|
446
555
|
f: Callable[..., Y],
|
447
556
|
*xs: X,
|
@@ -452,25 +561,54 @@ def checkpointed_for_loop(
|
|
452
561
|
"""
|
453
562
|
``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
|
454
563
|
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
the
|
461
|
-
|
462
|
-
|
564
|
+
Parameters
|
565
|
+
----------
|
566
|
+
f : callable
|
567
|
+
A Python function to be looped over that accepts variadic arguments
|
568
|
+
corresponding to slices of ``xs`` along their leading axes, and returns
|
569
|
+
the output for that iteration.
|
570
|
+
*xs : X
|
571
|
+
The values over which to loop along the leading axis,
|
572
|
+
where each can be an array or any pytree (nested Python
|
463
573
|
tuple/list/dict) thereof with consistent leading axis sizes.
|
464
|
-
|
574
|
+
length : int, optional
|
575
|
+
Optional integer specifying the number of loop iterations, which
|
465
576
|
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
466
|
-
be used to perform
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
577
|
+
be used to perform loops where no input ``xs`` are needed).
|
578
|
+
base : int, default 16
|
579
|
+
Optional integer specifying the base for the bounded loop.
|
580
|
+
pbar : ProgressBar or int, optional
|
581
|
+
Optional :class:`~.ProgressBar` instance to display the progress
|
582
|
+
of the loop operation.
|
583
|
+
|
584
|
+
Returns
|
585
|
+
-------
|
586
|
+
Y
|
587
|
+
The stacked outputs of ``f`` when looped over the leading axis of the inputs.
|
588
|
+
|
589
|
+
Examples
|
590
|
+
--------
|
591
|
+
Basic checkpointed for-loop operation:
|
592
|
+
|
593
|
+
.. code-block:: python
|
594
|
+
|
595
|
+
>>> import brainstate
|
596
|
+
>>> import jax.numpy as jnp
|
597
|
+
>>>
|
598
|
+
>>> def process_item(x, y):
|
599
|
+
... return x * y + 1
|
600
|
+
>>>
|
601
|
+
>>> xs = jnp.array([1.0, 2.0, 3.0])
|
602
|
+
>>> ys = jnp.array([4.0, 5.0, 6.0])
|
603
|
+
>>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
|
604
|
+
|
605
|
+
Using custom base for checkpointing:
|
606
|
+
|
607
|
+
.. code-block:: python
|
608
|
+
|
609
|
+
>>> results = brainstate.transform.checkpointed_for_loop(
|
610
|
+
... process_item, xs, ys, base=8
|
611
|
+
... )
|
474
612
|
"""
|
475
613
|
_, ys = checkpointed_scan(
|
476
614
|
_forloop_to_scan_fun(f),
|
@@ -483,7 +621,8 @@ def checkpointed_for_loop(
|
|
483
621
|
return ys
|
484
622
|
|
485
623
|
|
486
|
-
# This function is adapted from ``while_loop`` in
|
624
|
+
# This function is adapted from ``while_loop`` in
|
625
|
+
# `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
|
487
626
|
|
488
627
|
# There's several tricks happening here to work around various limitations of JAX.
|
489
628
|
# (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
|