brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,675 +1,675 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import math
17
- from functools import wraps
18
- from typing import Callable, Optional, TypeVar, Tuple, Any
19
-
20
- import jax
21
- import jax.numpy as jnp
22
-
23
- from brainstate._utils import set_module_as
24
- from ._make_jaxpr import StatefulFunction
25
- from ._progress_bar import ProgressBar
26
- from ._unvmap import unvmap
27
- from ._util import wrap_single_fun
28
-
29
- __all__ = [
30
- # "scan" syntax, which is similar to jax.lax.scan
31
- 'scan', 'checkpointed_scan',
32
-
33
- # "for_loop" syntax
34
- 'for_loop', 'checkpointed_for_loop',
35
- ]
36
-
37
- X = TypeVar('X')
38
- Y = TypeVar('Y')
39
- T = TypeVar('T')
40
- Carry = TypeVar('Carry')
41
-
42
-
43
- def _wrap_fun_with_pbar(
44
- fun: Callable[[Carry, X], Tuple[Carry, Y]],
45
- pbar_runner: Callable
46
- ):
47
- @wraps(fun)
48
- def new_fun(new_carry, inputs):
49
- i, old_carry = new_carry
50
- new_carry, new_outputs = fun(old_carry, inputs)
51
- pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
52
- return (i + 1, new_carry), new_outputs
53
-
54
- return new_fun
55
-
56
-
57
- @set_module_as('brainstate.transform')
58
- def scan(
59
- f: Callable[[Carry, X], Tuple[Carry, Y]],
60
- init: Carry,
61
- xs: X,
62
- length: int | None = None,
63
- reverse: bool = False,
64
- unroll: int | bool = 1,
65
- pbar: ProgressBar | int | None = None,
66
- ) -> Tuple[Carry, Y]:
67
- """
68
- Scan a function over leading array axes while carrying along state.
69
-
70
- The `Haskell-like type signature`_ in brief is
71
-
72
- .. code-block:: haskell
73
-
74
- scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
75
-
76
- where for any array type specifier ``t``, ``[t]`` represents the type with an additional
77
- leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
78
- represents the type with the same pytree structure and corresponding leaves
79
- each with an additional leading axis.
80
-
81
- When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
82
- of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
83
- given roughly by this Python implementation:
84
-
85
- .. code-block:: python
86
-
87
- >>> def scan(f, init, xs, length=None):
88
- ... if xs is None:
89
- ... xs = [None] * length
90
- ... carry = init
91
- ... ys = []
92
- ... for x in xs:
93
- ... carry, y = f(carry, x)
94
- ... ys.append(y)
95
- ... return carry, np.stack(ys)
96
-
97
- Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
98
- values, and so multiple arrays can be scanned over at once and produce multiple
99
- output arrays. ``None`` is actually a special case of this, as it represents an
100
- empty pytree.
101
-
102
- Also unlike that Python version, :func:`~scan` is a JAX primitive and is
103
- lowered to a single WhileOp. That makes it useful for reducing
104
- compilation times for JIT-compiled functions, since native Python
105
- loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
106
- XLA computations.
107
-
108
- Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
109
- across all iterations (and not just be consistent up to NumPy rank/shape
110
- broadcasting and dtype promotion rules, for example). In other words, the type
111
- ``c`` in the type signature above represents an array with a fixed shape and
112
- dtype (or a nested tuple/list/dict container data structure with a fixed
113
- structure and arrays with fixed shape and dtype at the leaves).
114
-
115
- Parameters
116
- ----------
117
- f : callable
118
- A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
119
- that ``f`` accepts two arguments where the first is a value of the loop
120
- carry and the second is a slice of ``xs`` along its leading axis, and that
121
- ``f`` returns a pair where the first element represents a new value for
122
- the loop carry and the second represents a slice of the output.
123
- init : Carry
124
- An initial loop carry value of type ``c``, which can be a scalar,
125
- array, or any pytree (nested Python tuple/list/dict) thereof, representing
126
- the initial loop carry value. This value must have the same structure as
127
- the first element of the pair returned by ``f``.
128
- xs : X
129
- The value of type ``[a]`` over which to scan along the leading axis,
130
- where ``[a]`` can be an array or any pytree (nested Python
131
- tuple/list/dict) thereof with consistent leading axis sizes.
132
- length : int, optional
133
- Optional integer specifying the number of loop iterations, which
134
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
135
- be used to perform scans where no input ``xs`` are needed).
136
- reverse : bool, default False
137
- Optional boolean specifying whether to run the scan iteration
138
- forward (the default) or in reverse, equivalent to reversing the leading
139
- axes of the arrays in both ``xs`` and in ``ys``.
140
- unroll : int or bool, default 1
141
- Optional positive int or bool specifying, in the underlying
142
- operation of the scan primitive, how many scan iterations to unroll within
143
- a single iteration of a loop. If an integer is provided, it determines how
144
- many unrolled loop iterations to run within a single rolled iteration of
145
- the loop. If a boolean is provided, it will determine if the loop is
146
- completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
147
- `unroll=False`).
148
- pbar : ProgressBar or int, optional
149
- Optional :class:`~.ProgressBar` instance to display the progress
150
- of the scan operation.
151
-
152
- Returns
153
- -------
154
- tuple of (Carry, Y)
155
- A pair of type ``(c, [b])`` where the first element represents the final
156
- loop carry value and the second element represents the stacked outputs of
157
- the second output of ``f`` when scanned over the leading axis of the inputs.
158
-
159
- Examples
160
- --------
161
- Basic scan operation:
162
-
163
- .. code-block:: python
164
-
165
- >>> import brainstate
166
- >>> import jax.numpy as jnp
167
- >>>
168
- >>> def step_fn(carry, x):
169
- ... return carry + x, carry * x
170
- >>>
171
- >>> init = 0.0
172
- >>> xs = jnp.array([1.0, 2.0, 3.0])
173
- >>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs)
174
-
175
- Scan with progress bar:
176
-
177
- .. code-block:: python
178
-
179
- >>> pbar = brainstate.transform.ProgressBar(freq=10)
180
- >>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs, pbar=pbar)
181
-
182
- References
183
- ----------
184
- .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
185
- """
186
- # check "f"
187
- if not callable(f):
188
- raise TypeError("f argument should be a callable.")
189
-
190
- # check "xs"
191
- xs_flat, xs_tree = jax.tree.flatten(xs)
192
- try:
193
- lengths = [x.shape[0] for x in xs_flat]
194
- except AttributeError as err:
195
- raise ValueError("scan got value with no leading axis to scan over: "
196
- "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
197
- if length is not None:
198
- length = int(length)
199
- if not all(length == l for l in lengths):
200
- raise ValueError(("scan got `length` argument of {} which disagrees with "
201
- "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
202
- else:
203
- unique_lengths = set(lengths)
204
- if len(unique_lengths) > 1:
205
- msg = "scan got values with different leading axis sizes: {}."
206
- raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
207
- elif len(unique_lengths) == 0:
208
- raise ValueError("scan got no values to scan over and `length` not provided.")
209
- else:
210
- length, = unique_lengths
211
-
212
- # function with progress bar
213
- has_pbar = False
214
- if pbar is not None:
215
- has_pbar = True
216
- if isinstance(pbar, ProgressBar):
217
- pbar_runner = pbar.init(length)
218
- elif isinstance(pbar, int):
219
- pbar_runner = ProgressBar(freq=pbar).init(length)
220
- else:
221
- raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
222
- f = _wrap_fun_with_pbar(f, pbar_runner)
223
- init = (0, init) if pbar else init
224
-
225
- # not jit
226
- if jax.config.jax_disable_jit:
227
- if length == 0:
228
- raise ValueError(
229
- "zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
230
- carry = init
231
- ys = []
232
- maybe_reversed = reversed if reverse else lambda x: x
233
- for i in maybe_reversed(range(length)):
234
- xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
235
- carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
236
- ys.append(y)
237
- stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
238
- if has_pbar:
239
- return carry[1], stacked_y
240
- else:
241
- return carry, stacked_y
242
-
243
- # evaluate jaxpr, get all states #
244
- # ------------------------------ #
245
- xs_avals = [jax.core.get_aval(x) for x in xs_flat]
246
- x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
247
- args = [init, xs_tree.unflatten(x_avals)]
248
- stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
249
- state_trace = stateful_fun.get_state_trace(*args)
250
- all_writen_state_vals = state_trace.get_write_state_values(True)
251
- all_read_state_vals = state_trace.get_read_state_values(True)
252
- wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
253
-
254
- # scan
255
- init = (all_writen_state_vals, init)
256
- (
257
- (
258
- all_writen_state_vals,
259
- carry
260
- ),
261
- ys
262
- ) = jax.lax.scan(
263
- wrapped_f,
264
- init,
265
- xs,
266
- length=length,
267
- reverse=reverse,
268
- unroll=unroll
269
- )
270
- # assign the written state values and restore the read state values
271
- state_trace.assign_state_vals_v2(all_read_state_vals, all_writen_state_vals)
272
- # carry
273
- if has_pbar:
274
- carry = carry[1]
275
- return carry, ys
276
-
277
-
278
- @set_module_as('brainstate.transform')
279
- def checkpointed_scan(
280
- f: Callable[[Carry, X], Tuple[Carry, Y]],
281
- init: Carry,
282
- xs: X,
283
- length: Optional[int] = None,
284
- base: int = 16,
285
- pbar: Optional[ProgressBar | int] = None,
286
- ) -> Tuple[Carry, Y]:
287
- """
288
- Scan a function over leading array axes while carrying along state.
289
- This function is similar to :func:`~scan` but with a checkpointed version.
290
-
291
- Parameters
292
- ----------
293
- f : callable
294
- A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
295
- that ``f`` accepts two arguments where the first is a value of the loop
296
- carry and the second is a slice of ``xs`` along its leading axis, and that
297
- ``f`` returns a pair where the first element represents a new value for
298
- the loop carry and the second represents a slice of the output.
299
- init : Carry
300
- An initial loop carry value of type ``c``, which can be a scalar,
301
- array, or any pytree (nested Python tuple/list/dict) thereof, representing
302
- the initial loop carry value. This value must have the same structure as
303
- the first element of the pair returned by ``f``.
304
- xs : X
305
- The value of type ``[a]`` over which to scan along the leading axis,
306
- where ``[a]`` can be an array or any pytree (nested Python
307
- tuple/list/dict) thereof with consistent leading axis sizes.
308
- length : int, optional
309
- Optional integer specifying the number of loop iterations, which
310
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
311
- be used to perform scans where no input ``xs`` are needed).
312
- base : int, default 16
313
- Optional integer specifying the base for the bounded scan loop.
314
- pbar : ProgressBar or int, optional
315
- Optional :class:`~.ProgressBar` instance to display the progress
316
- of the scan operation.
317
-
318
- Returns
319
- -------
320
- tuple of (Carry, Y)
321
- A pair of type ``(c, [b])`` where the first element represents the final
322
- loop carry value and the second element represents the stacked outputs of
323
- the second output of ``f`` when scanned over the leading axis of the inputs.
324
-
325
- Examples
326
- --------
327
- Basic checkpointed scan operation:
328
-
329
- .. code-block:: python
330
-
331
- >>> import brainstate
332
- >>> import jax.numpy as jnp
333
- >>>
334
- >>> def step_fn(carry, x):
335
- ... return carry + x, carry * x
336
- >>>
337
- >>> init = 0.0
338
- >>> xs = jnp.array([1.0, 2.0, 3.0])
339
- >>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)
340
-
341
- Using custom base for checkpointing:
342
-
343
- .. code-block:: python
344
-
345
- >>> final_carry, ys = brainstate.transform.checkpointed_scan(
346
- ... step_fn, init, xs, base=8
347
- ... )
348
- """
349
- # check "f"
350
- if not callable(f):
351
- raise TypeError("f argument should be a callable.")
352
-
353
- # check "xs"
354
- xs_flat, xs_tree = jax.tree.flatten(xs)
355
- try:
356
- lengths = [x.shape[0] for x in xs_flat]
357
- except AttributeError as err:
358
- raise ValueError("scan got value with no leading axis to scan over: "
359
- "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
360
- if length is not None:
361
- length = int(length)
362
- if not all(length == l for l in lengths):
363
- raise ValueError(("scan got `length` argument of {} which disagrees with "
364
- "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
365
- else:
366
- unique_lengths = set(lengths)
367
- if len(unique_lengths) > 1:
368
- msg = "scan got values with different leading axis sizes: {}."
369
- raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
370
- elif len(unique_lengths) == 0:
371
- raise ValueError("scan got no values to scan over and `length` not provided.")
372
- else:
373
- length, = unique_lengths
374
-
375
- # function with progress bar
376
- if isinstance(pbar, ProgressBar):
377
- pbar_runner = pbar.init(length)
378
- elif isinstance(pbar, int):
379
- pbar_runner = ProgressBar(freq=pbar).init(length)
380
- else:
381
- pbar_runner = None
382
-
383
- # evaluate jaxpr
384
- xs_avals = [jax.core.get_aval(x) for x in xs_flat]
385
- x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
386
- args = (init, xs_tree.unflatten(x_avals))
387
- stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
388
- state_trace = stateful_fun.get_state_trace(*args)
389
- cache_key = stateful_fun.get_arg_cache_key(*args)
390
- # get all states
391
- been_written = state_trace.been_writen
392
- read_state_vals = state_trace.get_read_state_values(True)
393
- write_state_vals = state_trace.get_write_state_values(True)
394
-
395
- # initialize the collected values/dataa
396
- out_info = stateful_fun.get_out_shapes_by_cache(cache_key)[0]
397
- assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
398
- data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
399
- del out_info
400
-
401
- def wrapped_cond_fun(inp):
402
- return inp[-1] < length
403
-
404
- def wrapped_body_fun(inp):
405
- (prev_write_states, carray), prev_collect, i = inp
406
- # progress bar
407
- if pbar_runner is not None:
408
- pbar_runner(unvmap(i, op='none'))
409
- # call the function
410
- prev_states = [w_val if write else r_val
411
- for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
412
- new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
413
- prev_states, carray, jax.tree.map(lambda x: x[i], xs)
414
- )
415
- # new written states
416
- new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
417
-
418
- # out of length
419
- pred = i < length
420
- new_collect = jax.tree.map(
421
- # lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
422
- lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
423
- prev_collect,
424
- out4updates,
425
- )
426
- new_write_states = jax.tree.map(
427
- lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
428
- prev_write_states,
429
- new_write_states,
430
- is_leaf=lambda x: x is None
431
- )
432
- new_carray = jax.tree.map(
433
- lambda pc, nc: jax.lax.select(pred, nc, pc),
434
- carray,
435
- new_carray,
436
- )
437
- return (new_write_states, new_carray), new_collect, i + 1
438
-
439
- # while_loop
440
- rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
441
- (write_state_vals, carry), data2collection, _ = (
442
- _bounded_while_loop(
443
- wrapped_cond_fun,
444
- wrapped_body_fun,
445
- ((write_state_vals, init), data2collection, 0),
446
- rounded_max_steps,
447
- base,
448
- pbar_runner
449
- )
450
- )
451
- # assign the written state values and restore the read state values
452
- state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
453
- del write_state_vals, read_state_vals, stateful_fun
454
- return carry, data2collection
455
-
456
-
457
- def _forloop_to_scan_fun(f: Callable):
458
- @wraps(f)
459
- def scan_fun(carry, x):
460
- return carry, f(*x)
461
-
462
- return scan_fun
463
-
464
-
465
- @set_module_as('brainstate.transform')
466
- def for_loop(
467
- f: Callable[..., Y],
468
- *xs,
469
- length: Optional[int] = None,
470
- reverse: bool = False,
471
- unroll: int | bool = 1,
472
- pbar: Optional[ProgressBar | int] = None
473
- ) -> Y:
474
- """
475
- ``for-loop`` control flow with :py:class:`~.State`.
476
-
477
- Parameters
478
- ----------
479
- f : callable
480
- A Python function to be looped over that accepts variadic arguments
481
- corresponding to slices of ``xs`` along their leading axes, and returns
482
- the output for that iteration.
483
- *xs
484
- The values over which to loop along the leading axis,
485
- where each can be an array or any pytree (nested Python
486
- tuple/list/dict) thereof with consistent leading axis sizes.
487
- length : int, optional
488
- Optional integer specifying the number of loop iterations, which
489
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
490
- be used to perform loops where no input ``xs`` are needed).
491
- reverse : bool, default False
492
- Optional boolean specifying whether to run the loop iteration
493
- forward (the default) or in reverse, equivalent to reversing the leading
494
- axes of the arrays in both ``xs`` and in ``ys``.
495
- unroll : int or bool, default 1
496
- Optional positive int or bool specifying, in the underlying
497
- operation of the scan primitive, how many loop iterations to unroll within
498
- a single iteration of a loop. If an integer is provided, it determines how
499
- many unrolled loop iterations to run within a single rolled iteration of
500
- the loop. If a boolean is provided, it will determine if the loop is
501
- completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
502
- `unroll=False`).
503
- pbar : ProgressBar or int, optional
504
- Optional :class:`~.ProgressBar` instance to display the progress
505
- of the loop operation.
506
-
507
- Returns
508
- -------
509
- Y
510
- The stacked outputs of ``f`` when looped over the leading axis of the inputs.
511
-
512
- Examples
513
- --------
514
- Basic for-loop operation:
515
-
516
- .. code-block:: python
517
-
518
- >>> import brainstate
519
- >>> import jax.numpy as jnp
520
- >>>
521
- >>> def process_item(x, y):
522
- ... return x * y + 1
523
- >>>
524
- >>> xs = jnp.array([1.0, 2.0, 3.0])
525
- >>> ys = jnp.array([4.0, 5.0, 6.0])
526
- >>> results = brainstate.transform.for_loop(process_item, xs, ys)
527
-
528
- For-loop with progress bar:
529
-
530
- .. code-block:: python
531
-
532
- >>> pbar = brainstate.transform.ProgressBar(freq=10)
533
- >>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
534
-
535
- For-loop with reverse iteration:
536
-
537
- .. code-block:: python
538
-
539
- >>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)
540
- """
541
- _, ys = scan(
542
- _forloop_to_scan_fun(f),
543
- init=None,
544
- xs=xs,
545
- length=length,
546
- reverse=reverse,
547
- unroll=unroll,
548
- pbar=pbar
549
- )
550
- return ys
551
-
552
-
553
- @set_module_as('brainstate.transform')
554
- def checkpointed_for_loop(
555
- f: Callable[..., Y],
556
- *xs: X,
557
- length: Optional[int] = None,
558
- base: int = 16,
559
- pbar: Optional[ProgressBar | int] = None,
560
- ) -> Y:
561
- """
562
- ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
563
-
564
- Parameters
565
- ----------
566
- f : callable
567
- A Python function to be looped over that accepts variadic arguments
568
- corresponding to slices of ``xs`` along their leading axes, and returns
569
- the output for that iteration.
570
- *xs : X
571
- The values over which to loop along the leading axis,
572
- where each can be an array or any pytree (nested Python
573
- tuple/list/dict) thereof with consistent leading axis sizes.
574
- length : int, optional
575
- Optional integer specifying the number of loop iterations, which
576
- must agree with the sizes of leading axes of the arrays in ``xs`` (but can
577
- be used to perform loops where no input ``xs`` are needed).
578
- base : int, default 16
579
- Optional integer specifying the base for the bounded loop.
580
- pbar : ProgressBar or int, optional
581
- Optional :class:`~.ProgressBar` instance to display the progress
582
- of the loop operation.
583
-
584
- Returns
585
- -------
586
- Y
587
- The stacked outputs of ``f`` when looped over the leading axis of the inputs.
588
-
589
- Examples
590
- --------
591
- Basic checkpointed for-loop operation:
592
-
593
- .. code-block:: python
594
-
595
- >>> import brainstate
596
- >>> import jax.numpy as jnp
597
- >>>
598
- >>> def process_item(x, y):
599
- ... return x * y + 1
600
- >>>
601
- >>> xs = jnp.array([1.0, 2.0, 3.0])
602
- >>> ys = jnp.array([4.0, 5.0, 6.0])
603
- >>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
604
-
605
- Using custom base for checkpointing:
606
-
607
- .. code-block:: python
608
-
609
- >>> results = brainstate.transform.checkpointed_for_loop(
610
- ... process_item, xs, ys, base=8
611
- ... )
612
- """
613
- _, ys = checkpointed_scan(
614
- _forloop_to_scan_fun(f),
615
- init=None,
616
- xs=xs,
617
- length=length,
618
- base=base,
619
- pbar=pbar
620
- )
621
- return ys
622
-
623
-
624
- # This function is adapted from ``while_loop`` in
625
- # `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
626
-
627
- # There's several tricks happening here to work around various limitations of JAX.
628
- # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
629
- # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
630
- # is converted to a `lax.select`, which executes both branches unconditionally.
631
- # Thus writing this naively, using a plain `lax.cond`, will mean the loop always
632
- # runs to `max_steps` when executing under vmap. Instead we run (only) until every
633
- # batch element has finished.
634
- # 2. Treating in-place updates specially in the body_fun. Specifically we need to
635
- # `lax.select` the update-to-make, not the updated buffer. This is because the
636
- # latter instead results in XLA:CPU failing to determine that the buffer can be
637
- # updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
638
- # This is done through the extra `inplace` argument provided to `body_fun`.
639
- # 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
640
- # `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
641
- # θ(number of steps actually taken).
642
- # 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
643
- # fewest superfluous operations. In practice this implies quite deep recursion in
644
- # the construction of the bounded while loop, and this slows down the jaxpr
645
- # creation and the XLA compilation. We choose `base=16` as a reasonable-looking
646
- # compromise between compilation time and run time.
647
-
648
- def _bounded_while_loop(
649
- cond_fun: Callable,
650
- body_fun: Callable,
651
- val: Any,
652
- max_steps: int,
653
- base: int,
654
- pbar_runner: Optional[Callable] = None
655
- ):
656
- if max_steps == 1:
657
- return body_fun(val)
658
- else:
659
-
660
- def true_call(val_):
661
- return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
662
-
663
- def false_call(val_):
664
- if pbar_runner is not None:
665
- pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
666
- return val_[:-1] + (val_[-1] + max_steps,)
667
-
668
- def scan_fn(val_, _):
669
- return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
670
-
671
- # Don't put checkpointing on the lowest level
672
- if max_steps != base:
673
- scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
674
-
675
- return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import math
17
+ from functools import wraps
18
+ from typing import Callable, Optional, TypeVar, Tuple, Any
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate._utils import set_module_as
24
+ from ._make_jaxpr import StatefulFunction
25
+ from ._progress_bar import ProgressBar
26
+ from ._unvmap import unvmap
27
+ from ._util import wrap_single_fun
28
+
29
+ __all__ = [
30
+ # "scan" syntax, which is similar to jax.lax.scan
31
+ 'scan', 'checkpointed_scan',
32
+
33
+ # "for_loop" syntax
34
+ 'for_loop', 'checkpointed_for_loop',
35
+ ]
36
+
37
+ X = TypeVar('X')
38
+ Y = TypeVar('Y')
39
+ T = TypeVar('T')
40
+ Carry = TypeVar('Carry')
41
+
42
+
43
+ def _wrap_fun_with_pbar(
44
+ fun: Callable[[Carry, X], Tuple[Carry, Y]],
45
+ pbar_runner: Callable
46
+ ):
47
+ @wraps(fun)
48
+ def new_fun(new_carry, inputs):
49
+ i, old_carry = new_carry
50
+ new_carry, new_outputs = fun(old_carry, inputs)
51
+ pbar_runner(unvmap(i, op='none'), carry=new_carry, y=new_outputs)
52
+ return (i + 1, new_carry), new_outputs
53
+
54
+ return new_fun
55
+
56
+
57
+ @set_module_as('brainstate.transform')
58
+ def scan(
59
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
60
+ init: Carry,
61
+ xs: X,
62
+ length: int | None = None,
63
+ reverse: bool = False,
64
+ unroll: int | bool = 1,
65
+ pbar: ProgressBar | int | None = None,
66
+ ) -> Tuple[Carry, Y]:
67
+ """
68
+ Scan a function over leading array axes while carrying along state.
69
+
70
+ The `Haskell-like type signature`_ in brief is
71
+
72
+ .. code-block:: haskell
73
+
74
+ scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
75
+
76
+ where for any array type specifier ``t``, ``[t]`` represents the type with an additional
77
+ leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
78
+ represents the type with the same pytree structure and corresponding leaves
79
+ each with an additional leading axis.
80
+
81
+ When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
82
+ of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
83
+ given roughly by this Python implementation:
84
+
85
+ .. code-block:: python
86
+
87
+ >>> def scan(f, init, xs, length=None):
88
+ ... if xs is None:
89
+ ... xs = [None] * length
90
+ ... carry = init
91
+ ... ys = []
92
+ ... for x in xs:
93
+ ... carry, y = f(carry, x)
94
+ ... ys.append(y)
95
+ ... return carry, np.stack(ys)
96
+
97
+ Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
98
+ values, and so multiple arrays can be scanned over at once and produce multiple
99
+ output arrays. ``None`` is actually a special case of this, as it represents an
100
+ empty pytree.
101
+
102
+ Also unlike that Python version, :func:`~scan` is a JAX primitive and is
103
+ lowered to a single WhileOp. That makes it useful for reducing
104
+ compilation times for JIT-compiled functions, since native Python
105
+ loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
106
+ XLA computations.
107
+
108
+ Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
109
+ across all iterations (and not just be consistent up to NumPy rank/shape
110
+ broadcasting and dtype promotion rules, for example). In other words, the type
111
+ ``c`` in the type signature above represents an array with a fixed shape and
112
+ dtype (or a nested tuple/list/dict container data structure with a fixed
113
+ structure and arrays with fixed shape and dtype at the leaves).
114
+
115
+ Parameters
116
+ ----------
117
+ f : callable
118
+ A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
119
+ that ``f`` accepts two arguments where the first is a value of the loop
120
+ carry and the second is a slice of ``xs`` along its leading axis, and that
121
+ ``f`` returns a pair where the first element represents a new value for
122
+ the loop carry and the second represents a slice of the output.
123
+ init : Carry
124
+ An initial loop carry value of type ``c``, which can be a scalar,
125
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
126
+ the initial loop carry value. This value must have the same structure as
127
+ the first element of the pair returned by ``f``.
128
+ xs : X
129
+ The value of type ``[a]`` over which to scan along the leading axis,
130
+ where ``[a]`` can be an array or any pytree (nested Python
131
+ tuple/list/dict) thereof with consistent leading axis sizes.
132
+ length : int, optional
133
+ Optional integer specifying the number of loop iterations, which
134
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
135
+ be used to perform scans where no input ``xs`` are needed).
136
+ reverse : bool, default False
137
+ Optional boolean specifying whether to run the scan iteration
138
+ forward (the default) or in reverse, equivalent to reversing the leading
139
+ axes of the arrays in both ``xs`` and in ``ys``.
140
+ unroll : int or bool, default 1
141
+ Optional positive int or bool specifying, in the underlying
142
+ operation of the scan primitive, how many scan iterations to unroll within
143
+ a single iteration of a loop. If an integer is provided, it determines how
144
+ many unrolled loop iterations to run within a single rolled iteration of
145
+ the loop. If a boolean is provided, it will determine if the loop is
146
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
147
+ `unroll=False`).
148
+ pbar : ProgressBar or int, optional
149
+ Optional :class:`~.ProgressBar` instance to display the progress
150
+ of the scan operation.
151
+
152
+ Returns
153
+ -------
154
+ tuple of (Carry, Y)
155
+ A pair of type ``(c, [b])`` where the first element represents the final
156
+ loop carry value and the second element represents the stacked outputs of
157
+ the second output of ``f`` when scanned over the leading axis of the inputs.
158
+
159
+ Examples
160
+ --------
161
+ Basic scan operation:
162
+
163
+ .. code-block:: python
164
+
165
+ >>> import brainstate
166
+ >>> import jax.numpy as jnp
167
+ >>>
168
+ >>> def step_fn(carry, x):
169
+ ... return carry + x, carry * x
170
+ >>>
171
+ >>> init = 0.0
172
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
173
+ >>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs)
174
+
175
+ Scan with progress bar:
176
+
177
+ .. code-block:: python
178
+
179
+ >>> pbar = brainstate.transform.ProgressBar(freq=10)
180
+ >>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs, pbar=pbar)
181
+
182
+ References
183
+ ----------
184
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
185
+ """
186
+ # check "f"
187
+ if not callable(f):
188
+ raise TypeError("f argument should be a callable.")
189
+
190
+ # check "xs"
191
+ xs_flat, xs_tree = jax.tree.flatten(xs)
192
+ try:
193
+ lengths = [x.shape[0] for x in xs_flat]
194
+ except AttributeError as err:
195
+ raise ValueError("scan got value with no leading axis to scan over: "
196
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
197
+ if length is not None:
198
+ length = int(length)
199
+ if not all(length == l for l in lengths):
200
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
201
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
202
+ else:
203
+ unique_lengths = set(lengths)
204
+ if len(unique_lengths) > 1:
205
+ msg = "scan got values with different leading axis sizes: {}."
206
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
207
+ elif len(unique_lengths) == 0:
208
+ raise ValueError("scan got no values to scan over and `length` not provided.")
209
+ else:
210
+ length, = unique_lengths
211
+
212
+ # function with progress bar
213
+ has_pbar = False
214
+ if pbar is not None:
215
+ has_pbar = True
216
+ if isinstance(pbar, ProgressBar):
217
+ pbar_runner = pbar.init(length)
218
+ elif isinstance(pbar, int):
219
+ pbar_runner = ProgressBar(freq=pbar).init(length)
220
+ else:
221
+ raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
222
+ f = _wrap_fun_with_pbar(f, pbar_runner)
223
+ init = (0, init) if pbar else init
224
+
225
+ # not jit
226
+ if jax.config.jax_disable_jit:
227
+ if length == 0:
228
+ raise ValueError(
229
+ "zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
230
+ carry = init
231
+ ys = []
232
+ maybe_reversed = reversed if reverse else lambda x: x
233
+ for i in maybe_reversed(range(length)):
234
+ xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
235
+ carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
236
+ ys.append(y)
237
+ stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
238
+ if has_pbar:
239
+ return carry[1], stacked_y
240
+ else:
241
+ return carry, stacked_y
242
+
243
+ # evaluate jaxpr, get all states #
244
+ # ------------------------------ #
245
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
246
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
247
+ args = [init, xs_tree.unflatten(x_avals)]
248
+ stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
249
+ state_trace = stateful_fun.get_state_trace(*args)
250
+ all_writen_state_vals = state_trace.get_write_state_values(True)
251
+ all_read_state_vals = state_trace.get_read_state_values(True)
252
+ wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
253
+
254
+ # scan
255
+ init = (all_writen_state_vals, init)
256
+ (
257
+ (
258
+ all_writen_state_vals,
259
+ carry
260
+ ),
261
+ ys
262
+ ) = jax.lax.scan(
263
+ wrapped_f,
264
+ init,
265
+ xs,
266
+ length=length,
267
+ reverse=reverse,
268
+ unroll=unroll
269
+ )
270
+ # assign the written state values and restore the read state values
271
+ state_trace.assign_state_vals_v2(all_read_state_vals, all_writen_state_vals)
272
+ # carry
273
+ if has_pbar:
274
+ carry = carry[1]
275
+ return carry, ys
276
+
277
+
278
+ @set_module_as('brainstate.transform')
279
+ def checkpointed_scan(
280
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
281
+ init: Carry,
282
+ xs: X,
283
+ length: Optional[int] = None,
284
+ base: int = 16,
285
+ pbar: Optional[ProgressBar | int] = None,
286
+ ) -> Tuple[Carry, Y]:
287
+ """
288
+ Scan a function over leading array axes while carrying along state.
289
+ This function is similar to :func:`~scan` but with a checkpointed version.
290
+
291
+ Parameters
292
+ ----------
293
+ f : callable
294
+ A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
295
+ that ``f`` accepts two arguments where the first is a value of the loop
296
+ carry and the second is a slice of ``xs`` along its leading axis, and that
297
+ ``f`` returns a pair where the first element represents a new value for
298
+ the loop carry and the second represents a slice of the output.
299
+ init : Carry
300
+ An initial loop carry value of type ``c``, which can be a scalar,
301
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
302
+ the initial loop carry value. This value must have the same structure as
303
+ the first element of the pair returned by ``f``.
304
+ xs : X
305
+ The value of type ``[a]`` over which to scan along the leading axis,
306
+ where ``[a]`` can be an array or any pytree (nested Python
307
+ tuple/list/dict) thereof with consistent leading axis sizes.
308
+ length : int, optional
309
+ Optional integer specifying the number of loop iterations, which
310
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
311
+ be used to perform scans where no input ``xs`` are needed).
312
+ base : int, default 16
313
+ Optional integer specifying the base for the bounded scan loop.
314
+ pbar : ProgressBar or int, optional
315
+ Optional :class:`~.ProgressBar` instance to display the progress
316
+ of the scan operation.
317
+
318
+ Returns
319
+ -------
320
+ tuple of (Carry, Y)
321
+ A pair of type ``(c, [b])`` where the first element represents the final
322
+ loop carry value and the second element represents the stacked outputs of
323
+ the second output of ``f`` when scanned over the leading axis of the inputs.
324
+
325
+ Examples
326
+ --------
327
+ Basic checkpointed scan operation:
328
+
329
+ .. code-block:: python
330
+
331
+ >>> import brainstate
332
+ >>> import jax.numpy as jnp
333
+ >>>
334
+ >>> def step_fn(carry, x):
335
+ ... return carry + x, carry * x
336
+ >>>
337
+ >>> init = 0.0
338
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
339
+ >>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)
340
+
341
+ Using custom base for checkpointing:
342
+
343
+ .. code-block:: python
344
+
345
+ >>> final_carry, ys = brainstate.transform.checkpointed_scan(
346
+ ... step_fn, init, xs, base=8
347
+ ... )
348
+ """
349
+ # check "f"
350
+ if not callable(f):
351
+ raise TypeError("f argument should be a callable.")
352
+
353
+ # check "xs"
354
+ xs_flat, xs_tree = jax.tree.flatten(xs)
355
+ try:
356
+ lengths = [x.shape[0] for x in xs_flat]
357
+ except AttributeError as err:
358
+ raise ValueError("scan got value with no leading axis to scan over: "
359
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
360
+ if length is not None:
361
+ length = int(length)
362
+ if not all(length == l for l in lengths):
363
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
364
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
365
+ else:
366
+ unique_lengths = set(lengths)
367
+ if len(unique_lengths) > 1:
368
+ msg = "scan got values with different leading axis sizes: {}."
369
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
370
+ elif len(unique_lengths) == 0:
371
+ raise ValueError("scan got no values to scan over and `length` not provided.")
372
+ else:
373
+ length, = unique_lengths
374
+
375
+ # function with progress bar
376
+ if isinstance(pbar, ProgressBar):
377
+ pbar_runner = pbar.init(length)
378
+ elif isinstance(pbar, int):
379
+ pbar_runner = ProgressBar(freq=pbar).init(length)
380
+ else:
381
+ pbar_runner = None
382
+
383
+ # evaluate jaxpr
384
+ xs_avals = [jax.core.get_aval(x) for x in xs_flat]
385
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
386
+ args = (init, xs_tree.unflatten(x_avals))
387
+ stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
388
+ state_trace = stateful_fun.get_state_trace(*args)
389
+ cache_key = stateful_fun.get_arg_cache_key(*args)
390
+ # get all states
391
+ been_written = state_trace.been_writen
392
+ read_state_vals = state_trace.get_read_state_values(True)
393
+ write_state_vals = state_trace.get_write_state_values(True)
394
+
395
+ # initialize the collected values/dataa
396
+ out_info = stateful_fun.get_out_shapes_by_cache(cache_key)[0]
397
+ assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
398
+ data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
399
+ del out_info
400
+
401
+ def wrapped_cond_fun(inp):
402
+ return inp[-1] < length
403
+
404
+ def wrapped_body_fun(inp):
405
+ (prev_write_states, carray), prev_collect, i = inp
406
+ # progress bar
407
+ if pbar_runner is not None:
408
+ pbar_runner(unvmap(i, op='none'))
409
+ # call the function
410
+ prev_states = [w_val if write else r_val
411
+ for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
412
+ new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
413
+ prev_states, carray, jax.tree.map(lambda x: x[i], xs)
414
+ )
415
+ # new written states
416
+ new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
417
+
418
+ # out of length
419
+ pred = i < length
420
+ new_collect = jax.tree.map(
421
+ # lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
422
+ lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
423
+ prev_collect,
424
+ out4updates,
425
+ )
426
+ new_write_states = jax.tree.map(
427
+ lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
428
+ prev_write_states,
429
+ new_write_states,
430
+ is_leaf=lambda x: x is None
431
+ )
432
+ new_carray = jax.tree.map(
433
+ lambda pc, nc: jax.lax.select(pred, nc, pc),
434
+ carray,
435
+ new_carray,
436
+ )
437
+ return (new_write_states, new_carray), new_collect, i + 1
438
+
439
+ # while_loop
440
+ rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
441
+ (write_state_vals, carry), data2collection, _ = (
442
+ _bounded_while_loop(
443
+ wrapped_cond_fun,
444
+ wrapped_body_fun,
445
+ ((write_state_vals, init), data2collection, 0),
446
+ rounded_max_steps,
447
+ base,
448
+ pbar_runner
449
+ )
450
+ )
451
+ # assign the written state values and restore the read state values
452
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
453
+ del write_state_vals, read_state_vals, stateful_fun
454
+ return carry, data2collection
455
+
456
+
457
+ def _forloop_to_scan_fun(f: Callable):
458
+ @wraps(f)
459
+ def scan_fun(carry, x):
460
+ return carry, f(*x)
461
+
462
+ return scan_fun
463
+
464
+
465
+ @set_module_as('brainstate.transform')
466
+ def for_loop(
467
+ f: Callable[..., Y],
468
+ *xs,
469
+ length: Optional[int] = None,
470
+ reverse: bool = False,
471
+ unroll: int | bool = 1,
472
+ pbar: Optional[ProgressBar | int] = None
473
+ ) -> Y:
474
+ """
475
+ ``for-loop`` control flow with :py:class:`~.State`.
476
+
477
+ Parameters
478
+ ----------
479
+ f : callable
480
+ A Python function to be looped over that accepts variadic arguments
481
+ corresponding to slices of ``xs`` along their leading axes, and returns
482
+ the output for that iteration.
483
+ *xs
484
+ The values over which to loop along the leading axis,
485
+ where each can be an array or any pytree (nested Python
486
+ tuple/list/dict) thereof with consistent leading axis sizes.
487
+ length : int, optional
488
+ Optional integer specifying the number of loop iterations, which
489
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
490
+ be used to perform loops where no input ``xs`` are needed).
491
+ reverse : bool, default False
492
+ Optional boolean specifying whether to run the loop iteration
493
+ forward (the default) or in reverse, equivalent to reversing the leading
494
+ axes of the arrays in both ``xs`` and in ``ys``.
495
+ unroll : int or bool, default 1
496
+ Optional positive int or bool specifying, in the underlying
497
+ operation of the scan primitive, how many loop iterations to unroll within
498
+ a single iteration of a loop. If an integer is provided, it determines how
499
+ many unrolled loop iterations to run within a single rolled iteration of
500
+ the loop. If a boolean is provided, it will determine if the loop is
501
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
502
+ `unroll=False`).
503
+ pbar : ProgressBar or int, optional
504
+ Optional :class:`~.ProgressBar` instance to display the progress
505
+ of the loop operation.
506
+
507
+ Returns
508
+ -------
509
+ Y
510
+ The stacked outputs of ``f`` when looped over the leading axis of the inputs.
511
+
512
+ Examples
513
+ --------
514
+ Basic for-loop operation:
515
+
516
+ .. code-block:: python
517
+
518
+ >>> import brainstate
519
+ >>> import jax.numpy as jnp
520
+ >>>
521
+ >>> def process_item(x, y):
522
+ ... return x * y + 1
523
+ >>>
524
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
525
+ >>> ys = jnp.array([4.0, 5.0, 6.0])
526
+ >>> results = brainstate.transform.for_loop(process_item, xs, ys)
527
+
528
+ For-loop with progress bar:
529
+
530
+ .. code-block:: python
531
+
532
+ >>> pbar = brainstate.transform.ProgressBar(freq=10)
533
+ >>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
534
+
535
+ For-loop with reverse iteration:
536
+
537
+ .. code-block:: python
538
+
539
+ >>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)
540
+ """
541
+ _, ys = scan(
542
+ _forloop_to_scan_fun(f),
543
+ init=None,
544
+ xs=xs,
545
+ length=length,
546
+ reverse=reverse,
547
+ unroll=unroll,
548
+ pbar=pbar
549
+ )
550
+ return ys
551
+
552
+
553
+ @set_module_as('brainstate.transform')
554
+ def checkpointed_for_loop(
555
+ f: Callable[..., Y],
556
+ *xs: X,
557
+ length: Optional[int] = None,
558
+ base: int = 16,
559
+ pbar: Optional[ProgressBar | int] = None,
560
+ ) -> Y:
561
+ """
562
+ ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
563
+
564
+ Parameters
565
+ ----------
566
+ f : callable
567
+ A Python function to be looped over that accepts variadic arguments
568
+ corresponding to slices of ``xs`` along their leading axes, and returns
569
+ the output for that iteration.
570
+ *xs : X
571
+ The values over which to loop along the leading axis,
572
+ where each can be an array or any pytree (nested Python
573
+ tuple/list/dict) thereof with consistent leading axis sizes.
574
+ length : int, optional
575
+ Optional integer specifying the number of loop iterations, which
576
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
577
+ be used to perform loops where no input ``xs`` are needed).
578
+ base : int, default 16
579
+ Optional integer specifying the base for the bounded loop.
580
+ pbar : ProgressBar or int, optional
581
+ Optional :class:`~.ProgressBar` instance to display the progress
582
+ of the loop operation.
583
+
584
+ Returns
585
+ -------
586
+ Y
587
+ The stacked outputs of ``f`` when looped over the leading axis of the inputs.
588
+
589
+ Examples
590
+ --------
591
+ Basic checkpointed for-loop operation:
592
+
593
+ .. code-block:: python
594
+
595
+ >>> import brainstate
596
+ >>> import jax.numpy as jnp
597
+ >>>
598
+ >>> def process_item(x, y):
599
+ ... return x * y + 1
600
+ >>>
601
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
602
+ >>> ys = jnp.array([4.0, 5.0, 6.0])
603
+ >>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
604
+
605
+ Using custom base for checkpointing:
606
+
607
+ .. code-block:: python
608
+
609
+ >>> results = brainstate.transform.checkpointed_for_loop(
610
+ ... process_item, xs, ys, base=8
611
+ ... )
612
+ """
613
+ _, ys = checkpointed_scan(
614
+ _forloop_to_scan_fun(f),
615
+ init=None,
616
+ xs=xs,
617
+ length=length,
618
+ base=base,
619
+ pbar=pbar
620
+ )
621
+ return ys
622
+
623
+
624
+ # This function is adapted from ``while_loop`` in
625
+ # `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
626
+
627
+ # There's several tricks happening here to work around various limitations of JAX.
628
+ # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
629
+ # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
630
+ # is converted to a `lax.select`, which executes both branches unconditionally.
631
+ # Thus writing this naively, using a plain `lax.cond`, will mean the loop always
632
+ # runs to `max_steps` when executing under vmap. Instead we run (only) until every
633
+ # batch element has finished.
634
+ # 2. Treating in-place updates specially in the body_fun. Specifically we need to
635
+ # `lax.select` the update-to-make, not the updated buffer. This is because the
636
+ # latter instead results in XLA:CPU failing to determine that the buffer can be
637
+ # updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
638
+ # This is done through the extra `inplace` argument provided to `body_fun`.
639
+ # 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
640
+ # `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
641
+ # θ(number of steps actually taken).
642
+ # 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
643
+ # fewest superfluous operations. In practice this implies quite deep recursion in
644
+ # the construction of the bounded while loop, and this slows down the jaxpr
645
+ # creation and the XLA compilation. We choose `base=16` as a reasonable-looking
646
+ # compromise between compilation time and run time.
647
+
648
+ def _bounded_while_loop(
649
+ cond_fun: Callable,
650
+ body_fun: Callable,
651
+ val: Any,
652
+ max_steps: int,
653
+ base: int,
654
+ pbar_runner: Optional[Callable] = None
655
+ ):
656
+ if max_steps == 1:
657
+ return body_fun(val)
658
+ else:
659
+
660
+ def true_call(val_):
661
+ return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
662
+
663
+ def false_call(val_):
664
+ if pbar_runner is not None:
665
+ pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
666
+ return val_[:-1] + (val_[-1] + max_steps,)
667
+
668
+ def scan_fn(val_, _):
669
+ return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
670
+
671
+ # Don't put checkpointing on the lowest level
672
+ if max_steps != base:
673
+ scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
674
+
675
+ return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]