bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +4 -2
- bartz/{BART.py → _interface.py} +256 -132
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +269 -314
- bartz/grove.py +124 -68
- bartz/jaxext/__init__.py +101 -27
- bartz/jaxext/_autobatch.py +257 -51
- bartz/jaxext/scipy/__init__.py +1 -1
- bartz/jaxext/scipy/special.py +3 -4
- bartz/jaxext/scipy/stats.py +1 -1
- bartz/mcmcloop.py +399 -208
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +1 -1
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/METADATA +17 -11
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/mcmcstep.py +0 -2616
- bartz-0.7.0.dist-info/RECORD +0 -17
bartz/{BART.py → _interface.py}
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
# bartz/src/bartz/
|
|
1
|
+
# bartz/src/bartz/_interface.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c)
|
|
3
|
+
# Copyright (c) 2025-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -22,17 +22,20 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""
|
|
25
|
+
"""Main high-level interface of the package."""
|
|
26
26
|
|
|
27
27
|
import math
|
|
28
28
|
from collections.abc import Sequence
|
|
29
29
|
from functools import cached_property
|
|
30
|
-
from typing import Any, Literal, Protocol
|
|
30
|
+
from typing import Any, Literal, Protocol, TypedDict
|
|
31
31
|
|
|
32
32
|
import jax
|
|
33
33
|
import jax.numpy as jnp
|
|
34
34
|
from equinox import Module, field
|
|
35
|
+
from jax import Device, device_put, jit, make_mesh
|
|
36
|
+
from jax.lax import collapse
|
|
35
37
|
from jax.scipy.special import ndtr
|
|
38
|
+
from jax.sharding import AxisType, Mesh
|
|
36
39
|
from jaxtyping import (
|
|
37
40
|
Array,
|
|
38
41
|
Bool,
|
|
@@ -48,22 +51,21 @@ from jaxtyping import (
|
|
|
48
51
|
from numpy import ndarray
|
|
49
52
|
|
|
50
53
|
from bartz import mcmcloop, mcmcstep, prepcovars
|
|
54
|
+
from bartz.jaxext import is_key
|
|
51
55
|
from bartz.jaxext.scipy.special import ndtri
|
|
52
56
|
from bartz.jaxext.scipy.stats import invgamma
|
|
57
|
+
from bartz.mcmcloop import compute_varcount, evaluate_trace, run_mcmc
|
|
58
|
+
from bartz.mcmcstep import make_p_nonterminal
|
|
59
|
+
from bartz.mcmcstep._state import get_num_chains
|
|
53
60
|
|
|
54
61
|
FloatLike = float | Float[Any, '']
|
|
55
62
|
|
|
56
63
|
|
|
57
64
|
class DataFrame(Protocol):
|
|
58
|
-
"""DataFrame duck-type for `
|
|
59
|
-
|
|
60
|
-
Attributes
|
|
61
|
-
----------
|
|
62
|
-
columns : Sequence[str]
|
|
63
|
-
The names of the columns.
|
|
64
|
-
"""
|
|
65
|
+
"""DataFrame duck-type for `Bart`."""
|
|
65
66
|
|
|
66
67
|
columns: Sequence[str]
|
|
68
|
+
"""The names of the columns."""
|
|
67
69
|
|
|
68
70
|
def to_numpy(self) -> ndarray:
|
|
69
71
|
"""Convert the dataframe to a 2d numpy array with columns on the second axis."""
|
|
@@ -71,22 +73,17 @@ class DataFrame(Protocol):
|
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
class Series(Protocol):
|
|
74
|
-
"""Series duck-type for `
|
|
75
|
-
|
|
76
|
-
Attributes
|
|
77
|
-
----------
|
|
78
|
-
name : str | None
|
|
79
|
-
The name of the series.
|
|
80
|
-
"""
|
|
76
|
+
"""Series duck-type for `Bart`."""
|
|
81
77
|
|
|
82
78
|
name: str | None
|
|
79
|
+
"""The name of the series."""
|
|
83
80
|
|
|
84
81
|
def to_numpy(self) -> ndarray:
|
|
85
82
|
"""Convert the series to a 1d numpy array."""
|
|
86
83
|
...
|
|
87
84
|
|
|
88
85
|
|
|
89
|
-
class
|
|
86
|
+
class Bart(Module):
|
|
90
87
|
R"""
|
|
91
88
|
Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
|
|
92
89
|
|
|
@@ -147,10 +144,7 @@ class gbart(Module):
|
|
|
147
144
|
rm_const
|
|
148
145
|
How to treat predictors with no associated decision rules (i.e., there
|
|
149
146
|
are no available cutpoints for that predictor). If `True` (default),
|
|
150
|
-
they are ignored. If `False`, an error is raised if there are any.
|
|
151
|
-
`None`, no check is performed, and the output of the MCMC may not make
|
|
152
|
-
sense if there are predictors without cutpoints. The option `None` is
|
|
153
|
-
provided only to allow jax tracing.
|
|
147
|
+
they are ignored. If `False`, an error is raised if there are any.
|
|
154
148
|
sigest
|
|
155
149
|
An estimate of the residual standard deviation on `y_train`, used to set
|
|
156
150
|
`lamda`. If not specified, it is estimated by linear regression (with
|
|
@@ -214,21 +208,40 @@ class gbart(Module):
|
|
|
214
208
|
|
|
215
209
|
Ignored if `xinfo` is specified.
|
|
216
210
|
ndpost
|
|
217
|
-
The number of MCMC samples to save, after burn-in.
|
|
211
|
+
The number of MCMC samples to save, after burn-in. `ndpost` is the
|
|
212
|
+
total number of samples across all chains. `ndpost` is rounded up to the
|
|
213
|
+
first multiple of `mc_cores`.
|
|
218
214
|
nskip
|
|
219
|
-
The number of initial MCMC samples to discard as burn-in.
|
|
215
|
+
The number of initial MCMC samples to discard as burn-in. This number
|
|
216
|
+
of samples is discarded from each chain.
|
|
220
217
|
keepevery
|
|
221
218
|
The thinning factor for the MCMC samples, after burn-in. By default, 1
|
|
222
219
|
for continuous regression and 10 for binary regression.
|
|
223
220
|
printevery
|
|
224
221
|
The number of iterations (including thinned-away ones) between each log
|
|
225
|
-
line. Set to `None` to disable logging.
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
222
|
+
line. Set to `None` to disable logging. ^C interrupts the MCMC only
|
|
223
|
+
every `printevery` iterations, so with logging disabled it's impossible
|
|
224
|
+
to kill the MCMC conveniently.
|
|
225
|
+
num_chains
|
|
226
|
+
The number of independent Markov chains to run. By default only one
|
|
227
|
+
chain is run.
|
|
228
|
+
|
|
229
|
+
The difference between not specifying `num_chains` and setting it to 1
|
|
230
|
+
is that in the latter case in the object attributes and some methods
|
|
231
|
+
there will be an explicit chain axis of size 1.
|
|
232
|
+
num_chain_devices
|
|
233
|
+
The number of devices to spread the chains across. Must be a divisor of
|
|
234
|
+
`num_chains`. Each device will run a fraction of the chains.
|
|
235
|
+
num_data_devices
|
|
236
|
+
The number of devices to split datapoints across. Must be a divisor of
|
|
237
|
+
`n`. This is useful only with very high `n`, about > 1000_000.
|
|
238
|
+
|
|
239
|
+
If both num_chain_devices and num_data_devices are specified, the total
|
|
240
|
+
number of devices used is the product of the two.
|
|
241
|
+
devices
|
|
242
|
+
One or more devices used to run the MCMC on. If not specified, the
|
|
243
|
+
computation will follow the placement of the input arrays. If a list of
|
|
244
|
+
devices, this argument can be longer than the number of devices needed.
|
|
232
245
|
seed
|
|
233
246
|
The seed for the random number generator.
|
|
234
247
|
maxdepth
|
|
@@ -239,34 +252,6 @@ class gbart(Module):
|
|
|
239
252
|
run_mcmc_kw
|
|
240
253
|
Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
|
|
241
254
|
|
|
242
|
-
Attributes
|
|
243
|
-
----------
|
|
244
|
-
offset : Float32[Array, '']
|
|
245
|
-
The prior mean of the latent mean function.
|
|
246
|
-
sigest : Float32[Array, ''] | None
|
|
247
|
-
The estimated standard deviation of the error used to set `lamda`.
|
|
248
|
-
yhat_test : Float32[Array, 'ndpost m'] | None
|
|
249
|
-
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
250
|
-
|
|
251
|
-
Notes
|
|
252
|
-
-----
|
|
253
|
-
This interface imitates the function ``gbart`` from the R package `BART
|
|
254
|
-
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
255
|
-
|
|
256
|
-
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
257
|
-
instead of per column.
|
|
258
|
-
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
259
|
-
less predictor values than the required number of bins, while bartz
|
|
260
|
-
always follows the specification.
|
|
261
|
-
- Some functionality is missing.
|
|
262
|
-
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
263
|
-
- There are some additional attributes, and some missing.
|
|
264
|
-
- The trees have a maximum depth.
|
|
265
|
-
- `rm_const` refers to predictors without decision rules instead of
|
|
266
|
-
predictors that are constant in `x_train`.
|
|
267
|
-
- If `rm_const=True` and some variables are dropped, the predictors
|
|
268
|
-
matrix/dataframe passed to `predict` should still include them.
|
|
269
|
-
|
|
270
255
|
References
|
|
271
256
|
----------
|
|
272
257
|
.. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
|
|
@@ -283,10 +268,14 @@ class gbart(Module):
|
|
|
283
268
|
_splits: Real[Array, 'p max_num_splits']
|
|
284
269
|
_x_train_fmt: Any = field(static=True)
|
|
285
270
|
|
|
286
|
-
ndpost: int = field(static=True)
|
|
287
271
|
offset: Float32[Array, '']
|
|
272
|
+
"""The prior mean of the latent mean function."""
|
|
273
|
+
|
|
288
274
|
sigest: Float32[Array, ''] | None = None
|
|
275
|
+
"""The estimated standard deviation of the error used to set `lamda`."""
|
|
276
|
+
|
|
289
277
|
yhat_test: Float32[Array, 'ndpost m'] | None = None
|
|
278
|
+
"""The conditional posterior mean at `x_test` for each MCMC iteration."""
|
|
290
279
|
|
|
291
280
|
def __init__(
|
|
292
281
|
self,
|
|
@@ -302,7 +291,7 @@ class gbart(Module):
|
|
|
302
291
|
rho: FloatLike | None = None,
|
|
303
292
|
xinfo: Float[Array, 'p n'] | None = None,
|
|
304
293
|
usequants: bool = False,
|
|
305
|
-
rm_const: bool
|
|
294
|
+
rm_const: bool = True,
|
|
306
295
|
sigest: FloatLike | None = None,
|
|
307
296
|
sigdf: FloatLike = 3.0,
|
|
308
297
|
sigquant: FloatLike = 0.9,
|
|
@@ -312,13 +301,17 @@ class gbart(Module):
|
|
|
312
301
|
lamda: FloatLike | None = None,
|
|
313
302
|
tau_num: FloatLike | None = None,
|
|
314
303
|
offset: FloatLike | None = None,
|
|
315
|
-
w: Float[Array, ' n'] | None = None,
|
|
304
|
+
w: Float[Array, ' n'] | Series | None = None,
|
|
316
305
|
ntree: int | None = None,
|
|
317
306
|
numcut: int = 100,
|
|
318
307
|
ndpost: int = 1000,
|
|
319
308
|
nskip: int = 100,
|
|
320
309
|
keepevery: int | None = None,
|
|
321
310
|
printevery: int | None = 100,
|
|
311
|
+
num_chains: int | None = None,
|
|
312
|
+
num_chain_devices: int | None = None,
|
|
313
|
+
num_data_devices: int | None = None,
|
|
314
|
+
devices: Device | Sequence[Device] | None = None,
|
|
322
315
|
seed: int | Key[Array, ''] = 0,
|
|
323
316
|
maxdepth: int = 6,
|
|
324
317
|
init_kw: dict | None = None,
|
|
@@ -378,21 +371,19 @@ class gbart(Module):
|
|
|
378
371
|
a,
|
|
379
372
|
b,
|
|
380
373
|
rho,
|
|
374
|
+
num_chains,
|
|
375
|
+
num_chain_devices,
|
|
376
|
+
num_data_devices,
|
|
377
|
+
devices,
|
|
378
|
+
sparse,
|
|
379
|
+
nskip,
|
|
381
380
|
)
|
|
382
381
|
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
383
|
-
initial_state,
|
|
384
|
-
ndpost,
|
|
385
|
-
nskip,
|
|
386
|
-
keepevery,
|
|
387
|
-
printevery,
|
|
388
|
-
seed,
|
|
389
|
-
run_mcmc_kw,
|
|
390
|
-
sparse,
|
|
382
|
+
initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
|
|
391
383
|
)
|
|
392
384
|
|
|
393
385
|
# set public attributes
|
|
394
386
|
self.offset = final_state.offset # from the state because of buffer donation
|
|
395
|
-
self.ndpost = ndpost
|
|
396
387
|
self.sigest = sigest
|
|
397
388
|
|
|
398
389
|
# set private attributes
|
|
@@ -406,6 +397,15 @@ class gbart(Module):
|
|
|
406
397
|
if x_test is not None:
|
|
407
398
|
self.yhat_test = self.predict(x_test)
|
|
408
399
|
|
|
400
|
+
@property
|
|
401
|
+
def ndpost(self):
|
|
402
|
+
"""The total number of posterior samples after burn-in across all chains.
|
|
403
|
+
|
|
404
|
+
May be larger than the initialization argument `ndpost` if it was not
|
|
405
|
+
divisible by the number of chains.
|
|
406
|
+
"""
|
|
407
|
+
return self._main_trace.grow_prop_count.size
|
|
408
|
+
|
|
409
409
|
@cached_property
|
|
410
410
|
def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
|
|
411
411
|
"""The posterior probability of y being True at `x_test` for each MCMC iteration."""
|
|
@@ -439,30 +439,53 @@ class gbart(Module):
|
|
|
439
439
|
return self.prob_train.mean(axis=0)
|
|
440
440
|
|
|
441
441
|
@cached_property
|
|
442
|
-
def sigma(
|
|
442
|
+
def sigma(
|
|
443
|
+
self,
|
|
444
|
+
) -> (
|
|
445
|
+
Float32[Array, ' nskip+ndpost']
|
|
446
|
+
| Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
|
|
447
|
+
| None
|
|
448
|
+
):
|
|
443
449
|
"""The standard deviation of the error, including burn-in samples."""
|
|
444
|
-
if self._burnin_trace.
|
|
450
|
+
if self._burnin_trace.error_cov_inv is None:
|
|
445
451
|
return None
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
jnp.concatenate(
|
|
452
|
+
assert self._main_trace.error_cov_inv is not None
|
|
453
|
+
return jnp.sqrt(
|
|
454
|
+
jnp.reciprocal(
|
|
455
|
+
jnp.concatenate(
|
|
456
|
+
[
|
|
457
|
+
self._burnin_trace.error_cov_inv.T,
|
|
458
|
+
self._main_trace.error_cov_inv.T,
|
|
459
|
+
],
|
|
460
|
+
axis=0,
|
|
461
|
+
# error_cov_inv has shape (chains? samples) in the trace
|
|
462
|
+
)
|
|
450
463
|
)
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
@cached_property
|
|
467
|
+
def sigma_(self) -> Float32[Array, 'ndpost'] | None:
|
|
468
|
+
"""The standard deviation of the error, only over the post-burnin samples and flattened."""
|
|
469
|
+
error_cov_inv = self._main_trace.error_cov_inv
|
|
470
|
+
if error_cov_inv is None:
|
|
471
|
+
return None
|
|
472
|
+
else:
|
|
473
|
+
return jnp.sqrt(jnp.reciprocal(error_cov_inv)).reshape(-1)
|
|
451
474
|
|
|
452
475
|
@cached_property
|
|
453
476
|
def sigma_mean(self) -> Float32[Array, ''] | None:
|
|
454
477
|
"""The mean of `sigma`, only over the post-burnin samples."""
|
|
455
|
-
if self.
|
|
478
|
+
if self.sigma_ is None:
|
|
456
479
|
return None
|
|
457
|
-
|
|
458
|
-
return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0)
|
|
480
|
+
return self.sigma_.mean()
|
|
459
481
|
|
|
460
482
|
@cached_property
|
|
461
483
|
def varcount(self) -> Int32[Array, 'ndpost p']:
|
|
462
484
|
"""Histogram of predictor usage for decision rules in the trees."""
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
)
|
|
485
|
+
p = self._mcmc_state.forest.max_split.size
|
|
486
|
+
varcount: Int32[Array, '*chains samples p']
|
|
487
|
+
varcount = compute_varcount(p, self._main_trace)
|
|
488
|
+
return collapse(varcount, 0, -1)
|
|
466
489
|
|
|
467
490
|
@cached_property
|
|
468
491
|
def varcount_mean(self) -> Float32[Array, ' p']:
|
|
@@ -472,13 +495,15 @@ class gbart(Module):
|
|
|
472
495
|
@cached_property
|
|
473
496
|
def varprob(self) -> Float32[Array, 'ndpost p']:
|
|
474
497
|
"""Posterior samples of the probability of choosing each predictor for a decision rule."""
|
|
498
|
+
max_split = self._mcmc_state.forest.max_split
|
|
499
|
+
p = max_split.size
|
|
475
500
|
varprob = self._main_trace.varprob
|
|
476
501
|
if varprob is None:
|
|
477
|
-
max_split = self._mcmc_state.forest.max_split
|
|
478
|
-
p = max_split.size
|
|
479
502
|
peff = jnp.count_nonzero(max_split)
|
|
480
503
|
varprob = jnp.where(max_split, 1 / peff, 0)
|
|
481
504
|
varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
|
|
505
|
+
else:
|
|
506
|
+
varprob = varprob.reshape(-1, p)
|
|
482
507
|
return varprob
|
|
483
508
|
|
|
484
509
|
@cached_property
|
|
@@ -567,10 +592,11 @@ class gbart(Module):
|
|
|
567
592
|
get_length = lambda x: x.shape[-1]
|
|
568
593
|
assert get_length(x1) == get_length(x2)
|
|
569
594
|
|
|
570
|
-
@
|
|
595
|
+
@classmethod
|
|
571
596
|
def _process_error_variance_settings(
|
|
572
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
597
|
+
cls, x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
573
598
|
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
599
|
+
"""Return (lamda, sigest)."""
|
|
574
600
|
if y_train.dtype == bool:
|
|
575
601
|
if sigest is not None:
|
|
576
602
|
msg = 'Let `sigest=None` for binary regression'
|
|
@@ -592,18 +618,26 @@ class gbart(Module):
|
|
|
592
618
|
elif y_train.size <= x_train.shape[0]:
|
|
593
619
|
sigest2 = jnp.var(y_train)
|
|
594
620
|
else:
|
|
595
|
-
|
|
596
|
-
y_centered = y_train - y_train.mean()
|
|
597
|
-
# centering is equivalent to adding an intercept column
|
|
598
|
-
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
599
|
-
chisq = chisq.squeeze(0)
|
|
600
|
-
dof = len(y_train) - rank
|
|
601
|
-
sigest2 = chisq / dof
|
|
621
|
+
sigest2 = cls._linear_regression(x_train, y_train)
|
|
602
622
|
alpha = sigdf / 2
|
|
603
623
|
invchi2 = invgamma.ppf(sigquant, alpha) / 2
|
|
604
624
|
invchi2rid = invchi2 * sigdf
|
|
605
625
|
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
606
626
|
|
|
627
|
+
@staticmethod
|
|
628
|
+
@jit
|
|
629
|
+
def _linear_regression(
|
|
630
|
+
x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n']
|
|
631
|
+
):
|
|
632
|
+
"""Return the error variance estimated with OLS with intercept."""
|
|
633
|
+
x_centered = x_train.T - x_train.mean(axis=1)
|
|
634
|
+
y_centered = y_train - y_train.mean()
|
|
635
|
+
# centering is equivalent to adding an intercept column
|
|
636
|
+
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
637
|
+
chisq = chisq.squeeze(0)
|
|
638
|
+
dof = len(y_train) - rank
|
|
639
|
+
return chisq / dof
|
|
640
|
+
|
|
607
641
|
@staticmethod
|
|
608
642
|
def _check_type_settings(y_train, type, w): # noqa: A002
|
|
609
643
|
match type:
|
|
@@ -641,6 +675,7 @@ class gbart(Module):
|
|
|
641
675
|
| tuple[FloatLike, None, None, None]
|
|
642
676
|
| tuple[None, FloatLike, FloatLike, FloatLike]
|
|
643
677
|
):
|
|
678
|
+
"""Return (theta, a, b, rho)."""
|
|
644
679
|
if not sparse:
|
|
645
680
|
return None, None, None, None
|
|
646
681
|
elif theta is not None:
|
|
@@ -656,6 +691,7 @@ class gbart(Module):
|
|
|
656
691
|
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
657
692
|
offset: float | Float32[Any, ''] | None,
|
|
658
693
|
) -> Float32[Array, '']:
|
|
694
|
+
"""Return offset."""
|
|
659
695
|
if offset is not None:
|
|
660
696
|
return jnp.asarray(offset)
|
|
661
697
|
elif y_train.size < 1:
|
|
@@ -677,6 +713,7 @@ class gbart(Module):
|
|
|
677
713
|
ntree: int,
|
|
678
714
|
tau_num: FloatLike | None,
|
|
679
715
|
):
|
|
716
|
+
"""Return sigma_mu."""
|
|
680
717
|
if tau_num is None:
|
|
681
718
|
if y_train.dtype == bool:
|
|
682
719
|
tau_num = 3.0
|
|
@@ -705,7 +742,9 @@ class gbart(Module):
|
|
|
705
742
|
return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
|
|
706
743
|
|
|
707
744
|
@staticmethod
|
|
708
|
-
def _bin_predictors(
|
|
745
|
+
def _bin_predictors(
|
|
746
|
+
x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits']
|
|
747
|
+
) -> UInt[Array, 'p n']:
|
|
709
748
|
return prepcovars.bin_predictors(x, splits)
|
|
710
749
|
|
|
711
750
|
@staticmethod
|
|
@@ -723,23 +762,35 @@ class gbart(Module):
|
|
|
723
762
|
maxdepth: int,
|
|
724
763
|
ntree: int,
|
|
725
764
|
init_kw: dict[str, Any] | None,
|
|
726
|
-
rm_const: bool
|
|
765
|
+
rm_const: bool,
|
|
727
766
|
theta: FloatLike | None,
|
|
728
767
|
a: FloatLike | None,
|
|
729
768
|
b: FloatLike | None,
|
|
730
769
|
rho: FloatLike | None,
|
|
770
|
+
num_chains: int | None,
|
|
771
|
+
num_chain_devices: int | None,
|
|
772
|
+
num_data_devices: int | None,
|
|
773
|
+
devices: Device | Sequence[Device] | None,
|
|
774
|
+
sparse: bool,
|
|
775
|
+
nskip: int,
|
|
731
776
|
):
|
|
732
|
-
|
|
733
|
-
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
777
|
+
p_nonterminal = make_p_nonterminal(maxdepth, base, power)
|
|
734
778
|
|
|
735
779
|
if y_train.dtype == bool:
|
|
736
|
-
|
|
737
|
-
|
|
780
|
+
error_cov_df = None
|
|
781
|
+
error_cov_scale = None
|
|
738
782
|
else:
|
|
739
|
-
|
|
740
|
-
|
|
783
|
+
assert lamda is not None
|
|
784
|
+
# inverse gamma prior: alpha = df / 2, beta = scale / 2
|
|
785
|
+
error_cov_df = sigdf
|
|
786
|
+
error_cov_scale = lamda * sigdf
|
|
787
|
+
|
|
788
|
+
# process device settings
|
|
789
|
+
device_kw, device = process_device_settings(
|
|
790
|
+
y_train, num_chains, num_chain_devices, num_data_devices, devices
|
|
791
|
+
)
|
|
741
792
|
|
|
742
|
-
kw = dict(
|
|
793
|
+
kw: dict = dict(
|
|
743
794
|
X=x_train,
|
|
744
795
|
# copy y_train because it's going to be donated in the mcmc loop
|
|
745
796
|
y=jnp.array(y_train),
|
|
@@ -748,35 +799,37 @@ class gbart(Module):
|
|
|
748
799
|
max_split=max_split,
|
|
749
800
|
num_trees=ntree,
|
|
750
801
|
p_nonterminal=p_nonterminal,
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
802
|
+
leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)),
|
|
803
|
+
error_cov_df=error_cov_df,
|
|
804
|
+
error_cov_scale=error_cov_scale,
|
|
754
805
|
min_points_per_decision_node=10,
|
|
755
806
|
min_points_per_leaf=5,
|
|
756
807
|
theta=theta,
|
|
757
808
|
a=a,
|
|
758
809
|
b=b,
|
|
759
810
|
rho=rho,
|
|
811
|
+
sparse_on_at=nskip // 2 if sparse else None,
|
|
812
|
+
**device_kw,
|
|
760
813
|
)
|
|
761
814
|
|
|
762
|
-
if rm_const
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
kw.update(filter_splitless_vars=True)
|
|
766
|
-
else:
|
|
767
|
-
n_empty = jnp.count_nonzero(max_split == 0)
|
|
768
|
-
if n_empty:
|
|
769
|
-
msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
|
|
770
|
-
raise ValueError(msg)
|
|
771
|
-
kw.update(filter_splitless_vars=False)
|
|
815
|
+
if rm_const:
|
|
816
|
+
n_empty = jnp.sum(max_split == 0).item()
|
|
817
|
+
kw.update(filter_splitless_vars=n_empty)
|
|
772
818
|
|
|
773
819
|
if init_kw is not None:
|
|
774
820
|
kw.update(init_kw)
|
|
775
821
|
|
|
776
|
-
|
|
822
|
+
state = mcmcstep.init(**kw)
|
|
777
823
|
|
|
778
|
-
|
|
824
|
+
# put state on device if requested explicitly by the user
|
|
825
|
+
if device is not None:
|
|
826
|
+
state = device_put(state, device, donate=True)
|
|
827
|
+
|
|
828
|
+
return state
|
|
829
|
+
|
|
830
|
+
@classmethod
|
|
779
831
|
def _run_mcmc(
|
|
832
|
+
cls,
|
|
780
833
|
mcmc_state: mcmcstep.State,
|
|
781
834
|
ndpost: int,
|
|
782
835
|
nskip: int,
|
|
@@ -784,30 +837,101 @@ class gbart(Module):
|
|
|
784
837
|
printevery: int | None,
|
|
785
838
|
seed: int | Integer[Array, ''] | Key[Array, ''],
|
|
786
839
|
run_mcmc_kw: dict | None,
|
|
787
|
-
|
|
788
|
-
):
|
|
840
|
+
) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]:
|
|
789
841
|
# prepare random generator seed
|
|
790
|
-
if
|
|
791
|
-
|
|
792
|
-
):
|
|
793
|
-
key = seed.copy()
|
|
794
|
-
# copy because the inner loop in run_mcmc will donate the buffer
|
|
842
|
+
if is_key(seed):
|
|
843
|
+
key = jnp.copy(seed)
|
|
795
844
|
else:
|
|
796
845
|
key = jax.random.key(seed)
|
|
797
846
|
|
|
847
|
+
# round up ndpost
|
|
848
|
+
num_chains = get_num_chains(mcmc_state)
|
|
849
|
+
if num_chains is None:
|
|
850
|
+
num_chains = 1
|
|
851
|
+
n_save = ndpost // num_chains + bool(ndpost % num_chains)
|
|
852
|
+
|
|
798
853
|
# prepare arguments
|
|
799
|
-
kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
|
|
854
|
+
kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
|
|
800
855
|
kw.update(
|
|
801
856
|
mcmcloop.make_default_callback(
|
|
857
|
+
mcmc_state,
|
|
802
858
|
dot_every=None if printevery is None or printevery == 1 else 1,
|
|
803
859
|
report_every=printevery,
|
|
804
|
-
sparse_on_at=nskip // 2 if sparse else None,
|
|
805
860
|
)
|
|
806
861
|
)
|
|
807
862
|
if run_mcmc_kw is not None:
|
|
808
863
|
kw.update(run_mcmc_kw)
|
|
809
864
|
|
|
810
|
-
return
|
|
811
|
-
|
|
812
|
-
def _predict(self, x):
|
|
813
|
-
|
|
865
|
+
return run_mcmc(key, mcmc_state, n_save, **kw)
|
|
866
|
+
|
|
867
|
+
def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']:
|
|
868
|
+
"""Evaluate trees on already quantized `x`."""
|
|
869
|
+
out = evaluate_trace(x, self._main_trace)
|
|
870
|
+
return collapse(out, 0, -1)
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
class DeviceKwArgs(TypedDict):
|
|
874
|
+
num_chains: int | None
|
|
875
|
+
mesh: Mesh | None
|
|
876
|
+
target_platform: Literal['cpu', 'gpu'] | None
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def process_device_settings(
|
|
880
|
+
y_train: Array,
|
|
881
|
+
num_chains: int | None,
|
|
882
|
+
num_chain_devices: int | None,
|
|
883
|
+
num_data_devices: int | None,
|
|
884
|
+
devices: Device | Sequence[Device] | None,
|
|
885
|
+
) -> tuple[DeviceKwArgs, Device | None]:
|
|
886
|
+
"""Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
|
|
887
|
+
# determine devices
|
|
888
|
+
if devices is not None:
|
|
889
|
+
if not hasattr(devices, '__len__'):
|
|
890
|
+
devices = (devices,)
|
|
891
|
+
device = devices[0]
|
|
892
|
+
platform = device.platform
|
|
893
|
+
elif hasattr(y_train, 'platform'):
|
|
894
|
+
platform = y_train.platform()
|
|
895
|
+
device = None
|
|
896
|
+
# set device=None because if the devices were not specified explicitly
|
|
897
|
+
# we may be in the case where computation will follow data placement,
|
|
898
|
+
# do not disturb jax as the user may be playing with vmap, jit, reshard...
|
|
899
|
+
devices = jax.devices(platform)
|
|
900
|
+
else:
|
|
901
|
+
msg = 'not possible to infer device from `y_train`, please set `devices`'
|
|
902
|
+
raise ValueError(msg)
|
|
903
|
+
|
|
904
|
+
# create mesh
|
|
905
|
+
if num_chain_devices is None and num_data_devices is None:
|
|
906
|
+
mesh = None
|
|
907
|
+
else:
|
|
908
|
+
mesh = dict()
|
|
909
|
+
if num_chain_devices is not None:
|
|
910
|
+
mesh.update(chains=num_chain_devices)
|
|
911
|
+
if num_data_devices is not None:
|
|
912
|
+
mesh.update(data=num_data_devices)
|
|
913
|
+
mesh = make_mesh(
|
|
914
|
+
axis_shapes=tuple(mesh.values()),
|
|
915
|
+
axis_names=tuple(mesh),
|
|
916
|
+
axis_types=(AxisType.Auto,) * len(mesh),
|
|
917
|
+
devices=devices,
|
|
918
|
+
)
|
|
919
|
+
device = None
|
|
920
|
+
# set device=None because `mcmcstep.init` will `device_put` with the
|
|
921
|
+
# mesh already, we don't want to undo its work
|
|
922
|
+
|
|
923
|
+
# prepare arguments to `init`
|
|
924
|
+
settings = DeviceKwArgs(
|
|
925
|
+
num_chains=num_chains,
|
|
926
|
+
mesh=mesh,
|
|
927
|
+
target_platform=None
|
|
928
|
+
if mesh is not None or hasattr(y_train, 'platform')
|
|
929
|
+
else platform,
|
|
930
|
+
# here we don't take into account the case where the user has set both
|
|
931
|
+
# batch sizes; since the user has to be playing with `init_kw` to do
|
|
932
|
+
# that, we'll let `init` throw the error and the user set
|
|
933
|
+
# `target_platform` themselves so they have a clearer idea how the
|
|
934
|
+
# thing works.
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
return settings, device
|