bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/BART.py CHANGED
@@ -22,25 +22,73 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """Implement a user interface that mimics the R BART package."""
25
+ """Implement a class `gbart` that mimics the R BART package."""
26
26
 
27
- import functools
28
27
  import math
29
- from typing import Any, Literal
28
+ from collections.abc import Sequence
29
+ from functools import cached_property
30
+ from typing import Any, Literal, Protocol
30
31
 
31
32
  import jax
32
33
  import jax.numpy as jnp
33
- from jax.scipy.special import ndtri
34
- from jaxtyping import Array, Bool, Float, Float32
35
-
36
- from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
34
+ from equinox import Module, field
35
+ from jax.scipy.special import ndtr
36
+ from jaxtyping import (
37
+ Array,
38
+ Bool,
39
+ Float,
40
+ Float32,
41
+ Int32,
42
+ Integer,
43
+ Key,
44
+ Real,
45
+ Shaped,
46
+ UInt,
47
+ )
48
+ from numpy import ndarray
49
+
50
+ from bartz import mcmcloop, mcmcstep, prepcovars
51
+ from bartz.jaxext.scipy.special import ndtri
52
+ from bartz.jaxext.scipy.stats import invgamma
37
53
 
38
54
  FloatLike = float | Float[Any, '']
39
55
 
40
56
 
41
- class gbart:
57
+ class DataFrame(Protocol):
58
+ """DataFrame duck-type for `gbart`.
59
+
60
+ Attributes
61
+ ----------
62
+ columns : Sequence[str]
63
+ The names of the columns.
42
64
  """
43
- Nonparametric regression with Bayesian Additive Regression Trees (BART).
65
+
66
+ columns: Sequence[str]
67
+
68
+ def to_numpy(self) -> ndarray:
69
+ """Convert the dataframe to a 2d numpy array with columns on the second axis."""
70
+ ...
71
+
72
+
73
+ class Series(Protocol):
74
+ """Series duck-type for `gbart`.
75
+
76
+ Attributes
77
+ ----------
78
+ name : str | None
79
+ The name of the series.
80
+ """
81
+
82
+ name: str | None
83
+
84
+ def to_numpy(self) -> ndarray:
85
+ """Convert the series to a 1d numpy array."""
86
+ ...
87
+
88
+
89
+ class gbart(Module):
90
+ R"""
91
+ Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
44
92
 
45
93
  Regress `y_train` on `x_train` with a latent mean function represented as
46
94
  a sum of decision trees. The inference is carried out by sampling the
@@ -48,36 +96,79 @@ class gbart:
48
96
 
49
97
  Parameters
50
98
  ----------
51
- x_train : array (p, n) or DataFrame
99
+ x_train
52
100
  The training predictors.
53
- y_train : array (n,) or Series
101
+ y_train
54
102
  The training responses.
55
- x_test : array (p, m) or DataFrame, optional
103
+ x_test
56
104
  The test predictors.
57
105
  type
58
106
  The type of regression. 'wbart' for continuous regression, 'pbart' for
59
107
  binary regression with probit link.
60
- usequants : bool, default False
108
+ sparse
109
+ Whether to activate variable selection on the predictors as done in
110
+ [1]_.
111
+ theta
112
+ a
113
+ b
114
+ rho
115
+ Hyperparameters of the sparsity prior used for variable selection.
116
+
117
+ The prior distribution on the choice of predictor for each decision rule
118
+ is
119
+
120
+ .. math::
121
+ (s_1, \ldots, s_p) \sim
122
+ \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
123
+
124
+ If `theta` is not specified, it's a priori distributed according to
125
+
126
+ .. math::
127
+ \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
128
+ \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
129
+
130
+ If not specified, `rho` is set to the number of predictors p. To tune
131
+ the prior, consider setting a lower `rho` to prefer more sparsity.
132
+ If setting `theta` directly, it should be in the ballpark of p or lower
133
+ as well.
134
+ xinfo
135
+ A matrix with the cutpoins to use to bin each predictor. If not
136
+ specified, it is generated automatically according to `usequants` and
137
+ `numcut`.
138
+
139
+ Each row shall contain a sorted list of cutpoints for a predictor. If
140
+ there are less cutpoints than the number of columns in the matrix,
141
+ fill the remaining cells with NaN.
142
+
143
+ `xinfo` shall be a matrix even if `x_train` is a dataframe.
144
+ usequants
61
145
  Whether to use predictors quantiles instead of a uniform grid to bin
62
- predictors.
63
- sigest : float, optional
146
+ predictors. Ignored if `xinfo` is specified.
147
+ rm_const
148
+ How to treat predictors with no associated decision rules (i.e., there
149
+ are no available cutpoints for that predictor). If `True` (default),
150
+ they are ignored. If `False`, an error is raised if there are any. If
151
+ `None`, no check is performed, and the output of the MCMC may not make
152
+ sense if there are predictors without cutpoints. The option `None` is
153
+ provided only to allow jax tracing.
154
+ sigest
64
155
  An estimate of the residual standard deviation on `y_train`, used to set
65
156
  `lamda`. If not specified, it is estimated by linear regression (with
66
157
  intercept, and without taking into account `w`). If `y_train` has less
67
158
  than two elements, it is set to 1. If n <= p, it is set to the standard
68
159
  deviation of `y_train`. Ignored if `lamda` is specified.
69
- sigdf : int, default 3
160
+ sigdf
70
161
  The degrees of freedom of the scaled inverse-chisquared prior on the
71
162
  noise variance.
72
- sigquant : float, default 0.9
163
+ sigquant
73
164
  The quantile of the prior on the noise variance that shall match
74
165
  `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
75
- k : float, default 2
166
+ k
76
167
  The inverse scale of the prior standard deviation on the latent mean
77
168
  function, relative to half the observed range of `y_train`. If `y_train`
78
169
  has less than two elements, `k` is ignored and the scale is set to 1.
79
- power : float, default 2
80
- base : float, default 0.95
170
+ power
171
+ base
81
172
  Parameters of the prior on tree node generation. The probability that a
82
173
  node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
83
174
  power``.
@@ -94,16 +185,19 @@ class gbart:
94
185
  The prior mean of the latent mean function. If not specified, it is set
95
186
  to the mean of `y_train` for continuous regression, and to
96
187
  ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
97
- `offset` is set to 0.
98
- w : array (n,), optional
188
+ `offset` is set to 0. With binary regression, if `y_train` is all
189
+ `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
190
+ ``Phi^-1(n/(n+1))``, respectively.
191
+ w
99
192
  Coefficients that rescale the error standard deviation on each
100
193
  datapoint. Not specifying `w` is equivalent to setting it to 1 for all
101
194
  datapoints. Note: `w` is ignored in the automatic determination of
102
195
  `sigest`, so either the weights should be O(1), or `sigest` should be
103
196
  specified by the user.
104
- ntree : int, default 200
105
- The number of trees used to represent the latent mean function.
106
- numcut : int, default 255
197
+ ntree
198
+ The number of trees used to represent the latent mean function. By
199
+ default 200 for continuous regression and 50 for binary regression.
200
+ numcut
107
201
  If `usequants` is `False`: the exact number of cutpoints used to bin the
108
202
  predictors, ranging between the minimum and maximum observed values
109
203
  (excluded).
@@ -116,14 +210,17 @@ class gbart:
116
210
 
117
211
  Before running the algorithm, the predictors are compressed to the
118
212
  smallest integer type that fits the bin indices, so `numcut` is best set
119
- to the maximum value of an unsigned integer type.
120
- ndpost : int, default 1000
213
+ to the maximum value of an unsigned integer type, like 255.
214
+
215
+ Ignored if `xinfo` is specified.
216
+ ndpost
121
217
  The number of MCMC samples to save, after burn-in.
122
- nskip : int, default 100
218
+ nskip
123
219
  The number of initial MCMC samples to discard as burn-in.
124
- keepevery : int, default 1
125
- The thinning factor for the MCMC samples, after burn-in.
126
- printevery : int or None, default 100
220
+ keepevery
221
+ The thinning factor for the MCMC samples, after burn-in. By default, 1
222
+ for continuous regression and 10 for binary regression.
223
+ printevery
127
224
  The number of iterations (including thinned-away ones) between each log
128
225
  line. Set to `None` to disable logging.
129
226
 
@@ -132,34 +229,24 @@ class gbart:
132
229
  iterations is a multiple of `printevery`, so if ``nskip + keepevery *
133
230
  ndpost`` is not a multiple of `printevery`, some of the last iterations
134
231
  will not be saved.
135
- seed : int or jax random key, default 0
232
+ seed
136
233
  The seed for the random number generator.
137
- maxdepth : int, default 6
234
+ maxdepth
138
235
  The maximum depth of the trees. This is 1-based, so with the default
139
236
  ``maxdepth=6``, the depths of the levels range from 0 to 5.
140
- init_kw : dict
141
- Additional arguments passed to `mcmcstep.init`.
142
- run_mcmc_kw : dict
143
- Additional arguments passed to `mcmcloop.run_mcmc`.
237
+ init_kw
238
+ Additional arguments passed to `bartz.mcmcstep.init`.
239
+ run_mcmc_kw
240
+ Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
144
241
 
145
242
  Attributes
146
243
  ----------
147
- yhat_train : array (ndpost, n)
148
- The conditional posterior mean at `x_train` for each MCMC iteration.
149
- yhat_train_mean : array (n,)
150
- The marginal posterior mean at `x_train`.
151
- yhat_test : array (ndpost, m)
152
- The conditional posterior mean at `x_test` for each MCMC iteration.
153
- yhat_test_mean : array (m,)
154
- The marginal posterior mean at `x_test`.
155
- sigma : array (ndpost,)
156
- The standard deviation of the error.
157
- first_sigma : array (nskip,)
158
- The standard deviation of the error in the burn-in phase.
159
- offset : float
244
+ offset : Float32[Array, '']
160
245
  The prior mean of the latent mean function.
161
- sigest : float or None
246
+ sigest : Float32[Array, ''] | None
162
247
  The estimated standard deviation of the error used to set `lamda`.
248
+ yhat_test : Float32[Array, 'ndpost m'] | None
249
+ The conditional posterior mean at `x_test` for each MCMC iteration.
163
250
 
164
251
  Notes
165
252
  -----
@@ -168,68 +255,111 @@ class gbart:
168
255
 
169
256
  - If `x_train` and `x_test` are matrices, they have one predictor per row
170
257
  instead of per column.
171
- - If `type` is not specified, it is determined solely based on the data type
172
- of `y_train`, and not on whether it contains only two unique values.
173
258
  - If ``usequants=False``, R BART switches to quantiles anyway if there are
174
259
  less predictor values than the required number of bins, while bartz
175
260
  always follows the specification.
261
+ - Some functionality is missing.
176
262
  - The error variance parameter is called `lamda` instead of `lambda`.
177
- - `rm_const` is always `False`.
178
- - The default `numcut` is 255 instead of 100.
179
- - A lot of functionality is missing (e.g., variable selection).
180
263
  - There are some additional attributes, and some missing.
181
264
  - The trees have a maximum depth.
265
+ - `rm_const` refers to predictors without decision rules instead of
266
+ predictors that are constant in `x_train`.
267
+ - If `rm_const=True` and some variables are dropped, the predictors
268
+ matrix/dataframe passed to `predict` should still include them.
182
269
 
270
+ References
271
+ ----------
272
+ .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
273
+ High-Dimensional Prediction and Variable Selection”. In: Journal of the
274
+ American Statistical Association 113.522, pp. 626-636.
275
+ .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
276
+ Bayesian additive regression trees," The Annals of Applied Statistics,
277
+ Ann. Appl. Stat. 4(1), 266-298, (March 2010).
183
278
  """
184
279
 
280
+ _main_trace: mcmcloop.MainTrace
281
+ _burnin_trace: mcmcloop.BurninTrace
282
+ _mcmc_state: mcmcstep.State
283
+ _splits: Real[Array, 'p max_num_splits']
284
+ _x_train_fmt: Any = field(static=True)
285
+
286
+ ndpost: int = field(static=True)
287
+ offset: Float32[Array, '']
288
+ sigest: Float32[Array, ''] | None = None
289
+ yhat_test: Float32[Array, 'ndpost m'] | None = None
290
+
185
291
  def __init__(
186
292
  self,
187
- x_train,
188
- y_train,
293
+ x_train: Real[Array, 'p n'] | DataFrame,
294
+ y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
189
295
  *,
190
- x_test=None,
191
- type: Literal['wbart', 'pbart'] = 'wbart',
192
- usequants=False,
193
- sigest=None,
194
- sigdf=3,
195
- sigquant=0.9,
196
- k=2,
197
- power=2,
198
- base=0.95,
296
+ x_test: Real[Array, 'p m'] | DataFrame | None = None,
297
+ type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
298
+ sparse: bool = False,
299
+ theta: FloatLike | None = None,
300
+ a: FloatLike = 0.5,
301
+ b: FloatLike = 1.0,
302
+ rho: FloatLike | None = None,
303
+ xinfo: Float[Array, 'p n'] | None = None,
304
+ usequants: bool = False,
305
+ rm_const: bool | None = True,
306
+ sigest: FloatLike | None = None,
307
+ sigdf: FloatLike = 3.0,
308
+ sigquant: FloatLike = 0.9,
309
+ k: FloatLike = 2.0,
310
+ power: FloatLike = 2.0,
311
+ base: FloatLike = 0.95,
199
312
  lamda: FloatLike | None = None,
200
313
  tau_num: FloatLike | None = None,
201
314
  offset: FloatLike | None = None,
202
- w=None,
203
- ntree=200,
204
- numcut=255,
205
- ndpost=1000,
206
- nskip=100,
207
- keepevery=1,
208
- printevery=100,
209
- seed=0,
210
- maxdepth=6,
211
- init_kw=None,
212
- run_mcmc_kw=None,
315
+ w: Float[Array, ' n'] | None = None,
316
+ ntree: int | None = None,
317
+ numcut: int = 100,
318
+ ndpost: int = 1000,
319
+ nskip: int = 100,
320
+ keepevery: int | None = None,
321
+ printevery: int | None = 100,
322
+ seed: int | Key[Array, ''] = 0,
323
+ maxdepth: int = 6,
324
+ init_kw: dict | None = None,
325
+ run_mcmc_kw: dict | None = None,
213
326
  ):
327
+ # check data and put it in the right format
214
328
  x_train, x_train_fmt = self._process_predictor_input(x_train)
215
- y_train, _ = self._process_response_input(y_train)
329
+ y_train = self._process_response_input(y_train)
216
330
  self._check_same_length(x_train, y_train)
217
331
  if w is not None:
218
- w, _ = self._process_response_input(w)
332
+ w = self._process_response_input(w)
219
333
  self._check_same_length(x_train, w)
220
334
 
221
- y_train = self._process_type_settings(y_train, type, w)
335
+ # check data types are correct for continuous/binary regression
336
+ self._check_type_settings(y_train, type, w)
222
337
  # from here onwards, the type is determined by y_train.dtype == bool
338
+
339
+ # set defaults that depend on type of regression
340
+ if ntree is None:
341
+ ntree = 50 if y_train.dtype == bool else 200
342
+ if keepevery is None:
343
+ keepevery = 10 if y_train.dtype == bool else 1
344
+
345
+ # process sparsity settings
346
+ theta, a, b, rho = self._process_sparsity_settings(
347
+ x_train, sparse, theta, a, b, rho
348
+ )
349
+
350
+ # process "standardization" settings
223
351
  offset = self._process_offset_settings(y_train, offset)
224
352
  sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
225
353
  lamda, sigest = self._process_error_variance_settings(
226
354
  x_train, y_train, sigest, sigdf, sigquant, lamda
227
355
  )
228
356
 
229
- splits, max_split = self._determine_splits(x_train, usequants, numcut)
357
+ # determine splits
358
+ splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo)
230
359
  x_train = self._bin_predictors(x_train, splits)
231
360
 
232
- mcmc_state = self._setup_mcmc(
361
+ # setup and run mcmc
362
+ initial_state = self._setup_mcmc(
233
363
  x_train,
234
364
  y_train,
235
365
  offset,
@@ -243,51 +373,163 @@ class gbart:
243
373
  maxdepth,
244
374
  ntree,
245
375
  init_kw,
376
+ rm_const,
377
+ theta,
378
+ a,
379
+ b,
380
+ rho,
246
381
  )
247
382
  final_state, burnin_trace, main_trace = self._run_mcmc(
248
- mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
383
+ initial_state,
384
+ ndpost,
385
+ nskip,
386
+ keepevery,
387
+ printevery,
388
+ seed,
389
+ run_mcmc_kw,
390
+ sparse,
249
391
  )
250
392
 
251
- sigma = self._extract_sigma(main_trace)
252
- first_sigma = self._extract_sigma(burnin_trace)
253
-
393
+ # set public attributes
254
394
  self.offset = final_state.offset # from the state because of buffer donation
395
+ self.ndpost = ndpost
255
396
  self.sigest = sigest
256
- self.sigma = sigma
257
- self.first_sigma = first_sigma
258
397
 
259
- self._x_train_fmt = x_train_fmt
260
- self._splits = splits
398
+ # set private attributes
261
399
  self._main_trace = main_trace
400
+ self._burnin_trace = burnin_trace
262
401
  self._mcmc_state = final_state
402
+ self._splits = splits
403
+ self._x_train_fmt = x_train_fmt
263
404
 
405
+ # predict at test points
264
406
  if x_test is not None:
265
- yhat_test = self.predict(x_test)
266
- self.yhat_test = yhat_test
267
- self.yhat_test_mean = yhat_test.mean(axis=0)
407
+ self.yhat_test = self.predict(x_test)
408
+
409
+ @cached_property
410
+ def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
411
+ """The posterior probability of y being True at `x_test` for each MCMC iteration."""
412
+ if self.yhat_test is None or self._mcmc_state.y.dtype != bool:
413
+ return None
414
+ else:
415
+ return ndtr(self.yhat_test)
416
+
417
+ @cached_property
418
+ def prob_test_mean(self) -> Float32[Array, ' m'] | None:
419
+ """The marginal posterior probability of y being True at `x_test`."""
420
+ if self.prob_test is None:
421
+ return None
422
+ else:
423
+ return self.prob_test.mean(axis=0)
424
+
425
+ @cached_property
426
+ def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
427
+ """The posterior probability of y being True at `x_train` for each MCMC iteration."""
428
+ if self._mcmc_state.y.dtype == bool:
429
+ return ndtr(self.yhat_train)
430
+ else:
431
+ return None
432
+
433
+ @cached_property
434
+ def prob_train_mean(self) -> Float32[Array, ' n'] | None:
435
+ """The marginal posterior probability of y being True at `x_train`."""
436
+ if self.prob_train is None:
437
+ return None
438
+ else:
439
+ return self.prob_train.mean(axis=0)
440
+
441
+ @cached_property
442
+ def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None:
443
+ """The standard deviation of the error, including burn-in samples."""
444
+ if self._burnin_trace.sigma2 is None:
445
+ return None
446
+ else:
447
+ assert self._main_trace.sigma2 is not None
448
+ return jnp.sqrt(
449
+ jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2])
450
+ )
451
+
452
+ @cached_property
453
+ def sigma_mean(self) -> Float32[Array, ''] | None:
454
+ """The mean of `sigma`, only over the post-burnin samples."""
455
+ if self.sigma is None:
456
+ return None
457
+ else:
458
+ return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0)
268
459
 
269
- @functools.cached_property
270
- def yhat_train(self):
460
+ @cached_property
461
+ def varcount(self) -> Int32[Array, 'ndpost p']:
462
+ """Histogram of predictor usage for decision rules in the trees."""
463
+ return mcmcloop.compute_varcount(
464
+ self._mcmc_state.forest.max_split.size, self._main_trace
465
+ )
466
+
467
+ @cached_property
468
+ def varcount_mean(self) -> Float32[Array, ' p']:
469
+ """Average of `varcount` across MCMC iterations."""
470
+ return self.varcount.mean(axis=0)
471
+
472
+ @cached_property
473
+ def varprob(self) -> Float32[Array, 'ndpost p']:
474
+ """Posterior samples of the probability of choosing each predictor for a decision rule."""
475
+ varprob = self._main_trace.varprob
476
+ if varprob is None:
477
+ max_split = self._mcmc_state.forest.max_split
478
+ p = max_split.size
479
+ peff = jnp.count_nonzero(max_split)
480
+ varprob = jnp.where(max_split, 1 / peff, 0)
481
+ varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
482
+ return varprob
483
+
484
+ @cached_property
485
+ def varprob_mean(self) -> Float32[Array, ' p']:
486
+ """The marginal posterior probability of each predictor being chosen for a decision rule."""
487
+ return self.varprob.mean(axis=0)
488
+
489
+ @cached_property
490
+ def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
491
+ """The marginal posterior mean at `x_test`.
492
+
493
+ Not defined with binary regression because it's error-prone, typically
494
+ the right thing to consider would be `prob_test_mean`.
495
+ """
496
+ if self.yhat_test is None or self._mcmc_state.y.dtype == bool:
497
+ return None
498
+ else:
499
+ return self.yhat_test.mean(axis=0)
500
+
501
+ @cached_property
502
+ def yhat_train(self) -> Float32[Array, 'ndpost n']:
503
+ """The conditional posterior mean at `x_train` for each MCMC iteration."""
271
504
  x_train = self._mcmc_state.X
272
- return self._predict(self._main_trace, x_train)
505
+ return self._predict(x_train)
273
506
 
274
- @functools.cached_property
275
- def yhat_train_mean(self):
276
- return self.yhat_train.mean(axis=0)
507
+ @cached_property
508
+ def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
509
+ """The marginal posterior mean at `x_train`.
277
510
 
278
- def predict(self, x_test):
511
+ Not defined with binary regression because it's error-prone, typically
512
+ the right thing to consider would be `prob_train_mean`.
513
+ """
514
+ if self._mcmc_state.y.dtype == bool:
515
+ return None
516
+ else:
517
+ return self.yhat_train.mean(axis=0)
518
+
519
+ def predict(
520
+ self, x_test: Real[Array, 'p m'] | DataFrame
521
+ ) -> Float32[Array, 'ndpost m']:
279
522
  """
280
523
  Compute the posterior mean at `x_test` for each MCMC iteration.
281
524
 
282
525
  Parameters
283
526
  ----------
284
- x_test : array (p, m) or DataFrame
527
+ x_test
285
528
  The test predictors.
286
529
 
287
530
  Returns
288
531
  -------
289
- yhat_test : array (ndpost, m)
290
- The conditional posterior mean at `x_test` for each MCMC iteration.
532
+ The conditional posterior mean at `x_test` for each MCMC iteration.
291
533
 
292
534
  Raises
293
535
  ------
@@ -296,14 +538,13 @@ class gbart:
296
538
  """
297
539
  x_test, x_test_fmt = self._process_predictor_input(x_test)
298
540
  if x_test_fmt != self._x_train_fmt:
299
- raise ValueError(
300
- f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
301
- )
541
+ msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
542
+ raise ValueError(msg)
302
543
  x_test = self._bin_predictors(x_test, self._splits)
303
- return self._predict(self._main_trace, x_test)
544
+ return self._predict(x_test)
304
545
 
305
546
  @staticmethod
306
- def _process_predictor_input(x):
547
+ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
307
548
  if hasattr(x, 'columns'):
308
549
  fmt = dict(kind='dataframe', columns=x.columns)
309
550
  x = x.to_numpy().T
@@ -314,15 +555,12 @@ class gbart:
314
555
  return x, fmt
315
556
 
316
557
  @staticmethod
317
- def _process_response_input(y):
558
+ def _process_response_input(y) -> Shaped[Array, ' n']:
318
559
  if hasattr(y, 'to_numpy'):
319
- fmt = dict(kind='series', name=y.name)
320
560
  y = y.to_numpy()
321
- else:
322
- fmt = dict(kind='array')
323
561
  y = jnp.asarray(y)
324
562
  assert y.ndim == 1
325
- return y, fmt
563
+ return y
326
564
 
327
565
  @staticmethod
328
566
  def _check_same_length(x1, x2):
@@ -335,13 +573,16 @@ class gbart:
335
573
  ) -> tuple[Float32[Array, ''] | None, ...]:
336
574
  if y_train.dtype == bool:
337
575
  if sigest is not None:
338
- raise ValueError('Let `sigest=None` for binary regression')
576
+ msg = 'Let `sigest=None` for binary regression'
577
+ raise ValueError(msg)
339
578
  if lamda is not None:
340
- raise ValueError('Let `lamda=None` for binary regression')
579
+ msg = 'Let `lamda=None` for binary regression'
580
+ raise ValueError(msg)
341
581
  return None, None
342
582
  elif lamda is not None:
343
583
  if sigest is not None:
344
- raise ValueError('Let `sigest=None` if `lamda` is specified')
584
+ msg = 'Let `sigest=None` if `lamda` is specified'
585
+ raise ValueError(msg)
345
586
  return lamda, None
346
587
  else:
347
588
  if sigest is not None:
@@ -359,37 +600,60 @@ class gbart:
359
600
  dof = len(y_train) - rank
360
601
  sigest2 = chisq / dof
361
602
  alpha = sigdf / 2
362
- invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
603
+ invchi2 = invgamma.ppf(sigquant, alpha) / 2
363
604
  invchi2rid = invchi2 * sigdf
364
605
  return sigest2 / invchi2rid, jnp.sqrt(sigest2)
365
606
 
366
607
  @staticmethod
367
- def _process_type_settings(y_train, type, w):
608
+ def _check_type_settings(y_train, type, w): # noqa: A002
368
609
  match type:
369
610
  case 'wbart':
370
611
  if y_train.dtype != jnp.float32:
371
- raise TypeError(
612
+ msg = (
372
613
  'Continuous regression requires y_train.dtype=float32,'
373
614
  f' got {y_train.dtype=} instead.'
374
615
  )
616
+ raise TypeError(msg)
375
617
  case 'pbart':
376
618
  if w is not None:
377
- raise ValueError(
378
- 'Binary regression does not support weights, set `w=None`'
379
- )
619
+ msg = 'Binary regression does not support weights, set `w=None`'
620
+ raise ValueError(msg)
380
621
  if y_train.dtype != bool:
381
- raise TypeError(
622
+ msg = (
382
623
  'Binary regression requires y_train.dtype=bool,'
383
624
  f' got {y_train.dtype=} instead.'
384
625
  )
626
+ raise TypeError(msg)
385
627
  case _:
386
- raise ValueError(f'Invalid {type=}')
628
+ msg = f'Invalid {type=}'
629
+ raise ValueError(msg)
387
630
 
388
- return y_train
631
+ @staticmethod
632
+ def _process_sparsity_settings(
633
+ x_train: Real[Array, 'p n'],
634
+ sparse: bool,
635
+ theta: FloatLike | None,
636
+ a: FloatLike,
637
+ b: FloatLike,
638
+ rho: FloatLike | None,
639
+ ) -> (
640
+ tuple[None, None, None, None]
641
+ | tuple[FloatLike, None, None, None]
642
+ | tuple[None, FloatLike, FloatLike, FloatLike]
643
+ ):
644
+ if not sparse:
645
+ return None, None, None, None
646
+ elif theta is not None:
647
+ return theta, None, None, None
648
+ else:
649
+ if rho is None:
650
+ p, _ = x_train.shape
651
+ rho = float(p)
652
+ return None, a, b, rho
389
653
 
390
654
  @staticmethod
391
655
  def _process_offset_settings(
392
- y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
656
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
393
657
  offset: float | Float32[Any, ''] | None,
394
658
  ) -> Float32[Array, '']:
395
659
  if offset is not None:
@@ -400,13 +664,15 @@ class gbart:
400
664
  mean = y_train.mean()
401
665
 
402
666
  if y_train.dtype == bool:
667
+ bound = 1 / (1 + y_train.size)
668
+ mean = jnp.clip(mean, bound, 1 - bound)
403
669
  return ndtri(mean)
404
670
  else:
405
671
  return mean
406
672
 
407
673
  @staticmethod
408
674
  def _process_leaf_sdev_settings(
409
- y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
675
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
410
676
  k: float,
411
677
  ntree: int,
412
678
  tau_num: FloatLike | None,
@@ -422,31 +688,46 @@ class gbart:
422
688
  return tau_num / (k * math.sqrt(ntree))
423
689
 
424
690
  @staticmethod
425
- def _determine_splits(x_train, usequants, numcut):
426
- if usequants:
691
+ def _determine_splits(
692
+ x_train: Real[Array, 'p n'],
693
+ usequants: bool,
694
+ numcut: int,
695
+ xinfo: Float[Array, 'p n'] | None,
696
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
697
+ if xinfo is not None:
698
+ if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]:
699
+ msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)'
700
+ raise ValueError(msg)
701
+ return prepcovars.parse_xinfo(xinfo)
702
+ elif usequants:
427
703
  return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
428
704
  else:
429
705
  return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
430
706
 
431
707
  @staticmethod
432
- def _bin_predictors(x, splits):
708
+ def _bin_predictors(x, splits) -> UInt[Array, 'p n']:
433
709
  return prepcovars.bin_predictors(x, splits)
434
710
 
435
711
  @staticmethod
436
712
  def _setup_mcmc(
437
- x_train,
438
- y_train,
439
- offset,
440
- w,
441
- max_split,
442
- lamda,
443
- sigma_mu,
444
- sigdf,
445
- power,
446
- base,
447
- maxdepth,
448
- ntree,
449
- init_kw,
713
+ x_train: Real[Array, 'p n'],
714
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
715
+ offset: Float32[Array, ''],
716
+ w: Float[Array, ' n'] | None,
717
+ max_split: UInt[Array, ' p'],
718
+ lamda: Float32[Array, ''] | None,
719
+ sigma_mu: FloatLike,
720
+ sigdf: FloatLike,
721
+ power: FloatLike,
722
+ base: FloatLike,
723
+ maxdepth: int,
724
+ ntree: int,
725
+ init_kw: dict[str, Any] | None,
726
+ rm_const: bool | None,
727
+ theta: FloatLike | None,
728
+ a: FloatLike | None,
729
+ b: FloatLike | None,
730
+ rho: FloatLike | None,
450
731
  ):
451
732
  depth = jnp.arange(maxdepth - 1)
452
733
  p_nonterminal = base / (1 + depth).astype(float) ** power
@@ -470,14 +751,42 @@ class gbart:
470
751
  sigma_mu2=jnp.square(sigma_mu),
471
752
  sigma2_alpha=sigma2_alpha,
472
753
  sigma2_beta=sigma2_beta,
754
+ min_points_per_decision_node=10,
473
755
  min_points_per_leaf=5,
756
+ theta=theta,
757
+ a=a,
758
+ b=b,
759
+ rho=rho,
474
760
  )
761
+
762
+ if rm_const is None:
763
+ kw.update(filter_splitless_vars=False)
764
+ elif rm_const:
765
+ kw.update(filter_splitless_vars=True)
766
+ else:
767
+ n_empty = jnp.count_nonzero(max_split == 0)
768
+ if n_empty:
769
+ msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
770
+ raise ValueError(msg)
771
+ kw.update(filter_splitless_vars=False)
772
+
475
773
  if init_kw is not None:
476
774
  kw.update(init_kw)
775
+
477
776
  return mcmcstep.init(**kw)
478
777
 
479
778
  @staticmethod
480
- def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
779
+ def _run_mcmc(
780
+ mcmc_state: mcmcstep.State,
781
+ ndpost: int,
782
+ nskip: int,
783
+ keepevery: int,
784
+ printevery: int | None,
785
+ seed: int | Integer[Array, ''] | Key[Array, ''],
786
+ run_mcmc_kw: dict | None,
787
+ sparse: bool,
788
+ ):
789
+ # prepare random generator seed
481
790
  if isinstance(seed, jax.Array) and jnp.issubdtype(
482
791
  seed.dtype, jax.dtypes.prng_key
483
792
  ):
@@ -486,118 +795,19 @@ class gbart:
486
795
  else:
487
796
  key = jax.random.key(seed)
488
797
 
489
- kw = dict(
490
- n_burn=nskip,
491
- n_skip=keepevery,
492
- inner_loop_length=printevery,
493
- allow_overflow=True,
798
+ # prepare arguments
799
+ kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
800
+ kw.update(
801
+ mcmcloop.make_default_callback(
802
+ dot_every=None if printevery is None or printevery == 1 else 1,
803
+ report_every=printevery,
804
+ sparse_on_at=nskip // 2 if sparse else None,
805
+ )
494
806
  )
495
- if printevery is not None:
496
- kw.update(mcmcloop.make_print_callbacks())
497
807
  if run_mcmc_kw is not None:
498
808
  kw.update(run_mcmc_kw)
499
809
 
500
810
  return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
501
811
 
502
- @staticmethod
503
- def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
504
- if trace['sigma2'] is None:
505
- return None
506
- else:
507
- return jnp.sqrt(trace['sigma2'])
508
-
509
- @staticmethod
510
- def _predict(trace, x):
511
- return mcmcloop.evaluate_trace(trace, x)
512
-
513
- def _show_tree(self, i_sample, i_tree, print_all=False):
514
- from . import debug
515
-
516
- trace = self._main_trace
517
- leaf_tree = trace['leaf_trees'][i_sample, i_tree]
518
- var_tree = trace['var_trees'][i_sample, i_tree]
519
- split_tree = trace['split_trees'][i_sample, i_tree]
520
- debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
521
-
522
- def _sigma_harmonic_mean(self, prior=False):
523
- bart = self._mcmc_state
524
- if prior:
525
- alpha = bart['sigma2_alpha']
526
- beta = bart['sigma2_beta']
527
- else:
528
- resid = bart['resid']
529
- alpha = bart['sigma2_alpha'] + resid.size / 2
530
- norm2 = jnp.dot(
531
- resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
532
- )
533
- beta = bart['sigma2_beta'] + norm2 / 2
534
- sigma2 = beta / alpha
535
- return jnp.sqrt(sigma2)
536
-
537
- def _compare_resid(self):
538
- bart = self._mcmc_state
539
- resid1 = bart.resid
540
-
541
- trees = grove.evaluate_forest(
542
- bart.X,
543
- bart.forest.leaf_trees,
544
- bart.forest.var_trees,
545
- bart.forest.split_trees,
546
- jnp.float32, # TODO remove these configurable dtypes around
547
- )
548
-
549
- if bart.z is not None:
550
- ref = bart.z
551
- else:
552
- ref = bart.y
553
- resid2 = ref - (trees + bart.offset)
554
-
555
- return resid1, resid2
556
-
557
- def _avg_acc(self):
558
- trace = self._main_trace
559
-
560
- def acc(prefix):
561
- acc = trace[f'{prefix}_acc_count']
562
- prop = trace[f'{prefix}_prop_count']
563
- return acc.sum() / prop.sum()
564
-
565
- return acc('grow'), acc('prune')
566
-
567
- def _avg_prop(self):
568
- trace = self._main_trace
569
-
570
- def prop(prefix):
571
- return trace[f'{prefix}_prop_count'].sum()
572
-
573
- pgrow = prop('grow')
574
- pprune = prop('prune')
575
- total = pgrow + pprune
576
- return pgrow / total, pprune / total
577
-
578
- def _avg_move(self):
579
- agrow, aprune = self._avg_acc()
580
- pgrow, pprune = self._avg_prop()
581
- return agrow * pgrow, aprune * pprune
582
-
583
- def _depth_distr(self):
584
- from . import debug
585
-
586
- trace = self._main_trace
587
- split_trees = trace['split_trees']
588
- return debug.trace_depth_distr(split_trees)
589
-
590
- def _points_per_leaf_distr(self):
591
- from . import debug
592
-
593
- return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
594
-
595
- def _check_trees(self):
596
- from . import debug
597
-
598
- return debug.check_trace(self._main_trace, self._mcmc_state)
599
-
600
- def _tree_goes_bad(self):
601
- bad = self._check_trees().astype(bool)
602
- bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
603
- return bad & ~bad_before
812
+ def _predict(self, x):
813
+ return mcmcloop.evaluate_trace(self._main_trace, x)