bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/_interface.py ADDED
@@ -0,0 +1,937 @@
1
+ # bartz/src/bartz/_interface.py
2
+ #
3
+ # Copyright (c) 2025-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Main high-level interface of the package."""
26
+
27
+ import math
28
+ from collections.abc import Sequence
29
+ from functools import cached_property
30
+ from typing import Any, Literal, Protocol, TypedDict
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ from equinox import Module, field
35
+ from jax import Device, device_put, jit, make_mesh
36
+ from jax.lax import collapse
37
+ from jax.scipy.special import ndtr
38
+ from jax.sharding import AxisType, Mesh
39
+ from jaxtyping import (
40
+ Array,
41
+ Bool,
42
+ Float,
43
+ Float32,
44
+ Int32,
45
+ Integer,
46
+ Key,
47
+ Real,
48
+ Shaped,
49
+ UInt,
50
+ )
51
+ from numpy import ndarray
52
+
53
+ from bartz import mcmcloop, mcmcstep, prepcovars
54
+ from bartz.jaxext import is_key
55
+ from bartz.jaxext.scipy.special import ndtri
56
+ from bartz.jaxext.scipy.stats import invgamma
57
+ from bartz.mcmcloop import compute_varcount, evaluate_trace, run_mcmc
58
+ from bartz.mcmcstep import make_p_nonterminal
59
+ from bartz.mcmcstep._state import get_num_chains
60
+
61
+ FloatLike = float | Float[Any, '']
62
+
63
+
64
+ class DataFrame(Protocol):
65
+ """DataFrame duck-type for `Bart`."""
66
+
67
+ columns: Sequence[str]
68
+ """The names of the columns."""
69
+
70
+ def to_numpy(self) -> ndarray:
71
+ """Convert the dataframe to a 2d numpy array with columns on the second axis."""
72
+ ...
73
+
74
+
75
+ class Series(Protocol):
76
+ """Series duck-type for `Bart`."""
77
+
78
+ name: str | None
79
+ """The name of the series."""
80
+
81
+ def to_numpy(self) -> ndarray:
82
+ """Convert the series to a 1d numpy array."""
83
+ ...
84
+
85
+
86
+ class Bart(Module):
87
+ R"""
88
+ Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
89
+
90
+ Regress `y_train` on `x_train` with a latent mean function represented as
91
+ a sum of decision trees. The inference is carried out by sampling the
92
+ posterior distribution of the tree ensemble with an MCMC.
93
+
94
+ Parameters
95
+ ----------
96
+ x_train
97
+ The training predictors.
98
+ y_train
99
+ The training responses.
100
+ x_test
101
+ The test predictors.
102
+ type
103
+ The type of regression. 'wbart' for continuous regression, 'pbart' for
104
+ binary regression with probit link.
105
+ sparse
106
+ Whether to activate variable selection on the predictors as done in
107
+ [1]_.
108
+ theta
109
+ a
110
+ b
111
+ rho
112
+ Hyperparameters of the sparsity prior used for variable selection.
113
+
114
+ The prior distribution on the choice of predictor for each decision rule
115
+ is
116
+
117
+ .. math::
118
+ (s_1, \ldots, s_p) \sim
119
+ \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
120
+
121
+ If `theta` is not specified, it's a priori distributed according to
122
+
123
+ .. math::
124
+ \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
125
+ \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
126
+
127
+ If not specified, `rho` is set to the number of predictors p. To tune
128
+ the prior, consider setting a lower `rho` to prefer more sparsity.
129
+ If setting `theta` directly, it should be in the ballpark of p or lower
130
+ as well.
131
+ xinfo
132
+ A matrix with the cutpoins to use to bin each predictor. If not
133
+ specified, it is generated automatically according to `usequants` and
134
+ `numcut`.
135
+
136
+ Each row shall contain a sorted list of cutpoints for a predictor. If
137
+ there are less cutpoints than the number of columns in the matrix,
138
+ fill the remaining cells with NaN.
139
+
140
+ `xinfo` shall be a matrix even if `x_train` is a dataframe.
141
+ usequants
142
+ Whether to use predictors quantiles instead of a uniform grid to bin
143
+ predictors. Ignored if `xinfo` is specified.
144
+ rm_const
145
+ How to treat predictors with no associated decision rules (i.e., there
146
+ are no available cutpoints for that predictor). If `True` (default),
147
+ they are ignored. If `False`, an error is raised if there are any.
148
+ sigest
149
+ An estimate of the residual standard deviation on `y_train`, used to set
150
+ `lamda`. If not specified, it is estimated by linear regression (with
151
+ intercept, and without taking into account `w`). If `y_train` has less
152
+ than two elements, it is set to 1. If n <= p, it is set to the standard
153
+ deviation of `y_train`. Ignored if `lamda` is specified.
154
+ sigdf
155
+ The degrees of freedom of the scaled inverse-chisquared prior on the
156
+ noise variance.
157
+ sigquant
158
+ The quantile of the prior on the noise variance that shall match
159
+ `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
160
+ k
161
+ The inverse scale of the prior standard deviation on the latent mean
162
+ function, relative to half the observed range of `y_train`. If `y_train`
163
+ has less than two elements, `k` is ignored and the scale is set to 1.
164
+ power
165
+ base
166
+ Parameters of the prior on tree node generation. The probability that a
167
+ node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
168
+ power``.
169
+ lamda
170
+ The prior harmonic mean of the error variance. (The harmonic mean of x
171
+ is 1/mean(1/x).) If not specified, it is set based on `sigest` and
172
+ `sigquant`.
173
+ tau_num
174
+ The numerator in the expression that determines the prior standard
175
+ deviation of leaves. If not specified, default to ``(max(y_train) -
176
+ min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
177
+ continuous regression, and 3 for binary regression.
178
+ offset
179
+ The prior mean of the latent mean function. If not specified, it is set
180
+ to the mean of `y_train` for continuous regression, and to
181
+ ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
182
+ `offset` is set to 0. With binary regression, if `y_train` is all
183
+ `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
184
+ ``Phi^-1(n/(n+1))``, respectively.
185
+ w
186
+ Coefficients that rescale the error standard deviation on each
187
+ datapoint. Not specifying `w` is equivalent to setting it to 1 for all
188
+ datapoints. Note: `w` is ignored in the automatic determination of
189
+ `sigest`, so either the weights should be O(1), or `sigest` should be
190
+ specified by the user.
191
+ ntree
192
+ The number of trees used to represent the latent mean function. By
193
+ default 200 for continuous regression and 50 for binary regression.
194
+ numcut
195
+ If `usequants` is `False`: the exact number of cutpoints used to bin the
196
+ predictors, ranging between the minimum and maximum observed values
197
+ (excluded).
198
+
199
+ If `usequants` is `True`: the maximum number of cutpoints to use for
200
+ binning the predictors. Each predictor is binned such that its
201
+ distribution in `x_train` is approximately uniform across bins. The
202
+ number of bins is at most the number of unique values appearing in
203
+ `x_train`, or ``numcut + 1``.
204
+
205
+ Before running the algorithm, the predictors are compressed to the
206
+ smallest integer type that fits the bin indices, so `numcut` is best set
207
+ to the maximum value of an unsigned integer type, like 255.
208
+
209
+ Ignored if `xinfo` is specified.
210
+ ndpost
211
+ The number of MCMC samples to save, after burn-in. `ndpost` is the
212
+ total number of samples across all chains. `ndpost` is rounded up to the
213
+ first multiple of `mc_cores`.
214
+ nskip
215
+ The number of initial MCMC samples to discard as burn-in. This number
216
+ of samples is discarded from each chain.
217
+ keepevery
218
+ The thinning factor for the MCMC samples, after burn-in. By default, 1
219
+ for continuous regression and 10 for binary regression.
220
+ printevery
221
+ The number of iterations (including thinned-away ones) between each log
222
+ line. Set to `None` to disable logging. ^C interrupts the MCMC only
223
+ every `printevery` iterations, so with logging disabled it's impossible
224
+ to kill the MCMC conveniently.
225
+ num_chains
226
+ The number of independent Markov chains to run. By default only one
227
+ chain is run.
228
+
229
+ The difference between not specifying `num_chains` and setting it to 1
230
+ is that in the latter case in the object attributes and some methods
231
+ there will be an explicit chain axis of size 1.
232
+ num_chain_devices
233
+ The number of devices to spread the chains across. Must be a divisor of
234
+ `num_chains`. Each device will run a fraction of the chains.
235
+ num_data_devices
236
+ The number of devices to split datapoints across. Must be a divisor of
237
+ `n`. This is useful only with very high `n`, about > 1000_000.
238
+
239
+ If both num_chain_devices and num_data_devices are specified, the total
240
+ number of devices used is the product of the two.
241
+ devices
242
+ One or more devices used to run the MCMC on. If not specified, the
243
+ computation will follow the placement of the input arrays. If a list of
244
+ devices, this argument can be longer than the number of devices needed.
245
+ seed
246
+ The seed for the random number generator.
247
+ maxdepth
248
+ The maximum depth of the trees. This is 1-based, so with the default
249
+ ``maxdepth=6``, the depths of the levels range from 0 to 5.
250
+ init_kw
251
+ Additional arguments passed to `bartz.mcmcstep.init`.
252
+ run_mcmc_kw
253
+ Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
254
+
255
+ References
256
+ ----------
257
+ .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
258
+ High-Dimensional Prediction and Variable Selection”. In: Journal of the
259
+ American Statistical Association 113.522, pp. 626-636.
260
+ .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
261
+ Bayesian additive regression trees," The Annals of Applied Statistics,
262
+ Ann. Appl. Stat. 4(1), 266-298, (March 2010).
263
+ """
264
+
265
+ _main_trace: mcmcloop.MainTrace
266
+ _burnin_trace: mcmcloop.BurninTrace
267
+ _mcmc_state: mcmcstep.State
268
+ _splits: Real[Array, 'p max_num_splits']
269
+ _x_train_fmt: Any = field(static=True)
270
+
271
+ offset: Float32[Array, '']
272
+ """The prior mean of the latent mean function."""
273
+
274
+ sigest: Float32[Array, ''] | None = None
275
+ """The estimated standard deviation of the error used to set `lamda`."""
276
+
277
+ yhat_test: Float32[Array, 'ndpost m'] | None = None
278
+ """The conditional posterior mean at `x_test` for each MCMC iteration."""
279
+
280
+ def __init__(
281
+ self,
282
+ x_train: Real[Array, 'p n'] | DataFrame,
283
+ y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
284
+ *,
285
+ x_test: Real[Array, 'p m'] | DataFrame | None = None,
286
+ type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
287
+ sparse: bool = False,
288
+ theta: FloatLike | None = None,
289
+ a: FloatLike = 0.5,
290
+ b: FloatLike = 1.0,
291
+ rho: FloatLike | None = None,
292
+ xinfo: Float[Array, 'p n'] | None = None,
293
+ usequants: bool = False,
294
+ rm_const: bool = True,
295
+ sigest: FloatLike | None = None,
296
+ sigdf: FloatLike = 3.0,
297
+ sigquant: FloatLike = 0.9,
298
+ k: FloatLike = 2.0,
299
+ power: FloatLike = 2.0,
300
+ base: FloatLike = 0.95,
301
+ lamda: FloatLike | None = None,
302
+ tau_num: FloatLike | None = None,
303
+ offset: FloatLike | None = None,
304
+ w: Float[Array, ' n'] | Series | None = None,
305
+ ntree: int | None = None,
306
+ numcut: int = 100,
307
+ ndpost: int = 1000,
308
+ nskip: int = 100,
309
+ keepevery: int | None = None,
310
+ printevery: int | None = 100,
311
+ num_chains: int | None = None,
312
+ num_chain_devices: int | None = None,
313
+ num_data_devices: int | None = None,
314
+ devices: Device | Sequence[Device] | None = None,
315
+ seed: int | Key[Array, ''] = 0,
316
+ maxdepth: int = 6,
317
+ init_kw: dict | None = None,
318
+ run_mcmc_kw: dict | None = None,
319
+ ):
320
+ # check data and put it in the right format
321
+ x_train, x_train_fmt = self._process_predictor_input(x_train)
322
+ y_train = self._process_response_input(y_train)
323
+ self._check_same_length(x_train, y_train)
324
+ if w is not None:
325
+ w = self._process_response_input(w)
326
+ self._check_same_length(x_train, w)
327
+
328
+ # check data types are correct for continuous/binary regression
329
+ self._check_type_settings(y_train, type, w)
330
+ # from here onwards, the type is determined by y_train.dtype == bool
331
+
332
+ # set defaults that depend on type of regression
333
+ if ntree is None:
334
+ ntree = 50 if y_train.dtype == bool else 200
335
+ if keepevery is None:
336
+ keepevery = 10 if y_train.dtype == bool else 1
337
+
338
+ # process sparsity settings
339
+ theta, a, b, rho = self._process_sparsity_settings(
340
+ x_train, sparse, theta, a, b, rho
341
+ )
342
+
343
+ # process "standardization" settings
344
+ offset = self._process_offset_settings(y_train, offset)
345
+ sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
346
+ lamda, sigest = self._process_error_variance_settings(
347
+ x_train, y_train, sigest, sigdf, sigquant, lamda
348
+ )
349
+
350
+ # determine splits
351
+ splits, max_split = self._determine_splits(x_train, usequants, numcut, xinfo)
352
+ x_train = self._bin_predictors(x_train, splits)
353
+
354
+ # setup and run mcmc
355
+ initial_state = self._setup_mcmc(
356
+ x_train,
357
+ y_train,
358
+ offset,
359
+ w,
360
+ max_split,
361
+ lamda,
362
+ sigma_mu,
363
+ sigdf,
364
+ power,
365
+ base,
366
+ maxdepth,
367
+ ntree,
368
+ init_kw,
369
+ rm_const,
370
+ theta,
371
+ a,
372
+ b,
373
+ rho,
374
+ num_chains,
375
+ num_chain_devices,
376
+ num_data_devices,
377
+ devices,
378
+ sparse,
379
+ nskip,
380
+ )
381
+ final_state, burnin_trace, main_trace = self._run_mcmc(
382
+ initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
383
+ )
384
+
385
+ # set public attributes
386
+ self.offset = final_state.offset # from the state because of buffer donation
387
+ self.sigest = sigest
388
+
389
+ # set private attributes
390
+ self._main_trace = main_trace
391
+ self._burnin_trace = burnin_trace
392
+ self._mcmc_state = final_state
393
+ self._splits = splits
394
+ self._x_train_fmt = x_train_fmt
395
+
396
+ # predict at test points
397
+ if x_test is not None:
398
+ self.yhat_test = self.predict(x_test)
399
+
400
+ @property
401
+ def ndpost(self):
402
+ """The total number of posterior samples after burn-in across all chains.
403
+
404
+ May be larger than the initialization argument `ndpost` if it was not
405
+ divisible by the number of chains.
406
+ """
407
+ return self._main_trace.grow_prop_count.size
408
+
409
+ @cached_property
410
+ def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
411
+ """The posterior probability of y being True at `x_test` for each MCMC iteration."""
412
+ if self.yhat_test is None or self._mcmc_state.y.dtype != bool:
413
+ return None
414
+ else:
415
+ return ndtr(self.yhat_test)
416
+
417
+ @cached_property
418
+ def prob_test_mean(self) -> Float32[Array, ' m'] | None:
419
+ """The marginal posterior probability of y being True at `x_test`."""
420
+ if self.prob_test is None:
421
+ return None
422
+ else:
423
+ return self.prob_test.mean(axis=0)
424
+
425
+ @cached_property
426
+ def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
427
+ """The posterior probability of y being True at `x_train` for each MCMC iteration."""
428
+ if self._mcmc_state.y.dtype == bool:
429
+ return ndtr(self.yhat_train)
430
+ else:
431
+ return None
432
+
433
+ @cached_property
434
+ def prob_train_mean(self) -> Float32[Array, ' n'] | None:
435
+ """The marginal posterior probability of y being True at `x_train`."""
436
+ if self.prob_train is None:
437
+ return None
438
+ else:
439
+ return self.prob_train.mean(axis=0)
440
+
441
+ @cached_property
442
+ def sigma(
443
+ self,
444
+ ) -> (
445
+ Float32[Array, ' nskip+ndpost']
446
+ | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
447
+ | None
448
+ ):
449
+ """The standard deviation of the error, including burn-in samples."""
450
+ if self._burnin_trace.error_cov_inv is None:
451
+ return None
452
+ assert self._main_trace.error_cov_inv is not None
453
+ return jnp.sqrt(
454
+ jnp.reciprocal(
455
+ jnp.concatenate(
456
+ [
457
+ self._burnin_trace.error_cov_inv.T,
458
+ self._main_trace.error_cov_inv.T,
459
+ ],
460
+ axis=0,
461
+ # error_cov_inv has shape (chains? samples) in the trace
462
+ )
463
+ )
464
+ )
465
+
466
+ @cached_property
467
+ def sigma_(self) -> Float32[Array, 'ndpost'] | None:
468
+ """The standard deviation of the error, only over the post-burnin samples and flattened."""
469
+ error_cov_inv = self._main_trace.error_cov_inv
470
+ if error_cov_inv is None:
471
+ return None
472
+ else:
473
+ return jnp.sqrt(jnp.reciprocal(error_cov_inv)).reshape(-1)
474
+
475
+ @cached_property
476
+ def sigma_mean(self) -> Float32[Array, ''] | None:
477
+ """The mean of `sigma`, only over the post-burnin samples."""
478
+ if self.sigma_ is None:
479
+ return None
480
+ return self.sigma_.mean()
481
+
482
+ @cached_property
483
+ def varcount(self) -> Int32[Array, 'ndpost p']:
484
+ """Histogram of predictor usage for decision rules in the trees."""
485
+ p = self._mcmc_state.forest.max_split.size
486
+ varcount: Int32[Array, '*chains samples p']
487
+ varcount = compute_varcount(p, self._main_trace)
488
+ return collapse(varcount, 0, -1)
489
+
490
+ @cached_property
491
+ def varcount_mean(self) -> Float32[Array, ' p']:
492
+ """Average of `varcount` across MCMC iterations."""
493
+ return self.varcount.mean(axis=0)
494
+
495
+ @cached_property
496
+ def varprob(self) -> Float32[Array, 'ndpost p']:
497
+ """Posterior samples of the probability of choosing each predictor for a decision rule."""
498
+ max_split = self._mcmc_state.forest.max_split
499
+ p = max_split.size
500
+ varprob = self._main_trace.varprob
501
+ if varprob is None:
502
+ peff = jnp.count_nonzero(max_split)
503
+ varprob = jnp.where(max_split, 1 / peff, 0)
504
+ varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
505
+ else:
506
+ varprob = varprob.reshape(-1, p)
507
+ return varprob
508
+
509
+ @cached_property
510
+ def varprob_mean(self) -> Float32[Array, ' p']:
511
+ """The marginal posterior probability of each predictor being chosen for a decision rule."""
512
+ return self.varprob.mean(axis=0)
513
+
514
+ @cached_property
515
+ def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
516
+ """The marginal posterior mean at `x_test`.
517
+
518
+ Not defined with binary regression because it's error-prone, typically
519
+ the right thing to consider would be `prob_test_mean`.
520
+ """
521
+ if self.yhat_test is None or self._mcmc_state.y.dtype == bool:
522
+ return None
523
+ else:
524
+ return self.yhat_test.mean(axis=0)
525
+
526
+ @cached_property
527
+ def yhat_train(self) -> Float32[Array, 'ndpost n']:
528
+ """The conditional posterior mean at `x_train` for each MCMC iteration."""
529
+ x_train = self._mcmc_state.X
530
+ return self._predict(x_train)
531
+
532
+ @cached_property
533
+ def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
534
+ """The marginal posterior mean at `x_train`.
535
+
536
+ Not defined with binary regression because it's error-prone, typically
537
+ the right thing to consider would be `prob_train_mean`.
538
+ """
539
+ if self._mcmc_state.y.dtype == bool:
540
+ return None
541
+ else:
542
+ return self.yhat_train.mean(axis=0)
543
+
544
+ def predict(
545
+ self, x_test: Real[Array, 'p m'] | DataFrame
546
+ ) -> Float32[Array, 'ndpost m']:
547
+ """
548
+ Compute the posterior mean at `x_test` for each MCMC iteration.
549
+
550
+ Parameters
551
+ ----------
552
+ x_test
553
+ The test predictors.
554
+
555
+ Returns
556
+ -------
557
+ The conditional posterior mean at `x_test` for each MCMC iteration.
558
+
559
+ Raises
560
+ ------
561
+ ValueError
562
+ If `x_test` has a different format than `x_train`.
563
+ """
564
+ x_test, x_test_fmt = self._process_predictor_input(x_test)
565
+ if x_test_fmt != self._x_train_fmt:
566
+ msg = f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
567
+ raise ValueError(msg)
568
+ x_test = self._bin_predictors(x_test, self._splits)
569
+ return self._predict(x_test)
570
+
571
+ @staticmethod
572
+ def _process_predictor_input(x) -> tuple[Shaped[Array, 'p n'], Any]:
573
+ if hasattr(x, 'columns'):
574
+ fmt = dict(kind='dataframe', columns=x.columns)
575
+ x = x.to_numpy().T
576
+ else:
577
+ fmt = dict(kind='array', num_covar=x.shape[0])
578
+ x = jnp.asarray(x)
579
+ assert x.ndim == 2
580
+ return x, fmt
581
+
582
+ @staticmethod
583
+ def _process_response_input(y) -> Shaped[Array, ' n']:
584
+ if hasattr(y, 'to_numpy'):
585
+ y = y.to_numpy()
586
+ y = jnp.asarray(y)
587
+ assert y.ndim == 1
588
+ return y
589
+
590
+ @staticmethod
591
+ def _check_same_length(x1, x2):
592
+ get_length = lambda x: x.shape[-1]
593
+ assert get_length(x1) == get_length(x2)
594
+
595
+ @classmethod
596
+ def _process_error_variance_settings(
597
+ cls, x_train, y_train, sigest, sigdf, sigquant, lamda
598
+ ) -> tuple[Float32[Array, ''] | None, ...]:
599
+ """Return (lamda, sigest)."""
600
+ if y_train.dtype == bool:
601
+ if sigest is not None:
602
+ msg = 'Let `sigest=None` for binary regression'
603
+ raise ValueError(msg)
604
+ if lamda is not None:
605
+ msg = 'Let `lamda=None` for binary regression'
606
+ raise ValueError(msg)
607
+ return None, None
608
+ elif lamda is not None:
609
+ if sigest is not None:
610
+ msg = 'Let `sigest=None` if `lamda` is specified'
611
+ raise ValueError(msg)
612
+ return lamda, None
613
+ else:
614
+ if sigest is not None:
615
+ sigest2 = jnp.square(sigest)
616
+ elif y_train.size < 2:
617
+ sigest2 = 1
618
+ elif y_train.size <= x_train.shape[0]:
619
+ sigest2 = jnp.var(y_train)
620
+ else:
621
+ sigest2 = cls._linear_regression(x_train, y_train)
622
+ alpha = sigdf / 2
623
+ invchi2 = invgamma.ppf(sigquant, alpha) / 2
624
+ invchi2rid = invchi2 * sigdf
625
+ return sigest2 / invchi2rid, jnp.sqrt(sigest2)
626
+
627
+ @staticmethod
628
+ @jit
629
+ def _linear_regression(
630
+ x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n']
631
+ ):
632
+ """Return the error variance estimated with OLS with intercept."""
633
+ x_centered = x_train.T - x_train.mean(axis=1)
634
+ y_centered = y_train - y_train.mean()
635
+ # centering is equivalent to adding an intercept column
636
+ _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
637
+ chisq = chisq.squeeze(0)
638
+ dof = len(y_train) - rank
639
+ return chisq / dof
640
+
641
+ @staticmethod
642
+ def _check_type_settings(y_train, type, w): # noqa: A002
643
+ match type:
644
+ case 'wbart':
645
+ if y_train.dtype != jnp.float32:
646
+ msg = (
647
+ 'Continuous regression requires y_train.dtype=float32,'
648
+ f' got {y_train.dtype=} instead.'
649
+ )
650
+ raise TypeError(msg)
651
+ case 'pbart':
652
+ if w is not None:
653
+ msg = 'Binary regression does not support weights, set `w=None`'
654
+ raise ValueError(msg)
655
+ if y_train.dtype != bool:
656
+ msg = (
657
+ 'Binary regression requires y_train.dtype=bool,'
658
+ f' got {y_train.dtype=} instead.'
659
+ )
660
+ raise TypeError(msg)
661
+ case _:
662
+ msg = f'Invalid {type=}'
663
+ raise ValueError(msg)
664
+
665
+ @staticmethod
666
+ def _process_sparsity_settings(
667
+ x_train: Real[Array, 'p n'],
668
+ sparse: bool,
669
+ theta: FloatLike | None,
670
+ a: FloatLike,
671
+ b: FloatLike,
672
+ rho: FloatLike | None,
673
+ ) -> (
674
+ tuple[None, None, None, None]
675
+ | tuple[FloatLike, None, None, None]
676
+ | tuple[None, FloatLike, FloatLike, FloatLike]
677
+ ):
678
+ """Return (theta, a, b, rho)."""
679
+ if not sparse:
680
+ return None, None, None, None
681
+ elif theta is not None:
682
+ return theta, None, None, None
683
+ else:
684
+ if rho is None:
685
+ p, _ = x_train.shape
686
+ rho = float(p)
687
+ return None, a, b, rho
688
+
689
+ @staticmethod
690
+ def _process_offset_settings(
691
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
692
+ offset: float | Float32[Any, ''] | None,
693
+ ) -> Float32[Array, '']:
694
+ """Return offset."""
695
+ if offset is not None:
696
+ return jnp.asarray(offset)
697
+ elif y_train.size < 1:
698
+ return jnp.array(0.0)
699
+ else:
700
+ mean = y_train.mean()
701
+
702
+ if y_train.dtype == bool:
703
+ bound = 1 / (1 + y_train.size)
704
+ mean = jnp.clip(mean, bound, 1 - bound)
705
+ return ndtri(mean)
706
+ else:
707
+ return mean
708
+
709
+ @staticmethod
710
+ def _process_leaf_sdev_settings(
711
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
712
+ k: float,
713
+ ntree: int,
714
+ tau_num: FloatLike | None,
715
+ ):
716
+ """Return sigma_mu."""
717
+ if tau_num is None:
718
+ if y_train.dtype == bool:
719
+ tau_num = 3.0
720
+ elif y_train.size < 2:
721
+ tau_num = 1.0
722
+ else:
723
+ tau_num = (y_train.max() - y_train.min()) / 2
724
+
725
+ return tau_num / (k * math.sqrt(ntree))
726
+
727
+ @staticmethod
728
+ def _determine_splits(
729
+ x_train: Real[Array, 'p n'],
730
+ usequants: bool,
731
+ numcut: int,
732
+ xinfo: Float[Array, 'p n'] | None,
733
+ ) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
734
+ if xinfo is not None:
735
+ if xinfo.ndim != 2 or xinfo.shape[0] != x_train.shape[0]:
736
+ msg = f'{xinfo.shape=} different from expected ({x_train.shape[0]}, *)'
737
+ raise ValueError(msg)
738
+ return prepcovars.parse_xinfo(xinfo)
739
+ elif usequants:
740
+ return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
741
+ else:
742
+ return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
743
+
744
+ @staticmethod
745
+ def _bin_predictors(
746
+ x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits']
747
+ ) -> UInt[Array, 'p n']:
748
+ return prepcovars.bin_predictors(x, splits)
749
+
750
+ @staticmethod
751
+ def _setup_mcmc(
752
+ x_train: Real[Array, 'p n'],
753
+ y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
754
+ offset: Float32[Array, ''],
755
+ w: Float[Array, ' n'] | None,
756
+ max_split: UInt[Array, ' p'],
757
+ lamda: Float32[Array, ''] | None,
758
+ sigma_mu: FloatLike,
759
+ sigdf: FloatLike,
760
+ power: FloatLike,
761
+ base: FloatLike,
762
+ maxdepth: int,
763
+ ntree: int,
764
+ init_kw: dict[str, Any] | None,
765
+ rm_const: bool,
766
+ theta: FloatLike | None,
767
+ a: FloatLike | None,
768
+ b: FloatLike | None,
769
+ rho: FloatLike | None,
770
+ num_chains: int | None,
771
+ num_chain_devices: int | None,
772
+ num_data_devices: int | None,
773
+ devices: Device | Sequence[Device] | None,
774
+ sparse: bool,
775
+ nskip: int,
776
+ ):
777
+ p_nonterminal = make_p_nonterminal(maxdepth, base, power)
778
+
779
+ if y_train.dtype == bool:
780
+ error_cov_df = None
781
+ error_cov_scale = None
782
+ else:
783
+ assert lamda is not None
784
+ # inverse gamma prior: alpha = df / 2, beta = scale / 2
785
+ error_cov_df = sigdf
786
+ error_cov_scale = lamda * sigdf
787
+
788
+ # process device settings
789
+ device_kw, device = process_device_settings(
790
+ y_train, num_chains, num_chain_devices, num_data_devices, devices
791
+ )
792
+
793
+ kw: dict = dict(
794
+ X=x_train,
795
+ # copy y_train because it's going to be donated in the mcmc loop
796
+ y=jnp.array(y_train),
797
+ offset=offset,
798
+ error_scale=w,
799
+ max_split=max_split,
800
+ num_trees=ntree,
801
+ p_nonterminal=p_nonterminal,
802
+ leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)),
803
+ error_cov_df=error_cov_df,
804
+ error_cov_scale=error_cov_scale,
805
+ min_points_per_decision_node=10,
806
+ min_points_per_leaf=5,
807
+ theta=theta,
808
+ a=a,
809
+ b=b,
810
+ rho=rho,
811
+ sparse_on_at=nskip // 2 if sparse else None,
812
+ **device_kw,
813
+ )
814
+
815
+ if rm_const:
816
+ n_empty = jnp.sum(max_split == 0).item()
817
+ kw.update(filter_splitless_vars=n_empty)
818
+
819
+ if init_kw is not None:
820
+ kw.update(init_kw)
821
+
822
+ state = mcmcstep.init(**kw)
823
+
824
+ # put state on device if requested explicitly by the user
825
+ if device is not None:
826
+ state = device_put(state, device, donate=True)
827
+
828
+ return state
829
+
830
+ @classmethod
831
+ def _run_mcmc(
832
+ cls,
833
+ mcmc_state: mcmcstep.State,
834
+ ndpost: int,
835
+ nskip: int,
836
+ keepevery: int,
837
+ printevery: int | None,
838
+ seed: int | Integer[Array, ''] | Key[Array, ''],
839
+ run_mcmc_kw: dict | None,
840
+ ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]:
841
+ # prepare random generator seed
842
+ if is_key(seed):
843
+ key = jnp.copy(seed)
844
+ else:
845
+ key = jax.random.key(seed)
846
+
847
+ # round up ndpost
848
+ num_chains = get_num_chains(mcmc_state)
849
+ if num_chains is None:
850
+ num_chains = 1
851
+ n_save = ndpost // num_chains + bool(ndpost % num_chains)
852
+
853
+ # prepare arguments
854
+ kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
855
+ kw.update(
856
+ mcmcloop.make_default_callback(
857
+ mcmc_state,
858
+ dot_every=None if printevery is None or printevery == 1 else 1,
859
+ report_every=printevery,
860
+ )
861
+ )
862
+ if run_mcmc_kw is not None:
863
+ kw.update(run_mcmc_kw)
864
+
865
+ return run_mcmc(key, mcmc_state, n_save, **kw)
866
+
867
+ def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']:
868
+ """Evaluate trees on already quantized `x`."""
869
+ out = evaluate_trace(x, self._main_trace)
870
+ return collapse(out, 0, -1)
871
+
872
+
873
+ class DeviceKwArgs(TypedDict):
874
+ num_chains: int | None
875
+ mesh: Mesh | None
876
+ target_platform: Literal['cpu', 'gpu'] | None
877
+
878
+
879
+ def process_device_settings(
880
+ y_train: Array,
881
+ num_chains: int | None,
882
+ num_chain_devices: int | None,
883
+ num_data_devices: int | None,
884
+ devices: Device | Sequence[Device] | None,
885
+ ) -> tuple[DeviceKwArgs, Device | None]:
886
+ """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
887
+ # determine devices
888
+ if devices is not None:
889
+ if not hasattr(devices, '__len__'):
890
+ devices = (devices,)
891
+ device = devices[0]
892
+ platform = device.platform
893
+ elif hasattr(y_train, 'platform'):
894
+ platform = y_train.platform()
895
+ device = None
896
+ # set device=None because if the devices were not specified explicitly
897
+ # we may be in the case where computation will follow data placement,
898
+ # do not disturb jax as the user may be playing with vmap, jit, reshard...
899
+ devices = jax.devices(platform)
900
+ else:
901
+ msg = 'not possible to infer device from `y_train`, please set `devices`'
902
+ raise ValueError(msg)
903
+
904
+ # create mesh
905
+ if num_chain_devices is None and num_data_devices is None:
906
+ mesh = None
907
+ else:
908
+ mesh = dict()
909
+ if num_chain_devices is not None:
910
+ mesh.update(chains=num_chain_devices)
911
+ if num_data_devices is not None:
912
+ mesh.update(data=num_data_devices)
913
+ mesh = make_mesh(
914
+ axis_shapes=tuple(mesh.values()),
915
+ axis_names=tuple(mesh),
916
+ axis_types=(AxisType.Auto,) * len(mesh),
917
+ devices=devices,
918
+ )
919
+ device = None
920
+ # set device=None because `mcmcstep.init` will `device_put` with the
921
+ # mesh already, we don't want to undo its work
922
+
923
+ # prepare arguments to `init`
924
+ settings = DeviceKwArgs(
925
+ num_chains=num_chains,
926
+ mesh=mesh,
927
+ target_platform=None
928
+ if mesh is not None or hasattr(y_train, 'platform')
929
+ else platform,
930
+ # here we don't take into account the case where the user has set both
931
+ # batch sizes; since the user has to be playing with `init_kw` to do
932
+ # that, we'll let `init` throw the error and the user set
933
+ # `target_platform` themselves so they have a clearer idea how the
934
+ # thing works.
935
+ )
936
+
937
+ return settings, device