bartz 0.4.1__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bartz-0.4.1 → bartz-0.5.0}/PKG-INFO +12 -16
- {bartz-0.4.1 → bartz-0.5.0}/README.md +1 -0
- {bartz-0.4.1 → bartz-0.5.0}/pyproject.toml +55 -27
- bartz-0.5.0/src/bartz/.DS_Store +0 -0
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/BART.py +99 -39
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/__init__.py +3 -11
- bartz-0.5.0/src/bartz/_version.py +1 -0
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/debug.py +42 -16
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/grove.py +20 -11
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/jaxext.py +41 -16
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/mcmcloop.py +119 -58
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/mcmcstep.py +426 -173
- {bartz-0.4.1 → bartz-0.5.0}/src/bartz/prepcovars.py +22 -9
- bartz-0.4.1/LICENSE +0 -21
- bartz-0.4.1/src/bartz/_version.py +0 -1
|
@@ -1,22 +1,18 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
|
|
5
|
-
License: MIT
|
|
6
5
|
Author: Giacomo Petrillo
|
|
7
|
-
Author-email: info@giacomopetrillo.com
|
|
6
|
+
Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Requires-Dist: jax>=0.4.35,<1
|
|
9
|
+
Requires-Dist: jaxlib>=0.4.35,<1
|
|
10
|
+
Requires-Dist: numpy>=1.25.2,<3
|
|
11
|
+
Requires-Dist: scipy>=1.11.4,<2
|
|
8
12
|
Requires-Python: >=3.10
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
-
Requires-Dist: jax (>=0.4.35,<1)
|
|
16
|
-
Requires-Dist: jaxlib (>=0.4.35,<1)
|
|
17
|
-
Requires-Dist: numpy (>=1.25.2,<3)
|
|
18
|
-
Requires-Dist: scipy (>=1.11.4,<2)
|
|
19
|
-
Project-URL: Bug Tracker, https://github.com/Gattocrucco/bartz/issues
|
|
13
|
+
Project-URL: Documentation, https://gattocrucco.github.io/bartz/docs-dev
|
|
14
|
+
Project-URL: Homepage, https://github.com/Gattocrucco/bartz
|
|
15
|
+
Project-URL: Issues, https://github.com/Gattocrucco/bartz/issues
|
|
20
16
|
Description-Content-Type: text/markdown
|
|
21
17
|
|
|
22
18
|
[](https://pypi.org/project/bartz/)
|
|
@@ -42,6 +38,7 @@ On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of)
|
|
|
42
38
|
- [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
|
|
43
39
|
- [Repository](https://github.com/Gattocrucco/bartz)
|
|
44
40
|
- [Code coverage](https://gattocrucco.github.io/bartz/coverage)
|
|
41
|
+
- [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
|
|
45
42
|
- [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
|
|
46
43
|
|
|
47
44
|
## Citing bartz
|
|
@@ -49,4 +46,3 @@ On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of)
|
|
|
49
46
|
Article: Petrillo (2024), "Very fast Bayesian Additive Regression Trees on GPU", [arXiv:2410.23244](https://arxiv.org/abs/2410.23244).
|
|
50
47
|
|
|
51
48
|
To cite the software directly, including the specific version, use [zenodo](https://doi.org/10.5281/zenodo.13931477).
|
|
52
|
-
|
|
@@ -21,6 +21,7 @@ On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of)
|
|
|
21
21
|
- [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
|
|
22
22
|
- [Repository](https://github.com/Gattocrucco/bartz)
|
|
23
23
|
- [Code coverage](https://gattocrucco.github.io/bartz/coverage)
|
|
24
|
+
- [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
|
|
24
25
|
- [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
|
|
25
26
|
|
|
26
27
|
## Citing bartz
|
|
@@ -23,12 +23,12 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
[build-system]
|
|
26
|
-
requires = ["
|
|
27
|
-
build-backend = "
|
|
26
|
+
requires = ["uv_build>=0.7.3,<0.8.0"]
|
|
27
|
+
build-backend = "uv_build"
|
|
28
28
|
|
|
29
29
|
[project]
|
|
30
30
|
name = "bartz"
|
|
31
|
-
version = "0.
|
|
31
|
+
version = "0.5.0"
|
|
32
32
|
description = "Super-fast BART (Bayesian Additive Regression Trees) in Python"
|
|
33
33
|
authors = [
|
|
34
34
|
{name = "Giacomo Petrillo", email = "info@giacomopetrillo.com"},
|
|
@@ -36,7 +36,6 @@ authors = [
|
|
|
36
36
|
license = "MIT"
|
|
37
37
|
readme = "README.md"
|
|
38
38
|
requires-python = ">=3.10"
|
|
39
|
-
repository = "https://github.com/Gattocrucco/bartz"
|
|
40
39
|
packages = [
|
|
41
40
|
{ include = "bartz", from = "src" },
|
|
42
41
|
]
|
|
@@ -47,29 +46,35 @@ dependencies = [
|
|
|
47
46
|
"scipy >=1.11.4,<2",
|
|
48
47
|
]
|
|
49
48
|
|
|
50
|
-
[
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
[
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
myst-parser
|
|
49
|
+
[project.urls]
|
|
50
|
+
Homepage = "https://github.com/Gattocrucco/bartz"
|
|
51
|
+
Documentation = "https://gattocrucco.github.io/bartz/docs-dev"
|
|
52
|
+
Issues = "https://github.com/Gattocrucco/bartz/issues"
|
|
53
|
+
|
|
54
|
+
[dependency-groups]
|
|
55
|
+
only-local = [
|
|
56
|
+
"appnope>=0.1.4",
|
|
57
|
+
"ipython>=8.36.0",
|
|
58
|
+
"matplotlib>=3.10.3",
|
|
59
|
+
"matplotlib-label-lines>=0.8.1",
|
|
60
|
+
"polars[pandas,pyarrow]>=1.29.0",
|
|
61
|
+
"pre-commit>=4.2.0",
|
|
62
|
+
"ruff>=0.11.9",
|
|
63
|
+
"scikit-learn>=1.6.1",
|
|
64
|
+
"tomli>=2.2.1",
|
|
65
|
+
"virtualenv>=20.31.2",
|
|
66
|
+
"xgboost>=3.0.0",
|
|
67
|
+
]
|
|
68
|
+
ci = [
|
|
69
|
+
"asv>=0.6.4",
|
|
70
|
+
"coverage>=7.8.0",
|
|
71
|
+
"myst-parser>=4.0.1",
|
|
72
|
+
"numpydoc>=1.8.0",
|
|
73
|
+
"packaging>=25.0",
|
|
74
|
+
"pytest>=8.3.5",
|
|
75
|
+
"pytest-timeout>=2.4.0",
|
|
76
|
+
"sphinx>=8.1.3",
|
|
77
|
+
]
|
|
73
78
|
|
|
74
79
|
[tool.pytest.ini_options]
|
|
75
80
|
testpaths = ["tests"]
|
|
@@ -81,6 +86,8 @@ addopts = [
|
|
|
81
86
|
"--pdbcls=IPython.terminal.debugger:TerminalPdb",
|
|
82
87
|
"--durations=3",
|
|
83
88
|
]
|
|
89
|
+
timeout = 32
|
|
90
|
+
timeout_method = "thread" # when jax hangs, signals do not work
|
|
84
91
|
|
|
85
92
|
# I wanted to use `--import-mode=importlib`, but it breaks importing submodules,
|
|
86
93
|
# in particular `from . import util`.
|
|
@@ -119,3 +126,24 @@ local = [
|
|
|
119
126
|
'/opt/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/',
|
|
120
127
|
'C:\hostedtoolcache\windows\Python\*\*\Lib\site-packages\bartz\',
|
|
121
128
|
]
|
|
129
|
+
|
|
130
|
+
[tool.ruff]
|
|
131
|
+
exclude = [".asv", "*.ipynb"]
|
|
132
|
+
|
|
133
|
+
[tool.ruff.format]
|
|
134
|
+
quote-style = "single"
|
|
135
|
+
|
|
136
|
+
[tool.ruff.lint]
|
|
137
|
+
select = [
|
|
138
|
+
"B", # bugbear: grab bag of additional stuff
|
|
139
|
+
"UP", # pyupgrade: fix some outdated idioms
|
|
140
|
+
"I", # isort: sort and reformat import statements
|
|
141
|
+
"F", # flake8
|
|
142
|
+
]
|
|
143
|
+
ignore = [
|
|
144
|
+
"B028", # warn with stacklevel = 2
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
[tool.uv]
|
|
148
|
+
python-downloads = "never"
|
|
149
|
+
python-preference = "only-system"
|
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -27,11 +27,8 @@ import functools
|
|
|
27
27
|
import jax
|
|
28
28
|
import jax.numpy as jnp
|
|
29
29
|
|
|
30
|
-
from . import jaxext
|
|
31
|
-
|
|
32
|
-
from . import mcmcstep
|
|
33
|
-
from . import mcmcloop
|
|
34
|
-
from . import prepcovars
|
|
30
|
+
from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
|
|
31
|
+
|
|
35
32
|
|
|
36
33
|
class gbart:
|
|
37
34
|
"""
|
|
@@ -53,10 +50,11 @@ class gbart:
|
|
|
53
50
|
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
54
51
|
predictors.
|
|
55
52
|
sigest : float, optional
|
|
56
|
-
An estimate of the residual standard deviation on `y_train`, used to
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
is set to
|
|
53
|
+
An estimate of the residual standard deviation on `y_train`, used to set
|
|
54
|
+
`lamda`. If not specified, it is estimated by linear regression (with
|
|
55
|
+
intercept, and without taking into account `w`). If `y_train` has less
|
|
56
|
+
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
57
|
+
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
60
58
|
sigdf : int, default 3
|
|
61
59
|
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
62
60
|
noise variance.
|
|
@@ -82,6 +80,12 @@ class gbart:
|
|
|
82
80
|
offset : float, optional
|
|
83
81
|
The prior mean of the latent mean function. If not specified, it is set
|
|
84
82
|
to the mean of `y_train`. If `y_train` is empty, it is set to 0.
|
|
83
|
+
w : array (n,), optional
|
|
84
|
+
Coefficients that rescale the error standard deviation on each
|
|
85
|
+
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
86
|
+
datapoints. Note: `w` is ignored in the automatic determination of
|
|
87
|
+
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
88
|
+
specified by the user.
|
|
85
89
|
ntree : int, default 200
|
|
86
90
|
The number of trees used to represent the latent mean function.
|
|
87
91
|
numcut : int, default 255
|
|
@@ -108,6 +112,8 @@ class gbart:
|
|
|
108
112
|
The number of iterations (including skipped ones) between each log.
|
|
109
113
|
seed : int or jax random key, default 0
|
|
110
114
|
The seed for the random number generator.
|
|
115
|
+
initkw : dict
|
|
116
|
+
Additional arguments passed to `mcmcstep.init`.
|
|
111
117
|
|
|
112
118
|
Attributes
|
|
113
119
|
----------
|
|
@@ -135,8 +141,6 @@ class gbart:
|
|
|
135
141
|
The number of trees.
|
|
136
142
|
maxdepth : int
|
|
137
143
|
The maximum depth of the trees.
|
|
138
|
-
initkw : dict
|
|
139
|
-
Additional arguments passed to `mcmcstep.init`.
|
|
140
144
|
|
|
141
145
|
Methods
|
|
142
146
|
-------
|
|
@@ -158,10 +162,13 @@ class gbart:
|
|
|
158
162
|
- A lot of functionality is missing (variable selection, discrete response).
|
|
159
163
|
- There are some additional attributes, and some missing.
|
|
160
164
|
|
|
161
|
-
The linear regression used to set `sigest` adds an intercept.
|
|
162
165
|
"""
|
|
163
166
|
|
|
164
|
-
def __init__(
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
x_train,
|
|
170
|
+
y_train,
|
|
171
|
+
*,
|
|
165
172
|
x_test=None,
|
|
166
173
|
usequants=False,
|
|
167
174
|
sigest=None,
|
|
@@ -173,6 +180,7 @@ class gbart:
|
|
|
173
180
|
maxdepth=6,
|
|
174
181
|
lamda=None,
|
|
175
182
|
offset=None,
|
|
183
|
+
w=None,
|
|
176
184
|
ntree=200,
|
|
177
185
|
numcut=255,
|
|
178
186
|
ndpost=1000,
|
|
@@ -180,26 +188,41 @@ class gbart:
|
|
|
180
188
|
keepevery=1,
|
|
181
189
|
printevery=100,
|
|
182
190
|
seed=0,
|
|
183
|
-
initkw=
|
|
184
|
-
|
|
185
|
-
|
|
191
|
+
initkw=None,
|
|
192
|
+
):
|
|
186
193
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
187
|
-
|
|
188
|
-
y_train, y_train_fmt = self._process_response_input(y_train)
|
|
194
|
+
y_train, _ = self._process_response_input(y_train)
|
|
189
195
|
self._check_same_length(x_train, y_train)
|
|
196
|
+
if w is not None:
|
|
197
|
+
w, _ = self._process_response_input(w)
|
|
198
|
+
self._check_same_length(x_train, w)
|
|
190
199
|
|
|
191
200
|
offset = self._process_offset_settings(y_train, offset)
|
|
192
201
|
scale = self._process_scale_settings(y_train, k)
|
|
193
|
-
lamda, sigest = self._process_noise_variance_settings(
|
|
202
|
+
lamda, sigest = self._process_noise_variance_settings(
|
|
203
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda, offset
|
|
204
|
+
)
|
|
194
205
|
|
|
195
206
|
splits, max_split = self._determine_splits(x_train, usequants, numcut)
|
|
196
207
|
x_train = self._bin_predictors(x_train, splits)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
208
|
+
y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
|
|
209
|
+
|
|
210
|
+
mcmc_state = self._setup_mcmc(
|
|
211
|
+
x_train,
|
|
212
|
+
y_train,
|
|
213
|
+
w,
|
|
214
|
+
max_split,
|
|
215
|
+
lamda_scaled,
|
|
216
|
+
sigdf,
|
|
217
|
+
power,
|
|
218
|
+
base,
|
|
219
|
+
maxdepth,
|
|
220
|
+
ntree,
|
|
221
|
+
initkw,
|
|
222
|
+
)
|
|
223
|
+
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
224
|
+
mcmc_state, ndpost, nskip, keepevery, printevery, seed
|
|
225
|
+
)
|
|
203
226
|
|
|
204
227
|
sigma = self._extract_sigma(main_trace, scale)
|
|
205
228
|
first_sigma = self._extract_sigma(burnin_trace, scale)
|
|
@@ -239,7 +262,7 @@ class gbart:
|
|
|
239
262
|
|
|
240
263
|
Parameters
|
|
241
264
|
----------
|
|
242
|
-
x_test : array (
|
|
265
|
+
x_test : array (p, m) or DataFrame
|
|
243
266
|
The test predictors.
|
|
244
267
|
|
|
245
268
|
Returns
|
|
@@ -285,7 +308,9 @@ class gbart:
|
|
|
285
308
|
assert get_length(x1) == get_length(x2)
|
|
286
309
|
|
|
287
310
|
@staticmethod
|
|
288
|
-
def _process_noise_variance_settings(
|
|
311
|
+
def _process_noise_variance_settings(
|
|
312
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda, offset
|
|
313
|
+
):
|
|
289
314
|
if lamda is not None:
|
|
290
315
|
return lamda, None
|
|
291
316
|
else:
|
|
@@ -298,7 +323,7 @@ class gbart:
|
|
|
298
323
|
else:
|
|
299
324
|
x_centered = x_train.T - x_train.mean(axis=1)
|
|
300
325
|
y_centered = y_train - y_train.mean()
|
|
301
|
-
|
|
326
|
+
# centering is equivalent to adding an intercept column
|
|
302
327
|
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
303
328
|
chisq = chisq.squeeze(0)
|
|
304
329
|
dof = len(y_train) - rank
|
|
@@ -336,11 +361,25 @@ class gbart:
|
|
|
336
361
|
return prepcovars.bin_predictors(x, splits)
|
|
337
362
|
|
|
338
363
|
@staticmethod
|
|
339
|
-
def _transform_input(y, offset, scale):
|
|
340
|
-
|
|
364
|
+
def _transform_input(y, lamda, offset, scale):
|
|
365
|
+
y = (y - offset) / scale
|
|
366
|
+
lamda = lamda / (scale * scale)
|
|
367
|
+
return y, lamda
|
|
341
368
|
|
|
342
369
|
@staticmethod
|
|
343
|
-
def _setup_mcmc(
|
|
370
|
+
def _setup_mcmc(
|
|
371
|
+
x_train,
|
|
372
|
+
y_train,
|
|
373
|
+
w,
|
|
374
|
+
max_split,
|
|
375
|
+
lamda,
|
|
376
|
+
sigdf,
|
|
377
|
+
power,
|
|
378
|
+
base,
|
|
379
|
+
maxdepth,
|
|
380
|
+
ntree,
|
|
381
|
+
initkw,
|
|
382
|
+
):
|
|
344
383
|
depth = jnp.arange(maxdepth - 1)
|
|
345
384
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
346
385
|
sigma2_alpha = sigdf / 2
|
|
@@ -348,6 +387,7 @@ class gbart:
|
|
|
348
387
|
kw = dict(
|
|
349
388
|
X=x_train,
|
|
350
389
|
y=y_train,
|
|
390
|
+
error_scale=w,
|
|
351
391
|
max_split=max_split,
|
|
352
392
|
num_trees=ntree,
|
|
353
393
|
p_nonterminal=p_nonterminal,
|
|
@@ -355,17 +395,20 @@ class gbart:
|
|
|
355
395
|
sigma2_beta=sigma2_beta,
|
|
356
396
|
min_points_per_leaf=5,
|
|
357
397
|
)
|
|
358
|
-
|
|
398
|
+
if initkw is not None:
|
|
399
|
+
kw.update(initkw)
|
|
359
400
|
return mcmcstep.init(**kw)
|
|
360
401
|
|
|
361
402
|
@staticmethod
|
|
362
403
|
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
|
|
363
|
-
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
404
|
+
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
405
|
+
seed.dtype, jax.dtypes.prng_key
|
|
406
|
+
):
|
|
364
407
|
key = seed
|
|
365
408
|
else:
|
|
366
409
|
key = jax.random.key(seed)
|
|
367
410
|
callback = mcmcloop.make_simple_print_callback(printevery)
|
|
368
|
-
return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback
|
|
411
|
+
return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
|
|
369
412
|
|
|
370
413
|
@staticmethod
|
|
371
414
|
def _predict(trace, x):
|
|
@@ -379,9 +422,9 @@ class gbart:
|
|
|
379
422
|
def _extract_sigma(trace, scale):
|
|
380
423
|
return scale * jnp.sqrt(trace['sigma2'])
|
|
381
424
|
|
|
382
|
-
|
|
383
425
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
384
426
|
from . import debug
|
|
427
|
+
|
|
385
428
|
trace = self._main_trace
|
|
386
429
|
leaf_tree = trace['leaf_trees'][i_sample, i_tree]
|
|
387
430
|
var_tree = trace['var_trees'][i_sample, i_tree]
|
|
@@ -396,7 +439,9 @@ class gbart:
|
|
|
396
439
|
else:
|
|
397
440
|
resid = bart['resid']
|
|
398
441
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
399
|
-
norm2 = jnp.dot(
|
|
442
|
+
norm2 = jnp.dot(
|
|
443
|
+
resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
|
|
444
|
+
)
|
|
400
445
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
401
446
|
sigma2 = beta / alpha
|
|
402
447
|
return jnp.sqrt(sigma2) * self.scale
|
|
@@ -404,22 +449,32 @@ class gbart:
|
|
|
404
449
|
def _compare_resid(self):
|
|
405
450
|
bart = self._mcmc_state
|
|
406
451
|
resid1 = bart['resid']
|
|
407
|
-
yhat = grove.evaluate_forest(
|
|
452
|
+
yhat = grove.evaluate_forest(
|
|
453
|
+
bart['X'],
|
|
454
|
+
bart['leaf_trees'],
|
|
455
|
+
bart['var_trees'],
|
|
456
|
+
bart['split_trees'],
|
|
457
|
+
jnp.float32,
|
|
458
|
+
)
|
|
408
459
|
resid2 = bart['y'] - yhat
|
|
409
460
|
return resid1, resid2
|
|
410
461
|
|
|
411
462
|
def _avg_acc(self):
|
|
412
463
|
trace = self._main_trace
|
|
464
|
+
|
|
413
465
|
def acc(prefix):
|
|
414
466
|
acc = trace[f'{prefix}_acc_count']
|
|
415
467
|
prop = trace[f'{prefix}_prop_count']
|
|
416
468
|
return acc.sum() / prop.sum()
|
|
469
|
+
|
|
417
470
|
return acc('grow'), acc('prune')
|
|
418
471
|
|
|
419
472
|
def _avg_prop(self):
|
|
420
473
|
trace = self._main_trace
|
|
474
|
+
|
|
421
475
|
def prop(prefix):
|
|
422
476
|
return trace[f'{prefix}_prop_count'].sum()
|
|
477
|
+
|
|
423
478
|
pgrow = prop('grow')
|
|
424
479
|
pprune = prop('prune')
|
|
425
480
|
total = pgrow + pprune
|
|
@@ -432,16 +487,21 @@ class gbart:
|
|
|
432
487
|
|
|
433
488
|
def _depth_distr(self):
|
|
434
489
|
from . import debug
|
|
490
|
+
|
|
435
491
|
trace = self._main_trace
|
|
436
492
|
split_trees = trace['split_trees']
|
|
437
493
|
return debug.trace_depth_distr(split_trees)
|
|
438
494
|
|
|
439
495
|
def _points_per_leaf_distr(self):
|
|
440
496
|
from . import debug
|
|
441
|
-
|
|
497
|
+
|
|
498
|
+
return debug.trace_points_per_leaf_distr(
|
|
499
|
+
self._main_trace, self._mcmc_state['X']
|
|
500
|
+
)
|
|
442
501
|
|
|
443
502
|
def _check_trees(self):
|
|
444
503
|
from . import debug
|
|
504
|
+
|
|
445
505
|
return debug.check_trace(self._main_trace, self._mcmc_state)
|
|
446
506
|
|
|
447
507
|
def _tree_goes_bad(self):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/__init__.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -28,13 +28,5 @@ Super-fast BART (Bayesian Additive Regression Trees) in Python
|
|
|
28
28
|
See the manual at https://gattocrucco.github.io/bartz/docs
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
from .
|
|
32
|
-
|
|
33
|
-
from . import BART
|
|
34
|
-
|
|
35
|
-
from . import debug
|
|
36
|
-
from . import grove
|
|
37
|
-
from . import mcmcstep
|
|
38
|
-
from . import mcmcloop
|
|
39
|
-
from . import prepcovars
|
|
40
|
-
from . import jaxext
|
|
31
|
+
from . import BART, debug, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
|
|
32
|
+
from ._version import __version__ # noqa: F401
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.5.0'
|
|
@@ -1,21 +1,19 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
|
-
from jax import numpy as jnp
|
|
5
4
|
from jax import lax
|
|
5
|
+
from jax import numpy as jnp
|
|
6
6
|
|
|
7
|
-
from . import grove
|
|
8
|
-
from . import mcmcstep
|
|
9
|
-
from . import jaxext
|
|
7
|
+
from . import grove, jaxext
|
|
10
8
|
|
|
11
|
-
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
12
9
|
|
|
10
|
+
def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
13
11
|
tee = '├──'
|
|
14
12
|
corner = '└──'
|
|
15
13
|
join = '│ '
|
|
16
14
|
space = ' '
|
|
17
15
|
down = '┐'
|
|
18
|
-
bottom = '╢'
|
|
16
|
+
bottom = '╢' # '┨' #
|
|
19
17
|
|
|
20
18
|
def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
|
|
21
19
|
if index >= len(leaf_tree):
|
|
@@ -58,7 +56,7 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
58
56
|
|
|
59
57
|
indent += next_indent
|
|
60
58
|
unused = unused or is_leaf
|
|
61
|
-
|
|
59
|
+
|
|
62
60
|
if unused and not print_all:
|
|
63
61
|
return
|
|
64
62
|
|
|
@@ -67,58 +65,80 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
67
65
|
|
|
68
66
|
traverse_tree(1, 0, '', '', '', False)
|
|
69
67
|
|
|
68
|
+
|
|
70
69
|
def tree_actual_depth(split_tree):
|
|
71
70
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
|
|
72
71
|
depth = grove.tree_depths(is_leaf.size)
|
|
73
72
|
depth = jnp.where(is_leaf, depth, 0)
|
|
74
73
|
return jnp.max(depth)
|
|
75
74
|
|
|
75
|
+
|
|
76
76
|
def forest_depth_distr(split_trees):
|
|
77
77
|
depth = grove.tree_depth(split_trees) + 1
|
|
78
78
|
depths = jax.vmap(tree_actual_depth)(split_trees)
|
|
79
79
|
return jnp.bincount(depths, length=depth)
|
|
80
80
|
|
|
81
|
+
|
|
81
82
|
def trace_depth_distr(split_trees_trace):
|
|
82
83
|
return jax.vmap(forest_depth_distr)(split_trees_trace)
|
|
83
84
|
|
|
85
|
+
|
|
84
86
|
def points_per_leaf_distr(var_tree, split_tree, X):
|
|
85
87
|
traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
|
|
86
88
|
indices = traverse_tree(X, var_tree, split_tree)
|
|
87
|
-
count_tree = jnp.zeros(
|
|
89
|
+
count_tree = jnp.zeros(
|
|
90
|
+
2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
|
|
91
|
+
)
|
|
88
92
|
count_tree = count_tree.at[indices].add(1)
|
|
89
93
|
is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
|
|
90
94
|
return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
|
|
91
95
|
|
|
96
|
+
|
|
92
97
|
def forest_points_per_leaf_distr(bart, X):
|
|
93
98
|
distr = jnp.zeros(X.shape[1] + 1, int)
|
|
94
99
|
trees = bart['var_trees'], bart['split_trees']
|
|
100
|
+
|
|
95
101
|
def loop(distr, tree):
|
|
96
102
|
return distr + points_per_leaf_distr(*tree, X), None
|
|
103
|
+
|
|
97
104
|
distr, _ = lax.scan(loop, distr, trees)
|
|
98
105
|
return distr
|
|
99
106
|
|
|
107
|
+
|
|
100
108
|
def trace_points_per_leaf_distr(bart, X):
|
|
101
109
|
def loop(_, bart):
|
|
102
110
|
return None, forest_points_per_leaf_distr(bart, X)
|
|
111
|
+
|
|
103
112
|
_, distr = lax.scan(loop, None, bart)
|
|
104
113
|
return distr
|
|
105
114
|
|
|
115
|
+
|
|
106
116
|
def check_types(leaf_tree, var_tree, split_tree, max_split):
|
|
107
117
|
expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
|
|
108
118
|
expected_split_dtype = max_split.dtype
|
|
109
|
-
return
|
|
119
|
+
return (
|
|
120
|
+
var_tree.dtype == expected_var_dtype
|
|
121
|
+
and split_tree.dtype == expected_split_dtype
|
|
122
|
+
)
|
|
123
|
+
|
|
110
124
|
|
|
111
125
|
def check_sizes(leaf_tree, var_tree, split_tree, max_split):
|
|
112
126
|
return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
|
|
113
127
|
|
|
128
|
+
|
|
114
129
|
def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
|
|
115
130
|
return (var_tree[0] == 0) & (split_tree[0] == 0)
|
|
116
131
|
|
|
132
|
+
|
|
117
133
|
def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
|
|
118
134
|
return jnp.all(jnp.isfinite(leaf_tree))
|
|
119
135
|
|
|
136
|
+
|
|
120
137
|
def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
121
|
-
index = jnp.arange(
|
|
138
|
+
index = jnp.arange(
|
|
139
|
+
2 * split_tree.size,
|
|
140
|
+
dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
|
|
141
|
+
)
|
|
122
142
|
parent_index = index >> 1
|
|
123
143
|
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
124
144
|
parent_is_leaf = split_tree[parent_index] == 0
|
|
@@ -126,6 +146,7 @@ def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
|
|
|
126
146
|
stray = stray.at[1].set(False)
|
|
127
147
|
return ~jnp.any(stray)
|
|
128
148
|
|
|
149
|
+
|
|
129
150
|
check_functions = [
|
|
130
151
|
check_types,
|
|
131
152
|
check_sizes,
|
|
@@ -134,6 +155,7 @@ check_functions = [
|
|
|
134
155
|
check_stray_nodes,
|
|
135
156
|
]
|
|
136
157
|
|
|
158
|
+
|
|
137
159
|
def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
138
160
|
error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
139
161
|
error = error_type(0)
|
|
@@ -144,15 +166,19 @@ def check_tree(leaf_tree, var_tree, split_tree, max_split):
|
|
|
144
166
|
error |= bit
|
|
145
167
|
return error
|
|
146
168
|
|
|
169
|
+
|
|
147
170
|
def describe_error(error):
|
|
148
|
-
return [
|
|
149
|
-
|
|
150
|
-
for i, func in enumerate(check_functions)
|
|
151
|
-
if error & (1 << i)
|
|
152
|
-
]
|
|
171
|
+
return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
|
|
172
|
+
|
|
153
173
|
|
|
154
174
|
check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None))
|
|
155
175
|
|
|
176
|
+
|
|
156
177
|
@functools.partial(jax.vmap, in_axes=(0, None))
|
|
157
178
|
def check_trace(trace, state):
|
|
158
|
-
return check_forest(
|
|
179
|
+
return check_forest(
|
|
180
|
+
trace['leaf_trees'],
|
|
181
|
+
trace['var_trees'],
|
|
182
|
+
trace['split_trees'],
|
|
183
|
+
state['max_split'],
|
|
184
|
+
)
|