bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +4 -2
- bartz/{BART.py → _interface.py} +256 -132
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +269 -314
- bartz/grove.py +124 -68
- bartz/jaxext/__init__.py +101 -27
- bartz/jaxext/_autobatch.py +257 -51
- bartz/jaxext/scipy/__init__.py +1 -1
- bartz/jaxext/scipy/special.py +3 -4
- bartz/jaxext/scipy/stats.py +1 -1
- bartz/mcmcloop.py +399 -208
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +1 -1
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/METADATA +17 -11
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/mcmcstep.py +0 -2616
- bartz-0.7.0.dist-info/RECORD +0 -17
bartz/mcmcloop.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcloop.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -28,39 +28,62 @@ The entry points are `run_mcmc` and `make_default_callback`.
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
from collections.abc import Callable
|
|
31
|
-
from dataclasses import fields
|
|
31
|
+
from dataclasses import fields
|
|
32
32
|
from functools import partial, wraps
|
|
33
|
+
from math import floor
|
|
33
34
|
from typing import Any, Protocol
|
|
34
35
|
|
|
35
36
|
import jax
|
|
36
37
|
import numpy
|
|
37
38
|
from equinox import Module
|
|
38
|
-
from jax import
|
|
39
|
+
from jax import (
|
|
40
|
+
NamedSharding,
|
|
41
|
+
ShapeDtypeStruct,
|
|
42
|
+
debug,
|
|
43
|
+
device_put,
|
|
44
|
+
eval_shape,
|
|
45
|
+
jit,
|
|
46
|
+
tree,
|
|
47
|
+
)
|
|
39
48
|
from jax import numpy as jnp
|
|
40
49
|
from jax.nn import softmax
|
|
50
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
41
51
|
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
|
|
42
52
|
|
|
43
|
-
from bartz import
|
|
53
|
+
from bartz import jaxext, mcmcstep
|
|
54
|
+
from bartz._profiler import (
|
|
55
|
+
cond_if_not_profiling,
|
|
56
|
+
get_profile_mode,
|
|
57
|
+
jit_if_not_profiling,
|
|
58
|
+
scan_if_not_profiling,
|
|
59
|
+
)
|
|
60
|
+
from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram
|
|
61
|
+
from bartz.jaxext import autobatch
|
|
44
62
|
from bartz.mcmcstep import State
|
|
63
|
+
from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
|
|
45
64
|
|
|
46
65
|
|
|
47
66
|
class BurninTrace(Module):
|
|
48
67
|
"""MCMC trace with only diagnostic values."""
|
|
49
68
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
69
|
+
error_cov_inv: (
|
|
70
|
+
Float32[Array, '*chains_and_samples']
|
|
71
|
+
| Float32[Array, '*chains_and_samples k k']
|
|
72
|
+
| None
|
|
73
|
+
) = field(chains=True)
|
|
74
|
+
theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
|
|
75
|
+
grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
76
|
+
grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
77
|
+
prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
78
|
+
prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
79
|
+
log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
|
|
80
|
+
log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
|
|
58
81
|
|
|
59
82
|
@classmethod
|
|
60
83
|
def from_state(cls, state: State) -> 'BurninTrace':
|
|
61
84
|
"""Create a single-item burn-in trace from a MCMC state."""
|
|
62
85
|
return cls(
|
|
63
|
-
|
|
86
|
+
error_cov_inv=state.error_cov_inv,
|
|
64
87
|
theta=state.forest.theta,
|
|
65
88
|
grow_prop_count=state.forest.grow_prop_count,
|
|
66
89
|
grow_acc_count=state.forest.grow_acc_count,
|
|
@@ -74,11 +97,14 @@ class BurninTrace(Module):
|
|
|
74
97
|
class MainTrace(BurninTrace):
|
|
75
98
|
"""MCMC trace with trees and diagnostic values."""
|
|
76
99
|
|
|
77
|
-
leaf_tree:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
100
|
+
leaf_tree: (
|
|
101
|
+
Float32[Array, '*chains_and_samples 2**d']
|
|
102
|
+
| Float32[Array, '*chains_and_samples k 2**d']
|
|
103
|
+
) = field(chains=True)
|
|
104
|
+
var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
|
|
105
|
+
split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
|
|
106
|
+
offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
|
|
107
|
+
varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
|
|
82
108
|
|
|
83
109
|
@classmethod
|
|
84
110
|
def from_state(cls, state: State) -> 'MainTrace':
|
|
@@ -171,8 +197,12 @@ class _Carry(Module):
|
|
|
171
197
|
bart: State
|
|
172
198
|
i_total: Int32[Array, '']
|
|
173
199
|
key: Key[Array, '']
|
|
174
|
-
burnin_trace: PyTree[
|
|
175
|
-
|
|
200
|
+
burnin_trace: PyTree[
|
|
201
|
+
Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
|
|
202
|
+
]
|
|
203
|
+
main_trace: PyTree[
|
|
204
|
+
Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
|
|
205
|
+
]
|
|
176
206
|
callback_state: CallbackState
|
|
177
207
|
|
|
178
208
|
|
|
@@ -188,7 +218,11 @@ def run_mcmc(
|
|
|
188
218
|
callback_state: CallbackState = None,
|
|
189
219
|
burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
|
|
190
220
|
main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
|
|
191
|
-
) -> tuple[
|
|
221
|
+
) -> tuple[
|
|
222
|
+
State,
|
|
223
|
+
PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']],
|
|
224
|
+
PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']],
|
|
225
|
+
]:
|
|
192
226
|
"""
|
|
193
227
|
Run the MCMC for the BART posterior.
|
|
194
228
|
|
|
@@ -226,9 +260,9 @@ def run_mcmc(
|
|
|
226
260
|
The initial custom state for the callback.
|
|
227
261
|
burnin_extractor
|
|
228
262
|
main_extractor
|
|
229
|
-
Functions that extract the variables to be saved respectively
|
|
230
|
-
|
|
231
|
-
|
|
263
|
+
Functions that extract the variables to be saved respectively in the
|
|
264
|
+
burnin trace and main traces, given the MCMC state as argument. Must
|
|
265
|
+
return a pytree, and must be vmappable.
|
|
232
266
|
|
|
233
267
|
Returns
|
|
234
268
|
-------
|
|
@@ -239,17 +273,20 @@ def run_mcmc(
|
|
|
239
273
|
main_trace : PyTree[Shaped[Array, 'n_save *']]
|
|
240
274
|
The trace of the main phase. For the default layout, see `MainTrace`.
|
|
241
275
|
|
|
276
|
+
Raises
|
|
277
|
+
------
|
|
278
|
+
RuntimeError
|
|
279
|
+
If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and
|
|
280
|
+
with settings that would create unrolled loops in the trace.
|
|
281
|
+
|
|
242
282
|
Notes
|
|
243
283
|
-----
|
|
244
284
|
The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
|
|
245
285
|
not include the initial state, and include the final state.
|
|
246
286
|
"""
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
burnin_trace = empty_trace(n_burn, bart, burnin_extractor)
|
|
252
|
-
main_trace = empty_trace(n_save, bart, main_extractor)
|
|
287
|
+
# create empty traces
|
|
288
|
+
burnin_trace = _empty_trace(n_burn, bart, burnin_extractor)
|
|
289
|
+
main_trace = _empty_trace(n_save, bart, main_extractor)
|
|
253
290
|
|
|
254
291
|
# determine number of iterations for inner and outer loops
|
|
255
292
|
n_iters = n_burn + n_skip * n_save
|
|
@@ -262,7 +299,27 @@ def run_mcmc(
|
|
|
262
299
|
# setting to 0 would make for a clean noop, but it's useful to keep the
|
|
263
300
|
# same code path for benchmarking and testing
|
|
264
301
|
|
|
265
|
-
|
|
302
|
+
# error if under jit and there are unrolled loops or profile mode is on
|
|
303
|
+
under_jit = not hasattr(jnp.empty(0), 'platform')
|
|
304
|
+
if under_jit and (n_outer > 1 or get_profile_mode()):
|
|
305
|
+
msg = (
|
|
306
|
+
'`run_mcmc` was called within a jit-compiled function and '
|
|
307
|
+
'there are either more than 1 outer loops or profile mode is active, '
|
|
308
|
+
'please either do not jit, set `inner_loop_length=None`, or disable '
|
|
309
|
+
'profile mode.'
|
|
310
|
+
)
|
|
311
|
+
raise RuntimeError(msg)
|
|
312
|
+
|
|
313
|
+
replicate = partial(_replicate, mesh=bart.config.mesh)
|
|
314
|
+
carry = _Carry(
|
|
315
|
+
bart,
|
|
316
|
+
replicate(jnp.int32(0)),
|
|
317
|
+
replicate(key),
|
|
318
|
+
burnin_trace,
|
|
319
|
+
main_trace,
|
|
320
|
+
callback_state,
|
|
321
|
+
)
|
|
322
|
+
_run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001
|
|
266
323
|
for i_outer in range(n_outer):
|
|
267
324
|
carry = _run_mcmc_inner_loop(
|
|
268
325
|
carry,
|
|
@@ -280,6 +337,30 @@ def run_mcmc(
|
|
|
280
337
|
return carry.bart, carry.burnin_trace, carry.main_trace
|
|
281
338
|
|
|
282
339
|
|
|
340
|
+
def _replicate(x: Array, mesh: Mesh | None) -> Array:
|
|
341
|
+
if mesh is None:
|
|
342
|
+
return x
|
|
343
|
+
else:
|
|
344
|
+
return device_put(x, NamedSharding(mesh, PartitionSpec()))
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@partial(jit, static_argnums=(0, 2))
|
|
348
|
+
def _empty_trace(
|
|
349
|
+
length: int, bart: State, extractor: Callable[[State], PyTree]
|
|
350
|
+
) -> PyTree:
|
|
351
|
+
num_chains = get_num_chains(bart)
|
|
352
|
+
if num_chains is None:
|
|
353
|
+
out_axes = 0
|
|
354
|
+
else:
|
|
355
|
+
example_output = eval_shape(extractor, bart)
|
|
356
|
+
chain_axes = chain_vmap_axes(example_output)
|
|
357
|
+
out_axes = tree.map(
|
|
358
|
+
lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
|
|
359
|
+
)
|
|
360
|
+
return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@jit
|
|
283
364
|
def _compute_i_skip(
|
|
284
365
|
i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
|
|
285
366
|
) -> Int32[Array, '']:
|
|
@@ -293,7 +374,34 @@ def _compute_i_skip(
|
|
|
293
374
|
)
|
|
294
375
|
|
|
295
376
|
|
|
296
|
-
|
|
377
|
+
class _CallCounter:
|
|
378
|
+
"""Wrap a callable to check it's not called more than once."""
|
|
379
|
+
|
|
380
|
+
def __init__(self, func: Callable) -> None:
|
|
381
|
+
self.func = func
|
|
382
|
+
self.n_calls = 0
|
|
383
|
+
|
|
384
|
+
def reset_call_counter(self) -> None:
|
|
385
|
+
"""Reset the call counter."""
|
|
386
|
+
self.n_calls = 0
|
|
387
|
+
|
|
388
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
389
|
+
if self.n_calls and not get_profile_mode():
|
|
390
|
+
msg = (
|
|
391
|
+
'The inner loop of `run_mcmc` was traced more than once, '
|
|
392
|
+
'which indicates a double compilation of the MCMC code. This '
|
|
393
|
+
'probably depends on the input state having different type from the '
|
|
394
|
+
'output state. Check the input is in a format that is the '
|
|
395
|
+
'same jax would output, e.g., all arrays and scalars are jax '
|
|
396
|
+
'arrays, with the right shardings.'
|
|
397
|
+
)
|
|
398
|
+
raise RuntimeError(msg)
|
|
399
|
+
self.n_calls += 1
|
|
400
|
+
return self.func(*args, **kwargs)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
|
|
404
|
+
@_CallCounter
|
|
297
405
|
def _run_mcmc_inner_loop(
|
|
298
406
|
carry: _Carry,
|
|
299
407
|
inner_loop_length: int,
|
|
@@ -305,28 +413,27 @@ def _run_mcmc_inner_loop(
|
|
|
305
413
|
n_skip: Int32[Array, ''],
|
|
306
414
|
i_outer: Int32[Array, ''],
|
|
307
415
|
n_iters: Int32[Array, ''],
|
|
308
|
-
):
|
|
416
|
+
) -> _Carry:
|
|
309
417
|
def loop_impl(carry: _Carry) -> _Carry:
|
|
310
418
|
"""Loop body to run if i_total < n_iters."""
|
|
311
419
|
# split random key
|
|
312
420
|
keys = jaxext.split(carry.key, 3)
|
|
313
|
-
|
|
421
|
+
key = keys.pop()
|
|
314
422
|
|
|
315
423
|
# update state
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
burnin = carry.i_total < n_burn
|
|
424
|
+
bart = mcmcstep.step(keys.pop(), carry.bart)
|
|
319
425
|
|
|
320
426
|
# invoke callback
|
|
427
|
+
callback_state = carry.callback_state
|
|
321
428
|
if callback is not None:
|
|
322
429
|
i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
|
|
323
430
|
rt = callback(
|
|
324
431
|
key=keys.pop(),
|
|
325
|
-
bart=
|
|
326
|
-
burnin=
|
|
432
|
+
bart=bart,
|
|
433
|
+
burnin=carry.i_total < n_burn,
|
|
327
434
|
i_total=carry.i_total,
|
|
328
435
|
i_skip=i_skip,
|
|
329
|
-
callback_state=
|
|
436
|
+
callback_state=callback_state,
|
|
330
437
|
n_burn=n_burn,
|
|
331
438
|
n_save=n_save,
|
|
332
439
|
n_skip=n_skip,
|
|
@@ -335,28 +442,26 @@ def _run_mcmc_inner_loop(
|
|
|
335
442
|
)
|
|
336
443
|
if rt is not None:
|
|
337
444
|
bart, callback_state = rt
|
|
338
|
-
carry = replace(carry, bart=bart, callback_state=callback_state)
|
|
339
|
-
|
|
340
|
-
def save_to_burnin_trace() -> tuple[PyTree, PyTree]:
|
|
341
|
-
return _pytree_at_set(
|
|
342
|
-
carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
|
|
343
|
-
), carry.main_trace
|
|
344
|
-
|
|
345
|
-
def save_to_main_trace() -> tuple[PyTree, PyTree]:
|
|
346
|
-
idx = (carry.i_total - n_burn) // n_skip
|
|
347
|
-
return carry.burnin_trace, _pytree_at_set(
|
|
348
|
-
carry.main_trace, idx, main_extractor(carry.bart)
|
|
349
|
-
)
|
|
350
445
|
|
|
351
|
-
# save
|
|
352
|
-
burnin_trace, main_trace =
|
|
353
|
-
|
|
446
|
+
# save to trace
|
|
447
|
+
burnin_trace, main_trace = _save_state_to_trace(
|
|
448
|
+
carry.burnin_trace,
|
|
449
|
+
carry.main_trace,
|
|
450
|
+
burnin_extractor,
|
|
451
|
+
main_extractor,
|
|
452
|
+
bart,
|
|
453
|
+
carry.i_total,
|
|
454
|
+
n_burn,
|
|
455
|
+
n_skip,
|
|
354
456
|
)
|
|
355
|
-
|
|
356
|
-
|
|
457
|
+
|
|
458
|
+
return _Carry(
|
|
459
|
+
bart=bart,
|
|
357
460
|
i_total=carry.i_total + 1,
|
|
461
|
+
key=key,
|
|
358
462
|
burnin_trace=burnin_trace,
|
|
359
463
|
main_trace=main_trace,
|
|
464
|
+
callback_state=callback_state,
|
|
360
465
|
)
|
|
361
466
|
|
|
362
467
|
def loop_noop(carry: _Carry) -> _Carry:
|
|
@@ -364,34 +469,86 @@ def _run_mcmc_inner_loop(
|
|
|
364
469
|
return carry
|
|
365
470
|
|
|
366
471
|
def loop(carry: _Carry, _) -> tuple[_Carry, None]:
|
|
367
|
-
carry =
|
|
472
|
+
carry = cond_if_not_profiling(
|
|
473
|
+
carry.i_total < n_iters, loop_impl, loop_noop, carry
|
|
474
|
+
)
|
|
368
475
|
return carry, None
|
|
369
476
|
|
|
370
|
-
carry, _ =
|
|
477
|
+
carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length)
|
|
371
478
|
return carry
|
|
372
479
|
|
|
373
480
|
|
|
374
|
-
|
|
375
|
-
|
|
481
|
+
@partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3))
|
|
482
|
+
# this is jitted because under profiling _run_mcmc_inner_loop and the loop
|
|
483
|
+
# within it are not, so I need the donate_argnums feature of jit to avoid
|
|
484
|
+
# creating copies of the traces
|
|
485
|
+
def _save_state_to_trace(
|
|
486
|
+
burnin_trace: PyTree,
|
|
487
|
+
main_trace: PyTree,
|
|
488
|
+
burnin_extractor: Callable[[State], PyTree],
|
|
489
|
+
main_extractor: Callable[[State], PyTree],
|
|
490
|
+
bart: State,
|
|
491
|
+
i_total: Int32[Array, ''],
|
|
492
|
+
n_burn: Int32[Array, ''],
|
|
493
|
+
n_skip: Int32[Array, ''],
|
|
494
|
+
) -> tuple[PyTree, PyTree]:
|
|
495
|
+
# trace index where to save during burnin; out-of-bounds => noop after
|
|
496
|
+
# burnin
|
|
497
|
+
burnin_idx = i_total
|
|
498
|
+
|
|
499
|
+
# trace index where to save during main phase; force it out-of-bounds
|
|
500
|
+
# during burnin
|
|
501
|
+
main_idx = (i_total - n_burn) // n_skip
|
|
502
|
+
noop_idx = jnp.iinfo(jnp.int32).max
|
|
503
|
+
noop_cond = i_total < n_burn
|
|
504
|
+
main_idx = jnp.where(noop_cond, noop_idx, main_idx)
|
|
505
|
+
|
|
506
|
+
# prepare array index
|
|
507
|
+
num_chains = get_num_chains(bart)
|
|
508
|
+
burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains)
|
|
509
|
+
main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains)
|
|
510
|
+
|
|
511
|
+
return burnin_trace, main_trace
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def _set(
|
|
515
|
+
trace: PyTree[Array, ' T'],
|
|
516
|
+
index: Int32[Array, ''],
|
|
517
|
+
val: PyTree[Array, ' T'],
|
|
518
|
+
num_chains: int | None,
|
|
376
519
|
) -> PyTree[Array, ' T']:
|
|
377
|
-
"""
|
|
520
|
+
"""Do ``trace[index] = val`` but fancier."""
|
|
521
|
+
chain_axis = chain_vmap_axes(val)
|
|
522
|
+
|
|
523
|
+
def at_set(
|
|
524
|
+
trace: Shaped[Array, 'chains samples *shape']
|
|
525
|
+
| Shaped[Array, ' samples *shape']
|
|
526
|
+
| None,
|
|
527
|
+
val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
|
|
528
|
+
chain_axis: int | None,
|
|
529
|
+
):
|
|
530
|
+
if trace is None or trace.size == 0:
|
|
531
|
+
# this handles the case where an array is empty because jax refuses
|
|
532
|
+
# to index into an axis of length 0, even if just in the abstract,
|
|
533
|
+
# and optional elements that are considered leaves due to `is_leaf`
|
|
534
|
+
# below needed to traverse `chain_axis`.
|
|
535
|
+
return trace
|
|
378
536
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
return dest.at[index, ...].set(val)
|
|
537
|
+
if num_chains is None or chain_axis is None:
|
|
538
|
+
ndindex = (index, ...)
|
|
382
539
|
else:
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
540
|
+
ndindex = (slice(None), index, ...)
|
|
541
|
+
|
|
542
|
+
return trace.at[ndindex].set(val, mode='drop')
|
|
386
543
|
|
|
387
|
-
return tree.map(at_set,
|
|
544
|
+
return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None)
|
|
388
545
|
|
|
389
546
|
|
|
390
547
|
def make_default_callback(
|
|
548
|
+
state: State,
|
|
391
549
|
*,
|
|
392
550
|
dot_every: int | Integer[Array, ''] | None = 1,
|
|
393
551
|
report_every: int | Integer[Array, ''] | None = 100,
|
|
394
|
-
sparse_on_at: int | Integer[Array, ''] | None = None,
|
|
395
552
|
) -> dict[str, Any]:
|
|
396
553
|
"""
|
|
397
554
|
Prepare a default callback for `run_mcmc`.
|
|
@@ -401,14 +558,14 @@ def make_default_callback(
|
|
|
401
558
|
|
|
402
559
|
Parameters
|
|
403
560
|
----------
|
|
561
|
+
state
|
|
562
|
+
The bart state to use the callback with, used to determine device
|
|
563
|
+
sharding.
|
|
404
564
|
dot_every
|
|
405
565
|
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
406
566
|
report_every
|
|
407
567
|
A one line report is printed every `report_every` MCMC iterations,
|
|
408
568
|
`None` to disable.
|
|
409
|
-
sparse_on_at
|
|
410
|
-
If specified, variable selection is activated starting from this
|
|
411
|
-
iteration. If `None`, variable selection is not used.
|
|
412
569
|
|
|
413
570
|
Returns
|
|
414
571
|
-------
|
|
@@ -416,44 +573,30 @@ def make_default_callback(
|
|
|
416
573
|
|
|
417
574
|
Examples
|
|
418
575
|
--------
|
|
419
|
-
>>> run_mcmc(..., **make_default_callback())
|
|
576
|
+
>>> run_mcmc(key, state, ..., **make_default_callback(state, ...))
|
|
420
577
|
"""
|
|
421
578
|
|
|
422
|
-
def
|
|
423
|
-
return None if val is None else jnp.asarray(val)
|
|
424
|
-
|
|
425
|
-
def callback(*, bart, callback_state, **kwargs):
|
|
426
|
-
print_state, sparse_state = callback_state
|
|
427
|
-
bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs)
|
|
428
|
-
print_callback(callback_state=print_state, bart=bart, **kwargs)
|
|
429
|
-
return bart, callback_state
|
|
430
|
-
# here I assume that the callbacks don't update their states
|
|
579
|
+
def as_replicated_array_or_none(val: None | Any) -> None | Array:
|
|
580
|
+
return None if val is None else _replicate(jnp.asarray(val), state.config.mesh)
|
|
431
581
|
|
|
432
582
|
return dict(
|
|
433
|
-
callback=
|
|
434
|
-
callback_state=(
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
),
|
|
438
|
-
SparseCallbackState(asarray_or_none(sparse_on_at)),
|
|
583
|
+
callback=print_callback,
|
|
584
|
+
callback_state=PrintCallbackState(
|
|
585
|
+
as_replicated_array_or_none(dot_every),
|
|
586
|
+
as_replicated_array_or_none(report_every),
|
|
439
587
|
),
|
|
440
588
|
)
|
|
441
589
|
|
|
442
590
|
|
|
443
591
|
class PrintCallbackState(Module):
|
|
444
|
-
"""State for `print_callback`.
|
|
445
|
-
|
|
446
|
-
Parameters
|
|
447
|
-
----------
|
|
448
|
-
dot_every
|
|
449
|
-
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
450
|
-
report_every
|
|
451
|
-
A one line report is printed every `report_every` MCMC iterations,
|
|
452
|
-
`None` to disable.
|
|
453
|
-
"""
|
|
592
|
+
"""State for `print_callback`."""
|
|
454
593
|
|
|
455
594
|
dot_every: Int32[Array, ''] | None
|
|
595
|
+
"""A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
|
|
596
|
+
|
|
456
597
|
report_every: Int32[Array, ''] | None
|
|
598
|
+
"""A one line report is printed every `report_every` MCMC iterations,
|
|
599
|
+
`None` to disable."""
|
|
457
600
|
|
|
458
601
|
|
|
459
602
|
def print_callback(
|
|
@@ -468,34 +611,51 @@ def print_callback(
|
|
|
468
611
|
**_,
|
|
469
612
|
):
|
|
470
613
|
"""Print a dot and/or a report periodically during the MCMC."""
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
614
|
+
report_every = callback_state.report_every
|
|
615
|
+
dot_every = callback_state.dot_every
|
|
616
|
+
it = i_total + 1
|
|
617
|
+
|
|
618
|
+
def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']:
|
|
619
|
+
return False if every is None else it % every == 0
|
|
620
|
+
|
|
621
|
+
report_cond = get_cond(report_every)
|
|
622
|
+
dot_cond = get_cond(dot_every)
|
|
623
|
+
|
|
624
|
+
def line_report_branch():
|
|
625
|
+
if report_every is None:
|
|
626
|
+
return
|
|
627
|
+
if dot_every is None:
|
|
628
|
+
print_newline = False
|
|
629
|
+
else:
|
|
630
|
+
print_newline = it % report_every > it % dot_every
|
|
631
|
+
debug.callback(
|
|
632
|
+
_print_report,
|
|
633
|
+
print_dot=dot_cond,
|
|
634
|
+
print_newline=print_newline,
|
|
635
|
+
burnin=burnin,
|
|
636
|
+
it=it,
|
|
637
|
+
n_iters=n_burn + n_save * n_skip,
|
|
638
|
+
num_chains=bart.forest.num_chains(),
|
|
639
|
+
grow_prop_count=bart.forest.grow_prop_count.mean(),
|
|
640
|
+
grow_acc_count=bart.forest.grow_acc_count.mean(),
|
|
641
|
+
prune_acc_count=bart.forest.prune_acc_count.mean(),
|
|
642
|
+
prop_total=bart.forest.split_tree.shape[-2],
|
|
643
|
+
fill=forest_fill(bart.forest.split_tree),
|
|
478
644
|
)
|
|
479
645
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
i_total=i_total,
|
|
488
|
-
n_iters=n_burn + n_save * n_skip,
|
|
489
|
-
grow_prop_count=bart.forest.grow_prop_count,
|
|
490
|
-
grow_acc_count=bart.forest.grow_acc_count,
|
|
491
|
-
prune_prop_count=bart.forest.prune_prop_count,
|
|
492
|
-
prune_acc_count=bart.forest.prune_acc_count,
|
|
493
|
-
prop_total=len(bart.forest.leaf_tree),
|
|
494
|
-
fill=grove.forest_fill(bart.forest.split_tree),
|
|
495
|
-
)
|
|
646
|
+
def just_dot_branch():
|
|
647
|
+
if dot_every is None:
|
|
648
|
+
return
|
|
649
|
+
debug.callback(
|
|
650
|
+
lambda: print('.', end='', flush=True) # noqa: T201
|
|
651
|
+
)
|
|
652
|
+
# logging can't do in-line printing so we use print
|
|
496
653
|
|
|
497
|
-
|
|
498
|
-
|
|
654
|
+
cond_if_not_profiling(
|
|
655
|
+
report_cond,
|
|
656
|
+
line_report_branch,
|
|
657
|
+
lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None),
|
|
658
|
+
)
|
|
499
659
|
|
|
500
660
|
|
|
501
661
|
def _convert_jax_arrays_in_args(func: Callable) -> Callable:
|
|
@@ -506,15 +666,15 @@ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
|
|
|
506
666
|
"""
|
|
507
667
|
|
|
508
668
|
def convert_jax_arrays(pytree: PyTree) -> PyTree:
|
|
509
|
-
def
|
|
510
|
-
if not isinstance(val,
|
|
669
|
+
def convert_jax_array(val: Any) -> Any:
|
|
670
|
+
if not isinstance(val, Array):
|
|
511
671
|
return val
|
|
512
672
|
elif val.shape:
|
|
513
673
|
return numpy.array(val)
|
|
514
674
|
else:
|
|
515
675
|
return val.item()
|
|
516
676
|
|
|
517
|
-
return tree.map(
|
|
677
|
+
return tree.map(convert_jax_array, pytree)
|
|
518
678
|
|
|
519
679
|
@wraps(func)
|
|
520
680
|
def new_func(*args, **kw):
|
|
@@ -530,126 +690,157 @@ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
|
|
|
530
690
|
# deadlock with the main thread
|
|
531
691
|
def _print_report(
|
|
532
692
|
*,
|
|
533
|
-
|
|
693
|
+
print_dot: bool,
|
|
694
|
+
print_newline: bool,
|
|
534
695
|
burnin: bool,
|
|
535
|
-
|
|
696
|
+
it: int,
|
|
536
697
|
n_iters: int,
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
prune_acc_count:
|
|
698
|
+
num_chains: int | None,
|
|
699
|
+
grow_prop_count: float,
|
|
700
|
+
grow_acc_count: float,
|
|
701
|
+
prune_acc_count: float,
|
|
541
702
|
prop_total: int,
|
|
542
703
|
fill: float,
|
|
543
704
|
):
|
|
544
705
|
"""Print the report for `print_callback`."""
|
|
545
|
-
|
|
546
|
-
def acc_string(acc_count, prop_count):
|
|
547
|
-
if prop_count:
|
|
548
|
-
return f'{acc_count / prop_count:.0%}'
|
|
549
|
-
else:
|
|
550
|
-
return 'n/d'
|
|
551
|
-
|
|
706
|
+
# compute fractions
|
|
552
707
|
grow_prop = grow_prop_count / prop_total
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
708
|
+
move_acc = (grow_acc_count + prune_acc_count) / prop_total
|
|
709
|
+
|
|
710
|
+
# determine prefix
|
|
711
|
+
if print_dot:
|
|
712
|
+
prefix = '.\n'
|
|
713
|
+
elif print_newline:
|
|
714
|
+
prefix = '\n'
|
|
715
|
+
else:
|
|
716
|
+
prefix = ''
|
|
556
717
|
|
|
557
|
-
|
|
558
|
-
|
|
718
|
+
# determine suffix in parentheses
|
|
719
|
+
msgs = []
|
|
720
|
+
if num_chains is not None:
|
|
721
|
+
msgs.append(f'avg. {num_chains} chains')
|
|
722
|
+
if burnin:
|
|
723
|
+
msgs.append('burnin')
|
|
724
|
+
suffix = f' ({", ".join(msgs)})' if msgs else ''
|
|
559
725
|
|
|
560
726
|
print( # noqa: T201, see print_callback for why not logging
|
|
561
|
-
f'{prefix}
|
|
562
|
-
f'grow
|
|
563
|
-
f'
|
|
564
|
-
f'fill
|
|
727
|
+
f'{prefix}Iteration {it}/{n_iters}, '
|
|
728
|
+
f'grow prob: {grow_prop:.0%}, '
|
|
729
|
+
f'move acc: {move_acc:.0%}, '
|
|
730
|
+
f'fill: {fill:.0%}{suffix}'
|
|
565
731
|
)
|
|
566
732
|
|
|
567
733
|
|
|
568
|
-
class
|
|
569
|
-
"""State for `sparse_callback`.
|
|
570
|
-
|
|
571
|
-
Parameters
|
|
572
|
-
----------
|
|
573
|
-
sparse_on_at
|
|
574
|
-
If specified, variable selection is activated starting from this
|
|
575
|
-
iteration. If `None`, variable selection is not used.
|
|
576
|
-
"""
|
|
577
|
-
|
|
578
|
-
sparse_on_at: Int32[Array, ''] | None
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
def sparse_callback(
|
|
582
|
-
*,
|
|
583
|
-
key: Key[Array, ''],
|
|
584
|
-
bart: State,
|
|
585
|
-
i_total: Int32[Array, ''],
|
|
586
|
-
callback_state: SparseCallbackState,
|
|
587
|
-
**_,
|
|
588
|
-
):
|
|
589
|
-
"""Perform variable selection, see `mcmcstep.step_sparse`."""
|
|
590
|
-
if callback_state.sparse_on_at is not None:
|
|
591
|
-
bart = lax.cond(
|
|
592
|
-
i_total < callback_state.sparse_on_at,
|
|
593
|
-
lambda: bart,
|
|
594
|
-
lambda: mcmcstep.step_sparse(key, bart),
|
|
595
|
-
)
|
|
596
|
-
return bart, callback_state
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
class Trace(grove.TreeHeaps, Protocol):
|
|
734
|
+
class Trace(TreeHeaps, Protocol):
|
|
600
735
|
"""Protocol for a MCMC trace."""
|
|
601
736
|
|
|
602
|
-
offset: Float32[Array, '
|
|
737
|
+
offset: Float32[Array, '*trace_shape']
|
|
603
738
|
|
|
604
739
|
|
|
605
740
|
class TreesTrace(Module):
|
|
606
741
|
"""Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
|
|
607
742
|
|
|
608
|
-
leaf_tree:
|
|
609
|
-
|
|
610
|
-
|
|
743
|
+
leaf_tree: (
|
|
744
|
+
Float32[Array, '*trace_shape num_trees 2**d']
|
|
745
|
+
| Float32[Array, '*trace_shape num_trees k 2**d']
|
|
746
|
+
)
|
|
747
|
+
var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
|
|
748
|
+
split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
|
|
611
749
|
|
|
612
750
|
@classmethod
|
|
613
|
-
def from_dataclass(cls, obj:
|
|
751
|
+
def from_dataclass(cls, obj: TreeHeaps):
|
|
614
752
|
"""Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
|
|
615
753
|
return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
|
|
616
754
|
|
|
617
755
|
|
|
618
|
-
@
|
|
756
|
+
@jit
|
|
619
757
|
def evaluate_trace(
|
|
620
|
-
|
|
621
|
-
) -> Float32[Array, '
|
|
758
|
+
X: UInt[Array, 'p n'], trace: Trace
|
|
759
|
+
) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
|
|
622
760
|
"""
|
|
623
761
|
Compute predictions for all iterations of the BART MCMC.
|
|
624
762
|
|
|
625
763
|
Parameters
|
|
626
764
|
----------
|
|
627
|
-
trace
|
|
628
|
-
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
629
765
|
X
|
|
630
766
|
The predictors matrix, with `p` predictors and `n` observations.
|
|
767
|
+
trace
|
|
768
|
+
A main trace of the BART MCMC, as returned by `run_mcmc`.
|
|
631
769
|
|
|
632
770
|
Returns
|
|
633
771
|
-------
|
|
634
|
-
The predictions for each iteration of the MCMC.
|
|
772
|
+
The predictions for each chain and iteration of the MCMC.
|
|
635
773
|
"""
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
774
|
+
# per-device memory limit
|
|
775
|
+
max_io_nbytes = 2**27 # 128 MiB
|
|
776
|
+
|
|
777
|
+
# adjust memory limit for number of devices
|
|
778
|
+
mesh = jax.typeof(trace.leaf_tree).sharding.mesh
|
|
779
|
+
num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data')
|
|
780
|
+
max_io_nbytes *= num_devices
|
|
781
|
+
|
|
782
|
+
# determine batching axes
|
|
783
|
+
has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes
|
|
784
|
+
if has_chains:
|
|
785
|
+
sample_axis = 1
|
|
786
|
+
tree_axis = 2
|
|
787
|
+
else:
|
|
788
|
+
sample_axis = 0
|
|
789
|
+
tree_axis = 1
|
|
790
|
+
|
|
791
|
+
# batch and sum over trees
|
|
792
|
+
batched_eval = autobatch(
|
|
793
|
+
evaluate_forest,
|
|
794
|
+
max_io_nbytes,
|
|
795
|
+
(None, tree_axis),
|
|
796
|
+
tree_axis,
|
|
797
|
+
reduce_ufunc=jnp.add,
|
|
798
|
+
)
|
|
639
799
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
800
|
+
# determine output shape (to avoid autobatch tracing everything 4 times)
|
|
801
|
+
is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
|
|
802
|
+
k = trace.leaf_tree.shape[-2] if is_mv else 1
|
|
803
|
+
mv_shape = (k,) if is_mv else ()
|
|
804
|
+
_, n = X.shape
|
|
805
|
+
out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n)
|
|
806
|
+
|
|
807
|
+
# adjust memory limit keeping into account that trees are summed over
|
|
808
|
+
num_trees, hts = trace.split_tree.shape[-2:]
|
|
809
|
+
out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest
|
|
810
|
+
core_io_size = (
|
|
811
|
+
num_trees
|
|
812
|
+
* hts
|
|
813
|
+
* (
|
|
814
|
+
2 * k * trace.leaf_tree.itemsize
|
|
815
|
+
+ trace.var_tree.itemsize
|
|
816
|
+
+ trace.split_tree.itemsize
|
|
817
|
+
)
|
|
818
|
+
+ out_size
|
|
819
|
+
)
|
|
820
|
+
core_int_size = (num_trees - 1) * out_size
|
|
821
|
+
max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size)))
|
|
822
|
+
|
|
823
|
+
# batch over mcmc samples
|
|
824
|
+
batched_eval = autobatch(
|
|
825
|
+
batched_eval,
|
|
826
|
+
max_io_nbytes,
|
|
827
|
+
(None, sample_axis),
|
|
828
|
+
sample_axis,
|
|
829
|
+
warn_on_overflow=False, # the inner autobatch will handle it
|
|
830
|
+
result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
# extract only the trees from the trace
|
|
834
|
+
trees = TreesTrace.from_dataclass(trace)
|
|
644
835
|
|
|
645
|
-
|
|
646
|
-
|
|
836
|
+
# evaluate trees
|
|
837
|
+
y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
|
|
838
|
+
y_centered = batched_eval(X, trees)
|
|
839
|
+
return y_centered + trace.offset[..., None]
|
|
647
840
|
|
|
648
841
|
|
|
649
|
-
@partial(
|
|
650
|
-
def compute_varcount(
|
|
651
|
-
p: int, trace: grove.TreeHeaps
|
|
652
|
-
) -> Int32[Array, 'trace_length {p}']:
|
|
842
|
+
@partial(jit, static_argnums=(0,))
|
|
843
|
+
def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
|
|
653
844
|
"""
|
|
654
845
|
Count how many times each predictor is used in each MCMC state.
|
|
655
846
|
|
|
@@ -658,11 +849,11 @@ def compute_varcount(
|
|
|
658
849
|
p
|
|
659
850
|
The number of predictors.
|
|
660
851
|
trace
|
|
661
|
-
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
852
|
+
A main trace of the BART MCMC, as returned by `run_mcmc`.
|
|
662
853
|
|
|
663
854
|
Returns
|
|
664
855
|
-------
|
|
665
856
|
Histogram of predictor usage in each MCMC state.
|
|
666
857
|
"""
|
|
667
|
-
|
|
668
|
-
return
|
|
858
|
+
# var_tree has shape (chains? samples trees nodes)
|
|
859
|
+
return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1)
|