bartz 0.4.0__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz-0.5.0/PKG-INFO +48 -0
- bartz-0.5.0/README.md +31 -0
- {bartz-0.4.0 → bartz-0.5.0}/pyproject.toml +68 -38
- bartz-0.5.0/src/bartz/.DS_Store +0 -0
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/BART.py +99 -39
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/__init__.py +6 -14
- bartz-0.5.0/src/bartz/_version.py +1 -0
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/debug.py +42 -16
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/grove.py +20 -11
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/jaxext.py +44 -38
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/mcmcloop.py +119 -58
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/mcmcstep.py +426 -173
- {bartz-0.4.0 → bartz-0.5.0}/src/bartz/prepcovars.py +22 -9
- bartz-0.4.0/LICENSE +0 -21
- bartz-0.4.0/PKG-INFO +0 -77
- bartz-0.4.0/README.md +0 -54
- bartz-0.4.0/src/bartz/_version.py +0 -1
bartz-0.5.0/PKG-INFO
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: bartz
|
|
3
|
+
Version: 0.5.0
|
|
4
|
+
Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
|
|
5
|
+
Author: Giacomo Petrillo
|
|
6
|
+
Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Requires-Dist: jax>=0.4.35,<1
|
|
9
|
+
Requires-Dist: jaxlib>=0.4.35,<1
|
|
10
|
+
Requires-Dist: numpy>=1.25.2,<3
|
|
11
|
+
Requires-Dist: scipy>=1.11.4,<2
|
|
12
|
+
Requires-Python: >=3.10
|
|
13
|
+
Project-URL: Documentation, https://gattocrucco.github.io/bartz/docs-dev
|
|
14
|
+
Project-URL: Homepage, https://github.com/Gattocrucco/bartz
|
|
15
|
+
Project-URL: Issues, https://github.com/Gattocrucco/bartz/issues
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
[](https://pypi.org/project/bartz/)
|
|
19
|
+
[](https://doi.org/10.5281/zenodo.13931477)
|
|
20
|
+
|
|
21
|
+
# BART vectoriZed
|
|
22
|
+
|
|
23
|
+
An implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
24
|
+
|
|
25
|
+
If you don't know what BART is, but know XGBoost, consider BART as a sort of Bayesian XGBoost. bartz makes BART run as fast as XGBoost.
|
|
26
|
+
|
|
27
|
+
BART is a nonparametric Bayesian regression technique. Given training predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
28
|
+
|
|
29
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
30
|
+
|
|
31
|
+
On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of) if n > 20,000, but using 1/20 of the memory. On GPU, the speed premium depends on sample size; it is convenient over CPU only for n > 10,000. The maximum speedup is currently 200x, on an Nvidia A100 and with at least 2,000,000 observations.
|
|
32
|
+
|
|
33
|
+
[This Colab notebook](https://colab.research.google.com/github/Gattocrucco/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 5 minutes.
|
|
34
|
+
|
|
35
|
+
## Links
|
|
36
|
+
|
|
37
|
+
- [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
|
|
38
|
+
- [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
|
|
39
|
+
- [Repository](https://github.com/Gattocrucco/bartz)
|
|
40
|
+
- [Code coverage](https://gattocrucco.github.io/bartz/coverage)
|
|
41
|
+
- [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
|
|
42
|
+
- [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
|
|
43
|
+
|
|
44
|
+
## Citing bartz
|
|
45
|
+
|
|
46
|
+
Article: Petrillo (2024), "Very fast Bayesian Additive Regression Trees on GPU", [arXiv:2410.23244](https://arxiv.org/abs/2410.23244).
|
|
47
|
+
|
|
48
|
+
To cite the software directly, including the specific version, use [zenodo](https://doi.org/10.5281/zenodo.13931477).
|
bartz-0.5.0/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
[](https://pypi.org/project/bartz/)
|
|
2
|
+
[](https://doi.org/10.5281/zenodo.13931477)
|
|
3
|
+
|
|
4
|
+
# BART vectoriZed
|
|
5
|
+
|
|
6
|
+
An implementation of Bayesian Additive Regression Trees (BART) in JAX.
|
|
7
|
+
|
|
8
|
+
If you don't know what BART is, but know XGBoost, consider BART as a sort of Bayesian XGBoost. bartz makes BART run as fast as XGBoost.
|
|
9
|
+
|
|
10
|
+
BART is a nonparametric Bayesian regression technique. Given training predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
|
|
11
|
+
|
|
12
|
+
This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only.
|
|
13
|
+
|
|
14
|
+
On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of) if n > 20,000, but using 1/20 of the memory. On GPU, the speed premium depends on sample size; it is convenient over CPU only for n > 10,000. The maximum speedup is currently 200x, on an Nvidia A100 and with at least 2,000,000 observations.
|
|
15
|
+
|
|
16
|
+
[This Colab notebook](https://colab.research.google.com/github/Gattocrucco/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 5 minutes.
|
|
17
|
+
|
|
18
|
+
## Links
|
|
19
|
+
|
|
20
|
+
- [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
|
|
21
|
+
- [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
|
|
22
|
+
- [Repository](https://github.com/Gattocrucco/bartz)
|
|
23
|
+
- [Code coverage](https://gattocrucco.github.io/bartz/coverage)
|
|
24
|
+
- [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
|
|
25
|
+
- [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
|
|
26
|
+
|
|
27
|
+
## Citing bartz
|
|
28
|
+
|
|
29
|
+
Article: Petrillo (2024), "Very fast Bayesian Additive Regression Trees on GPU", [arXiv:2410.23244](https://arxiv.org/abs/2410.23244).
|
|
30
|
+
|
|
31
|
+
To cite the software directly, including the specific version, use [zenodo](https://doi.org/10.5281/zenodo.13931477).
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/pyproject.toml
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -23,51 +23,58 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
[build-system]
|
|
26
|
-
requires = ["
|
|
27
|
-
build-backend = "
|
|
26
|
+
requires = ["uv_build>=0.7.3,<0.8.0"]
|
|
27
|
+
build-backend = "uv_build"
|
|
28
28
|
|
|
29
|
-
[
|
|
29
|
+
[project]
|
|
30
30
|
name = "bartz"
|
|
31
|
-
version = "0.
|
|
32
|
-
description = "
|
|
33
|
-
authors = [
|
|
31
|
+
version = "0.5.0"
|
|
32
|
+
description = "Super-fast BART (Bayesian Additive Regression Trees) in Python"
|
|
33
|
+
authors = [
|
|
34
|
+
{name = "Giacomo Petrillo", email = "info@giacomopetrillo.com"},
|
|
35
|
+
]
|
|
34
36
|
license = "MIT"
|
|
35
37
|
readme = "README.md"
|
|
36
|
-
|
|
38
|
+
requires-python = ">=3.10"
|
|
37
39
|
packages = [
|
|
38
40
|
{ include = "bartz", from = "src" },
|
|
39
41
|
]
|
|
42
|
+
dependencies = [
|
|
43
|
+
"jax >=0.4.35,<1",
|
|
44
|
+
"jaxlib >=0.4.35,<1",
|
|
45
|
+
"numpy >=1.25.2,<3",
|
|
46
|
+
"scipy >=1.11.4,<2",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
[project.urls]
|
|
50
|
+
Homepage = "https://github.com/Gattocrucco/bartz"
|
|
51
|
+
Documentation = "https://gattocrucco.github.io/bartz/docs-dev"
|
|
52
|
+
Issues = "https://github.com/Gattocrucco/bartz/issues"
|
|
40
53
|
|
|
41
|
-
[
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
pytest = "^8.1.1"
|
|
66
|
-
|
|
67
|
-
[tool.poetry.group.docs.dependencies]
|
|
68
|
-
Sphinx = "^7.2.6"
|
|
69
|
-
numpydoc = "^1.6.0"
|
|
70
|
-
myst-parser = "^2.0.0"
|
|
54
|
+
[dependency-groups]
|
|
55
|
+
only-local = [
|
|
56
|
+
"appnope>=0.1.4",
|
|
57
|
+
"ipython>=8.36.0",
|
|
58
|
+
"matplotlib>=3.10.3",
|
|
59
|
+
"matplotlib-label-lines>=0.8.1",
|
|
60
|
+
"polars[pandas,pyarrow]>=1.29.0",
|
|
61
|
+
"pre-commit>=4.2.0",
|
|
62
|
+
"ruff>=0.11.9",
|
|
63
|
+
"scikit-learn>=1.6.1",
|
|
64
|
+
"tomli>=2.2.1",
|
|
65
|
+
"virtualenv>=20.31.2",
|
|
66
|
+
"xgboost>=3.0.0",
|
|
67
|
+
]
|
|
68
|
+
ci = [
|
|
69
|
+
"asv>=0.6.4",
|
|
70
|
+
"coverage>=7.8.0",
|
|
71
|
+
"myst-parser>=4.0.1",
|
|
72
|
+
"numpydoc>=1.8.0",
|
|
73
|
+
"packaging>=25.0",
|
|
74
|
+
"pytest>=8.3.5",
|
|
75
|
+
"pytest-timeout>=2.4.0",
|
|
76
|
+
"sphinx>=8.1.3",
|
|
77
|
+
]
|
|
71
78
|
|
|
72
79
|
[tool.pytest.ini_options]
|
|
73
80
|
testpaths = ["tests"]
|
|
@@ -79,6 +86,8 @@ addopts = [
|
|
|
79
86
|
"--pdbcls=IPython.terminal.debugger:TerminalPdb",
|
|
80
87
|
"--durations=3",
|
|
81
88
|
]
|
|
89
|
+
timeout = 32
|
|
90
|
+
timeout_method = "thread" # when jax hangs, signals do not work
|
|
82
91
|
|
|
83
92
|
# I wanted to use `--import-mode=importlib`, but it breaks importing submodules,
|
|
84
93
|
# in particular `from . import util`.
|
|
@@ -117,3 +126,24 @@ local = [
|
|
|
117
126
|
'/opt/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/',
|
|
118
127
|
'C:\hostedtoolcache\windows\Python\*\*\Lib\site-packages\bartz\',
|
|
119
128
|
]
|
|
129
|
+
|
|
130
|
+
[tool.ruff]
|
|
131
|
+
exclude = [".asv", "*.ipynb"]
|
|
132
|
+
|
|
133
|
+
[tool.ruff.format]
|
|
134
|
+
quote-style = "single"
|
|
135
|
+
|
|
136
|
+
[tool.ruff.lint]
|
|
137
|
+
select = [
|
|
138
|
+
"B", # bugbear: grab bag of additional stuff
|
|
139
|
+
"UP", # pyupgrade: fix some outdated idioms
|
|
140
|
+
"I", # isort: sort and reformat import statements
|
|
141
|
+
"F", # flake8
|
|
142
|
+
]
|
|
143
|
+
ignore = [
|
|
144
|
+
"B028", # warn with stacklevel = 2
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
[tool.uv]
|
|
148
|
+
python-downloads = "never"
|
|
149
|
+
python-preference = "only-system"
|
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/BART.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -27,11 +27,8 @@ import functools
|
|
|
27
27
|
import jax
|
|
28
28
|
import jax.numpy as jnp
|
|
29
29
|
|
|
30
|
-
from . import jaxext
|
|
31
|
-
|
|
32
|
-
from . import mcmcstep
|
|
33
|
-
from . import mcmcloop
|
|
34
|
-
from . import prepcovars
|
|
30
|
+
from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
|
|
31
|
+
|
|
35
32
|
|
|
36
33
|
class gbart:
|
|
37
34
|
"""
|
|
@@ -53,10 +50,11 @@ class gbart:
|
|
|
53
50
|
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
54
51
|
predictors.
|
|
55
52
|
sigest : float, optional
|
|
56
|
-
An estimate of the residual standard deviation on `y_train`, used to
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
is set to
|
|
53
|
+
An estimate of the residual standard deviation on `y_train`, used to set
|
|
54
|
+
`lamda`. If not specified, it is estimated by linear regression (with
|
|
55
|
+
intercept, and without taking into account `w`). If `y_train` has less
|
|
56
|
+
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
57
|
+
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
60
58
|
sigdf : int, default 3
|
|
61
59
|
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
62
60
|
noise variance.
|
|
@@ -82,6 +80,12 @@ class gbart:
|
|
|
82
80
|
offset : float, optional
|
|
83
81
|
The prior mean of the latent mean function. If not specified, it is set
|
|
84
82
|
to the mean of `y_train`. If `y_train` is empty, it is set to 0.
|
|
83
|
+
w : array (n,), optional
|
|
84
|
+
Coefficients that rescale the error standard deviation on each
|
|
85
|
+
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
86
|
+
datapoints. Note: `w` is ignored in the automatic determination of
|
|
87
|
+
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
88
|
+
specified by the user.
|
|
85
89
|
ntree : int, default 200
|
|
86
90
|
The number of trees used to represent the latent mean function.
|
|
87
91
|
numcut : int, default 255
|
|
@@ -108,6 +112,8 @@ class gbart:
|
|
|
108
112
|
The number of iterations (including skipped ones) between each log.
|
|
109
113
|
seed : int or jax random key, default 0
|
|
110
114
|
The seed for the random number generator.
|
|
115
|
+
initkw : dict
|
|
116
|
+
Additional arguments passed to `mcmcstep.init`.
|
|
111
117
|
|
|
112
118
|
Attributes
|
|
113
119
|
----------
|
|
@@ -135,8 +141,6 @@ class gbart:
|
|
|
135
141
|
The number of trees.
|
|
136
142
|
maxdepth : int
|
|
137
143
|
The maximum depth of the trees.
|
|
138
|
-
initkw : dict
|
|
139
|
-
Additional arguments passed to `mcmcstep.init`.
|
|
140
144
|
|
|
141
145
|
Methods
|
|
142
146
|
-------
|
|
@@ -158,10 +162,13 @@ class gbart:
|
|
|
158
162
|
- A lot of functionality is missing (variable selection, discrete response).
|
|
159
163
|
- There are some additional attributes, and some missing.
|
|
160
164
|
|
|
161
|
-
The linear regression used to set `sigest` adds an intercept.
|
|
162
165
|
"""
|
|
163
166
|
|
|
164
|
-
def __init__(
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
x_train,
|
|
170
|
+
y_train,
|
|
171
|
+
*,
|
|
165
172
|
x_test=None,
|
|
166
173
|
usequants=False,
|
|
167
174
|
sigest=None,
|
|
@@ -173,6 +180,7 @@ class gbart:
|
|
|
173
180
|
maxdepth=6,
|
|
174
181
|
lamda=None,
|
|
175
182
|
offset=None,
|
|
183
|
+
w=None,
|
|
176
184
|
ntree=200,
|
|
177
185
|
numcut=255,
|
|
178
186
|
ndpost=1000,
|
|
@@ -180,26 +188,41 @@ class gbart:
|
|
|
180
188
|
keepevery=1,
|
|
181
189
|
printevery=100,
|
|
182
190
|
seed=0,
|
|
183
|
-
initkw=
|
|
184
|
-
|
|
185
|
-
|
|
191
|
+
initkw=None,
|
|
192
|
+
):
|
|
186
193
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
187
|
-
|
|
188
|
-
y_train, y_train_fmt = self._process_response_input(y_train)
|
|
194
|
+
y_train, _ = self._process_response_input(y_train)
|
|
189
195
|
self._check_same_length(x_train, y_train)
|
|
196
|
+
if w is not None:
|
|
197
|
+
w, _ = self._process_response_input(w)
|
|
198
|
+
self._check_same_length(x_train, w)
|
|
190
199
|
|
|
191
200
|
offset = self._process_offset_settings(y_train, offset)
|
|
192
201
|
scale = self._process_scale_settings(y_train, k)
|
|
193
|
-
lamda, sigest = self._process_noise_variance_settings(
|
|
202
|
+
lamda, sigest = self._process_noise_variance_settings(
|
|
203
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda, offset
|
|
204
|
+
)
|
|
194
205
|
|
|
195
206
|
splits, max_split = self._determine_splits(x_train, usequants, numcut)
|
|
196
207
|
x_train = self._bin_predictors(x_train, splits)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
208
|
+
y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
|
|
209
|
+
|
|
210
|
+
mcmc_state = self._setup_mcmc(
|
|
211
|
+
x_train,
|
|
212
|
+
y_train,
|
|
213
|
+
w,
|
|
214
|
+
max_split,
|
|
215
|
+
lamda_scaled,
|
|
216
|
+
sigdf,
|
|
217
|
+
power,
|
|
218
|
+
base,
|
|
219
|
+
maxdepth,
|
|
220
|
+
ntree,
|
|
221
|
+
initkw,
|
|
222
|
+
)
|
|
223
|
+
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
224
|
+
mcmc_state, ndpost, nskip, keepevery, printevery, seed
|
|
225
|
+
)
|
|
203
226
|
|
|
204
227
|
sigma = self._extract_sigma(main_trace, scale)
|
|
205
228
|
first_sigma = self._extract_sigma(burnin_trace, scale)
|
|
@@ -239,7 +262,7 @@ class gbart:
|
|
|
239
262
|
|
|
240
263
|
Parameters
|
|
241
264
|
----------
|
|
242
|
-
x_test : array (
|
|
265
|
+
x_test : array (p, m) or DataFrame
|
|
243
266
|
The test predictors.
|
|
244
267
|
|
|
245
268
|
Returns
|
|
@@ -285,7 +308,9 @@ class gbart:
|
|
|
285
308
|
assert get_length(x1) == get_length(x2)
|
|
286
309
|
|
|
287
310
|
@staticmethod
|
|
288
|
-
def _process_noise_variance_settings(
|
|
311
|
+
def _process_noise_variance_settings(
|
|
312
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda, offset
|
|
313
|
+
):
|
|
289
314
|
if lamda is not None:
|
|
290
315
|
return lamda, None
|
|
291
316
|
else:
|
|
@@ -298,7 +323,7 @@ class gbart:
|
|
|
298
323
|
else:
|
|
299
324
|
x_centered = x_train.T - x_train.mean(axis=1)
|
|
300
325
|
y_centered = y_train - y_train.mean()
|
|
301
|
-
|
|
326
|
+
# centering is equivalent to adding an intercept column
|
|
302
327
|
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
303
328
|
chisq = chisq.squeeze(0)
|
|
304
329
|
dof = len(y_train) - rank
|
|
@@ -336,11 +361,25 @@ class gbart:
|
|
|
336
361
|
return prepcovars.bin_predictors(x, splits)
|
|
337
362
|
|
|
338
363
|
@staticmethod
|
|
339
|
-
def _transform_input(y, offset, scale):
|
|
340
|
-
|
|
364
|
+
def _transform_input(y, lamda, offset, scale):
|
|
365
|
+
y = (y - offset) / scale
|
|
366
|
+
lamda = lamda / (scale * scale)
|
|
367
|
+
return y, lamda
|
|
341
368
|
|
|
342
369
|
@staticmethod
|
|
343
|
-
def _setup_mcmc(
|
|
370
|
+
def _setup_mcmc(
|
|
371
|
+
x_train,
|
|
372
|
+
y_train,
|
|
373
|
+
w,
|
|
374
|
+
max_split,
|
|
375
|
+
lamda,
|
|
376
|
+
sigdf,
|
|
377
|
+
power,
|
|
378
|
+
base,
|
|
379
|
+
maxdepth,
|
|
380
|
+
ntree,
|
|
381
|
+
initkw,
|
|
382
|
+
):
|
|
344
383
|
depth = jnp.arange(maxdepth - 1)
|
|
345
384
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
346
385
|
sigma2_alpha = sigdf / 2
|
|
@@ -348,6 +387,7 @@ class gbart:
|
|
|
348
387
|
kw = dict(
|
|
349
388
|
X=x_train,
|
|
350
389
|
y=y_train,
|
|
390
|
+
error_scale=w,
|
|
351
391
|
max_split=max_split,
|
|
352
392
|
num_trees=ntree,
|
|
353
393
|
p_nonterminal=p_nonterminal,
|
|
@@ -355,17 +395,20 @@ class gbart:
|
|
|
355
395
|
sigma2_beta=sigma2_beta,
|
|
356
396
|
min_points_per_leaf=5,
|
|
357
397
|
)
|
|
358
|
-
|
|
398
|
+
if initkw is not None:
|
|
399
|
+
kw.update(initkw)
|
|
359
400
|
return mcmcstep.init(**kw)
|
|
360
401
|
|
|
361
402
|
@staticmethod
|
|
362
403
|
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
|
|
363
|
-
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
404
|
+
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
405
|
+
seed.dtype, jax.dtypes.prng_key
|
|
406
|
+
):
|
|
364
407
|
key = seed
|
|
365
408
|
else:
|
|
366
409
|
key = jax.random.key(seed)
|
|
367
410
|
callback = mcmcloop.make_simple_print_callback(printevery)
|
|
368
|
-
return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback
|
|
411
|
+
return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
|
|
369
412
|
|
|
370
413
|
@staticmethod
|
|
371
414
|
def _predict(trace, x):
|
|
@@ -379,9 +422,9 @@ class gbart:
|
|
|
379
422
|
def _extract_sigma(trace, scale):
|
|
380
423
|
return scale * jnp.sqrt(trace['sigma2'])
|
|
381
424
|
|
|
382
|
-
|
|
383
425
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
384
426
|
from . import debug
|
|
427
|
+
|
|
385
428
|
trace = self._main_trace
|
|
386
429
|
leaf_tree = trace['leaf_trees'][i_sample, i_tree]
|
|
387
430
|
var_tree = trace['var_trees'][i_sample, i_tree]
|
|
@@ -396,7 +439,9 @@ class gbart:
|
|
|
396
439
|
else:
|
|
397
440
|
resid = bart['resid']
|
|
398
441
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
399
|
-
norm2 = jnp.dot(
|
|
442
|
+
norm2 = jnp.dot(
|
|
443
|
+
resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
|
|
444
|
+
)
|
|
400
445
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
401
446
|
sigma2 = beta / alpha
|
|
402
447
|
return jnp.sqrt(sigma2) * self.scale
|
|
@@ -404,22 +449,32 @@ class gbart:
|
|
|
404
449
|
def _compare_resid(self):
|
|
405
450
|
bart = self._mcmc_state
|
|
406
451
|
resid1 = bart['resid']
|
|
407
|
-
yhat = grove.evaluate_forest(
|
|
452
|
+
yhat = grove.evaluate_forest(
|
|
453
|
+
bart['X'],
|
|
454
|
+
bart['leaf_trees'],
|
|
455
|
+
bart['var_trees'],
|
|
456
|
+
bart['split_trees'],
|
|
457
|
+
jnp.float32,
|
|
458
|
+
)
|
|
408
459
|
resid2 = bart['y'] - yhat
|
|
409
460
|
return resid1, resid2
|
|
410
461
|
|
|
411
462
|
def _avg_acc(self):
|
|
412
463
|
trace = self._main_trace
|
|
464
|
+
|
|
413
465
|
def acc(prefix):
|
|
414
466
|
acc = trace[f'{prefix}_acc_count']
|
|
415
467
|
prop = trace[f'{prefix}_prop_count']
|
|
416
468
|
return acc.sum() / prop.sum()
|
|
469
|
+
|
|
417
470
|
return acc('grow'), acc('prune')
|
|
418
471
|
|
|
419
472
|
def _avg_prop(self):
|
|
420
473
|
trace = self._main_trace
|
|
474
|
+
|
|
421
475
|
def prop(prefix):
|
|
422
476
|
return trace[f'{prefix}_prop_count'].sum()
|
|
477
|
+
|
|
423
478
|
pgrow = prop('grow')
|
|
424
479
|
pprune = prop('prune')
|
|
425
480
|
total = pgrow + pprune
|
|
@@ -432,16 +487,21 @@ class gbart:
|
|
|
432
487
|
|
|
433
488
|
def _depth_distr(self):
|
|
434
489
|
from . import debug
|
|
490
|
+
|
|
435
491
|
trace = self._main_trace
|
|
436
492
|
split_trees = trace['split_trees']
|
|
437
493
|
return debug.trace_depth_distr(split_trees)
|
|
438
494
|
|
|
439
495
|
def _points_per_leaf_distr(self):
|
|
440
496
|
from . import debug
|
|
441
|
-
|
|
497
|
+
|
|
498
|
+
return debug.trace_points_per_leaf_distr(
|
|
499
|
+
self._main_trace, self._mcmc_state['X']
|
|
500
|
+
)
|
|
442
501
|
|
|
443
502
|
def _check_trees(self):
|
|
444
503
|
from . import debug
|
|
504
|
+
|
|
445
505
|
return debug.check_trace(self._main_trace, self._mcmc_state)
|
|
446
506
|
|
|
447
507
|
def _tree_goes_bad(self):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/__init__.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -23,18 +23,10 @@
|
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
25
|
"""
|
|
26
|
-
|
|
26
|
+
Super-fast BART (Bayesian Additive Regression Trees) in Python
|
|
27
27
|
|
|
28
28
|
See the manual at https://gattocrucco.github.io/bartz/docs
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
from .
|
|
32
|
-
|
|
33
|
-
from . import BART
|
|
34
|
-
|
|
35
|
-
from . import debug
|
|
36
|
-
from . import grove
|
|
37
|
-
from . import mcmcstep
|
|
38
|
-
from . import mcmcloop
|
|
39
|
-
from . import prepcovars
|
|
40
|
-
from . import jaxext
|
|
31
|
+
from . import BART, debug, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
|
|
32
|
+
from ._version import __version__ # noqa: F401
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.5.0'
|