brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -146
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -470
  58. brainstate/nn/_delay_test.py +238 -0
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1361
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1120
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -208
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.7.dist-info/RECORD +0 -131
  133. {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,536 +1,536 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import math
17
- from functools import wraps
18
- from typing import Callable, Optional, TypeVar, Tuple, Any
19
-
20
- import jax
21
- import jax.numpy as jnp
22
-
23
- from brainstate._utils import set_module_as
24
- from ._make_jaxpr import StatefulFunction
25
- from ._progress_bar import ProgressBar
26
- from ._unvmap import unvmap
27
- from ._util import write_back_state_values, wrap_single_fun
28
-
29
- __all__ = [
30
- # "scan" syntax, which is similar to jax.lax.scan
31
- 'scan', 'checkpointed_scan',
32
-
33
- # "for_loop" syntax
34
- 'for_loop', 'checkpointed_for_loop',
35
- ]
36
-
37
- X = TypeVar('X')
38
- Y = TypeVar('Y')
39
- T = TypeVar('T')
40
- Carry = TypeVar('Carry')
41
-
42
-
43
- def _wrap_fun_with_pbar(
44
- fun: Callable[[Carry, X], Tuple[Carry, Y]],
45
- pbar_runner: Callable
46
- ):
47
- @wraps(fun)
48
- def new_fun(new_carry, inputs):
49
- i, old_carry = new_carry
50
- new_carry, new_outputs = fun(old_carry, inputs)
51
- pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
52
- return (i + 1, new_carry), new_outputs
53
-
54
- return new_fun
55
-
56
-
57
- @set_module_as('brainstate.compile')
58
- def scan(
59
- f: Callable[[Carry, X], Tuple[Carry, Y]],
60
- init: Carry,
61
- xs: X,
62
- length: int | None = None,
63
- reverse: bool = False,
64
- unroll: int | bool = 1,
65
- pbar: ProgressBar | int | None = None,
66
- ) -> Tuple[Carry, Y]:
67
- """
68
- Scan a function over leading array axes while carrying along state.
69
-
70
- The `Haskell-like type signature`_ in brief is
71
-
72
- .. code-block:: haskell
73
-
74
- scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
75
-
76
- where for any array type specifier ``t``, ``[t]`` represents the type with an additional
77
- leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
78
- represents the type with the same pytree structure and corresponding leaves
79
- each with an additional leading axis.
80
-
81
- When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
82
- of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
83
- given roughly by this Python implementation::
84
-
85
- def scan(f, init, xs, length=None):
86
- if xs is None:
87
- xs = [None] * length
88
- carry = init
89
- ys = []
90
- for x in xs:
91
- carry, y = f(carry, x)
92
- ys.append(y)
93
- return carry, np.stack(ys)
94
-
95
- Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
96
- values, and so multiple arrays can be scanned over at once and produce multiple
97
- output arrays. ``None`` is actually a special case of this, as it represents an
98
- empty pytree.
99
-
100
- Also unlike that Python version, :func:`~scan` is a JAX primitive and is
101
- lowered to a single WhileOp. That makes it useful for reducing
102
- compilation times for JIT-compiled functions, since native Python
103
- loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
104
- XLA computations.
105
-
106
- Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
107
- across all iterations (and not just be consistent up to NumPy rank/shape
108
- broadcasting and dtype promotion rules, for example). In other words, the type
109
- ``c`` in the type signature above represents an array with a fixed shape and
110
- dtype (or a nested tuple/list/dict container data structure with a fixed
111
- structure and arrays with fixed shape and dtype at the leaves).
112
-
113
- Args:
114
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
115
- that ``f`` accepts two arguments where the first is a value of the loop
116
- carry and the second is a slice of ``xs`` along its leading axis, and that
117
- ``f`` returns a pair where the first element represents a new value for
118
- the loop carry and the second represents a slice of the output.
119
- init: an initial loop carry value of type ``c``, which can be a scalar,
120
- array, or any pytree (nested Python tuple/list/dict) thereof, representing
121
- the initial loop carry value. This value must have the same structure as
122
- the first element of the pair returned by ``f``.
123
- xs: the value of type ``[a]`` over which to scan along the leading axis,
124
- where ``[a]`` can be an array or any pytree (nested Python
125
- tuple/list/dict) thereof with consistent leading axis sizes.
126
- length: optional integer specifying the number of loop iterations, which
127
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
128
- be used to perform scans where no input ``xs`` are needed).
129
- reverse: optional boolean specifying whether to run the scan iteration
130
- forward (the default) or in reverse, equivalent to reversing the leading
131
- axes of the arrays in both ``xs`` and in ``ys``.
132
- unroll: optional positive int or bool specifying, in the underlying
133
- operation of the scan primitive, how many scan iterations to unroll within
134
- a single iteration of a loop. If an integer is provided, it determines how
135
- many unrolled loop iterations to run within a single rolled iteration of
136
- the loop. If a boolean is provided, it will determine if the loop is
137
- completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
138
- `unroll=False`).
139
- pbar: optional :class:`~.ProgressBar` instance to display the progress
140
- of the scan operation.
141
-
142
- Returns:
143
- A pair of type ``(c, [b])`` where the first element represents the final
144
- loop carry value and the second element represents the stacked outputs of
145
- the second output of ``f`` when scanned over the leading axis of the inputs.
146
-
147
- .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
148
- """
149
- # check "f"
150
- if not callable(f):
151
- raise TypeError("f argument should be a callable.")
152
-
153
- # check "xs"
154
- xs_flat, xs_tree = jax.tree.flatten(xs)
155
- try:
156
- lengths = [x.shape[0] for x in xs_flat]
157
- except AttributeError as err:
158
- raise ValueError("scan got value with no leading axis to scan over: "
159
- "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
160
- if length is not None:
161
- length = int(length)
162
- if not all(length == l for l in lengths):
163
- raise ValueError(("scan got `length` argument of {} which disagrees with "
164
- "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
165
- else:
166
- unique_lengths = set(lengths)
167
- if len(unique_lengths) > 1:
168
- msg = "scan got values with different leading axis sizes: {}."
169
- raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
170
- elif len(unique_lengths) == 0:
171
- raise ValueError("scan got no values to scan over and `length` not provided.")
172
- else:
173
- length, = unique_lengths
174
-
175
- # function with progress bar
176
- has_pbar = False
177
- if pbar is not None:
178
- has_pbar = True
179
- if isinstance(pbar, ProgressBar):
180
- pbar_runner = pbar.init(length)
181
- elif isinstance(pbar, int):
182
- pbar_runner = ProgressBar(freq=pbar).init(length)
183
- else:
184
- raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
185
- f = _wrap_fun_with_pbar(f, pbar_runner)
186
- init = (0, init) if pbar else init
187
-
188
- # not jit
189
- if jax.config.jax_disable_jit:
190
- if length == 0:
191
- raise ValueError(
192
- "zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
193
- carry = init
194
- ys = []
195
- maybe_reversed = reversed if reverse else lambda x: x
196
- for i in maybe_reversed(range(length)):
197
- xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
198
- carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
199
- ys.append(y)
200
- stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
201
- if has_pbar:
202
- return carry[1], stacked_y
203
- else:
204
- return carry, stacked_y
205
-
206
- # evaluate jaxpr, get all states #
207
- # ------------------------------ #
208
- xs_avals = [jax.core.get_aval(x) for x in xs_flat]
209
- x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
210
- stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
211
- state_trace = stateful_fun.get_state_trace()
212
- all_writen_state_vals = state_trace.get_write_state_values(True)
213
- all_read_state_vals = state_trace.get_read_state_values(True)
214
- wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
215
-
216
- # scan
217
- init = (all_writen_state_vals, init)
218
- (
219
- (
220
- all_writen_state_vals,
221
- carry
222
- ),
223
- ys
224
- ) = jax.lax.scan(
225
- wrapped_f,
226
- init,
227
- xs,
228
- length=length,
229
- reverse=reverse,
230
- unroll=unroll
231
- )
232
- # assign the written state values and restore the read state values
233
- write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
234
- # carry
235
- if has_pbar:
236
- carry = carry[1]
237
- return carry, ys
238
-
239
-
240
- def checkpointed_scan(
241
- f: Callable[[Carry, X], Tuple[Carry, Y]],
242
- init: Carry,
243
- xs: X,
244
- length: Optional[int] = None,
245
- base: int = 16,
246
- pbar: Optional[ProgressBar | int] = None,
247
- ) -> Tuple[Carry, Y]:
248
- """
249
- Scan a function over leading array axes while carrying along state.
250
- This function is similar to :func:`~scan` but with a checkpointed version.
251
-
252
- Args:
253
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
254
- that ``f`` accepts two arguments where the first is a value of the loop
255
- carry and the second is a slice of ``xs`` along its leading axis, and that
256
- ``f`` returns a pair where the first element represents a new value for
257
- the loop carry and the second represents a slice of the output.
258
- init: an initial loop carry value of type ``c``, which can be a scalar,
259
- array, or any pytree (nested Python tuple/list/dict) thereof, representing
260
- the initial loop carry value. This value must have the same structure as
261
- the first element of the pair returned by ``f``.
262
- xs: the value of type ``[a]`` over which to scan along the leading axis,
263
- where ``[a]`` can be an array or any pytree (nested Python
264
- tuple/list/dict) thereof with consistent leading axis sizes.
265
- length: optional integer specifying the number of loop iterations, which
266
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
267
- be used to perform scans where no input ``xs`` are needed).
268
- base: optional integer specifying the base for the bounded scan loop.
269
- pbar: optional :class:`~.ProgressBar` instance to display the progress
270
- of the scan operation.
271
-
272
- Returns:
273
- A pair of type ``(c, [b])`` where the first element represents the final
274
- loop carry value and the second element represents the stacked outputs of
275
- the second output of ``f`` when scanned over the leading axis of the inputs.
276
- """
277
- # check "f"
278
- if not callable(f):
279
- raise TypeError("f argument should be a callable.")
280
-
281
- # check "xs"
282
- xs_flat, xs_tree = jax.tree.flatten(xs)
283
- try:
284
- lengths = [x.shape[0] for x in xs_flat]
285
- except AttributeError as err:
286
- raise ValueError("scan got value with no leading axis to scan over: "
287
- "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
288
- if length is not None:
289
- length = int(length)
290
- if not all(length == l for l in lengths):
291
- raise ValueError(("scan got `length` argument of {} which disagrees with "
292
- "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
293
- else:
294
- unique_lengths = set(lengths)
295
- if len(unique_lengths) > 1:
296
- msg = "scan got values with different leading axis sizes: {}."
297
- raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
298
- elif len(unique_lengths) == 0:
299
- raise ValueError("scan got no values to scan over and `length` not provided.")
300
- else:
301
- length, = unique_lengths
302
-
303
- # function with progress bar
304
- if isinstance(pbar, ProgressBar):
305
- pbar_runner = pbar.init(length)
306
- elif isinstance(pbar, int):
307
- pbar_runner = ProgressBar(freq=pbar).init(length)
308
- else:
309
- pbar_runner = None
310
-
311
- # evaluate jaxpr
312
- xs_avals = [jax.core.get_aval(x) for x in xs_flat]
313
- x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
314
- stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
315
- state_trace = stateful_fun.get_state_trace()
316
- # get all states
317
- been_written = state_trace.been_writen
318
- read_state_vals = state_trace.get_read_state_values(True)
319
- write_state_vals = state_trace.get_write_state_values(True)
320
-
321
- # initialize the collected values/dataa
322
- out_info = stateful_fun.get_out_shapes()[0]
323
- assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
324
- data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
325
- del out_info
326
-
327
- def wrapped_cond_fun(inp):
328
- return inp[-1] < length
329
-
330
- def wrapped_body_fun(inp):
331
- (prev_write_states, carray), prev_collect, i = inp
332
- # progress bar
333
- if pbar_runner is not None:
334
- pbar_runner(unvmap(i, op='none'))
335
- # call the function
336
- prev_states = [w_val if write else r_val
337
- for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
338
- new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
339
- prev_states, carray, jax.tree.map(lambda x: x[i], xs)
340
- )
341
- # new written states
342
- new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
343
-
344
- # out of length
345
- pred = i < length
346
- new_collect = jax.tree.map(
347
- # lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
348
- lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
349
- prev_collect,
350
- out4updates,
351
- )
352
- new_write_states = jax.tree.map(
353
- lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
354
- prev_write_states,
355
- new_write_states,
356
- is_leaf=lambda x: x is None
357
- )
358
- new_carray = jax.tree.map(
359
- lambda pc, nc: jax.lax.select(pred, nc, pc),
360
- carray,
361
- new_carray,
362
- )
363
- return (new_write_states, new_carray), new_collect, i + 1
364
-
365
- # while_loop
366
- rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
367
- (write_state_vals, carry), data2collection, _ = (
368
- _bounded_while_loop(
369
- wrapped_cond_fun,
370
- wrapped_body_fun,
371
- ((write_state_vals, init), data2collection, 0),
372
- rounded_max_steps,
373
- base,
374
- pbar_runner
375
- )
376
- )
377
- # assign the written state values and restore the read state values
378
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
379
- del write_state_vals, read_state_vals, stateful_fun
380
- return carry, data2collection
381
-
382
-
383
- def _forloop_to_scan_fun(f: Callable):
384
- @wraps(f)
385
- def scan_fun(carry, x):
386
- return carry, f(*x)
387
-
388
- return scan_fun
389
-
390
-
391
- @set_module_as('brainstate.compile')
392
- def for_loop(
393
- f: Callable[..., Y],
394
- *xs,
395
- length: Optional[int] = None,
396
- reverse: bool = False,
397
- unroll: int | bool = 1,
398
- pbar: Optional[ProgressBar | int] = None
399
- ) -> Y:
400
- """
401
- ``for-loop`` control flow with :py:class:`~.State`.
402
-
403
- Args:
404
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
405
- that ``f`` accepts two arguments where the first is a value of the loop
406
- carry and the second is a slice of ``xs`` along its leading axis, and that
407
- ``f`` returns a pair where the first element represents a new value for
408
- the loop carry and the second represents a slice of the output.
409
- xs: the value of type ``[a]`` over which to scan along the leading axis,
410
- where ``[a]`` can be an array or any pytree (nested Python
411
- tuple/list/dict) thereof with consistent leading axis sizes.
412
- length: optional integer specifying the number of loop iterations, which
413
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
414
- be used to perform scans where no input ``xs`` are needed).
415
- reverse: optional boolean specifying whether to run the scan iteration
416
- forward (the default) or in reverse, equivalent to reversing the leading
417
- axes of the arrays in both ``xs`` and in ``ys``.
418
- unroll: optional positive int or bool specifying, in the underlying
419
- operation of the scan primitive, how many scan iterations to unroll within
420
- a single iteration of a loop. If an integer is provided, it determines how
421
- many unrolled loop iterations to run within a single rolled iteration of
422
- the loop. If a boolean is provided, it will determine if the loop is
423
- completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
424
- `unroll=False`).
425
- pbar: optional :class:`~.ProgressBar` instance to display the progress
426
- of the scan operation.
427
-
428
- Returns:
429
- The return represents the stacked outputs of the second output of ``f``
430
- when scanned over the leading axis of the inputs.
431
-
432
- """
433
- _, ys = scan(
434
- _forloop_to_scan_fun(f),
435
- init=None,
436
- xs=xs,
437
- length=length,
438
- reverse=reverse,
439
- unroll=unroll,
440
- pbar=pbar
441
- )
442
- return ys
443
-
444
-
445
- def checkpointed_for_loop(
446
- f: Callable[..., Y],
447
- *xs: X,
448
- length: Optional[int] = None,
449
- base: int = 16,
450
- pbar: Optional[ProgressBar | int] = None,
451
- ) -> Y:
452
- """
453
- ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
454
-
455
- Args:
456
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
457
- that ``f`` accepts two arguments where the first is a value of the loop
458
- carry and the second is a slice of ``xs`` along its leading axis, and that
459
- ``f`` returns a pair where the first element represents a new value for
460
- the loop carry and the second represents a slice of the output.
461
- xs: the value of type ``[a]`` over which to scan along the leading axis,
462
- where ``[a]`` can be an array or any pytree (nested Python
463
- tuple/list/dict) thereof with consistent leading axis sizes.
464
- length: optional integer specifying the number of loop iterations, which
465
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
466
- be used to perform scans where no input ``xs`` are needed).
467
- base: optional integer specifying the base for the bounded scan loop.
468
- pbar: optional :class:`~.ProgressBar` instance to display the progress
469
- of the scan operation.
470
-
471
- Returns:
472
- The return represents the stacked outputs of the second output of ``f``
473
- when scanned over the leading axis of the inputs.
474
- """
475
- _, ys = checkpointed_scan(
476
- _forloop_to_scan_fun(f),
477
- init=None,
478
- xs=xs,
479
- length=length,
480
- base=base,
481
- pbar=pbar
482
- )
483
- return ys
484
-
485
-
486
- # This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
487
-
488
- # There's several tricks happening here to work around various limitations of JAX.
489
- # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
490
- # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
491
- # is converted to a `lax.select`, which executes both branches unconditionally.
492
- # Thus writing this naively, using a plain `lax.cond`, will mean the loop always
493
- # runs to `max_steps` when executing under vmap. Instead we run (only) until every
494
- # batch element has finished.
495
- # 2. Treating in-place updates specially in the body_fun. Specifically we need to
496
- # `lax.select` the update-to-make, not the updated buffer. This is because the
497
- # latter instead results in XLA:CPU failing to determine that the buffer can be
498
- # updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
499
- # This is done through the extra `inplace` argument provided to `body_fun`.
500
- # 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
501
- # `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
502
- # θ(number of steps actually taken).
503
- # 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
504
- # fewest superfluous operations. In practice this implies quite deep recursion in
505
- # the construction of the bounded while loop, and this slows down the jaxpr
506
- # creation and the XLA compilation. We choose `base=16` as a reasonable-looking
507
- # compromise between compilation time and run time.
508
-
509
- def _bounded_while_loop(
510
- cond_fun: Callable,
511
- body_fun: Callable,
512
- val: Any,
513
- max_steps: int,
514
- base: int,
515
- pbar_runner: Optional[Callable] = None
516
- ):
517
- if max_steps == 1:
518
- return body_fun(val)
519
- else:
520
-
521
- def true_call(val_):
522
- return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
523
-
524
- def false_call(val_):
525
- if pbar_runner is not None:
526
- pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
527
- return val_[:-1] + (val_[-1] + max_steps,)
528
-
529
- def scan_fn(val_, _):
530
- return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
531
-
532
- # Don't put checkpointing on the lowest level
533
- if max_steps != base:
534
- scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
535
-
536
- return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import math
17
+ from functools import wraps
18
+ from typing import Callable, Optional, TypeVar, Tuple, Any
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate._utils import set_module_as
24
+ from ._make_jaxpr import StatefulFunction
25
+ from ._progress_bar import ProgressBar
26
+ from ._unvmap import unvmap
27
+ from ._util import write_back_state_values, wrap_single_fun
28
+
29
+ __all__ = [
30
+ # "scan" syntax, which is similar to jax.lax.scan
31
+ 'scan', 'checkpointed_scan',
32
+
33
+ # "for_loop" syntax
34
+ 'for_loop', 'checkpointed_for_loop',
35
+ ]
36
+
37
+ X = TypeVar('X')
38
+ Y = TypeVar('Y')
39
+ T = TypeVar('T')
40
+ Carry = TypeVar('Carry')
41
+
42
+
43
+ def _wrap_fun_with_pbar(
44
+ fun: Callable[[Carry, X], Tuple[Carry, Y]],
45
+ pbar_runner: Callable
46
+ ):
47
+ @wraps(fun)
48
+ def new_fun(new_carry, inputs):
49
+ i, old_carry = new_carry
50
+ new_carry, new_outputs = fun(old_carry, inputs)
51
+ pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
52
+ return (i + 1, new_carry), new_outputs
53
+
54
+ return new_fun
55
+
56
+
57
+ @set_module_as('brainstate.compile')
58
+ def scan(
59
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
60
+ init: Carry,
61
+ xs: X,
62
+ length: int | None = None,
63
+ reverse: bool = False,
64
+ unroll: int | bool = 1,
65
+ pbar: ProgressBar | int | None = None,
66
+ ) -> Tuple[Carry, Y]:
67
+ """
68
+ Scan a function over leading array axes while carrying along state.
69
+
70
+ The `Haskell-like type signature`_ in brief is
71
+
72
+ .. code-block:: haskell
73
+
74
+ scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
75
+
76
+ where for any array type specifier ``t``, ``[t]`` represents the type with an additional
77
+ leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
78
+ represents the type with the same pytree structure and corresponding leaves
79
+ each with an additional leading axis.
80
+
81
+ When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
82
+ of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
83
+ given roughly by this Python implementation::
84
+
85
+ def scan(f, init, xs, length=None):
86
+ if xs is None:
87
+ xs = [None] * length
88
+ carry = init
89
+ ys = []
90
+ for x in xs:
91
+ carry, y = f(carry, x)
92
+ ys.append(y)
93
+ return carry, np.stack(ys)
94
+
95
+ Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
96
+ values, and so multiple arrays can be scanned over at once and produce multiple
97
+ output arrays. ``None`` is actually a special case of this, as it represents an
98
+ empty pytree.
99
+
100
+ Also unlike that Python version, :func:`~scan` is a JAX primitive and is
101
+ lowered to a single WhileOp. That makes it useful for reducing
102
+ compilation times for JIT-compiled functions, since native Python
103
+ loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
104
+ XLA computations.
105
+
106
+ Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
107
+ across all iterations (and not just be consistent up to NumPy rank/shape
108
+ broadcasting and dtype promotion rules, for example). In other words, the type
109
+ ``c`` in the type signature above represents an array with a fixed shape and
110
+ dtype (or a nested tuple/list/dict container data structure with a fixed
111
+ structure and arrays with fixed shape and dtype at the leaves).
112
+
113
+ Args:
114
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
115
+ that ``f`` accepts two arguments where the first is a value of the loop
116
+ carry and the second is a slice of ``xs`` along its leading axis, and that
117
+ ``f`` returns a pair where the first element represents a new value for
118
+ the loop carry and the second represents a slice of the output.
119
+ init: an initial loop carry value of type ``c``, which can be a scalar,
120
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
121
+ the initial loop carry value. This value must have the same structure as
122
+ the first element of the pair returned by ``f``.
123
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
124
+ where ``[a]`` can be an array or any pytree (nested Python
125
+ tuple/list/dict) thereof with consistent leading axis sizes.
126
+ length: optional integer specifying the number of loop iterations, which
127
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
128
+ be used to perform scans where no input ``xs`` are needed).
129
+ reverse: optional boolean specifying whether to run the scan iteration
130
+ forward (the default) or in reverse, equivalent to reversing the leading
131
+ axes of the arrays in both ``xs`` and in ``ys``.
132
+ unroll: optional positive int or bool specifying, in the underlying
133
+ operation of the scan primitive, how many scan iterations to unroll within
134
+ a single iteration of a loop. If an integer is provided, it determines how
135
+ many unrolled loop iterations to run within a single rolled iteration of
136
+ the loop. If a boolean is provided, it will determine if the loop is
137
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
138
+ `unroll=False`).
139
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
140
+ of the scan operation.
141
+
142
+ Returns:
143
+ A pair of type ``(c, [b])`` where the first element represents the final
144
+ loop carry value and the second element represents the stacked outputs of
145
+ the second output of ``f`` when scanned over the leading axis of the inputs.
146
+
147
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
148
+ """
149
+ # check "f"
150
+ if not callable(f):
151
+ raise TypeError("f argument should be a callable.")
152
+
153
+ # check "xs"
154
+ xs_flat, xs_tree = jax.tree.flatten(xs)
155
+ try:
156
+ lengths = [x.shape[0] for x in xs_flat]
157
+ except AttributeError as err:
158
+ raise ValueError("scan got value with no leading axis to scan over: "
159
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
160
+ if length is not None:
161
+ length = int(length)
162
+ if not all(length == l for l in lengths):
163
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
164
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
165
+ else:
166
+ unique_lengths = set(lengths)
167
+ if len(unique_lengths) > 1:
168
+ msg = "scan got values with different leading axis sizes: {}."
169
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
170
+ elif len(unique_lengths) == 0:
171
+ raise ValueError("scan got no values to scan over and `length` not provided.")
172
+ else:
173
+ length, = unique_lengths
174
+
175
+ # function with progress bar
176
+ has_pbar = False
177
+ if pbar is not None:
178
+ has_pbar = True
179
+ if isinstance(pbar, ProgressBar):
180
+ pbar_runner = pbar.init(length)
181
+ elif isinstance(pbar, int):
182
+ pbar_runner = ProgressBar(freq=pbar).init(length)
183
+ else:
184
+ raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
185
+ f = _wrap_fun_with_pbar(f, pbar_runner)
186
+ init = (0, init) if pbar else init
187
+
188
+ # not jit
189
+ if jax.config.jax_disable_jit:
190
+ if length == 0:
191
+ raise ValueError(
192
+ "zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
193
+ carry = init
194
+ ys = []
195
+ maybe_reversed = reversed if reverse else lambda x: x
196
+ for i in maybe_reversed(range(length)):
197
+ xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
198
+ carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
199
+ ys.append(y)
200
+ stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
201
+ if has_pbar:
202
+ return carry[1], stacked_y
203
+ else:
204
+ return carry, stacked_y
205
+
206
+ # evaluate jaxpr, get all states #
207
+ # ------------------------------ #
208
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
209
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
210
+ stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
211
+ state_trace = stateful_fun.get_state_trace()
212
+ all_writen_state_vals = state_trace.get_write_state_values(True)
213
+ all_read_state_vals = state_trace.get_read_state_values(True)
214
+ wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
215
+
216
+ # scan
217
+ init = (all_writen_state_vals, init)
218
+ (
219
+ (
220
+ all_writen_state_vals,
221
+ carry
222
+ ),
223
+ ys
224
+ ) = jax.lax.scan(
225
+ wrapped_f,
226
+ init,
227
+ xs,
228
+ length=length,
229
+ reverse=reverse,
230
+ unroll=unroll
231
+ )
232
+ # assign the written state values and restore the read state values
233
+ write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
234
+ # carry
235
+ if has_pbar:
236
+ carry = carry[1]
237
+ return carry, ys
238
+
239
+
240
+ def checkpointed_scan(
241
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
242
+ init: Carry,
243
+ xs: X,
244
+ length: Optional[int] = None,
245
+ base: int = 16,
246
+ pbar: Optional[ProgressBar | int] = None,
247
+ ) -> Tuple[Carry, Y]:
248
+ """
249
+ Scan a function over leading array axes while carrying along state.
250
+ This function is similar to :func:`~scan` but with a checkpointed version.
251
+
252
+ Args:
253
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
254
+ that ``f`` accepts two arguments where the first is a value of the loop
255
+ carry and the second is a slice of ``xs`` along its leading axis, and that
256
+ ``f`` returns a pair where the first element represents a new value for
257
+ the loop carry and the second represents a slice of the output.
258
+ init: an initial loop carry value of type ``c``, which can be a scalar,
259
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
260
+ the initial loop carry value. This value must have the same structure as
261
+ the first element of the pair returned by ``f``.
262
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
263
+ where ``[a]`` can be an array or any pytree (nested Python
264
+ tuple/list/dict) thereof with consistent leading axis sizes.
265
+ length: optional integer specifying the number of loop iterations, which
266
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
267
+ be used to perform scans where no input ``xs`` are needed).
268
+ base: optional integer specifying the base for the bounded scan loop.
269
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
270
+ of the scan operation.
271
+
272
+ Returns:
273
+ A pair of type ``(c, [b])`` where the first element represents the final
274
+ loop carry value and the second element represents the stacked outputs of
275
+ the second output of ``f`` when scanned over the leading axis of the inputs.
276
+ """
277
+ # check "f"
278
+ if not callable(f):
279
+ raise TypeError("f argument should be a callable.")
280
+
281
+ # check "xs"
282
+ xs_flat, xs_tree = jax.tree.flatten(xs)
283
+ try:
284
+ lengths = [x.shape[0] for x in xs_flat]
285
+ except AttributeError as err:
286
+ raise ValueError("scan got value with no leading axis to scan over: "
287
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
288
+ if length is not None:
289
+ length = int(length)
290
+ if not all(length == l for l in lengths):
291
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
292
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
293
+ else:
294
+ unique_lengths = set(lengths)
295
+ if len(unique_lengths) > 1:
296
+ msg = "scan got values with different leading axis sizes: {}."
297
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
298
+ elif len(unique_lengths) == 0:
299
+ raise ValueError("scan got no values to scan over and `length` not provided.")
300
+ else:
301
+ length, = unique_lengths
302
+
303
+ # function with progress bar
304
+ if isinstance(pbar, ProgressBar):
305
+ pbar_runner = pbar.init(length)
306
+ elif isinstance(pbar, int):
307
+ pbar_runner = ProgressBar(freq=pbar).init(length)
308
+ else:
309
+ pbar_runner = None
310
+
311
+ # evaluate jaxpr
312
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
313
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
314
+ stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
315
+ state_trace = stateful_fun.get_state_trace()
316
+ # get all states
317
+ been_written = state_trace.been_writen
318
+ read_state_vals = state_trace.get_read_state_values(True)
319
+ write_state_vals = state_trace.get_write_state_values(True)
320
+
321
+ # initialize the collected values/dataa
322
+ out_info = stateful_fun.get_out_shapes()[0]
323
+ assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
324
+ data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
325
+ del out_info
326
+
327
+ def wrapped_cond_fun(inp):
328
+ return inp[-1] < length
329
+
330
+ def wrapped_body_fun(inp):
331
+ (prev_write_states, carray), prev_collect, i = inp
332
+ # progress bar
333
+ if pbar_runner is not None:
334
+ pbar_runner(unvmap(i, op='none'))
335
+ # call the function
336
+ prev_states = [w_val if write else r_val
337
+ for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
338
+ new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
339
+ prev_states, carray, jax.tree.map(lambda x: x[i], xs)
340
+ )
341
+ # new written states
342
+ new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
343
+
344
+ # out of length
345
+ pred = i < length
346
+ new_collect = jax.tree.map(
347
+ # lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
348
+ lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
349
+ prev_collect,
350
+ out4updates,
351
+ )
352
+ new_write_states = jax.tree.map(
353
+ lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
354
+ prev_write_states,
355
+ new_write_states,
356
+ is_leaf=lambda x: x is None
357
+ )
358
+ new_carray = jax.tree.map(
359
+ lambda pc, nc: jax.lax.select(pred, nc, pc),
360
+ carray,
361
+ new_carray,
362
+ )
363
+ return (new_write_states, new_carray), new_collect, i + 1
364
+
365
+ # while_loop
366
+ rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
367
+ (write_state_vals, carry), data2collection, _ = (
368
+ _bounded_while_loop(
369
+ wrapped_cond_fun,
370
+ wrapped_body_fun,
371
+ ((write_state_vals, init), data2collection, 0),
372
+ rounded_max_steps,
373
+ base,
374
+ pbar_runner
375
+ )
376
+ )
377
+ # assign the written state values and restore the read state values
378
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
379
+ del write_state_vals, read_state_vals, stateful_fun
380
+ return carry, data2collection
381
+
382
+
383
+ def _forloop_to_scan_fun(f: Callable):
384
+ @wraps(f)
385
+ def scan_fun(carry, x):
386
+ return carry, f(*x)
387
+
388
+ return scan_fun
389
+
390
+
391
+ @set_module_as('brainstate.compile')
392
+ def for_loop(
393
+ f: Callable[..., Y],
394
+ *xs,
395
+ length: Optional[int] = None,
396
+ reverse: bool = False,
397
+ unroll: int | bool = 1,
398
+ pbar: Optional[ProgressBar | int] = None
399
+ ) -> Y:
400
+ """
401
+ ``for-loop`` control flow with :py:class:`~.State`.
402
+
403
+ Args:
404
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
405
+ that ``f`` accepts two arguments where the first is a value of the loop
406
+ carry and the second is a slice of ``xs`` along its leading axis, and that
407
+ ``f`` returns a pair where the first element represents a new value for
408
+ the loop carry and the second represents a slice of the output.
409
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
410
+ where ``[a]`` can be an array or any pytree (nested Python
411
+ tuple/list/dict) thereof with consistent leading axis sizes.
412
+ length: optional integer specifying the number of loop iterations, which
413
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
414
+ be used to perform scans where no input ``xs`` are needed).
415
+ reverse: optional boolean specifying whether to run the scan iteration
416
+ forward (the default) or in reverse, equivalent to reversing the leading
417
+ axes of the arrays in both ``xs`` and in ``ys``.
418
+ unroll: optional positive int or bool specifying, in the underlying
419
+ operation of the scan primitive, how many scan iterations to unroll within
420
+ a single iteration of a loop. If an integer is provided, it determines how
421
+ many unrolled loop iterations to run within a single rolled iteration of
422
+ the loop. If a boolean is provided, it will determine if the loop is
423
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
424
+ `unroll=False`).
425
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
426
+ of the scan operation.
427
+
428
+ Returns:
429
+ The return represents the stacked outputs of the second output of ``f``
430
+ when scanned over the leading axis of the inputs.
431
+
432
+ """
433
+ _, ys = scan(
434
+ _forloop_to_scan_fun(f),
435
+ init=None,
436
+ xs=xs,
437
+ length=length,
438
+ reverse=reverse,
439
+ unroll=unroll,
440
+ pbar=pbar
441
+ )
442
+ return ys
443
+
444
+
445
+ def checkpointed_for_loop(
446
+ f: Callable[..., Y],
447
+ *xs: X,
448
+ length: Optional[int] = None,
449
+ base: int = 16,
450
+ pbar: Optional[ProgressBar | int] = None,
451
+ ) -> Y:
452
+ """
453
+ ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
454
+
455
+ Args:
456
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
457
+ that ``f`` accepts two arguments where the first is a value of the loop
458
+ carry and the second is a slice of ``xs`` along its leading axis, and that
459
+ ``f`` returns a pair where the first element represents a new value for
460
+ the loop carry and the second represents a slice of the output.
461
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
462
+ where ``[a]`` can be an array or any pytree (nested Python
463
+ tuple/list/dict) thereof with consistent leading axis sizes.
464
+ length: optional integer specifying the number of loop iterations, which
465
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
466
+ be used to perform scans where no input ``xs`` are needed).
467
+ base: optional integer specifying the base for the bounded scan loop.
468
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
469
+ of the scan operation.
470
+
471
+ Returns:
472
+ The return represents the stacked outputs of the second output of ``f``
473
+ when scanned over the leading axis of the inputs.
474
+ """
475
+ _, ys = checkpointed_scan(
476
+ _forloop_to_scan_fun(f),
477
+ init=None,
478
+ xs=xs,
479
+ length=length,
480
+ base=base,
481
+ pbar=pbar
482
+ )
483
+ return ys
484
+
485
+
486
+ # This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
487
+
488
+ # There's several tricks happening here to work around various limitations of JAX.
489
+ # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
490
+ # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
491
+ # is converted to a `lax.select`, which executes both branches unconditionally.
492
+ # Thus writing this naively, using a plain `lax.cond`, will mean the loop always
493
+ # runs to `max_steps` when executing under vmap. Instead we run (only) until every
494
+ # batch element has finished.
495
+ # 2. Treating in-place updates specially in the body_fun. Specifically we need to
496
+ # `lax.select` the update-to-make, not the updated buffer. This is because the
497
+ # latter instead results in XLA:CPU failing to determine that the buffer can be
498
+ # updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
499
+ # This is done through the extra `inplace` argument provided to `body_fun`.
500
+ # 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
501
+ # `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
502
+ # θ(number of steps actually taken).
503
+ # 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
504
+ # fewest superfluous operations. In practice this implies quite deep recursion in
505
+ # the construction of the bounded while loop, and this slows down the jaxpr
506
+ # creation and the XLA compilation. We choose `base=16` as a reasonable-looking
507
+ # compromise between compilation time and run time.
508
+
509
+ def _bounded_while_loop(
510
+ cond_fun: Callable,
511
+ body_fun: Callable,
512
+ val: Any,
513
+ max_steps: int,
514
+ base: int,
515
+ pbar_runner: Optional[Callable] = None
516
+ ):
517
+ if max_steps == 1:
518
+ return body_fun(val)
519
+ else:
520
+
521
+ def true_call(val_):
522
+ return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
523
+
524
+ def false_call(val_):
525
+ if pbar_runner is not None:
526
+ pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
527
+ return val_[:-1] + (val_[-1] + max_steps,)
528
+
529
+ def scan_fn(val_, _):
530
+ return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
531
+
532
+ # Don't put checkpointing on the lowest level
533
+ if max_steps != base:
534
+ scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
535
+
536
+ return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]