bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/BART.py DELETED
@@ -1,603 +0,0 @@
1
- # bartz/src/bartz/BART.py
2
- #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
4
- #
5
- # This file is part of bartz.
6
- #
7
- # Permission is hereby granted, free of charge, to any person obtaining a copy
8
- # of this software and associated documentation files (the "Software"), to deal
9
- # in the Software without restriction, including without limitation the rights
10
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
- # copies of the Software, and to permit persons to whom the Software is
12
- # furnished to do so, subject to the following conditions:
13
- #
14
- # The above copyright notice and this permission notice shall be included in all
15
- # copies or substantial portions of the Software.
16
- #
17
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
- # SOFTWARE.
24
-
25
- """Implement a user interface that mimics the R BART package."""
26
-
27
- import functools
28
- import math
29
- from typing import Any, Literal
30
-
31
- import jax
32
- import jax.numpy as jnp
33
- from jax.scipy.special import ndtri
34
- from jaxtyping import Array, Bool, Float, Float32
35
-
36
- from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
37
-
38
- FloatLike = float | Float[Any, '']
39
-
40
-
41
- class gbart:
42
- """
43
- Nonparametric regression with Bayesian Additive Regression Trees (BART).
44
-
45
- Regress `y_train` on `x_train` with a latent mean function represented as
46
- a sum of decision trees. The inference is carried out by sampling the
47
- posterior distribution of the tree ensemble with an MCMC.
48
-
49
- Parameters
50
- ----------
51
- x_train : array (p, n) or DataFrame
52
- The training predictors.
53
- y_train : array (n,) or Series
54
- The training responses.
55
- x_test : array (p, m) or DataFrame, optional
56
- The test predictors.
57
- type
58
- The type of regression. 'wbart' for continuous regression, 'pbart' for
59
- binary regression with probit link.
60
- usequants : bool, default False
61
- Whether to use predictors quantiles instead of a uniform grid to bin
62
- predictors.
63
- sigest : float, optional
64
- An estimate of the residual standard deviation on `y_train`, used to set
65
- `lamda`. If not specified, it is estimated by linear regression (with
66
- intercept, and without taking into account `w`). If `y_train` has less
67
- than two elements, it is set to 1. If n <= p, it is set to the standard
68
- deviation of `y_train`. Ignored if `lamda` is specified.
69
- sigdf : int, default 3
70
- The degrees of freedom of the scaled inverse-chisquared prior on the
71
- noise variance.
72
- sigquant : float, default 0.9
73
- The quantile of the prior on the noise variance that shall match
74
- `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
75
- k : float, default 2
76
- The inverse scale of the prior standard deviation on the latent mean
77
- function, relative to half the observed range of `y_train`. If `y_train`
78
- has less than two elements, `k` is ignored and the scale is set to 1.
79
- power : float, default 2
80
- base : float, default 0.95
81
- Parameters of the prior on tree node generation. The probability that a
82
- node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
83
- power``.
84
- lamda
85
- The prior harmonic mean of the error variance. (The harmonic mean of x
86
- is 1/mean(1/x).) If not specified, it is set based on `sigest` and
87
- `sigquant`.
88
- tau_num
89
- The numerator in the expression that determines the prior standard
90
- deviation of leaves. If not specified, default to ``(max(y_train) -
91
- min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
92
- continuous regression, and 3 for binary regression.
93
- offset
94
- The prior mean of the latent mean function. If not specified, it is set
95
- to the mean of `y_train` for continuous regression, and to
96
- ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
97
- `offset` is set to 0.
98
- w : array (n,), optional
99
- Coefficients that rescale the error standard deviation on each
100
- datapoint. Not specifying `w` is equivalent to setting it to 1 for all
101
- datapoints. Note: `w` is ignored in the automatic determination of
102
- `sigest`, so either the weights should be O(1), or `sigest` should be
103
- specified by the user.
104
- ntree : int, default 200
105
- The number of trees used to represent the latent mean function.
106
- numcut : int, default 255
107
- If `usequants` is `False`: the exact number of cutpoints used to bin the
108
- predictors, ranging between the minimum and maximum observed values
109
- (excluded).
110
-
111
- If `usequants` is `True`: the maximum number of cutpoints to use for
112
- binning the predictors. Each predictor is binned such that its
113
- distribution in `x_train` is approximately uniform across bins. The
114
- number of bins is at most the number of unique values appearing in
115
- `x_train`, or ``numcut + 1``.
116
-
117
- Before running the algorithm, the predictors are compressed to the
118
- smallest integer type that fits the bin indices, so `numcut` is best set
119
- to the maximum value of an unsigned integer type.
120
- ndpost : int, default 1000
121
- The number of MCMC samples to save, after burn-in.
122
- nskip : int, default 100
123
- The number of initial MCMC samples to discard as burn-in.
124
- keepevery : int, default 1
125
- The thinning factor for the MCMC samples, after burn-in.
126
- printevery : int or None, default 100
127
- The number of iterations (including thinned-away ones) between each log
128
- line. Set to `None` to disable logging.
129
-
130
- `printevery` has a few unexpected side effects. On cpu, interrupting
131
- with ^C halts the MCMC only on the next log. And the total number of
132
- iterations is a multiple of `printevery`, so if ``nskip + keepevery *
133
- ndpost`` is not a multiple of `printevery`, some of the last iterations
134
- will not be saved.
135
- seed : int or jax random key, default 0
136
- The seed for the random number generator.
137
- maxdepth : int, default 6
138
- The maximum depth of the trees. This is 1-based, so with the default
139
- ``maxdepth=6``, the depths of the levels range from 0 to 5.
140
- init_kw : dict
141
- Additional arguments passed to `mcmcstep.init`.
142
- run_mcmc_kw : dict
143
- Additional arguments passed to `mcmcloop.run_mcmc`.
144
-
145
- Attributes
146
- ----------
147
- yhat_train : array (ndpost, n)
148
- The conditional posterior mean at `x_train` for each MCMC iteration.
149
- yhat_train_mean : array (n,)
150
- The marginal posterior mean at `x_train`.
151
- yhat_test : array (ndpost, m)
152
- The conditional posterior mean at `x_test` for each MCMC iteration.
153
- yhat_test_mean : array (m,)
154
- The marginal posterior mean at `x_test`.
155
- sigma : array (ndpost,)
156
- The standard deviation of the error.
157
- first_sigma : array (nskip,)
158
- The standard deviation of the error in the burn-in phase.
159
- offset : float
160
- The prior mean of the latent mean function.
161
- sigest : float or None
162
- The estimated standard deviation of the error used to set `lamda`.
163
-
164
- Notes
165
- -----
166
- This interface imitates the function ``gbart`` from the R package `BART
167
- <https://cran.r-project.org/package=BART>`_, but with these differences:
168
-
169
- - If `x_train` and `x_test` are matrices, they have one predictor per row
170
- instead of per column.
171
- - If `type` is not specified, it is determined solely based on the data type
172
- of `y_train`, and not on whether it contains only two unique values.
173
- - If ``usequants=False``, R BART switches to quantiles anyway if there are
174
- less predictor values than the required number of bins, while bartz
175
- always follows the specification.
176
- - The error variance parameter is called `lamda` instead of `lambda`.
177
- - `rm_const` is always `False`.
178
- - The default `numcut` is 255 instead of 100.
179
- - A lot of functionality is missing (e.g., variable selection).
180
- - There are some additional attributes, and some missing.
181
- - The trees have a maximum depth.
182
-
183
- """
184
-
185
- def __init__(
186
- self,
187
- x_train,
188
- y_train,
189
- *,
190
- x_test=None,
191
- type: Literal['wbart', 'pbart'] = 'wbart',
192
- usequants=False,
193
- sigest=None,
194
- sigdf=3,
195
- sigquant=0.9,
196
- k=2,
197
- power=2,
198
- base=0.95,
199
- lamda: FloatLike | None = None,
200
- tau_num: FloatLike | None = None,
201
- offset: FloatLike | None = None,
202
- w=None,
203
- ntree=200,
204
- numcut=255,
205
- ndpost=1000,
206
- nskip=100,
207
- keepevery=1,
208
- printevery=100,
209
- seed=0,
210
- maxdepth=6,
211
- init_kw=None,
212
- run_mcmc_kw=None,
213
- ):
214
- x_train, x_train_fmt = self._process_predictor_input(x_train)
215
- y_train, _ = self._process_response_input(y_train)
216
- self._check_same_length(x_train, y_train)
217
- if w is not None:
218
- w, _ = self._process_response_input(w)
219
- self._check_same_length(x_train, w)
220
-
221
- y_train = self._process_type_settings(y_train, type, w)
222
- # from here onwards, the type is determined by y_train.dtype == bool
223
- offset = self._process_offset_settings(y_train, offset)
224
- sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
225
- lamda, sigest = self._process_error_variance_settings(
226
- x_train, y_train, sigest, sigdf, sigquant, lamda
227
- )
228
-
229
- splits, max_split = self._determine_splits(x_train, usequants, numcut)
230
- x_train = self._bin_predictors(x_train, splits)
231
-
232
- mcmc_state = self._setup_mcmc(
233
- x_train,
234
- y_train,
235
- offset,
236
- w,
237
- max_split,
238
- lamda,
239
- sigma_mu,
240
- sigdf,
241
- power,
242
- base,
243
- maxdepth,
244
- ntree,
245
- init_kw,
246
- )
247
- final_state, burnin_trace, main_trace = self._run_mcmc(
248
- mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
249
- )
250
-
251
- sigma = self._extract_sigma(main_trace)
252
- first_sigma = self._extract_sigma(burnin_trace)
253
-
254
- self.offset = final_state.offset # from the state because of buffer donation
255
- self.sigest = sigest
256
- self.sigma = sigma
257
- self.first_sigma = first_sigma
258
-
259
- self._x_train_fmt = x_train_fmt
260
- self._splits = splits
261
- self._main_trace = main_trace
262
- self._mcmc_state = final_state
263
-
264
- if x_test is not None:
265
- yhat_test = self.predict(x_test)
266
- self.yhat_test = yhat_test
267
- self.yhat_test_mean = yhat_test.mean(axis=0)
268
-
269
- @functools.cached_property
270
- def yhat_train(self):
271
- x_train = self._mcmc_state.X
272
- return self._predict(self._main_trace, x_train)
273
-
274
- @functools.cached_property
275
- def yhat_train_mean(self):
276
- return self.yhat_train.mean(axis=0)
277
-
278
- def predict(self, x_test):
279
- """
280
- Compute the posterior mean at `x_test` for each MCMC iteration.
281
-
282
- Parameters
283
- ----------
284
- x_test : array (p, m) or DataFrame
285
- The test predictors.
286
-
287
- Returns
288
- -------
289
- yhat_test : array (ndpost, m)
290
- The conditional posterior mean at `x_test` for each MCMC iteration.
291
-
292
- Raises
293
- ------
294
- ValueError
295
- If `x_test` has a different format than `x_train`.
296
- """
297
- x_test, x_test_fmt = self._process_predictor_input(x_test)
298
- if x_test_fmt != self._x_train_fmt:
299
- raise ValueError(
300
- f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
301
- )
302
- x_test = self._bin_predictors(x_test, self._splits)
303
- return self._predict(self._main_trace, x_test)
304
-
305
- @staticmethod
306
- def _process_predictor_input(x):
307
- if hasattr(x, 'columns'):
308
- fmt = dict(kind='dataframe', columns=x.columns)
309
- x = x.to_numpy().T
310
- else:
311
- fmt = dict(kind='array', num_covar=x.shape[0])
312
- x = jnp.asarray(x)
313
- assert x.ndim == 2
314
- return x, fmt
315
-
316
- @staticmethod
317
- def _process_response_input(y):
318
- if hasattr(y, 'to_numpy'):
319
- fmt = dict(kind='series', name=y.name)
320
- y = y.to_numpy()
321
- else:
322
- fmt = dict(kind='array')
323
- y = jnp.asarray(y)
324
- assert y.ndim == 1
325
- return y, fmt
326
-
327
- @staticmethod
328
- def _check_same_length(x1, x2):
329
- get_length = lambda x: x.shape[-1]
330
- assert get_length(x1) == get_length(x2)
331
-
332
- @staticmethod
333
- def _process_error_variance_settings(
334
- x_train, y_train, sigest, sigdf, sigquant, lamda
335
- ) -> tuple[Float32[Array, ''] | None, ...]:
336
- if y_train.dtype == bool:
337
- if sigest is not None:
338
- raise ValueError('Let `sigest=None` for binary regression')
339
- if lamda is not None:
340
- raise ValueError('Let `lamda=None` for binary regression')
341
- return None, None
342
- elif lamda is not None:
343
- if sigest is not None:
344
- raise ValueError('Let `sigest=None` if `lamda` is specified')
345
- return lamda, None
346
- else:
347
- if sigest is not None:
348
- sigest2 = jnp.square(sigest)
349
- elif y_train.size < 2:
350
- sigest2 = 1
351
- elif y_train.size <= x_train.shape[0]:
352
- sigest2 = jnp.var(y_train)
353
- else:
354
- x_centered = x_train.T - x_train.mean(axis=1)
355
- y_centered = y_train - y_train.mean()
356
- # centering is equivalent to adding an intercept column
357
- _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
358
- chisq = chisq.squeeze(0)
359
- dof = len(y_train) - rank
360
- sigest2 = chisq / dof
361
- alpha = sigdf / 2
362
- invchi2 = jaxext.scipy.stats.invgamma.ppf(sigquant, alpha) / 2
363
- invchi2rid = invchi2 * sigdf
364
- return sigest2 / invchi2rid, jnp.sqrt(sigest2)
365
-
366
- @staticmethod
367
- def _process_type_settings(y_train, type, w):
368
- match type:
369
- case 'wbart':
370
- if y_train.dtype != jnp.float32:
371
- raise TypeError(
372
- 'Continuous regression requires y_train.dtype=float32,'
373
- f' got {y_train.dtype=} instead.'
374
- )
375
- case 'pbart':
376
- if w is not None:
377
- raise ValueError(
378
- 'Binary regression does not support weights, set `w=None`'
379
- )
380
- if y_train.dtype != bool:
381
- raise TypeError(
382
- 'Binary regression requires y_train.dtype=bool,'
383
- f' got {y_train.dtype=} instead.'
384
- )
385
- case _:
386
- raise ValueError(f'Invalid {type=}')
387
-
388
- return y_train
389
-
390
- @staticmethod
391
- def _process_offset_settings(
392
- y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
393
- offset: float | Float32[Any, ''] | None,
394
- ) -> Float32[Array, '']:
395
- if offset is not None:
396
- return jnp.asarray(offset)
397
- elif y_train.size < 1:
398
- return jnp.array(0.0)
399
- else:
400
- mean = y_train.mean()
401
-
402
- if y_train.dtype == bool:
403
- return ndtri(mean)
404
- else:
405
- return mean
406
-
407
- @staticmethod
408
- def _process_leaf_sdev_settings(
409
- y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
410
- k: float,
411
- ntree: int,
412
- tau_num: FloatLike | None,
413
- ):
414
- if tau_num is None:
415
- if y_train.dtype == bool:
416
- tau_num = 3.0
417
- elif y_train.size < 2:
418
- tau_num = 1.0
419
- else:
420
- tau_num = (y_train.max() - y_train.min()) / 2
421
-
422
- return tau_num / (k * math.sqrt(ntree))
423
-
424
- @staticmethod
425
- def _determine_splits(x_train, usequants, numcut):
426
- if usequants:
427
- return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
428
- else:
429
- return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
430
-
431
- @staticmethod
432
- def _bin_predictors(x, splits):
433
- return prepcovars.bin_predictors(x, splits)
434
-
435
- @staticmethod
436
- def _setup_mcmc(
437
- x_train,
438
- y_train,
439
- offset,
440
- w,
441
- max_split,
442
- lamda,
443
- sigma_mu,
444
- sigdf,
445
- power,
446
- base,
447
- maxdepth,
448
- ntree,
449
- init_kw,
450
- ):
451
- depth = jnp.arange(maxdepth - 1)
452
- p_nonterminal = base / (1 + depth).astype(float) ** power
453
-
454
- if y_train.dtype == bool:
455
- sigma2_alpha = None
456
- sigma2_beta = None
457
- else:
458
- sigma2_alpha = sigdf / 2
459
- sigma2_beta = lamda * sigma2_alpha
460
-
461
- kw = dict(
462
- X=x_train,
463
- # copy y_train because it's going to be donated in the mcmc loop
464
- y=jnp.array(y_train),
465
- offset=offset,
466
- error_scale=w,
467
- max_split=max_split,
468
- num_trees=ntree,
469
- p_nonterminal=p_nonterminal,
470
- sigma_mu2=jnp.square(sigma_mu),
471
- sigma2_alpha=sigma2_alpha,
472
- sigma2_beta=sigma2_beta,
473
- min_points_per_leaf=5,
474
- )
475
- if init_kw is not None:
476
- kw.update(init_kw)
477
- return mcmcstep.init(**kw)
478
-
479
- @staticmethod
480
- def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
481
- if isinstance(seed, jax.Array) and jnp.issubdtype(
482
- seed.dtype, jax.dtypes.prng_key
483
- ):
484
- key = seed.copy()
485
- # copy because the inner loop in run_mcmc will donate the buffer
486
- else:
487
- key = jax.random.key(seed)
488
-
489
- kw = dict(
490
- n_burn=nskip,
491
- n_skip=keepevery,
492
- inner_loop_length=printevery,
493
- allow_overflow=True,
494
- )
495
- if printevery is not None:
496
- kw.update(mcmcloop.make_print_callbacks())
497
- if run_mcmc_kw is not None:
498
- kw.update(run_mcmc_kw)
499
-
500
- return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
501
-
502
- @staticmethod
503
- def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
504
- if trace['sigma2'] is None:
505
- return None
506
- else:
507
- return jnp.sqrt(trace['sigma2'])
508
-
509
- @staticmethod
510
- def _predict(trace, x):
511
- return mcmcloop.evaluate_trace(trace, x)
512
-
513
- def _show_tree(self, i_sample, i_tree, print_all=False):
514
- from . import debug
515
-
516
- trace = self._main_trace
517
- leaf_tree = trace['leaf_trees'][i_sample, i_tree]
518
- var_tree = trace['var_trees'][i_sample, i_tree]
519
- split_tree = trace['split_trees'][i_sample, i_tree]
520
- debug.print_tree(leaf_tree, var_tree, split_tree, print_all)
521
-
522
- def _sigma_harmonic_mean(self, prior=False):
523
- bart = self._mcmc_state
524
- if prior:
525
- alpha = bart['sigma2_alpha']
526
- beta = bart['sigma2_beta']
527
- else:
528
- resid = bart['resid']
529
- alpha = bart['sigma2_alpha'] + resid.size / 2
530
- norm2 = jnp.dot(
531
- resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
532
- )
533
- beta = bart['sigma2_beta'] + norm2 / 2
534
- sigma2 = beta / alpha
535
- return jnp.sqrt(sigma2)
536
-
537
- def _compare_resid(self):
538
- bart = self._mcmc_state
539
- resid1 = bart.resid
540
-
541
- trees = grove.evaluate_forest(
542
- bart.X,
543
- bart.forest.leaf_trees,
544
- bart.forest.var_trees,
545
- bart.forest.split_trees,
546
- jnp.float32, # TODO remove these configurable dtypes around
547
- )
548
-
549
- if bart.z is not None:
550
- ref = bart.z
551
- else:
552
- ref = bart.y
553
- resid2 = ref - (trees + bart.offset)
554
-
555
- return resid1, resid2
556
-
557
- def _avg_acc(self):
558
- trace = self._main_trace
559
-
560
- def acc(prefix):
561
- acc = trace[f'{prefix}_acc_count']
562
- prop = trace[f'{prefix}_prop_count']
563
- return acc.sum() / prop.sum()
564
-
565
- return acc('grow'), acc('prune')
566
-
567
- def _avg_prop(self):
568
- trace = self._main_trace
569
-
570
- def prop(prefix):
571
- return trace[f'{prefix}_prop_count'].sum()
572
-
573
- pgrow = prop('grow')
574
- pprune = prop('prune')
575
- total = pgrow + pprune
576
- return pgrow / total, pprune / total
577
-
578
- def _avg_move(self):
579
- agrow, aprune = self._avg_acc()
580
- pgrow, pprune = self._avg_prop()
581
- return agrow * pgrow, aprune * pprune
582
-
583
- def _depth_distr(self):
584
- from . import debug
585
-
586
- trace = self._main_trace
587
- split_trees = trace['split_trees']
588
- return debug.trace_depth_distr(split_trees)
589
-
590
- def _points_per_leaf_distr(self):
591
- from . import debug
592
-
593
- return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
594
-
595
- def _check_trees(self):
596
- from . import debug
597
-
598
- return debug.check_trace(self._main_trace, self._mcmc_state)
599
-
600
- def _tree_goes_bad(self):
601
- bad = self._check_trees().astype(bool)
602
- bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
603
- return bad & ~bad_before