bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +266 -113
- bartz/__init__.py +4 -12
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +62 -12
- bartz/jaxext.py +111 -37
- bartz/mcmcloop.py +419 -105
- bartz/mcmcstep.py +1528 -760
- bartz/prepcovars.py +25 -10
- {bartz-0.4.1.dist-info → bartz-0.6.0.dist-info}/METADATA +14 -16
- bartz-0.6.0.dist-info/RECORD +13 -0
- bartz-0.6.0.dist-info/WHEEL +4 -0
- bartz-0.4.1.dist-info/LICENSE +0 -21
- bartz-0.4.1.dist-info/RECORD +0 -13
- bartz-0.4.1.dist-info/WHEEL +0 -4
bartz/.DS_Store
ADDED
|
Binary file
|
bartz/BART.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -22,16 +22,21 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
+
"""Implement a user interface that mimics the R BART package."""
|
|
26
|
+
|
|
25
27
|
import functools
|
|
28
|
+
import math
|
|
29
|
+
from typing import Any, Literal
|
|
26
30
|
|
|
27
31
|
import jax
|
|
28
32
|
import jax.numpy as jnp
|
|
33
|
+
from jax.scipy.special import ndtri
|
|
34
|
+
from jaxtyping import Array, Bool, Float, Float32
|
|
35
|
+
|
|
36
|
+
from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
|
|
37
|
+
|
|
38
|
+
FloatLike = float | Float[Any, '']
|
|
29
39
|
|
|
30
|
-
from . import jaxext
|
|
31
|
-
from . import grove
|
|
32
|
-
from . import mcmcstep
|
|
33
|
-
from . import mcmcloop
|
|
34
|
-
from . import prepcovars
|
|
35
40
|
|
|
36
41
|
class gbart:
|
|
37
42
|
"""
|
|
@@ -49,14 +54,18 @@ class gbart:
|
|
|
49
54
|
The training responses.
|
|
50
55
|
x_test : array (p, m) or DataFrame, optional
|
|
51
56
|
The test predictors.
|
|
57
|
+
type
|
|
58
|
+
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
59
|
+
binary regression with probit link.
|
|
52
60
|
usequants : bool, default False
|
|
53
61
|
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
54
62
|
predictors.
|
|
55
63
|
sigest : float, optional
|
|
56
|
-
An estimate of the residual standard deviation on `y_train`, used to
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
is set to
|
|
64
|
+
An estimate of the residual standard deviation on `y_train`, used to set
|
|
65
|
+
`lamda`. If not specified, it is estimated by linear regression (with
|
|
66
|
+
intercept, and without taking into account `w`). If `y_train` has less
|
|
67
|
+
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
68
|
+
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
60
69
|
sigdf : int, default 3
|
|
61
70
|
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
62
71
|
noise variance.
|
|
@@ -72,16 +81,26 @@ class gbart:
|
|
|
72
81
|
Parameters of the prior on tree node generation. The probability that a
|
|
73
82
|
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
74
83
|
power``.
|
|
75
|
-
|
|
76
|
-
The
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
84
|
+
lamda
|
|
85
|
+
The prior harmonic mean of the error variance. (The harmonic mean of x
|
|
86
|
+
is 1/mean(1/x).) If not specified, it is set based on `sigest` and
|
|
87
|
+
`sigquant`.
|
|
88
|
+
tau_num
|
|
89
|
+
The numerator in the expression that determines the prior standard
|
|
90
|
+
deviation of leaves. If not specified, default to ``(max(y_train) -
|
|
91
|
+
min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
|
|
92
|
+
continuous regression, and 3 for binary regression.
|
|
93
|
+
offset
|
|
83
94
|
The prior mean of the latent mean function. If not specified, it is set
|
|
84
|
-
to the mean of `y_train
|
|
95
|
+
to the mean of `y_train` for continuous regression, and to
|
|
96
|
+
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
97
|
+
`offset` is set to 0.
|
|
98
|
+
w : array (n,), optional
|
|
99
|
+
Coefficients that rescale the error standard deviation on each
|
|
100
|
+
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
101
|
+
datapoints. Note: `w` is ignored in the automatic determination of
|
|
102
|
+
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
103
|
+
specified by the user.
|
|
85
104
|
ntree : int, default 200
|
|
86
105
|
The number of trees used to represent the latent mean function.
|
|
87
106
|
numcut : int, default 255
|
|
@@ -104,10 +123,24 @@ class gbart:
|
|
|
104
123
|
The number of initial MCMC samples to discard as burn-in.
|
|
105
124
|
keepevery : int, default 1
|
|
106
125
|
The thinning factor for the MCMC samples, after burn-in.
|
|
107
|
-
printevery : int, default 100
|
|
108
|
-
The number of iterations (including
|
|
126
|
+
printevery : int or None, default 100
|
|
127
|
+
The number of iterations (including thinned-away ones) between each log
|
|
128
|
+
line. Set to `None` to disable logging.
|
|
129
|
+
|
|
130
|
+
`printevery` has a few unexpected side effects. On cpu, interrupting
|
|
131
|
+
with ^C halts the MCMC only on the next log. And the total number of
|
|
132
|
+
iterations is a multiple of `printevery`, so if ``nskip + keepevery *
|
|
133
|
+
ndpost`` is not a multiple of `printevery`, some of the last iterations
|
|
134
|
+
will not be saved.
|
|
109
135
|
seed : int or jax random key, default 0
|
|
110
136
|
The seed for the random number generator.
|
|
137
|
+
maxdepth : int, default 6
|
|
138
|
+
The maximum depth of the trees. This is 1-based, so with the default
|
|
139
|
+
``maxdepth=6``, the depths of the levels range from 0 to 5.
|
|
140
|
+
init_kw : dict
|
|
141
|
+
Additional arguments passed to `mcmcstep.init`.
|
|
142
|
+
run_mcmc_kw : dict
|
|
143
|
+
Additional arguments passed to `mcmcloop.run_mcmc`.
|
|
111
144
|
|
|
112
145
|
Attributes
|
|
113
146
|
----------
|
|
@@ -125,22 +158,8 @@ class gbart:
|
|
|
125
158
|
The standard deviation of the error in the burn-in phase.
|
|
126
159
|
offset : float
|
|
127
160
|
The prior mean of the latent mean function.
|
|
128
|
-
scale : float
|
|
129
|
-
The prior standard deviation of the latent mean function.
|
|
130
|
-
lamda : float
|
|
131
|
-
The prior harmonic mean of the error variance.
|
|
132
161
|
sigest : float or None
|
|
133
162
|
The estimated standard deviation of the error used to set `lamda`.
|
|
134
|
-
ntree : int
|
|
135
|
-
The number of trees.
|
|
136
|
-
maxdepth : int
|
|
137
|
-
The maximum depth of the trees.
|
|
138
|
-
initkw : dict
|
|
139
|
-
Additional arguments passed to `mcmcstep.init`.
|
|
140
|
-
|
|
141
|
-
Methods
|
|
142
|
-
-------
|
|
143
|
-
predict
|
|
144
163
|
|
|
145
164
|
Notes
|
|
146
165
|
-----
|
|
@@ -149,20 +168,27 @@ class gbart:
|
|
|
149
168
|
|
|
150
169
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
151
170
|
instead of per column.
|
|
171
|
+
- If `type` is not specified, it is determined solely based on the data type
|
|
172
|
+
of `y_train`, and not on whether it contains only two unique values.
|
|
152
173
|
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
153
174
|
less predictor values than the required number of bins, while bartz
|
|
154
175
|
always follows the specification.
|
|
155
176
|
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
156
177
|
- `rm_const` is always `False`.
|
|
157
178
|
- The default `numcut` is 255 instead of 100.
|
|
158
|
-
- A lot of functionality is missing (variable selection
|
|
179
|
+
- A lot of functionality is missing (e.g., variable selection).
|
|
159
180
|
- There are some additional attributes, and some missing.
|
|
181
|
+
- The trees have a maximum depth.
|
|
160
182
|
|
|
161
|
-
The linear regression used to set `sigest` adds an intercept.
|
|
162
183
|
"""
|
|
163
184
|
|
|
164
|
-
def __init__(
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
x_train,
|
|
188
|
+
y_train,
|
|
189
|
+
*,
|
|
165
190
|
x_test=None,
|
|
191
|
+
type: Literal['wbart', 'pbart'] = 'wbart',
|
|
166
192
|
usequants=False,
|
|
167
193
|
sigest=None,
|
|
168
194
|
sigdf=3,
|
|
@@ -170,9 +196,10 @@ class gbart:
|
|
|
170
196
|
k=2,
|
|
171
197
|
power=2,
|
|
172
198
|
base=0.95,
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
offset=None,
|
|
199
|
+
lamda: FloatLike | None = None,
|
|
200
|
+
tau_num: FloatLike | None = None,
|
|
201
|
+
offset: FloatLike | None = None,
|
|
202
|
+
w=None,
|
|
176
203
|
ntree=200,
|
|
177
204
|
numcut=255,
|
|
178
205
|
ndpost=1000,
|
|
@@ -180,36 +207,52 @@ class gbart:
|
|
|
180
207
|
keepevery=1,
|
|
181
208
|
printevery=100,
|
|
182
209
|
seed=0,
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
210
|
+
maxdepth=6,
|
|
211
|
+
init_kw=None,
|
|
212
|
+
run_mcmc_kw=None,
|
|
213
|
+
):
|
|
186
214
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
187
|
-
|
|
188
|
-
y_train, y_train_fmt = self._process_response_input(y_train)
|
|
215
|
+
y_train, _ = self._process_response_input(y_train)
|
|
189
216
|
self._check_same_length(x_train, y_train)
|
|
217
|
+
if w is not None:
|
|
218
|
+
w, _ = self._process_response_input(w)
|
|
219
|
+
self._check_same_length(x_train, w)
|
|
190
220
|
|
|
221
|
+
y_train = self._process_type_settings(y_train, type, w)
|
|
222
|
+
# from here onwards, the type is determined by y_train.dtype == bool
|
|
191
223
|
offset = self._process_offset_settings(y_train, offset)
|
|
192
|
-
|
|
193
|
-
lamda, sigest = self.
|
|
224
|
+
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
|
|
225
|
+
lamda, sigest = self._process_error_variance_settings(
|
|
226
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
227
|
+
)
|
|
194
228
|
|
|
195
229
|
splits, max_split = self._determine_splits(x_train, usequants, numcut)
|
|
196
230
|
x_train = self._bin_predictors(x_train, splits)
|
|
197
231
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
232
|
+
mcmc_state = self._setup_mcmc(
|
|
233
|
+
x_train,
|
|
234
|
+
y_train,
|
|
235
|
+
offset,
|
|
236
|
+
w,
|
|
237
|
+
max_split,
|
|
238
|
+
lamda,
|
|
239
|
+
sigma_mu,
|
|
240
|
+
sigdf,
|
|
241
|
+
power,
|
|
242
|
+
base,
|
|
243
|
+
maxdepth,
|
|
244
|
+
ntree,
|
|
245
|
+
init_kw,
|
|
246
|
+
)
|
|
247
|
+
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
248
|
+
mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
|
|
249
|
+
)
|
|
203
250
|
|
|
204
|
-
sigma = self._extract_sigma(main_trace
|
|
205
|
-
first_sigma = self._extract_sigma(burnin_trace
|
|
251
|
+
sigma = self._extract_sigma(main_trace)
|
|
252
|
+
first_sigma = self._extract_sigma(burnin_trace)
|
|
206
253
|
|
|
207
|
-
self.offset = offset
|
|
208
|
-
self.scale = scale
|
|
209
|
-
self.lamda = lamda
|
|
254
|
+
self.offset = final_state.offset # from the state because of buffer donation
|
|
210
255
|
self.sigest = sigest
|
|
211
|
-
self.ntree = ntree
|
|
212
|
-
self.maxdepth = maxdepth
|
|
213
256
|
self.sigma = sigma
|
|
214
257
|
self.first_sigma = first_sigma
|
|
215
258
|
|
|
@@ -225,9 +268,8 @@ class gbart:
|
|
|
225
268
|
|
|
226
269
|
@functools.cached_property
|
|
227
270
|
def yhat_train(self):
|
|
228
|
-
x_train = self._mcmc_state
|
|
229
|
-
|
|
230
|
-
return self._transform_output(yhat_train, self.offset, self.scale)
|
|
271
|
+
x_train = self._mcmc_state.X
|
|
272
|
+
return self._predict(self._main_trace, x_train)
|
|
231
273
|
|
|
232
274
|
@functools.cached_property
|
|
233
275
|
def yhat_train_mean(self):
|
|
@@ -239,19 +281,26 @@ class gbart:
|
|
|
239
281
|
|
|
240
282
|
Parameters
|
|
241
283
|
----------
|
|
242
|
-
x_test : array (
|
|
284
|
+
x_test : array (p, m) or DataFrame
|
|
243
285
|
The test predictors.
|
|
244
286
|
|
|
245
287
|
Returns
|
|
246
288
|
-------
|
|
247
289
|
yhat_test : array (ndpost, m)
|
|
248
290
|
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
291
|
+
|
|
292
|
+
Raises
|
|
293
|
+
------
|
|
294
|
+
ValueError
|
|
295
|
+
If `x_test` has a different format than `x_train`.
|
|
249
296
|
"""
|
|
250
297
|
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
251
|
-
|
|
298
|
+
if x_test_fmt != self._x_train_fmt:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
|
|
301
|
+
)
|
|
252
302
|
x_test = self._bin_predictors(x_test, self._splits)
|
|
253
|
-
|
|
254
|
-
return self._transform_output(yhat_test, self.offset, self.scale)
|
|
303
|
+
return self._predict(self._main_trace, x_test)
|
|
255
304
|
|
|
256
305
|
@staticmethod
|
|
257
306
|
def _process_predictor_input(x):
|
|
@@ -264,10 +313,6 @@ class gbart:
|
|
|
264
313
|
assert x.ndim == 2
|
|
265
314
|
return x, fmt
|
|
266
315
|
|
|
267
|
-
@staticmethod
|
|
268
|
-
def _check_compatible_formats(fmt1, fmt2):
|
|
269
|
-
assert fmt1 == fmt2
|
|
270
|
-
|
|
271
316
|
@staticmethod
|
|
272
317
|
def _process_response_input(y):
|
|
273
318
|
if hasattr(y, 'to_numpy'):
|
|
@@ -285,20 +330,30 @@ class gbart:
|
|
|
285
330
|
assert get_length(x1) == get_length(x2)
|
|
286
331
|
|
|
287
332
|
@staticmethod
|
|
288
|
-
def
|
|
289
|
-
|
|
333
|
+
def _process_error_variance_settings(
|
|
334
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
335
|
+
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
336
|
+
if y_train.dtype == bool:
|
|
337
|
+
if sigest is not None:
|
|
338
|
+
raise ValueError('Let `sigest=None` for binary regression')
|
|
339
|
+
if lamda is not None:
|
|
340
|
+
raise ValueError('Let `lamda=None` for binary regression')
|
|
341
|
+
return None, None
|
|
342
|
+
elif lamda is not None:
|
|
343
|
+
if sigest is not None:
|
|
344
|
+
raise ValueError('Let `sigest=None` if `lamda` is specified')
|
|
290
345
|
return lamda, None
|
|
291
346
|
else:
|
|
292
347
|
if sigest is not None:
|
|
293
|
-
sigest2 = sigest
|
|
348
|
+
sigest2 = jnp.square(sigest)
|
|
294
349
|
elif y_train.size < 2:
|
|
295
350
|
sigest2 = 1
|
|
296
351
|
elif y_train.size <= x_train.shape[0]:
|
|
297
|
-
sigest2 = jnp.var(y_train
|
|
352
|
+
sigest2 = jnp.var(y_train)
|
|
298
353
|
else:
|
|
299
354
|
x_centered = x_train.T - x_train.mean(axis=1)
|
|
300
355
|
y_centered = y_train - y_train.mean()
|
|
301
|
-
|
|
356
|
+
# centering is equivalent to adding an intercept column
|
|
302
357
|
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
303
358
|
chisq = chisq.squeeze(0)
|
|
304
359
|
dof = len(y_train) - rank
|
|
@@ -309,20 +364,62 @@ class gbart:
|
|
|
309
364
|
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
310
365
|
|
|
311
366
|
@staticmethod
|
|
312
|
-
def
|
|
367
|
+
def _process_type_settings(y_train, type, w):
|
|
368
|
+
match type:
|
|
369
|
+
case 'wbart':
|
|
370
|
+
if y_train.dtype != jnp.float32:
|
|
371
|
+
raise TypeError(
|
|
372
|
+
'Continuous regression requires y_train.dtype=float32,'
|
|
373
|
+
f' got {y_train.dtype=} instead.'
|
|
374
|
+
)
|
|
375
|
+
case 'pbart':
|
|
376
|
+
if w is not None:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
'Binary regression does not support weights, set `w=None`'
|
|
379
|
+
)
|
|
380
|
+
if y_train.dtype != bool:
|
|
381
|
+
raise TypeError(
|
|
382
|
+
'Binary regression requires y_train.dtype=bool,'
|
|
383
|
+
f' got {y_train.dtype=} instead.'
|
|
384
|
+
)
|
|
385
|
+
case _:
|
|
386
|
+
raise ValueError(f'Invalid {type=}')
|
|
387
|
+
|
|
388
|
+
return y_train
|
|
389
|
+
|
|
390
|
+
@staticmethod
|
|
391
|
+
def _process_offset_settings(
|
|
392
|
+
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
393
|
+
offset: float | Float32[Any, ''] | None,
|
|
394
|
+
) -> Float32[Array, '']:
|
|
313
395
|
if offset is not None:
|
|
314
|
-
return offset
|
|
396
|
+
return jnp.asarray(offset)
|
|
315
397
|
elif y_train.size < 1:
|
|
316
|
-
return 0
|
|
398
|
+
return jnp.array(0.0)
|
|
317
399
|
else:
|
|
318
|
-
|
|
400
|
+
mean = y_train.mean()
|
|
319
401
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
if y_train.size < 2:
|
|
323
|
-
return 1
|
|
402
|
+
if y_train.dtype == bool:
|
|
403
|
+
return ndtri(mean)
|
|
324
404
|
else:
|
|
325
|
-
return
|
|
405
|
+
return mean
|
|
406
|
+
|
|
407
|
+
@staticmethod
|
|
408
|
+
def _process_leaf_sdev_settings(
|
|
409
|
+
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
410
|
+
k: float,
|
|
411
|
+
ntree: int,
|
|
412
|
+
tau_num: FloatLike | None,
|
|
413
|
+
):
|
|
414
|
+
if tau_num is None:
|
|
415
|
+
if y_train.dtype == bool:
|
|
416
|
+
tau_num = 3.0
|
|
417
|
+
elif y_train.size < 2:
|
|
418
|
+
tau_num = 1.0
|
|
419
|
+
else:
|
|
420
|
+
tau_num = (y_train.max() - y_train.min()) / 2
|
|
421
|
+
|
|
422
|
+
return tau_num / (k * math.sqrt(ntree))
|
|
326
423
|
|
|
327
424
|
@staticmethod
|
|
328
425
|
def _determine_splits(x_train, usequants, numcut):
|
|
@@ -336,52 +433,86 @@ class gbart:
|
|
|
336
433
|
return prepcovars.bin_predictors(x, splits)
|
|
337
434
|
|
|
338
435
|
@staticmethod
|
|
339
|
-
def
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
436
|
+
def _setup_mcmc(
|
|
437
|
+
x_train,
|
|
438
|
+
y_train,
|
|
439
|
+
offset,
|
|
440
|
+
w,
|
|
441
|
+
max_split,
|
|
442
|
+
lamda,
|
|
443
|
+
sigma_mu,
|
|
444
|
+
sigdf,
|
|
445
|
+
power,
|
|
446
|
+
base,
|
|
447
|
+
maxdepth,
|
|
448
|
+
ntree,
|
|
449
|
+
init_kw,
|
|
450
|
+
):
|
|
344
451
|
depth = jnp.arange(maxdepth - 1)
|
|
345
452
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
346
|
-
|
|
347
|
-
|
|
453
|
+
|
|
454
|
+
if y_train.dtype == bool:
|
|
455
|
+
sigma2_alpha = None
|
|
456
|
+
sigma2_beta = None
|
|
457
|
+
else:
|
|
458
|
+
sigma2_alpha = sigdf / 2
|
|
459
|
+
sigma2_beta = lamda * sigma2_alpha
|
|
460
|
+
|
|
348
461
|
kw = dict(
|
|
349
462
|
X=x_train,
|
|
350
|
-
|
|
463
|
+
# copy y_train because it's going to be donated in the mcmc loop
|
|
464
|
+
y=jnp.array(y_train),
|
|
465
|
+
offset=offset,
|
|
466
|
+
error_scale=w,
|
|
351
467
|
max_split=max_split,
|
|
352
468
|
num_trees=ntree,
|
|
353
469
|
p_nonterminal=p_nonterminal,
|
|
470
|
+
sigma_mu2=jnp.square(sigma_mu),
|
|
354
471
|
sigma2_alpha=sigma2_alpha,
|
|
355
472
|
sigma2_beta=sigma2_beta,
|
|
356
473
|
min_points_per_leaf=5,
|
|
357
474
|
)
|
|
358
|
-
|
|
475
|
+
if init_kw is not None:
|
|
476
|
+
kw.update(init_kw)
|
|
359
477
|
return mcmcstep.init(**kw)
|
|
360
478
|
|
|
361
479
|
@staticmethod
|
|
362
|
-
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
|
|
363
|
-
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
364
|
-
|
|
480
|
+
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
|
|
481
|
+
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
482
|
+
seed.dtype, jax.dtypes.prng_key
|
|
483
|
+
):
|
|
484
|
+
key = seed.copy()
|
|
485
|
+
# copy because the inner loop in run_mcmc will donate the buffer
|
|
365
486
|
else:
|
|
366
487
|
key = jax.random.key(seed)
|
|
367
|
-
callback = mcmcloop.make_simple_print_callback(printevery)
|
|
368
|
-
return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback, key)
|
|
369
488
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
489
|
+
kw = dict(
|
|
490
|
+
n_burn=nskip,
|
|
491
|
+
n_skip=keepevery,
|
|
492
|
+
inner_loop_length=printevery,
|
|
493
|
+
allow_overflow=True,
|
|
494
|
+
)
|
|
495
|
+
if printevery is not None:
|
|
496
|
+
kw.update(mcmcloop.make_print_callbacks())
|
|
497
|
+
if run_mcmc_kw is not None:
|
|
498
|
+
kw.update(run_mcmc_kw)
|
|
373
499
|
|
|
374
|
-
|
|
375
|
-
def _transform_output(y, offset, scale):
|
|
376
|
-
return offset + scale * y
|
|
500
|
+
return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
|
|
377
501
|
|
|
378
502
|
@staticmethod
|
|
379
|
-
def _extract_sigma(trace,
|
|
380
|
-
|
|
503
|
+
def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
|
|
504
|
+
if trace['sigma2'] is None:
|
|
505
|
+
return None
|
|
506
|
+
else:
|
|
507
|
+
return jnp.sqrt(trace['sigma2'])
|
|
381
508
|
|
|
509
|
+
@staticmethod
|
|
510
|
+
def _predict(trace, x):
|
|
511
|
+
return mcmcloop.evaluate_trace(trace, x)
|
|
382
512
|
|
|
383
513
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
384
514
|
from . import debug
|
|
515
|
+
|
|
385
516
|
trace = self._main_trace
|
|
386
517
|
leaf_tree = trace['leaf_trees'][i_sample, i_tree]
|
|
387
518
|
var_tree = trace['var_trees'][i_sample, i_tree]
|
|
@@ -396,30 +527,49 @@ class gbart:
|
|
|
396
527
|
else:
|
|
397
528
|
resid = bart['resid']
|
|
398
529
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
399
|
-
norm2 = jnp.dot(
|
|
530
|
+
norm2 = jnp.dot(
|
|
531
|
+
resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
|
|
532
|
+
)
|
|
400
533
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
401
534
|
sigma2 = beta / alpha
|
|
402
|
-
return jnp.sqrt(sigma2)
|
|
535
|
+
return jnp.sqrt(sigma2)
|
|
403
536
|
|
|
404
537
|
def _compare_resid(self):
|
|
405
538
|
bart = self._mcmc_state
|
|
406
|
-
resid1 = bart
|
|
407
|
-
|
|
408
|
-
|
|
539
|
+
resid1 = bart.resid
|
|
540
|
+
|
|
541
|
+
trees = grove.evaluate_forest(
|
|
542
|
+
bart.X,
|
|
543
|
+
bart.forest.leaf_trees,
|
|
544
|
+
bart.forest.var_trees,
|
|
545
|
+
bart.forest.split_trees,
|
|
546
|
+
jnp.float32, # TODO remove these configurable dtypes around
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
if bart.z is not None:
|
|
550
|
+
ref = bart.z
|
|
551
|
+
else:
|
|
552
|
+
ref = bart.y
|
|
553
|
+
resid2 = ref - (trees + bart.offset)
|
|
554
|
+
|
|
409
555
|
return resid1, resid2
|
|
410
556
|
|
|
411
557
|
def _avg_acc(self):
|
|
412
558
|
trace = self._main_trace
|
|
559
|
+
|
|
413
560
|
def acc(prefix):
|
|
414
561
|
acc = trace[f'{prefix}_acc_count']
|
|
415
562
|
prop = trace[f'{prefix}_prop_count']
|
|
416
563
|
return acc.sum() / prop.sum()
|
|
564
|
+
|
|
417
565
|
return acc('grow'), acc('prune')
|
|
418
566
|
|
|
419
567
|
def _avg_prop(self):
|
|
420
568
|
trace = self._main_trace
|
|
569
|
+
|
|
421
570
|
def prop(prefix):
|
|
422
571
|
return trace[f'{prefix}_prop_count'].sum()
|
|
572
|
+
|
|
423
573
|
pgrow = prop('grow')
|
|
424
574
|
pprune = prop('prune')
|
|
425
575
|
total = pgrow + pprune
|
|
@@ -432,16 +582,19 @@ class gbart:
|
|
|
432
582
|
|
|
433
583
|
def _depth_distr(self):
|
|
434
584
|
from . import debug
|
|
585
|
+
|
|
435
586
|
trace = self._main_trace
|
|
436
587
|
split_trees = trace['split_trees']
|
|
437
588
|
return debug.trace_depth_distr(split_trees)
|
|
438
589
|
|
|
439
590
|
def _points_per_leaf_distr(self):
|
|
440
591
|
from . import debug
|
|
441
|
-
|
|
592
|
+
|
|
593
|
+
return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
|
|
442
594
|
|
|
443
595
|
def _check_trees(self):
|
|
444
596
|
from . import debug
|
|
597
|
+
|
|
445
598
|
return debug.check_trace(self._main_trace, self._mcmc_state)
|
|
446
599
|
|
|
447
600
|
def _tree_goes_bad(self):
|
bartz/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/__init__.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -23,18 +23,10 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
"""
|
|
26
|
-
Super-fast BART (Bayesian Additive Regression Trees) in Python
|
|
26
|
+
Super-fast BART (Bayesian Additive Regression Trees) in Python.
|
|
27
27
|
|
|
28
28
|
See the manual at https://gattocrucco.github.io/bartz/docs
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
from .
|
|
32
|
-
|
|
33
|
-
from . import BART
|
|
34
|
-
|
|
35
|
-
from . import debug
|
|
36
|
-
from . import grove
|
|
37
|
-
from . import mcmcstep
|
|
38
|
-
from . import mcmcloop
|
|
39
|
-
from . import prepcovars
|
|
40
|
-
from . import jaxext
|
|
31
|
+
from . import BART, debug, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
|
|
32
|
+
from ._version import __version__ # noqa: F401
|
bartz/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.6.0'
|