bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/BART.py CHANGED
@@ -22,17 +22,73 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- import functools
25
+ """Implement a class `gbart` that mimics the R BART package."""
26
+
27
+ import math
28
+ from collections.abc import Sequence
29
+ from functools import cached_property
30
+ from typing import Any, Literal, Protocol
26
31
 
27
32
  import jax
28
33
  import jax.numpy as jnp
34
+ from equinox import Module, field
35
+ from jax.scipy.special import ndtr
36
+ from jaxtyping import (
37
+ Array,
38
+ Bool,
39
+ Float,
40
+ Float32,
41
+ Int32,
42
+ Integer,
43
+ Key,
44
+ Real,
45
+ Shaped,
46
+ UInt,
47
+ )
48
+ from numpy import ndarray
49
+
50
+ from bartz import mcmcloop, mcmcstep, prepcovars
51
+ from bartz.jaxext.scipy.special import ndtri
52
+ from bartz.jaxext.scipy.stats import invgamma
53
+
54
+ FloatLike = float | Float[Any, '']
55
+
56
+
57
+ class DataFrame(Protocol):
58
+ """DataFrame duck-type for `gbart`.
59
+
60
+ Attributes
61
+ ----------
62
+ columns : Sequence[str]
63
+ The names of the columns.
64
+ """
65
+
66
+ columns: Sequence[str]
67
+
68
+ def to_numpy(self) -> ndarray:
69
+ """Convert the dataframe to a 2d numpy array with columns on the second axis."""
70
+ ...
29
71
 
30
- from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
31
72
 
73
+ class Series(Protocol):
74
+ """Series duck-type for `gbart`.
32
75
 
33
- class gbart:
76
+ Attributes
77
+ ----------
78
+ name : str | None
79
+ The name of the series.
34
80
  """
35
- Nonparametric regression with Bayesian Additive Regression Trees (BART).
81
+
82
+ name: str | None
83
+
84
+ def to_numpy(self) -> ndarray:
85
+ """Convert the series to a 1d numpy array."""
86
+ ...
87
+
88
+
89
+ class gbart(Module):
90
+ R"""
91
+ Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
36
92
 
37
93
  Regress `y_train` on `x_train` with a latent mean function represented as
38
94
  a sum of decision trees. The inference is carried out by sampling the
@@ -40,55 +96,108 @@ class gbart:
40
96
 
41
97
  Parameters
42
98
  ----------
43
- x_train : array (p, n) or DataFrame
99
+ x_train
44
100
  The training predictors.
45
- y_train : array (n,) or Series
101
+ y_train
46
102
  The training responses.
47
- x_test : array (p, m) or DataFrame, optional
103
+ x_test
48
104
  The test predictors.
49
- usequants : bool, default False
105
+ type
106
+ The type of regression. 'wbart' for continuous regression, 'pbart' for
107
+ binary regression with probit link.
108
+ sparse
109
+ Whether to activate variable selection on the predictors as done in
110
+ [1]_.
111
+ theta
112
+ a
113
+ b
114
+ rho
115
+ Hyperparameters of the sparsity prior used for variable selection.
116
+
117
+ The prior distribution on the choice of predictor for each decision rule
118
+ is
119
+
120
+ .. math::
121
+ (s_1, \ldots, s_p) \sim
122
+ \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
123
+
124
+ If `theta` is not specified, it's a priori distributed according to
125
+
126
+ .. math::
127
+ \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
128
+ \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
129
+
130
+ If not specified, `rho` is set to the number of predictors p. To tune
131
+ the prior, consider setting a lower `rho` to prefer more sparsity.
132
+ If setting `theta` directly, it should be in the ballpark of p or lower
133
+ as well.
134
+ xinfo
135
+ A matrix with the cutpoins to use to bin each predictor. If not
136
+ specified, it is generated automatically according to `usequants` and
137
+ `numcut`.
138
+
139
+ Each row shall contain a sorted list of cutpoints for a predictor. If
140
+ there are less cutpoints than the number of columns in the matrix,
141
+ fill the remaining cells with NaN.
142
+
143
+ `xinfo` shall be a matrix even if `x_train` is a dataframe.
144
+ usequants
50
145
  Whether to use predictors quantiles instead of a uniform grid to bin
51
- predictors.
52
- sigest : float, optional
146
+ predictors. Ignored if `xinfo` is specified.
147
+ rm_const
148
+ How to treat predictors with no associated decision rules (i.e., there
149
+ are no available cutpoints for that predictor). If `True` (default),
150
+ they are ignored. If `False`, an error is raised if there are any. If
151
+ `None`, no check is performed, and the output of the MCMC may not make
152
+ sense if there are predictors without cutpoints. The option `None` is
153
+ provided only to allow jax tracing.
154
+ sigest
53
155
  An estimate of the residual standard deviation on `y_train`, used to set
54
156
  `lamda`. If not specified, it is estimated by linear regression (with
55
157
  intercept, and without taking into account `w`). If `y_train` has less
56
158
  than two elements, it is set to 1. If n <= p, it is set to the standard
57
159
  deviation of `y_train`. Ignored if `lamda` is specified.
58
- sigdf : int, default 3
160
+ sigdf
59
161
  The degrees of freedom of the scaled inverse-chisquared prior on the
60
162
  noise variance.
61
- sigquant : float, default 0.9
163
+ sigquant
62
164
  The quantile of the prior on the noise variance that shall match
63
165
  `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
64
- k : float, default 2
166
+ k
65
167
  The inverse scale of the prior standard deviation on the latent mean
66
168
  function, relative to half the observed range of `y_train`. If `y_train`
67
169
  has less than two elements, `k` is ignored and the scale is set to 1.
68
- power : float, default 2
69
- base : float, default 0.95
170
+ power
171
+ base
70
172
  Parameters of the prior on tree node generation. The probability that a
71
173
  node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
72
174
  power``.
73
- maxdepth : int, default 6
74
- The maximum depth of the trees. This is 1-based, so with the default
75
- ``maxdepth=6``, the depths of the levels range from 0 to 5.
76
- lamda : float, optional
77
- The scale of the prior on the noise variance. If ``lamda==1``, the
78
- prior is an inverse chi-squared scaled to have harmonic mean 1. If
79
- not specified, it is set based on `sigest` and `sigquant`.
80
- offset : float, optional
175
+ lamda
176
+ The prior harmonic mean of the error variance. (The harmonic mean of x
177
+ is 1/mean(1/x).) If not specified, it is set based on `sigest` and
178
+ `sigquant`.
179
+ tau_num
180
+ The numerator in the expression that determines the prior standard
181
+ deviation of leaves. If not specified, default to ``(max(y_train) -
182
+ min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
183
+ continuous regression, and 3 for binary regression.
184
+ offset
81
185
  The prior mean of the latent mean function. If not specified, it is set
82
- to the mean of `y_train`. If `y_train` is empty, it is set to 0.
83
- w : array (n,), optional
186
+ to the mean of `y_train` for continuous regression, and to
187
+ ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
188
+ `offset` is set to 0. With binary regression, if `y_train` is all
189
+ `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
190
+ ``Phi^-1(n/(n+1))``, respectively.
191
+ w
84
192
  Coefficients that rescale the error standard deviation on each
85
193
  datapoint. Not specifying `w` is equivalent to setting it to 1 for all
86
194
  datapoints. Note: `w` is ignored in the automatic determination of
87
195
  `sigest`, so either the weights should be O(1), or `sigest` should be
88
196
  specified by the user.
89
- ntree : int, default 200
90
- The number of trees used to represent the latent mean function.
91
- numcut : int, default 255
197
+ ntree
198
+ The number of trees used to represent the latent mean function. By
199
+ default 200 for continuous regression and 50 for binary regression.
200
+ numcut
92
201
  If `usequants` is `False`: the exact number of cutpoints used to bin the
93
202
  predictors, ranging between the minimum and maximum observed values
94
203
  (excluded).
@@ -101,50 +210,43 @@ class gbart:
101
210
 
102
211
  Before running the algorithm, the predictors are compressed to the
103
212
  smallest integer type that fits the bin indices, so `numcut` is best set
104
- to the maximum value of an unsigned integer type.
105
- ndpost : int, default 1000
213
+ to the maximum value of an unsigned integer type, like 255.
214
+
215
+ Ignored if `xinfo` is specified.
216
+ ndpost
106
217
  The number of MCMC samples to save, after burn-in.
107
- nskip : int, default 100
218
+ nskip
108
219
  The number of initial MCMC samples to discard as burn-in.
109
- keepevery : int, default 1
110
- The thinning factor for the MCMC samples, after burn-in.
111
- printevery : int, default 100
112
- The number of iterations (including skipped ones) between each log.
113
- seed : int or jax random key, default 0
220
+ keepevery
221
+ The thinning factor for the MCMC samples, after burn-in. By default, 1
222
+ for continuous regression and 10 for binary regression.
223
+ printevery
224
+ The number of iterations (including thinned-away ones) between each log
225
+ line. Set to `None` to disable logging.
226
+
227
+ `printevery` has a few unexpected side effects. On cpu, interrupting
228
+ with ^C halts the MCMC only on the next log. And the total number of
229
+ iterations is a multiple of `printevery`, so if ``nskip + keepevery *
230
+ ndpost`` is not a multiple of `printevery`, some of the last iterations
231
+ will not be saved.
232
+ seed
114
233
  The seed for the random number generator.
115
- initkw : dict
116
- Additional arguments passed to `mcmcstep.init`.
234
+ maxdepth
235
+ The maximum depth of the trees. This is 1-based, so with the default
236
+ ``maxdepth=6``, the depths of the levels range from 0 to 5.
237
+ init_kw
238
+ Additional arguments passed to `bartz.mcmcstep.init`.
239
+ run_mcmc_kw
240
+ Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
117
241
 
118
242
  Attributes
119
243
  ----------
120
- yhat_train : array (ndpost, n)
121
- The conditional posterior mean at `x_train` for each MCMC iteration.
122
- yhat_train_mean : array (n,)
123
- The marginal posterior mean at `x_train`.
124
- yhat_test : array (ndpost, m)
125
- The conditional posterior mean at `x_test` for each MCMC iteration.
126
- yhat_test_mean : array (m,)
127
- The marginal posterior mean at `x_test`.
128
- sigma : array (ndpost,)
129
- The standard deviation of the error.
130
- first_sigma : array (nskip,)
131
- The standard deviation of the error in the burn-in phase.
132
- offset : float
244
+ offset : Float32[Array, '']
133
245
  The prior mean of the latent mean function.
134
- scale : float
135
- The prior standard deviation of the latent mean function.
136
- lamda : float
137
- The prior harmonic mean of the error variance.
138
- sigest : float or None
246
+ sigest : Float32[Array, ''] | None
139
247
  The estimated standard deviation of the error used to set `lamda`.
140
- ntree : int
141
- The number of trees.
142
- maxdepth : int
143
- The maximum depth of the trees.
144
-
145
- Methods
146
- -------
147
- predict
248
+ yhat_test : Float32[Array, 'ndpost m'] | None
249
+ The conditional posterior mean at `x_test` for each MCMC iteration.
148
250
 
149
251
  Notes
150
252
  -----
@@ -156,128 +258,293 @@ class gbart:
156
258
  - If ``usequants=False``, R BART switches to quantiles anyway if there are
157
259
  less predictor values than the required number of bins, while bartz
158
260
  always follows the specification.
261
+ - Some functionality is missing.
159
262
  - The error variance parameter is called `lamda` instead of `lambda`.
160
- - `rm_const` is always `False`.
161
- - The default `numcut` is 255 instead of 100.
162
- - A lot of functionality is missing (variable selection, discrete response).
163
263
  - There are some additional attributes, and some missing.
264
+ - The trees have a maximum depth.
265
+ - `rm_const` refers to predictors without decision rules instead of
266
+ predictors that are constant in `x_train`.
267
+ - If `rm_const=True` and some variables are dropped, the predictors
268
+ matrix/dataframe passed to `predict` should still include them.
164
269
 
270
+ References
271
+ ----------
272
+ .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
273
+ High-Dimensional Prediction and Variable Selection”. In: Journal of the
274
+ American Statistical Association 113.522, pp. 626-636.
275
+ .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
276
+ Bayesian additive regression trees," The Annals of Applied Statistics,
277
+ Ann. Appl. Stat. 4(1), 266-298, (March 2010).
165
278
  """
166
279
 
280
+ _main_trace: mcmcloop.MainTrace
281
+ _burnin_trace: mcmcloop.BurninTrace
282
+ _mcmc_state: mcmcstep.State
283
+ _splits: Real[Array, 'p max_num_splits']
284
+ _x_train_fmt: Any = field(static=True)
285
+
286
+ ndpost: int = field(static=True)
287
+ offset: Float32[Array, '']
288
+ sigest: Float32[Array, ''] | None = None
289
+ yhat_test: Float32[Array, 'ndpost m'] | None = None
290
+
167
291
  def __init__(
168
292
  self,
169
- x_train,
170
- y_train,
293
+ x_train: Real[Array, 'p n'] | DataFrame,
294
+ y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
171
295
  *,
172
- x_test=None,
173
- usequants=False,
174
- sigest=None,
175
- sigdf=3,
176
- sigquant=0.9,
177
- k=2,
178
- power=2,
179
- base=0.95,
180
- maxdepth=6,
181
- lamda=None,
182
- offset=None,
183
- w=None,
184
- ntree=200,
185
- numcut=255,
186
- ndpost=1000,
187
- nskip=100,
188
- keepevery=1,
189
- printevery=100,
190
- seed=0,
191
- initkw=None,
296
+ x_test: Real[Array, 'p m'] | DataFrame | None = None,
297
+ type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
298
+ sparse: bool = False,
299
+ theta: FloatLike | None = None,
300
+ a: FloatLike = 0.5,
301
+ b: FloatLike = 1.0,
302
+ rho: FloatLike | None = None,
303
+ xinfo: Float[Array, 'p n'] | None = None,
304
+ usequants: bool = False,
305
+ rm_const: bool | None = True,
306
+ sigest: FloatLike | None = None,
307
+ sigdf: FloatLike = 3.0,
308
+ sigquant: FloatLike = 0.9,
309
+ k: FloatLike = 2.0,
310
+ power: FloatLike = 2.0,
311
+ base: FloatLike = 0.95,
312
+ lamda: FloatLike | None = None,
313
+ tau_num: FloatLike | None = None,
314
+ offset: FloatLike | None = None,
315
+ w: Float[Array, ' n'] | None = None,
316
+ ntree: int | None = None,
317
+ numcut: int = 100,
318
+ ndpost: int = 1000,
319
+ nskip: int = 100,
320
+ keepevery: int | None = None,
321
+ printevery: int | None = 100,
322
+ seed: int | Key[Array, ''] = 0,
323
+ maxdepth: int = 6,
324
+ init_kw: dict | None = None,
325
+ run_mcmc_kw: dict | None = None,
192
326
  ):
327
+ # check data and put it in the right format
193
328
  x_train, x_train_fmt = self._process_predictor_input(x_train)
194
- y_train, _ = self._process_response_input(y_train)
329
+ y_train = self._process_response_input(y_train)
195
330
  self._check_same_length(x_train, y_train)
196
331
  if w is not None:
197
- w, _ = self._process_response_input(w)
332
+ w = self._process_response_input(w)
198
333
  self._check_same_length(x_train, w)
199
334
 
335
+ # check data types are correct for continuous/binary regression
336
+ self._check_type_settings(y_train, type, w)
337
+ # from here onwards, the type is determined by y_train.dtype == bool
338
+
339
+ # set defaults that depend on type of regression
340
+ if ntree is None:
341
+ ntree = 50 if y_train.dtype == bool else 200
342
+ if keepevery is None:
343
+ keepevery = 10 if y_train.dtype == bool else 1
344
+
345
+ # process sparsity settings
346
+ theta, a, b, rho = self._process_sparsity_settings(
347
+ x_train, sparse, theta, a, b, rho
348
+ )
349
+
350
+ # process "standardization" settings
200
351
  offset = self._process_offset_settings(y_train, offset)
201
- scale = self._process_scale_settings(y_train, k)
202
- lamda, sigest = self._process_noise_variance_settings(
203
- x_train, y_train, sigest, sigdf, sigquant, lamda, offset
352
+ sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
353
+ lamda, sigest = self._process_error_variance_settings(
354
+ x_train, y_train, sigest, sigdf, sigquant, lamda
204
355
  )
205
356
 
206
- splits, max_split = self._determine_splits(x_train, usequants, numcut)
357
+ # determine splits
358
+ splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo)
207
359
  x_train = self._bin_predictors(x_train, splits)
208
- y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
209
360
 
210
- mcmc_state = self._setup_mcmc(
361
+ # setup and run mcmc
362
+ initial_state = self._setup_mcmc(
211
363
  x_train,
212
364
  y_train,
365
+ offset,
213
366
  w,
214
367
  max_split,
215
- lamda_scaled,
368
+ lamda,
369
+ sigma_mu,
216
370
  sigdf,
217
371
  power,
218
372
  base,
219
373
  maxdepth,
220
374
  ntree,
221
- initkw,
375
+ init_kw,
376
+ rm_const,
377
+ theta,
378
+ a,
379
+ b,
380
+ rho,
222
381
  )
223
382
  final_state, burnin_trace, main_trace = self._run_mcmc(
224
- mcmc_state, ndpost, nskip, keepevery, printevery, seed
383
+ initial_state,
384
+ ndpost,
385
+ nskip,
386
+ keepevery,
387
+ printevery,
388
+ seed,
389
+ run_mcmc_kw,
390
+ sparse,
225
391
  )
226
392
 
227
- sigma = self._extract_sigma(main_trace, scale)
228
- first_sigma = self._extract_sigma(burnin_trace, scale)
229
-
230
- self.offset = offset
231
- self.scale = scale
232
- self.lamda = lamda
393
+ # set public attributes
394
+ self.offset = final_state.offset # from the state because of buffer donation
395
+ self.ndpost = ndpost
233
396
  self.sigest = sigest
234
- self.ntree = ntree
235
- self.maxdepth = maxdepth
236
- self.sigma = sigma
237
- self.first_sigma = first_sigma
238
397
 
239
- self._x_train_fmt = x_train_fmt
240
- self._splits = splits
398
+ # set private attributes
241
399
  self._main_trace = main_trace
400
+ self._burnin_trace = burnin_trace
242
401
  self._mcmc_state = final_state
402
+ self._splits = splits
403
+ self._x_train_fmt = x_train_fmt
243
404
 
405
+ # predict at test points
244
406
  if x_test is not None:
245
- yhat_test = self.predict(x_test)
246
- self.yhat_test = yhat_test
247
- self.yhat_test_mean = yhat_test.mean(axis=0)
407
+ self.yhat_test = self.predict(x_test)
408
+
409
+ @cached_property
410
+ def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
411
+ """The posterior probability of y being True at `x_test` for each MCMC iteration."""
412
+ if self.yhat_test is None or self._mcmc_state.y.dtype != bool:
413
+ return None
414
+ else:
415
+ return ndtr(self.yhat_test)
416
+
417
+ @cached_property
418
+ def prob_test_mean(self) -> Float32[Array, ' m'] | None:
419
+ """The marginal posterior probability of y being True at `x_test`."""
420
+ if self.prob_test is None:
421
+ return None
422
+ else:
423
+ return self.prob_test.mean(axis=0)
424
+
425
+ @cached_property
426
+ def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
427
+ """The posterior probability of y being True at `x_train` for each MCMC iteration."""
428
+ if self._mcmc_state.y.dtype == bool:
429
+ return ndtr(self.yhat_train)
430
+ else:
431
+ return None
432
+
433
+ @cached_property
434
+ def prob_train_mean(self) -> Float32[Array, ' n'] | None:
435
+ """The marginal posterior probability of y being True at `x_train`."""
436
+ if self.prob_train is None:
437
+ return None
438
+ else:
439
+ return self.prob_train.mean(axis=0)
440
+
441
+ @cached_property
442
+ def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None:
443
+ """The standard deviation of the error, including burn-in samples."""
444
+ if self._burnin_trace.sigma2 is None:
445
+ return None
446
+ else:
447
+ assert self._main_trace.sigma2 is not None
448
+ return jnp.sqrt(
449
+ jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2])
450
+ )
451
+
452
+ @cached_property
453
+ def sigma_mean(self) -> Float32[Array, ''] | None:
454
+ """The mean of `sigma`, only over the post-burnin samples."""
455
+ if self.sigma is None:
456
+ return None
457
+ else:
458
+ return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0)
459
+
460
+ @cached_property
461
+ def varcount(self) -> Int32[Array, 'ndpost p']:
462
+ """Histogram of predictor usage for decision rules in the trees."""
463
+ return mcmcloop.compute_varcount(
464
+ self._mcmc_state.forest.max_split.size, self._main_trace
465
+ )
466
+
467
+ @cached_property
468
+ def varcount_mean(self) -> Float32[Array, ' p']:
469
+ """Average of `varcount` across MCMC iterations."""
470
+ return self.varcount.mean(axis=0)
471
+
472
+ @cached_property
473
+ def varprob(self) -> Float32[Array, 'ndpost p']:
474
+ """Posterior samples of the probability of choosing each predictor for a decision rule."""
475
+ varprob = self._main_trace.varprob
476
+ if varprob is None:
477
+ max_split = self._mcmc_state.forest.max_split
478
+ p = max_split.size
479
+ peff = jnp.count_nonzero(max_split)
480
+ varprob = jnp.where(max_split, 1 / peff, 0)
481
+ varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
482
+ return varprob
483
+
484
+ @cached_property
485
+ def varprob_mean(self) -> Float32[Array, ' p']:
486
+ """The marginal posterior probability of each predictor being chosen for a decision rule."""
487
+ return self.varprob.mean(axis=0)
488
+
489
+ @cached_property
490
+ def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
491
+ """The marginal posterior mean at `x_test`.
492
+
493
+ Not defined with binary regression because it's error-prone, typically
494
+ the right thing to consider would be `prob_test_mean`.
495
+ """
496
+ if self.yhat_test is None or self._mcmc_state.y.dtype == bool:
497
+ return None
498
+ else:
499
+ return self.yhat_test.mean(axis=0)
500
+
501
+ @cached_property
502
+ def yhat_train(self) -> Float32[Array, 'ndpost n']:
503
+ """The conditional posterior mean at `x_train` for each MCMC iteration."""
504
+ x_train = self._mcmc_state.X
505
+ return self._predict(x_train)
248
506
 
249
- @functools.cached_property
250
- def yhat_train(self):
251
- x_train = self._mcmc_state['X']
252
- yhat_train = self._predict(self._main_trace, x_train)
253
- return self._transform_output(yhat_train, self.offset, self.scale)
507
+ @cached_property
508
+ def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
509
+ """The marginal posterior mean at `x_train`.
254
510
 
255
- @functools.cached_property
256
- def yhat_train_mean(self):
257
- return self.yhat_train.mean(axis=0)
511
+ Not defined with binary regression because it's error-prone, typically
512
+ the right thing to consider would be `prob_train_mean`.
513
+ """
514
+ if self._mcmc_state.y.dtype == bool:
515
+ return None
516
+ else:
517
+ return self.yhat_train.mean(axis=0)
258
518
 
259
- def predict(self, x_test):
519
+ def predict(
520
+ self, x_test: Real[Array, 'p m'] | DataFrame
521
+ ) -> Float32[Array, 'ndpost m']:
260
522
  """
261
523
  Compute the posterior mean at `x_test` for each MCMC iteration.
262
524
 
263
525
  Parameters
264
526
  ----------
265
- x_test : array (p, m) or DataFrame
527
+ x_test
266
528
  The test predictors.
267
529
 
268
530
  Returns
269
531
  -------
270
- yhat_test : array (ndpost, m)
271
- The conditional posterior mean at `x_test` for each MCMC iteration.
532
+ The conditional posterior mean at `x_test` for each MCMC iteration.
533
+
534
+ Raises
535
+ ------
536
+ ValueError
537
+ If `x_test` has a different format than `x_train`.
272
538
  """
273
539
  x_test, x_test_fmt = self._process_predictor_input(x_test)
274
- self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
540
+ if x_test_fmt != self._x_train_fmt:
541
+ msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
542
+ raise ValueError(msg)
275
543
  x_test = self._bin_predictors(x_test, self._splits)
276
- yhat_test = self._predict(self._main_trace, x_test)
277
- return self._transform_output(yhat_test, self.offset, self.scale)
544
+ return self._predict(x_test)
278
545
 
279
546
  @staticmethod
280
- def _process_predictor_input(x):
547
+ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
281
548
  if hasattr(x, 'columns'):
282
549
  fmt = dict(kind='dataframe', columns=x.columns)
283
550
  x = x.to_numpy().T
@@ -288,19 +555,12 @@ class gbart:
288
555
  return x, fmt
289
556
 
290
557
  @staticmethod
291
- def _check_compatible_formats(fmt1, fmt2):
292
- assert fmt1 == fmt2
293
-
294
- @staticmethod
295
- def _process_response_input(y):
558
+ def _process_response_input(y) -> Shaped[Array, ' n']:
296
559
  if hasattr(y, 'to_numpy'):
297
- fmt = dict(kind='series', name=y.name)
298
560
  y = y.to_numpy()
299
- else:
300
- fmt = dict(kind='array')
301
561
  y = jnp.asarray(y)
302
562
  assert y.ndim == 1
303
- return y, fmt
563
+ return y
304
564
 
305
565
  @staticmethod
306
566
  def _check_same_length(x1, x2):
@@ -308,18 +568,29 @@ class gbart:
308
568
  assert get_length(x1) == get_length(x2)
309
569
 
310
570
  @staticmethod
311
- def _process_noise_variance_settings(
312
- x_train, y_train, sigest, sigdf, sigquant, lamda, offset
313
- ):
314
- if lamda is not None:
571
+ def _process_error_variance_settings(
572
+ x_train, y_train, sigest, sigdf, sigquant, lamda
573
+ ) -> tuple[Float32[Array, ''] | None, ...]:
574
+ if y_train.dtype == bool:
575
+ if sigest is not None:
576
+ msg = 'Let `sigest=None` for binary regression'
577
+ raise ValueError(msg)
578
+ if lamda is not None:
579
+ msg = 'Let `lamda=None` for binary regression'
580
+ raise ValueError(msg)
581
+ return None, None
582
+ elif lamda is not None:
583
+ if sigest is not None:
584
+ msg = 'Let `sigest=None` if `lamda` is specified'
585
+ raise ValueError(msg)
315
586
  return lamda, None
316
587
  else:
317
588
  if sigest is not None:
318
- sigest2 = sigest * sigest
589
+ sigest2 = jnp.square(sigest)
319
590
  elif y_train.size < 2:
320
591
  sigest2 = 1
321
592
  elif y_train.size <= x_train.shape[0]:
322
- sigest2 = jnp.var(y_train - offset)
593
+ sigest2 = jnp.var(y_train)
323
594
  else:
324
595
  x_centered = x_train.T - x_train.mean(axis=1)
325
596
  y_centered = y_train - y_train.mean()
@@ -329,182 +600,214 @@ class gbart:
329
600
  dof = len(y_train) - rank
330
601
  sigest2 = chisq / dof
331
602
  alpha = sigdf / 2
332
- invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
603
+ invchi2 = invgamma.ppf(sigquant, alpha) / 2
333
604
  invchi2rid = invchi2 * sigdf
334
605
  return sigest2 / invchi2rid, jnp.sqrt(sigest2)
335
606
 
336
607
  @staticmethod
337
- def _process_offset_settings(y_train, offset):
608
+ def _check_type_settings(y_train, type, w): # noqa: A002
609
+ match type:
610
+ case 'wbart':
611
+ if y_train.dtype != jnp.float32:
612
+ msg = (
613
+ 'Continuous regression requires y_train.dtype=float32,'
614
+ f' got {y_train.dtype=} instead.'
615
+ )
616
+ raise TypeError(msg)
617
+ case 'pbart':
618
+ if w is not None:
619
+ msg = 'Binary regression does not support weights, set `w=None`'
620
+ raise ValueError(msg)
621
+ if y_train.dtype != bool:
622
+ msg = (
623
+ 'Binary regression requires y_train.dtype=bool,'
624
+ f' got {y_train.dtype=} instead.'
625
+ )
626
+ raise TypeError(msg)
627
+ case _:
628
+ msg = f'Invalid {type=}'
629
+ raise ValueError(msg)
630
+
631
+ @staticmethod
632
+ def _process_sparsity_settings(
633
+ x_train: Real[Array, 'p n'],
634
+ sparse: bool,
635
+ theta: FloatLike | None,
636
+ a: FloatLike,
637
+ b: FloatLike,
638
+ rho: FloatLike | None,
639
+ ) -> (
640
+ tuple[None, None, None, None]
641
+ | tuple[FloatLike, None, None, None]
642
+ | tuple[None, FloatLike, FloatLike, FloatLike]
643
+ ):
644
+ if not sparse:
645
+ return None, None, None, None
646
+ elif theta is not None:
647
+ return theta, None, None, None
648
+ else:
649
+ if rho is None:
650
+ p, _ = x_train.shape
651
+ rho = float(p)
652
+ return None, a, b, rho
653
+
654
+ @staticmethod
655
+ def _process_offset_settings(
656
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
657
+ offset: float | Float32[Any, ''] | None,
658
+ ) -> Float32[Array, '']:
338
659
  if offset is not None:
339
- return offset
660
+ return jnp.asarray(offset)
340
661
  elif y_train.size < 1:
341
- return 0
662
+ return jnp.array(0.0)
342
663
  else:
343
- return y_train.mean()
664
+ mean = y_train.mean()
344
665
 
345
- @staticmethod
346
- def _process_scale_settings(y_train, k):
347
- if y_train.size < 2:
348
- return 1
666
+ if y_train.dtype == bool:
667
+ bound = 1 / (1 + y_train.size)
668
+ mean = jnp.clip(mean, bound, 1 - bound)
669
+ return ndtri(mean)
349
670
  else:
350
- return (y_train.max() - y_train.min()) / (2 * k)
671
+ return mean
351
672
 
352
673
  @staticmethod
353
- def _determine_splits(x_train, usequants, numcut):
354
- if usequants:
674
+ def _process_leaf_sdev_settings(
675
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
676
+ k: float,
677
+ ntree: int,
678
+ tau_num: FloatLike | None,
679
+ ):
680
+ if tau_num is None:
681
+ if y_train.dtype == bool:
682
+ tau_num = 3.0
683
+ elif y_train.size < 2:
684
+ tau_num = 1.0
685
+ else:
686
+ tau_num = (y_train.max() - y_train.min()) / 2
687
+
688
+ return tau_num / (k * math.sqrt(ntree))
689
+
690
+ @staticmethod
691
+ def _determine_splits(
692
+ x_train: Real[Array, 'p n'],
693
+ usequants: bool,
694
+ numcut: int,
695
+ xinfo: Float[Array, 'p n'] | None,
696
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
697
+ if xinfo is not None:
698
+ if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]:
699
+ msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)'
700
+ raise ValueError(msg)
701
+ return prepcovars.parse_xinfo(xinfo)
702
+ elif usequants:
355
703
  return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
356
704
  else:
357
705
  return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
358
706
 
359
707
  @staticmethod
360
- def _bin_predictors(x, splits):
708
+ def _bin_predictors(x, splits) -> UInt[Array, 'p n']:
361
709
  return prepcovars.bin_predictors(x, splits)
362
710
 
363
- @staticmethod
364
- def _transform_input(y, lamda, offset, scale):
365
- y = (y - offset) / scale
366
- lamda = lamda / (scale * scale)
367
- return y, lamda
368
-
369
711
  @staticmethod
370
712
  def _setup_mcmc(
371
- x_train,
372
- y_train,
373
- w,
374
- max_split,
375
- lamda,
376
- sigdf,
377
- power,
378
- base,
379
- maxdepth,
380
- ntree,
381
- initkw,
713
+ x_train: Real[Array, 'p n'],
714
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
715
+ offset: Float32[Array, ''],
716
+ w: Float[Array, ' n'] | None,
717
+ max_split: UInt[Array, ' p'],
718
+ lamda: Float32[Array, ''] | None,
719
+ sigma_mu: FloatLike,
720
+ sigdf: FloatLike,
721
+ power: FloatLike,
722
+ base: FloatLike,
723
+ maxdepth: int,
724
+ ntree: int,
725
+ init_kw: dict[str, Any] | None,
726
+ rm_const: bool | None,
727
+ theta: FloatLike | None,
728
+ a: FloatLike | None,
729
+ b: FloatLike | None,
730
+ rho: FloatLike | None,
382
731
  ):
383
732
  depth = jnp.arange(maxdepth - 1)
384
733
  p_nonterminal = base / (1 + depth).astype(float) ** power
385
- sigma2_alpha = sigdf / 2
386
- sigma2_beta = lamda * sigma2_alpha
734
+
735
+ if y_train.dtype == bool:
736
+ sigma2_alpha = None
737
+ sigma2_beta = None
738
+ else:
739
+ sigma2_alpha = sigdf / 2
740
+ sigma2_beta = lamda * sigma2_alpha
741
+
387
742
  kw = dict(
388
743
  X=x_train,
389
- y=y_train,
744
+ # copy y_train because it's going to be donated in the mcmc loop
745
+ y=jnp.array(y_train),
746
+ offset=offset,
390
747
  error_scale=w,
391
748
  max_split=max_split,
392
749
  num_trees=ntree,
393
750
  p_nonterminal=p_nonterminal,
751
+ sigma_mu2=jnp.square(sigma_mu),
394
752
  sigma2_alpha=sigma2_alpha,
395
753
  sigma2_beta=sigma2_beta,
754
+ min_points_per_decision_node=10,
396
755
  min_points_per_leaf=5,
756
+ theta=theta,
757
+ a=a,
758
+ b=b,
759
+ rho=rho,
397
760
  )
398
- if initkw is not None:
399
- kw.update(initkw)
761
+
762
+ if rm_const is None:
763
+ kw.update(filter_splitless_vars=False)
764
+ elif rm_const:
765
+ kw.update(filter_splitless_vars=True)
766
+ else:
767
+ n_empty = jnp.count_nonzero(max_split == 0)
768
+ if n_empty:
769
+ msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
770
+ raise ValueError(msg)
771
+ kw.update(filter_splitless_vars=False)
772
+
773
+ if init_kw is not None:
774
+ kw.update(init_kw)
775
+
400
776
  return mcmcstep.init(**kw)
401
777
 
402
778
  @staticmethod
403
- def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
779
+ def _run_mcmc(
780
+ mcmc_state: mcmcstep.State,
781
+ ndpost: int,
782
+ nskip: int,
783
+ keepevery: int,
784
+ printevery: int | None,
785
+ seed: int | Integer[Array, ''] | Key[Array, ''],
786
+ run_mcmc_kw: dict | None,
787
+ sparse: bool,
788
+ ):
789
+ # prepare random generator seed
404
790
  if isinstance(seed, jax.Array) and jnp.issubdtype(
405
791
  seed.dtype, jax.dtypes.prng_key
406
792
  ):
407
- key = seed
793
+ key = seed.copy()
794
+ # copy because the inner loop in run_mcmc will donate the buffer
408
795
  else:
409
796
  key = jax.random.key(seed)
410
- callback = mcmcloop.make_simple_print_callback(printevery)
411
- return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
412
797
 
413
- @staticmethod
414
- def _predict(trace, x):
415
- return mcmcloop.evaluate_trace(trace, x)
416
-
417
- @staticmethod
418
- def _transform_output(y, offset, scale):
419
- return offset + scale * y
420
-
421
- @staticmethod
422
- def _extract_sigma(trace, scale):
423
- return scale * jnp.sqrt(trace['sigma2'])
424
-
425
- def _show_tree(self, i_sample, i_tree, print_all=False):
426
- from . import debug
427
-
428
- trace = self._main_trace
429
- leaf_tree = trace['leaf_trees'][i_sample, i_tree]
430
- var_tree = trace['var_trees'][i_sample, i_tree]
431
- split_tree = trace['split_trees'][i_sample, i_tree]
432
- debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
433
-
434
- def _sigma_harmonic_mean(self, prior=False):
435
- bart = self._mcmc_state
436
- if prior:
437
- alpha = bart['sigma2_alpha']
438
- beta = bart['sigma2_beta']
439
- else:
440
- resid = bart['resid']
441
- alpha = bart['sigma2_alpha'] + resid.size / 2
442
- norm2 = jnp.dot(
443
- resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
798
+ # prepare arguments
799
+ kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
800
+ kw.update(
801
+ mcmcloop.make_default_callback(
802
+ dot_every=None if printevery is None or printevery == 1 else 1,
803
+ report_every=printevery,
804
+ sparse_on_at=nskip // 2 if sparse else None,
444
805
  )
445
- beta = bart['sigma2_beta'] + norm2 / 2
446
- sigma2 = beta / alpha
447
- return jnp.sqrt(sigma2) * self.scale
448
-
449
- def _compare_resid(self):
450
- bart = self._mcmc_state
451
- resid1 = bart['resid']
452
- yhat = grove.evaluate_forest(
453
- bart['X'],
454
- bart['leaf_trees'],
455
- bart['var_trees'],
456
- bart['split_trees'],
457
- jnp.float32,
458
806
  )
459
- resid2 = bart['y'] - yhat
460
- return resid1, resid2
461
-
462
- def _avg_acc(self):
463
- trace = self._main_trace
464
-
465
- def acc(prefix):
466
- acc = trace[f'{prefix}_acc_count']
467
- prop = trace[f'{prefix}_prop_count']
468
- return acc.sum() / prop.sum()
469
-
470
- return acc('grow'), acc('prune')
471
-
472
- def _avg_prop(self):
473
- trace = self._main_trace
474
-
475
- def prop(prefix):
476
- return trace[f'{prefix}_prop_count'].sum()
477
-
478
- pgrow = prop('grow')
479
- pprune = prop('prune')
480
- total = pgrow + pprune
481
- return pgrow / total, pprune / total
482
-
483
- def _avg_move(self):
484
- agrow, aprune = self._avg_acc()
485
- pgrow, pprune = self._avg_prop()
486
- return agrow * pgrow, aprune * pprune
487
-
488
- def _depth_distr(self):
489
- from . import debug
490
-
491
- trace = self._main_trace
492
- split_trees = trace['split_trees']
493
- return debug.trace_depth_distr(split_trees)
494
-
495
- def _points_per_leaf_distr(self):
496
- from . import debug
497
-
498
- return debug.trace_points_per_leaf_distr(
499
- self._main_trace, self._mcmc_state['X']
500
- )
501
-
502
- def _check_trees(self):
503
- from . import debug
807
+ if run_mcmc_kw is not None:
808
+ kw.update(run_mcmc_kw)
504
809
 
505
- return debug.check_trace(self._main_trace, self._mcmc_state)
810
+ return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
506
811
 
507
- def _tree_goes_bad(self):
508
- bad = self._check_trees().astype(bool)
509
- bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
510
- return bad & ~bad_before
812
+ def _predict(self, x):
813
+ return mcmcloop.evaluate_trace(self._main_trace, x)