bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/.DS_Store
CHANGED
|
File without changes
|
bartz/BART/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# bartz/src/bartz/BART/__init__.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
|
|
26
|
+
|
|
27
|
+
from bartz.BART._gbart import gbart, mc_gbart # noqa: F401
|
bartz/BART/_gbart.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
1
|
+
# bartz/src/bartz/BART/_gbart.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
|
|
26
|
+
|
|
27
|
+
from collections.abc import Mapping
|
|
28
|
+
from functools import cached_property
|
|
29
|
+
from os import cpu_count
|
|
30
|
+
from types import MappingProxyType
|
|
31
|
+
from typing import Any, Literal
|
|
32
|
+
from warnings import warn
|
|
33
|
+
|
|
34
|
+
from equinox import Module
|
|
35
|
+
from jax import device_count
|
|
36
|
+
from jax import numpy as jnp
|
|
37
|
+
from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real
|
|
38
|
+
|
|
39
|
+
from bartz import mcmcloop, mcmcstep
|
|
40
|
+
from bartz._interface import Bart, DataFrame, FloatLike, Series
|
|
41
|
+
from bartz.jaxext import get_default_device
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class mc_gbart(Module):
|
|
45
|
+
R"""
|
|
46
|
+
Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
|
|
47
|
+
|
|
48
|
+
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
49
|
+
a sum of decision trees. The inference is carried out by sampling the
|
|
50
|
+
posterior distribution of the tree ensemble with an MCMC.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
x_train
|
|
55
|
+
The training predictors.
|
|
56
|
+
y_train
|
|
57
|
+
The training responses.
|
|
58
|
+
x_test
|
|
59
|
+
The test predictors.
|
|
60
|
+
type
|
|
61
|
+
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
62
|
+
binary regression with probit link.
|
|
63
|
+
sparse
|
|
64
|
+
Whether to activate variable selection on the predictors as done in
|
|
65
|
+
[1]_.
|
|
66
|
+
theta
|
|
67
|
+
a
|
|
68
|
+
b
|
|
69
|
+
rho
|
|
70
|
+
Hyperparameters of the sparsity prior used for variable selection.
|
|
71
|
+
|
|
72
|
+
The prior distribution on the choice of predictor for each decision rule
|
|
73
|
+
is
|
|
74
|
+
|
|
75
|
+
.. math::
|
|
76
|
+
(s_1, \ldots, s_p) \sim
|
|
77
|
+
\operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
|
|
78
|
+
|
|
79
|
+
If `theta` is not specified, it's a priori distributed according to
|
|
80
|
+
|
|
81
|
+
.. math::
|
|
82
|
+
\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
|
|
83
|
+
\operatorname{Beta}(\mathtt{a}, \mathtt{b}).
|
|
84
|
+
|
|
85
|
+
If not specified, `rho` is set to the number of predictors p. To tune
|
|
86
|
+
the prior, consider setting a lower `rho` to prefer more sparsity.
|
|
87
|
+
If setting `theta` directly, it should be in the ballpark of p or lower
|
|
88
|
+
as well.
|
|
89
|
+
xinfo
|
|
90
|
+
A matrix with the cutpoins to use to bin each predictor. If not
|
|
91
|
+
specified, it is generated automatically according to `usequants` and
|
|
92
|
+
`numcut`.
|
|
93
|
+
|
|
94
|
+
Each row shall contain a sorted list of cutpoints for a predictor. If
|
|
95
|
+
there are less cutpoints than the number of columns in the matrix,
|
|
96
|
+
fill the remaining cells with NaN.
|
|
97
|
+
|
|
98
|
+
`xinfo` shall be a matrix even if `x_train` is a dataframe.
|
|
99
|
+
usequants
|
|
100
|
+
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
101
|
+
predictors. Ignored if `xinfo` is specified.
|
|
102
|
+
rm_const
|
|
103
|
+
How to treat predictors with no associated decision rules (i.e., there
|
|
104
|
+
are no available cutpoints for that predictor). If `True` (default),
|
|
105
|
+
they are ignored. If `False`, an error is raised if there are any.
|
|
106
|
+
sigest
|
|
107
|
+
An estimate of the residual standard deviation on `y_train`, used to set
|
|
108
|
+
`lamda`. If not specified, it is estimated by linear regression (with
|
|
109
|
+
intercept, and without taking into account `w`). If `y_train` has less
|
|
110
|
+
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
111
|
+
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
112
|
+
sigdf
|
|
113
|
+
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
114
|
+
noise variance.
|
|
115
|
+
sigquant
|
|
116
|
+
The quantile of the prior on the noise variance that shall match
|
|
117
|
+
`sigest` to set the scale of the prior. Ignored if `lamda` is specified.
|
|
118
|
+
k
|
|
119
|
+
The inverse scale of the prior standard deviation on the latent mean
|
|
120
|
+
function, relative to half the observed range of `y_train`. If `y_train`
|
|
121
|
+
has less than two elements, `k` is ignored and the scale is set to 1.
|
|
122
|
+
power
|
|
123
|
+
base
|
|
124
|
+
Parameters of the prior on tree node generation. The probability that a
|
|
125
|
+
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
126
|
+
power``.
|
|
127
|
+
lamda
|
|
128
|
+
The prior harmonic mean of the error variance. (The harmonic mean of x
|
|
129
|
+
is 1/mean(1/x).) If not specified, it is set based on `sigest` and
|
|
130
|
+
`sigquant`.
|
|
131
|
+
tau_num
|
|
132
|
+
The numerator in the expression that determines the prior standard
|
|
133
|
+
deviation of leaves. If not specified, default to ``(max(y_train) -
|
|
134
|
+
min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
|
|
135
|
+
continuous regression, and 3 for binary regression.
|
|
136
|
+
offset
|
|
137
|
+
The prior mean of the latent mean function. If not specified, it is set
|
|
138
|
+
to the mean of `y_train` for continuous regression, and to
|
|
139
|
+
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
140
|
+
`offset` is set to 0. With binary regression, if `y_train` is all
|
|
141
|
+
`False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
|
|
142
|
+
``Phi^-1(n/(n+1))``, respectively.
|
|
143
|
+
w
|
|
144
|
+
Coefficients that rescale the error standard deviation on each
|
|
145
|
+
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
146
|
+
datapoints. Note: `w` is ignored in the automatic determination of
|
|
147
|
+
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
148
|
+
specified by the user.
|
|
149
|
+
ntree
|
|
150
|
+
The number of trees used to represent the latent mean function. By
|
|
151
|
+
default 200 for continuous regression and 50 for binary regression.
|
|
152
|
+
numcut
|
|
153
|
+
If `usequants` is `False`: the exact number of cutpoints used to bin the
|
|
154
|
+
predictors, ranging between the minimum and maximum observed values
|
|
155
|
+
(excluded).
|
|
156
|
+
|
|
157
|
+
If `usequants` is `True`: the maximum number of cutpoints to use for
|
|
158
|
+
binning the predictors. Each predictor is binned such that its
|
|
159
|
+
distribution in `x_train` is approximately uniform across bins. The
|
|
160
|
+
number of bins is at most the number of unique values appearing in
|
|
161
|
+
`x_train`, or ``numcut + 1``.
|
|
162
|
+
|
|
163
|
+
Before running the algorithm, the predictors are compressed to the
|
|
164
|
+
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
165
|
+
to the maximum value of an unsigned integer type, like 255.
|
|
166
|
+
|
|
167
|
+
Ignored if `xinfo` is specified.
|
|
168
|
+
ndpost
|
|
169
|
+
The number of MCMC samples to save, after burn-in. `ndpost` is the
|
|
170
|
+
total number of samples across all chains. `ndpost` is rounded up to the
|
|
171
|
+
first multiple of `mc_cores`.
|
|
172
|
+
nskip
|
|
173
|
+
The number of initial MCMC samples to discard as burn-in. This number
|
|
174
|
+
of samples is discarded from each chain.
|
|
175
|
+
keepevery
|
|
176
|
+
The thinning factor for the MCMC samples, after burn-in. By default, 1
|
|
177
|
+
for continuous regression and 10 for binary regression.
|
|
178
|
+
printevery
|
|
179
|
+
The number of iterations (including thinned-away ones) between each log
|
|
180
|
+
line. Set to `None` to disable logging. ^C interrupts the MCMC only
|
|
181
|
+
every `printevery` iterations, so with logging disabled it's impossible
|
|
182
|
+
to kill the MCMC conveniently.
|
|
183
|
+
mc_cores
|
|
184
|
+
The number of independent MCMC chains.
|
|
185
|
+
seed
|
|
186
|
+
The seed for the random number generator.
|
|
187
|
+
bart_kwargs
|
|
188
|
+
Additional arguments passed to `bartz.Bart`.
|
|
189
|
+
|
|
190
|
+
Notes
|
|
191
|
+
-----
|
|
192
|
+
This interface imitates the function ``mc_gbart`` from the R package `BART3
|
|
193
|
+
<https://github.com/rsparapa/bnptools>`_, but with these differences:
|
|
194
|
+
|
|
195
|
+
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
196
|
+
instead of per column.
|
|
197
|
+
- If ``usequants=False``, R BART3 switches to quantiles anyway if there are
|
|
198
|
+
less predictor values than the required number of bins, while bartz
|
|
199
|
+
always follows the specification.
|
|
200
|
+
- Some functionality is missing.
|
|
201
|
+
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
202
|
+
- There are some additional attributes, and some missing.
|
|
203
|
+
- The trees have a maximum depth of 6.
|
|
204
|
+
- `rm_const` refers to predictors without decision rules instead of
|
|
205
|
+
predictors that are constant in `x_train`.
|
|
206
|
+
- If `rm_const=True` and some variables are dropped, the predictors
|
|
207
|
+
matrix/dataframe passed to `predict` should still include them.
|
|
208
|
+
|
|
209
|
+
References
|
|
210
|
+
----------
|
|
211
|
+
.. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for
|
|
212
|
+
High-Dimensional Prediction and Variable Selection". In: Journal of the
|
|
213
|
+
American Statistical Association 113.522, pp. 626-636.
|
|
214
|
+
.. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
|
|
215
|
+
Bayesian additive regression trees," The Annals of Applied Statistics,
|
|
216
|
+
Ann. Appl. Stat. 4(1), 266-298, (March 2010).
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
_bart: Bart
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
x_train: Real[Array, 'p n'] | DataFrame,
|
|
224
|
+
y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
|
|
225
|
+
*,
|
|
226
|
+
x_test: Real[Array, 'p m'] | DataFrame | None = None,
|
|
227
|
+
type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
|
|
228
|
+
sparse: bool = False,
|
|
229
|
+
theta: FloatLike | None = None,
|
|
230
|
+
a: FloatLike = 0.5,
|
|
231
|
+
b: FloatLike = 1.0,
|
|
232
|
+
rho: FloatLike | None = None,
|
|
233
|
+
xinfo: Float[Array, 'p n'] | None = None,
|
|
234
|
+
usequants: bool = False,
|
|
235
|
+
rm_const: bool = True,
|
|
236
|
+
sigest: FloatLike | None = None,
|
|
237
|
+
sigdf: FloatLike = 3.0,
|
|
238
|
+
sigquant: FloatLike = 0.9,
|
|
239
|
+
k: FloatLike = 2.0,
|
|
240
|
+
power: FloatLike = 2.0,
|
|
241
|
+
base: FloatLike = 0.95,
|
|
242
|
+
lamda: FloatLike | None = None,
|
|
243
|
+
tau_num: FloatLike | None = None,
|
|
244
|
+
offset: FloatLike | None = None,
|
|
245
|
+
w: Float[Array, ' n'] | None = None,
|
|
246
|
+
ntree: int | None = None,
|
|
247
|
+
numcut: int = 100,
|
|
248
|
+
ndpost: int = 1000,
|
|
249
|
+
nskip: int = 100,
|
|
250
|
+
keepevery: int | None = None,
|
|
251
|
+
printevery: int | None = 100,
|
|
252
|
+
mc_cores: int = 2,
|
|
253
|
+
seed: int | Key[Array, ''] = 0,
|
|
254
|
+
bart_kwargs: Mapping = MappingProxyType({}),
|
|
255
|
+
):
|
|
256
|
+
kwargs: dict = dict(
|
|
257
|
+
x_train=x_train,
|
|
258
|
+
y_train=y_train,
|
|
259
|
+
x_test=x_test,
|
|
260
|
+
type=type,
|
|
261
|
+
sparse=sparse,
|
|
262
|
+
theta=theta,
|
|
263
|
+
a=a,
|
|
264
|
+
b=b,
|
|
265
|
+
rho=rho,
|
|
266
|
+
xinfo=xinfo,
|
|
267
|
+
usequants=usequants,
|
|
268
|
+
rm_const=rm_const,
|
|
269
|
+
sigest=sigest,
|
|
270
|
+
sigdf=sigdf,
|
|
271
|
+
sigquant=sigquant,
|
|
272
|
+
k=k,
|
|
273
|
+
power=power,
|
|
274
|
+
base=base,
|
|
275
|
+
lamda=lamda,
|
|
276
|
+
tau_num=tau_num,
|
|
277
|
+
offset=offset,
|
|
278
|
+
w=w,
|
|
279
|
+
ntree=ntree,
|
|
280
|
+
numcut=numcut,
|
|
281
|
+
ndpost=ndpost,
|
|
282
|
+
nskip=nskip,
|
|
283
|
+
keepevery=keepevery,
|
|
284
|
+
printevery=printevery,
|
|
285
|
+
seed=seed,
|
|
286
|
+
maxdepth=6,
|
|
287
|
+
**process_mc_cores(y_train, mc_cores),
|
|
288
|
+
)
|
|
289
|
+
kwargs.update(bart_kwargs)
|
|
290
|
+
self._bart = Bart(**kwargs)
|
|
291
|
+
|
|
292
|
+
# Public attributes from Bart
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def ndpost(self) -> int:
|
|
296
|
+
"""The number of MCMC samples saved, after burn-in."""
|
|
297
|
+
return self._bart.ndpost
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def offset(self) -> Float32[Array, '']:
|
|
301
|
+
"""The prior mean of the latent mean function."""
|
|
302
|
+
return self._bart.offset
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def sigest(self) -> Float32[Array, ''] | None:
|
|
306
|
+
"""The estimated standard deviation of the error used to set `lamda`."""
|
|
307
|
+
return self._bart.sigest
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def yhat_test(self) -> Float32[Array, 'ndpost m'] | None:
|
|
311
|
+
"""The conditional posterior mean at `x_test` for each MCMC iteration."""
|
|
312
|
+
return self._bart.yhat_test
|
|
313
|
+
|
|
314
|
+
# Private attributes from Bart
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def _main_trace(self) -> mcmcloop.MainTrace:
|
|
318
|
+
return self._bart._main_trace # noqa: SLF001
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def _burnin_trace(self) -> mcmcloop.BurninTrace:
|
|
322
|
+
return self._bart._burnin_trace # noqa: SLF001
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def _mcmc_state(self) -> mcmcstep.State:
|
|
326
|
+
return self._bart._mcmc_state # noqa: SLF001
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def _splits(self) -> Real[Array, 'p max_num_splits']:
|
|
330
|
+
return self._bart._splits # noqa: SLF001
|
|
331
|
+
|
|
332
|
+
@property
|
|
333
|
+
def _x_train_fmt(self) -> Any:
|
|
334
|
+
return self._bart._x_train_fmt # noqa: SLF001
|
|
335
|
+
|
|
336
|
+
# Cached properties from Bart
|
|
337
|
+
|
|
338
|
+
@cached_property
|
|
339
|
+
def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
|
|
340
|
+
"""The posterior probability of y being True at `x_test` for each MCMC iteration."""
|
|
341
|
+
return self._bart.prob_test
|
|
342
|
+
|
|
343
|
+
@cached_property
|
|
344
|
+
def prob_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
345
|
+
"""The marginal posterior probability of y being True at `x_test`."""
|
|
346
|
+
return self._bart.prob_test_mean
|
|
347
|
+
|
|
348
|
+
@cached_property
|
|
349
|
+
def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
|
|
350
|
+
"""The posterior probability of y being True at `x_train` for each MCMC iteration."""
|
|
351
|
+
return self._bart.prob_train
|
|
352
|
+
|
|
353
|
+
@cached_property
|
|
354
|
+
def prob_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
355
|
+
"""The marginal posterior probability of y being True at `x_train`."""
|
|
356
|
+
return self._bart.prob_train_mean
|
|
357
|
+
|
|
358
|
+
@cached_property
|
|
359
|
+
def sigma(
|
|
360
|
+
self,
|
|
361
|
+
) -> (
|
|
362
|
+
Float32[Array, ' nskip+ndpost']
|
|
363
|
+
| Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
|
|
364
|
+
| None
|
|
365
|
+
):
|
|
366
|
+
"""The standard deviation of the error, including burn-in samples."""
|
|
367
|
+
return self._bart.sigma
|
|
368
|
+
|
|
369
|
+
@cached_property
|
|
370
|
+
def sigma_(self) -> Float32[Array, 'ndpost'] | None:
|
|
371
|
+
"""The standard deviation of the error, only over the post-burnin samples and flattened."""
|
|
372
|
+
return self._bart.sigma_
|
|
373
|
+
|
|
374
|
+
@cached_property
|
|
375
|
+
def sigma_mean(self) -> Float32[Array, ''] | None:
|
|
376
|
+
"""The mean of `sigma`, only over the post-burnin samples."""
|
|
377
|
+
return self._bart.sigma_mean
|
|
378
|
+
|
|
379
|
+
@cached_property
|
|
380
|
+
def varcount(self) -> Int32[Array, 'ndpost p']:
|
|
381
|
+
"""Histogram of predictor usage for decision rules in the trees."""
|
|
382
|
+
return self._bart.varcount
|
|
383
|
+
|
|
384
|
+
@cached_property
|
|
385
|
+
def varcount_mean(self) -> Float32[Array, ' p']:
|
|
386
|
+
"""Average of `varcount` across MCMC iterations."""
|
|
387
|
+
return self._bart.varcount_mean
|
|
388
|
+
|
|
389
|
+
@cached_property
|
|
390
|
+
def varprob(self) -> Float32[Array, 'ndpost p']:
|
|
391
|
+
"""Posterior samples of the probability of choosing each predictor for a decision rule."""
|
|
392
|
+
return self._bart.varprob
|
|
393
|
+
|
|
394
|
+
@cached_property
|
|
395
|
+
def varprob_mean(self) -> Float32[Array, ' p']:
|
|
396
|
+
"""The marginal posterior probability of each predictor being chosen for a decision rule."""
|
|
397
|
+
return self._bart.varprob_mean
|
|
398
|
+
|
|
399
|
+
@cached_property
|
|
400
|
+
def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
401
|
+
"""The marginal posterior mean at `x_test`.
|
|
402
|
+
|
|
403
|
+
Not defined with binary regression because it's error-prone, typically
|
|
404
|
+
the right thing to consider would be `prob_test_mean`.
|
|
405
|
+
"""
|
|
406
|
+
return self._bart.yhat_test_mean
|
|
407
|
+
|
|
408
|
+
@cached_property
|
|
409
|
+
def yhat_train(self) -> Float32[Array, 'ndpost n']:
|
|
410
|
+
"""The conditional posterior mean at `x_train` for each MCMC iteration."""
|
|
411
|
+
return self._bart.yhat_train
|
|
412
|
+
|
|
413
|
+
@cached_property
|
|
414
|
+
def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
415
|
+
"""The marginal posterior mean at `x_train`.
|
|
416
|
+
|
|
417
|
+
Not defined with binary regression because it's error-prone, typically
|
|
418
|
+
the right thing to consider would be `prob_train_mean`.
|
|
419
|
+
"""
|
|
420
|
+
return self._bart.yhat_train_mean
|
|
421
|
+
|
|
422
|
+
# Public methods from Bart
|
|
423
|
+
|
|
424
|
+
def predict(
|
|
425
|
+
self, x_test: Real[Array, 'p m'] | DataFrame
|
|
426
|
+
) -> Float32[Array, 'ndpost m']:
|
|
427
|
+
"""
|
|
428
|
+
Compute the posterior mean at `x_test` for each MCMC iteration.
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
----------
|
|
432
|
+
x_test
|
|
433
|
+
The test predictors.
|
|
434
|
+
|
|
435
|
+
Returns
|
|
436
|
+
-------
|
|
437
|
+
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
438
|
+
"""
|
|
439
|
+
return self._bart.predict(x_test)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class gbart(mc_gbart):
|
|
443
|
+
"""Subclass of `mc_gbart` that forces `mc_cores=1`."""
|
|
444
|
+
|
|
445
|
+
def __init__(self, *args, **kwargs):
|
|
446
|
+
if 'mc_cores' in kwargs:
|
|
447
|
+
msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'"
|
|
448
|
+
raise TypeError(msg)
|
|
449
|
+
kwargs.update(mc_cores=1)
|
|
450
|
+
super().__init__(*args, **kwargs)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def process_mc_cores(y_train: Array | Any, mc_cores: int) -> dict[str, Any]:
|
|
454
|
+
"""Determine the arguments to pass to `Bart` to configure multiple chains."""
|
|
455
|
+
# one chain, leave default configuration which is num_chains=None
|
|
456
|
+
if abs(mc_cores) == 1:
|
|
457
|
+
return {}
|
|
458
|
+
|
|
459
|
+
# determine if we are on cpu; this point may raise an exception
|
|
460
|
+
platform = get_platform(y_train, mc_cores)
|
|
461
|
+
|
|
462
|
+
# set the num_chains argument
|
|
463
|
+
mc_cores = abs(mc_cores)
|
|
464
|
+
kwargs = dict(num_chains=mc_cores)
|
|
465
|
+
|
|
466
|
+
# if on cpu, try to shard the chains across multiple virtual cpus
|
|
467
|
+
if platform == 'cpu':
|
|
468
|
+
# determine number of logical cpu cores
|
|
469
|
+
num_cores = cpu_count()
|
|
470
|
+
assert num_cores is not None, 'could not determine number of cpu cores'
|
|
471
|
+
|
|
472
|
+
# determine number of shards that evenly divides chains
|
|
473
|
+
for num_shards in range(num_cores, 0, -1):
|
|
474
|
+
if mc_cores % num_shards == 0:
|
|
475
|
+
break
|
|
476
|
+
|
|
477
|
+
# handle the case where there are less jax cpu devices that that
|
|
478
|
+
if num_shards > 1:
|
|
479
|
+
num_jax_cpus = device_count('cpu')
|
|
480
|
+
if num_jax_cpus < num_shards:
|
|
481
|
+
for new_num_shards in range(num_jax_cpus, 0, -1):
|
|
482
|
+
if mc_cores % new_num_shards == 0:
|
|
483
|
+
break
|
|
484
|
+
msg = (
|
|
485
|
+
f'`mc_gbart` would like to shard {mc_cores} chains across '
|
|
486
|
+
f'{num_shards} virtual jax cpu devices, but jax is set up '
|
|
487
|
+
f'with only {num_jax_cpus} cpu devices, so it will use '
|
|
488
|
+
f'{new_num_shards} devices instead. To enable '
|
|
489
|
+
'parallelization, please increase the limit with '
|
|
490
|
+
'`jax.config.update("jax_num_cpu_devices", <num_devices>)`.'
|
|
491
|
+
)
|
|
492
|
+
warn(msg)
|
|
493
|
+
num_shards = new_num_shards
|
|
494
|
+
|
|
495
|
+
# set the number of shards
|
|
496
|
+
if num_shards > 1:
|
|
497
|
+
kwargs.update(num_chain_devices=num_shards)
|
|
498
|
+
|
|
499
|
+
return kwargs
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def get_platform(y_train: Array | Any, mc_cores: int) -> str:
|
|
503
|
+
"""Get the platform for `process_mc_cores` from `y_train` or the default device."""
|
|
504
|
+
if isinstance(y_train, Array) and hasattr(y_train, 'platform'):
|
|
505
|
+
return y_train.platform()
|
|
506
|
+
elif (
|
|
507
|
+
not isinstance(y_train, Array) and hasattr(jnp.zeros(()), 'platform')
|
|
508
|
+
# this condition means: y_train is not an array, but we are not under
|
|
509
|
+
# jit, so y_train is going to be converted to an array on the default
|
|
510
|
+
# device
|
|
511
|
+
) or mc_cores < 0:
|
|
512
|
+
return get_default_device().platform
|
|
513
|
+
else:
|
|
514
|
+
msg = (
|
|
515
|
+
'Could not determine the platform from `y_train`, maybe `mc_gbart` '
|
|
516
|
+
'was used with a `jax.jit`ted function? The platform is needed to '
|
|
517
|
+
'determine whether the computation is going to run on CPU to '
|
|
518
|
+
'automatically shard the chains across multiple virtual CPU '
|
|
519
|
+
'devices. To acknowledge this problem and circumvent it '
|
|
520
|
+
'by using the current default jax device, negate `mc_cores`.'
|
|
521
|
+
)
|
|
522
|
+
raise RuntimeError(msg)
|
bartz/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/__init__.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-2025,
|
|
3
|
+
# Copyright (c) 2024-2025, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -25,8 +25,10 @@
|
|
|
25
25
|
"""
|
|
26
26
|
Super-fast BART (Bayesian Additive Regression Trees) in Python.
|
|
27
27
|
|
|
28
|
-
See the manual at https://
|
|
28
|
+
See the manual at https://bartz-org.github.io/bartz/docs
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
from
|
|
32
|
-
from .
|
|
31
|
+
from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
|
|
32
|
+
from bartz._interface import Bart # noqa: F401
|
|
33
|
+
from bartz._profiler import profile_mode # noqa: F401
|
|
34
|
+
from bartz._version import __version__ # noqa: F401
|