brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ from brainstate._utils import set_module_as
24
24
  from ._make_jaxpr import StatefulFunction
25
25
  from ._progress_bar import ProgressBar
26
26
  from ._unvmap import unvmap
27
- from ._util import write_back_state_values, wrap_single_fun
27
+ from ._util import wrap_single_fun
28
28
 
29
29
  __all__ = [
30
30
  # "scan" syntax, which is similar to jax.lax.scan
@@ -54,7 +54,7 @@ def _wrap_fun_with_pbar(
54
54
  return new_fun
55
55
 
56
56
 
57
- @set_module_as('brainstate.compile')
57
+ @set_module_as('brainstate.transform')
58
58
  def scan(
59
59
  f: Callable[[Carry, X], Tuple[Carry, Y]],
60
60
  init: Carry,
@@ -80,17 +80,19 @@ def scan(
80
80
 
81
81
  When the type of ``xs`` (denoted `a` above) is an array type or None, and the type
82
82
  of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are
83
- given roughly by this Python implementation::
83
+ given roughly by this Python implementation:
84
84
 
85
- def scan(f, init, xs, length=None):
86
- if xs is None:
87
- xs = [None] * length
88
- carry = init
89
- ys = []
90
- for x in xs:
91
- carry, y = f(carry, x)
92
- ys.append(y)
93
- return carry, np.stack(ys)
85
+ .. code-block:: python
86
+
87
+ >>> def scan(f, init, xs, length=None):
88
+ ... if xs is None:
89
+ ... xs = [None] * length
90
+ ... carry = init
91
+ ... ys = []
92
+ ... for x in xs:
93
+ ... carry, y = f(carry, x)
94
+ ... ys.append(y)
95
+ ... return carry, np.stack(ys)
94
96
 
95
97
  Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree
96
98
  values, and so multiple arrays can be scanned over at once and produce multiple
@@ -110,40 +112,75 @@ def scan(
110
112
  dtype (or a nested tuple/list/dict container data structure with a fixed
111
113
  structure and arrays with fixed shape and dtype at the leaves).
112
114
 
113
- Args:
114
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
115
+ Parameters
116
+ ----------
117
+ f : callable
118
+ A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
115
119
  that ``f`` accepts two arguments where the first is a value of the loop
116
120
  carry and the second is a slice of ``xs`` along its leading axis, and that
117
121
  ``f`` returns a pair where the first element represents a new value for
118
122
  the loop carry and the second represents a slice of the output.
119
- init: an initial loop carry value of type ``c``, which can be a scalar,
123
+ init : Carry
124
+ An initial loop carry value of type ``c``, which can be a scalar,
120
125
  array, or any pytree (nested Python tuple/list/dict) thereof, representing
121
126
  the initial loop carry value. This value must have the same structure as
122
127
  the first element of the pair returned by ``f``.
123
- xs: the value of type ``[a]`` over which to scan along the leading axis,
128
+ xs : X
129
+ The value of type ``[a]`` over which to scan along the leading axis,
124
130
  where ``[a]`` can be an array or any pytree (nested Python
125
131
  tuple/list/dict) thereof with consistent leading axis sizes.
126
- length: optional integer specifying the number of loop iterations, which
132
+ length : int, optional
133
+ Optional integer specifying the number of loop iterations, which
127
134
  must agree with the sizes of leading axes of the arrays in ``xs`` (but can
128
135
  be used to perform scans where no input ``xs`` are needed).
129
- reverse: optional boolean specifying whether to run the scan iteration
136
+ reverse : bool, default False
137
+ Optional boolean specifying whether to run the scan iteration
130
138
  forward (the default) or in reverse, equivalent to reversing the leading
131
139
  axes of the arrays in both ``xs`` and in ``ys``.
132
- unroll: optional positive int or bool specifying, in the underlying
140
+ unroll : int or bool, default 1
141
+ Optional positive int or bool specifying, in the underlying
133
142
  operation of the scan primitive, how many scan iterations to unroll within
134
143
  a single iteration of a loop. If an integer is provided, it determines how
135
144
  many unrolled loop iterations to run within a single rolled iteration of
136
145
  the loop. If a boolean is provided, it will determine if the loop is
137
146
  completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
138
147
  `unroll=False`).
139
- pbar: optional :class:`~.ProgressBar` instance to display the progress
148
+ pbar : ProgressBar or int, optional
149
+ Optional :class:`~.ProgressBar` instance to display the progress
140
150
  of the scan operation.
141
151
 
142
- Returns:
143
- A pair of type ``(c, [b])`` where the first element represents the final
144
- loop carry value and the second element represents the stacked outputs of
145
- the second output of ``f`` when scanned over the leading axis of the inputs.
152
+ Returns
153
+ -------
154
+ tuple of (Carry, Y)
155
+ A pair of type ``(c, [b])`` where the first element represents the final
156
+ loop carry value and the second element represents the stacked outputs of
157
+ the second output of ``f`` when scanned over the leading axis of the inputs.
158
+
159
+ Examples
160
+ --------
161
+ Basic scan operation:
162
+
163
+ .. code-block:: python
146
164
 
165
+ >>> import brainstate
166
+ >>> import jax.numpy as jnp
167
+ >>>
168
+ >>> def step_fn(carry, x):
169
+ ... return carry + x, carry * x
170
+ >>>
171
+ >>> init = 0.0
172
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
173
+ >>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs)
174
+
175
+ Scan with progress bar:
176
+
177
+ .. code-block:: python
178
+
179
+ >>> pbar = brainstate.transform.ProgressBar(freq=10)
180
+ >>> final_carry, ys = brainstate.transform.scan(step_fn, init, xs, pbar=pbar)
181
+
182
+ References
183
+ ----------
147
184
  .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
148
185
  """
149
186
  # check "f"
@@ -207,8 +244,9 @@ def scan(
207
244
  # ------------------------------ #
208
245
  xs_avals = [jax.core.get_aval(x) for x in xs_flat]
209
246
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
210
- stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
211
- state_trace = stateful_fun.get_state_trace()
247
+ args = [init, xs_tree.unflatten(x_avals)]
248
+ stateful_fun = StatefulFunction(f, name='scan').make_jaxpr(*args)
249
+ state_trace = stateful_fun.get_state_trace(*args)
212
250
  all_writen_state_vals = state_trace.get_write_state_values(True)
213
251
  all_read_state_vals = state_trace.get_read_state_values(True)
214
252
  wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
@@ -230,13 +268,14 @@ def scan(
230
268
  unroll=unroll
231
269
  )
232
270
  # assign the written state values and restore the read state values
233
- write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
271
+ state_trace.assign_state_vals_v2(all_read_state_vals, all_writen_state_vals)
234
272
  # carry
235
273
  if has_pbar:
236
274
  carry = carry[1]
237
275
  return carry, ys
238
276
 
239
277
 
278
+ @set_module_as('brainstate.transform')
240
279
  def checkpointed_scan(
241
280
  f: Callable[[Carry, X], Tuple[Carry, Y]],
242
281
  init: Carry,
@@ -249,30 +288,63 @@ def checkpointed_scan(
249
288
  Scan a function over leading array axes while carrying along state.
250
289
  This function is similar to :func:`~scan` but with a checkpointed version.
251
290
 
252
- Args:
253
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
291
+ Parameters
292
+ ----------
293
+ f : callable
294
+ A Python function to be scanned of type ``c -> a -> (c, b)``, meaning
254
295
  that ``f`` accepts two arguments where the first is a value of the loop
255
296
  carry and the second is a slice of ``xs`` along its leading axis, and that
256
297
  ``f`` returns a pair where the first element represents a new value for
257
298
  the loop carry and the second represents a slice of the output.
258
- init: an initial loop carry value of type ``c``, which can be a scalar,
299
+ init : Carry
300
+ An initial loop carry value of type ``c``, which can be a scalar,
259
301
  array, or any pytree (nested Python tuple/list/dict) thereof, representing
260
302
  the initial loop carry value. This value must have the same structure as
261
303
  the first element of the pair returned by ``f``.
262
- xs: the value of type ``[a]`` over which to scan along the leading axis,
304
+ xs : X
305
+ The value of type ``[a]`` over which to scan along the leading axis,
263
306
  where ``[a]`` can be an array or any pytree (nested Python
264
307
  tuple/list/dict) thereof with consistent leading axis sizes.
265
- length: optional integer specifying the number of loop iterations, which
308
+ length : int, optional
309
+ Optional integer specifying the number of loop iterations, which
266
310
  must agree with the sizes of leading axes of the arrays in ``xs`` (but can
267
311
  be used to perform scans where no input ``xs`` are needed).
268
- base: optional integer specifying the base for the bounded scan loop.
269
- pbar: optional :class:`~.ProgressBar` instance to display the progress
312
+ base : int, default 16
313
+ Optional integer specifying the base for the bounded scan loop.
314
+ pbar : ProgressBar or int, optional
315
+ Optional :class:`~.ProgressBar` instance to display the progress
270
316
  of the scan operation.
271
317
 
272
- Returns:
273
- A pair of type ``(c, [b])`` where the first element represents the final
274
- loop carry value and the second element represents the stacked outputs of
275
- the second output of ``f`` when scanned over the leading axis of the inputs.
318
+ Returns
319
+ -------
320
+ tuple of (Carry, Y)
321
+ A pair of type ``(c, [b])`` where the first element represents the final
322
+ loop carry value and the second element represents the stacked outputs of
323
+ the second output of ``f`` when scanned over the leading axis of the inputs.
324
+
325
+ Examples
326
+ --------
327
+ Basic checkpointed scan operation:
328
+
329
+ .. code-block:: python
330
+
331
+ >>> import brainstate
332
+ >>> import jax.numpy as jnp
333
+ >>>
334
+ >>> def step_fn(carry, x):
335
+ ... return carry + x, carry * x
336
+ >>>
337
+ >>> init = 0.0
338
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
339
+ >>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)
340
+
341
+ Using custom base for checkpointing:
342
+
343
+ .. code-block:: python
344
+
345
+ >>> final_carry, ys = brainstate.transform.checkpointed_scan(
346
+ ... step_fn, init, xs, base=8
347
+ ... )
276
348
  """
277
349
  # check "f"
278
350
  if not callable(f):
@@ -311,15 +383,17 @@ def checkpointed_scan(
311
383
  # evaluate jaxpr
312
384
  xs_avals = [jax.core.get_aval(x) for x in xs_flat]
313
385
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
314
- stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(init, xs_tree.unflatten(x_avals))
315
- state_trace = stateful_fun.get_state_trace()
386
+ args = (init, xs_tree.unflatten(x_avals))
387
+ stateful_fun = StatefulFunction(f, name='checkpoint_scan').make_jaxpr(*args)
388
+ state_trace = stateful_fun.get_state_trace(*args)
389
+ cache_key = stateful_fun.get_arg_cache_key(*args)
316
390
  # get all states
317
391
  been_written = state_trace.been_writen
318
392
  read_state_vals = state_trace.get_read_state_values(True)
319
393
  write_state_vals = state_trace.get_write_state_values(True)
320
394
 
321
395
  # initialize the collected values/dataa
322
- out_info = stateful_fun.get_out_shapes()[0]
396
+ out_info = stateful_fun.get_out_shapes_by_cache(cache_key)[0]
323
397
  assert len(out_info) == 2, "function in checkpointed_scan should return two data: carray and out."
324
398
  data2collection = jax.tree.map(lambda x: jnp.zeros((length,) + x.shape, x.dtype), out_info[1])
325
399
  del out_info
@@ -375,7 +449,7 @@ def checkpointed_scan(
375
449
  )
376
450
  )
377
451
  # assign the written state values and restore the read state values
378
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
452
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
379
453
  del write_state_vals, read_state_vals, stateful_fun
380
454
  return carry, data2collection
381
455
 
@@ -388,7 +462,7 @@ def _forloop_to_scan_fun(f: Callable):
388
462
  return scan_fun
389
463
 
390
464
 
391
- @set_module_as('brainstate.compile')
465
+ @set_module_as('brainstate.transform')
392
466
  def for_loop(
393
467
  f: Callable[..., Y],
394
468
  *xs,
@@ -400,35 +474,69 @@ def for_loop(
400
474
  """
401
475
  ``for-loop`` control flow with :py:class:`~.State`.
402
476
 
403
- Args:
404
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
405
- that ``f`` accepts two arguments where the first is a value of the loop
406
- carry and the second is a slice of ``xs`` along its leading axis, and that
407
- ``f`` returns a pair where the first element represents a new value for
408
- the loop carry and the second represents a slice of the output.
409
- xs: the value of type ``[a]`` over which to scan along the leading axis,
410
- where ``[a]`` can be an array or any pytree (nested Python
477
+ Parameters
478
+ ----------
479
+ f : callable
480
+ A Python function to be looped over that accepts variadic arguments
481
+ corresponding to slices of ``xs`` along their leading axes, and returns
482
+ the output for that iteration.
483
+ *xs
484
+ The values over which to loop along the leading axis,
485
+ where each can be an array or any pytree (nested Python
411
486
  tuple/list/dict) thereof with consistent leading axis sizes.
412
- length: optional integer specifying the number of loop iterations, which
487
+ length : int, optional
488
+ Optional integer specifying the number of loop iterations, which
413
489
  must agree with the sizes of leading axes of the arrays in ``xs`` (but can
414
- be used to perform scans where no input ``xs`` are needed).
415
- reverse: optional boolean specifying whether to run the scan iteration
490
+ be used to perform loops where no input ``xs`` are needed).
491
+ reverse : bool, default False
492
+ Optional boolean specifying whether to run the loop iteration
416
493
  forward (the default) or in reverse, equivalent to reversing the leading
417
494
  axes of the arrays in both ``xs`` and in ``ys``.
418
- unroll: optional positive int or bool specifying, in the underlying
419
- operation of the scan primitive, how many scan iterations to unroll within
495
+ unroll : int or bool, default 1
496
+ Optional positive int or bool specifying, in the underlying
497
+ operation of the scan primitive, how many loop iterations to unroll within
420
498
  a single iteration of a loop. If an integer is provided, it determines how
421
499
  many unrolled loop iterations to run within a single rolled iteration of
422
500
  the loop. If a boolean is provided, it will determine if the loop is
423
501
  completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e.
424
502
  `unroll=False`).
425
- pbar: optional :class:`~.ProgressBar` instance to display the progress
426
- of the scan operation.
503
+ pbar : ProgressBar or int, optional
504
+ Optional :class:`~.ProgressBar` instance to display the progress
505
+ of the loop operation.
506
+
507
+ Returns
508
+ -------
509
+ Y
510
+ The stacked outputs of ``f`` when looped over the leading axis of the inputs.
511
+
512
+ Examples
513
+ --------
514
+ Basic for-loop operation:
515
+
516
+ .. code-block:: python
517
+
518
+ >>> import brainstate
519
+ >>> import jax.numpy as jnp
520
+ >>>
521
+ >>> def process_item(x, y):
522
+ ... return x * y + 1
523
+ >>>
524
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
525
+ >>> ys = jnp.array([4.0, 5.0, 6.0])
526
+ >>> results = brainstate.transform.for_loop(process_item, xs, ys)
527
+
528
+ For-loop with progress bar:
529
+
530
+ .. code-block:: python
531
+
532
+ >>> pbar = brainstate.transform.ProgressBar(freq=10)
533
+ >>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
427
534
 
428
- Returns:
429
- The return represents the stacked outputs of the second output of ``f``
430
- when scanned over the leading axis of the inputs.
535
+ For-loop with reverse iteration:
431
536
 
537
+ .. code-block:: python
538
+
539
+ >>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)
432
540
  """
433
541
  _, ys = scan(
434
542
  _forloop_to_scan_fun(f),
@@ -442,6 +550,7 @@ def for_loop(
442
550
  return ys
443
551
 
444
552
 
553
+ @set_module_as('brainstate.transform')
445
554
  def checkpointed_for_loop(
446
555
  f: Callable[..., Y],
447
556
  *xs: X,
@@ -452,25 +561,54 @@ def checkpointed_for_loop(
452
561
  """
453
562
  ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
454
563
 
455
- Args:
456
- f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
457
- that ``f`` accepts two arguments where the first is a value of the loop
458
- carry and the second is a slice of ``xs`` along its leading axis, and that
459
- ``f`` returns a pair where the first element represents a new value for
460
- the loop carry and the second represents a slice of the output.
461
- xs: the value of type ``[a]`` over which to scan along the leading axis,
462
- where ``[a]`` can be an array or any pytree (nested Python
564
+ Parameters
565
+ ----------
566
+ f : callable
567
+ A Python function to be looped over that accepts variadic arguments
568
+ corresponding to slices of ``xs`` along their leading axes, and returns
569
+ the output for that iteration.
570
+ *xs : X
571
+ The values over which to loop along the leading axis,
572
+ where each can be an array or any pytree (nested Python
463
573
  tuple/list/dict) thereof with consistent leading axis sizes.
464
- length: optional integer specifying the number of loop iterations, which
574
+ length : int, optional
575
+ Optional integer specifying the number of loop iterations, which
465
576
  must agree with the sizes of leading axes of the arrays in ``xs`` (but can
466
- be used to perform scans where no input ``xs`` are needed).
467
- base: optional integer specifying the base for the bounded scan loop.
468
- pbar: optional :class:`~.ProgressBar` instance to display the progress
469
- of the scan operation.
470
-
471
- Returns:
472
- The return represents the stacked outputs of the second output of ``f``
473
- when scanned over the leading axis of the inputs.
577
+ be used to perform loops where no input ``xs`` are needed).
578
+ base : int, default 16
579
+ Optional integer specifying the base for the bounded loop.
580
+ pbar : ProgressBar or int, optional
581
+ Optional :class:`~.ProgressBar` instance to display the progress
582
+ of the loop operation.
583
+
584
+ Returns
585
+ -------
586
+ Y
587
+ The stacked outputs of ``f`` when looped over the leading axis of the inputs.
588
+
589
+ Examples
590
+ --------
591
+ Basic checkpointed for-loop operation:
592
+
593
+ .. code-block:: python
594
+
595
+ >>> import brainstate
596
+ >>> import jax.numpy as jnp
597
+ >>>
598
+ >>> def process_item(x, y):
599
+ ... return x * y + 1
600
+ >>>
601
+ >>> xs = jnp.array([1.0, 2.0, 3.0])
602
+ >>> ys = jnp.array([4.0, 5.0, 6.0])
603
+ >>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
604
+
605
+ Using custom base for checkpointing:
606
+
607
+ .. code-block:: python
608
+
609
+ >>> results = brainstate.transform.checkpointed_for_loop(
610
+ ... process_item, xs, ys, base=8
611
+ ... )
474
612
  """
475
613
  _, ys = checkpointed_scan(
476
614
  _forloop_to_scan_fun(f),
@@ -483,7 +621,8 @@ def checkpointed_for_loop(
483
621
  return ys
484
622
 
485
623
 
486
- # This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
624
+ # This function is adapted from ``while_loop`` in
625
+ # `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
487
626
 
488
627
  # There's several tricks happening here to work around various limitations of JAX.
489
628
  # (Also see https://github.com/google/jax/issues/2139#issuecomment-1039293633)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.