bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +464 -254
- bartz/__init__.py +2 -2
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +139 -93
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +468 -311
- bartz/mcmcstep.py +734 -453
- bartz/prepcovars.py +139 -43
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/METADATA +2 -3
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -423
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/BART.py
CHANGED
|
@@ -22,25 +22,73 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""Implement a
|
|
25
|
+
"""Implement a class `gbart` that mimics the R BART package."""
|
|
26
26
|
|
|
27
|
-
import functools
|
|
28
27
|
import math
|
|
29
|
-
from
|
|
28
|
+
from collections.abc import Sequence
|
|
29
|
+
from functools import cached_property
|
|
30
|
+
from typing import Any, Literal, Protocol
|
|
30
31
|
|
|
31
32
|
import jax
|
|
32
33
|
import jax.numpy as jnp
|
|
33
|
-
from
|
|
34
|
-
from
|
|
35
|
-
|
|
36
|
-
|
|
34
|
+
from equinox import Module, field
|
|
35
|
+
from jax.scipy.special import ndtr
|
|
36
|
+
from jaxtyping import (
|
|
37
|
+
Array,
|
|
38
|
+
Bool,
|
|
39
|
+
Float,
|
|
40
|
+
Float32,
|
|
41
|
+
Int32,
|
|
42
|
+
Integer,
|
|
43
|
+
Key,
|
|
44
|
+
Real,
|
|
45
|
+
Shaped,
|
|
46
|
+
UInt,
|
|
47
|
+
)
|
|
48
|
+
from numpy import ndarray
|
|
49
|
+
|
|
50
|
+
from bartz import mcmcloop, mcmcstep, prepcovars
|
|
51
|
+
from bartz.jaxext.scipy.special import ndtri
|
|
52
|
+
from bartz.jaxext.scipy.stats import invgamma
|
|
37
53
|
|
|
38
54
|
FloatLike = float | Float[Any, '']
|
|
39
55
|
|
|
40
56
|
|
|
41
|
-
class
|
|
57
|
+
class DataFrame(Protocol):
|
|
58
|
+
"""DataFrame duck-type for `gbart`.
|
|
59
|
+
|
|
60
|
+
Attributes
|
|
61
|
+
----------
|
|
62
|
+
columns : Sequence[str]
|
|
63
|
+
The names of the columns.
|
|
42
64
|
"""
|
|
43
|
-
|
|
65
|
+
|
|
66
|
+
columns: Sequence[str]
|
|
67
|
+
|
|
68
|
+
def to_numpy(self) -> ndarray:
|
|
69
|
+
"""Convert the dataframe to a 2d numpy array with columns on the second axis."""
|
|
70
|
+
...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Series(Protocol):
|
|
74
|
+
"""Series duck-type for `gbart`.
|
|
75
|
+
|
|
76
|
+
Attributes
|
|
77
|
+
----------
|
|
78
|
+
name : str | None
|
|
79
|
+
The name of the series.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
name: str | None
|
|
83
|
+
|
|
84
|
+
def to_numpy(self) -> ndarray:
|
|
85
|
+
"""Convert the series to a 1d numpy array."""
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class gbart(Module):
|
|
90
|
+
R"""
|
|
91
|
+
Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
|
|
44
92
|
|
|
45
93
|
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
46
94
|
a sum of decision trees. The inference is carried out by sampling the
|
|
@@ -48,36 +96,79 @@ class gbart:
|
|
|
48
96
|
|
|
49
97
|
Parameters
|
|
50
98
|
----------
|
|
51
|
-
x_train
|
|
99
|
+
x_train
|
|
52
100
|
The training predictors.
|
|
53
|
-
y_train
|
|
101
|
+
y_train
|
|
54
102
|
The training responses.
|
|
55
|
-
x_test
|
|
103
|
+
x_test
|
|
56
104
|
The test predictors.
|
|
57
105
|
type
|
|
58
106
|
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
59
107
|
binary regression with probit link.
|
|
60
|
-
|
|
108
|
+
sparse
|
|
109
|
+
Whether to activate variable selection on the predictors as done in
|
|
110
|
+
[1]_.
|
|
111
|
+
theta
|
|
112
|
+
a
|
|
113
|
+
b
|
|
114
|
+
rho
|
|
115
|
+
Hyperparameters of the sparsity prior used for variable selection.
|
|
116
|
+
|
|
117
|
+
The prior distribution on the choice of predictor for each decision rule
|
|
118
|
+
is
|
|
119
|
+
|
|
120
|
+
.. math::
|
|
121
|
+
(s_1, \ldots, s_p) \sim
|
|
122
|
+
\operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
|
|
123
|
+
|
|
124
|
+
If `theta` is not specified, it's a priori distributed according to
|
|
125
|
+
|
|
126
|
+
.. math::
|
|
127
|
+
\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
|
|
128
|
+
\operatorname{Beta}(\mathtt{a}, \mathtt{b}).
|
|
129
|
+
|
|
130
|
+
If not specified, `rho` is set to the number of predictors p. To tune
|
|
131
|
+
the prior, consider setting a lower `rho` to prefer more sparsity.
|
|
132
|
+
If setting `theta` directly, it should be in the ballpark of p or lower
|
|
133
|
+
as well.
|
|
134
|
+
xinfo
|
|
135
|
+
A matrix with the cutpoins to use to bin each predictor. If not
|
|
136
|
+
specified, it is generated automatically according to `usequants` and
|
|
137
|
+
`numcut`.
|
|
138
|
+
|
|
139
|
+
Each row shall contain a sorted list of cutpoints for a predictor. If
|
|
140
|
+
there are less cutpoints than the number of columns in the matrix,
|
|
141
|
+
fill the remaining cells with NaN.
|
|
142
|
+
|
|
143
|
+
`xinfo` shall be a matrix even if `x_train` is a dataframe.
|
|
144
|
+
usequants
|
|
61
145
|
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
62
|
-
predictors.
|
|
63
|
-
|
|
146
|
+
predictors. Ignored if `xinfo` is specified.
|
|
147
|
+
rm_const
|
|
148
|
+
How to treat predictors with no associated decision rules (i.e., there
|
|
149
|
+
are no available cutpoints for that predictor). If `True` (default),
|
|
150
|
+
they are ignored. If `False`, an error is raised if there are any. If
|
|
151
|
+
`None`, no check is performed, and the output of the MCMC may not make
|
|
152
|
+
sense if there are predictors without cutpoints. The option `None` is
|
|
153
|
+
provided only to allow jax tracing.
|
|
154
|
+
sigest
|
|
64
155
|
An estimate of the residual standard deviation on `y_train`, used to set
|
|
65
156
|
`lamda`. If not specified, it is estimated by linear regression (with
|
|
66
157
|
intercept, and without taking into account `w`). If `y_train` has less
|
|
67
158
|
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
68
159
|
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
69
|
-
sigdf
|
|
160
|
+
sigdf
|
|
70
161
|
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
71
162
|
noise variance.
|
|
72
|
-
sigquant
|
|
163
|
+
sigquant
|
|
73
164
|
The quantile of the prior on the noise variance that shall match
|
|
74
165
|
`sigest` to set the scale of the prior. Ignored if `lamda` is specified.
|
|
75
|
-
k
|
|
166
|
+
k
|
|
76
167
|
The inverse scale of the prior standard deviation on the latent mean
|
|
77
168
|
function, relative to half the observed range of `y_train`. If `y_train`
|
|
78
169
|
has less than two elements, `k` is ignored and the scale is set to 1.
|
|
79
|
-
power
|
|
80
|
-
base
|
|
170
|
+
power
|
|
171
|
+
base
|
|
81
172
|
Parameters of the prior on tree node generation. The probability that a
|
|
82
173
|
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
83
174
|
power``.
|
|
@@ -94,16 +185,19 @@ class gbart:
|
|
|
94
185
|
The prior mean of the latent mean function. If not specified, it is set
|
|
95
186
|
to the mean of `y_train` for continuous regression, and to
|
|
96
187
|
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
97
|
-
`offset` is set to 0.
|
|
98
|
-
|
|
188
|
+
`offset` is set to 0. With binary regression, if `y_train` is all
|
|
189
|
+
`False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
|
|
190
|
+
``Phi^-1(n/(n+1))``, respectively.
|
|
191
|
+
w
|
|
99
192
|
Coefficients that rescale the error standard deviation on each
|
|
100
193
|
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
101
194
|
datapoints. Note: `w` is ignored in the automatic determination of
|
|
102
195
|
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
103
196
|
specified by the user.
|
|
104
|
-
ntree
|
|
105
|
-
The number of trees used to represent the latent mean function.
|
|
106
|
-
|
|
197
|
+
ntree
|
|
198
|
+
The number of trees used to represent the latent mean function. By
|
|
199
|
+
default 200 for continuous regression and 50 for binary regression.
|
|
200
|
+
numcut
|
|
107
201
|
If `usequants` is `False`: the exact number of cutpoints used to bin the
|
|
108
202
|
predictors, ranging between the minimum and maximum observed values
|
|
109
203
|
(excluded).
|
|
@@ -116,14 +210,17 @@ class gbart:
|
|
|
116
210
|
|
|
117
211
|
Before running the algorithm, the predictors are compressed to the
|
|
118
212
|
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
119
|
-
to the maximum value of an unsigned integer type.
|
|
120
|
-
|
|
213
|
+
to the maximum value of an unsigned integer type, like 255.
|
|
214
|
+
|
|
215
|
+
Ignored if `xinfo` is specified.
|
|
216
|
+
ndpost
|
|
121
217
|
The number of MCMC samples to save, after burn-in.
|
|
122
|
-
nskip
|
|
218
|
+
nskip
|
|
123
219
|
The number of initial MCMC samples to discard as burn-in.
|
|
124
|
-
keepevery
|
|
125
|
-
The thinning factor for the MCMC samples, after burn-in.
|
|
126
|
-
|
|
220
|
+
keepevery
|
|
221
|
+
The thinning factor for the MCMC samples, after burn-in. By default, 1
|
|
222
|
+
for continuous regression and 10 for binary regression.
|
|
223
|
+
printevery
|
|
127
224
|
The number of iterations (including thinned-away ones) between each log
|
|
128
225
|
line. Set to `None` to disable logging.
|
|
129
226
|
|
|
@@ -132,34 +229,24 @@ class gbart:
|
|
|
132
229
|
iterations is a multiple of `printevery`, so if ``nskip + keepevery *
|
|
133
230
|
ndpost`` is not a multiple of `printevery`, some of the last iterations
|
|
134
231
|
will not be saved.
|
|
135
|
-
seed
|
|
232
|
+
seed
|
|
136
233
|
The seed for the random number generator.
|
|
137
|
-
maxdepth
|
|
234
|
+
maxdepth
|
|
138
235
|
The maximum depth of the trees. This is 1-based, so with the default
|
|
139
236
|
``maxdepth=6``, the depths of the levels range from 0 to 5.
|
|
140
|
-
init_kw
|
|
141
|
-
Additional arguments passed to `mcmcstep.init`.
|
|
142
|
-
run_mcmc_kw
|
|
143
|
-
Additional arguments passed to `mcmcloop.run_mcmc`.
|
|
237
|
+
init_kw
|
|
238
|
+
Additional arguments passed to `bartz.mcmcstep.init`.
|
|
239
|
+
run_mcmc_kw
|
|
240
|
+
Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
|
|
144
241
|
|
|
145
242
|
Attributes
|
|
146
243
|
----------
|
|
147
|
-
|
|
148
|
-
The conditional posterior mean at `x_train` for each MCMC iteration.
|
|
149
|
-
yhat_train_mean : array (n,)
|
|
150
|
-
The marginal posterior mean at `x_train`.
|
|
151
|
-
yhat_test : array (ndpost, m)
|
|
152
|
-
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
153
|
-
yhat_test_mean : array (m,)
|
|
154
|
-
The marginal posterior mean at `x_test`.
|
|
155
|
-
sigma : array (ndpost,)
|
|
156
|
-
The standard deviation of the error.
|
|
157
|
-
first_sigma : array (nskip,)
|
|
158
|
-
The standard deviation of the error in the burn-in phase.
|
|
159
|
-
offset : float
|
|
244
|
+
offset : Float32[Array, '']
|
|
160
245
|
The prior mean of the latent mean function.
|
|
161
|
-
sigest :
|
|
246
|
+
sigest : Float32[Array, ''] | None
|
|
162
247
|
The estimated standard deviation of the error used to set `lamda`.
|
|
248
|
+
yhat_test : Float32[Array, 'ndpost m'] | None
|
|
249
|
+
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
163
250
|
|
|
164
251
|
Notes
|
|
165
252
|
-----
|
|
@@ -168,68 +255,111 @@ class gbart:
|
|
|
168
255
|
|
|
169
256
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
170
257
|
instead of per column.
|
|
171
|
-
- If `type` is not specified, it is determined solely based on the data type
|
|
172
|
-
of `y_train`, and not on whether it contains only two unique values.
|
|
173
258
|
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
174
259
|
less predictor values than the required number of bins, while bartz
|
|
175
260
|
always follows the specification.
|
|
261
|
+
- Some functionality is missing.
|
|
176
262
|
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
177
|
-
- `rm_const` is always `False`.
|
|
178
|
-
- The default `numcut` is 255 instead of 100.
|
|
179
|
-
- A lot of functionality is missing (e.g., variable selection).
|
|
180
263
|
- There are some additional attributes, and some missing.
|
|
181
264
|
- The trees have a maximum depth.
|
|
265
|
+
- `rm_const` refers to predictors without decision rules instead of
|
|
266
|
+
predictors that are constant in `x_train`.
|
|
267
|
+
- If `rm_const=True` and some variables are dropped, the predictors
|
|
268
|
+
matrix/dataframe passed to `predict` should still include them.
|
|
182
269
|
|
|
270
|
+
References
|
|
271
|
+
----------
|
|
272
|
+
.. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
|
|
273
|
+
High-Dimensional Prediction and Variable Selection”. In: Journal of the
|
|
274
|
+
American Statistical Association 113.522, pp. 626-636.
|
|
275
|
+
.. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
|
|
276
|
+
Bayesian additive regression trees," The Annals of Applied Statistics,
|
|
277
|
+
Ann. Appl. Stat. 4(1), 266-298, (March 2010).
|
|
183
278
|
"""
|
|
184
279
|
|
|
280
|
+
_main_trace: mcmcloop.MainTrace
|
|
281
|
+
_burnin_trace: mcmcloop.BurninTrace
|
|
282
|
+
_mcmc_state: mcmcstep.State
|
|
283
|
+
_splits: Real[Array, 'p max_num_splits']
|
|
284
|
+
_x_train_fmt: Any = field(static=True)
|
|
285
|
+
|
|
286
|
+
ndpost: int = field(static=True)
|
|
287
|
+
offset: Float32[Array, '']
|
|
288
|
+
sigest: Float32[Array, ''] | None = None
|
|
289
|
+
yhat_test: Float32[Array, 'ndpost m'] | None = None
|
|
290
|
+
|
|
185
291
|
def __init__(
|
|
186
292
|
self,
|
|
187
|
-
x_train,
|
|
188
|
-
y_train,
|
|
293
|
+
x_train: Real[Array, 'p n'] | DataFrame,
|
|
294
|
+
y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
|
|
189
295
|
*,
|
|
190
|
-
x_test=None,
|
|
191
|
-
type: Literal['wbart', 'pbart'] = 'wbart',
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
296
|
+
x_test: Real[Array, 'p m'] | DataFrame | None = None,
|
|
297
|
+
type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
|
|
298
|
+
sparse: bool = False,
|
|
299
|
+
theta: FloatLike | None = None,
|
|
300
|
+
a: FloatLike = 0.5,
|
|
301
|
+
b: FloatLike = 1.0,
|
|
302
|
+
rho: FloatLike | None = None,
|
|
303
|
+
xinfo: Float[Array, 'p n'] | None = None,
|
|
304
|
+
usequants: bool = False,
|
|
305
|
+
rm_const: bool | None = True,
|
|
306
|
+
sigest: FloatLike | None = None,
|
|
307
|
+
sigdf: FloatLike = 3.0,
|
|
308
|
+
sigquant: FloatLike = 0.9,
|
|
309
|
+
k: FloatLike = 2.0,
|
|
310
|
+
power: FloatLike = 2.0,
|
|
311
|
+
base: FloatLike = 0.95,
|
|
199
312
|
lamda: FloatLike | None = None,
|
|
200
313
|
tau_num: FloatLike | None = None,
|
|
201
314
|
offset: FloatLike | None = None,
|
|
202
|
-
w=None,
|
|
203
|
-
ntree=
|
|
204
|
-
numcut=
|
|
205
|
-
ndpost=1000,
|
|
206
|
-
nskip=100,
|
|
207
|
-
keepevery=
|
|
208
|
-
printevery=100,
|
|
209
|
-
seed=0,
|
|
210
|
-
maxdepth=6,
|
|
211
|
-
init_kw=None,
|
|
212
|
-
run_mcmc_kw=None,
|
|
315
|
+
w: Float[Array, ' n'] | None = None,
|
|
316
|
+
ntree: int | None = None,
|
|
317
|
+
numcut: int = 100,
|
|
318
|
+
ndpost: int = 1000,
|
|
319
|
+
nskip: int = 100,
|
|
320
|
+
keepevery: int | None = None,
|
|
321
|
+
printevery: int | None = 100,
|
|
322
|
+
seed: int | Key[Array, ''] = 0,
|
|
323
|
+
maxdepth: int = 6,
|
|
324
|
+
init_kw: dict | None = None,
|
|
325
|
+
run_mcmc_kw: dict | None = None,
|
|
213
326
|
):
|
|
327
|
+
# check data and put it in the right format
|
|
214
328
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
215
|
-
y_train
|
|
329
|
+
y_train = self._process_response_input(y_train)
|
|
216
330
|
self._check_same_length(x_train, y_train)
|
|
217
331
|
if w is not None:
|
|
218
|
-
w
|
|
332
|
+
w = self._process_response_input(w)
|
|
219
333
|
self._check_same_length(x_train, w)
|
|
220
334
|
|
|
221
|
-
|
|
335
|
+
# check data types are correct for continuous/binary regression
|
|
336
|
+
self._check_type_settings(y_train, type, w)
|
|
222
337
|
# from here onwards, the type is determined by y_train.dtype == bool
|
|
338
|
+
|
|
339
|
+
# set defaults that depend on type of regression
|
|
340
|
+
if ntree is None:
|
|
341
|
+
ntree = 50 if y_train.dtype == bool else 200
|
|
342
|
+
if keepevery is None:
|
|
343
|
+
keepevery = 10 if y_train.dtype == bool else 1
|
|
344
|
+
|
|
345
|
+
# process sparsity settings
|
|
346
|
+
theta, a, b, rho = self._process_sparsity_settings(
|
|
347
|
+
x_train, sparse, theta, a, b, rho
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# process "standardization" settings
|
|
223
351
|
offset = self._process_offset_settings(y_train, offset)
|
|
224
352
|
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
|
|
225
353
|
lamda, sigest = self._process_error_variance_settings(
|
|
226
354
|
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
227
355
|
)
|
|
228
356
|
|
|
229
|
-
|
|
357
|
+
# determine splits
|
|
358
|
+
splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo)
|
|
230
359
|
x_train = self._bin_predictors(x_train, splits)
|
|
231
360
|
|
|
232
|
-
|
|
361
|
+
# setup and run mcmc
|
|
362
|
+
initial_state = self._setup_mcmc(
|
|
233
363
|
x_train,
|
|
234
364
|
y_train,
|
|
235
365
|
offset,
|
|
@@ -243,51 +373,163 @@ class gbart:
|
|
|
243
373
|
maxdepth,
|
|
244
374
|
ntree,
|
|
245
375
|
init_kw,
|
|
376
|
+
rm_const,
|
|
377
|
+
theta,
|
|
378
|
+
a,
|
|
379
|
+
b,
|
|
380
|
+
rho,
|
|
246
381
|
)
|
|
247
382
|
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
248
|
-
|
|
383
|
+
initial_state,
|
|
384
|
+
ndpost,
|
|
385
|
+
nskip,
|
|
386
|
+
keepevery,
|
|
387
|
+
printevery,
|
|
388
|
+
seed,
|
|
389
|
+
run_mcmc_kw,
|
|
390
|
+
sparse,
|
|
249
391
|
)
|
|
250
392
|
|
|
251
|
-
|
|
252
|
-
first_sigma = self._extract_sigma(burnin_trace)
|
|
253
|
-
|
|
393
|
+
# set public attributes
|
|
254
394
|
self.offset = final_state.offset # from the state because of buffer donation
|
|
395
|
+
self.ndpost = ndpost
|
|
255
396
|
self.sigest = sigest
|
|
256
|
-
self.sigma = sigma
|
|
257
|
-
self.first_sigma = first_sigma
|
|
258
397
|
|
|
259
|
-
|
|
260
|
-
self._splits = splits
|
|
398
|
+
# set private attributes
|
|
261
399
|
self._main_trace = main_trace
|
|
400
|
+
self._burnin_trace = burnin_trace
|
|
262
401
|
self._mcmc_state = final_state
|
|
402
|
+
self._splits = splits
|
|
403
|
+
self._x_train_fmt = x_train_fmt
|
|
263
404
|
|
|
405
|
+
# predict at test points
|
|
264
406
|
if x_test is not None:
|
|
265
|
-
yhat_test = self.predict(x_test)
|
|
266
|
-
|
|
267
|
-
|
|
407
|
+
self.yhat_test = self.predict(x_test)
|
|
408
|
+
|
|
409
|
+
@cached_property
|
|
410
|
+
def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
|
|
411
|
+
"""The posterior probability of y being True at `x_test` for each MCMC iteration."""
|
|
412
|
+
if self.yhat_test is None or self._mcmc_state.y.dtype != bool:
|
|
413
|
+
return None
|
|
414
|
+
else:
|
|
415
|
+
return ndtr(self.yhat_test)
|
|
416
|
+
|
|
417
|
+
@cached_property
|
|
418
|
+
def prob_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
419
|
+
"""The marginal posterior probability of y being True at `x_test`."""
|
|
420
|
+
if self.prob_test is None:
|
|
421
|
+
return None
|
|
422
|
+
else:
|
|
423
|
+
return self.prob_test.mean(axis=0)
|
|
424
|
+
|
|
425
|
+
@cached_property
|
|
426
|
+
def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
|
|
427
|
+
"""The posterior probability of y being True at `x_train` for each MCMC iteration."""
|
|
428
|
+
if self._mcmc_state.y.dtype == bool:
|
|
429
|
+
return ndtr(self.yhat_train)
|
|
430
|
+
else:
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
@cached_property
|
|
434
|
+
def prob_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
435
|
+
"""The marginal posterior probability of y being True at `x_train`."""
|
|
436
|
+
if self.prob_train is None:
|
|
437
|
+
return None
|
|
438
|
+
else:
|
|
439
|
+
return self.prob_train.mean(axis=0)
|
|
440
|
+
|
|
441
|
+
@cached_property
|
|
442
|
+
def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None:
|
|
443
|
+
"""The standard deviation of the error, including burn-in samples."""
|
|
444
|
+
if self._burnin_trace.sigma2 is None:
|
|
445
|
+
return None
|
|
446
|
+
else:
|
|
447
|
+
assert self._main_trace.sigma2 is not None
|
|
448
|
+
return jnp.sqrt(
|
|
449
|
+
jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2])
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
@cached_property
|
|
453
|
+
def sigma_mean(self) -> Float32[Array, ''] | None:
|
|
454
|
+
"""The mean of `sigma`, only over the post-burnin samples."""
|
|
455
|
+
if self.sigma is None:
|
|
456
|
+
return None
|
|
457
|
+
else:
|
|
458
|
+
return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0)
|
|
268
459
|
|
|
269
|
-
@
|
|
270
|
-
def
|
|
460
|
+
@cached_property
|
|
461
|
+
def varcount(self) -> Int32[Array, 'ndpost p']:
|
|
462
|
+
"""Histogram of predictor usage for decision rules in the trees."""
|
|
463
|
+
return mcmcloop.compute_varcount(
|
|
464
|
+
self._mcmc_state.forest.max_split.size, self._main_trace
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
@cached_property
|
|
468
|
+
def varcount_mean(self) -> Float32[Array, ' p']:
|
|
469
|
+
"""Average of `varcount` across MCMC iterations."""
|
|
470
|
+
return self.varcount.mean(axis=0)
|
|
471
|
+
|
|
472
|
+
@cached_property
|
|
473
|
+
def varprob(self) -> Float32[Array, 'ndpost p']:
|
|
474
|
+
"""Posterior samples of the probability of choosing each predictor for a decision rule."""
|
|
475
|
+
varprob = self._main_trace.varprob
|
|
476
|
+
if varprob is None:
|
|
477
|
+
max_split = self._mcmc_state.forest.max_split
|
|
478
|
+
p = max_split.size
|
|
479
|
+
peff = jnp.count_nonzero(max_split)
|
|
480
|
+
varprob = jnp.where(max_split, 1 / peff, 0)
|
|
481
|
+
varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
|
|
482
|
+
return varprob
|
|
483
|
+
|
|
484
|
+
@cached_property
|
|
485
|
+
def varprob_mean(self) -> Float32[Array, ' p']:
|
|
486
|
+
"""The marginal posterior probability of each predictor being chosen for a decision rule."""
|
|
487
|
+
return self.varprob.mean(axis=0)
|
|
488
|
+
|
|
489
|
+
@cached_property
|
|
490
|
+
def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
491
|
+
"""The marginal posterior mean at `x_test`.
|
|
492
|
+
|
|
493
|
+
Not defined with binary regression because it's error-prone, typically
|
|
494
|
+
the right thing to consider would be `prob_test_mean`.
|
|
495
|
+
"""
|
|
496
|
+
if self.yhat_test is None or self._mcmc_state.y.dtype == bool:
|
|
497
|
+
return None
|
|
498
|
+
else:
|
|
499
|
+
return self.yhat_test.mean(axis=0)
|
|
500
|
+
|
|
501
|
+
@cached_property
|
|
502
|
+
def yhat_train(self) -> Float32[Array, 'ndpost n']:
|
|
503
|
+
"""The conditional posterior mean at `x_train` for each MCMC iteration."""
|
|
271
504
|
x_train = self._mcmc_state.X
|
|
272
|
-
return self._predict(
|
|
505
|
+
return self._predict(x_train)
|
|
273
506
|
|
|
274
|
-
@
|
|
275
|
-
def yhat_train_mean(self):
|
|
276
|
-
|
|
507
|
+
@cached_property
|
|
508
|
+
def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
509
|
+
"""The marginal posterior mean at `x_train`.
|
|
277
510
|
|
|
278
|
-
|
|
511
|
+
Not defined with binary regression because it's error-prone, typically
|
|
512
|
+
the right thing to consider would be `prob_train_mean`.
|
|
513
|
+
"""
|
|
514
|
+
if self._mcmc_state.y.dtype == bool:
|
|
515
|
+
return None
|
|
516
|
+
else:
|
|
517
|
+
return self.yhat_train.mean(axis=0)
|
|
518
|
+
|
|
519
|
+
def predict(
|
|
520
|
+
self, x_test: Real[Array, 'p m'] | DataFrame
|
|
521
|
+
) -> Float32[Array, 'ndpost m']:
|
|
279
522
|
"""
|
|
280
523
|
Compute the posterior mean at `x_test` for each MCMC iteration.
|
|
281
524
|
|
|
282
525
|
Parameters
|
|
283
526
|
----------
|
|
284
|
-
x_test
|
|
527
|
+
x_test
|
|
285
528
|
The test predictors.
|
|
286
529
|
|
|
287
530
|
Returns
|
|
288
531
|
-------
|
|
289
|
-
|
|
290
|
-
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
532
|
+
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
291
533
|
|
|
292
534
|
Raises
|
|
293
535
|
------
|
|
@@ -296,14 +538,13 @@ class gbart:
|
|
|
296
538
|
"""
|
|
297
539
|
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
298
540
|
if x_test_fmt != self._x_train_fmt:
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
)
|
|
541
|
+
msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
|
|
542
|
+
raise ValueError(msg)
|
|
302
543
|
x_test = self._bin_predictors(x_test, self._splits)
|
|
303
|
-
return self._predict(
|
|
544
|
+
return self._predict(x_test)
|
|
304
545
|
|
|
305
546
|
@staticmethod
|
|
306
|
-
def _process_predictor_input(x):
|
|
547
|
+
def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
|
|
307
548
|
if hasattr(x, 'columns'):
|
|
308
549
|
fmt = dict(kind='dataframe', columns=x.columns)
|
|
309
550
|
x = x.to_numpy().T
|
|
@@ -314,15 +555,12 @@ class gbart:
|
|
|
314
555
|
return x, fmt
|
|
315
556
|
|
|
316
557
|
@staticmethod
|
|
317
|
-
def _process_response_input(y):
|
|
558
|
+
def _process_response_input(y) -> Shaped[Array, ' n']:
|
|
318
559
|
if hasattr(y, 'to_numpy'):
|
|
319
|
-
fmt = dict(kind='series', name=y.name)
|
|
320
560
|
y = y.to_numpy()
|
|
321
|
-
else:
|
|
322
|
-
fmt = dict(kind='array')
|
|
323
561
|
y = jnp.asarray(y)
|
|
324
562
|
assert y.ndim == 1
|
|
325
|
-
return y
|
|
563
|
+
return y
|
|
326
564
|
|
|
327
565
|
@staticmethod
|
|
328
566
|
def _check_same_length(x1, x2):
|
|
@@ -335,13 +573,16 @@ class gbart:
|
|
|
335
573
|
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
336
574
|
if y_train.dtype == bool:
|
|
337
575
|
if sigest is not None:
|
|
338
|
-
|
|
576
|
+
msg = 'Let `sigest=None` for binary regression'
|
|
577
|
+
raise ValueError(msg)
|
|
339
578
|
if lamda is not None:
|
|
340
|
-
|
|
579
|
+
msg = 'Let `lamda=None` for binary regression'
|
|
580
|
+
raise ValueError(msg)
|
|
341
581
|
return None, None
|
|
342
582
|
elif lamda is not None:
|
|
343
583
|
if sigest is not None:
|
|
344
|
-
|
|
584
|
+
msg = 'Let `sigest=None` if `lamda` is specified'
|
|
585
|
+
raise ValueError(msg)
|
|
345
586
|
return lamda, None
|
|
346
587
|
else:
|
|
347
588
|
if sigest is not None:
|
|
@@ -359,37 +600,60 @@ class gbart:
|
|
|
359
600
|
dof = len(y_train) - rank
|
|
360
601
|
sigest2 = chisq / dof
|
|
361
602
|
alpha = sigdf / 2
|
|
362
|
-
invchi2 =
|
|
603
|
+
invchi2 = invgamma.ppf(sigquant, alpha) / 2
|
|
363
604
|
invchi2rid = invchi2 * sigdf
|
|
364
605
|
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
365
606
|
|
|
366
607
|
@staticmethod
|
|
367
|
-
def
|
|
608
|
+
def _check_type_settings(y_train, type, w): # noqa: A002
|
|
368
609
|
match type:
|
|
369
610
|
case 'wbart':
|
|
370
611
|
if y_train.dtype != jnp.float32:
|
|
371
|
-
|
|
612
|
+
msg = (
|
|
372
613
|
'Continuous regression requires y_train.dtype=float32,'
|
|
373
614
|
f' got {y_train.dtype=} instead.'
|
|
374
615
|
)
|
|
616
|
+
raise TypeError(msg)
|
|
375
617
|
case 'pbart':
|
|
376
618
|
if w is not None:
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
)
|
|
619
|
+
msg = 'Binary regression does not support weights, set `w=None`'
|
|
620
|
+
raise ValueError(msg)
|
|
380
621
|
if y_train.dtype != bool:
|
|
381
|
-
|
|
622
|
+
msg = (
|
|
382
623
|
'Binary regression requires y_train.dtype=bool,'
|
|
383
624
|
f' got {y_train.dtype=} instead.'
|
|
384
625
|
)
|
|
626
|
+
raise TypeError(msg)
|
|
385
627
|
case _:
|
|
386
|
-
|
|
628
|
+
msg = f'Invalid {type=}'
|
|
629
|
+
raise ValueError(msg)
|
|
387
630
|
|
|
388
|
-
|
|
631
|
+
@staticmethod
|
|
632
|
+
def _process_sparsity_settings(
|
|
633
|
+
x_train: Real[Array, 'p n'],
|
|
634
|
+
sparse: bool,
|
|
635
|
+
theta: FloatLike | None,
|
|
636
|
+
a: FloatLike,
|
|
637
|
+
b: FloatLike,
|
|
638
|
+
rho: FloatLike | None,
|
|
639
|
+
) -> (
|
|
640
|
+
tuple[None, None, None, None]
|
|
641
|
+
| tuple[FloatLike, None, None, None]
|
|
642
|
+
| tuple[None, FloatLike, FloatLike, FloatLike]
|
|
643
|
+
):
|
|
644
|
+
if not sparse:
|
|
645
|
+
return None, None, None, None
|
|
646
|
+
elif theta is not None:
|
|
647
|
+
return theta, None, None, None
|
|
648
|
+
else:
|
|
649
|
+
if rho is None:
|
|
650
|
+
p, _ = x_train.shape
|
|
651
|
+
rho = float(p)
|
|
652
|
+
return None, a, b, rho
|
|
389
653
|
|
|
390
654
|
@staticmethod
|
|
391
655
|
def _process_offset_settings(
|
|
392
|
-
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
656
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
393
657
|
offset: float | Float32[Any, ''] | None,
|
|
394
658
|
) -> Float32[Array, '']:
|
|
395
659
|
if offset is not None:
|
|
@@ -400,13 +664,15 @@ class gbart:
|
|
|
400
664
|
mean = y_train.mean()
|
|
401
665
|
|
|
402
666
|
if y_train.dtype == bool:
|
|
667
|
+
bound = 1 / (1 + y_train.size)
|
|
668
|
+
mean = jnp.clip(mean, bound, 1 - bound)
|
|
403
669
|
return ndtri(mean)
|
|
404
670
|
else:
|
|
405
671
|
return mean
|
|
406
672
|
|
|
407
673
|
@staticmethod
|
|
408
674
|
def _process_leaf_sdev_settings(
|
|
409
|
-
y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
|
|
675
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
410
676
|
k: float,
|
|
411
677
|
ntree: int,
|
|
412
678
|
tau_num: FloatLike | None,
|
|
@@ -422,31 +688,46 @@ class gbart:
|
|
|
422
688
|
return tau_num / (k * math.sqrt(ntree))
|
|
423
689
|
|
|
424
690
|
@staticmethod
|
|
425
|
-
def _determine_splits(
|
|
426
|
-
|
|
691
|
+
def _determine_splits(
|
|
692
|
+
x_train: Real[Array, 'p n'],
|
|
693
|
+
usequants: bool,
|
|
694
|
+
numcut: int,
|
|
695
|
+
xinfo: Float[Array, 'p n'] | None,
|
|
696
|
+
) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
|
|
697
|
+
if xinfo is not None:
|
|
698
|
+
if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]:
|
|
699
|
+
msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)'
|
|
700
|
+
raise ValueError(msg)
|
|
701
|
+
return prepcovars.parse_xinfo(xinfo)
|
|
702
|
+
elif usequants:
|
|
427
703
|
return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
|
|
428
704
|
else:
|
|
429
705
|
return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
|
|
430
706
|
|
|
431
707
|
@staticmethod
|
|
432
|
-
def _bin_predictors(x, splits):
|
|
708
|
+
def _bin_predictors(x, splits) -> UInt[Array, 'p n']:
|
|
433
709
|
return prepcovars.bin_predictors(x, splits)
|
|
434
710
|
|
|
435
711
|
@staticmethod
|
|
436
712
|
def _setup_mcmc(
|
|
437
|
-
x_train,
|
|
438
|
-
y_train,
|
|
439
|
-
offset,
|
|
440
|
-
w,
|
|
441
|
-
max_split,
|
|
442
|
-
lamda,
|
|
443
|
-
sigma_mu,
|
|
444
|
-
sigdf,
|
|
445
|
-
power,
|
|
446
|
-
base,
|
|
447
|
-
maxdepth,
|
|
448
|
-
ntree,
|
|
449
|
-
init_kw,
|
|
713
|
+
x_train: Real[Array, 'p n'],
|
|
714
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
715
|
+
offset: Float32[Array, ''],
|
|
716
|
+
w: Float[Array, ' n'] | None,
|
|
717
|
+
max_split: UInt[Array, ' p'],
|
|
718
|
+
lamda: Float32[Array, ''] | None,
|
|
719
|
+
sigma_mu: FloatLike,
|
|
720
|
+
sigdf: FloatLike,
|
|
721
|
+
power: FloatLike,
|
|
722
|
+
base: FloatLike,
|
|
723
|
+
maxdepth: int,
|
|
724
|
+
ntree: int,
|
|
725
|
+
init_kw: dict[str, Any] | None,
|
|
726
|
+
rm_const: bool | None,
|
|
727
|
+
theta: FloatLike | None,
|
|
728
|
+
a: FloatLike | None,
|
|
729
|
+
b: FloatLike | None,
|
|
730
|
+
rho: FloatLike | None,
|
|
450
731
|
):
|
|
451
732
|
depth = jnp.arange(maxdepth - 1)
|
|
452
733
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
@@ -470,14 +751,42 @@ class gbart:
|
|
|
470
751
|
sigma_mu2=jnp.square(sigma_mu),
|
|
471
752
|
sigma2_alpha=sigma2_alpha,
|
|
472
753
|
sigma2_beta=sigma2_beta,
|
|
754
|
+
min_points_per_decision_node=10,
|
|
473
755
|
min_points_per_leaf=5,
|
|
756
|
+
theta=theta,
|
|
757
|
+
a=a,
|
|
758
|
+
b=b,
|
|
759
|
+
rho=rho,
|
|
474
760
|
)
|
|
761
|
+
|
|
762
|
+
if rm_const is None:
|
|
763
|
+
kw.update(filter_splitless_vars=False)
|
|
764
|
+
elif rm_const:
|
|
765
|
+
kw.update(filter_splitless_vars=True)
|
|
766
|
+
else:
|
|
767
|
+
n_empty = jnp.count_nonzero(max_split == 0)
|
|
768
|
+
if n_empty:
|
|
769
|
+
msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
|
|
770
|
+
raise ValueError(msg)
|
|
771
|
+
kw.update(filter_splitless_vars=False)
|
|
772
|
+
|
|
475
773
|
if init_kw is not None:
|
|
476
774
|
kw.update(init_kw)
|
|
775
|
+
|
|
477
776
|
return mcmcstep.init(**kw)
|
|
478
777
|
|
|
479
778
|
@staticmethod
|
|
480
|
-
def _run_mcmc(
|
|
779
|
+
def _run_mcmc(
|
|
780
|
+
mcmc_state: mcmcstep.State,
|
|
781
|
+
ndpost: int,
|
|
782
|
+
nskip: int,
|
|
783
|
+
keepevery: int,
|
|
784
|
+
printevery: int | None,
|
|
785
|
+
seed: int | Integer[Array, ''] | Key[Array, ''],
|
|
786
|
+
run_mcmc_kw: dict | None,
|
|
787
|
+
sparse: bool,
|
|
788
|
+
):
|
|
789
|
+
# prepare random generator seed
|
|
481
790
|
if isinstance(seed, jax.Array) and jnp.issubdtype(
|
|
482
791
|
seed.dtype, jax.dtypes.prng_key
|
|
483
792
|
):
|
|
@@ -486,118 +795,19 @@ class gbart:
|
|
|
486
795
|
else:
|
|
487
796
|
key = jax.random.key(seed)
|
|
488
797
|
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
798
|
+
# prepare arguments
|
|
799
|
+
kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
|
|
800
|
+
kw.update(
|
|
801
|
+
mcmcloop.make_default_callback(
|
|
802
|
+
dot_every=None if printevery is None or printevery == 1 else 1,
|
|
803
|
+
report_every=printevery,
|
|
804
|
+
sparse_on_at=nskip // 2 if sparse else None,
|
|
805
|
+
)
|
|
494
806
|
)
|
|
495
|
-
if printevery is not None:
|
|
496
|
-
kw.update(mcmcloop.make_print_callbacks())
|
|
497
807
|
if run_mcmc_kw is not None:
|
|
498
808
|
kw.update(run_mcmc_kw)
|
|
499
809
|
|
|
500
810
|
return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
|
|
501
811
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
if trace['sigma2'] is None:
|
|
505
|
-
return None
|
|
506
|
-
else:
|
|
507
|
-
return jnp.sqrt(trace['sigma2'])
|
|
508
|
-
|
|
509
|
-
@staticmethod
|
|
510
|
-
def _predict(trace, x):
|
|
511
|
-
return mcmcloop.evaluate_trace(trace, x)
|
|
512
|
-
|
|
513
|
-
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
514
|
-
from . import debug
|
|
515
|
-
|
|
516
|
-
trace = self._main_trace
|
|
517
|
-
leaf_tree = trace['leaf_trees'][i_sample, i_tree]
|
|
518
|
-
var_tree = trace['var_trees'][i_sample, i_tree]
|
|
519
|
-
split_tree = trace['split_trees'][i_sample, i_tree]
|
|
520
|
-
debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
|
|
521
|
-
|
|
522
|
-
def _sigma_harmonic_mean(self, prior=False):
|
|
523
|
-
bart = self._mcmc_state
|
|
524
|
-
if prior:
|
|
525
|
-
alpha = bart['sigma2_alpha']
|
|
526
|
-
beta = bart['sigma2_beta']
|
|
527
|
-
else:
|
|
528
|
-
resid = bart['resid']
|
|
529
|
-
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
530
|
-
norm2 = jnp.dot(
|
|
531
|
-
resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
|
|
532
|
-
)
|
|
533
|
-
beta = bart['sigma2_beta'] + norm2 / 2
|
|
534
|
-
sigma2 = beta / alpha
|
|
535
|
-
return jnp.sqrt(sigma2)
|
|
536
|
-
|
|
537
|
-
def _compare_resid(self):
|
|
538
|
-
bart = self._mcmc_state
|
|
539
|
-
resid1 = bart.resid
|
|
540
|
-
|
|
541
|
-
trees = grove.evaluate_forest(
|
|
542
|
-
bart.X,
|
|
543
|
-
bart.forest.leaf_trees,
|
|
544
|
-
bart.forest.var_trees,
|
|
545
|
-
bart.forest.split_trees,
|
|
546
|
-
jnp.float32, # TODO remove these configurable dtypes around
|
|
547
|
-
)
|
|
548
|
-
|
|
549
|
-
if bart.z is not None:
|
|
550
|
-
ref = bart.z
|
|
551
|
-
else:
|
|
552
|
-
ref = bart.y
|
|
553
|
-
resid2 = ref - (trees + bart.offset)
|
|
554
|
-
|
|
555
|
-
return resid1, resid2
|
|
556
|
-
|
|
557
|
-
def _avg_acc(self):
|
|
558
|
-
trace = self._main_trace
|
|
559
|
-
|
|
560
|
-
def acc(prefix):
|
|
561
|
-
acc = trace[f'{prefix}_acc_count']
|
|
562
|
-
prop = trace[f'{prefix}_prop_count']
|
|
563
|
-
return acc.sum() / prop.sum()
|
|
564
|
-
|
|
565
|
-
return acc('grow'), acc('prune')
|
|
566
|
-
|
|
567
|
-
def _avg_prop(self):
|
|
568
|
-
trace = self._main_trace
|
|
569
|
-
|
|
570
|
-
def prop(prefix):
|
|
571
|
-
return trace[f'{prefix}_prop_count'].sum()
|
|
572
|
-
|
|
573
|
-
pgrow = prop('grow')
|
|
574
|
-
pprune = prop('prune')
|
|
575
|
-
total = pgrow + pprune
|
|
576
|
-
return pgrow / total, pprune / total
|
|
577
|
-
|
|
578
|
-
def _avg_move(self):
|
|
579
|
-
agrow, aprune = self._avg_acc()
|
|
580
|
-
pgrow, pprune = self._avg_prop()
|
|
581
|
-
return agrow * pgrow, aprune * pprune
|
|
582
|
-
|
|
583
|
-
def _depth_distr(self):
|
|
584
|
-
from . import debug
|
|
585
|
-
|
|
586
|
-
trace = self._main_trace
|
|
587
|
-
split_trees = trace['split_trees']
|
|
588
|
-
return debug.trace_depth_distr(split_trees)
|
|
589
|
-
|
|
590
|
-
def _points_per_leaf_distr(self):
|
|
591
|
-
from . import debug
|
|
592
|
-
|
|
593
|
-
return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
|
|
594
|
-
|
|
595
|
-
def _check_trees(self):
|
|
596
|
-
from . import debug
|
|
597
|
-
|
|
598
|
-
return debug.check_trace(self._main_trace, self._mcmc_state)
|
|
599
|
-
|
|
600
|
-
def _tree_goes_bad(self):
|
|
601
|
-
bad = self._check_trees().astype(bool)
|
|
602
|
-
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
603
|
-
return bad & ~bad_before
|
|
812
|
+
def _predict(self, x):
|
|
813
|
+
return mcmcloop.evaluate_trace(self._main_trace, x)
|