bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/.DS_Store CHANGED
File without changes
bartz/BART/__init__.py ADDED
@@ -0,0 +1,27 @@
1
+ # bartz/src/bartz/BART/__init__.py
2
+ #
3
+ # Copyright (c) 2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
26
+
27
+ from bartz.BART._gbart import gbart, mc_gbart # noqa: F401
bartz/BART/_gbart.py ADDED
@@ -0,0 +1,522 @@
1
+ # bartz/src/bartz/BART/_gbart.py
2
+ #
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Implement classes `mc_gbart` and `gbart` that mimic the R BART3 package."""
26
+
27
+ from collections.abc import Mapping
28
+ from functools import cached_property
29
+ from os import cpu_count
30
+ from types import MappingProxyType
31
+ from typing import Any, Literal
32
+ from warnings import warn
33
+
34
+ from equinox import Module
35
+ from jax import device_count
36
+ from jax import numpy as jnp
37
+ from jaxtyping import Array, Bool, Float, Float32, Int32, Key, Real
38
+
39
+ from bartz import mcmcloop, mcmcstep
40
+ from bartz._interface import Bart, DataFrame, FloatLike, Series
41
+ from bartz.jaxext import get_default_device
42
+
43
+
44
+ class mc_gbart(Module):
45
+ R"""
46
+ Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
47
+
48
+ Regress `y_train` on `x_train` with a latent mean function represented as
49
+ a sum of decision trees. The inference is carried out by sampling the
50
+ posterior distribution of the tree ensemble with an MCMC.
51
+
52
+ Parameters
53
+ ----------
54
+ x_train
55
+ The training predictors.
56
+ y_train
57
+ The training responses.
58
+ x_test
59
+ The test predictors.
60
+ type
61
+ The type of regression. 'wbart' for continuous regression, 'pbart' for
62
+ binary regression with probit link.
63
+ sparse
64
+ Whether to activate variable selection on the predictors as done in
65
+ [1]_.
66
+ theta
67
+ a
68
+ b
69
+ rho
70
+ Hyperparameters of the sparsity prior used for variable selection.
71
+
72
+ The prior distribution on the choice of predictor for each decision rule
73
+ is
74
+
75
+ .. math::
76
+ (s_1, \ldots, s_p) \sim
77
+ \operatorname{Dirichlet}(\mathtt{theta}/p, \ldots, \mathtt{theta}/p).
78
+
79
+ If `theta` is not specified, it's a priori distributed according to
80
+
81
+ .. math::
82
+ \frac{\mathtt{theta}}{\mathtt{theta} + \mathtt{rho}} \sim
83
+ \operatorname{Beta}(\mathtt{a}, \mathtt{b}).
84
+
85
+ If not specified, `rho` is set to the number of predictors p. To tune
86
+ the prior, consider setting a lower `rho` to prefer more sparsity.
87
+ If setting `theta` directly, it should be in the ballpark of p or lower
88
+ as well.
89
+ xinfo
90
+ A matrix with the cutpoins to use to bin each predictor. If not
91
+ specified, it is generated automatically according to `usequants` and
92
+ `numcut`.
93
+
94
+ Each row shall contain a sorted list of cutpoints for a predictor. If
95
+ there are less cutpoints than the number of columns in the matrix,
96
+ fill the remaining cells with NaN.
97
+
98
+ `xinfo` shall be a matrix even if `x_train` is a dataframe.
99
+ usequants
100
+ Whether to use predictors quantiles instead of a uniform grid to bin
101
+ predictors. Ignored if `xinfo` is specified.
102
+ rm_const
103
+ How to treat predictors with no associated decision rules (i.e., there
104
+ are no available cutpoints for that predictor). If `True` (default),
105
+ they are ignored. If `False`, an error is raised if there are any.
106
+ sigest
107
+ An estimate of the residual standard deviation on `y_train`, used to set
108
+ `lamda`. If not specified, it is estimated by linear regression (with
109
+ intercept, and without taking into account `w`). If `y_train` has less
110
+ than two elements, it is set to 1. If n <= p, it is set to the standard
111
+ deviation of `y_train`. Ignored if `lamda` is specified.
112
+ sigdf
113
+ The degrees of freedom of the scaled inverse-chisquared prior on the
114
+ noise variance.
115
+ sigquant
116
+ The quantile of the prior on the noise variance that shall match
117
+ `sigest` to set the scale of the prior. Ignored if `lamda` is specified.
118
+ k
119
+ The inverse scale of the prior standard deviation on the latent mean
120
+ function, relative to half the observed range of `y_train`. If `y_train`
121
+ has less than two elements, `k` is ignored and the scale is set to 1.
122
+ power
123
+ base
124
+ Parameters of the prior on tree node generation. The probability that a
125
+ node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
126
+ power``.
127
+ lamda
128
+ The prior harmonic mean of the error variance. (The harmonic mean of x
129
+ is 1/mean(1/x).) If not specified, it is set based on `sigest` and
130
+ `sigquant`.
131
+ tau_num
132
+ The numerator in the expression that determines the prior standard
133
+ deviation of leaves. If not specified, default to ``(max(y_train) -
134
+ min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
135
+ continuous regression, and 3 for binary regression.
136
+ offset
137
+ The prior mean of the latent mean function. If not specified, it is set
138
+ to the mean of `y_train` for continuous regression, and to
139
+ ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
140
+ `offset` is set to 0. With binary regression, if `y_train` is all
141
+ `False` or `True`, it is set to ``Phi^-1(1/(n+1))`` or
142
+ ``Phi^-1(n/(n+1))``, respectively.
143
+ w
144
+ Coefficients that rescale the error standard deviation on each
145
+ datapoint. Not specifying `w` is equivalent to setting it to 1 for all
146
+ datapoints. Note: `w` is ignored in the automatic determination of
147
+ `sigest`, so either the weights should be O(1), or `sigest` should be
148
+ specified by the user.
149
+ ntree
150
+ The number of trees used to represent the latent mean function. By
151
+ default 200 for continuous regression and 50 for binary regression.
152
+ numcut
153
+ If `usequants` is `False`: the exact number of cutpoints used to bin the
154
+ predictors, ranging between the minimum and maximum observed values
155
+ (excluded).
156
+
157
+ If `usequants` is `True`: the maximum number of cutpoints to use for
158
+ binning the predictors. Each predictor is binned such that its
159
+ distribution in `x_train` is approximately uniform across bins. The
160
+ number of bins is at most the number of unique values appearing in
161
+ `x_train`, or ``numcut + 1``.
162
+
163
+ Before running the algorithm, the predictors are compressed to the
164
+ smallest integer type that fits the bin indices, so `numcut` is best set
165
+ to the maximum value of an unsigned integer type, like 255.
166
+
167
+ Ignored if `xinfo` is specified.
168
+ ndpost
169
+ The number of MCMC samples to save, after burn-in. `ndpost` is the
170
+ total number of samples across all chains. `ndpost` is rounded up to the
171
+ first multiple of `mc_cores`.
172
+ nskip
173
+ The number of initial MCMC samples to discard as burn-in. This number
174
+ of samples is discarded from each chain.
175
+ keepevery
176
+ The thinning factor for the MCMC samples, after burn-in. By default, 1
177
+ for continuous regression and 10 for binary regression.
178
+ printevery
179
+ The number of iterations (including thinned-away ones) between each log
180
+ line. Set to `None` to disable logging. ^C interrupts the MCMC only
181
+ every `printevery` iterations, so with logging disabled it's impossible
182
+ to kill the MCMC conveniently.
183
+ mc_cores
184
+ The number of independent MCMC chains.
185
+ seed
186
+ The seed for the random number generator.
187
+ bart_kwargs
188
+ Additional arguments passed to `bartz.Bart`.
189
+
190
+ Notes
191
+ -----
192
+ This interface imitates the function ``mc_gbart`` from the R package `BART3
193
+ <https://github.com/rsparapa/bnptools>`_, but with these differences:
194
+
195
+ - If `x_train` and `x_test` are matrices, they have one predictor per row
196
+ instead of per column.
197
+ - If ``usequants=False``, R BART3 switches to quantiles anyway if there are
198
+ less predictor values than the required number of bins, while bartz
199
+ always follows the specification.
200
+ - Some functionality is missing.
201
+ - The error variance parameter is called `lamda` instead of `lambda`.
202
+ - There are some additional attributes, and some missing.
203
+ - The trees have a maximum depth of 6.
204
+ - `rm_const` refers to predictors without decision rules instead of
205
+ predictors that are constant in `x_train`.
206
+ - If `rm_const=True` and some variables are dropped, the predictors
207
+ matrix/dataframe passed to `predict` should still include them.
208
+
209
+ References
210
+ ----------
211
+ .. [1] Linero, Antonio R. (2018). "Bayesian Regression Trees for
212
+ High-Dimensional Prediction and Variable Selection". In: Journal of the
213
+ American Statistical Association 113.522, pp. 626-636.
214
+ .. [2] Hugh A. Chipman, Edward I. George, Robert E. McCulloch "BART:
215
+ Bayesian additive regression trees," The Annals of Applied Statistics,
216
+ Ann. Appl. Stat. 4(1), 266-298, (March 2010).
217
+ """
218
+
219
+ _bart: Bart
220
+
221
+ def __init__(
222
+ self,
223
+ x_train: Real[Array, 'p n'] | DataFrame,
224
+ y_train: Bool[Array, ' n'] | Float32[Array, ' n'] | Series,
225
+ *,
226
+ x_test: Real[Array, 'p m'] | DataFrame | None = None,
227
+ type: Literal['wbart', 'pbart'] = 'wbart', # noqa: A002
228
+ sparse: bool = False,
229
+ theta: FloatLike | None = None,
230
+ a: FloatLike = 0.5,
231
+ b: FloatLike = 1.0,
232
+ rho: FloatLike | None = None,
233
+ xinfo: Float[Array, 'p n'] | None = None,
234
+ usequants: bool = False,
235
+ rm_const: bool = True,
236
+ sigest: FloatLike | None = None,
237
+ sigdf: FloatLike = 3.0,
238
+ sigquant: FloatLike = 0.9,
239
+ k: FloatLike = 2.0,
240
+ power: FloatLike = 2.0,
241
+ base: FloatLike = 0.95,
242
+ lamda: FloatLike | None = None,
243
+ tau_num: FloatLike | None = None,
244
+ offset: FloatLike | None = None,
245
+ w: Float[Array, ' n'] | None = None,
246
+ ntree: int | None = None,
247
+ numcut: int = 100,
248
+ ndpost: int = 1000,
249
+ nskip: int = 100,
250
+ keepevery: int | None = None,
251
+ printevery: int | None = 100,
252
+ mc_cores: int = 2,
253
+ seed: int | Key[Array, ''] = 0,
254
+ bart_kwargs: Mapping = MappingProxyType({}),
255
+ ):
256
+ kwargs: dict = dict(
257
+ x_train=x_train,
258
+ y_train=y_train,
259
+ x_test=x_test,
260
+ type=type,
261
+ sparse=sparse,
262
+ theta=theta,
263
+ a=a,
264
+ b=b,
265
+ rho=rho,
266
+ xinfo=xinfo,
267
+ usequants=usequants,
268
+ rm_const=rm_const,
269
+ sigest=sigest,
270
+ sigdf=sigdf,
271
+ sigquant=sigquant,
272
+ k=k,
273
+ power=power,
274
+ base=base,
275
+ lamda=lamda,
276
+ tau_num=tau_num,
277
+ offset=offset,
278
+ w=w,
279
+ ntree=ntree,
280
+ numcut=numcut,
281
+ ndpost=ndpost,
282
+ nskip=nskip,
283
+ keepevery=keepevery,
284
+ printevery=printevery,
285
+ seed=seed,
286
+ maxdepth=6,
287
+ **process_mc_cores(y_train, mc_cores),
288
+ )
289
+ kwargs.update(bart_kwargs)
290
+ self._bart = Bart(**kwargs)
291
+
292
+ # Public attributes from Bart
293
+
294
+ @property
295
+ def ndpost(self) -> int:
296
+ """The number of MCMC samples saved, after burn-in."""
297
+ return self._bart.ndpost
298
+
299
+ @property
300
+ def offset(self) -> Float32[Array, '']:
301
+ """The prior mean of the latent mean function."""
302
+ return self._bart.offset
303
+
304
+ @property
305
+ def sigest(self) -> Float32[Array, ''] | None:
306
+ """The estimated standard deviation of the error used to set `lamda`."""
307
+ return self._bart.sigest
308
+
309
+ @property
310
+ def yhat_test(self) -> Float32[Array, 'ndpost m'] | None:
311
+ """The conditional posterior mean at `x_test` for each MCMC iteration."""
312
+ return self._bart.yhat_test
313
+
314
+ # Private attributes from Bart
315
+
316
+ @property
317
+ def _main_trace(self) -> mcmcloop.MainTrace:
318
+ return self._bart._main_trace # noqa: SLF001
319
+
320
+ @property
321
+ def _burnin_trace(self) -> mcmcloop.BurninTrace:
322
+ return self._bart._burnin_trace # noqa: SLF001
323
+
324
+ @property
325
+ def _mcmc_state(self) -> mcmcstep.State:
326
+ return self._bart._mcmc_state # noqa: SLF001
327
+
328
+ @property
329
+ def _splits(self) -> Real[Array, 'p max_num_splits']:
330
+ return self._bart._splits # noqa: SLF001
331
+
332
+ @property
333
+ def _x_train_fmt(self) -> Any:
334
+ return self._bart._x_train_fmt # noqa: SLF001
335
+
336
+ # Cached properties from Bart
337
+
338
+ @cached_property
339
+ def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
340
+ """The posterior probability of y being True at `x_test` for each MCMC iteration."""
341
+ return self._bart.prob_test
342
+
343
+ @cached_property
344
+ def prob_test_mean(self) -> Float32[Array, ' m'] | None:
345
+ """The marginal posterior probability of y being True at `x_test`."""
346
+ return self._bart.prob_test_mean
347
+
348
+ @cached_property
349
+ def prob_train(self) -> Float32[Array, 'ndpost n'] | None:
350
+ """The posterior probability of y being True at `x_train` for each MCMC iteration."""
351
+ return self._bart.prob_train
352
+
353
+ @cached_property
354
+ def prob_train_mean(self) -> Float32[Array, ' n'] | None:
355
+ """The marginal posterior probability of y being True at `x_train`."""
356
+ return self._bart.prob_train_mean
357
+
358
+ @cached_property
359
+ def sigma(
360
+ self,
361
+ ) -> (
362
+ Float32[Array, ' nskip+ndpost']
363
+ | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
364
+ | None
365
+ ):
366
+ """The standard deviation of the error, including burn-in samples."""
367
+ return self._bart.sigma
368
+
369
+ @cached_property
370
+ def sigma_(self) -> Float32[Array, 'ndpost'] | None:
371
+ """The standard deviation of the error, only over the post-burnin samples and flattened."""
372
+ return self._bart.sigma_
373
+
374
+ @cached_property
375
+ def sigma_mean(self) -> Float32[Array, ''] | None:
376
+ """The mean of `sigma`, only over the post-burnin samples."""
377
+ return self._bart.sigma_mean
378
+
379
+ @cached_property
380
+ def varcount(self) -> Int32[Array, 'ndpost p']:
381
+ """Histogram of predictor usage for decision rules in the trees."""
382
+ return self._bart.varcount
383
+
384
+ @cached_property
385
+ def varcount_mean(self) -> Float32[Array, ' p']:
386
+ """Average of `varcount` across MCMC iterations."""
387
+ return self._bart.varcount_mean
388
+
389
+ @cached_property
390
+ def varprob(self) -> Float32[Array, 'ndpost p']:
391
+ """Posterior samples of the probability of choosing each predictor for a decision rule."""
392
+ return self._bart.varprob
393
+
394
+ @cached_property
395
+ def varprob_mean(self) -> Float32[Array, ' p']:
396
+ """The marginal posterior probability of each predictor being chosen for a decision rule."""
397
+ return self._bart.varprob_mean
398
+
399
+ @cached_property
400
+ def yhat_test_mean(self) -> Float32[Array, ' m'] | None:
401
+ """The marginal posterior mean at `x_test`.
402
+
403
+ Not defined with binary regression because it's error-prone, typically
404
+ the right thing to consider would be `prob_test_mean`.
405
+ """
406
+ return self._bart.yhat_test_mean
407
+
408
+ @cached_property
409
+ def yhat_train(self) -> Float32[Array, 'ndpost n']:
410
+ """The conditional posterior mean at `x_train` for each MCMC iteration."""
411
+ return self._bart.yhat_train
412
+
413
+ @cached_property
414
+ def yhat_train_mean(self) -> Float32[Array, ' n'] | None:
415
+ """The marginal posterior mean at `x_train`.
416
+
417
+ Not defined with binary regression because it's error-prone, typically
418
+ the right thing to consider would be `prob_train_mean`.
419
+ """
420
+ return self._bart.yhat_train_mean
421
+
422
+ # Public methods from Bart
423
+
424
+ def predict(
425
+ self, x_test: Real[Array, 'p m'] | DataFrame
426
+ ) -> Float32[Array, 'ndpost m']:
427
+ """
428
+ Compute the posterior mean at `x_test` for each MCMC iteration.
429
+
430
+ Parameters
431
+ ----------
432
+ x_test
433
+ The test predictors.
434
+
435
+ Returns
436
+ -------
437
+ The conditional posterior mean at `x_test` for each MCMC iteration.
438
+ """
439
+ return self._bart.predict(x_test)
440
+
441
+
442
+ class gbart(mc_gbart):
443
+ """Subclass of `mc_gbart` that forces `mc_cores=1`."""
444
+
445
+ def __init__(self, *args, **kwargs):
446
+ if 'mc_cores' in kwargs:
447
+ msg = "gbart.__init__() got an unexpected keyword argument 'mc_cores'"
448
+ raise TypeError(msg)
449
+ kwargs.update(mc_cores=1)
450
+ super().__init__(*args, **kwargs)
451
+
452
+
453
+ def process_mc_cores(y_train: Array | Any, mc_cores: int) -> dict[str, Any]:
454
+ """Determine the arguments to pass to `Bart` to configure multiple chains."""
455
+ # one chain, leave default configuration which is num_chains=None
456
+ if abs(mc_cores) == 1:
457
+ return {}
458
+
459
+ # determine if we are on cpu; this point may raise an exception
460
+ platform = get_platform(y_train, mc_cores)
461
+
462
+ # set the num_chains argument
463
+ mc_cores = abs(mc_cores)
464
+ kwargs = dict(num_chains=mc_cores)
465
+
466
+ # if on cpu, try to shard the chains across multiple virtual cpus
467
+ if platform == 'cpu':
468
+ # determine number of logical cpu cores
469
+ num_cores = cpu_count()
470
+ assert num_cores is not None, 'could not determine number of cpu cores'
471
+
472
+ # determine number of shards that evenly divides chains
473
+ for num_shards in range(num_cores, 0, -1):
474
+ if mc_cores % num_shards == 0:
475
+ break
476
+
477
+ # handle the case where there are less jax cpu devices that that
478
+ if num_shards > 1:
479
+ num_jax_cpus = device_count('cpu')
480
+ if num_jax_cpus < num_shards:
481
+ for new_num_shards in range(num_jax_cpus, 0, -1):
482
+ if mc_cores % new_num_shards == 0:
483
+ break
484
+ msg = (
485
+ f'`mc_gbart` would like to shard {mc_cores} chains across '
486
+ f'{num_shards} virtual jax cpu devices, but jax is set up '
487
+ f'with only {num_jax_cpus} cpu devices, so it will use '
488
+ f'{new_num_shards} devices instead. To enable '
489
+ 'parallelization, please increase the limit with '
490
+ '`jax.config.update("jax_num_cpu_devices", <num_devices>)`.'
491
+ )
492
+ warn(msg)
493
+ num_shards = new_num_shards
494
+
495
+ # set the number of shards
496
+ if num_shards > 1:
497
+ kwargs.update(num_chain_devices=num_shards)
498
+
499
+ return kwargs
500
+
501
+
502
+ def get_platform(y_train: Array | Any, mc_cores: int) -> str:
503
+ """Get the platform for `process_mc_cores` from `y_train` or the default device."""
504
+ if isinstance(y_train, Array) and hasattr(y_train, 'platform'):
505
+ return y_train.platform()
506
+ elif (
507
+ not isinstance(y_train, Array) and hasattr(jnp.zeros(()), 'platform')
508
+ # this condition means: y_train is not an array, but we are not under
509
+ # jit, so y_train is going to be converted to an array on the default
510
+ # device
511
+ ) or mc_cores < 0:
512
+ return get_default_device().platform
513
+ else:
514
+ msg = (
515
+ 'Could not determine the platform from `y_train`, maybe `mc_gbart` '
516
+ 'was used with a `jax.jit`ted function? The platform is needed to '
517
+ 'determine whether the computation is going to run on CPU to '
518
+ 'automatically shard the chains across multiple virtual CPU '
519
+ 'devices. To acknowledge this problem and circumvent it '
520
+ 'by using the current default jax device, negate `mc_cores`.'
521
+ )
522
+ raise RuntimeError(msg)
bartz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/__init__.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -25,8 +25,10 @@
25
25
  """
26
26
  Super-fast BART (Bayesian Additive Regression Trees) in Python.
27
27
 
28
- See the manual at https://gattocrucco.github.io/bartz/docs
28
+ See the manual at https://bartz-org.github.io/bartz/docs
29
29
  """
30
30
 
31
31
  from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
32
+ from bartz._interface import Bart # noqa: F401
33
+ from bartz._profiler import profile_mode # noqa: F401
32
34
  from bartz._version import __version__ # noqa: F401