bartz 0.5.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bartz-0.5.0 → bartz-0.6.0}/PKG-INFO +7 -5
- {bartz-0.5.0 → bartz-0.6.0}/pyproject.toml +48 -12
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/BART.py +196 -103
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/__init__.py +1 -1
- bartz-0.6.0/src/bartz/_version.py +1 -0
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/debug.py +1 -1
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/grove.py +43 -2
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/jaxext.py +82 -33
- bartz-0.6.0/src/bartz/mcmcloop.py +511 -0
- bartz-0.6.0/src/bartz/mcmcstep.py +2335 -0
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/prepcovars.py +3 -1
- bartz-0.5.0/src/bartz/_version.py +0 -1
- bartz-0.5.0/src/bartz/mcmcloop.py +0 -258
- bartz-0.5.0/src/bartz/mcmcstep.py +0 -1820
- {bartz-0.5.0 → bartz-0.6.0}/README.md +0 -0
- {bartz-0.5.0 → bartz-0.6.0}/src/bartz/.DS_Store +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: bartz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
|
|
5
5
|
Author: Giacomo Petrillo
|
|
6
6
|
Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
|
|
7
7
|
License-Expression: MIT
|
|
8
|
-
Requires-Dist:
|
|
9
|
-
Requires-Dist:
|
|
10
|
-
Requires-Dist:
|
|
11
|
-
Requires-Dist:
|
|
8
|
+
Requires-Dist: equinox>=0.12.2
|
|
9
|
+
Requires-Dist: jax>=0.4.35
|
|
10
|
+
Requires-Dist: jaxlib>=0.4.35
|
|
11
|
+
Requires-Dist: jaxtyping>=0.3.2
|
|
12
|
+
Requires-Dist: numpy>=1.25.2
|
|
13
|
+
Requires-Dist: scipy>=1.11.4
|
|
12
14
|
Requires-Python: >=3.10
|
|
13
15
|
Project-URL: Documentation, https://gattocrucco.github.io/bartz/docs-dev
|
|
14
16
|
Project-URL: Homepage, https://github.com/Gattocrucco/bartz
|
|
@@ -28,7 +28,7 @@ build-backend = "uv_build"
|
|
|
28
28
|
|
|
29
29
|
[project]
|
|
30
30
|
name = "bartz"
|
|
31
|
-
version = "0.
|
|
31
|
+
version = "0.6.0"
|
|
32
32
|
description = "Super-fast BART (Bayesian Additive Regression Trees) in Python"
|
|
33
33
|
authors = [
|
|
34
34
|
{name = "Giacomo Petrillo", email = "info@giacomopetrillo.com"},
|
|
@@ -36,14 +36,13 @@ authors = [
|
|
|
36
36
|
license = "MIT"
|
|
37
37
|
readme = "README.md"
|
|
38
38
|
requires-python = ">=3.10"
|
|
39
|
-
packages = [
|
|
40
|
-
{ include = "bartz", from = "src" },
|
|
41
|
-
]
|
|
42
39
|
dependencies = [
|
|
43
|
-
"
|
|
44
|
-
"
|
|
45
|
-
"
|
|
46
|
-
"
|
|
40
|
+
"equinox>=0.12.2",
|
|
41
|
+
"jax>=0.4.35",
|
|
42
|
+
"jaxlib>=0.4.35",
|
|
43
|
+
"jaxtyping>=0.3.2",
|
|
44
|
+
"numpy>=1.25.2",
|
|
45
|
+
"scipy>=1.11.4",
|
|
47
46
|
]
|
|
48
47
|
|
|
49
48
|
[project.urls]
|
|
@@ -57,8 +56,8 @@ only-local = [
|
|
|
57
56
|
"ipython>=8.36.0",
|
|
58
57
|
"matplotlib>=3.10.3",
|
|
59
58
|
"matplotlib-label-lines>=0.8.1",
|
|
60
|
-
"polars[pandas,pyarrow]>=1.29.0",
|
|
61
59
|
"pre-commit>=4.2.0",
|
|
60
|
+
"pydoclint>=0.6.6",
|
|
62
61
|
"ruff>=0.11.9",
|
|
63
62
|
"scikit-learn>=1.6.1",
|
|
64
63
|
"tomli>=2.2.1",
|
|
@@ -71,12 +70,15 @@ ci = [
|
|
|
71
70
|
"myst-parser>=4.0.1",
|
|
72
71
|
"numpydoc>=1.8.0",
|
|
73
72
|
"packaging>=25.0",
|
|
73
|
+
"polars[pandas,pyarrow]>=1.29.0",
|
|
74
74
|
"pytest>=8.3.5",
|
|
75
75
|
"pytest-timeout>=2.4.0",
|
|
76
76
|
"sphinx>=8.1.3",
|
|
77
|
+
"sphinx-autodoc-typehints>=3.0.1",
|
|
77
78
|
]
|
|
78
79
|
|
|
79
80
|
[tool.pytest.ini_options]
|
|
81
|
+
cache_dir = "config/pytest_cache"
|
|
80
82
|
testpaths = ["tests"]
|
|
81
83
|
filterwarnings = [
|
|
82
84
|
'error:scatter inputs have incompatible types.*',
|
|
@@ -85,8 +87,9 @@ addopts = [
|
|
|
85
87
|
"-r xXfE",
|
|
86
88
|
"--pdbcls=IPython.terminal.debugger:TerminalPdb",
|
|
87
89
|
"--durations=3",
|
|
90
|
+
"--verbose",
|
|
88
91
|
]
|
|
89
|
-
timeout =
|
|
92
|
+
timeout = 64
|
|
90
93
|
timeout_method = "thread" # when jax hangs, signals do not work
|
|
91
94
|
|
|
92
95
|
# I wanted to use `--import-mode=importlib`, but it breaks importing submodules,
|
|
@@ -101,6 +104,7 @@ show_missing = true
|
|
|
101
104
|
|
|
102
105
|
[tool.coverage.html]
|
|
103
106
|
show_contexts = true
|
|
107
|
+
directory = "_site/coverage"
|
|
104
108
|
|
|
105
109
|
[tool.coverage.paths]
|
|
106
110
|
# the first path in each list must be the source directory in the machine that's
|
|
@@ -129,6 +133,7 @@ local = [
|
|
|
129
133
|
|
|
130
134
|
[tool.ruff]
|
|
131
135
|
exclude = [".asv", "*.ipynb"]
|
|
136
|
+
cache-dir = "config/ruff_cache"
|
|
132
137
|
|
|
133
138
|
[tool.ruff.format]
|
|
134
139
|
quote-style = "single"
|
|
@@ -138,12 +143,43 @@ select = [
|
|
|
138
143
|
"B", # bugbear: grab bag of additional stuff
|
|
139
144
|
"UP", # pyupgrade: fix some outdated idioms
|
|
140
145
|
"I", # isort: sort and reformat import statements
|
|
141
|
-
"F", #
|
|
146
|
+
"F", # pyflakes
|
|
147
|
+
"D", # pydocstyle
|
|
148
|
+
"PT", # flake8-pytest-style
|
|
142
149
|
]
|
|
143
150
|
ignore = [
|
|
144
|
-
"B028",
|
|
151
|
+
"B028", # warn with stacklevel = 2
|
|
152
|
+
"D105", # Missing docstring in magic method
|
|
153
|
+
"F722", # Syntax error in forward annotation. I ignore this because jaxtyping uses strings for shapes instead of for deferred annotations.
|
|
154
|
+
"F821", # Undefined name. I ignore this because strings in jaxtyping.
|
|
155
|
+
"UP037", # Remove quotes from type annotation. Ignore because jaxtyping.
|
|
145
156
|
]
|
|
146
157
|
|
|
158
|
+
[tool.ruff.lint.per-file-ignores]
|
|
159
|
+
"{config/*,benchmarks/*,docs/*,src/bartz/debug.py,tests/rbartpackages/*,tests/__init__.py}" = [
|
|
160
|
+
"D100", # Missing docstring in public module
|
|
161
|
+
"D101", # Missing docstring in public class
|
|
162
|
+
"D102", # Missing docstring in public method
|
|
163
|
+
"D103", # Missing docstring in public function
|
|
164
|
+
"D104", # Missing docstring in public package
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
[tool.ruff.lint.pydocstyle]
|
|
168
|
+
convention = "numpy"
|
|
169
|
+
ignore-decorators = ["functools.cached_property"]
|
|
170
|
+
|
|
171
|
+
[tool.pydoclint]
|
|
172
|
+
baseline = "config/pydoclint-baseline.txt"
|
|
173
|
+
auto-regenerate-baseline = true
|
|
174
|
+
arg-type-hints-in-signature = true
|
|
175
|
+
arg-type-hints-in-docstring = false
|
|
176
|
+
check-return-types = false
|
|
177
|
+
check-yield-types = false
|
|
178
|
+
treat-property-methods-as-class-attributes = true
|
|
179
|
+
check-style-mismatch = true
|
|
180
|
+
show-filenames-in-every-violation-message = true
|
|
181
|
+
check-class-attributes = false
|
|
182
|
+
|
|
147
183
|
[tool.uv]
|
|
148
184
|
python-downloads = "never"
|
|
149
185
|
python-preference = "only-system"
|
|
@@ -22,13 +22,21 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
+
"""Implement a user interface that mimics the R BART package."""
|
|
26
|
+
|
|
25
27
|
import functools
|
|
28
|
+
import math
|
|
29
|
+
from typing import Any, Literal
|
|
26
30
|
|
|
27
31
|
import jax
|
|
28
32
|
import jax.numpy as jnp
|
|
33
|
+
from jax.scipy.special import ndtri
|
|
34
|
+
from jaxtyping import Array, Bool, Float, Float32
|
|
29
35
|
|
|
30
36
|
from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
|
|
31
37
|
|
|
38
|
+
FloatLike = float | Float[Any, '']
|
|
39
|
+
|
|
32
40
|
|
|
33
41
|
class gbart:
|
|
34
42
|
"""
|
|
@@ -46,6 +54,9 @@ class gbart:
|
|
|
46
54
|
The training responses.
|
|
47
55
|
x_test : array (p, m) or DataFrame, optional
|
|
48
56
|
The test predictors.
|
|
57
|
+
type
|
|
58
|
+
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
59
|
+
binary regression with probit link.
|
|
49
60
|
usequants : bool, default False
|
|
50
61
|
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
51
62
|
predictors.
|
|
@@ -70,16 +81,20 @@ class gbart:
|
|
|
70
81
|
Parameters of the prior on tree node generation. The probability that a
|
|
71
82
|
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
72
83
|
power``.
|
|
73
|
-
|
|
74
|
-
The
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
84
|
+
lamda
|
|
85
|
+
The prior harmonic mean of the error variance. (The harmonic mean of x
|
|
86
|
+
is 1/mean(1/x).) If not specified, it is set based on `sigest` and
|
|
87
|
+
`sigquant`.
|
|
88
|
+
tau_num
|
|
89
|
+
The numerator in the expression that determines the prior standard
|
|
90
|
+
deviation of leaves. If not specified, default to ``(max(y_train) -
|
|
91
|
+
min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
|
|
92
|
+
continuous regression, and 3 for binary regression.
|
|
93
|
+
offset
|
|
81
94
|
The prior mean of the latent mean function. If not specified, it is set
|
|
82
|
-
to the mean of `y_train
|
|
95
|
+
to the mean of `y_train` for continuous regression, and to
|
|
96
|
+
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
97
|
+
`offset` is set to 0.
|
|
83
98
|
w : array (n,), optional
|
|
84
99
|
Coefficients that rescale the error standard deviation on each
|
|
85
100
|
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
@@ -108,12 +123,24 @@ class gbart:
|
|
|
108
123
|
The number of initial MCMC samples to discard as burn-in.
|
|
109
124
|
keepevery : int, default 1
|
|
110
125
|
The thinning factor for the MCMC samples, after burn-in.
|
|
111
|
-
printevery : int, default 100
|
|
112
|
-
The number of iterations (including
|
|
126
|
+
printevery : int or None, default 100
|
|
127
|
+
The number of iterations (including thinned-away ones) between each log
|
|
128
|
+
line. Set to `None` to disable logging.
|
|
129
|
+
|
|
130
|
+
`printevery` has a few unexpected side effects. On cpu, interrupting
|
|
131
|
+
with ^C halts the MCMC only on the next log. And the total number of
|
|
132
|
+
iterations is a multiple of `printevery`, so if ``nskip + keepevery *
|
|
133
|
+
ndpost`` is not a multiple of `printevery`, some of the last iterations
|
|
134
|
+
will not be saved.
|
|
113
135
|
seed : int or jax random key, default 0
|
|
114
136
|
The seed for the random number generator.
|
|
115
|
-
|
|
137
|
+
maxdepth : int, default 6
|
|
138
|
+
The maximum depth of the trees. This is 1-based, so with the default
|
|
139
|
+
``maxdepth=6``, the depths of the levels range from 0 to 5.
|
|
140
|
+
init_kw : dict
|
|
116
141
|
Additional arguments passed to `mcmcstep.init`.
|
|
142
|
+
run_mcmc_kw : dict
|
|
143
|
+
Additional arguments passed to `mcmcloop.run_mcmc`.
|
|
117
144
|
|
|
118
145
|
Attributes
|
|
119
146
|
----------
|
|
@@ -131,20 +158,8 @@ class gbart:
|
|
|
131
158
|
The standard deviation of the error in the burn-in phase.
|
|
132
159
|
offset : float
|
|
133
160
|
The prior mean of the latent mean function.
|
|
134
|
-
scale : float
|
|
135
|
-
The prior standard deviation of the latent mean function.
|
|
136
|
-
lamda : float
|
|
137
|
-
The prior harmonic mean of the error variance.
|
|
138
161
|
sigest : float or None
|
|
139
162
|
The estimated standard deviation of the error used to set `lamda`.
|
|
140
|
-
ntree : int
|
|
141
|
-
The number of trees.
|
|
142
|
-
maxdepth : int
|
|
143
|
-
The maximum depth of the trees.
|
|
144
|
-
|
|
145
|
-
Methods
|
|
146
|
-
-------
|
|
147
|
-
predict
|
|
148
163
|
|
|
149
164
|
Notes
|
|
150
165
|
-----
|
|
@@ -153,14 +168,17 @@ class gbart:
|
|
|
153
168
|
|
|
154
169
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
155
170
|
instead of per column.
|
|
171
|
+
- If `type` is not specified, it is determined solely based on the data type
|
|
172
|
+
of `y_train`, and not on whether it contains only two unique values.
|
|
156
173
|
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
157
174
|
less predictor values than the required number of bins, while bartz
|
|
158
175
|
always follows the specification.
|
|
159
176
|
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
160
177
|
- `rm_const` is always `False`.
|
|
161
178
|
- The default `numcut` is 255 instead of 100.
|
|
162
|
-
- A lot of functionality is missing (variable selection
|
|
179
|
+
- A lot of functionality is missing (e.g., variable selection).
|
|
163
180
|
- There are some additional attributes, and some missing.
|
|
181
|
+
- The trees have a maximum depth.
|
|
164
182
|
|
|
165
183
|
"""
|
|
166
184
|
|
|
@@ -170,6 +188,7 @@ class gbart:
|
|
|
170
188
|
y_train,
|
|
171
189
|
*,
|
|
172
190
|
x_test=None,
|
|
191
|
+
type: Literal['wbart', 'pbart'] = 'wbart',
|
|
173
192
|
usequants=False,
|
|
174
193
|
sigest=None,
|
|
175
194
|
sigdf=3,
|
|
@@ -177,9 +196,9 @@ class gbart:
|
|
|
177
196
|
k=2,
|
|
178
197
|
power=2,
|
|
179
198
|
base=0.95,
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
offset=None,
|
|
199
|
+
lamda: FloatLike | None = None,
|
|
200
|
+
tau_num: FloatLike | None = None,
|
|
201
|
+
offset: FloatLike | None = None,
|
|
183
202
|
w=None,
|
|
184
203
|
ntree=200,
|
|
185
204
|
numcut=255,
|
|
@@ -188,7 +207,9 @@ class gbart:
|
|
|
188
207
|
keepevery=1,
|
|
189
208
|
printevery=100,
|
|
190
209
|
seed=0,
|
|
191
|
-
|
|
210
|
+
maxdepth=6,
|
|
211
|
+
init_kw=None,
|
|
212
|
+
run_mcmc_kw=None,
|
|
192
213
|
):
|
|
193
214
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
194
215
|
y_train, _ = self._process_response_input(y_train)
|
|
@@ -197,42 +218,41 @@ class gbart:
|
|
|
197
218
|
w, _ = self._process_response_input(w)
|
|
198
219
|
self._check_same_length(x_train, w)
|
|
199
220
|
|
|
221
|
+
y_train = self._process_type_settings(y_train, type, w)
|
|
222
|
+
# from here onwards, the type is determined by y_train.dtype == bool
|
|
200
223
|
offset = self._process_offset_settings(y_train, offset)
|
|
201
|
-
|
|
202
|
-
lamda, sigest = self.
|
|
203
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
224
|
+
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
|
|
225
|
+
lamda, sigest = self._process_error_variance_settings(
|
|
226
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
204
227
|
)
|
|
205
228
|
|
|
206
229
|
splits, max_split = self._determine_splits(x_train, usequants, numcut)
|
|
207
230
|
x_train = self._bin_predictors(x_train, splits)
|
|
208
|
-
y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
|
|
209
231
|
|
|
210
232
|
mcmc_state = self._setup_mcmc(
|
|
211
233
|
x_train,
|
|
212
234
|
y_train,
|
|
235
|
+
offset,
|
|
213
236
|
w,
|
|
214
237
|
max_split,
|
|
215
|
-
|
|
238
|
+
lamda,
|
|
239
|
+
sigma_mu,
|
|
216
240
|
sigdf,
|
|
217
241
|
power,
|
|
218
242
|
base,
|
|
219
243
|
maxdepth,
|
|
220
244
|
ntree,
|
|
221
|
-
|
|
245
|
+
init_kw,
|
|
222
246
|
)
|
|
223
247
|
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
224
|
-
mcmc_state, ndpost, nskip, keepevery, printevery, seed
|
|
248
|
+
mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
|
|
225
249
|
)
|
|
226
250
|
|
|
227
|
-
sigma = self._extract_sigma(main_trace
|
|
228
|
-
first_sigma = self._extract_sigma(burnin_trace
|
|
251
|
+
sigma = self._extract_sigma(main_trace)
|
|
252
|
+
first_sigma = self._extract_sigma(burnin_trace)
|
|
229
253
|
|
|
230
|
-
self.offset = offset
|
|
231
|
-
self.scale = scale
|
|
232
|
-
self.lamda = lamda
|
|
254
|
+
self.offset = final_state.offset # from the state because of buffer donation
|
|
233
255
|
self.sigest = sigest
|
|
234
|
-
self.ntree = ntree
|
|
235
|
-
self.maxdepth = maxdepth
|
|
236
256
|
self.sigma = sigma
|
|
237
257
|
self.first_sigma = first_sigma
|
|
238
258
|
|
|
@@ -248,9 +268,8 @@ class gbart:
|
|
|
248
268
|
|
|
249
269
|
@functools.cached_property
|
|
250
270
|
def yhat_train(self):
|
|
251
|
-
x_train = self._mcmc_state
|
|
252
|
-
|
|
253
|
-
return self._transform_output(yhat_train, self.offset, self.scale)
|
|
271
|
+
x_train = self._mcmc_state.X
|
|
272
|
+
return self._predict(self._main_trace, x_train)
|
|
254
273
|
|
|
255
274
|
@functools.cached_property
|
|
256
275
|
def yhat_train_mean(self):
|
|
@@ -269,12 +288,19 @@ class gbart:
|
|
|
269
288
|
-------
|
|
270
289
|
yhat_test : array (ndpost, m)
|
|
271
290
|
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
291
|
+
|
|
292
|
+
Raises
|
|
293
|
+
------
|
|
294
|
+
ValueError
|
|
295
|
+
If `x_test` has a different format than `x_train`.
|
|
272
296
|
"""
|
|
273
297
|
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
274
|
-
|
|
298
|
+
if x_test_fmt != self._x_train_fmt:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
|
|
301
|
+
)
|
|
275
302
|
x_test = self._bin_predictors(x_test, self._splits)
|
|
276
|
-
|
|
277
|
-
return self._transform_output(yhat_test, self.offset, self.scale)
|
|
303
|
+
return self._predict(self._main_trace, x_test)
|
|
278
304
|
|
|
279
305
|
@staticmethod
|
|
280
306
|
def _process_predictor_input(x):
|
|
@@ -287,10 +313,6 @@ class gbart:
|
|
|
287
313
|
assert x.ndim == 2
|
|
288
314
|
return x, fmt
|
|
289
315
|
|
|
290
|
-
@staticmethod
|
|
291
|
-
def _check_compatible_formats(fmt1, fmt2):
|
|
292
|
-
assert fmt1 == fmt2
|
|
293
|
-
|
|
294
316
|
@staticmethod
|
|
295
317
|
def _process_response_input(y):
|
|
296
318
|
if hasattr(y, 'to_numpy'):
|
|
@@ -308,18 +330,26 @@ class gbart:
|
|
|
308
330
|
assert get_length(x1) == get_length(x2)
|
|
309
331
|
|
|
310
332
|
@staticmethod
|
|
311
|
-
def
|
|
312
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
313
|
-
):
|
|
314
|
-
if
|
|
333
|
+
def _process_error_variance_settings(
|
|
334
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
335
|
+
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
336
|
+
if y_train.dtype == bool:
|
|
337
|
+
if sigest is not None:
|
|
338
|
+
raise ValueError('Let `sigest=None` for binary regression')
|
|
339
|
+
if lamda is not None:
|
|
340
|
+
raise ValueError('Let `lamda=None` for binary regression')
|
|
341
|
+
return None, None
|
|
342
|
+
elif lamda is not None:
|
|
343
|
+
if sigest is not None:
|
|
344
|
+
raise ValueError('Let `sigest=None` if `lamda` is specified')
|
|
315
345
|
return lamda, None
|
|
316
346
|
else:
|
|
317
347
|
if sigest is not None:
|
|
318
|
-
sigest2 = sigest
|
|
348
|
+
sigest2 = jnp.square(sigest)
|
|
319
349
|
elif y_train.size < 2:
|
|
320
350
|
sigest2 = 1
|
|
321
351
|
elif y_train.size <= x_train.shape[0]:
|
|
322
|
-
sigest2 = jnp.var(y_train
|
|
352
|
+
sigest2 = jnp.var(y_train)
|
|
323
353
|
else:
|
|
324
354
|
x_centered = x_train.T - x_train.mean(axis=1)
|
|
325
355
|
y_centered = y_train - y_train.mean()
|
|
@@ -334,20 +364,62 @@ class gbart:
|
|
|
334
364
|
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
335
365
|
|
|
336
366
|
@staticmethod
|
|
337
|
-
def
|
|
367
|
+
def _process_type_settings(y_train, type, w):
|
|
368
|
+
match type:
|
|
369
|
+
case 'wbart':
|
|
370
|
+
if y_train.dtype != jnp.float32:
|
|
371
|
+
raise TypeError(
|
|
372
|
+
'Continuous regression requires y_train.dtype=float32,'
|
|
373
|
+
f' got {y_train.dtype=} instead.'
|
|
374
|
+
)
|
|
375
|
+
case 'pbart':
|
|
376
|
+
if w is not None:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
'Binary regression does not support weights, set `w=None`'
|
|
379
|
+
)
|
|
380
|
+
if y_train.dtype != bool:
|
|
381
|
+
raise TypeError(
|
|
382
|
+
'Binary regression requires y_train.dtype=bool,'
|
|
383
|
+
f' got {y_train.dtype=} instead.'
|
|
384
|
+
)
|
|
385
|
+
case _:
|
|
386
|
+
raise ValueError(f'Invalid {type=}')
|
|
387
|
+
|
|
388
|
+
return y_train
|
|
389
|
+
|
|
390
|
+
@staticmethod
|
|
391
|
+
def _process_offset_settings(
|
|
392
|
+
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
393
|
+
offset: float | Float32[Any, ''] | None,
|
|
394
|
+
) -> Float32[Array, '']:
|
|
338
395
|
if offset is not None:
|
|
339
|
-
return offset
|
|
396
|
+
return jnp.asarray(offset)
|
|
340
397
|
elif y_train.size < 1:
|
|
341
|
-
return 0
|
|
398
|
+
return jnp.array(0.0)
|
|
342
399
|
else:
|
|
343
|
-
|
|
400
|
+
mean = y_train.mean()
|
|
344
401
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
if y_train.size < 2:
|
|
348
|
-
return 1
|
|
402
|
+
if y_train.dtype == bool:
|
|
403
|
+
return ndtri(mean)
|
|
349
404
|
else:
|
|
350
|
-
return
|
|
405
|
+
return mean
|
|
406
|
+
|
|
407
|
+
@staticmethod
|
|
408
|
+
def _process_leaf_sdev_settings(
|
|
409
|
+
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
410
|
+
k: float,
|
|
411
|
+
ntree: int,
|
|
412
|
+
tau_num: FloatLike | None,
|
|
413
|
+
):
|
|
414
|
+
if tau_num is None:
|
|
415
|
+
if y_train.dtype == bool:
|
|
416
|
+
tau_num = 3.0
|
|
417
|
+
elif y_train.size < 2:
|
|
418
|
+
tau_num = 1.0
|
|
419
|
+
else:
|
|
420
|
+
tau_num = (y_train.max() - y_train.min()) / 2
|
|
421
|
+
|
|
422
|
+
return tau_num / (k * math.sqrt(ntree))
|
|
351
423
|
|
|
352
424
|
@staticmethod
|
|
353
425
|
def _determine_splits(x_train, usequants, numcut):
|
|
@@ -360,67 +432,83 @@ class gbart:
|
|
|
360
432
|
def _bin_predictors(x, splits):
|
|
361
433
|
return prepcovars.bin_predictors(x, splits)
|
|
362
434
|
|
|
363
|
-
@staticmethod
|
|
364
|
-
def _transform_input(y, lamda, offset, scale):
|
|
365
|
-
y = (y - offset) / scale
|
|
366
|
-
lamda = lamda / (scale * scale)
|
|
367
|
-
return y, lamda
|
|
368
|
-
|
|
369
435
|
@staticmethod
|
|
370
436
|
def _setup_mcmc(
|
|
371
437
|
x_train,
|
|
372
438
|
y_train,
|
|
439
|
+
offset,
|
|
373
440
|
w,
|
|
374
441
|
max_split,
|
|
375
442
|
lamda,
|
|
443
|
+
sigma_mu,
|
|
376
444
|
sigdf,
|
|
377
445
|
power,
|
|
378
446
|
base,
|
|
379
447
|
maxdepth,
|
|
380
448
|
ntree,
|
|
381
|
-
|
|
449
|
+
init_kw,
|
|
382
450
|
):
|
|
383
451
|
depth = jnp.arange(maxdepth - 1)
|
|
384
452
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
385
|
-
|
|
386
|
-
|
|
453
|
+
|
|
454
|
+
if y_train.dtype == bool:
|
|
455
|
+
sigma2_alpha = None
|
|
456
|
+
sigma2_beta = None
|
|
457
|
+
else:
|
|
458
|
+
sigma2_alpha = sigdf / 2
|
|
459
|
+
sigma2_beta = lamda * sigma2_alpha
|
|
460
|
+
|
|
387
461
|
kw = dict(
|
|
388
462
|
X=x_train,
|
|
389
|
-
|
|
463
|
+
# copy y_train because it's going to be donated in the mcmc loop
|
|
464
|
+
y=jnp.array(y_train),
|
|
465
|
+
offset=offset,
|
|
390
466
|
error_scale=w,
|
|
391
467
|
max_split=max_split,
|
|
392
468
|
num_trees=ntree,
|
|
393
469
|
p_nonterminal=p_nonterminal,
|
|
470
|
+
sigma_mu2=jnp.square(sigma_mu),
|
|
394
471
|
sigma2_alpha=sigma2_alpha,
|
|
395
472
|
sigma2_beta=sigma2_beta,
|
|
396
473
|
min_points_per_leaf=5,
|
|
397
474
|
)
|
|
398
|
-
if
|
|
399
|
-
kw.update(
|
|
475
|
+
if init_kw is not None:
|
|
476
|
+
kw.update(init_kw)
|
|
400
477
|
return mcmcstep.init(**kw)
|
|
401
478
|
|
|
402
479
|
@staticmethod
|
|
403
|
-
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
|
|
480
|
+
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
|
|
404
481
|
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
405
482
|
seed.dtype, jax.dtypes.prng_key
|
|
406
483
|
):
|
|
407
|
-
key = seed
|
|
484
|
+
key = seed.copy()
|
|
485
|
+
# copy because the inner loop in run_mcmc will donate the buffer
|
|
408
486
|
else:
|
|
409
487
|
key = jax.random.key(seed)
|
|
410
|
-
callback = mcmcloop.make_simple_print_callback(printevery)
|
|
411
|
-
return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
|
|
412
488
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
489
|
+
kw = dict(
|
|
490
|
+
n_burn=nskip,
|
|
491
|
+
n_skip=keepevery,
|
|
492
|
+
inner_loop_length=printevery,
|
|
493
|
+
allow_overflow=True,
|
|
494
|
+
)
|
|
495
|
+
if printevery is not None:
|
|
496
|
+
kw.update(mcmcloop.make_print_callbacks())
|
|
497
|
+
if run_mcmc_kw is not None:
|
|
498
|
+
kw.update(run_mcmc_kw)
|
|
499
|
+
|
|
500
|
+
return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
|
|
416
501
|
|
|
417
502
|
@staticmethod
|
|
418
|
-
def
|
|
419
|
-
|
|
503
|
+
def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
|
|
504
|
+
if trace['sigma2'] is None:
|
|
505
|
+
return None
|
|
506
|
+
else:
|
|
507
|
+
return jnp.sqrt(trace['sigma2'])
|
|
420
508
|
|
|
421
509
|
@staticmethod
|
|
422
|
-
def
|
|
423
|
-
return
|
|
510
|
+
def _predict(trace, x):
|
|
511
|
+
return mcmcloop.evaluate_trace(trace, x)
|
|
424
512
|
|
|
425
513
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
426
514
|
from . import debug
|
|
@@ -444,19 +532,26 @@ class gbart:
|
|
|
444
532
|
)
|
|
445
533
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
446
534
|
sigma2 = beta / alpha
|
|
447
|
-
return jnp.sqrt(sigma2)
|
|
535
|
+
return jnp.sqrt(sigma2)
|
|
448
536
|
|
|
449
537
|
def _compare_resid(self):
|
|
450
538
|
bart = self._mcmc_state
|
|
451
|
-
resid1 = bart
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
bart
|
|
455
|
-
bart
|
|
456
|
-
bart
|
|
457
|
-
|
|
539
|
+
resid1 = bart.resid
|
|
540
|
+
|
|
541
|
+
trees = grove.evaluate_forest(
|
|
542
|
+
bart.X,
|
|
543
|
+
bart.forest.leaf_trees,
|
|
544
|
+
bart.forest.var_trees,
|
|
545
|
+
bart.forest.split_trees,
|
|
546
|
+
jnp.float32, # TODO remove these configurable dtypes around
|
|
458
547
|
)
|
|
459
|
-
|
|
548
|
+
|
|
549
|
+
if bart.z is not None:
|
|
550
|
+
ref = bart.z
|
|
551
|
+
else:
|
|
552
|
+
ref = bart.y
|
|
553
|
+
resid2 = ref - (trees + bart.offset)
|
|
554
|
+
|
|
460
555
|
return resid1, resid2
|
|
461
556
|
|
|
462
557
|
def _avg_acc(self):
|
|
@@ -495,9 +590,7 @@ class gbart:
|
|
|
495
590
|
def _points_per_leaf_distr(self):
|
|
496
591
|
from . import debug
|
|
497
592
|
|
|
498
|
-
return debug.trace_points_per_leaf_distr(
|
|
499
|
-
self._main_trace, self._mcmc_state['X']
|
|
500
|
-
)
|
|
593
|
+
return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
|
|
501
594
|
|
|
502
595
|
def _check_trees(self):
|
|
503
596
|
from . import debug
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.6.0'
|