bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/BART.py
DELETED
|
@@ -1,603 +0,0 @@
|
|
|
1
|
-
# bartz/src/bartz/BART.py
|
|
2
|
-
#
|
|
3
|
-
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
|
-
#
|
|
5
|
-
# This file is part of bartz.
|
|
6
|
-
#
|
|
7
|
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
-
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
-
# in the Software without restriction, including without limitation the rights
|
|
10
|
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
-
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
-
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
14
|
-
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
-
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
17
|
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
-
# SOFTWARE.
|
|
24
|
-
|
|
25
|
-
"""Implement a user interface that mimics the R BART package."""
|
|
26
|
-
|
|
27
|
-
import functools
|
|
28
|
-
import math
|
|
29
|
-
from typing import Any, Literal
|
|
30
|
-
|
|
31
|
-
import jax
|
|
32
|
-
import jax.numpy as jnp
|
|
33
|
-
from jax.scipy.special import ndtri
|
|
34
|
-
from jaxtyping import Array, Bool, Float, Float32
|
|
35
|
-
|
|
36
|
-
from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
|
|
37
|
-
|
|
38
|
-
FloatLike = float | Float[Any, '']
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class gbart:
|
|
42
|
-
"""
|
|
43
|
-
Nonparametric regression with Bayesian Additive Regression Trees (BART).
|
|
44
|
-
|
|
45
|
-
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
46
|
-
a sum of decision trees. The inference is carried out by sampling the
|
|
47
|
-
posterior distribution of the tree ensemble with an MCMC.
|
|
48
|
-
|
|
49
|
-
Parameters
|
|
50
|
-
----------
|
|
51
|
-
x_train : array (p, n) or DataFrame
|
|
52
|
-
The training predictors.
|
|
53
|
-
y_train : array (n,) or Series
|
|
54
|
-
The training responses.
|
|
55
|
-
x_test : array (p, m) or DataFrame, optional
|
|
56
|
-
The test predictors.
|
|
57
|
-
type
|
|
58
|
-
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
59
|
-
binary regression with probit link.
|
|
60
|
-
usequants : bool, default False
|
|
61
|
-
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
62
|
-
predictors.
|
|
63
|
-
sigest : float, optional
|
|
64
|
-
An estimate of the residual standard deviation on `y_train`, used to set
|
|
65
|
-
`lamda`. If not specified, it is estimated by linear regression (with
|
|
66
|
-
intercept, and without taking into account `w`). If `y_train` has less
|
|
67
|
-
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
68
|
-
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
69
|
-
sigdf : int, default 3
|
|
70
|
-
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
71
|
-
noise variance.
|
|
72
|
-
sigquant : float, default 0.9
|
|
73
|
-
The quantile of the prior on the noise variance that shall match
|
|
74
|
-
`sigest` to set the scale of the prior. Ignored if `lamda` is specified.
|
|
75
|
-
k : float, default 2
|
|
76
|
-
The inverse scale of the prior standard deviation on the latent mean
|
|
77
|
-
function, relative to half the observed range of `y_train`. If `y_train`
|
|
78
|
-
has less than two elements, `k` is ignored and the scale is set to 1.
|
|
79
|
-
power : float, default 2
|
|
80
|
-
base : float, default 0.95
|
|
81
|
-
Parameters of the prior on tree node generation. The probability that a
|
|
82
|
-
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
83
|
-
power``.
|
|
84
|
-
lamda
|
|
85
|
-
The prior harmonic mean of the error variance. (The harmonic mean of x
|
|
86
|
-
is 1/mean(1/x).) If not specified, it is set based on `sigest` and
|
|
87
|
-
`sigquant`.
|
|
88
|
-
tau_num
|
|
89
|
-
The numerator in the expression that determines the prior standard
|
|
90
|
-
deviation of leaves. If not specified, default to ``(max(y_train) -
|
|
91
|
-
min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
|
|
92
|
-
continuous regression, and 3 for binary regression.
|
|
93
|
-
offset
|
|
94
|
-
The prior mean of the latent mean function. If not specified, it is set
|
|
95
|
-
to the mean of `y_train` for continuous regression, and to
|
|
96
|
-
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
97
|
-
`offset` is set to 0.
|
|
98
|
-
w : array (n,), optional
|
|
99
|
-
Coefficients that rescale the error standard deviation on each
|
|
100
|
-
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
101
|
-
datapoints. Note: `w` is ignored in the automatic determination of
|
|
102
|
-
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
103
|
-
specified by the user.
|
|
104
|
-
ntree : int, default 200
|
|
105
|
-
The number of trees used to represent the latent mean function.
|
|
106
|
-
numcut : int, default 255
|
|
107
|
-
If `usequants` is `False`: the exact number of cutpoints used to bin the
|
|
108
|
-
predictors, ranging between the minimum and maximum observed values
|
|
109
|
-
(excluded).
|
|
110
|
-
|
|
111
|
-
If `usequants` is `True`: the maximum number of cutpoints to use for
|
|
112
|
-
binning the predictors. Each predictor is binned such that its
|
|
113
|
-
distribution in `x_train` is approximately uniform across bins. The
|
|
114
|
-
number of bins is at most the number of unique values appearing in
|
|
115
|
-
`x_train`, or ``numcut + 1``.
|
|
116
|
-
|
|
117
|
-
Before running the algorithm, the predictors are compressed to the
|
|
118
|
-
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
119
|
-
to the maximum value of an unsigned integer type.
|
|
120
|
-
ndpost : int, default 1000
|
|
121
|
-
The number of MCMC samples to save, after burn-in.
|
|
122
|
-
nskip : int, default 100
|
|
123
|
-
The number of initial MCMC samples to discard as burn-in.
|
|
124
|
-
keepevery : int, default 1
|
|
125
|
-
The thinning factor for the MCMC samples, after burn-in.
|
|
126
|
-
printevery : int or None, default 100
|
|
127
|
-
The number of iterations (including thinned-away ones) between each log
|
|
128
|
-
line. Set to `None` to disable logging.
|
|
129
|
-
|
|
130
|
-
`printevery` has a few unexpected side effects. On cpu, interrupting
|
|
131
|
-
with ^C halts the MCMC only on the next log. And the total number of
|
|
132
|
-
iterations is a multiple of `printevery`, so if ``nskip + keepevery *
|
|
133
|
-
ndpost`` is not a multiple of `printevery`, some of the last iterations
|
|
134
|
-
will not be saved.
|
|
135
|
-
seed : int or jax random key, default 0
|
|
136
|
-
The seed for the random number generator.
|
|
137
|
-
maxdepth : int, default 6
|
|
138
|
-
The maximum depth of the trees. This is 1-based, so with the default
|
|
139
|
-
``maxdepth=6``, the depths of the levels range from 0 to 5.
|
|
140
|
-
init_kw : dict
|
|
141
|
-
Additional arguments passed to `mcmcstep.init`.
|
|
142
|
-
run_mcmc_kw : dict
|
|
143
|
-
Additional arguments passed to `mcmcloop.run_mcmc`.
|
|
144
|
-
|
|
145
|
-
Attributes
|
|
146
|
-
----------
|
|
147
|
-
yhat_train : array (ndpost, n)
|
|
148
|
-
The conditional posterior mean at `x_train` for each MCMC iteration.
|
|
149
|
-
yhat_train_mean : array (n,)
|
|
150
|
-
The marginal posterior mean at `x_train`.
|
|
151
|
-
yhat_test : array (ndpost, m)
|
|
152
|
-
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
153
|
-
yhat_test_mean : array (m,)
|
|
154
|
-
The marginal posterior mean at `x_test`.
|
|
155
|
-
sigma : array (ndpost,)
|
|
156
|
-
The standard deviation of the error.
|
|
157
|
-
first_sigma : array (nskip,)
|
|
158
|
-
The standard deviation of the error in the burn-in phase.
|
|
159
|
-
offset : float
|
|
160
|
-
The prior mean of the latent mean function.
|
|
161
|
-
sigest : float or None
|
|
162
|
-
The estimated standard deviation of the error used to set `lamda`.
|
|
163
|
-
|
|
164
|
-
Notes
|
|
165
|
-
-----
|
|
166
|
-
This interface imitates the function ``gbart`` from the R package `BART
|
|
167
|
-
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
168
|
-
|
|
169
|
-
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
170
|
-
instead of per column.
|
|
171
|
-
- If `type` is not specified, it is determined solely based on the data type
|
|
172
|
-
of `y_train`, and not on whether it contains only two unique values.
|
|
173
|
-
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
174
|
-
less predictor values than the required number of bins, while bartz
|
|
175
|
-
always follows the specification.
|
|
176
|
-
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
177
|
-
- `rm_const` is always `False`.
|
|
178
|
-
- The default `numcut` is 255 instead of 100.
|
|
179
|
-
- A lot of functionality is missing (e.g., variable selection).
|
|
180
|
-
- There are some additional attributes, and some missing.
|
|
181
|
-
- The trees have a maximum depth.
|
|
182
|
-
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
def __init__(
|
|
186
|
-
self,
|
|
187
|
-
x_train,
|
|
188
|
-
y_train,
|
|
189
|
-
*,
|
|
190
|
-
x_test=None,
|
|
191
|
-
type: Literal['wbart', 'pbart'] = 'wbart',
|
|
192
|
-
usequants=False,
|
|
193
|
-
sigest=None,
|
|
194
|
-
sigdf=3,
|
|
195
|
-
sigquant=0.9,
|
|
196
|
-
k=2,
|
|
197
|
-
power=2,
|
|
198
|
-
base=0.95,
|
|
199
|
-
lamda: FloatLike | None = None,
|
|
200
|
-
tau_num: FloatLike | None = None,
|
|
201
|
-
offset: FloatLike | None = None,
|
|
202
|
-
w=None,
|
|
203
|
-
ntree=200,
|
|
204
|
-
numcut=255,
|
|
205
|
-
ndpost=1000,
|
|
206
|
-
nskip=100,
|
|
207
|
-
keepevery=1,
|
|
208
|
-
printevery=100,
|
|
209
|
-
seed=0,
|
|
210
|
-
maxdepth=6,
|
|
211
|
-
init_kw=None,
|
|
212
|
-
run_mcmc_kw=None,
|
|
213
|
-
):
|
|
214
|
-
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
215
|
-
y_train, _ = self._process_response_input(y_train)
|
|
216
|
-
self._check_same_length(x_train, y_train)
|
|
217
|
-
if w is not None:
|
|
218
|
-
w, _ = self._process_response_input(w)
|
|
219
|
-
self._check_same_length(x_train, w)
|
|
220
|
-
|
|
221
|
-
y_train = self._process_type_settings(y_train, type, w)
|
|
222
|
-
# from here onwards, the type is determined by y_train.dtype == bool
|
|
223
|
-
offset = self._process_offset_settings(y_train, offset)
|
|
224
|
-
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
|
|
225
|
-
lamda, sigest = self._process_error_variance_settings(
|
|
226
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
splits, max_split = self._determine_splits(x_train, usequants, numcut)
|
|
230
|
-
x_train = self._bin_predictors(x_train, splits)
|
|
231
|
-
|
|
232
|
-
mcmc_state = self._setup_mcmc(
|
|
233
|
-
x_train,
|
|
234
|
-
y_train,
|
|
235
|
-
offset,
|
|
236
|
-
w,
|
|
237
|
-
max_split,
|
|
238
|
-
lamda,
|
|
239
|
-
sigma_mu,
|
|
240
|
-
sigdf,
|
|
241
|
-
power,
|
|
242
|
-
base,
|
|
243
|
-
maxdepth,
|
|
244
|
-
ntree,
|
|
245
|
-
init_kw,
|
|
246
|
-
)
|
|
247
|
-
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
248
|
-
mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
sigma = self._extract_sigma(main_trace)
|
|
252
|
-
first_sigma = self._extract_sigma(burnin_trace)
|
|
253
|
-
|
|
254
|
-
self.offset = final_state.offset # from the state because of buffer donation
|
|
255
|
-
self.sigest = sigest
|
|
256
|
-
self.sigma = sigma
|
|
257
|
-
self.first_sigma = first_sigma
|
|
258
|
-
|
|
259
|
-
self._x_train_fmt = x_train_fmt
|
|
260
|
-
self._splits = splits
|
|
261
|
-
self._main_trace = main_trace
|
|
262
|
-
self._mcmc_state = final_state
|
|
263
|
-
|
|
264
|
-
if x_test is not None:
|
|
265
|
-
yhat_test = self.predict(x_test)
|
|
266
|
-
self.yhat_test = yhat_test
|
|
267
|
-
self.yhat_test_mean = yhat_test.mean(axis=0)
|
|
268
|
-
|
|
269
|
-
@functools.cached_property
|
|
270
|
-
def yhat_train(self):
|
|
271
|
-
x_train = self._mcmc_state.X
|
|
272
|
-
return self._predict(self._main_trace, x_train)
|
|
273
|
-
|
|
274
|
-
@functools.cached_property
|
|
275
|
-
def yhat_train_mean(self):
|
|
276
|
-
return self.yhat_train.mean(axis=0)
|
|
277
|
-
|
|
278
|
-
def predict(self, x_test):
|
|
279
|
-
"""
|
|
280
|
-
Compute the posterior mean at `x_test` for each MCMC iteration.
|
|
281
|
-
|
|
282
|
-
Parameters
|
|
283
|
-
----------
|
|
284
|
-
x_test : array (p, m) or DataFrame
|
|
285
|
-
The test predictors.
|
|
286
|
-
|
|
287
|
-
Returns
|
|
288
|
-
-------
|
|
289
|
-
yhat_test : array (ndpost, m)
|
|
290
|
-
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
291
|
-
|
|
292
|
-
Raises
|
|
293
|
-
------
|
|
294
|
-
ValueError
|
|
295
|
-
If `x_test` has a different format than `x_train`.
|
|
296
|
-
"""
|
|
297
|
-
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
298
|
-
if x_test_fmt != self._x_train_fmt:
|
|
299
|
-
raise ValueError(
|
|
300
|
-
f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
|
|
301
|
-
)
|
|
302
|
-
x_test = self._bin_predictors(x_test, self._splits)
|
|
303
|
-
return self._predict(self._main_trace, x_test)
|
|
304
|
-
|
|
305
|
-
@staticmethod
|
|
306
|
-
def _process_predictor_input(x):
|
|
307
|
-
if hasattr(x, 'columns'):
|
|
308
|
-
fmt = dict(kind='dataframe', columns=x.columns)
|
|
309
|
-
x = x.to_numpy().T
|
|
310
|
-
else:
|
|
311
|
-
fmt = dict(kind='array', num_covar=x.shape[0])
|
|
312
|
-
x = jnp.asarray(x)
|
|
313
|
-
assert x.ndim == 2
|
|
314
|
-
return x, fmt
|
|
315
|
-
|
|
316
|
-
@staticmethod
|
|
317
|
-
def _process_response_input(y):
|
|
318
|
-
if hasattr(y, 'to_numpy'):
|
|
319
|
-
fmt = dict(kind='series', name=y.name)
|
|
320
|
-
y = y.to_numpy()
|
|
321
|
-
else:
|
|
322
|
-
fmt = dict(kind='array')
|
|
323
|
-
y = jnp.asarray(y)
|
|
324
|
-
assert y.ndim == 1
|
|
325
|
-
return y, fmt
|
|
326
|
-
|
|
327
|
-
@staticmethod
|
|
328
|
-
def _check_same_length(x1, x2):
|
|
329
|
-
get_length = lambda x: x.shape[-1]
|
|
330
|
-
assert get_length(x1) == get_length(x2)
|
|
331
|
-
|
|
332
|
-
@staticmethod
|
|
333
|
-
def _process_error_variance_settings(
|
|
334
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
335
|
-
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
336
|
-
if y_train.dtype == bool:
|
|
337
|
-
if sigest is not None:
|
|
338
|
-
raise ValueError('Let `sigest=None` for binary regression')
|
|
339
|
-
if lamda is not None:
|
|
340
|
-
raise ValueError('Let `lamda=None` for binary regression')
|
|
341
|
-
return None, None
|
|
342
|
-
elif lamda is not None:
|
|
343
|
-
if sigest is not None:
|
|
344
|
-
raise ValueError('Let `sigest=None` if `lamda` is specified')
|
|
345
|
-
return lamda, None
|
|
346
|
-
else:
|
|
347
|
-
if sigest is not None:
|
|
348
|
-
sigest2 = jnp.square(sigest)
|
|
349
|
-
elif y_train.size < 2:
|
|
350
|
-
sigest2 = 1
|
|
351
|
-
elif y_train.size <= x_train.shape[0]:
|
|
352
|
-
sigest2 = jnp.var(y_train)
|
|
353
|
-
else:
|
|
354
|
-
x_centered = x_train.T - x_train.mean(axis=1)
|
|
355
|
-
y_centered = y_train - y_train.mean()
|
|
356
|
-
# centering is equivalent to adding an intercept column
|
|
357
|
-
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
358
|
-
chisq = chisq.squeeze(0)
|
|
359
|
-
dof = len(y_train) - rank
|
|
360
|
-
sigest2 = chisq / dof
|
|
361
|
-
alpha = sigdf / 2
|
|
362
|
-
invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
|
|
363
|
-
invchi2rid = invchi2 * sigdf
|
|
364
|
-
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
365
|
-
|
|
366
|
-
@staticmethod
|
|
367
|
-
def _process_type_settings(y_train, type, w):
|
|
368
|
-
match type:
|
|
369
|
-
case 'wbart':
|
|
370
|
-
if y_train.dtype != jnp.float32:
|
|
371
|
-
raise TypeError(
|
|
372
|
-
'Continuous regression requires y_train.dtype=float32,'
|
|
373
|
-
f' got {y_train.dtype=} instead.'
|
|
374
|
-
)
|
|
375
|
-
case 'pbart':
|
|
376
|
-
if w is not None:
|
|
377
|
-
raise ValueError(
|
|
378
|
-
'Binary regression does not support weights, set `w=None`'
|
|
379
|
-
)
|
|
380
|
-
if y_train.dtype != bool:
|
|
381
|
-
raise TypeError(
|
|
382
|
-
'Binary regression requires y_train.dtype=bool,'
|
|
383
|
-
f' got {y_train.dtype=} instead.'
|
|
384
|
-
)
|
|
385
|
-
case _:
|
|
386
|
-
raise ValueError(f'Invalid {type=}')
|
|
387
|
-
|
|
388
|
-
return y_train
|
|
389
|
-
|
|
390
|
-
@staticmethod
|
|
391
|
-
def _process_offset_settings(
|
|
392
|
-
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
393
|
-
offset: float | Float32[Any, ''] | None,
|
|
394
|
-
) -> Float32[Array, '']:
|
|
395
|
-
if offset is not None:
|
|
396
|
-
return jnp.asarray(offset)
|
|
397
|
-
elif y_train.size < 1:
|
|
398
|
-
return jnp.array(0.0)
|
|
399
|
-
else:
|
|
400
|
-
mean = y_train.mean()
|
|
401
|
-
|
|
402
|
-
if y_train.dtype == bool:
|
|
403
|
-
return ndtri(mean)
|
|
404
|
-
else:
|
|
405
|
-
return mean
|
|
406
|
-
|
|
407
|
-
@staticmethod
|
|
408
|
-
def _process_leaf_sdev_settings(
|
|
409
|
-
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
410
|
-
k: float,
|
|
411
|
-
ntree: int,
|
|
412
|
-
tau_num: FloatLike | None,
|
|
413
|
-
):
|
|
414
|
-
if tau_num is None:
|
|
415
|
-
if y_train.dtype == bool:
|
|
416
|
-
tau_num = 3.0
|
|
417
|
-
elif y_train.size < 2:
|
|
418
|
-
tau_num = 1.0
|
|
419
|
-
else:
|
|
420
|
-
tau_num = (y_train.max() - y_train.min()) / 2
|
|
421
|
-
|
|
422
|
-
return tau_num / (k * math.sqrt(ntree))
|
|
423
|
-
|
|
424
|
-
@staticmethod
|
|
425
|
-
def _determine_splits(x_train, usequants, numcut):
|
|
426
|
-
if usequants:
|
|
427
|
-
return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
|
|
428
|
-
else:
|
|
429
|
-
return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
|
|
430
|
-
|
|
431
|
-
@staticmethod
|
|
432
|
-
def _bin_predictors(x, splits):
|
|
433
|
-
return prepcovars.bin_predictors(x, splits)
|
|
434
|
-
|
|
435
|
-
@staticmethod
|
|
436
|
-
def _setup_mcmc(
|
|
437
|
-
x_train,
|
|
438
|
-
y_train,
|
|
439
|
-
offset,
|
|
440
|
-
w,
|
|
441
|
-
max_split,
|
|
442
|
-
lamda,
|
|
443
|
-
sigma_mu,
|
|
444
|
-
sigdf,
|
|
445
|
-
power,
|
|
446
|
-
base,
|
|
447
|
-
maxdepth,
|
|
448
|
-
ntree,
|
|
449
|
-
init_kw,
|
|
450
|
-
):
|
|
451
|
-
depth = jnp.arange(maxdepth - 1)
|
|
452
|
-
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
453
|
-
|
|
454
|
-
if y_train.dtype == bool:
|
|
455
|
-
sigma2_alpha = None
|
|
456
|
-
sigma2_beta = None
|
|
457
|
-
else:
|
|
458
|
-
sigma2_alpha = sigdf / 2
|
|
459
|
-
sigma2_beta = lamda * sigma2_alpha
|
|
460
|
-
|
|
461
|
-
kw = dict(
|
|
462
|
-
X=x_train,
|
|
463
|
-
# copy y_train because it's going to be donated in the mcmc loop
|
|
464
|
-
y=jnp.array(y_train),
|
|
465
|
-
offset=offset,
|
|
466
|
-
error_scale=w,
|
|
467
|
-
max_split=max_split,
|
|
468
|
-
num_trees=ntree,
|
|
469
|
-
p_nonterminal=p_nonterminal,
|
|
470
|
-
sigma_mu2=jnp.square(sigma_mu),
|
|
471
|
-
sigma2_alpha=sigma2_alpha,
|
|
472
|
-
sigma2_beta=sigma2_beta,
|
|
473
|
-
min_points_per_leaf=5,
|
|
474
|
-
)
|
|
475
|
-
if init_kw is not None:
|
|
476
|
-
kw.update(init_kw)
|
|
477
|
-
return mcmcstep.init(**kw)
|
|
478
|
-
|
|
479
|
-
@staticmethod
|
|
480
|
-
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
|
|
481
|
-
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
482
|
-
seed.dtype, jax.dtypes.prng_key
|
|
483
|
-
):
|
|
484
|
-
key = seed.copy()
|
|
485
|
-
# copy because the inner loop in run_mcmc will donate the buffer
|
|
486
|
-
else:
|
|
487
|
-
key = jax.random.key(seed)
|
|
488
|
-
|
|
489
|
-
kw = dict(
|
|
490
|
-
n_burn=nskip,
|
|
491
|
-
n_skip=keepevery,
|
|
492
|
-
inner_loop_length=printevery,
|
|
493
|
-
allow_overflow=True,
|
|
494
|
-
)
|
|
495
|
-
if printevery is not None:
|
|
496
|
-
kw.update(mcmcloop.make_print_callbacks())
|
|
497
|
-
if run_mcmc_kw is not None:
|
|
498
|
-
kw.update(run_mcmc_kw)
|
|
499
|
-
|
|
500
|
-
return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
|
|
501
|
-
|
|
502
|
-
@staticmethod
|
|
503
|
-
def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
|
|
504
|
-
if trace['sigma2'] is None:
|
|
505
|
-
return None
|
|
506
|
-
else:
|
|
507
|
-
return jnp.sqrt(trace['sigma2'])
|
|
508
|
-
|
|
509
|
-
@staticmethod
|
|
510
|
-
def _predict(trace, x):
|
|
511
|
-
return mcmcloop.evaluate_trace(trace, x)
|
|
512
|
-
|
|
513
|
-
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
514
|
-
from . import debug
|
|
515
|
-
|
|
516
|
-
trace = self._main_trace
|
|
517
|
-
leaf_tree = trace['leaf_trees'][i_sample, i_tree]
|
|
518
|
-
var_tree = trace['var_trees'][i_sample, i_tree]
|
|
519
|
-
split_tree = trace['split_trees'][i_sample, i_tree]
|
|
520
|
-
debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
|
|
521
|
-
|
|
522
|
-
def _sigma_harmonic_mean(self, prior=False):
|
|
523
|
-
bart = self._mcmc_state
|
|
524
|
-
if prior:
|
|
525
|
-
alpha = bart['sigma2_alpha']
|
|
526
|
-
beta = bart['sigma2_beta']
|
|
527
|
-
else:
|
|
528
|
-
resid = bart['resid']
|
|
529
|
-
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
530
|
-
norm2 = jnp.dot(
|
|
531
|
-
resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
|
|
532
|
-
)
|
|
533
|
-
beta = bart['sigma2_beta'] + norm2 / 2
|
|
534
|
-
sigma2 = beta / alpha
|
|
535
|
-
return jnp.sqrt(sigma2)
|
|
536
|
-
|
|
537
|
-
def _compare_resid(self):
|
|
538
|
-
bart = self._mcmc_state
|
|
539
|
-
resid1 = bart.resid
|
|
540
|
-
|
|
541
|
-
trees = grove.evaluate_forest(
|
|
542
|
-
bart.X,
|
|
543
|
-
bart.forest.leaf_trees,
|
|
544
|
-
bart.forest.var_trees,
|
|
545
|
-
bart.forest.split_trees,
|
|
546
|
-
jnp.float32, # TODO remove these configurable dtypes around
|
|
547
|
-
)
|
|
548
|
-
|
|
549
|
-
if bart.z is not None:
|
|
550
|
-
ref = bart.z
|
|
551
|
-
else:
|
|
552
|
-
ref = bart.y
|
|
553
|
-
resid2 = ref - (trees + bart.offset)
|
|
554
|
-
|
|
555
|
-
return resid1, resid2
|
|
556
|
-
|
|
557
|
-
def _avg_acc(self):
|
|
558
|
-
trace = self._main_trace
|
|
559
|
-
|
|
560
|
-
def acc(prefix):
|
|
561
|
-
acc = trace[f'{prefix}_acc_count']
|
|
562
|
-
prop = trace[f'{prefix}_prop_count']
|
|
563
|
-
return acc.sum() / prop.sum()
|
|
564
|
-
|
|
565
|
-
return acc('grow'), acc('prune')
|
|
566
|
-
|
|
567
|
-
def _avg_prop(self):
|
|
568
|
-
trace = self._main_trace
|
|
569
|
-
|
|
570
|
-
def prop(prefix):
|
|
571
|
-
return trace[f'{prefix}_prop_count'].sum()
|
|
572
|
-
|
|
573
|
-
pgrow = prop('grow')
|
|
574
|
-
pprune = prop('prune')
|
|
575
|
-
total = pgrow + pprune
|
|
576
|
-
return pgrow / total, pprune / total
|
|
577
|
-
|
|
578
|
-
def _avg_move(self):
|
|
579
|
-
agrow, aprune = self._avg_acc()
|
|
580
|
-
pgrow, pprune = self._avg_prop()
|
|
581
|
-
return agrow * pgrow, aprune * pprune
|
|
582
|
-
|
|
583
|
-
def _depth_distr(self):
|
|
584
|
-
from . import debug
|
|
585
|
-
|
|
586
|
-
trace = self._main_trace
|
|
587
|
-
split_trees = trace['split_trees']
|
|
588
|
-
return debug.trace_depth_distr(split_trees)
|
|
589
|
-
|
|
590
|
-
def _points_per_leaf_distr(self):
|
|
591
|
-
from . import debug
|
|
592
|
-
|
|
593
|
-
return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
|
|
594
|
-
|
|
595
|
-
def _check_trees(self):
|
|
596
|
-
from . import debug
|
|
597
|
-
|
|
598
|
-
return debug.check_trace(self._main_trace, self._mcmc_state)
|
|
599
|
-
|
|
600
|
-
def _tree_goes_bad(self):
|
|
601
|
-
bad = self._check_trees().astype(bool)
|
|
602
|
-
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
603
|
-
return bad & ~bad_before
|