bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcloop.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcloop.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -22,154 +22,464 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """
26
- Functions that implement the full BART posterior MCMC loop.
27
- """
25
+ """Functions that implement the full BART posterior MCMC loop."""
28
26
 
29
27
  import functools
30
28
 
31
29
  import jax
32
- from jax import random
33
- from jax import debug
30
+ import numpy
31
+ from jax import debug, lax, tree
34
32
  from jax import numpy as jnp
35
- from jax import lax
33
+ from jaxtyping import Array, Real
36
34
 
37
- from . import jaxext
38
- from . import grove
39
- from . import mcmcstep
35
+ from . import grove, jaxext, mcmcstep
36
+ from .mcmcstep import State
40
37
 
41
- @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
42
- def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
38
+
39
+ def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
40
+ """Extract variables for the main trace, to be used in `run_mcmc`."""
41
+ return dict(
42
+ leaf_trees=state.forest.leaf_trees,
43
+ var_trees=state.forest.var_trees,
44
+ split_trees=state.forest.split_trees,
45
+ offset=state.offset,
46
+ )
47
+
48
+
49
+ def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]:
50
+ """Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
51
+ return dict(
52
+ sigma2=state.sigma2,
53
+ grow_prop_count=state.forest.grow_prop_count,
54
+ grow_acc_count=state.forest.grow_acc_count,
55
+ prune_prop_count=state.forest.prune_prop_count,
56
+ prune_acc_count=state.forest.prune_acc_count,
57
+ log_likelihood=state.forest.log_likelihood,
58
+ log_trans_prior=state.forest.log_trans_prior,
59
+ )
60
+
61
+
62
+ def run_mcmc(
63
+ key,
64
+ bart,
65
+ n_save,
66
+ *,
67
+ n_burn=0,
68
+ n_skip=1,
69
+ inner_loop_length=None,
70
+ allow_overflow=False,
71
+ inner_callback=None,
72
+ outer_callback=None,
73
+ callback_state=None,
74
+ onlymain_extractor=default_onlymain_extractor,
75
+ both_extractor=default_both_extractor,
76
+ ):
43
77
  """
44
78
  Run the MCMC for the BART posterior.
45
79
 
46
80
  Parameters
47
81
  ----------
82
+ key : jax.dtypes.prng_key array
83
+ A key for random number generation.
48
84
  bart : dict
49
85
  The initial MCMC state, as created and updated by the functions in
50
- `bartz.mcmcstep`.
51
- n_burn : int
52
- The number of initial iterations which are not saved.
86
+ `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
87
+ so this variable is invalidated after running `run_mcmc`. Make a copy
88
+ beforehand to use it again.
53
89
  n_save : int
54
90
  The number of iterations to save.
55
- n_skip : int
91
+ n_burn : int, default 0
92
+ The number of initial iterations which are not saved.
93
+ n_skip : int, default 1
56
94
  The number of iterations to skip between each saved iteration, plus 1.
57
- callback : callable
58
- An arbitrary function run at each iteration, called with the following
59
- arguments, passed by keyword:
95
+ The effective burn-in is ``n_burn + n_skip - 1``.
96
+ inner_loop_length : int, optional
97
+ The MCMC loop is split into an outer and an inner loop. The outer loop
98
+ is in Python, while the inner loop is in JAX. `inner_loop_length` is the
99
+ number of iterations of the inner loop to run for each iteration of the
100
+ outer loop. If not specified, the outer loop will iterate just once,
101
+ with all iterations done in a single inner loop run. The inner stride is
102
+ unrelated to the stride used for saving the trace.
103
+ allow_overflow : bool, default False
104
+ If `False`, `inner_loop_length` must be a divisor of the total number of
105
+ iterations ``n_burn + n_skip * n_save``. If `True` and
106
+ `inner_loop_length` is not a divisor, some of the MCMC iterations in the
107
+ last outer loop iteration will not be saved to the trace.
108
+ inner_callback : callable, optional
109
+ outer_callback : callable, optional
110
+ Arbitrary functions run during the loop after updating the state.
111
+ `inner_callback` is called after each update, while `outer_callback` is
112
+ called after completing an inner loop. The callbacks are invoked with
113
+ the following arguments, passed by keyword:
60
114
 
61
115
  bart : dict
62
- The current MCMC state.
116
+ The MCMC state just after updating it.
63
117
  burnin : bool
64
118
  Whether the last iteration was in the burn-in phase.
119
+ overflow : bool
120
+ Whether the last iteration was in the overflow phase (iterations
121
+ not saved due to `inner_loop_length` not being a divisor of the
122
+ total number of iterations).
65
123
  i_total : int
66
- The index of the last iteration (0-based).
124
+ The index of the last MCMC iteration (0-based).
67
125
  i_skip : int
68
- The index of the last iteration, starting from the last saved
69
- iteration.
126
+ The number of MCMC updates from the last saved state. The initial
127
+ state counts as saved, even if it's not copied into the trace.
128
+ callback_state : jax pytree
129
+ The callback state, initially set to the argument passed to
130
+ `run_mcmc`, afterwards to the value returned by the last invocation
131
+ of `inner_callback` or `outer_callback`.
70
132
  n_burn, n_save, n_skip : int
71
133
  The corresponding arguments as-is.
134
+ i_outer : int
135
+ The index of the last outer loop iteration (0-based).
136
+ inner_loop_length : int
137
+ The number of MCMC iterations in the inner loop.
72
138
 
73
- Since this function is called under the jax jit, the values are not
74
- available at the time the Python code is executed. Use the utilities in
75
- `jax.debug` to access the values at actual runtime.
76
- key : jax.dtypes.prng_key array
77
- The key for random number generation.
139
+ `inner_callback` is called under the jax jit, so the argument values are
140
+ not available at the time the Python code is executed. Use the utilities
141
+ in `jax.debug` to access the values at actual runtime.
142
+
143
+ The callbacks must return two values:
144
+
145
+ bart : dict
146
+ A possibly modified MCMC state. To avoid modifying the state,
147
+ return the `bart` argument passed to the callback as-is.
148
+ callback_state : jax pytree
149
+ The new state to be passed on the next callback invocation.
150
+
151
+ For convenience, if a callback returns `None`, the states are not
152
+ updated.
153
+ callback_state : jax pytree, optional
154
+ The initial state for the callbacks.
155
+ onlymain_extractor : callable, optional
156
+ both_extractor : callable, optional
157
+ Functions that extract the variables to be saved respectively only in
158
+ the main trace and in both traces, given the MCMC state as argument.
159
+ Must return a pytree, and must be vmappable.
78
160
 
79
161
  Returns
80
162
  -------
81
163
  bart : dict
82
164
  The final MCMC state.
83
- burnin_trace : dict
165
+ burnin_trace : dict of (n_burn, ...) arrays
84
166
  The trace of the burn-in phase, containing the following subset of
85
167
  fields from the `bart` dictionary, with an additional head index that
86
168
  runs over MCMC iterations: 'sigma2', 'grow_prop_count',
87
- 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
88
- main_trace : dict
169
+ 'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
170
+ the fields in `tracevars_both`).
171
+ main_trace : dict of (n_save, ...) arrays
89
172
  The trace of the main phase, containing the following subset of fields
90
- from the `bart` dictionary, with an additional head index that runs
91
- over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
92
- the fields in `burnin_trace`.
173
+ from the `bart` dictionary, with an additional head index that runs over
174
+ MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
175
+ specified the fields in `tracevars_onlymain`), plus the fields in
176
+ `burnin_trace`.
177
+
178
+ Raises
179
+ ------
180
+ ValueError
181
+ If `inner_loop_length` is not a divisor of the total number of
182
+ iterations and `allow_overflow` is `False`.
183
+
184
+ Notes
185
+ -----
186
+ The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
187
+ not include the initial state, and include the final state.
93
188
  """
94
189
 
95
- tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios'
96
-
97
- tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees')
98
-
99
- callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
100
-
101
- def inner_loop(carry, _, tracelist, burnin):
102
- bart, i_total, i_skip, key = carry
103
- key, subkey = random.split(key)
104
- bart = mcmcstep.step(bart, subkey)
105
- callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
106
- output = {key: bart[key] for key in tracelist if key in bart}
107
- return (bart, i_total + 1, i_skip + 1, key), output
108
-
109
- def empty_trace(bart, tracelist):
110
- return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart)
111
-
112
- if n_burn > 0:
113
- carry = bart, 0, 0, key
114
- burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True)
115
- (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn)
116
- else:
117
- i_total = 0
118
- burnin_trace = empty_trace(bart, tracelist_burnin)
119
-
120
- def outer_loop(carry, _):
121
- bart, i_total, key = carry
122
- main_loop = functools.partial(inner_loop, tracelist=[], burnin=False)
123
- inner_carry = bart, i_total, 0, key
124
- (bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip)
125
- output = {key: bart[key] for key in tracelist_main if key in bart}
126
- return (bart, i_total, key), output
127
-
128
- if n_save > 0:
129
- carry = bart, i_total, key
130
- (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save)
131
- else:
132
- main_trace = empty_trace(bart, tracelist_main)
190
+ def empty_trace(length, bart, extractor):
191
+ return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
192
+
193
+ trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
194
+ trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
195
+
196
+ # determine number of iterations for inner and outer loops
197
+ n_iters = n_burn + n_skip * n_save
198
+ if inner_loop_length is None:
199
+ inner_loop_length = n_iters
200
+ n_outer = n_iters // inner_loop_length
201
+ if n_iters % inner_loop_length:
202
+ if allow_overflow:
203
+ n_outer += 1
204
+ else:
205
+ raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
206
+
207
+ carry = (bart, 0, key, trace_both, trace_onlymain, callback_state)
208
+ for i_outer in range(n_outer):
209
+ carry = _run_mcmc_inner_loop(
210
+ carry,
211
+ inner_loop_length,
212
+ inner_callback,
213
+ onlymain_extractor,
214
+ both_extractor,
215
+ n_burn,
216
+ n_save,
217
+ n_skip,
218
+ i_outer,
219
+ )
220
+ if outer_callback is not None:
221
+ bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
222
+ i_total -= 1 # because i_total is updated at the end of the inner loop
223
+ i_skip = _compute_i_skip(i_total, n_burn, n_skip)
224
+ rt = outer_callback(
225
+ bart=bart,
226
+ burnin=i_total < n_burn,
227
+ overflow=i_total >= n_iters,
228
+ i_total=i_total,
229
+ i_skip=i_skip,
230
+ callback_state=callback_state,
231
+ n_burn=n_burn,
232
+ n_save=n_save,
233
+ n_skip=n_skip,
234
+ i_outer=i_outer,
235
+ inner_loop_length=inner_loop_length,
236
+ )
237
+ if rt is not None:
238
+ bart, callback_state = rt
239
+ i_total += 1
240
+ carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
241
+
242
+ bart, _, _, trace_both, trace_onlymain, _ = carry
243
+
244
+ burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
245
+ main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
246
+ main_trace.update(trace_onlymain)
133
247
 
134
248
  return bart, burnin_trace, main_trace
135
249
 
136
- @functools.lru_cache
137
- # cache to make the callback function object unique, such that the jit
138
- # of run_mcmc recognizes it
139
- def make_simple_print_callback(printevery):
250
+
251
+ def _compute_i_skip(i_total, n_burn, n_skip):
252
+ burnin = i_total < n_burn
253
+ return jnp.where(
254
+ burnin,
255
+ i_total + 1,
256
+ (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
257
+ )
258
+
259
+
260
+ @functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
261
+ def _run_mcmc_inner_loop(
262
+ carry,
263
+ inner_loop_length,
264
+ inner_callback,
265
+ onlymain_extractor,
266
+ both_extractor,
267
+ n_burn,
268
+ n_save,
269
+ n_skip,
270
+ i_outer,
271
+ ):
272
+ def loop(carry, _):
273
+ bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
274
+
275
+ keys = jaxext.split(key)
276
+ key = keys.pop()
277
+ bart = mcmcstep.step(keys.pop(), bart)
278
+
279
+ burnin = i_total < n_burn
280
+ if inner_callback is not None:
281
+ i_skip = _compute_i_skip(i_total, n_burn, n_skip)
282
+ rt = inner_callback(
283
+ bart=bart,
284
+ burnin=burnin,
285
+ overflow=i_total >= n_burn + n_save * n_skip,
286
+ i_total=i_total,
287
+ i_skip=i_skip,
288
+ callback_state=callback_state,
289
+ n_burn=n_burn,
290
+ n_save=n_save,
291
+ n_skip=n_skip,
292
+ i_outer=i_outer,
293
+ inner_loop_length=inner_loop_length,
294
+ )
295
+ if rt is not None:
296
+ bart, callback_state = rt
297
+
298
+ i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
299
+ i_both = jnp.where(burnin, i_total, n_burn + i_onlymain)
300
+
301
+ def update_trace(index, trace, state):
302
+ def assign_at_index(trace_array, state_array):
303
+ if trace_array.size:
304
+ return trace_array.at[index, ...].set(state_array)
305
+ else:
306
+ # this handles the case where a trace is empty (e.g.,
307
+ # no burn-in) because jax refuses to index into an array
308
+ # of length 0
309
+ return trace_array
310
+
311
+ return tree.map(assign_at_index, trace, state)
312
+
313
+ trace_onlymain = update_trace(
314
+ i_onlymain, trace_onlymain, onlymain_extractor(bart)
315
+ )
316
+ trace_both = update_trace(i_both, trace_both, both_extractor(bart))
317
+
318
+ i_total += 1
319
+ carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
320
+ return carry, None
321
+
322
+ carry, _ = lax.scan(loop, carry, None, inner_loop_length)
323
+ return carry
324
+
325
+
326
+ def make_print_callbacks(dot_every_inner=1, report_every_outer=1):
140
327
  """
141
- Create a logging callback function for MCMC iterations.
328
+ Prepare logging callbacks for `run_mcmc`.
329
+
330
+ Prepare callbacks which print a dot on every iteration, and a longer
331
+ report outer loop iteration.
142
332
 
143
333
  Parameters
144
334
  ----------
145
- printevery : int
146
- The number of iterations between each log.
335
+ dot_every_inner : int, default 1
336
+ A dot is printed every `dot_every_inner` MCMC iterations.
337
+ report_every_outer : int, default 1
338
+ A report is printed every `report_every_outer` outer loop
339
+ iterations.
147
340
 
148
341
  Returns
149
342
  -------
150
- callback : callable
151
- A function in the format required by `run_mcmc`.
343
+ kwargs : dict
344
+ A dictionary with the arguments to pass to `run_mcmc` as keyword
345
+ arguments to set up the callbacks.
346
+
347
+ Examples
348
+ --------
349
+ >>> run_mcmc(..., **make_print_callbacks())
152
350
  """
153
- def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
154
- prop_total = len(bart['leaf_trees'])
155
- grow_prop = bart['grow_prop_count'] / prop_total
156
- prune_prop = bart['prune_prop_count'] / prop_total
157
- grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
158
- prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
159
- n_total = n_burn + n_save * n_skip
160
- printcond = (i_total + 1) % printevery == 0
161
- debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond)
162
- return callback
163
-
164
- def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond):
165
- if printcond:
166
- burnin_flag = ' (burnin)' if burnin else ''
167
- total_str = str(n_total)
168
- ndigits = len(total_str)
169
- i_str = str(i_total + 1).rjust(ndigits)
170
- print(f'Iteration {i_str}/{total_str} '
171
- f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
172
- f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
351
+ return dict(
352
+ inner_callback=_print_callback_inner,
353
+ outer_callback=_print_callback_outer,
354
+ callback_state=dict(
355
+ dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
356
+ ),
357
+ )
358
+
359
+
360
+ def _print_callback_inner(*, i_total, callback_state, **_):
361
+ dot_every_inner = callback_state['dot_every_inner']
362
+ if dot_every_inner is not None:
363
+ cond = (i_total + 1) % dot_every_inner == 0
364
+ debug.callback(_print_dot, cond)
365
+
366
+
367
+ def _print_dot(cond):
368
+ if cond:
369
+ print('.', end='', flush=True)
370
+
371
+
372
+ def _print_callback_outer(
373
+ *,
374
+ bart,
375
+ burnin,
376
+ overflow,
377
+ i_total,
378
+ n_burn,
379
+ n_save,
380
+ n_skip,
381
+ callback_state,
382
+ i_outer,
383
+ inner_loop_length,
384
+ **_,
385
+ ):
386
+ report_every_outer = callback_state['report_every_outer']
387
+ if report_every_outer is not None:
388
+ dot_every_inner = callback_state['dot_every_inner']
389
+ if dot_every_inner is None:
390
+ newline = False
391
+ else:
392
+ newline = dot_every_inner < inner_loop_length
393
+ debug.callback(
394
+ _print_report,
395
+ cond=(i_outer + 1) % report_every_outer == 0,
396
+ newline=newline,
397
+ burnin=burnin,
398
+ overflow=overflow,
399
+ i_total=i_total,
400
+ n_iters=n_burn + n_save * n_skip,
401
+ grow_prop_count=bart.forest.grow_prop_count,
402
+ grow_acc_count=bart.forest.grow_acc_count,
403
+ prune_prop_count=bart.forest.prune_prop_count,
404
+ prune_acc_count=bart.forest.prune_acc_count,
405
+ prop_total=len(bart.forest.leaf_trees),
406
+ fill=grove.forest_fill(bart.forest.split_trees),
407
+ )
408
+
409
+
410
+ def _convert_jax_arrays_in_args(func):
411
+ """Remove jax arrays from a function arguments.
412
+
413
+ Converts all jax.Array instances in the arguments to either Python scalars
414
+ or numpy arrays.
415
+ """
416
+
417
+ def convert_jax_arrays(pytree):
418
+ def convert_jax_arrays(val):
419
+ if not isinstance(val, jax.Array):
420
+ return val
421
+ elif val.shape:
422
+ return numpy.array(val)
423
+ else:
424
+ return val.item()
425
+
426
+ return tree.map(convert_jax_arrays, pytree)
427
+
428
+ @functools.wraps(func)
429
+ def new_func(*args, **kw):
430
+ args = convert_jax_arrays(args)
431
+ kw = convert_jax_arrays(kw)
432
+ return func(*args, **kw)
433
+
434
+ return new_func
435
+
436
+
437
+ @_convert_jax_arrays_in_args
438
+ # convert all jax arrays in arguments because operations on them could lead to
439
+ # deadlock with the main thread
440
+ def _print_report(
441
+ *,
442
+ cond,
443
+ newline,
444
+ burnin,
445
+ overflow,
446
+ i_total,
447
+ n_iters,
448
+ grow_prop_count,
449
+ grow_acc_count,
450
+ prune_prop_count,
451
+ prune_acc_count,
452
+ prop_total,
453
+ fill,
454
+ ):
455
+ if cond:
456
+ newline = '\n' if newline else ''
457
+
458
+ def acc_string(acc_count, prop_count):
459
+ if prop_count:
460
+ return f'{acc_count / prop_count:.0%}'
461
+ else:
462
+ return ' n/d'
463
+
464
+ grow_prop = grow_prop_count / prop_total
465
+ prune_prop = prune_prop_count / prop_total
466
+ grow_acc = acc_string(grow_acc_count, grow_prop_count)
467
+ prune_acc = acc_string(prune_acc_count, prune_prop_count)
468
+
469
+ if burnin:
470
+ flag = ' (burnin)'
471
+ elif overflow:
472
+ flag = ' (overflow)'
473
+ else:
474
+ flag = ''
475
+
476
+ print(
477
+ f'{newline}It {i_total + 1}/{n_iters} '
478
+ f'grow P={grow_prop:.0%} A={grow_acc}, '
479
+ f'prune P={prune_prop:.0%} A={prune_acc}, '
480
+ f'fill={fill:.0%}{flag}'
481
+ )
482
+
173
483
 
174
484
  @jax.jit
175
485
  def evaluate_trace(trace, X):
@@ -189,9 +499,13 @@ def evaluate_trace(trace, X):
189
499
  The predictions for each iteration of the MCMC.
190
500
  """
191
501
  evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
192
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0))
193
- def loop(_, state):
194
- values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees'])
195
- return None, jnp.sum(values, axis=0, dtype=jnp.float32)
502
+ evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
503
+
504
+ def loop(_, row):
505
+ values = evaluate_trees(
506
+ X, row['leaf_trees'], row['var_trees'], row['split_trees']
507
+ )
508
+ return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
509
+
196
510
  _, y = lax.scan(loop, None, trace)
197
511
  return y