bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +582 -279
- bartz/__init__.py +3 -3
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +168 -81
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +568 -158
- bartz/mcmcstep.py +1722 -926
- bartz/prepcovars.py +142 -44
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/METADATA +6 -5
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -374
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/BART.py
CHANGED
|
@@ -22,17 +22,73 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
"""Implement a class `gbart` that mimics the R BART package."""
|
|
26
|
+
|
|
27
|
+
import math
|
|
28
|
+
from collections.abc import Sequence
|
|
29
|
+
from functools import cached_property
|
|
30
|
+
from typing import Any, Literal, Protocol
|
|
26
31
|
|
|
27
32
|
import jax
|
|
28
33
|
import jax.numpy as jnp
|
|
34
|
+
from equinox import Module, field
|
|
35
|
+
from jax.scipy.special import ndtr
|
|
36
|
+
from jaxtyping import (
|
|
37
|
+
Array,
|
|
38
|
+
Bool,
|
|
39
|
+
Float,
|
|
40
|
+
Float32,
|
|
41
|
+
Int32,
|
|
42
|
+
Integer,
|
|
43
|
+
Key,
|
|
44
|
+
Real,
|
|
45
|
+
Shaped,
|
|
46
|
+
UInt,
|
|
47
|
+
)
|
|
48
|
+
from numpy import ndarray
|
|
49
|
+
|
|
50
|
+
from bartz import mcmcloop, mcmcstep, prepcovars
|
|
51
|
+
from bartz.jaxext.scipy.special import ndtri
|
|
52
|
+
from bartz.jaxext.scipy.stats import invgamma
|
|
53
|
+
|
|
54
|
+
FloatLike = float | Float[Any, '']
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DataFrame(Protocol):
|
|
58
|
+
"""DataFrame duck-type for `gbart`.
|
|
59
|
+
|
|
60
|
+
Attributes
|
|
61
|
+
----------
|
|
62
|
+
columns : Sequence[str]
|
|
63
|
+
The names of the columns.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
columns: Sequence[str]
|
|
67
|
+
|
|
68
|
+
def to_numpy(self) -> ndarray:
|
|
69
|
+
"""Convert the dataframe to a 2d numpy array with columns on the second axis."""
|
|
70
|
+
...
|
|
29
71
|
|
|
30
|
-
from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
|
|
31
72
|
|
|
73
|
+
class Series(Protocol):
|
|
74
|
+
"""Series duck-type for `gbart`.
|
|
32
75
|
|
|
33
|
-
|
|
76
|
+
Attributes
|
|
77
|
+
----------
|
|
78
|
+
name : str | None
|
|
79
|
+
The name of the series.
|
|
34
80
|
"""
|
|
35
|
-
|
|
81
|
+
|
|
82
|
+
name: str | None
|
|
83
|
+
|
|
84
|
+
def to_numpy(self) -> ndarray:
|
|
85
|
+
"""Convert the series to a 1d numpy array."""
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class gbart(Module):
|
|
90
|
+
R"""
|
|
91
|
+
Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
|
|
36
92
|
|
|
37
93
|
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
38
94
|
a sum of decision trees. The inference is carried out by sampling the
|
|
@@ -40,55 +96,108 @@ class gbart:
|
|
|
40
96
|
|
|
41
97
|
Parameters
|
|
42
98
|
----------
|
|
43
|
-
x_train
|
|
99
|
+
x_train
|
|
44
100
|
The training predictors.
|
|
45
|
-
y_train
|
|
101
|
+
y_train
|
|
46
102
|
The training responses.
|
|
47
|
-
x_test
|
|
103
|
+
x_test
|
|
48
104
|
The test predictors.
|
|
49
|
-
|
|
105
|
+
type
|
|
106
|
+
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
107
|
+
binary regression with probit link.
|
|
108
|
+
sparse
|
|
109
|
+
Whether to activate variable selection on the predictors as done in
|
|
110
|
+
[1]_.
|
|
111
|
+
theta
|
|
112
|
+
a
|
|
113
|
+
b
|
|
114
|
+
rho
|
|
115
|
+
Hyperparameters of the sparsity prior used for variable selection.
|
|
116
|
+
|
|
117
|
+
The prior distribution on the choice of predictor for each decision rule
|
|
118
|
+
is
|
|
119
|
+
|
|
120
|
+
.. math::
|
|
121
|
+
(s_1, \ldots, s_p) \sim
|
|
122
|
+
\operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
|
|
123
|
+
|
|
124
|
+
If `theta` is not specified, it's a priori distributed according to
|
|
125
|
+
|
|
126
|
+
.. math::
|
|
127
|
+
\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
|
|
128
|
+
\operatorname{Beta}(\mathtt{a}, \mathtt{b}).
|
|
129
|
+
|
|
130
|
+
If not specified, `rho` is set to the number of predictors p. To tune
|
|
131
|
+
the prior, consider setting a lower `rho` to prefer more sparsity.
|
|
132
|
+
If setting `theta` directly, it should be in the ballpark of p or lower
|
|
133
|
+
as well.
|
|
134
|
+
xinfo
|
|
135
|
+
A matrix with the cutpoins to use to bin each predictor. If not
|
|
136
|
+
specified, it is generated automatically according to `usequants` and
|
|
137
|
+
`numcut`.
|
|
138
|
+
|
|
139
|
+
Each row shall contain a sorted list of cutpoints for a predictor. If
|
|
140
|
+
there are less cutpoints than the number of columns in the matrix,
|
|
141
|
+
fill the remaining cells with NaN.
|
|
142
|
+
|
|
143
|
+
`xinfo` shall be a matrix even if `x_train` is a dataframe.
|
|
144
|
+
usequants
|
|
50
145
|
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
51
|
-
predictors.
|
|
52
|
-
|
|
146
|
+
predictors. Ignored if `xinfo` is specified.
|
|
147
|
+
rm_const
|
|
148
|
+
How to treat predictors with no associated decision rules (i.e., there
|
|
149
|
+
are no available cutpoints for that predictor). If `True` (default),
|
|
150
|
+
they are ignored. If `False`, an error is raised if there are any. If
|
|
151
|
+
`None`, no check is performed, and the output of the MCMC may not make
|
|
152
|
+
sense if there are predictors without cutpoints. The option `None` is
|
|
153
|
+
provided only to allow jax tracing.
|
|
154
|
+
sigest
|
|
53
155
|
An estimate of the residual standard deviation on `y_train`, used to set
|
|
54
156
|
`lamda`. If not specified, it is estimated by linear regression (with
|
|
55
157
|
intercept, and without taking into account `w`). If `y_train` has less
|
|
56
158
|
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
57
159
|
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
58
|
-
sigdf
|
|
160
|
+
sigdf
|
|
59
161
|
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
60
162
|
noise variance.
|
|
61
|
-
sigquant
|
|
163
|
+
sigquant
|
|
62
164
|
The quantile of the prior on the noise variance that shall match
|
|
63
165
|
`sigest` to set the scale of the prior. Ignored if `lamda` is specified.
|
|
64
|
-
k
|
|
166
|
+
k
|
|
65
167
|
The inverse scale of the prior standard deviation on the latent mean
|
|
66
168
|
function, relative to half the observed range of `y_train`. If `y_train`
|
|
67
169
|
has less than two elements, `k` is ignored and the scale is set to 1.
|
|
68
|
-
power
|
|
69
|
-
base
|
|
170
|
+
power
|
|
171
|
+
base
|
|
70
172
|
Parameters of the prior on tree node generation. The probability that a
|
|
71
173
|
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
72
174
|
power``.
|
|
73
|
-
|
|
74
|
-
The
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
175
|
+
lamda
|
|
176
|
+
The prior harmonic mean of the error variance. (The harmonic mean of x
|
|
177
|
+
is 1/mean(1/x).) If not specified, it is set based on `sigest` and
|
|
178
|
+
`sigquant`.
|
|
179
|
+
tau_num
|
|
180
|
+
The numerator in the expression that determines the prior standard
|
|
181
|
+
deviation of leaves. If not specified, default to ``(max(y_train) -
|
|
182
|
+
min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
|
|
183
|
+
continuous regression, and 3 for binary regression.
|
|
184
|
+
offset
|
|
81
185
|
The prior mean of the latent mean function. If not specified, it is set
|
|
82
|
-
to the mean of `y_train
|
|
83
|
-
|
|
186
|
+
to the mean of `y_train` for continuous regression, and to
|
|
187
|
+
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
188
|
+
`offset` is set to 0. With binary regression, if `y_train` is all
|
|
189
|
+
`False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
|
|
190
|
+
``Phi^-1(n/(n+1))``, respectively.
|
|
191
|
+
w
|
|
84
192
|
Coefficients that rescale the error standard deviation on each
|
|
85
193
|
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
86
194
|
datapoints. Note: `w` is ignored in the automatic determination of
|
|
87
195
|
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
88
196
|
specified by the user.
|
|
89
|
-
ntree
|
|
90
|
-
The number of trees used to represent the latent mean function.
|
|
91
|
-
|
|
197
|
+
ntree
|
|
198
|
+
The number of trees used to represent the latent mean function. By
|
|
199
|
+
default 200 for continuous regression and 50 for binary regression.
|
|
200
|
+
numcut
|
|
92
201
|
If `usequants` is `False`: the exact number of cutpoints used to bin the
|
|
93
202
|
predictors, ranging between the minimum and maximum observed values
|
|
94
203
|
(excluded).
|
|
@@ -101,50 +210,43 @@ class gbart:
|
|
|
101
210
|
|
|
102
211
|
Before running the algorithm, the predictors are compressed to the
|
|
103
212
|
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
104
|
-
to the maximum value of an unsigned integer type.
|
|
105
|
-
|
|
213
|
+
to the maximum value of an unsigned integer type, like 255.
|
|
214
|
+
|
|
215
|
+
Ignored if `xinfo` is specified.
|
|
216
|
+
ndpost
|
|
106
217
|
The number of MCMC samples to save, after burn-in.
|
|
107
|
-
nskip
|
|
218
|
+
nskip
|
|
108
219
|
The number of initial MCMC samples to discard as burn-in.
|
|
109
|
-
keepevery
|
|
110
|
-
The thinning factor for the MCMC samples, after burn-in.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
220
|
+
keepevery
|
|
221
|
+
The thinning factor for the MCMC samples, after burn-in. By default, 1
|
|
222
|
+
for continuous regression and 10 for binary regression.
|
|
223
|
+
printevery
|
|
224
|
+
The number of iterations (including thinned-away ones) between each log
|
|
225
|
+
line. Set to `None` to disable logging.
|
|
226
|
+
|
|
227
|
+
`printevery` has a few unexpected side effects. On cpu, interrupting
|
|
228
|
+
with ^C halts the MCMC only on the next log. And the total number of
|
|
229
|
+
iterations is a multiple of `printevery`, so if ``nskip + keepevery *
|
|
230
|
+
ndpost`` is not a multiple of `printevery`, some of the last iterations
|
|
231
|
+
will not be saved.
|
|
232
|
+
seed
|
|
114
233
|
The seed for the random number generator.
|
|
115
|
-
|
|
116
|
-
|
|
234
|
+
maxdepth
|
|
235
|
+
The maximum depth of the trees. This is 1-based, so with the default
|
|
236
|
+
``maxdepth=6``, the depths of the levels range from 0 to 5.
|
|
237
|
+
init_kw
|
|
238
|
+
Additional arguments passed to `bartz.mcmcstep.init`.
|
|
239
|
+
run_mcmc_kw
|
|
240
|
+
Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
|
|
117
241
|
|
|
118
242
|
Attributes
|
|
119
243
|
----------
|
|
120
|
-
|
|
121
|
-
The conditional posterior mean at `x_train` for each MCMC iteration.
|
|
122
|
-
yhat_train_mean : array (n,)
|
|
123
|
-
The marginal posterior mean at `x_train`.
|
|
124
|
-
yhat_test : array (ndpost, m)
|
|
125
|
-
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
126
|
-
yhat_test_mean : array (m,)
|
|
127
|
-
The marginal posterior mean at `x_test`.
|
|
128
|
-
sigma : array (ndpost,)
|
|
129
|
-
The standard deviation of the error.
|
|
130
|
-
first_sigma : array (nskip,)
|
|
131
|
-
The standard deviation of the error in the burn-in phase.
|
|
132
|
-
offset : float
|
|
244
|
+
offset : Float32[Array, '']
|
|
133
245
|
The prior mean of the latent mean function.
|
|
134
|
-
|
|
135
|
-
The prior standard deviation of the latent mean function.
|
|
136
|
-
lamda : float
|
|
137
|
-
The prior harmonic mean of the error variance.
|
|
138
|
-
sigest : float or None
|
|
246
|
+
sigest : Float32[Array, ''] | None
|
|
139
247
|
The estimated standard deviation of the error used to set `lamda`.
|
|
140
|
-
|
|
141
|
-
The
|
|
142
|
-
maxdepth : int
|
|
143
|
-
The maximum depth of the trees.
|
|
144
|
-
|
|
145
|
-
Methods
|
|
146
|
-
-------
|
|
147
|
-
predict
|
|
248
|
+
yhat_test : Float32[Array, 'ndpost m'] | None
|
|
249
|
+
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
148
250
|
|
|
149
251
|
Notes
|
|
150
252
|
-----
|
|
@@ -156,128 +258,293 @@ class gbart:
|
|
|
156
258
|
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
157
259
|
less predictor values than the required number of bins, while bartz
|
|
158
260
|
always follows the specification.
|
|
261
|
+
- Some functionality is missing.
|
|
159
262
|
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
160
|
-
- `rm_const` is always `False`.
|
|
161
|
-
- The default `numcut` is 255 instead of 100.
|
|
162
|
-
- A lot of functionality is missing (variable selection, discrete response).
|
|
163
263
|
- There are some additional attributes, and some missing.
|
|
264
|
+
- The trees have a maximum depth.
|
|
265
|
+
- `rm_const` refers to predictors without decision rules instead of
|
|
266
|
+
predictors that are constant in `x_train`.
|
|
267
|
+
- If `rm_const=True` and some variables are dropped, the predictors
|
|
268
|
+
matrix/dataframe passed to `predict` should still include them.
|
|
164
269
|
|
|
270
|
+
References
|
|
271
|
+
----------
|
|
272
|
+
.. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
|
|
273
|
+
High-Dimensional Prediction and Variable Selection”. In: Journal of the
|
|
274
|
+
American Statistical Association 113.522, pp. 626-636.
|
|
275
|
+
.. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
|
|
276
|
+
Bayesian additive regression trees," The Annals of Applied Statistics,
|
|
277
|
+
Ann. Appl. Stat. 4(1), 266-298, (March 2010).
|
|
165
278
|
"""
|
|
166
279
|
|
|
280
|
+
_main_trace: mcmcloop.MainTrace
|
|
281
|
+
_burnin_trace: mcmcloop.BurninTrace
|
|
282
|
+
_mcmc_state: mcmcstep.State
|
|
283
|
+
_splits: Real[Array, 'p max_num_splits']
|
|
284
|
+
_x_train_fmt: Any = field(static=True)
|
|
285
|
+
|
|
286
|
+
ndpost: int = field(static=True)
|
|
287
|
+
offset: Float32[Array, '']
|
|
288
|
+
sigest: Float32[Array, ''] | None = None
|
|
289
|
+
yhat_test: Float32[Array, 'ndpost m'] | None = None
|
|
290
|
+
|
|
167
291
|
def __init__(
|
|
168
292
|
self,
|
|
169
|
-
x_train,
|
|
170
|
-
y_train,
|
|
293
|
+
x_train: Real[Array, 'p n'] | DataFrame,
|
|
294
|
+
y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
|
|
171
295
|
*,
|
|
172
|
-
x_test=None,
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
296
|
+
x_test: Real[Array, 'p m'] | DataFrame | None = None,
|
|
297
|
+
type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
|
|
298
|
+
sparse: bool = False,
|
|
299
|
+
theta: FloatLike | None = None,
|
|
300
|
+
a: FloatLike = 0.5,
|
|
301
|
+
b: FloatLike = 1.0,
|
|
302
|
+
rho: FloatLike | None = None,
|
|
303
|
+
xinfo: Float[Array, 'p n'] | None = None,
|
|
304
|
+
usequants: bool = False,
|
|
305
|
+
rm_const: bool | None = True,
|
|
306
|
+
sigest: FloatLike | None = None,
|
|
307
|
+
sigdf: FloatLike = 3.0,
|
|
308
|
+
sigquant: FloatLike = 0.9,
|
|
309
|
+
k: FloatLike = 2.0,
|
|
310
|
+
power: FloatLike = 2.0,
|
|
311
|
+
base: FloatLike = 0.95,
|
|
312
|
+
lamda: FloatLike | None = None,
|
|
313
|
+
tau_num: FloatLike | None = None,
|
|
314
|
+
offset: FloatLike | None = None,
|
|
315
|
+
w: Float[Array, ' n'] | None = None,
|
|
316
|
+
ntree: int | None = None,
|
|
317
|
+
numcut: int = 100,
|
|
318
|
+
ndpost: int = 1000,
|
|
319
|
+
nskip: int = 100,
|
|
320
|
+
keepevery: int | None = None,
|
|
321
|
+
printevery: int | None = 100,
|
|
322
|
+
seed: int | Key[Array, ''] = 0,
|
|
323
|
+
maxdepth: int = 6,
|
|
324
|
+
init_kw: dict | None = None,
|
|
325
|
+
run_mcmc_kw: dict | None = None,
|
|
192
326
|
):
|
|
327
|
+
# check data and put it in the right format
|
|
193
328
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
194
|
-
y_train
|
|
329
|
+
y_train = self._process_response_input(y_train)
|
|
195
330
|
self._check_same_length(x_train, y_train)
|
|
196
331
|
if w is not None:
|
|
197
|
-
w
|
|
332
|
+
w = self._process_response_input(w)
|
|
198
333
|
self._check_same_length(x_train, w)
|
|
199
334
|
|
|
335
|
+
# check data types are correct for continuous/binary regression
|
|
336
|
+
self._check_type_settings(y_train, type, w)
|
|
337
|
+
# from here onwards, the type is determined by y_train.dtype == bool
|
|
338
|
+
|
|
339
|
+
# set defaults that depend on type of regression
|
|
340
|
+
if ntree is None:
|
|
341
|
+
ntree = 50 if y_train.dtype == bool else 200
|
|
342
|
+
if keepevery is None:
|
|
343
|
+
keepevery = 10 if y_train.dtype == bool else 1
|
|
344
|
+
|
|
345
|
+
# process sparsity settings
|
|
346
|
+
theta, a, b, rho = self._process_sparsity_settings(
|
|
347
|
+
x_train, sparse, theta, a, b, rho
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# process "standardization" settings
|
|
200
351
|
offset = self._process_offset_settings(y_train, offset)
|
|
201
|
-
|
|
202
|
-
lamda, sigest = self.
|
|
203
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
352
|
+
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
|
|
353
|
+
lamda, sigest = self._process_error_variance_settings(
|
|
354
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
204
355
|
)
|
|
205
356
|
|
|
206
|
-
|
|
357
|
+
# determine splits
|
|
358
|
+
splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo)
|
|
207
359
|
x_train = self._bin_predictors(x_train, splits)
|
|
208
|
-
y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
|
|
209
360
|
|
|
210
|
-
|
|
361
|
+
# setup and run mcmc
|
|
362
|
+
initial_state = self._setup_mcmc(
|
|
211
363
|
x_train,
|
|
212
364
|
y_train,
|
|
365
|
+
offset,
|
|
213
366
|
w,
|
|
214
367
|
max_split,
|
|
215
|
-
|
|
368
|
+
lamda,
|
|
369
|
+
sigma_mu,
|
|
216
370
|
sigdf,
|
|
217
371
|
power,
|
|
218
372
|
base,
|
|
219
373
|
maxdepth,
|
|
220
374
|
ntree,
|
|
221
|
-
|
|
375
|
+
init_kw,
|
|
376
|
+
rm_const,
|
|
377
|
+
theta,
|
|
378
|
+
a,
|
|
379
|
+
b,
|
|
380
|
+
rho,
|
|
222
381
|
)
|
|
223
382
|
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
224
|
-
|
|
383
|
+
initial_state,
|
|
384
|
+
ndpost,
|
|
385
|
+
nskip,
|
|
386
|
+
keepevery,
|
|
387
|
+
printevery,
|
|
388
|
+
seed,
|
|
389
|
+
run_mcmc_kw,
|
|
390
|
+
sparse,
|
|
225
391
|
)
|
|
226
392
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
self.offset = offset
|
|
231
|
-
self.scale = scale
|
|
232
|
-
self.lamda = lamda
|
|
393
|
+
# set public attributes
|
|
394
|
+
self.offset = final_state.offset # from the state because of buffer donation
|
|
395
|
+
self.ndpost = ndpost
|
|
233
396
|
self.sigest = sigest
|
|
234
|
-
self.ntree = ntree
|
|
235
|
-
self.maxdepth = maxdepth
|
|
236
|
-
self.sigma = sigma
|
|
237
|
-
self.first_sigma = first_sigma
|
|
238
397
|
|
|
239
|
-
|
|
240
|
-
self._splits = splits
|
|
398
|
+
# set private attributes
|
|
241
399
|
self._main_trace = main_trace
|
|
400
|
+
self._burnin_trace = burnin_trace
|
|
242
401
|
self._mcmc_state = final_state
|
|
402
|
+
self._splits = splits
|
|
403
|
+
self._x_train_fmt = x_train_fmt
|
|
243
404
|
|
|
405
|
+
# predict at test points
|
|
244
406
|
if x_test is not None:
|
|
245
|
-
yhat_test = self.predict(x_test)
|
|
246
|
-
|
|
247
|
-
|
|
407
|
+
self.yhat_test = self.predict(x_test)
|
|
408
|
+
|
|
409
|
+
@cached_property
|
|
410
|
+
def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
|
|
411
|
+
"""The posterior probability of y being True at `x_test` for each MCMC iteration."""
|
|
412
|
+
if self.yhat_test is None or self._mcmc_state.y.dtype != bool:
|
|
413
|
+
return None
|
|
414
|
+
else:
|
|
415
|
+
return ndtr(self.yhat_test)
|
|
416
|
+
|
|
417
|
+
@cached_property
|
|
418
|
+
def prob_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
419
|
+
"""The marginal posterior probability of y being True at `x_test`."""
|
|
420
|
+
if self.prob_test is None:
|
|
421
|
+
return None
|
|
422
|
+
else:
|
|
423
|
+
return self.prob_test.mean(axis=0)
|
|
424
|
+
|
|
425
|
+
@cached_property
|
|
426
|
+
def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
|
|
427
|
+
"""The posterior probability of y being True at `x_train` for each MCMC iteration."""
|
|
428
|
+
if self._mcmc_state.y.dtype == bool:
|
|
429
|
+
return ndtr(self.yhat_train)
|
|
430
|
+
else:
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
@cached_property
|
|
434
|
+
def prob_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
435
|
+
"""The marginal posterior probability of y being True at `x_train`."""
|
|
436
|
+
if self.prob_train is None:
|
|
437
|
+
return None
|
|
438
|
+
else:
|
|
439
|
+
return self.prob_train.mean(axis=0)
|
|
440
|
+
|
|
441
|
+
@cached_property
|
|
442
|
+
def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None:
|
|
443
|
+
"""The standard deviation of the error, including burn-in samples."""
|
|
444
|
+
if self._burnin_trace.sigma2 is None:
|
|
445
|
+
return None
|
|
446
|
+
else:
|
|
447
|
+
assert self._main_trace.sigma2 is not None
|
|
448
|
+
return jnp.sqrt(
|
|
449
|
+
jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2])
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
@cached_property
|
|
453
|
+
def sigma_mean(self) -> Float32[Array, ''] | None:
|
|
454
|
+
"""The mean of `sigma`, only over the post-burnin samples."""
|
|
455
|
+
if self.sigma is None:
|
|
456
|
+
return None
|
|
457
|
+
else:
|
|
458
|
+
return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0)
|
|
459
|
+
|
|
460
|
+
@cached_property
|
|
461
|
+
def varcount(self) -> Int32[Array, 'ndpost p']:
|
|
462
|
+
"""Histogram of predictor usage for decision rules in the trees."""
|
|
463
|
+
return mcmcloop.compute_varcount(
|
|
464
|
+
self._mcmc_state.forest.max_split.size, self._main_trace
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
@cached_property
|
|
468
|
+
def varcount_mean(self) -> Float32[Array, ' p']:
|
|
469
|
+
"""Average of `varcount` across MCMC iterations."""
|
|
470
|
+
return self.varcount.mean(axis=0)
|
|
471
|
+
|
|
472
|
+
@cached_property
|
|
473
|
+
def varprob(self) -> Float32[Array, 'ndpost p']:
|
|
474
|
+
"""Posterior samples of the probability of choosing each predictor for a decision rule."""
|
|
475
|
+
varprob = self._main_trace.varprob
|
|
476
|
+
if varprob is None:
|
|
477
|
+
max_split = self._mcmc_state.forest.max_split
|
|
478
|
+
p = max_split.size
|
|
479
|
+
peff = jnp.count_nonzero(max_split)
|
|
480
|
+
varprob = jnp.where(max_split, 1 / peff, 0)
|
|
481
|
+
varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
|
|
482
|
+
return varprob
|
|
483
|
+
|
|
484
|
+
@cached_property
|
|
485
|
+
def varprob_mean(self) -> Float32[Array, ' p']:
|
|
486
|
+
"""The marginal posterior probability of each predictor being chosen for a decision rule."""
|
|
487
|
+
return self.varprob.mean(axis=0)
|
|
488
|
+
|
|
489
|
+
@cached_property
|
|
490
|
+
def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
491
|
+
"""The marginal posterior mean at `x_test`.
|
|
492
|
+
|
|
493
|
+
Not defined with binary regression because it's error-prone, typically
|
|
494
|
+
the right thing to consider would be `prob_test_mean`.
|
|
495
|
+
"""
|
|
496
|
+
if self.yhat_test is None or self._mcmc_state.y.dtype == bool:
|
|
497
|
+
return None
|
|
498
|
+
else:
|
|
499
|
+
return self.yhat_test.mean(axis=0)
|
|
500
|
+
|
|
501
|
+
@cached_property
|
|
502
|
+
def yhat_train(self) -> Float32[Array, 'ndpost n']:
|
|
503
|
+
"""The conditional posterior mean at `x_train` for each MCMC iteration."""
|
|
504
|
+
x_train = self._mcmc_state.X
|
|
505
|
+
return self._predict(x_train)
|
|
248
506
|
|
|
249
|
-
@
|
|
250
|
-
def
|
|
251
|
-
|
|
252
|
-
yhat_train = self._predict(self._main_trace, x_train)
|
|
253
|
-
return self._transform_output(yhat_train, self.offset, self.scale)
|
|
507
|
+
@cached_property
|
|
508
|
+
def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
509
|
+
"""The marginal posterior mean at `x_train`.
|
|
254
510
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
511
|
+
Not defined with binary regression because it's error-prone, typically
|
|
512
|
+
the right thing to consider would be `prob_train_mean`.
|
|
513
|
+
"""
|
|
514
|
+
if self._mcmc_state.y.dtype == bool:
|
|
515
|
+
return None
|
|
516
|
+
else:
|
|
517
|
+
return self.yhat_train.mean(axis=0)
|
|
258
518
|
|
|
259
|
-
def predict(
|
|
519
|
+
def predict(
|
|
520
|
+
self, x_test: Real[Array, 'p m'] | DataFrame
|
|
521
|
+
) -> Float32[Array, 'ndpost m']:
|
|
260
522
|
"""
|
|
261
523
|
Compute the posterior mean at `x_test` for each MCMC iteration.
|
|
262
524
|
|
|
263
525
|
Parameters
|
|
264
526
|
----------
|
|
265
|
-
x_test
|
|
527
|
+
x_test
|
|
266
528
|
The test predictors.
|
|
267
529
|
|
|
268
530
|
Returns
|
|
269
531
|
-------
|
|
270
|
-
|
|
271
|
-
|
|
532
|
+
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
533
|
+
|
|
534
|
+
Raises
|
|
535
|
+
------
|
|
536
|
+
ValueError
|
|
537
|
+
If `x_test` has a different format than `x_train`.
|
|
272
538
|
"""
|
|
273
539
|
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
274
|
-
|
|
540
|
+
if x_test_fmt != self._x_train_fmt:
|
|
541
|
+
msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
|
|
542
|
+
raise ValueError(msg)
|
|
275
543
|
x_test = self._bin_predictors(x_test, self._splits)
|
|
276
|
-
|
|
277
|
-
return self._transform_output(yhat_test, self.offset, self.scale)
|
|
544
|
+
return self._predict(x_test)
|
|
278
545
|
|
|
279
546
|
@staticmethod
|
|
280
|
-
def _process_predictor_input(x):
|
|
547
|
+
def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
|
|
281
548
|
if hasattr(x, 'columns'):
|
|
282
549
|
fmt = dict(kind='dataframe', columns=x.columns)
|
|
283
550
|
x = x.to_numpy().T
|
|
@@ -288,19 +555,12 @@ class gbart:
|
|
|
288
555
|
return x, fmt
|
|
289
556
|
|
|
290
557
|
@staticmethod
|
|
291
|
-
def
|
|
292
|
-
assert fmt1 == fmt2
|
|
293
|
-
|
|
294
|
-
@staticmethod
|
|
295
|
-
def _process_response_input(y):
|
|
558
|
+
def _process_response_input(y) -> Shaped[Array, ' n']:
|
|
296
559
|
if hasattr(y, 'to_numpy'):
|
|
297
|
-
fmt = dict(kind='series', name=y.name)
|
|
298
560
|
y = y.to_numpy()
|
|
299
|
-
else:
|
|
300
|
-
fmt = dict(kind='array')
|
|
301
561
|
y = jnp.asarray(y)
|
|
302
562
|
assert y.ndim == 1
|
|
303
|
-
return y
|
|
563
|
+
return y
|
|
304
564
|
|
|
305
565
|
@staticmethod
|
|
306
566
|
def _check_same_length(x1, x2):
|
|
@@ -308,18 +568,29 @@ class gbart:
|
|
|
308
568
|
assert get_length(x1) == get_length(x2)
|
|
309
569
|
|
|
310
570
|
@staticmethod
|
|
311
|
-
def
|
|
312
|
-
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
313
|
-
):
|
|
314
|
-
if
|
|
571
|
+
def _process_error_variance_settings(
|
|
572
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
573
|
+
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
574
|
+
if y_train.dtype == bool:
|
|
575
|
+
if sigest is not None:
|
|
576
|
+
msg = 'Let `sigest=None` for binary regression'
|
|
577
|
+
raise ValueError(msg)
|
|
578
|
+
if lamda is not None:
|
|
579
|
+
msg = 'Let `lamda=None` for binary regression'
|
|
580
|
+
raise ValueError(msg)
|
|
581
|
+
return None, None
|
|
582
|
+
elif lamda is not None:
|
|
583
|
+
if sigest is not None:
|
|
584
|
+
msg = 'Let `sigest=None` if `lamda` is specified'
|
|
585
|
+
raise ValueError(msg)
|
|
315
586
|
return lamda, None
|
|
316
587
|
else:
|
|
317
588
|
if sigest is not None:
|
|
318
|
-
sigest2 = sigest
|
|
589
|
+
sigest2 = jnp.square(sigest)
|
|
319
590
|
elif y_train.size < 2:
|
|
320
591
|
sigest2 = 1
|
|
321
592
|
elif y_train.size <= x_train.shape[0]:
|
|
322
|
-
sigest2 = jnp.var(y_train
|
|
593
|
+
sigest2 = jnp.var(y_train)
|
|
323
594
|
else:
|
|
324
595
|
x_centered = x_train.T - x_train.mean(axis=1)
|
|
325
596
|
y_centered = y_train - y_train.mean()
|
|
@@ -329,182 +600,214 @@ class gbart:
|
|
|
329
600
|
dof = len(y_train) - rank
|
|
330
601
|
sigest2 = chisq / dof
|
|
331
602
|
alpha = sigdf / 2
|
|
332
|
-
invchi2 =
|
|
603
|
+
invchi2 = invgamma.ppf(sigquant, alpha) / 2
|
|
333
604
|
invchi2rid = invchi2 * sigdf
|
|
334
605
|
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
335
606
|
|
|
336
607
|
@staticmethod
|
|
337
|
-
def
|
|
608
|
+
def _check_type_settings(y_train, type, w): # noqa: A002
|
|
609
|
+
match type:
|
|
610
|
+
case 'wbart':
|
|
611
|
+
if y_train.dtype != jnp.float32:
|
|
612
|
+
msg = (
|
|
613
|
+
'Continuous regression requires y_train.dtype=float32,'
|
|
614
|
+
f' got {y_train.dtype=} instead.'
|
|
615
|
+
)
|
|
616
|
+
raise TypeError(msg)
|
|
617
|
+
case 'pbart':
|
|
618
|
+
if w is not None:
|
|
619
|
+
msg = 'Binary regression does not support weights, set `w=None`'
|
|
620
|
+
raise ValueError(msg)
|
|
621
|
+
if y_train.dtype != bool:
|
|
622
|
+
msg = (
|
|
623
|
+
'Binary regression requires y_train.dtype=bool,'
|
|
624
|
+
f' got {y_train.dtype=} instead.'
|
|
625
|
+
)
|
|
626
|
+
raise TypeError(msg)
|
|
627
|
+
case _:
|
|
628
|
+
msg = f'Invalid {type=}'
|
|
629
|
+
raise ValueError(msg)
|
|
630
|
+
|
|
631
|
+
@staticmethod
|
|
632
|
+
def _process_sparsity_settings(
|
|
633
|
+
x_train: Real[Array, 'p n'],
|
|
634
|
+
sparse: bool,
|
|
635
|
+
theta: FloatLike | None,
|
|
636
|
+
a: FloatLike,
|
|
637
|
+
b: FloatLike,
|
|
638
|
+
rho: FloatLike | None,
|
|
639
|
+
) -> (
|
|
640
|
+
tuple[None, None, None, None]
|
|
641
|
+
| tuple[FloatLike, None, None, None]
|
|
642
|
+
| tuple[None, FloatLike, FloatLike, FloatLike]
|
|
643
|
+
):
|
|
644
|
+
if not sparse:
|
|
645
|
+
return None, None, None, None
|
|
646
|
+
elif theta is not None:
|
|
647
|
+
return theta, None, None, None
|
|
648
|
+
else:
|
|
649
|
+
if rho is None:
|
|
650
|
+
p, _ = x_train.shape
|
|
651
|
+
rho = float(p)
|
|
652
|
+
return None, a, b, rho
|
|
653
|
+
|
|
654
|
+
@staticmethod
|
|
655
|
+
def _process_offset_settings(
|
|
656
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
657
|
+
offset: float | Float32[Any, ''] | None,
|
|
658
|
+
) -> Float32[Array, '']:
|
|
338
659
|
if offset is not None:
|
|
339
|
-
return offset
|
|
660
|
+
return jnp.asarray(offset)
|
|
340
661
|
elif y_train.size < 1:
|
|
341
|
-
return 0
|
|
662
|
+
return jnp.array(0.0)
|
|
342
663
|
else:
|
|
343
|
-
|
|
664
|
+
mean = y_train.mean()
|
|
344
665
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
return
|
|
666
|
+
if y_train.dtype == bool:
|
|
667
|
+
bound = 1 / (1 + y_train.size)
|
|
668
|
+
mean = jnp.clip(mean, bound, 1 - bound)
|
|
669
|
+
return ndtri(mean)
|
|
349
670
|
else:
|
|
350
|
-
return
|
|
671
|
+
return mean
|
|
351
672
|
|
|
352
673
|
@staticmethod
|
|
353
|
-
def
|
|
354
|
-
|
|
674
|
+
def _process_leaf_sdev_settings(
|
|
675
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
676
|
+
k: float,
|
|
677
|
+
ntree: int,
|
|
678
|
+
tau_num: FloatLike | None,
|
|
679
|
+
):
|
|
680
|
+
if tau_num is None:
|
|
681
|
+
if y_train.dtype == bool:
|
|
682
|
+
tau_num = 3.0
|
|
683
|
+
elif y_train.size < 2:
|
|
684
|
+
tau_num = 1.0
|
|
685
|
+
else:
|
|
686
|
+
tau_num = (y_train.max() - y_train.min()) / 2
|
|
687
|
+
|
|
688
|
+
return tau_num / (k * math.sqrt(ntree))
|
|
689
|
+
|
|
690
|
+
@staticmethod
|
|
691
|
+
def _determine_splits(
|
|
692
|
+
x_train: Real[Array, 'p n'],
|
|
693
|
+
usequants: bool,
|
|
694
|
+
numcut: int,
|
|
695
|
+
xinfo: Float[Array, 'p n'] | None,
|
|
696
|
+
) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
|
|
697
|
+
if xinfo is not None:
|
|
698
|
+
if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]:
|
|
699
|
+
msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)'
|
|
700
|
+
raise ValueError(msg)
|
|
701
|
+
return prepcovars.parse_xinfo(xinfo)
|
|
702
|
+
elif usequants:
|
|
355
703
|
return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
|
|
356
704
|
else:
|
|
357
705
|
return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
|
|
358
706
|
|
|
359
707
|
@staticmethod
|
|
360
|
-
def _bin_predictors(x, splits):
|
|
708
|
+
def _bin_predictors(x, splits) -> UInt[Array, 'p n']:
|
|
361
709
|
return prepcovars.bin_predictors(x, splits)
|
|
362
710
|
|
|
363
|
-
@staticmethod
|
|
364
|
-
def _transform_input(y, lamda, offset, scale):
|
|
365
|
-
y = (y - offset) / scale
|
|
366
|
-
lamda = lamda / (scale * scale)
|
|
367
|
-
return y, lamda
|
|
368
|
-
|
|
369
711
|
@staticmethod
|
|
370
712
|
def _setup_mcmc(
|
|
371
|
-
x_train,
|
|
372
|
-
y_train,
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
713
|
+
x_train: Real[Array, 'p n'],
|
|
714
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
715
|
+
offset: Float32[Array, ''],
|
|
716
|
+
w: Float[Array, ' n'] | None,
|
|
717
|
+
max_split: UInt[Array, ' p'],
|
|
718
|
+
lamda: Float32[Array, ''] | None,
|
|
719
|
+
sigma_mu: FloatLike,
|
|
720
|
+
sigdf: FloatLike,
|
|
721
|
+
power: FloatLike,
|
|
722
|
+
base: FloatLike,
|
|
723
|
+
maxdepth: int,
|
|
724
|
+
ntree: int,
|
|
725
|
+
init_kw: dict[str, Any] | None,
|
|
726
|
+
rm_const: bool | None,
|
|
727
|
+
theta: FloatLike | None,
|
|
728
|
+
a: FloatLike | None,
|
|
729
|
+
b: FloatLike | None,
|
|
730
|
+
rho: FloatLike | None,
|
|
382
731
|
):
|
|
383
732
|
depth = jnp.arange(maxdepth - 1)
|
|
384
733
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
385
|
-
|
|
386
|
-
|
|
734
|
+
|
|
735
|
+
if y_train.dtype == bool:
|
|
736
|
+
sigma2_alpha = None
|
|
737
|
+
sigma2_beta = None
|
|
738
|
+
else:
|
|
739
|
+
sigma2_alpha = sigdf / 2
|
|
740
|
+
sigma2_beta = lamda * sigma2_alpha
|
|
741
|
+
|
|
387
742
|
kw = dict(
|
|
388
743
|
X=x_train,
|
|
389
|
-
|
|
744
|
+
# copy y_train because it's going to be donated in the mcmc loop
|
|
745
|
+
y=jnp.array(y_train),
|
|
746
|
+
offset=offset,
|
|
390
747
|
error_scale=w,
|
|
391
748
|
max_split=max_split,
|
|
392
749
|
num_trees=ntree,
|
|
393
750
|
p_nonterminal=p_nonterminal,
|
|
751
|
+
sigma_mu2=jnp.square(sigma_mu),
|
|
394
752
|
sigma2_alpha=sigma2_alpha,
|
|
395
753
|
sigma2_beta=sigma2_beta,
|
|
754
|
+
min_points_per_decision_node=10,
|
|
396
755
|
min_points_per_leaf=5,
|
|
756
|
+
theta=theta,
|
|
757
|
+
a=a,
|
|
758
|
+
b=b,
|
|
759
|
+
rho=rho,
|
|
397
760
|
)
|
|
398
|
-
|
|
399
|
-
|
|
761
|
+
|
|
762
|
+
if rm_const is None:
|
|
763
|
+
kw.update(filter_splitless_vars=False)
|
|
764
|
+
elif rm_const:
|
|
765
|
+
kw.update(filter_splitless_vars=True)
|
|
766
|
+
else:
|
|
767
|
+
n_empty = jnp.count_nonzero(max_split == 0)
|
|
768
|
+
if n_empty:
|
|
769
|
+
msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
|
|
770
|
+
raise ValueError(msg)
|
|
771
|
+
kw.update(filter_splitless_vars=False)
|
|
772
|
+
|
|
773
|
+
if init_kw is not None:
|
|
774
|
+
kw.update(init_kw)
|
|
775
|
+
|
|
400
776
|
return mcmcstep.init(**kw)
|
|
401
777
|
|
|
402
778
|
@staticmethod
|
|
403
|
-
def _run_mcmc(
|
|
779
|
+
def _run_mcmc(
|
|
780
|
+
mcmc_state: mcmcstep.State,
|
|
781
|
+
ndpost: int,
|
|
782
|
+
nskip: int,
|
|
783
|
+
keepevery: int,
|
|
784
|
+
printevery: int | None,
|
|
785
|
+
seed: int | Integer[Array, ''] | Key[Array, ''],
|
|
786
|
+
run_mcmc_kw: dict | None,
|
|
787
|
+
sparse: bool,
|
|
788
|
+
):
|
|
789
|
+
# prepare random generator seed
|
|
404
790
|
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
405
791
|
seed.dtype, jax.dtypes.prng_key
|
|
406
792
|
):
|
|
407
|
-
key = seed
|
|
793
|
+
key = seed.copy()
|
|
794
|
+
# copy because the inner loop in run_mcmc will donate the buffer
|
|
408
795
|
else:
|
|
409
796
|
key = jax.random.key(seed)
|
|
410
|
-
callback = mcmcloop.make_simple_print_callback(printevery)
|
|
411
|
-
return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
|
|
412
797
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
@staticmethod
|
|
422
|
-
def _extract_sigma(trace, scale):
|
|
423
|
-
return scale * jnp.sqrt(trace['sigma2'])
|
|
424
|
-
|
|
425
|
-
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
426
|
-
from . import debug
|
|
427
|
-
|
|
428
|
-
trace = self._main_trace
|
|
429
|
-
leaf_tree = trace['leaf_trees'][i_sample, i_tree]
|
|
430
|
-
var_tree = trace['var_trees'][i_sample, i_tree]
|
|
431
|
-
split_tree = trace['split_trees'][i_sample, i_tree]
|
|
432
|
-
debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
|
|
433
|
-
|
|
434
|
-
def _sigma_harmonic_mean(self, prior=False):
|
|
435
|
-
bart = self._mcmc_state
|
|
436
|
-
if prior:
|
|
437
|
-
alpha = bart['sigma2_alpha']
|
|
438
|
-
beta = bart['sigma2_beta']
|
|
439
|
-
else:
|
|
440
|
-
resid = bart['resid']
|
|
441
|
-
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
442
|
-
norm2 = jnp.dot(
|
|
443
|
-
resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
|
|
798
|
+
# prepare arguments
|
|
799
|
+
kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
|
|
800
|
+
kw.update(
|
|
801
|
+
mcmcloop.make_default_callback(
|
|
802
|
+
dot_every=None if printevery is None or printevery == 1 else 1,
|
|
803
|
+
report_every=printevery,
|
|
804
|
+
sparse_on_at=nskip // 2 if sparse else None,
|
|
444
805
|
)
|
|
445
|
-
beta = bart['sigma2_beta'] + norm2 / 2
|
|
446
|
-
sigma2 = beta / alpha
|
|
447
|
-
return jnp.sqrt(sigma2) * self.scale
|
|
448
|
-
|
|
449
|
-
def _compare_resid(self):
|
|
450
|
-
bart = self._mcmc_state
|
|
451
|
-
resid1 = bart['resid']
|
|
452
|
-
yhat = grove.evaluate_forest(
|
|
453
|
-
bart['X'],
|
|
454
|
-
bart['leaf_trees'],
|
|
455
|
-
bart['var_trees'],
|
|
456
|
-
bart['split_trees'],
|
|
457
|
-
jnp.float32,
|
|
458
806
|
)
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
def _avg_acc(self):
|
|
463
|
-
trace = self._main_trace
|
|
464
|
-
|
|
465
|
-
def acc(prefix):
|
|
466
|
-
acc = trace[f'{prefix}_acc_count']
|
|
467
|
-
prop = trace[f'{prefix}_prop_count']
|
|
468
|
-
return acc.sum() / prop.sum()
|
|
469
|
-
|
|
470
|
-
return acc('grow'), acc('prune')
|
|
471
|
-
|
|
472
|
-
def _avg_prop(self):
|
|
473
|
-
trace = self._main_trace
|
|
474
|
-
|
|
475
|
-
def prop(prefix):
|
|
476
|
-
return trace[f'{prefix}_prop_count'].sum()
|
|
477
|
-
|
|
478
|
-
pgrow = prop('grow')
|
|
479
|
-
pprune = prop('prune')
|
|
480
|
-
total = pgrow + pprune
|
|
481
|
-
return pgrow / total, pprune / total
|
|
482
|
-
|
|
483
|
-
def _avg_move(self):
|
|
484
|
-
agrow, aprune = self._avg_acc()
|
|
485
|
-
pgrow, pprune = self._avg_prop()
|
|
486
|
-
return agrow * pgrow, aprune * pprune
|
|
487
|
-
|
|
488
|
-
def _depth_distr(self):
|
|
489
|
-
from . import debug
|
|
490
|
-
|
|
491
|
-
trace = self._main_trace
|
|
492
|
-
split_trees = trace['split_trees']
|
|
493
|
-
return debug.trace_depth_distr(split_trees)
|
|
494
|
-
|
|
495
|
-
def _points_per_leaf_distr(self):
|
|
496
|
-
from . import debug
|
|
497
|
-
|
|
498
|
-
return debug.trace_points_per_leaf_distr(
|
|
499
|
-
self._main_trace, self._mcmc_state['X']
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
def _check_trees(self):
|
|
503
|
-
from . import debug
|
|
807
|
+
if run_mcmc_kw is not None:
|
|
808
|
+
kw.update(run_mcmc_kw)
|
|
504
809
|
|
|
505
|
-
return
|
|
810
|
+
return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
|
|
506
811
|
|
|
507
|
-
def
|
|
508
|
-
|
|
509
|
-
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
510
|
-
return bad & ~bad_before
|
|
812
|
+
def _predict(self, x):
|
|
813
|
+
return mcmcloop.evaluate_trace(self._main_trace, x)
|