bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcloop.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcloop.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -28,39 +28,62 @@ The entry points are `run_mcmc` and `make_default_callback`.
28
28
  """
29
29
 
30
30
  from collections.abc import Callable
31
- from dataclasses import fields, replace
31
+ from dataclasses import fields
32
32
  from functools import partial, wraps
33
+ from math import floor
33
34
  from typing import Any, Protocol
34
35
 
35
36
  import jax
36
37
  import numpy
37
38
  from equinox import Module
38
- from jax import debug, lax, tree
39
+ from jax import (
40
+ NamedSharding,
41
+ ShapeDtypeStruct,
42
+ debug,
43
+ device_put,
44
+ eval_shape,
45
+ jit,
46
+ tree,
47
+ )
39
48
  from jax import numpy as jnp
40
49
  from jax.nn import softmax
50
+ from jax.sharding import Mesh, PartitionSpec
41
51
  from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
42
52
 
43
- from bartz import grove, jaxext, mcmcstep
53
+ from bartz import jaxext, mcmcstep
54
+ from bartz._profiler import (
55
+ cond_if_not_profiling,
56
+ get_profile_mode,
57
+ jit_if_not_profiling,
58
+ scan_if_not_profiling,
59
+ )
60
+ from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram
61
+ from bartz.jaxext import autobatch
44
62
  from bartz.mcmcstep import State
63
+ from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
45
64
 
46
65
 
47
66
  class BurninTrace(Module):
48
67
  """MCMC trace with only diagnostic values."""
49
68
 
50
- sigma2: Float32[Array, '*trace_length'] | None
51
- theta: Float32[Array, '*trace_length'] | None
52
- grow_prop_count: Int32[Array, '*trace_length']
53
- grow_acc_count: Int32[Array, '*trace_length']
54
- prune_prop_count: Int32[Array, '*trace_length']
55
- prune_acc_count: Int32[Array, '*trace_length']
56
- log_likelihood: Float32[Array, '*trace_length'] | None
57
- log_trans_prior: Float32[Array, '*trace_length'] | None
69
+ error_cov_inv: (
70
+ Float32[Array, '*chains_and_samples']
71
+ | Float32[Array, '*chains_and_samples k k']
72
+ | None
73
+ ) = field(chains=True)
74
+ theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
75
+ grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
76
+ grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
77
+ prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
78
+ prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
79
+ log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
80
+ log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
58
81
 
59
82
  @classmethod
60
83
  def from_state(cls, state: State) -> 'BurninTrace':
61
84
  """Create a single-item burn-in trace from a MCMC state."""
62
85
  return cls(
63
- sigma2=state.sigma2,
86
+ error_cov_inv=state.error_cov_inv,
64
87
  theta=state.forest.theta,
65
88
  grow_prop_count=state.forest.grow_prop_count,
66
89
  grow_acc_count=state.forest.grow_acc_count,
@@ -74,11 +97,14 @@ class BurninTrace(Module):
74
97
  class MainTrace(BurninTrace):
75
98
  """MCMC trace with trees and diagnostic values."""
76
99
 
77
- leaf_tree: Float32[Array, '*trace_length 2**d']
78
- var_tree: UInt[Array, '*trace_length 2**(d-1)']
79
- split_tree: UInt[Array, '*trace_length 2**(d-1)']
80
- offset: Float32[Array, '*trace_length']
81
- varprob: Float32[Array, '*trace_length p'] | None
100
+ leaf_tree: (
101
+ Float32[Array, '*chains_and_samples 2**d']
102
+ | Float32[Array, '*chains_and_samples k 2**d']
103
+ ) = field(chains=True)
104
+ var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
105
+ split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
106
+ offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
107
+ varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
82
108
 
83
109
  @classmethod
84
110
  def from_state(cls, state: State) -> 'MainTrace':
@@ -171,8 +197,12 @@ class _Carry(Module):
171
197
  bart: State
172
198
  i_total: Int32[Array, '']
173
199
  key: Key[Array, '']
174
- burnin_trace: PyTree[Shaped[Array, 'n_burn *']]
175
- main_trace: PyTree[Shaped[Array, 'n_save *']]
200
+ burnin_trace: PyTree[
201
+ Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
202
+ ]
203
+ main_trace: PyTree[
204
+ Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
205
+ ]
176
206
  callback_state: CallbackState
177
207
 
178
208
 
@@ -188,7 +218,11 @@ def run_mcmc(
188
218
  callback_state: CallbackState = None,
189
219
  burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
190
220
  main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
191
- ) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
221
+ ) -> tuple[
222
+ State,
223
+ PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']],
224
+ PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']],
225
+ ]:
192
226
  """
193
227
  Run the MCMC for the BART posterior.
194
228
 
@@ -226,9 +260,9 @@ def run_mcmc(
226
260
  The initial custom state for the callback.
227
261
  burnin_extractor
228
262
  main_extractor
229
- Functions that extract the variables to be saved respectively only in
230
- the main trace and in both traces, given the MCMC state as argument.
231
- Must return a pytree, and must be vmappable.
263
+ Functions that extract the variables to be saved respectively in the
264
+ burnin trace and main traces, given the MCMC state as argument. Must
265
+ return a pytree, and must be vmappable.
232
266
 
233
267
  Returns
234
268
  -------
@@ -239,17 +273,20 @@ def run_mcmc(
239
273
  main_trace : PyTree[Shaped[Array, 'n_save *']]
240
274
  The trace of the main phase. For the default layout, see `MainTrace`.
241
275
 
276
+ Raises
277
+ ------
278
+ RuntimeError
279
+ If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and
280
+ with settings that would create unrolled loops in the trace.
281
+
242
282
  Notes
243
283
  -----
244
284
  The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
245
285
  not include the initial state, and include the final state.
246
286
  """
247
-
248
- def empty_trace(length, bart, extractor):
249
- return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
250
-
251
- burnin_trace = empty_trace(n_burn, bart, burnin_extractor)
252
- main_trace = empty_trace(n_save, bart, main_extractor)
287
+ # create empty traces
288
+ burnin_trace = _empty_trace(n_burn, bart, burnin_extractor)
289
+ main_trace = _empty_trace(n_save, bart, main_extractor)
253
290
 
254
291
  # determine number of iterations for inner and outer loops
255
292
  n_iters = n_burn + n_skip * n_save
@@ -262,7 +299,27 @@ def run_mcmc(
262
299
  # setting to 0 would make for a clean noop, but it's useful to keep the
263
300
  # same code path for benchmarking and testing
264
301
 
265
- carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state)
302
+ # error if under jit and there are unrolled loops or profile mode is on
303
+ under_jit = not hasattr(jnp.empty(0), 'platform')
304
+ if under_jit and (n_outer > 1 or get_profile_mode()):
305
+ msg = (
306
+ '`run_mcmc` was called within a jit-compiled function and '
307
+ 'there are either more than 1 outer loops or profile mode is active, '
308
+ 'please either do not jit, set `inner_loop_length=None`, or disable '
309
+ 'profile mode.'
310
+ )
311
+ raise RuntimeError(msg)
312
+
313
+ replicate = partial(_replicate, mesh=bart.config.mesh)
314
+ carry = _Carry(
315
+ bart,
316
+ replicate(jnp.int32(0)),
317
+ replicate(key),
318
+ burnin_trace,
319
+ main_trace,
320
+ callback_state,
321
+ )
322
+ _run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001
266
323
  for i_outer in range(n_outer):
267
324
  carry = _run_mcmc_inner_loop(
268
325
  carry,
@@ -280,6 +337,30 @@ def run_mcmc(
280
337
  return carry.bart, carry.burnin_trace, carry.main_trace
281
338
 
282
339
 
340
+ def _replicate(x: Array, mesh: Mesh | None) -> Array:
341
+ if mesh is None:
342
+ return x
343
+ else:
344
+ return device_put(x, NamedSharding(mesh, PartitionSpec()))
345
+
346
+
347
+ @partial(jit, static_argnums=(0, 2))
348
+ def _empty_trace(
349
+ length: int, bart: State, extractor: Callable[[State], PyTree]
350
+ ) -> PyTree:
351
+ num_chains = get_num_chains(bart)
352
+ if num_chains is None:
353
+ out_axes = 0
354
+ else:
355
+ example_output = eval_shape(extractor, bart)
356
+ chain_axes = chain_vmap_axes(example_output)
357
+ out_axes = tree.map(
358
+ lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
359
+ )
360
+ return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart)
361
+
362
+
363
+ @jit
283
364
  def _compute_i_skip(
284
365
  i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
285
366
  ) -> Int32[Array, '']:
@@ -293,7 +374,34 @@ def _compute_i_skip(
293
374
  )
294
375
 
295
376
 
296
- @partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
377
+ class _CallCounter:
378
+ """Wrap a callable to check it's not called more than once."""
379
+
380
+ def __init__(self, func: Callable) -> None:
381
+ self.func = func
382
+ self.n_calls = 0
383
+
384
+ def reset_call_counter(self) -> None:
385
+ """Reset the call counter."""
386
+ self.n_calls = 0
387
+
388
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
389
+ if self.n_calls and not get_profile_mode():
390
+ msg = (
391
+ 'The inner loop of `run_mcmc` was traced more than once, '
392
+ 'which indicates a double compilation of the MCMC code. This '
393
+ 'probably depends on the input state having different type from the '
394
+ 'output state. Check the input is in a format that is the '
395
+ 'same jax would output, e.g., all arrays and scalars are jax '
396
+ 'arrays, with the right shardings.'
397
+ )
398
+ raise RuntimeError(msg)
399
+ self.n_calls += 1
400
+ return self.func(*args, **kwargs)
401
+
402
+
403
+ @partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
404
+ @_CallCounter
297
405
  def _run_mcmc_inner_loop(
298
406
  carry: _Carry,
299
407
  inner_loop_length: int,
@@ -305,28 +413,27 @@ def _run_mcmc_inner_loop(
305
413
  n_skip: Int32[Array, ''],
306
414
  i_outer: Int32[Array, ''],
307
415
  n_iters: Int32[Array, ''],
308
- ):
416
+ ) -> _Carry:
309
417
  def loop_impl(carry: _Carry) -> _Carry:
310
418
  """Loop body to run if i_total < n_iters."""
311
419
  # split random key
312
420
  keys = jaxext.split(carry.key, 3)
313
- carry = replace(carry, key=keys.pop())
421
+ key = keys.pop()
314
422
 
315
423
  # update state
316
- carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart))
317
-
318
- burnin = carry.i_total < n_burn
424
+ bart = mcmcstep.step(keys.pop(), carry.bart)
319
425
 
320
426
  # invoke callback
427
+ callback_state = carry.callback_state
321
428
  if callback is not None:
322
429
  i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
323
430
  rt = callback(
324
431
  key=keys.pop(),
325
- bart=carry.bart,
326
- burnin=burnin,
432
+ bart=bart,
433
+ burnin=carry.i_total < n_burn,
327
434
  i_total=carry.i_total,
328
435
  i_skip=i_skip,
329
- callback_state=carry.callback_state,
436
+ callback_state=callback_state,
330
437
  n_burn=n_burn,
331
438
  n_save=n_save,
332
439
  n_skip=n_skip,
@@ -335,28 +442,26 @@ def _run_mcmc_inner_loop(
335
442
  )
336
443
  if rt is not None:
337
444
  bart, callback_state = rt
338
- carry = replace(carry, bart=bart, callback_state=callback_state)
339
-
340
- def save_to_burnin_trace() -> tuple[PyTree, PyTree]:
341
- return _pytree_at_set(
342
- carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
343
- ), carry.main_trace
344
-
345
- def save_to_main_trace() -> tuple[PyTree, PyTree]:
346
- idx = (carry.i_total - n_burn) // n_skip
347
- return carry.burnin_trace, _pytree_at_set(
348
- carry.main_trace, idx, main_extractor(carry.bart)
349
- )
350
445
 
351
- # save state to trace
352
- burnin_trace, main_trace = lax.cond(
353
- burnin, save_to_burnin_trace, save_to_main_trace
446
+ # save to trace
447
+ burnin_trace, main_trace = _save_state_to_trace(
448
+ carry.burnin_trace,
449
+ carry.main_trace,
450
+ burnin_extractor,
451
+ main_extractor,
452
+ bart,
453
+ carry.i_total,
454
+ n_burn,
455
+ n_skip,
354
456
  )
355
- return replace(
356
- carry,
457
+
458
+ return _Carry(
459
+ bart=bart,
357
460
  i_total=carry.i_total + 1,
461
+ key=key,
358
462
  burnin_trace=burnin_trace,
359
463
  main_trace=main_trace,
464
+ callback_state=callback_state,
360
465
  )
361
466
 
362
467
  def loop_noop(carry: _Carry) -> _Carry:
@@ -364,34 +469,86 @@ def _run_mcmc_inner_loop(
364
469
  return carry
365
470
 
366
471
  def loop(carry: _Carry, _) -> tuple[_Carry, None]:
367
- carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry)
472
+ carry = cond_if_not_profiling(
473
+ carry.i_total < n_iters, loop_impl, loop_noop, carry
474
+ )
368
475
  return carry, None
369
476
 
370
- carry, _ = lax.scan(loop, carry, None, inner_loop_length)
477
+ carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length)
371
478
  return carry
372
479
 
373
480
 
374
- def _pytree_at_set(
375
- dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array]
481
+ @partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3))
482
+ # this is jitted because under profiling _run_mcmc_inner_loop and the loop
483
+ # within it are not, so I need the donate_argnums feature of jit to avoid
484
+ # creating copies of the traces
485
+ def _save_state_to_trace(
486
+ burnin_trace: PyTree,
487
+ main_trace: PyTree,
488
+ burnin_extractor: Callable[[State], PyTree],
489
+ main_extractor: Callable[[State], PyTree],
490
+ bart: State,
491
+ i_total: Int32[Array, ''],
492
+ n_burn: Int32[Array, ''],
493
+ n_skip: Int32[Array, ''],
494
+ ) -> tuple[PyTree, PyTree]:
495
+ # trace index where to save during burnin; out-of-bounds => noop after
496
+ # burnin
497
+ burnin_idx = i_total
498
+
499
+ # trace index where to save during main phase; force it out-of-bounds
500
+ # during burnin
501
+ main_idx = (i_total - n_burn) // n_skip
502
+ noop_idx = jnp.iinfo(jnp.int32).max
503
+ noop_cond = i_total < n_burn
504
+ main_idx = jnp.where(noop_cond, noop_idx, main_idx)
505
+
506
+ # prepare array index
507
+ num_chains = get_num_chains(bart)
508
+ burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains)
509
+ main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains)
510
+
511
+ return burnin_trace, main_trace
512
+
513
+
514
+ def _set(
515
+ trace: PyTree[Array, ' T'],
516
+ index: Int32[Array, ''],
517
+ val: PyTree[Array, ' T'],
518
+ num_chains: int | None,
376
519
  ) -> PyTree[Array, ' T']:
377
- """Map ``dest.at[index].set(val)`` over pytrees."""
520
+ """Do ``trace[index] = val`` but fancier."""
521
+ chain_axis = chain_vmap_axes(val)
522
+
523
+ def at_set(
524
+ trace: Shaped[Array, 'chains samples *shape']
525
+ | Shaped[Array, ' samples *shape']
526
+ | None,
527
+ val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
528
+ chain_axis: int | None,
529
+ ):
530
+ if trace is None or trace.size == 0:
531
+ # this handles the case where an array is empty because jax refuses
532
+ # to index into an axis of length 0, even if just in the abstract,
533
+ # and optional elements that are considered leaves due to `is_leaf`
534
+ # below needed to traverse `chain_axis`.
535
+ return trace
378
536
 
379
- def at_set(dest, val):
380
- if dest.size:
381
- return dest.at[index, ...].set(val)
537
+ if num_chains is None or chain_axis is None:
538
+ ndindex = (index, ...)
382
539
  else:
383
- # this handles the case where an array is empty because jax refuses
384
- # to index into an array of length 0, even if just in the abstract
385
- return dest
540
+ ndindex = (slice(None), index, ...)
541
+
542
+ return trace.at[ndindex].set(val, mode='drop')
386
543
 
387
- return tree.map(at_set, dest, val)
544
+ return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None)
388
545
 
389
546
 
390
547
  def make_default_callback(
548
+ state: State,
391
549
  *,
392
550
  dot_every: int | Integer[Array, ''] | None = 1,
393
551
  report_every: int | Integer[Array, ''] | None = 100,
394
- sparse_on_at: int | Integer[Array, ''] | None = None,
395
552
  ) -> dict[str, Any]:
396
553
  """
397
554
  Prepare a default callback for `run_mcmc`.
@@ -401,14 +558,14 @@ def make_default_callback(
401
558
 
402
559
  Parameters
403
560
  ----------
561
+ state
562
+ The bart state to use the callback with, used to determine device
563
+ sharding.
404
564
  dot_every
405
565
  A dot is printed every `dot_every` MCMC iterations, `None` to disable.
406
566
  report_every
407
567
  A one line report is printed every `report_every` MCMC iterations,
408
568
  `None` to disable.
409
- sparse_on_at
410
- If specified, variable selection is activated starting from this
411
- iteration. If `None`, variable selection is not used.
412
569
 
413
570
  Returns
414
571
  -------
@@ -416,44 +573,30 @@ def make_default_callback(
416
573
 
417
574
  Examples
418
575
  --------
419
- >>> run_mcmc(..., **make_default_callback())
576
+ >>> run_mcmc(key, state, ..., **make_default_callback(state, ...))
420
577
  """
421
578
 
422
- def asarray_or_none(val: None | Any) -> None | Array:
423
- return None if val is None else jnp.asarray(val)
424
-
425
- def callback(*, bart, callback_state, **kwargs):
426
- print_state, sparse_state = callback_state
427
- bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs)
428
- print_callback(callback_state=print_state, bart=bart, **kwargs)
429
- return bart, callback_state
430
- # here I assume that the callbacks don't update their states
579
+ def as_replicated_array_or_none(val: None | Any) -> None | Array:
580
+ return None if val is None else _replicate(jnp.asarray(val), state.config.mesh)
431
581
 
432
582
  return dict(
433
- callback=callback,
434
- callback_state=(
435
- PrintCallbackState(
436
- asarray_or_none(dot_every), asarray_or_none(report_every)
437
- ),
438
- SparseCallbackState(asarray_or_none(sparse_on_at)),
583
+ callback=print_callback,
584
+ callback_state=PrintCallbackState(
585
+ as_replicated_array_or_none(dot_every),
586
+ as_replicated_array_or_none(report_every),
439
587
  ),
440
588
  )
441
589
 
442
590
 
443
591
  class PrintCallbackState(Module):
444
- """State for `print_callback`.
445
-
446
- Parameters
447
- ----------
448
- dot_every
449
- A dot is printed every `dot_every` MCMC iterations, `None` to disable.
450
- report_every
451
- A one line report is printed every `report_every` MCMC iterations,
452
- `None` to disable.
453
- """
592
+ """State for `print_callback`."""
454
593
 
455
594
  dot_every: Int32[Array, ''] | None
595
+ """A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
596
+
456
597
  report_every: Int32[Array, ''] | None
598
+ """A one line report is printed every `report_every` MCMC iterations,
599
+ `None` to disable."""
457
600
 
458
601
 
459
602
  def print_callback(
@@ -468,34 +611,51 @@ def print_callback(
468
611
  **_,
469
612
  ):
470
613
  """Print a dot and/or a report periodically during the MCMC."""
471
- if callback_state.dot_every is not None:
472
- cond = (i_total + 1) % callback_state.dot_every == 0
473
- lax.cond(
474
- cond,
475
- lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
476
- # logging can't do in-line printing so I'll stick to print
477
- lambda: None,
614
+ report_every = callback_state.report_every
615
+ dot_every = callback_state.dot_every
616
+ it = i_total + 1
617
+
618
+ def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']:
619
+ return False if every is None else it % every == 0
620
+
621
+ report_cond = get_cond(report_every)
622
+ dot_cond = get_cond(dot_every)
623
+
624
+ def line_report_branch():
625
+ if report_every is None:
626
+ return
627
+ if dot_every is None:
628
+ print_newline = False
629
+ else:
630
+ print_newline = it % report_every > it % dot_every
631
+ debug.callback(
632
+ _print_report,
633
+ print_dot=dot_cond,
634
+ print_newline=print_newline,
635
+ burnin=burnin,
636
+ it=it,
637
+ n_iters=n_burn + n_save * n_skip,
638
+ num_chains=bart.forest.num_chains(),
639
+ grow_prop_count=bart.forest.grow_prop_count.mean(),
640
+ grow_acc_count=bart.forest.grow_acc_count.mean(),
641
+ prune_acc_count=bart.forest.prune_acc_count.mean(),
642
+ prop_total=bart.forest.split_tree.shape[-2],
643
+ fill=forest_fill(bart.forest.split_tree),
478
644
  )
479
645
 
480
- if callback_state.report_every is not None:
481
-
482
- def print_report():
483
- debug.callback(
484
- _print_report,
485
- newline=callback_state.dot_every is not None,
486
- burnin=burnin,
487
- i_total=i_total,
488
- n_iters=n_burn + n_save * n_skip,
489
- grow_prop_count=bart.forest.grow_prop_count,
490
- grow_acc_count=bart.forest.grow_acc_count,
491
- prune_prop_count=bart.forest.prune_prop_count,
492
- prune_acc_count=bart.forest.prune_acc_count,
493
- prop_total=len(bart.forest.leaf_tree),
494
- fill=grove.forest_fill(bart.forest.split_tree),
495
- )
646
+ def just_dot_branch():
647
+ if dot_every is None:
648
+ return
649
+ debug.callback(
650
+ lambda: print('.', end='', flush=True) # noqa: T201
651
+ )
652
+ # logging can't do in-line printing so we use print
496
653
 
497
- cond = (i_total + 1) % callback_state.report_every == 0
498
- lax.cond(cond, print_report, lambda: None)
654
+ cond_if_not_profiling(
655
+ report_cond,
656
+ line_report_branch,
657
+ lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None),
658
+ )
499
659
 
500
660
 
501
661
  def _convert_jax_arrays_in_args(func: Callable) -> Callable:
@@ -506,15 +666,15 @@ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
506
666
  """
507
667
 
508
668
  def convert_jax_arrays(pytree: PyTree) -> PyTree:
509
- def convert_jax_arrays(val: Any) -> Any:
510
- if not isinstance(val, jax.Array):
669
+ def convert_jax_array(val: Any) -> Any:
670
+ if not isinstance(val, Array):
511
671
  return val
512
672
  elif val.shape:
513
673
  return numpy.array(val)
514
674
  else:
515
675
  return val.item()
516
676
 
517
- return tree.map(convert_jax_arrays, pytree)
677
+ return tree.map(convert_jax_array, pytree)
518
678
 
519
679
  @wraps(func)
520
680
  def new_func(*args, **kw):
@@ -530,126 +690,157 @@ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
530
690
  # deadlock with the main thread
531
691
  def _print_report(
532
692
  *,
533
- newline: bool,
693
+ print_dot: bool,
694
+ print_newline: bool,
534
695
  burnin: bool,
535
- i_total: int,
696
+ it: int,
536
697
  n_iters: int,
537
- grow_prop_count: int,
538
- grow_acc_count: int,
539
- prune_prop_count: int,
540
- prune_acc_count: int,
698
+ num_chains: int | None,
699
+ grow_prop_count: float,
700
+ grow_acc_count: float,
701
+ prune_acc_count: float,
541
702
  prop_total: int,
542
703
  fill: float,
543
704
  ):
544
705
  """Print the report for `print_callback`."""
545
-
546
- def acc_string(acc_count, prop_count):
547
- if prop_count:
548
- return f'{acc_count / prop_count:.0%}'
549
- else:
550
- return 'n/d'
551
-
706
+ # compute fractions
552
707
  grow_prop = grow_prop_count / prop_total
553
- prune_prop = prune_prop_count / prop_total
554
- grow_acc = acc_string(grow_acc_count, grow_prop_count)
555
- prune_acc = acc_string(prune_acc_count, prune_prop_count)
708
+ move_acc = (grow_acc_count + prune_acc_count) / prop_total
709
+
710
+ # determine prefix
711
+ if print_dot:
712
+ prefix = '.\n'
713
+ elif print_newline:
714
+ prefix = '\n'
715
+ else:
716
+ prefix = ''
556
717
 
557
- prefix = '\n' if newline else ''
558
- suffix = ' (burnin)' if burnin else ''
718
+ # determine suffix in parentheses
719
+ msgs = []
720
+ if num_chains is not None:
721
+ msgs.append(f'avg. {num_chains} chains')
722
+ if burnin:
723
+ msgs.append('burnin')
724
+ suffix = f' ({", ".join(msgs)})' if msgs else ''
559
725
 
560
726
  print( # noqa: T201, see print_callback for why not logging
561
- f'{prefix}It {i_total + 1}/{n_iters} '
562
- f'grow P={grow_prop:.0%} A={grow_acc}, '
563
- f'prune P={prune_prop:.0%} A={prune_acc}, '
564
- f'fill={fill:.0%}{suffix}'
727
+ f'{prefix}Iteration {it}/{n_iters}, '
728
+ f'grow prob: {grow_prop:.0%}, '
729
+ f'move acc: {move_acc:.0%}, '
730
+ f'fill: {fill:.0%}{suffix}'
565
731
  )
566
732
 
567
733
 
568
- class SparseCallbackState(Module):
569
- """State for `sparse_callback`.
570
-
571
- Parameters
572
- ----------
573
- sparse_on_at
574
- If specified, variable selection is activated starting from this
575
- iteration. If `None`, variable selection is not used.
576
- """
577
-
578
- sparse_on_at: Int32[Array, ''] | None
579
-
580
-
581
- def sparse_callback(
582
- *,
583
- key: Key[Array, ''],
584
- bart: State,
585
- i_total: Int32[Array, ''],
586
- callback_state: SparseCallbackState,
587
- **_,
588
- ):
589
- """Perform variable selection, see `mcmcstep.step_sparse`."""
590
- if callback_state.sparse_on_at is not None:
591
- bart = lax.cond(
592
- i_total < callback_state.sparse_on_at,
593
- lambda: bart,
594
- lambda: mcmcstep.step_sparse(key, bart),
595
- )
596
- return bart, callback_state
597
-
598
-
599
- class Trace(grove.TreeHeaps, Protocol):
734
+ class Trace(TreeHeaps, Protocol):
600
735
  """Protocol for a MCMC trace."""
601
736
 
602
- offset: Float32[Array, ' trace_length']
737
+ offset: Float32[Array, '*trace_shape']
603
738
 
604
739
 
605
740
  class TreesTrace(Module):
606
741
  """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
607
742
 
608
- leaf_tree: Float32[Array, 'trace_length num_trees 2**d']
609
- var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
610
- split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
743
+ leaf_tree: (
744
+ Float32[Array, '*trace_shape num_trees 2**d']
745
+ | Float32[Array, '*trace_shape num_trees k 2**d']
746
+ )
747
+ var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
748
+ split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
611
749
 
612
750
  @classmethod
613
- def from_dataclass(cls, obj: grove.TreeHeaps):
751
+ def from_dataclass(cls, obj: TreeHeaps):
614
752
  """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
615
753
  return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
616
754
 
617
755
 
618
- @jax.jit
756
+ @jit
619
757
  def evaluate_trace(
620
- trace: Trace, X: UInt[Array, 'p n']
621
- ) -> Float32[Array, 'trace_length n']:
758
+ X: UInt[Array, 'p n'], trace: Trace
759
+ ) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
622
760
  """
623
761
  Compute predictions for all iterations of the BART MCMC.
624
762
 
625
763
  Parameters
626
764
  ----------
627
- trace
628
- A trace of the BART MCMC, as returned by `run_mcmc`.
629
765
  X
630
766
  The predictors matrix, with `p` predictors and `n` observations.
767
+ trace
768
+ A main trace of the BART MCMC, as returned by `run_mcmc`.
631
769
 
632
770
  Returns
633
771
  -------
634
- The predictions for each iteration of the MCMC.
772
+ The predictions for each chain and iteration of the MCMC.
635
773
  """
636
- evaluate_trees = partial(grove.evaluate_forest, sum_trees=False)
637
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0))
638
- trees = TreesTrace.from_dataclass(trace)
774
+ # per-device memory limit
775
+ max_io_nbytes = 2**27 # 128 MiB
776
+
777
+ # adjust memory limit for number of devices
778
+ mesh = jax.typeof(trace.leaf_tree).sharding.mesh
779
+ num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data')
780
+ max_io_nbytes *= num_devices
781
+
782
+ # determine batching axes
783
+ has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes
784
+ if has_chains:
785
+ sample_axis = 1
786
+ tree_axis = 2
787
+ else:
788
+ sample_axis = 0
789
+ tree_axis = 1
790
+
791
+ # batch and sum over trees
792
+ batched_eval = autobatch(
793
+ evaluate_forest,
794
+ max_io_nbytes,
795
+ (None, tree_axis),
796
+ tree_axis,
797
+ reduce_ufunc=jnp.add,
798
+ )
639
799
 
640
- def loop(_, item):
641
- offset, trees = item
642
- values = evaluate_trees(X, trees)
643
- return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32)
800
+ # determine output shape (to avoid autobatch tracing everything 4 times)
801
+ is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
802
+ k = trace.leaf_tree.shape[-2] if is_mv else 1
803
+ mv_shape = (k,) if is_mv else ()
804
+ _, n = X.shape
805
+ out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n)
806
+
807
+ # adjust memory limit keeping into account that trees are summed over
808
+ num_trees, hts = trace.split_tree.shape[-2:]
809
+ out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest
810
+ core_io_size = (
811
+ num_trees
812
+ * hts
813
+ * (
814
+ 2 * k * trace.leaf_tree.itemsize
815
+ + trace.var_tree.itemsize
816
+ + trace.split_tree.itemsize
817
+ )
818
+ + out_size
819
+ )
820
+ core_int_size = (num_trees - 1) * out_size
821
+ max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size)))
822
+
823
+ # batch over mcmc samples
824
+ batched_eval = autobatch(
825
+ batched_eval,
826
+ max_io_nbytes,
827
+ (None, sample_axis),
828
+ sample_axis,
829
+ warn_on_overflow=False, # the inner autobatch will handle it
830
+ result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
831
+ )
832
+
833
+ # extract only the trees from the trace
834
+ trees = TreesTrace.from_dataclass(trace)
644
835
 
645
- _, y = lax.scan(loop, None, (trace.offset, trees))
646
- return y
836
+ # evaluate trees
837
+ y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
838
+ y_centered = batched_eval(X, trees)
839
+ return y_centered + trace.offset[..., None]
647
840
 
648
841
 
649
- @partial(jax.jit, static_argnums=(0,))
650
- def compute_varcount(
651
- p: int, trace: grove.TreeHeaps
652
- ) -> Int32[Array, 'trace_length {p}']:
842
+ @partial(jit, static_argnums=(0,))
843
+ def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
653
844
  """
654
845
  Count how many times each predictor is used in each MCMC state.
655
846
 
@@ -658,11 +849,11 @@ def compute_varcount(
658
849
  p
659
850
  The number of predictors.
660
851
  trace
661
- A trace of the BART MCMC, as returned by `run_mcmc`.
852
+ A main trace of the BART MCMC, as returned by `run_mcmc`.
662
853
 
663
854
  Returns
664
855
  -------
665
856
  Histogram of predictor usage in each MCMC state.
666
857
  """
667
- vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0))
668
- return vmapped_var_histogram(p, trace.var_tree, trace.split_tree)
858
+ # var_tree has shape (chains? samples trees nodes)
859
+ return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1)