brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,516 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ from functools import wraps
20
+ from typing import Callable, Optional, TypeVar, Tuple, Any
21
+
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+ from brainstate._utils import set_module_as
26
+ from ._make_jaxpr import StatefulFunction
27
+ from ._progress_bar import ProgressBar
28
+ from ._unvmap import unvmap
29
+ from ._util import write_back_state_values, wrap_single_fun
30
+
31
+ __all__ = [
32
+ # "scan" syntax, which is similar to jax.lax.scan
33
+ 'scan', 'checkpointed_scan',
34
+ # "for_loop" syntax
35
+ 'for_loop', 'checkpointed_for_loop',
36
+ ]
37
+
38
+ X = TypeVar('X')
39
+ Y = TypeVar('Y')
40
+ T = TypeVar('T')
41
+ Carry = TypeVar('Carry')
42
+
43
+
44
+ def _wrap_fun_with_pbar(
45
+ fun: Callable[[Carry, X], Tuple[Carry, Y]],
46
+ pbar_runner: Callable
47
+ ):
48
+ @wraps(fun)
49
+ def new_fun(new_carry, inputs):
50
+ i, old_carry = new_carry
51
+ old_carry, old_outputs = fun(old_carry, inputs)
52
+ pbar_runner(unvmap(i, op='none'))
53
+ return (i + 1, old_carry), old_outputs
54
+
55
+ return new_fun
56
+
57
+
58
+ @set_module_as('brainstate.compile')
59
+ def scan(
60
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
61
+ init: Carry,
62
+ xs: X,
63
+ length: int | None = None,
64
+ reverse: bool = False,
65
+ unroll: int | bool = 1,
66
+ pbar: ProgressBar | None = None,
67
+ ) -> Tuple[Carry, Y]:
68
+ """
69
+ Scan a function over leading array axes while carrying along state.
70
+
71
+ The `Haskell-like type signature`_ in brief is
72
+
73
+ .. code-block:: haskell
74
+
75
+ scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
76
+
77
+ where for any array type specifier ``t``, ``[t]`` represents the type with an additional
78
+ leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]``
79
+ represents the type with the same pytree structure and corresponding leaves
80
+ each with an additional leading axis.
81
+
82
+ When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
83
+ of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
84
+ given roughly by this Python implementation::
85
+
86
+ def scan(f, init, xs, length=None):
87
+ if xs is None:
88
+ xs = [None] * length
89
+ carry = init
90
+ ys = []
91
+ for x in xs:
92
+ carry, y = f(carry, x)
93
+ ys.append(y)
94
+ return carry, np.stack(ys)
95
+
96
+ Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
97
+ values, and so multiple arrays can be scanned over at once and produce multiple
98
+ output arrays. ``None`` is actually a special case of this, as it represents an
99
+ empty pytree.
100
+
101
+ Also unlike that Python version, :func:`~scan` is a JAX primitive and is
102
+ lowered to a single WhileOp. That makes it useful for reducing
103
+ compilation times for JIT-compiled functions, since native Python
104
+ loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
105
+ XLA computations.
106
+
107
+ Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
108
+ across all iterations (and not just be consistent up to NumPy rank/shape
109
+ broadcasting and dtype promotion rules, for example). In other words, the type
110
+ ``c`` in the type signature above represents an array with a fixed shape and
111
+ dtype (or a nested tuple/list/dict container data structure with a fixed
112
+ structure and arrays with fixed shape and dtype at the leaves).
113
+
114
+ Args:
115
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
116
+ that ``f`` accepts two arguments where the first is a value of the loop
117
+ carry and the second is a slice of ``xs`` along its leading axis, and that
118
+ ``f`` returns a pair where the first element represents a new value for
119
+ the loop carry and the second represents a slice of the output.
120
+ init: an initial loop carry value of type ``c``, which can be a scalar,
121
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
122
+ the initial loop carry value. This value must have the same structure as
123
+ the first element of the pair returned by ``f``.
124
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
125
+ where ``[a]`` can be an array or any pytree (nested Python
126
+ tuple/list/dict) thereof with consistent leading axis sizes.
127
+ length: optional integer specifying the number of loop iterations, which
128
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
129
+ be used to perform scans where no input ``xs`` are needed).
130
+ reverse: optional boolean specifying whether to run the scan iteration
131
+ forward (the default) or in reverse, equivalent to reversing the leading
132
+ axes of the arrays in both ``xs`` and in ``ys``.
133
+ unroll: optional positive int or bool specifying, in the underlying
134
+ operation of the scan primitive, how many scan iterations to unroll within
135
+ a single iteration of a loop. If an integer is provided, it determines how
136
+ many unrolled loop iterations to run within a single rolled iteration of
137
+ the loop. If a boolean is provided, it will determine if the loop is
138
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
139
+ `unroll=False`).
140
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
141
+ of the scan operation.
142
+
143
+ Returns:
144
+ A pair of type ``(c, [b])`` where the first element represents the final
145
+ loop carry value and the second element represents the stacked outputs of
146
+ the second output of ``f`` when scanned over the leading axis of the inputs.
147
+
148
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
149
+ """
150
+ # check "f"
151
+ if not callable(f):
152
+ raise TypeError("f argument should be a callable.")
153
+
154
+ # check "xs"
155
+ xs_flat, xs_tree = jax.tree.flatten(xs)
156
+ try:
157
+ lengths = [x.shape[0] for x in xs_flat]
158
+ except AttributeError as err:
159
+ raise ValueError("scan got value with no leading axis to scan over: "
160
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
161
+ if length is not None:
162
+ length = int(length)
163
+ if not all(length == l for l in lengths):
164
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
165
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
166
+ else:
167
+ unique_lengths = set(lengths)
168
+ if len(unique_lengths) > 1:
169
+ msg = "scan got values with different leading axis sizes: {}."
170
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
171
+ elif len(unique_lengths) == 0:
172
+ raise ValueError("scan got no values to scan over and `length` not provided.")
173
+ else:
174
+ length, = unique_lengths
175
+
176
+ # function with progress bar
177
+ has_pbar = False
178
+ if pbar is not None:
179
+ has_pbar = True
180
+ f = _wrap_fun_with_pbar(f, pbar.init(length))
181
+ init = (0, init) if pbar else init
182
+
183
+ # not jit
184
+ if jax.config.jax_disable_jit:
185
+ if length == 0:
186
+ raise ValueError(
187
+ "zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
188
+ carry = init
189
+ ys = []
190
+ maybe_reversed = reversed if reverse else lambda x: x
191
+ for i in maybe_reversed(range(length)):
192
+ xs_slice = [jax.lax.index_in_dim(x, i, keepdims=False) for x in xs_flat]
193
+ carry, y = f(carry, jax.tree.unflatten(xs_tree, xs_slice))
194
+ ys.append(y)
195
+ stacked_y = jax.tree.map(lambda *elems: jnp.stack(elems), *maybe_reversed(ys))
196
+ if has_pbar:
197
+ return carry[1], stacked_y
198
+ else:
199
+ return carry, stacked_y
200
+
201
+ # evaluate jaxpr, get all states #
202
+ # ------------------------------ #
203
+ xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
204
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
205
+ with jax.ensure_compile_time_eval():
206
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
207
+ state_trace = stateful_fun.get_state_trace()
208
+ all_writen_state_vals = state_trace.get_write_state_values(True)
209
+ all_read_state_vals = state_trace.get_read_state_values(True)
210
+ wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
211
+
212
+ # scan
213
+ init = (all_writen_state_vals, init)
214
+ (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f, init, xs, length=length, reverse=reverse,
215
+ unroll=unroll)
216
+ # assign the written state values and restore the read state values
217
+ write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
218
+ # carry
219
+ if has_pbar:
220
+ carry = carry[1]
221
+ return carry, ys
222
+
223
+
224
+ def checkpointed_scan(
225
+ f: Callable[[Carry, X], Tuple[Carry, Y]],
226
+ init: Carry,
227
+ xs: X,
228
+ length: Optional[int] = None,
229
+ base: int = 16,
230
+ pbar: Optional[ProgressBar] = None,
231
+ ) -> Tuple[Carry, Y]:
232
+ """
233
+ Scan a function over leading array axes while carrying along state.
234
+ This function is similar to :func:`~scan` but with a checkpointed version.
235
+
236
+ Args:
237
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
238
+ that ``f`` accepts two arguments where the first is a value of the loop
239
+ carry and the second is a slice of ``xs`` along its leading axis, and that
240
+ ``f`` returns a pair where the first element represents a new value for
241
+ the loop carry and the second represents a slice of the output.
242
+ init: an initial loop carry value of type ``c``, which can be a scalar,
243
+ array, or any pytree (nested Python tuple/list/dict) thereof, representing
244
+ the initial loop carry value. This value must have the same structure as
245
+ the first element of the pair returned by ``f``.
246
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
247
+ where ``[a]`` can be an array or any pytree (nested Python
248
+ tuple/list/dict) thereof with consistent leading axis sizes.
249
+ length: optional integer specifying the number of loop iterations, which
250
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
251
+ be used to perform scans where no input ``xs`` are needed).
252
+ base: optional integer specifying the base for the bounded scan loop.
253
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
254
+ of the scan operation.
255
+
256
+ Returns:
257
+ A pair of type ``(c, [b])`` where the first element represents the final
258
+ loop carry value and the second element represents the stacked outputs of
259
+ the second output of ``f`` when scanned over the leading axis of the inputs.
260
+ """
261
+ # check "f"
262
+ if not callable(f):
263
+ raise TypeError("f argument should be a callable.")
264
+
265
+ # check "xs"
266
+ xs_flat, xs_tree = jax.tree.flatten(xs)
267
+ try:
268
+ lengths = [x.shape[0] for x in xs_flat]
269
+ except AttributeError as err:
270
+ raise ValueError("scan got value with no leading axis to scan over: "
271
+ "{}.".format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err
272
+ if length is not None:
273
+ length = int(length)
274
+ if not all(length == l for l in lengths):
275
+ raise ValueError(("scan got `length` argument of {} which disagrees with "
276
+ "leading axis sizes {}.").format(length, [x.shape[0] for x in xs_flat]))
277
+ else:
278
+ unique_lengths = set(lengths)
279
+ if len(unique_lengths) > 1:
280
+ msg = "scan got values with different leading axis sizes: {}."
281
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
282
+ elif len(unique_lengths) == 0:
283
+ raise ValueError("scan got no values to scan over and `length` not provided.")
284
+ else:
285
+ length, = unique_lengths
286
+
287
+ # function with progress bar
288
+ if pbar is not None:
289
+ pbar_runner = pbar.init(length)
290
+ else:
291
+ pbar_runner = None
292
+
293
+ # evaluate jaxpr
294
+ xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
295
+ x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
296
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
297
+ state_trace = stateful_fun.get_state_trace()
298
+ # get all states
299
+ been_written = state_trace.been_writen
300
+ read_state_vals = state_trace.get_read_state_values(True)
301
+ write_state_vals = state_trace.get_write_state_values(True)
302
+
303
+ # initialize the collected values/dataa
304
+ out_info = stateful_fun.get_out_shapes()[0]
305
+ assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
306
+ data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
307
+ del out_info
308
+
309
+ def wrapped_cond_fun(inp):
310
+ return inp[-1] < length
311
+
312
+ def wrapped_body_fun(inp):
313
+ (prev_write_states, carray), prev_collect, i = inp
314
+ # progress bar
315
+ if pbar_runner is not None:
316
+ pbar_runner(unvmap(i, op='none'))
317
+ # call the function
318
+ prev_states = [w_val if write else r_val
319
+ for write, w_val, r_val in zip(been_written, prev_write_states, read_state_vals)]
320
+ new_states, (new_carray, out4updates) = stateful_fun.jaxpr_call(
321
+ prev_states, carray, jax.tree.map(lambda x: x[i], xs)
322
+ )
323
+ # new written states
324
+ new_write_states = tuple([val if write else None for write, val in zip(been_written, new_states)])
325
+
326
+ # out of length
327
+ pred = i < length
328
+ new_collect = jax.tree.map(
329
+ # lambda x, update: x.at[i].set(jax.lax.select(pred, update, x[i])),
330
+ lambda x, update: jax.lax.select(pred, x.at[i].set(update), x),
331
+ prev_collect,
332
+ out4updates,
333
+ )
334
+ new_write_states = jax.tree.map(
335
+ lambda ps, ns: None if ns is None else jax.lax.select(pred, ns, ps),
336
+ prev_write_states,
337
+ new_write_states,
338
+ is_leaf=lambda x: x is None
339
+ )
340
+ new_carray = jax.tree.map(
341
+ lambda pc, nc: jax.lax.select(pred, nc, pc),
342
+ carray,
343
+ new_carray,
344
+ )
345
+ return (new_write_states, new_carray), new_collect, i + 1
346
+
347
+ # while_loop
348
+ rounded_max_steps = base ** int(math.ceil(math.log(length, base)))
349
+ (write_state_vals, carry), data2collection, _ = (
350
+ _bounded_while_loop(
351
+ wrapped_cond_fun,
352
+ wrapped_body_fun,
353
+ ((write_state_vals, init), data2collection, 0),
354
+ rounded_max_steps,
355
+ base,
356
+ pbar_runner
357
+ )
358
+ )
359
+ # assign the written state values and restore the read state values
360
+ write_back_state_values(state_trace, read_state_vals, write_state_vals)
361
+ del write_state_vals, read_state_vals, stateful_fun
362
+ return carry, data2collection
363
+
364
+
365
+ def _forloop_to_scan_fun(f: Callable):
366
+ @wraps(f)
367
+ def scan_fun(carry, x):
368
+ return carry, f(*x)
369
+
370
+ return scan_fun
371
+
372
+
373
+ @set_module_as('brainstate.compile')
374
+ def for_loop(
375
+ f: Callable[..., Y],
376
+ *xs,
377
+ length: Optional[int] = None,
378
+ reverse: bool = False,
379
+ unroll: int | bool = 1,
380
+ pbar: Optional[ProgressBar] = None
381
+ ) -> Y:
382
+ """
383
+ ``for-loop`` control flow with :py:class:`~.State`.
384
+
385
+ Args:
386
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
387
+ that ``f`` accepts two arguments where the first is a value of the loop
388
+ carry and the second is a slice of ``xs`` along its leading axis, and that
389
+ ``f`` returns a pair where the first element represents a new value for
390
+ the loop carry and the second represents a slice of the output.
391
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
392
+ where ``[a]`` can be an array or any pytree (nested Python
393
+ tuple/list/dict) thereof with consistent leading axis sizes.
394
+ length: optional integer specifying the number of loop iterations, which
395
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
396
+ be used to perform scans where no input ``xs`` are needed).
397
+ reverse: optional boolean specifying whether to run the scan iteration
398
+ forward (the default) or in reverse, equivalent to reversing the leading
399
+ axes of the arrays in both ``xs`` and in ``ys``.
400
+ unroll: optional positive int or bool specifying, in the underlying
401
+ operation of the scan primitive, how many scan iterations to unroll within
402
+ a single iteration of a loop. If an integer is provided, it determines how
403
+ many unrolled loop iterations to run within a single rolled iteration of
404
+ the loop. If a boolean is provided, it will determine if the loop is
405
+ completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
406
+ `unroll=False`).
407
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
408
+ of the scan operation.
409
+
410
+ Returns:
411
+ The return represents the stacked outputs of the second output of ``f``
412
+ when scanned over the leading axis of the inputs.
413
+
414
+ """
415
+ _, ys = scan(
416
+ _forloop_to_scan_fun(f),
417
+ init=None,
418
+ xs=xs,
419
+ length=length,
420
+ reverse=reverse,
421
+ unroll=unroll,
422
+ pbar=pbar
423
+ )
424
+ return ys
425
+
426
+
427
+ def checkpointed_for_loop(
428
+ f: Callable[..., Y],
429
+ *xs: X,
430
+ length: Optional[int] = None,
431
+ base: int = 16,
432
+ pbar: Optional[ProgressBar] = None,
433
+ ) -> Y:
434
+ """
435
+ ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
436
+
437
+ Args:
438
+ f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
439
+ that ``f`` accepts two arguments where the first is a value of the loop
440
+ carry and the second is a slice of ``xs`` along its leading axis, and that
441
+ ``f`` returns a pair where the first element represents a new value for
442
+ the loop carry and the second represents a slice of the output.
443
+ xs: the value of type ``[a]`` over which to scan along the leading axis,
444
+ where ``[a]`` can be an array or any pytree (nested Python
445
+ tuple/list/dict) thereof with consistent leading axis sizes.
446
+ length: optional integer specifying the number of loop iterations, which
447
+ must agree with the sizes of leading axes of the arrays in ``xs`` (but can
448
+ be used to perform scans where no input ``xs`` are needed).
449
+ base: optional integer specifying the base for the bounded scan loop.
450
+ pbar: optional :class:`~.ProgressBar` instance to display the progress
451
+ of the scan operation.
452
+
453
+ Returns:
454
+ The return represents the stacked outputs of the second output of ``f``
455
+ when scanned over the leading axis of the inputs.
456
+ """
457
+ _, ys = checkpointed_scan(
458
+ _forloop_to_scan_fun(f),
459
+ init=None,
460
+ xs=xs,
461
+ length=length,
462
+ base=base,
463
+ pbar=pbar
464
+ )
465
+ return ys
466
+
467
+
468
+ # There's several tricks happening here to work around various limitations of JAX.
469
+ # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
470
+ # 1. `unvmap_any` prior to using `lax.cond`. JAX has a problem in that vmap-of-cond
471
+ # is converted to a `lax.select`, which executes both branches unconditionally.
472
+ # Thus writing this naively, using a plain `lax.cond`, will mean the loop always
473
+ # runs to `max_steps` when executing under vmap. Instead we run (only) until every
474
+ # batch element has finished.
475
+ # 2. Treating in-place updates specially in the body_fun. Specifically we need to
476
+ # `lax.select` the update-to-make, not the updated buffer. This is because the
477
+ # latter instead results in XLA:CPU failing to determine that the buffer can be
478
+ # updated in-place, and instead it makes a copy. c.f. JAX issue #8192.
479
+ # This is done through the extra `inplace` argument provided to `body_fun`.
480
+ # 3. The use of the `@jax.checkpoint` decorator. Backpropagation through a
481
+ # `bounded_while_loop` will otherwise run in θ(max_steps) time, rather than
482
+ # θ(number of steps actually taken).
483
+ # 4. The use of `base`. In theory `base=2` is optimal at run time, as it implies the
484
+ # fewest superfluous operations. In practice this implies quite deep recursion in
485
+ # the construction of the bounded while loop, and this slows down the jaxpr
486
+ # creation and the XLA compilation. We choose `base=16` as a reasonable-looking
487
+ # compromise between compilation time and run time.
488
+
489
+ def _bounded_while_loop(
490
+ cond_fun: Callable,
491
+ body_fun: Callable,
492
+ val: Any,
493
+ max_steps: int,
494
+ base: int,
495
+ pbar_runner: Optional[Callable] = None
496
+ ):
497
+ if max_steps == 1:
498
+ return body_fun(val)
499
+ else:
500
+
501
+ def true_call(val_):
502
+ return _bounded_while_loop(cond_fun, body_fun, val_, max_steps // base, base, pbar_runner)
503
+
504
+ def false_call(val_):
505
+ if pbar_runner is not None:
506
+ pbar_runner(unvmap(val_[-1] + max_steps, op='none'))
507
+ return val_[:-1] + (val_[-1] + max_steps,)
508
+
509
+ def scan_fn(val_, _):
510
+ return jax.lax.cond(unvmap(cond_fun(val_), op='any'), true_call, false_call, val_), None
511
+
512
+ # Don't put checkpointing on the lowest level
513
+ if max_steps != base:
514
+ scan_fn = jax.checkpoint(scan_fn, prevent_cse=False) # pyright: ignore
515
+
516
+ return jax.lax.scan(scan_fn, val, xs=None, length=base)[0]
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+ import unittest
18
+
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestForLoop(unittest.TestCase):
26
+ def test_for_loop(self):
27
+ a = bst.ShortTermState(0.)
28
+ b = bst.ShortTermState(0.)
29
+
30
+ def f(i):
31
+ a.value += (1 + b.value)
32
+ return a.value
33
+
34
+ n_iter = 10
35
+ ops = np.arange(n_iter)
36
+ r = bst.compile.for_loop(f, ops)
37
+
38
+ print(a)
39
+ print(b)
40
+ self.assertTrue(a.value == n_iter)
41
+ self.assertTrue(jnp.allclose(r, ops + 1))
42
+
43
+ def test_checkpointed_for_loop(self):
44
+ a = bst.ShortTermState(0.)
45
+ b = bst.ShortTermState(0.)
46
+
47
+ def f(i):
48
+ a.value += (1 + b.value)
49
+ return a.value
50
+
51
+ n_iter = 18
52
+ ops = jnp.arange(n_iter)
53
+ r = bst.compile.checkpointed_for_loop(f, ops, base=2, pbar=bst.compile.ProgressBar())
54
+
55
+ print(a)
56
+ print(b)
57
+ print(r)
58
+ self.assertTrue(a.value == n_iter)
59
+ self.assertTrue(jnp.allclose(r, ops + 1))