bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/_interface.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
1
|
+
# bartz/src/bartz/_interface.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Main high-level interface of the package."""
|
|
26
|
+
|
|
27
|
+
import math
|
|
28
|
+
from collections.abc import Sequence
|
|
29
|
+
from functools import cached_property
|
|
30
|
+
from typing import Any, Literal, Protocol, TypedDict
|
|
31
|
+
|
|
32
|
+
import jax
|
|
33
|
+
import jax.numpy as jnp
|
|
34
|
+
from equinox import Module, field
|
|
35
|
+
from jax import Device, device_put, jit, make_mesh
|
|
36
|
+
from jax.lax import collapse
|
|
37
|
+
from jax.scipy.special import ndtr
|
|
38
|
+
from jax.sharding import AxisType, Mesh
|
|
39
|
+
from jaxtyping import (
|
|
40
|
+
Array,
|
|
41
|
+
Bool,
|
|
42
|
+
Float,
|
|
43
|
+
Float32,
|
|
44
|
+
Int32,
|
|
45
|
+
Integer,
|
|
46
|
+
Key,
|
|
47
|
+
Real,
|
|
48
|
+
Shaped,
|
|
49
|
+
UInt,
|
|
50
|
+
)
|
|
51
|
+
from numpy import ndarray
|
|
52
|
+
|
|
53
|
+
from bartz import mcmcloop, mcmcstep, prepcovars
|
|
54
|
+
from bartz.jaxext import is_key
|
|
55
|
+
from bartz.jaxext.scipy.special import ndtri
|
|
56
|
+
from bartz.jaxext.scipy.stats import invgamma
|
|
57
|
+
from bartz.mcmcloop import compute_varcount, evaluate_trace, run_mcmc
|
|
58
|
+
from bartz.mcmcstep import make_p_nonterminal
|
|
59
|
+
from bartz.mcmcstep._state import get_num_chains
|
|
60
|
+
|
|
61
|
+
FloatLike = float | Float[Any, '']
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DataFrame(Protocol):
|
|
65
|
+
"""DataFrame duck-type for `Bart`."""
|
|
66
|
+
|
|
67
|
+
columns: Sequence[str]
|
|
68
|
+
"""The names of the columns."""
|
|
69
|
+
|
|
70
|
+
def to_numpy(self) -> ndarray:
|
|
71
|
+
"""Convert the dataframe to a 2d numpy array with columns on the second axis."""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Series(Protocol):
|
|
76
|
+
"""Series duck-type for `Bart`."""
|
|
77
|
+
|
|
78
|
+
name: str | None
|
|
79
|
+
"""The name of the series."""
|
|
80
|
+
|
|
81
|
+
def to_numpy(self) -> ndarray:
|
|
82
|
+
"""Convert the series to a 1d numpy array."""
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Bart(Module):
|
|
87
|
+
R"""
|
|
88
|
+
Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
|
|
89
|
+
|
|
90
|
+
Regress `y_train` on `x_train` with a latent mean function represented as
|
|
91
|
+
a sum of decision trees. The inference is carried out by sampling the
|
|
92
|
+
posterior distribution of the tree ensemble with an MCMC.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
x_train
|
|
97
|
+
The training predictors.
|
|
98
|
+
y_train
|
|
99
|
+
The training responses.
|
|
100
|
+
x_test
|
|
101
|
+
The test predictors.
|
|
102
|
+
type
|
|
103
|
+
The type of regression. 'wbart' for continuous regression, 'pbart' for
|
|
104
|
+
binary regression with probit link.
|
|
105
|
+
sparse
|
|
106
|
+
Whether to activate variable selection on the predictors as done in
|
|
107
|
+
[1]_.
|
|
108
|
+
theta
|
|
109
|
+
a
|
|
110
|
+
b
|
|
111
|
+
rho
|
|
112
|
+
Hyperparameters of the sparsity prior used for variable selection.
|
|
113
|
+
|
|
114
|
+
The prior distribution on the choice of predictor for each decision rule
|
|
115
|
+
is
|
|
116
|
+
|
|
117
|
+
.. math::
|
|
118
|
+
(s_1, \ldots, s_p) \sim
|
|
119
|
+
\operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
|
|
120
|
+
|
|
121
|
+
If `theta` is not specified, it's a priori distributed according to
|
|
122
|
+
|
|
123
|
+
.. math::
|
|
124
|
+
\frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
|
|
125
|
+
\operatorname{Beta}(\mathtt{a}, \mathtt{b}).
|
|
126
|
+
|
|
127
|
+
If not specified, `rho` is set to the number of predictors p. To tune
|
|
128
|
+
the prior, consider setting a lower `rho` to prefer more sparsity.
|
|
129
|
+
If setting `theta` directly, it should be in the ballpark of p or lower
|
|
130
|
+
as well.
|
|
131
|
+
xinfo
|
|
132
|
+
A matrix with the cutpoins to use to bin each predictor. If not
|
|
133
|
+
specified, it is generated automatically according to `usequants` and
|
|
134
|
+
`numcut`.
|
|
135
|
+
|
|
136
|
+
Each row shall contain a sorted list of cutpoints for a predictor. If
|
|
137
|
+
there are less cutpoints than the number of columns in the matrix,
|
|
138
|
+
fill the remaining cells with NaN.
|
|
139
|
+
|
|
140
|
+
`xinfo` shall be a matrix even if `x_train` is a dataframe.
|
|
141
|
+
usequants
|
|
142
|
+
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
143
|
+
predictors. Ignored if `xinfo` is specified.
|
|
144
|
+
rm_const
|
|
145
|
+
How to treat predictors with no associated decision rules (i.e., there
|
|
146
|
+
are no available cutpoints for that predictor). If `True` (default),
|
|
147
|
+
they are ignored. If `False`, an error is raised if there are any.
|
|
148
|
+
sigest
|
|
149
|
+
An estimate of the residual standard deviation on `y_train`, used to set
|
|
150
|
+
`lamda`. If not specified, it is estimated by linear regression (with
|
|
151
|
+
intercept, and without taking into account `w`). If `y_train` has less
|
|
152
|
+
than two elements, it is set to 1. If n <= p, it is set to the standard
|
|
153
|
+
deviation of `y_train`. Ignored if `lamda` is specified.
|
|
154
|
+
sigdf
|
|
155
|
+
The degrees of freedom of the scaled inverse-chisquared prior on the
|
|
156
|
+
noise variance.
|
|
157
|
+
sigquant
|
|
158
|
+
The quantile of the prior on the noise variance that shall match
|
|
159
|
+
`sigest` to set the scale of the prior. Ignored if `lamda` is specified.
|
|
160
|
+
k
|
|
161
|
+
The inverse scale of the prior standard deviation on the latent mean
|
|
162
|
+
function, relative to half the observed range of `y_train`. If `y_train`
|
|
163
|
+
has less than two elements, `k` is ignored and the scale is set to 1.
|
|
164
|
+
power
|
|
165
|
+
base
|
|
166
|
+
Parameters of the prior on tree node generation. The probability that a
|
|
167
|
+
node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
|
|
168
|
+
power``.
|
|
169
|
+
lamda
|
|
170
|
+
The prior harmonic mean of the error variance. (The harmonic mean of x
|
|
171
|
+
is 1/mean(1/x).) If not specified, it is set based on `sigest` and
|
|
172
|
+
`sigquant`.
|
|
173
|
+
tau_num
|
|
174
|
+
The numerator in the expression that determines the prior standard
|
|
175
|
+
deviation of leaves. If not specified, default to ``(max(y_train) -
|
|
176
|
+
min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
|
|
177
|
+
continuous regression, and 3 for binary regression.
|
|
178
|
+
offset
|
|
179
|
+
The prior mean of the latent mean function. If not specified, it is set
|
|
180
|
+
to the mean of `y_train` for continuous regression, and to
|
|
181
|
+
``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
|
|
182
|
+
`offset` is set to 0. With binary regression, if `y_train` is all
|
|
183
|
+
`False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
|
|
184
|
+
``Phi^-1(n/(n+1))``, respectively.
|
|
185
|
+
w
|
|
186
|
+
Coefficients that rescale the error standard deviation on each
|
|
187
|
+
datapoint. Not specifying `w` is equivalent to setting it to 1 for all
|
|
188
|
+
datapoints. Note: `w` is ignored in the automatic determination of
|
|
189
|
+
`sigest`, so either the weights should be O(1), or `sigest` should be
|
|
190
|
+
specified by the user.
|
|
191
|
+
ntree
|
|
192
|
+
The number of trees used to represent the latent mean function. By
|
|
193
|
+
default 200 for continuous regression and 50 for binary regression.
|
|
194
|
+
numcut
|
|
195
|
+
If `usequants` is `False`: the exact number of cutpoints used to bin the
|
|
196
|
+
predictors, ranging between the minimum and maximum observed values
|
|
197
|
+
(excluded).
|
|
198
|
+
|
|
199
|
+
If `usequants` is `True`: the maximum number of cutpoints to use for
|
|
200
|
+
binning the predictors. Each predictor is binned such that its
|
|
201
|
+
distribution in `x_train` is approximately uniform across bins. The
|
|
202
|
+
number of bins is at most the number of unique values appearing in
|
|
203
|
+
`x_train`, or ``numcut + 1``.
|
|
204
|
+
|
|
205
|
+
Before running the algorithm, the predictors are compressed to the
|
|
206
|
+
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
207
|
+
to the maximum value of an unsigned integer type, like 255.
|
|
208
|
+
|
|
209
|
+
Ignored if `xinfo` is specified.
|
|
210
|
+
ndpost
|
|
211
|
+
The number of MCMC samples to save, after burn-in. `ndpost` is the
|
|
212
|
+
total number of samples across all chains. `ndpost` is rounded up to the
|
|
213
|
+
first multiple of `mc_cores`.
|
|
214
|
+
nskip
|
|
215
|
+
The number of initial MCMC samples to discard as burn-in. This number
|
|
216
|
+
of samples is discarded from each chain.
|
|
217
|
+
keepevery
|
|
218
|
+
The thinning factor for the MCMC samples, after burn-in. By default, 1
|
|
219
|
+
for continuous regression and 10 for binary regression.
|
|
220
|
+
printevery
|
|
221
|
+
The number of iterations (including thinned-away ones) between each log
|
|
222
|
+
line. Set to `None` to disable logging. ^C interrupts the MCMC only
|
|
223
|
+
every `printevery` iterations, so with logging disabled it's impossible
|
|
224
|
+
to kill the MCMC conveniently.
|
|
225
|
+
num_chains
|
|
226
|
+
The number of independent Markov chains to run. By default only one
|
|
227
|
+
chain is run.
|
|
228
|
+
|
|
229
|
+
The difference between not specifying `num_chains` and setting it to 1
|
|
230
|
+
is that in the latter case in the object attributes and some methods
|
|
231
|
+
there will be an explicit chain axis of size 1.
|
|
232
|
+
num_chain_devices
|
|
233
|
+
The number of devices to spread the chains across. Must be a divisor of
|
|
234
|
+
`num_chains`. Each device will run a fraction of the chains.
|
|
235
|
+
num_data_devices
|
|
236
|
+
The number of devices to split datapoints across. Must be a divisor of
|
|
237
|
+
`n`. This is useful only with very high `n`, about > 1000_000.
|
|
238
|
+
|
|
239
|
+
If both num_chain_devices and num_data_devices are specified, the total
|
|
240
|
+
number of devices used is the product of the two.
|
|
241
|
+
devices
|
|
242
|
+
One or more devices used to run the MCMC on. If not specified, the
|
|
243
|
+
computation will follow the placement of the input arrays. If a list of
|
|
244
|
+
devices, this argument can be longer than the number of devices needed.
|
|
245
|
+
seed
|
|
246
|
+
The seed for the random number generator.
|
|
247
|
+
maxdepth
|
|
248
|
+
The maximum depth of the trees. This is 1-based, so with the default
|
|
249
|
+
``maxdepth=6``, the depths of the levels range from 0 to 5.
|
|
250
|
+
init_kw
|
|
251
|
+
Additional arguments passed to `bartz.mcmcstep.init`.
|
|
252
|
+
run_mcmc_kw
|
|
253
|
+
Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
|
|
254
|
+
|
|
255
|
+
References
|
|
256
|
+
----------
|
|
257
|
+
.. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
|
|
258
|
+
High-Dimensional Prediction and Variable Selection”. In: Journal of the
|
|
259
|
+
American Statistical Association 113.522, pp. 626-636.
|
|
260
|
+
.. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
|
|
261
|
+
Bayesian additive regression trees," The Annals of Applied Statistics,
|
|
262
|
+
Ann. Appl. Stat. 4(1), 266-298, (March 2010).
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
_main_trace: mcmcloop.MainTrace
|
|
266
|
+
_burnin_trace: mcmcloop.BurninTrace
|
|
267
|
+
_mcmc_state: mcmcstep.State
|
|
268
|
+
_splits: Real[Array, 'p max_num_splits']
|
|
269
|
+
_x_train_fmt: Any = field(static=True)
|
|
270
|
+
|
|
271
|
+
offset: Float32[Array, '']
|
|
272
|
+
"""The prior mean of the latent mean function."""
|
|
273
|
+
|
|
274
|
+
sigest: Float32[Array, ''] | None = None
|
|
275
|
+
"""The estimated standard deviation of the error used to set `lamda`."""
|
|
276
|
+
|
|
277
|
+
yhat_test: Float32[Array, 'ndpost m'] | None = None
|
|
278
|
+
"""The conditional posterior mean at `x_test` for each MCMC iteration."""
|
|
279
|
+
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
x_train: Real[Array, 'p n'] | DataFrame,
|
|
283
|
+
y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
|
|
284
|
+
*,
|
|
285
|
+
x_test: Real[Array, 'p m'] | DataFrame | None = None,
|
|
286
|
+
type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
|
|
287
|
+
sparse: bool = False,
|
|
288
|
+
theta: FloatLike | None = None,
|
|
289
|
+
a: FloatLike = 0.5,
|
|
290
|
+
b: FloatLike = 1.0,
|
|
291
|
+
rho: FloatLike | None = None,
|
|
292
|
+
xinfo: Float[Array, 'p n'] | None = None,
|
|
293
|
+
usequants: bool = False,
|
|
294
|
+
rm_const: bool = True,
|
|
295
|
+
sigest: FloatLike | None = None,
|
|
296
|
+
sigdf: FloatLike = 3.0,
|
|
297
|
+
sigquant: FloatLike = 0.9,
|
|
298
|
+
k: FloatLike = 2.0,
|
|
299
|
+
power: FloatLike = 2.0,
|
|
300
|
+
base: FloatLike = 0.95,
|
|
301
|
+
lamda: FloatLike | None = None,
|
|
302
|
+
tau_num: FloatLike | None = None,
|
|
303
|
+
offset: FloatLike | None = None,
|
|
304
|
+
w: Float[Array, ' n'] | Series | None = None,
|
|
305
|
+
ntree: int | None = None,
|
|
306
|
+
numcut: int = 100,
|
|
307
|
+
ndpost: int = 1000,
|
|
308
|
+
nskip: int = 100,
|
|
309
|
+
keepevery: int | None = None,
|
|
310
|
+
printevery: int | None = 100,
|
|
311
|
+
num_chains: int | None = None,
|
|
312
|
+
num_chain_devices: int | None = None,
|
|
313
|
+
num_data_devices: int | None = None,
|
|
314
|
+
devices: Device | Sequence[Device] | None = None,
|
|
315
|
+
seed: int | Key[Array, ''] = 0,
|
|
316
|
+
maxdepth: int = 6,
|
|
317
|
+
init_kw: dict | None = None,
|
|
318
|
+
run_mcmc_kw: dict | None = None,
|
|
319
|
+
):
|
|
320
|
+
# check data and put it in the right format
|
|
321
|
+
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
322
|
+
y_train = self._process_response_input(y_train)
|
|
323
|
+
self._check_same_length(x_train, y_train)
|
|
324
|
+
if w is not None:
|
|
325
|
+
w = self._process_response_input(w)
|
|
326
|
+
self._check_same_length(x_train, w)
|
|
327
|
+
|
|
328
|
+
# check data types are correct for continuous/binary regression
|
|
329
|
+
self._check_type_settings(y_train, type, w)
|
|
330
|
+
# from here onwards, the type is determined by y_train.dtype == bool
|
|
331
|
+
|
|
332
|
+
# set defaults that depend on type of regression
|
|
333
|
+
if ntree is None:
|
|
334
|
+
ntree = 50 if y_train.dtype == bool else 200
|
|
335
|
+
if keepevery is None:
|
|
336
|
+
keepevery = 10 if y_train.dtype == bool else 1
|
|
337
|
+
|
|
338
|
+
# process sparsity settings
|
|
339
|
+
theta, a, b, rho = self._process_sparsity_settings(
|
|
340
|
+
x_train, sparse, theta, a, b, rho
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# process "standardization" settings
|
|
344
|
+
offset = self._process_offset_settings(y_train, offset)
|
|
345
|
+
sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
|
|
346
|
+
lamda, sigest = self._process_error_variance_settings(
|
|
347
|
+
x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# determine splits
|
|
351
|
+
splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo)
|
|
352
|
+
x_train = self._bin_predictors(x_train, splits)
|
|
353
|
+
|
|
354
|
+
# setup and run mcmc
|
|
355
|
+
initial_state = self._setup_mcmc(
|
|
356
|
+
x_train,
|
|
357
|
+
y_train,
|
|
358
|
+
offset,
|
|
359
|
+
w,
|
|
360
|
+
max_split,
|
|
361
|
+
lamda,
|
|
362
|
+
sigma_mu,
|
|
363
|
+
sigdf,
|
|
364
|
+
power,
|
|
365
|
+
base,
|
|
366
|
+
maxdepth,
|
|
367
|
+
ntree,
|
|
368
|
+
init_kw,
|
|
369
|
+
rm_const,
|
|
370
|
+
theta,
|
|
371
|
+
a,
|
|
372
|
+
b,
|
|
373
|
+
rho,
|
|
374
|
+
num_chains,
|
|
375
|
+
num_chain_devices,
|
|
376
|
+
num_data_devices,
|
|
377
|
+
devices,
|
|
378
|
+
sparse,
|
|
379
|
+
nskip,
|
|
380
|
+
)
|
|
381
|
+
final_state, burnin_trace, main_trace = self._run_mcmc(
|
|
382
|
+
initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# set public attributes
|
|
386
|
+
self.offset = final_state.offset # from the state because of buffer donation
|
|
387
|
+
self.sigest = sigest
|
|
388
|
+
|
|
389
|
+
# set private attributes
|
|
390
|
+
self._main_trace = main_trace
|
|
391
|
+
self._burnin_trace = burnin_trace
|
|
392
|
+
self._mcmc_state = final_state
|
|
393
|
+
self._splits = splits
|
|
394
|
+
self._x_train_fmt = x_train_fmt
|
|
395
|
+
|
|
396
|
+
# predict at test points
|
|
397
|
+
if x_test is not None:
|
|
398
|
+
self.yhat_test = self.predict(x_test)
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def ndpost(self):
|
|
402
|
+
"""The total number of posterior samples after burn-in across all chains.
|
|
403
|
+
|
|
404
|
+
May be larger than the initialization argument `ndpost` if it was not
|
|
405
|
+
divisible by the number of chains.
|
|
406
|
+
"""
|
|
407
|
+
return self._main_trace.grow_prop_count.size
|
|
408
|
+
|
|
409
|
+
@cached_property
|
|
410
|
+
def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
|
|
411
|
+
"""The posterior probability of y being True at `x_test` for each MCMC iteration."""
|
|
412
|
+
if self.yhat_test is None or self._mcmc_state.y.dtype != bool:
|
|
413
|
+
return None
|
|
414
|
+
else:
|
|
415
|
+
return ndtr(self.yhat_test)
|
|
416
|
+
|
|
417
|
+
@cached_property
|
|
418
|
+
def prob_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
419
|
+
"""The marginal posterior probability of y being True at `x_test`."""
|
|
420
|
+
if self.prob_test is None:
|
|
421
|
+
return None
|
|
422
|
+
else:
|
|
423
|
+
return self.prob_test.mean(axis=0)
|
|
424
|
+
|
|
425
|
+
@cached_property
|
|
426
|
+
def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
|
|
427
|
+
"""The posterior probability of y being True at `x_train` for each MCMC iteration."""
|
|
428
|
+
if self._mcmc_state.y.dtype == bool:
|
|
429
|
+
return ndtr(self.yhat_train)
|
|
430
|
+
else:
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
@cached_property
|
|
434
|
+
def prob_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
435
|
+
"""The marginal posterior probability of y being True at `x_train`."""
|
|
436
|
+
if self.prob_train is None:
|
|
437
|
+
return None
|
|
438
|
+
else:
|
|
439
|
+
return self.prob_train.mean(axis=0)
|
|
440
|
+
|
|
441
|
+
@cached_property
|
|
442
|
+
def sigma(
|
|
443
|
+
self,
|
|
444
|
+
) -> (
|
|
445
|
+
Float32[Array, ' nskip+ndpost']
|
|
446
|
+
| Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
|
|
447
|
+
| None
|
|
448
|
+
):
|
|
449
|
+
"""The standard deviation of the error, including burn-in samples."""
|
|
450
|
+
if self._burnin_trace.error_cov_inv is None:
|
|
451
|
+
return None
|
|
452
|
+
assert self._main_trace.error_cov_inv is not None
|
|
453
|
+
return jnp.sqrt(
|
|
454
|
+
jnp.reciprocal(
|
|
455
|
+
jnp.concatenate(
|
|
456
|
+
[
|
|
457
|
+
self._burnin_trace.error_cov_inv.T,
|
|
458
|
+
self._main_trace.error_cov_inv.T,
|
|
459
|
+
],
|
|
460
|
+
axis=0,
|
|
461
|
+
# error_cov_inv has shape (chains? samples) in the trace
|
|
462
|
+
)
|
|
463
|
+
)
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
@cached_property
|
|
467
|
+
def sigma_(self) -> Float32[Array, 'ndpost'] | None:
|
|
468
|
+
"""The standard deviation of the error, only over the post-burnin samples and flattened."""
|
|
469
|
+
error_cov_inv = self._main_trace.error_cov_inv
|
|
470
|
+
if error_cov_inv is None:
|
|
471
|
+
return None
|
|
472
|
+
else:
|
|
473
|
+
return jnp.sqrt(jnp.reciprocal(error_cov_inv)).reshape(-1)
|
|
474
|
+
|
|
475
|
+
@cached_property
|
|
476
|
+
def sigma_mean(self) -> Float32[Array, ''] | None:
|
|
477
|
+
"""The mean of `sigma`, only over the post-burnin samples."""
|
|
478
|
+
if self.sigma_ is None:
|
|
479
|
+
return None
|
|
480
|
+
return self.sigma_.mean()
|
|
481
|
+
|
|
482
|
+
@cached_property
|
|
483
|
+
def varcount(self) -> Int32[Array, 'ndpost p']:
|
|
484
|
+
"""Histogram of predictor usage for decision rules in the trees."""
|
|
485
|
+
p = self._mcmc_state.forest.max_split.size
|
|
486
|
+
varcount: Int32[Array, '*chains samples p']
|
|
487
|
+
varcount = compute_varcount(p, self._main_trace)
|
|
488
|
+
return collapse(varcount, 0, -1)
|
|
489
|
+
|
|
490
|
+
@cached_property
|
|
491
|
+
def varcount_mean(self) -> Float32[Array, ' p']:
|
|
492
|
+
"""Average of `varcount` across MCMC iterations."""
|
|
493
|
+
return self.varcount.mean(axis=0)
|
|
494
|
+
|
|
495
|
+
@cached_property
|
|
496
|
+
def varprob(self) -> Float32[Array, 'ndpost p']:
|
|
497
|
+
"""Posterior samples of the probability of choosing each predictor for a decision rule."""
|
|
498
|
+
max_split = self._mcmc_state.forest.max_split
|
|
499
|
+
p = max_split.size
|
|
500
|
+
varprob = self._main_trace.varprob
|
|
501
|
+
if varprob is None:
|
|
502
|
+
peff = jnp.count_nonzero(max_split)
|
|
503
|
+
varprob = jnp.where(max_split, 1 / peff, 0)
|
|
504
|
+
varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
|
|
505
|
+
else:
|
|
506
|
+
varprob = varprob.reshape(-1, p)
|
|
507
|
+
return varprob
|
|
508
|
+
|
|
509
|
+
@cached_property
|
|
510
|
+
def varprob_mean(self) -> Float32[Array, ' p']:
|
|
511
|
+
"""The marginal posterior probability of each predictor being chosen for a decision rule."""
|
|
512
|
+
return self.varprob.mean(axis=0)
|
|
513
|
+
|
|
514
|
+
@cached_property
|
|
515
|
+
def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
|
|
516
|
+
"""The marginal posterior mean at `x_test`.
|
|
517
|
+
|
|
518
|
+
Not defined with binary regression because it's error-prone, typically
|
|
519
|
+
the right thing to consider would be `prob_test_mean`.
|
|
520
|
+
"""
|
|
521
|
+
if self.yhat_test is None or self._mcmc_state.y.dtype == bool:
|
|
522
|
+
return None
|
|
523
|
+
else:
|
|
524
|
+
return self.yhat_test.mean(axis=0)
|
|
525
|
+
|
|
526
|
+
@cached_property
|
|
527
|
+
def yhat_train(self) -> Float32[Array, 'ndpost n']:
|
|
528
|
+
"""The conditional posterior mean at `x_train` for each MCMC iteration."""
|
|
529
|
+
x_train = self._mcmc_state.X
|
|
530
|
+
return self._predict(x_train)
|
|
531
|
+
|
|
532
|
+
@cached_property
|
|
533
|
+
def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
|
|
534
|
+
"""The marginal posterior mean at `x_train`.
|
|
535
|
+
|
|
536
|
+
Not defined with binary regression because it's error-prone, typically
|
|
537
|
+
the right thing to consider would be `prob_train_mean`.
|
|
538
|
+
"""
|
|
539
|
+
if self._mcmc_state.y.dtype == bool:
|
|
540
|
+
return None
|
|
541
|
+
else:
|
|
542
|
+
return self.yhat_train.mean(axis=0)
|
|
543
|
+
|
|
544
|
+
def predict(
|
|
545
|
+
self, x_test: Real[Array, 'p m'] | DataFrame
|
|
546
|
+
) -> Float32[Array, 'ndpost m']:
|
|
547
|
+
"""
|
|
548
|
+
Compute the posterior mean at `x_test` for each MCMC iteration.
|
|
549
|
+
|
|
550
|
+
Parameters
|
|
551
|
+
----------
|
|
552
|
+
x_test
|
|
553
|
+
The test predictors.
|
|
554
|
+
|
|
555
|
+
Returns
|
|
556
|
+
-------
|
|
557
|
+
The conditional posterior mean at `x_test` for each MCMC iteration.
|
|
558
|
+
|
|
559
|
+
Raises
|
|
560
|
+
------
|
|
561
|
+
ValueError
|
|
562
|
+
If `x_test` has a different format than `x_train`.
|
|
563
|
+
"""
|
|
564
|
+
x_test, x_test_fmt = self._process_predictor_input(x_test)
|
|
565
|
+
if x_test_fmt != self._x_train_fmt:
|
|
566
|
+
msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
|
|
567
|
+
raise ValueError(msg)
|
|
568
|
+
x_test = self._bin_predictors(x_test, self._splits)
|
|
569
|
+
return self._predict(x_test)
|
|
570
|
+
|
|
571
|
+
@staticmethod
|
|
572
|
+
def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
|
|
573
|
+
if hasattr(x, 'columns'):
|
|
574
|
+
fmt = dict(kind='dataframe', columns=x.columns)
|
|
575
|
+
x = x.to_numpy().T
|
|
576
|
+
else:
|
|
577
|
+
fmt = dict(kind='array', num_covar=x.shape[0])
|
|
578
|
+
x = jnp.asarray(x)
|
|
579
|
+
assert x.ndim == 2
|
|
580
|
+
return x, fmt
|
|
581
|
+
|
|
582
|
+
@staticmethod
|
|
583
|
+
def _process_response_input(y) -> Shaped[Array, ' n']:
|
|
584
|
+
if hasattr(y, 'to_numpy'):
|
|
585
|
+
y = y.to_numpy()
|
|
586
|
+
y = jnp.asarray(y)
|
|
587
|
+
assert y.ndim == 1
|
|
588
|
+
return y
|
|
589
|
+
|
|
590
|
+
@staticmethod
|
|
591
|
+
def _check_same_length(x1, x2):
|
|
592
|
+
get_length = lambda x: x.shape[-1]
|
|
593
|
+
assert get_length(x1) == get_length(x2)
|
|
594
|
+
|
|
595
|
+
@classmethod
|
|
596
|
+
def _process_error_variance_settings(
|
|
597
|
+
cls, x_train, y_train, sigest, sigdf, sigquant, lamda
|
|
598
|
+
) -> tuple[Float32[Array, ''] | None, ...]:
|
|
599
|
+
"""Return (lamda, sigest)."""
|
|
600
|
+
if y_train.dtype == bool:
|
|
601
|
+
if sigest is not None:
|
|
602
|
+
msg = 'Let `sigest=None` for binary regression'
|
|
603
|
+
raise ValueError(msg)
|
|
604
|
+
if lamda is not None:
|
|
605
|
+
msg = 'Let `lamda=None` for binary regression'
|
|
606
|
+
raise ValueError(msg)
|
|
607
|
+
return None, None
|
|
608
|
+
elif lamda is not None:
|
|
609
|
+
if sigest is not None:
|
|
610
|
+
msg = 'Let `sigest=None` if `lamda` is specified'
|
|
611
|
+
raise ValueError(msg)
|
|
612
|
+
return lamda, None
|
|
613
|
+
else:
|
|
614
|
+
if sigest is not None:
|
|
615
|
+
sigest2 = jnp.square(sigest)
|
|
616
|
+
elif y_train.size < 2:
|
|
617
|
+
sigest2 = 1
|
|
618
|
+
elif y_train.size <= x_train.shape[0]:
|
|
619
|
+
sigest2 = jnp.var(y_train)
|
|
620
|
+
else:
|
|
621
|
+
sigest2 = cls._linear_regression(x_train, y_train)
|
|
622
|
+
alpha = sigdf / 2
|
|
623
|
+
invchi2 = invgamma.ppf(sigquant, alpha) / 2
|
|
624
|
+
invchi2rid = invchi2 * sigdf
|
|
625
|
+
return sigest2 / invchi2rid, jnp.sqrt(sigest2)
|
|
626
|
+
|
|
627
|
+
@staticmethod
|
|
628
|
+
@jit
|
|
629
|
+
def _linear_regression(
|
|
630
|
+
x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n']
|
|
631
|
+
):
|
|
632
|
+
"""Return the error variance estimated with OLS with intercept."""
|
|
633
|
+
x_centered = x_train.T - x_train.mean(axis=1)
|
|
634
|
+
y_centered = y_train - y_train.mean()
|
|
635
|
+
# centering is equivalent to adding an intercept column
|
|
636
|
+
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
637
|
+
chisq = chisq.squeeze(0)
|
|
638
|
+
dof = len(y_train) - rank
|
|
639
|
+
return chisq / dof
|
|
640
|
+
|
|
641
|
+
@staticmethod
|
|
642
|
+
def _check_type_settings(y_train, type, w): # noqa: A002
|
|
643
|
+
match type:
|
|
644
|
+
case 'wbart':
|
|
645
|
+
if y_train.dtype != jnp.float32:
|
|
646
|
+
msg = (
|
|
647
|
+
'Continuous regression requires y_train.dtype=float32,'
|
|
648
|
+
f' got {y_train.dtype=} instead.'
|
|
649
|
+
)
|
|
650
|
+
raise TypeError(msg)
|
|
651
|
+
case 'pbart':
|
|
652
|
+
if w is not None:
|
|
653
|
+
msg = 'Binary regression does not support weights, set `w=None`'
|
|
654
|
+
raise ValueError(msg)
|
|
655
|
+
if y_train.dtype != bool:
|
|
656
|
+
msg = (
|
|
657
|
+
'Binary regression requires y_train.dtype=bool,'
|
|
658
|
+
f' got {y_train.dtype=} instead.'
|
|
659
|
+
)
|
|
660
|
+
raise TypeError(msg)
|
|
661
|
+
case _:
|
|
662
|
+
msg = f'Invalid {type=}'
|
|
663
|
+
raise ValueError(msg)
|
|
664
|
+
|
|
665
|
+
@staticmethod
|
|
666
|
+
def _process_sparsity_settings(
|
|
667
|
+
x_train: Real[Array, 'p n'],
|
|
668
|
+
sparse: bool,
|
|
669
|
+
theta: FloatLike | None,
|
|
670
|
+
a: FloatLike,
|
|
671
|
+
b: FloatLike,
|
|
672
|
+
rho: FloatLike | None,
|
|
673
|
+
) -> (
|
|
674
|
+
tuple[None, None, None, None]
|
|
675
|
+
| tuple[FloatLike, None, None, None]
|
|
676
|
+
| tuple[None, FloatLike, FloatLike, FloatLike]
|
|
677
|
+
):
|
|
678
|
+
"""Return (theta, a, b, rho)."""
|
|
679
|
+
if not sparse:
|
|
680
|
+
return None, None, None, None
|
|
681
|
+
elif theta is not None:
|
|
682
|
+
return theta, None, None, None
|
|
683
|
+
else:
|
|
684
|
+
if rho is None:
|
|
685
|
+
p, _ = x_train.shape
|
|
686
|
+
rho = float(p)
|
|
687
|
+
return None, a, b, rho
|
|
688
|
+
|
|
689
|
+
@staticmethod
|
|
690
|
+
def _process_offset_settings(
|
|
691
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
692
|
+
offset: float | Float32[Any, ''] | None,
|
|
693
|
+
) -> Float32[Array, '']:
|
|
694
|
+
"""Return offset."""
|
|
695
|
+
if offset is not None:
|
|
696
|
+
return jnp.asarray(offset)
|
|
697
|
+
elif y_train.size < 1:
|
|
698
|
+
return jnp.array(0.0)
|
|
699
|
+
else:
|
|
700
|
+
mean = y_train.mean()
|
|
701
|
+
|
|
702
|
+
if y_train.dtype == bool:
|
|
703
|
+
bound = 1 / (1 + y_train.size)
|
|
704
|
+
mean = jnp.clip(mean, bound, 1 - bound)
|
|
705
|
+
return ndtri(mean)
|
|
706
|
+
else:
|
|
707
|
+
return mean
|
|
708
|
+
|
|
709
|
+
@staticmethod
|
|
710
|
+
def _process_leaf_sdev_settings(
|
|
711
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
712
|
+
k: float,
|
|
713
|
+
ntree: int,
|
|
714
|
+
tau_num: FloatLike | None,
|
|
715
|
+
):
|
|
716
|
+
"""Return sigma_mu."""
|
|
717
|
+
if tau_num is None:
|
|
718
|
+
if y_train.dtype == bool:
|
|
719
|
+
tau_num = 3.0
|
|
720
|
+
elif y_train.size < 2:
|
|
721
|
+
tau_num = 1.0
|
|
722
|
+
else:
|
|
723
|
+
tau_num = (y_train.max() - y_train.min()) / 2
|
|
724
|
+
|
|
725
|
+
return tau_num / (k * math.sqrt(ntree))
|
|
726
|
+
|
|
727
|
+
@staticmethod
|
|
728
|
+
def _determine_splits(
|
|
729
|
+
x_train: Real[Array, 'p n'],
|
|
730
|
+
usequants: bool,
|
|
731
|
+
numcut: int,
|
|
732
|
+
xinfo: Float[Array, 'p n'] | None,
|
|
733
|
+
) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
|
|
734
|
+
if xinfo is not None:
|
|
735
|
+
if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]:
|
|
736
|
+
msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)'
|
|
737
|
+
raise ValueError(msg)
|
|
738
|
+
return prepcovars.parse_xinfo(xinfo)
|
|
739
|
+
elif usequants:
|
|
740
|
+
return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
|
|
741
|
+
else:
|
|
742
|
+
return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
|
|
743
|
+
|
|
744
|
+
@staticmethod
|
|
745
|
+
def _bin_predictors(
|
|
746
|
+
x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits']
|
|
747
|
+
) -> UInt[Array, 'p n']:
|
|
748
|
+
return prepcovars.bin_predictors(x, splits)
|
|
749
|
+
|
|
750
|
+
@staticmethod
|
|
751
|
+
def _setup_mcmc(
|
|
752
|
+
x_train: Real[Array, 'p n'],
|
|
753
|
+
y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
|
|
754
|
+
offset: Float32[Array, ''],
|
|
755
|
+
w: Float[Array, ' n'] | None,
|
|
756
|
+
max_split: UInt[Array, ' p'],
|
|
757
|
+
lamda: Float32[Array, ''] | None,
|
|
758
|
+
sigma_mu: FloatLike,
|
|
759
|
+
sigdf: FloatLike,
|
|
760
|
+
power: FloatLike,
|
|
761
|
+
base: FloatLike,
|
|
762
|
+
maxdepth: int,
|
|
763
|
+
ntree: int,
|
|
764
|
+
init_kw: dict[str, Any] | None,
|
|
765
|
+
rm_const: bool,
|
|
766
|
+
theta: FloatLike | None,
|
|
767
|
+
a: FloatLike | None,
|
|
768
|
+
b: FloatLike | None,
|
|
769
|
+
rho: FloatLike | None,
|
|
770
|
+
num_chains: int | None,
|
|
771
|
+
num_chain_devices: int | None,
|
|
772
|
+
num_data_devices: int | None,
|
|
773
|
+
devices: Device | Sequence[Device] | None,
|
|
774
|
+
sparse: bool,
|
|
775
|
+
nskip: int,
|
|
776
|
+
):
|
|
777
|
+
p_nonterminal = make_p_nonterminal(maxdepth, base, power)
|
|
778
|
+
|
|
779
|
+
if y_train.dtype == bool:
|
|
780
|
+
error_cov_df = None
|
|
781
|
+
error_cov_scale = None
|
|
782
|
+
else:
|
|
783
|
+
assert lamda is not None
|
|
784
|
+
# inverse gamma prior: alpha = df / 2, beta = scale / 2
|
|
785
|
+
error_cov_df = sigdf
|
|
786
|
+
error_cov_scale = lamda * sigdf
|
|
787
|
+
|
|
788
|
+
# process device settings
|
|
789
|
+
device_kw, device = process_device_settings(
|
|
790
|
+
y_train, num_chains, num_chain_devices, num_data_devices, devices
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
kw: dict = dict(
|
|
794
|
+
X=x_train,
|
|
795
|
+
# copy y_train because it's going to be donated in the mcmc loop
|
|
796
|
+
y=jnp.array(y_train),
|
|
797
|
+
offset=offset,
|
|
798
|
+
error_scale=w,
|
|
799
|
+
max_split=max_split,
|
|
800
|
+
num_trees=ntree,
|
|
801
|
+
p_nonterminal=p_nonterminal,
|
|
802
|
+
leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)),
|
|
803
|
+
error_cov_df=error_cov_df,
|
|
804
|
+
error_cov_scale=error_cov_scale,
|
|
805
|
+
min_points_per_decision_node=10,
|
|
806
|
+
min_points_per_leaf=5,
|
|
807
|
+
theta=theta,
|
|
808
|
+
a=a,
|
|
809
|
+
b=b,
|
|
810
|
+
rho=rho,
|
|
811
|
+
sparse_on_at=nskip // 2 if sparse else None,
|
|
812
|
+
**device_kw,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
if rm_const:
|
|
816
|
+
n_empty = jnp.sum(max_split == 0).item()
|
|
817
|
+
kw.update(filter_splitless_vars=n_empty)
|
|
818
|
+
|
|
819
|
+
if init_kw is not None:
|
|
820
|
+
kw.update(init_kw)
|
|
821
|
+
|
|
822
|
+
state = mcmcstep.init(**kw)
|
|
823
|
+
|
|
824
|
+
# put state on device if requested explicitly by the user
|
|
825
|
+
if device is not None:
|
|
826
|
+
state = device_put(state, device, donate=True)
|
|
827
|
+
|
|
828
|
+
return state
|
|
829
|
+
|
|
830
|
+
@classmethod
|
|
831
|
+
def _run_mcmc(
|
|
832
|
+
cls,
|
|
833
|
+
mcmc_state: mcmcstep.State,
|
|
834
|
+
ndpost: int,
|
|
835
|
+
nskip: int,
|
|
836
|
+
keepevery: int,
|
|
837
|
+
printevery: int | None,
|
|
838
|
+
seed: int | Integer[Array, ''] | Key[Array, ''],
|
|
839
|
+
run_mcmc_kw: dict | None,
|
|
840
|
+
) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]:
|
|
841
|
+
# prepare random generator seed
|
|
842
|
+
if is_key(seed):
|
|
843
|
+
key = jnp.copy(seed)
|
|
844
|
+
else:
|
|
845
|
+
key = jax.random.key(seed)
|
|
846
|
+
|
|
847
|
+
# round up ndpost
|
|
848
|
+
num_chains = get_num_chains(mcmc_state)
|
|
849
|
+
if num_chains is None:
|
|
850
|
+
num_chains = 1
|
|
851
|
+
n_save = ndpost // num_chains + bool(ndpost % num_chains)
|
|
852
|
+
|
|
853
|
+
# prepare arguments
|
|
854
|
+
kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
|
|
855
|
+
kw.update(
|
|
856
|
+
mcmcloop.make_default_callback(
|
|
857
|
+
mcmc_state,
|
|
858
|
+
dot_every=None if printevery is None or printevery == 1 else 1,
|
|
859
|
+
report_every=printevery,
|
|
860
|
+
)
|
|
861
|
+
)
|
|
862
|
+
if run_mcmc_kw is not None:
|
|
863
|
+
kw.update(run_mcmc_kw)
|
|
864
|
+
|
|
865
|
+
return run_mcmc(key, mcmc_state, n_save, **kw)
|
|
866
|
+
|
|
867
|
+
def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']:
|
|
868
|
+
"""Evaluate trees on already quantized `x`."""
|
|
869
|
+
out = evaluate_trace(x, self._main_trace)
|
|
870
|
+
return collapse(out, 0, -1)
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
class DeviceKwArgs(TypedDict):
|
|
874
|
+
num_chains: int | None
|
|
875
|
+
mesh: Mesh | None
|
|
876
|
+
target_platform: Literal['cpu', 'gpu'] | None
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def process_device_settings(
|
|
880
|
+
y_train: Array,
|
|
881
|
+
num_chains: int | None,
|
|
882
|
+
num_chain_devices: int | None,
|
|
883
|
+
num_data_devices: int | None,
|
|
884
|
+
devices: Device | Sequence[Device] | None,
|
|
885
|
+
) -> tuple[DeviceKwArgs, Device | None]:
|
|
886
|
+
"""Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
|
|
887
|
+
# determine devices
|
|
888
|
+
if devices is not None:
|
|
889
|
+
if not hasattr(devices, '__len__'):
|
|
890
|
+
devices = (devices,)
|
|
891
|
+
device = devices[0]
|
|
892
|
+
platform = device.platform
|
|
893
|
+
elif hasattr(y_train, 'platform'):
|
|
894
|
+
platform = y_train.platform()
|
|
895
|
+
device = None
|
|
896
|
+
# set device=None because if the devices were not specified explicitly
|
|
897
|
+
# we may be in the case where computation will follow data placement,
|
|
898
|
+
# do not disturb jax as the user may be playing with vmap, jit, reshard...
|
|
899
|
+
devices = jax.devices(platform)
|
|
900
|
+
else:
|
|
901
|
+
msg = 'not possible to infer device from `y_train`, please set `devices`'
|
|
902
|
+
raise ValueError(msg)
|
|
903
|
+
|
|
904
|
+
# create mesh
|
|
905
|
+
if num_chain_devices is None and num_data_devices is None:
|
|
906
|
+
mesh = None
|
|
907
|
+
else:
|
|
908
|
+
mesh = dict()
|
|
909
|
+
if num_chain_devices is not None:
|
|
910
|
+
mesh.update(chains=num_chain_devices)
|
|
911
|
+
if num_data_devices is not None:
|
|
912
|
+
mesh.update(data=num_data_devices)
|
|
913
|
+
mesh = make_mesh(
|
|
914
|
+
axis_shapes=tuple(mesh.values()),
|
|
915
|
+
axis_names=tuple(mesh),
|
|
916
|
+
axis_types=(AxisType.Auto,) * len(mesh),
|
|
917
|
+
devices=devices,
|
|
918
|
+
)
|
|
919
|
+
device = None
|
|
920
|
+
# set device=None because `mcmcstep.init` will `device_put` with the
|
|
921
|
+
# mesh already, we don't want to undo its work
|
|
922
|
+
|
|
923
|
+
# prepare arguments to `init`
|
|
924
|
+
settings = DeviceKwArgs(
|
|
925
|
+
num_chains=num_chains,
|
|
926
|
+
mesh=mesh,
|
|
927
|
+
target_platform=None
|
|
928
|
+
if mesh is not None or hasattr(y_train, 'platform')
|
|
929
|
+
else platform,
|
|
930
|
+
# here we don't take into account the case where the user has set both
|
|
931
|
+
# batch sizes; since the user has to be playing with `init_kw` to do
|
|
932
|
+
# that, we'll let `init` throw the error and the user set
|
|
933
|
+
# `target_platform` themselves so they have a clearer idea how the
|
|
934
|
+
# thing works.
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
return settings, device
|