bartz 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcloop.py CHANGED
@@ -22,57 +22,141 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """
26
- Functions that implement the full BART posterior MCMC loop.
27
- """
25
+ """Functions that implement the full BART posterior MCMC loop."""
28
26
 
29
27
  import functools
30
28
 
31
29
  import jax
32
- from jax import debug, lax, random, tree
30
+ import numpy
31
+ from jax import debug, lax, tree
33
32
  from jax import numpy as jnp
33
+ from jaxtyping import Array, Real
34
34
 
35
35
  from . import grove, jaxext, mcmcstep
36
+ from .mcmcstep import State
36
37
 
37
38
 
38
- @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))
39
- def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
39
+ def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
40
+ """Extract variables for the main trace, to be used in `run_mcmc`."""
41
+ return dict(
42
+ leaf_trees=state.forest.leaf_trees,
43
+ var_trees=state.forest.var_trees,
44
+ split_trees=state.forest.split_trees,
45
+ offset=state.offset,
46
+ )
47
+
48
+
49
+ def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]:
50
+ """Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
51
+ return dict(
52
+ sigma2=state.sigma2,
53
+ grow_prop_count=state.forest.grow_prop_count,
54
+ grow_acc_count=state.forest.grow_acc_count,
55
+ prune_prop_count=state.forest.prune_prop_count,
56
+ prune_acc_count=state.forest.prune_acc_count,
57
+ log_likelihood=state.forest.log_likelihood,
58
+ log_trans_prior=state.forest.log_trans_prior,
59
+ )
60
+
61
+
62
+ def run_mcmc(
63
+ key,
64
+ bart,
65
+ n_save,
66
+ *,
67
+ n_burn=0,
68
+ n_skip=1,
69
+ inner_loop_length=None,
70
+ allow_overflow=False,
71
+ inner_callback=None,
72
+ outer_callback=None,
73
+ callback_state=None,
74
+ onlymain_extractor=default_onlymain_extractor,
75
+ both_extractor=default_both_extractor,
76
+ ):
40
77
  """
41
78
  Run the MCMC for the BART posterior.
42
79
 
43
80
  Parameters
44
81
  ----------
45
82
  key : jax.dtypes.prng_key array
46
- The key for random number generation.
83
+ A key for random number generation.
47
84
  bart : dict
48
85
  The initial MCMC state, as created and updated by the functions in
49
- `bartz.mcmcstep`.
50
- n_burn : int
51
- The number of initial iterations which are not saved.
86
+ `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
87
+ so this variable is invalidated after running `run_mcmc`. Make a copy
88
+ beforehand to use it again.
52
89
  n_save : int
53
90
  The number of iterations to save.
54
- n_skip : int
91
+ n_burn : int, default 0
92
+ The number of initial iterations which are not saved.
93
+ n_skip : int, default 1
55
94
  The number of iterations to skip between each saved iteration, plus 1.
56
95
  The effective burn-in is ``n_burn + n_skip - 1``.
57
- callback : callable
58
- An arbitrary function run at each iteration, called with the following
59
- arguments, passed by keyword:
96
+ inner_loop_length : int, optional
97
+ The MCMC loop is split into an outer and an inner loop. The outer loop
98
+ is in Python, while the inner loop is in JAX. `inner_loop_length` is the
99
+ number of iterations of the inner loop to run for each iteration of the
100
+ outer loop. If not specified, the outer loop will iterate just once,
101
+ with all iterations done in a single inner loop run. The inner stride is
102
+ unrelated to the stride used for saving the trace.
103
+ allow_overflow : bool, default False
104
+ If `False`, `inner_loop_length` must be a divisor of the total number of
105
+ iterations ``n_burn + n_skip * n_save``. If `True` and
106
+ `inner_loop_length` is not a divisor, some of the MCMC iterations in the
107
+ last outer loop iteration will not be saved to the trace.
108
+ inner_callback : callable, optional
109
+ outer_callback : callable, optional
110
+ Arbitrary functions run during the loop after updating the state.
111
+ `inner_callback` is called after each update, while `outer_callback` is
112
+ called after completing an inner loop. The callbacks are invoked with
113
+ the following arguments, passed by keyword:
60
114
 
61
115
  bart : dict
62
116
  The MCMC state just after updating it.
63
117
  burnin : bool
64
118
  Whether the last iteration was in the burn-in phase.
119
+ overflow : bool
120
+ Whether the last iteration was in the overflow phase (iterations
121
+ not saved due to `inner_loop_length` not being a divisor of the
122
+ total number of iterations).
65
123
  i_total : int
66
- The index of the last iteration (0-based).
124
+ The index of the last MCMC iteration (0-based).
67
125
  i_skip : int
68
126
  The number of MCMC updates from the last saved state. The initial
69
127
  state counts as saved, even if it's not copied into the trace.
128
+ callback_state : jax pytree
129
+ The callback state, initially set to the argument passed to
130
+ `run_mcmc`, afterwards to the value returned by the last invocation
131
+ of `inner_callback` or `outer_callback`.
70
132
  n_burn, n_save, n_skip : int
71
133
  The corresponding arguments as-is.
134
+ i_outer : int
135
+ The index of the last outer loop iteration (0-based).
136
+ inner_loop_length : int
137
+ The number of MCMC iterations in the inner loop.
138
+
139
+ `inner_callback` is called under the jax jit, so the argument values are
140
+ not available at the time the Python code is executed. Use the utilities
141
+ in `jax.debug` to access the values at actual runtime.
72
142
 
73
- Since this function is called under the jax jit, the values are not
74
- available at the time the Python code is executed. Use the utilities in
75
- `jax.debug` to access the values at actual runtime.
143
+ The callbacks must return two values:
144
+
145
+ bart : dict
146
+ A possibly modified MCMC state. To avoid modifying the state,
147
+ return the `bart` argument passed to the callback as-is.
148
+ callback_state : jax pytree
149
+ The new state to be passed on the next callback invocation.
150
+
151
+ For convenience, if a callback returns `None`, the states are not
152
+ updated.
153
+ callback_state : jax pytree, optional
154
+ The initial state for the callbacks.
155
+ onlymain_extractor : callable, optional
156
+ both_extractor : callable, optional
157
+ Functions that extract the variables to be saved respectively only in
158
+ the main trace and in both traces, given the MCMC state as argument.
159
+ Must return a pytree, and must be vmappable.
76
160
 
77
161
  Returns
78
162
  -------
@@ -82,12 +166,20 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
82
166
  The trace of the burn-in phase, containing the following subset of
83
167
  fields from the `bart` dictionary, with an additional head index that
84
168
  runs over MCMC iterations: 'sigma2', 'grow_prop_count',
85
- 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
169
+ 'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
170
+ the fields in `tracevars_both`).
86
171
  main_trace : dict of (n_save, ...) arrays
87
172
  The trace of the main phase, containing the following subset of fields
88
- from the `bart` dictionary, with an additional head index that runs
89
- over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
90
- the fields in `burnin_trace`.
173
+ from the `bart` dictionary, with an additional head index that runs over
174
+ MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
175
+ specified the fields in `tracevars_onlymain`), plus the fields in
176
+ `burnin_trace`.
177
+
178
+ Raises
179
+ ------
180
+ ValueError
181
+ If `inner_loop_length` is not a divisor of the total number of
182
+ iterations and `allow_overflow` is `False`.
91
183
 
92
184
  Notes
93
185
  -----
@@ -95,49 +187,118 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
95
187
  not include the initial state, and include the final state.
96
188
  """
97
189
 
98
- tracevars_light = (
99
- 'sigma2',
100
- 'grow_prop_count',
101
- 'grow_acc_count',
102
- 'prune_prop_count',
103
- 'prune_acc_count',
104
- 'ratios',
105
- )
106
- tracevars_heavy = ('leaf_trees', 'var_trees', 'split_trees')
190
+ def empty_trace(length, bart, extractor):
191
+ return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
192
+
193
+ trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
194
+ trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
195
+
196
+ # determine number of iterations for inner and outer loops
197
+ n_iters = n_burn + n_skip * n_save
198
+ if inner_loop_length is None:
199
+ inner_loop_length = n_iters
200
+ n_outer = n_iters // inner_loop_length
201
+ if n_iters % inner_loop_length:
202
+ if allow_overflow:
203
+ n_outer += 1
204
+ else:
205
+ raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
206
+
207
+ carry = (bart, 0, key, trace_both, trace_onlymain, callback_state)
208
+ for i_outer in range(n_outer):
209
+ carry = _run_mcmc_inner_loop(
210
+ carry,
211
+ inner_loop_length,
212
+ inner_callback,
213
+ onlymain_extractor,
214
+ both_extractor,
215
+ n_burn,
216
+ n_save,
217
+ n_skip,
218
+ i_outer,
219
+ )
220
+ if outer_callback is not None:
221
+ bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
222
+ i_total -= 1 # because i_total is updated at the end of the inner loop
223
+ i_skip = _compute_i_skip(i_total, n_burn, n_skip)
224
+ rt = outer_callback(
225
+ bart=bart,
226
+ burnin=i_total < n_burn,
227
+ overflow=i_total >= n_iters,
228
+ i_total=i_total,
229
+ i_skip=i_skip,
230
+ callback_state=callback_state,
231
+ n_burn=n_burn,
232
+ n_save=n_save,
233
+ n_skip=n_skip,
234
+ i_outer=i_outer,
235
+ inner_loop_length=inner_loop_length,
236
+ )
237
+ if rt is not None:
238
+ bart, callback_state = rt
239
+ i_total += 1
240
+ carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
241
+
242
+ bart, _, _, trace_both, trace_onlymain, _ = carry
243
+
244
+ burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
245
+ main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
246
+ main_trace.update(trace_onlymain)
107
247
 
108
- def empty_trace(length, bart, tracelist):
109
- bart = {k: v for k, v in bart.items() if k in tracelist}
110
- return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
248
+ return bart, burnin_trace, main_trace
111
249
 
112
- trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
113
- trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
114
250
 
115
- callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
251
+ def _compute_i_skip(i_total, n_burn, n_skip):
252
+ burnin = i_total < n_burn
253
+ return jnp.where(
254
+ burnin,
255
+ i_total + 1,
256
+ (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
257
+ )
116
258
 
117
- carry = (bart, 0, key, trace_light, trace_heavy)
118
259
 
260
+ @functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
261
+ def _run_mcmc_inner_loop(
262
+ carry,
263
+ inner_loop_length,
264
+ inner_callback,
265
+ onlymain_extractor,
266
+ both_extractor,
267
+ n_burn,
268
+ n_save,
269
+ n_skip,
270
+ i_outer,
271
+ ):
119
272
  def loop(carry, _):
120
- bart, i_total, key, trace_light, trace_heavy = carry
273
+ bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
121
274
 
122
- key, subkey = random.split(key)
123
- bart = mcmcstep.step(subkey, bart)
275
+ keys = jaxext.split(key)
276
+ key = keys.pop()
277
+ bart = mcmcstep.step(keys.pop(), bart)
124
278
 
125
279
  burnin = i_total < n_burn
126
- i_skip = jnp.where(
127
- burnin,
128
- i_total + 1,
129
- (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
130
- )
131
- callback(
132
- bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw
133
- )
134
-
135
- i_heavy = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
136
- i_light = jnp.where(burnin, i_total, n_burn + i_heavy)
137
-
138
- def update_trace(index, trace, bart):
139
- bart = {k: v for k, v in bart.items() if k in trace}
140
-
280
+ if inner_callback is not None:
281
+ i_skip = _compute_i_skip(i_total, n_burn, n_skip)
282
+ rt = inner_callback(
283
+ bart=bart,
284
+ burnin=burnin,
285
+ overflow=i_total >= n_burn + n_save * n_skip,
286
+ i_total=i_total,
287
+ i_skip=i_skip,
288
+ callback_state=callback_state,
289
+ n_burn=n_burn,
290
+ n_save=n_save,
291
+ n_skip=n_skip,
292
+ i_outer=i_outer,
293
+ inner_loop_length=inner_loop_length,
294
+ )
295
+ if rt is not None:
296
+ bart, callback_state = rt
297
+
298
+ i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
299
+ i_both = jnp.where(burnin, i_total, n_burn + i_onlymain)
300
+
301
+ def update_trace(index, trace, state):
141
302
  def assign_at_index(trace_array, state_array):
142
303
  if trace_array.size:
143
304
  return trace_array.at[index, ...].set(state_array)
@@ -147,84 +308,176 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
147
308
  # of length 0
148
309
  return trace_array
149
310
 
150
- return tree.map(assign_at_index, trace, bart)
311
+ return tree.map(assign_at_index, trace, state)
151
312
 
152
- trace_heavy = update_trace(i_heavy, trace_heavy, bart)
153
- trace_light = update_trace(i_light, trace_light, bart)
313
+ trace_onlymain = update_trace(
314
+ i_onlymain, trace_onlymain, onlymain_extractor(bart)
315
+ )
316
+ trace_both = update_trace(i_both, trace_both, both_extractor(bart))
154
317
 
155
318
  i_total += 1
156
- carry = (bart, i_total, key, trace_light, trace_heavy)
319
+ carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
157
320
  return carry, None
158
321
 
159
- carry, _ = lax.scan(loop, carry, None, n_burn + n_skip * n_save)
160
-
161
- bart, _, _, trace_light, trace_heavy = carry
322
+ carry, _ = lax.scan(loop, carry, None, inner_loop_length)
323
+ return carry
162
324
 
163
- burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
164
- main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
165
- main_trace.update(trace_heavy)
166
325
 
167
- return bart, burnin_trace, main_trace
168
-
169
-
170
- @functools.lru_cache
171
- # cache to make the callback function object unique, such that the jit
172
- # of run_mcmc recognizes it
173
- def make_simple_print_callback(printevery):
326
+ def make_print_callbacks(dot_every_inner=1, report_every_outer=1):
174
327
  """
175
- Create a logging callback function for MCMC iterations.
328
+ Prepare logging callbacks for `run_mcmc`.
329
+
330
+ Prepare callbacks which print a dot on every iteration, and a longer
331
+ report outer loop iteration.
176
332
 
177
333
  Parameters
178
334
  ----------
179
- printevery : int
180
- The number of iterations between each log.
335
+ dot_every_inner : int, default 1
336
+ A dot is printed every `dot_every_inner` MCMC iterations.
337
+ report_every_outer : int, default 1
338
+ A report is printed every `report_every_outer` outer loop
339
+ iterations.
181
340
 
182
341
  Returns
183
342
  -------
184
- callback : callable
185
- A function in the format required by `run_mcmc`.
343
+ kwargs : dict
344
+ A dictionary with the arguments to pass to `run_mcmc` as keyword
345
+ arguments to set up the callbacks.
346
+
347
+ Examples
348
+ --------
349
+ >>> run_mcmc(..., **make_print_callbacks())
186
350
  """
351
+ return dict(
352
+ inner_callback=_print_callback_inner,
353
+ outer_callback=_print_callback_outer,
354
+ callback_state=dict(
355
+ dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
356
+ ),
357
+ )
358
+
359
+
360
+ def _print_callback_inner(*, i_total, callback_state, **_):
361
+ dot_every_inner = callback_state['dot_every_inner']
362
+ if dot_every_inner is not None:
363
+ cond = (i_total + 1) % dot_every_inner == 0
364
+ debug.callback(_print_dot, cond)
365
+
366
+
367
+ def _print_dot(cond):
368
+ if cond:
369
+ print('.', end='', flush=True)
187
370
 
188
- def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
189
- prop_total = len(bart['leaf_trees'])
190
- grow_prop = bart['grow_prop_count'] / prop_total
191
- prune_prop = bart['prune_prop_count'] / prop_total
192
- grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
193
- prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
194
- n_total = n_burn + n_save * n_skip
195
- printcond = (i_total + 1) % printevery == 0
371
+
372
+ def _print_callback_outer(
373
+ *,
374
+ bart,
375
+ burnin,
376
+ overflow,
377
+ i_total,
378
+ n_burn,
379
+ n_save,
380
+ n_skip,
381
+ callback_state,
382
+ i_outer,
383
+ inner_loop_length,
384
+ **_,
385
+ ):
386
+ report_every_outer = callback_state['report_every_outer']
387
+ if report_every_outer is not None:
388
+ dot_every_inner = callback_state['dot_every_inner']
389
+ if dot_every_inner is None:
390
+ newline = False
391
+ else:
392
+ newline = dot_every_inner < inner_loop_length
196
393
  debug.callback(
197
- _simple_print_callback,
198
- burnin,
199
- i_total,
200
- n_total,
201
- grow_prop,
202
- grow_acc,
203
- prune_prop,
204
- prune_acc,
205
- printcond,
394
+ _print_report,
395
+ cond=(i_outer + 1) % report_every_outer == 0,
396
+ newline=newline,
397
+ burnin=burnin,
398
+ overflow=overflow,
399
+ i_total=i_total,
400
+ n_iters=n_burn + n_save * n_skip,
401
+ grow_prop_count=bart.forest.grow_prop_count,
402
+ grow_acc_count=bart.forest.grow_acc_count,
403
+ prune_prop_count=bart.forest.prune_prop_count,
404
+ prune_acc_count=bart.forest.prune_acc_count,
405
+ prop_total=len(bart.forest.leaf_trees),
406
+ fill=grove.forest_fill(bart.forest.split_trees),
206
407
  )
207
408
 
208
- return callback
209
409
 
410
+ def _convert_jax_arrays_in_args(func):
411
+ """Remove jax arrays from a function arguments.
210
412
 
211
- def _simple_print_callback(
212
- burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond
413
+ Converts all jax.Array instances in the arguments to either Python scalars
414
+ or numpy arrays.
415
+ """
416
+
417
+ def convert_jax_arrays(pytree):
418
+ def convert_jax_arrays(val):
419
+ if not isinstance(val, jax.Array):
420
+ return val
421
+ elif val.shape:
422
+ return numpy.array(val)
423
+ else:
424
+ return val.item()
425
+
426
+ return tree.map(convert_jax_arrays, pytree)
427
+
428
+ @functools.wraps(func)
429
+ def new_func(*args, **kw):
430
+ args = convert_jax_arrays(args)
431
+ kw = convert_jax_arrays(kw)
432
+ return func(*args, **kw)
433
+
434
+ return new_func
435
+
436
+
437
+ @_convert_jax_arrays_in_args
438
+ # convert all jax arrays in arguments because operations on them could lead to
439
+ # deadlock with the main thread
440
+ def _print_report(
441
+ *,
442
+ cond,
443
+ newline,
444
+ burnin,
445
+ overflow,
446
+ i_total,
447
+ n_iters,
448
+ grow_prop_count,
449
+ grow_acc_count,
450
+ prune_prop_count,
451
+ prune_acc_count,
452
+ prop_total,
453
+ fill,
213
454
  ):
214
- if printcond:
215
- burnin_flag = ' (burnin)' if burnin else ''
216
- total_str = str(n_total)
217
- ndigits = len(total_str)
218
- i_str = str(i_total.item() + 1).rjust(ndigits)
219
- # I do i_total.item() + 1 instead of just i_total + 1 to solve a bug
220
- # originating when jax is combined with some outdated dependencies. (I
221
- # did not track down which dependencies exactly.) Doing .item() makes
222
- # the + 1 operation be done by Python instead of by jax. The bug is that
223
- # jax hangs completely, with a secondary thread blocked at this line.
455
+ if cond:
456
+ newline = '\n' if newline else ''
457
+
458
+ def acc_string(acc_count, prop_count):
459
+ if prop_count:
460
+ return f'{acc_count / prop_count:.0%}'
461
+ else:
462
+ return ' n/d'
463
+
464
+ grow_prop = grow_prop_count / prop_total
465
+ prune_prop = prune_prop_count / prop_total
466
+ grow_acc = acc_string(grow_acc_count, grow_prop_count)
467
+ prune_acc = acc_string(prune_acc_count, prune_prop_count)
468
+
469
+ if burnin:
470
+ flag = ' (burnin)'
471
+ elif overflow:
472
+ flag = ' (overflow)'
473
+ else:
474
+ flag = ''
475
+
224
476
  print(
225
- f'Iteration {i_str}/{total_str} '
226
- f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
227
- f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
477
+ f'{newline}It {i_total + 1}/{n_iters} '
478
+ f'grow P={grow_prop:.0%} A={grow_acc}, '
479
+ f'prune P={prune_prop:.0%} A={prune_acc}, '
480
+ f'fill={fill:.0%}{flag}'
228
481
  )
229
482
 
230
483
 
@@ -248,11 +501,11 @@ def evaluate_trace(trace, X):
248
501
  evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
249
502
  evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
250
503
 
251
- def loop(_, state):
504
+ def loop(_, row):
252
505
  values = evaluate_trees(
253
- X, state['leaf_trees'], state['var_trees'], state['split_trees']
506
+ X, row['leaf_trees'], row['var_trees'], row['split_trees']
254
507
  )
255
- return None, jnp.sum(values, axis=0, dtype=jnp.float32)
508
+ return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
256
509
 
257
510
  _, y = lax.scan(loop, None, trace)
258
511
  return y