bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/.DS_Store ADDED
Binary file
bartz/BART.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/BART.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -22,16 +22,21 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
+ """Implement a user interface that mimics the R BART package."""
26
+
25
27
  import functools
28
+ import math
29
+ from typing import Any, Literal
26
30
 
27
31
  import jax
28
32
  import jax.numpy as jnp
33
+ from jax.scipy.special import ndtri
34
+ from jaxtyping import Array, Bool, Float, Float32
35
+
36
+ from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
37
+
38
+ FloatLike = float | Float[Any, '']
29
39
 
30
- from . import jaxext
31
- from . import grove
32
- from . import mcmcstep
33
- from . import mcmcloop
34
- from . import prepcovars
35
40
 
36
41
  class gbart:
37
42
  """
@@ -49,14 +54,18 @@ class gbart:
49
54
  The training responses.
50
55
  x_test : array (p, m) or DataFrame, optional
51
56
  The test predictors.
57
+ type
58
+ The type of regression. 'wbart' for continuous regression, 'pbart' for
59
+ binary regression with probit link.
52
60
  usequants : bool, default False
53
61
  Whether to use predictors quantiles instead of a uniform grid to bin
54
62
  predictors.
55
63
  sigest : float, optional
56
- An estimate of the residual standard deviation on `y_train`, used to
57
- set `lamda`. If not specified, it is estimated by linear regression.
58
- If `y_train` has less than two elements, it is set to 1. If n <= p, it
59
- is set to the variance of `y_train`. Ignored if `lamda` is specified.
64
+ An estimate of the residual standard deviation on `y_train`, used to set
65
+ `lamda`. If not specified, it is estimated by linear regression (with
66
+ intercept, and without taking into account `w`). If `y_train` has less
67
+ than two elements, it is set to 1. If n <= p, it is set to the standard
68
+ deviation of `y_train`. Ignored if `lamda` is specified.
60
69
  sigdf : int, default 3
61
70
  The degrees of freedom of the scaled inverse-chisquared prior on the
62
71
  noise variance.
@@ -72,16 +81,26 @@ class gbart:
72
81
  Parameters of the prior on tree node generation. The probability that a
73
82
  node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
74
83
  power``.
75
- maxdepth : int, default 6
76
- The maximum depth of the trees. This is 1-based, so with the default
77
- ``maxdepth=6``, the depths of the levels range from 0 to 5.
78
- lamda : float, optional
79
- The scale of the prior on the noise variance. If ``lamda==1``, the
80
- prior is an inverse chi-squared scaled to have harmonic mean 1. If
81
- not specified, it is set based on `sigest` and `sigquant`.
82
- offset : float, optional
84
+ lamda
85
+ The prior harmonic mean of the error variance. (The harmonic mean of x
86
+ is 1/mean(1/x).) If not specified, it is set based on `sigest` and
87
+ `sigquant`.
88
+ tau_num
89
+ The numerator in the expression that determines the prior standard
90
+ deviation of leaves. If not specified, default to ``(max(y_train) -
91
+ min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
92
+ continuous regression, and 3 for binary regression.
93
+ offset
83
94
  The prior mean of the latent mean function. If not specified, it is set
84
- to the mean of `y_train`. If `y_train` is empty, it is set to 0.
95
+ to the mean of `y_train` for continuous regression, and to
96
+ ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
97
+ `offset` is set to 0.
98
+ w : array (n,), optional
99
+ Coefficients that rescale the error standard deviation on each
100
+ datapoint. Not specifying `w` is equivalent to setting it to 1 for all
101
+ datapoints. Note: `w` is ignored in the automatic determination of
102
+ `sigest`, so either the weights should be O(1), or `sigest` should be
103
+ specified by the user.
85
104
  ntree : int, default 200
86
105
  The number of trees used to represent the latent mean function.
87
106
  numcut : int, default 255
@@ -104,10 +123,24 @@ class gbart:
104
123
  The number of initial MCMC samples to discard as burn-in.
105
124
  keepevery : int, default 1
106
125
  The thinning factor for the MCMC samples, after burn-in.
107
- printevery : int, default 100
108
- The number of iterations (including skipped ones) between each log.
126
+ printevery : int or None, default 100
127
+ The number of iterations (including thinned-away ones) between each log
128
+ line. Set to `None` to disable logging.
129
+
130
+ `printevery` has a few unexpected side effects. On cpu, interrupting
131
+ with ^C halts the MCMC only on the next log. And the total number of
132
+ iterations is a multiple of `printevery`, so if ``nskip + keepevery *
133
+ ndpost`` is not a multiple of `printevery`, some of the last iterations
134
+ will not be saved.
109
135
  seed : int or jax random key, default 0
110
136
  The seed for the random number generator.
137
+ maxdepth : int, default 6
138
+ The maximum depth of the trees. This is 1-based, so with the default
139
+ ``maxdepth=6``, the depths of the levels range from 0 to 5.
140
+ init_kw : dict
141
+ Additional arguments passed to `mcmcstep.init`.
142
+ run_mcmc_kw : dict
143
+ Additional arguments passed to `mcmcloop.run_mcmc`.
111
144
 
112
145
  Attributes
113
146
  ----------
@@ -125,22 +158,8 @@ class gbart:
125
158
  The standard deviation of the error in the burn-in phase.
126
159
  offset : float
127
160
  The prior mean of the latent mean function.
128
- scale : float
129
- The prior standard deviation of the latent mean function.
130
- lamda : float
131
- The prior harmonic mean of the error variance.
132
161
  sigest : float or None
133
162
  The estimated standard deviation of the error used to set `lamda`.
134
- ntree : int
135
- The number of trees.
136
- maxdepth : int
137
- The maximum depth of the trees.
138
- initkw : dict
139
- Additional arguments passed to `mcmcstep.init`.
140
-
141
- Methods
142
- -------
143
- predict
144
163
 
145
164
  Notes
146
165
  -----
@@ -149,20 +168,27 @@ class gbart:
149
168
 
150
169
  - If `x_train` and `x_test` are matrices, they have one predictor per row
151
170
  instead of per column.
171
+ - If `type` is not specified, it is determined solely based on the data type
172
+ of `y_train`, and not on whether it contains only two unique values.
152
173
  - If ``usequants=False``, R BART switches to quantiles anyway if there are
153
174
  less predictor values than the required number of bins, while bartz
154
175
  always follows the specification.
155
176
  - The error variance parameter is called `lamda` instead of `lambda`.
156
177
  - `rm_const` is always `False`.
157
178
  - The default `numcut` is 255 instead of 100.
158
- - A lot of functionality is missing (variable selection, discrete response).
179
+ - A lot of functionality is missing (e.g., variable selection).
159
180
  - There are some additional attributes, and some missing.
181
+ - The trees have a maximum depth.
160
182
 
161
- The linear regression used to set `sigest` adds an intercept.
162
183
  """
163
184
 
164
- def __init__(self, x_train, y_train, *,
185
+ def __init__(
186
+ self,
187
+ x_train,
188
+ y_train,
189
+ *,
165
190
  x_test=None,
191
+ type: Literal['wbart', 'pbart'] = 'wbart',
166
192
  usequants=False,
167
193
  sigest=None,
168
194
  sigdf=3,
@@ -170,9 +196,10 @@ class gbart:
170
196
  k=2,
171
197
  power=2,
172
198
  base=0.95,
173
- maxdepth=6,
174
- lamda=None,
175
- offset=None,
199
+ lamda: FloatLike | None = None,
200
+ tau_num: FloatLike | None = None,
201
+ offset: FloatLike | None = None,
202
+ w=None,
176
203
  ntree=200,
177
204
  numcut=255,
178
205
  ndpost=1000,
@@ -180,36 +207,52 @@ class gbart:
180
207
  keepevery=1,
181
208
  printevery=100,
182
209
  seed=0,
183
- initkw={},
184
- ):
185
-
210
+ maxdepth=6,
211
+ init_kw=None,
212
+ run_mcmc_kw=None,
213
+ ):
186
214
  x_train, x_train_fmt = self._process_predictor_input(x_train)
187
-
188
- y_train, y_train_fmt = self._process_response_input(y_train)
215
+ y_train, _ = self._process_response_input(y_train)
189
216
  self._check_same_length(x_train, y_train)
217
+ if w is not None:
218
+ w, _ = self._process_response_input(w)
219
+ self._check_same_length(x_train, w)
190
220
 
221
+ y_train = self._process_type_settings(y_train, type, w)
222
+ # from here onwards, the type is determined by y_train.dtype == bool
191
223
  offset = self._process_offset_settings(y_train, offset)
192
- scale = self._process_scale_settings(y_train, k)
193
- lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
224
+ sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
225
+ lamda, sigest = self._process_error_variance_settings(
226
+ x_train, y_train, sigest, sigdf, sigquant, lamda
227
+ )
194
228
 
195
229
  splits, max_split = self._determine_splits(x_train, usequants, numcut)
196
230
  x_train = self._bin_predictors(x_train, splits)
197
231
 
198
- y_train = self._transform_input(y_train, offset, scale)
199
- lamda_scaled = lamda / (scale * scale)
200
-
201
- mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree, initkw)
202
- final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
232
+ mcmc_state = self._setup_mcmc(
233
+ x_train,
234
+ y_train,
235
+ offset,
236
+ w,
237
+ max_split,
238
+ lamda,
239
+ sigma_mu,
240
+ sigdf,
241
+ power,
242
+ base,
243
+ maxdepth,
244
+ ntree,
245
+ init_kw,
246
+ )
247
+ final_state, burnin_trace, main_trace = self._run_mcmc(
248
+ mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
249
+ )
203
250
 
204
- sigma = self._extract_sigma(main_trace, scale)
205
- first_sigma = self._extract_sigma(burnin_trace, scale)
251
+ sigma = self._extract_sigma(main_trace)
252
+ first_sigma = self._extract_sigma(burnin_trace)
206
253
 
207
- self.offset = offset
208
- self.scale = scale
209
- self.lamda = lamda
254
+ self.offset = final_state.offset # from the state because of buffer donation
210
255
  self.sigest = sigest
211
- self.ntree = ntree
212
- self.maxdepth = maxdepth
213
256
  self.sigma = sigma
214
257
  self.first_sigma = first_sigma
215
258
 
@@ -225,9 +268,8 @@ class gbart:
225
268
 
226
269
  @functools.cached_property
227
270
  def yhat_train(self):
228
- x_train = self._mcmc_state['X']
229
- yhat_train = self._predict(self._main_trace, x_train)
230
- return self._transform_output(yhat_train, self.offset, self.scale)
271
+ x_train = self._mcmc_state.X
272
+ return self._predict(self._main_trace, x_train)
231
273
 
232
274
  @functools.cached_property
233
275
  def yhat_train_mean(self):
@@ -239,19 +281,26 @@ class gbart:
239
281
 
240
282
  Parameters
241
283
  ----------
242
- x_test : array (m, p) or DataFrame
284
+ x_test : array (p, m) or DataFrame
243
285
  The test predictors.
244
286
 
245
287
  Returns
246
288
  -------
247
289
  yhat_test : array (ndpost, m)
248
290
  The conditional posterior mean at `x_test` for each MCMC iteration.
291
+
292
+ Raises
293
+ ------
294
+ ValueError
295
+ If `x_test` has a different format than `x_train`.
249
296
  """
250
297
  x_test, x_test_fmt = self._process_predictor_input(x_test)
251
- self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
298
+ if x_test_fmt != self._x_train_fmt:
299
+ raise ValueError(
300
+ f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
301
+ )
252
302
  x_test = self._bin_predictors(x_test, self._splits)
253
- yhat_test = self._predict(self._main_trace, x_test)
254
- return self._transform_output(yhat_test, self.offset, self.scale)
303
+ return self._predict(self._main_trace, x_test)
255
304
 
256
305
  @staticmethod
257
306
  def _process_predictor_input(x):
@@ -264,10 +313,6 @@ class gbart:
264
313
  assert x.ndim == 2
265
314
  return x, fmt
266
315
 
267
- @staticmethod
268
- def _check_compatible_formats(fmt1, fmt2):
269
- assert fmt1 == fmt2
270
-
271
316
  @staticmethod
272
317
  def _process_response_input(y):
273
318
  if hasattr(y, 'to_numpy'):
@@ -285,20 +330,30 @@ class gbart:
285
330
  assert get_length(x1) == get_length(x2)
286
331
 
287
332
  @staticmethod
288
- def _process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset):
289
- if lamda is not None:
333
+ def _process_error_variance_settings(
334
+ x_train, y_train, sigest, sigdf, sigquant, lamda
335
+ ) -> tuple[Float32[Array, ''] | None, ...]:
336
+ if y_train.dtype == bool:
337
+ if sigest is not None:
338
+ raise ValueError('Let `sigest=None` for binary regression')
339
+ if lamda is not None:
340
+ raise ValueError('Let `lamda=None` for binary regression')
341
+ return None, None
342
+ elif lamda is not None:
343
+ if sigest is not None:
344
+ raise ValueError('Let `sigest=None` if `lamda` is specified')
290
345
  return lamda, None
291
346
  else:
292
347
  if sigest is not None:
293
- sigest2 = sigest * sigest
348
+ sigest2 = jnp.square(sigest)
294
349
  elif y_train.size < 2:
295
350
  sigest2 = 1
296
351
  elif y_train.size <= x_train.shape[0]:
297
- sigest2 = jnp.var(y_train - offset)
352
+ sigest2 = jnp.var(y_train)
298
353
  else:
299
354
  x_centered = x_train.T - x_train.mean(axis=1)
300
355
  y_centered = y_train - y_train.mean()
301
- # centering is equivalent to adding an intercept column
356
+ # centering is equivalent to adding an intercept column
302
357
  _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
303
358
  chisq = chisq.squeeze(0)
304
359
  dof = len(y_train) - rank
@@ -309,20 +364,62 @@ class gbart:
309
364
  return sigest2 / invchi2rid, jnp.sqrt(sigest2)
310
365
 
311
366
  @staticmethod
312
- def _process_offset_settings(y_train, offset):
367
+ def _process_type_settings(y_train, type, w):
368
+ match type:
369
+ case 'wbart':
370
+ if y_train.dtype != jnp.float32:
371
+ raise TypeError(
372
+ 'Continuous regression requires y_train.dtype=float32,'
373
+ f' got {y_train.dtype=} instead.'
374
+ )
375
+ case 'pbart':
376
+ if w is not None:
377
+ raise ValueError(
378
+ 'Binary regression does not support weights, set `w=None`'
379
+ )
380
+ if y_train.dtype != bool:
381
+ raise TypeError(
382
+ 'Binary regression requires y_train.dtype=bool,'
383
+ f' got {y_train.dtype=} instead.'
384
+ )
385
+ case _:
386
+ raise ValueError(f'Invalid {type=}')
387
+
388
+ return y_train
389
+
390
+ @staticmethod
391
+ def _process_offset_settings(
392
+ y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
393
+ offset: float | Float32[Any, ''] | None,
394
+ ) -> Float32[Array, '']:
313
395
  if offset is not None:
314
- return offset
396
+ return jnp.asarray(offset)
315
397
  elif y_train.size < 1:
316
- return 0
398
+ return jnp.array(0.0)
317
399
  else:
318
- return y_train.mean()
400
+ mean = y_train.mean()
319
401
 
320
- @staticmethod
321
- def _process_scale_settings(y_train, k):
322
- if y_train.size < 2:
323
- return 1
402
+ if y_train.dtype == bool:
403
+ return ndtri(mean)
324
404
  else:
325
- return (y_train.max() - y_train.min()) / (2 * k)
405
+ return mean
406
+
407
+ @staticmethod
408
+ def _process_leaf_sdev_settings(
409
+ y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
410
+ k: float,
411
+ ntree: int,
412
+ tau_num: FloatLike | None,
413
+ ):
414
+ if tau_num is None:
415
+ if y_train.dtype == bool:
416
+ tau_num = 3.0
417
+ elif y_train.size < 2:
418
+ tau_num = 1.0
419
+ else:
420
+ tau_num = (y_train.max() - y_train.min()) / 2
421
+
422
+ return tau_num / (k * math.sqrt(ntree))
326
423
 
327
424
  @staticmethod
328
425
  def _determine_splits(x_train, usequants, numcut):
@@ -336,52 +433,86 @@ class gbart:
336
433
  return prepcovars.bin_predictors(x, splits)
337
434
 
338
435
  @staticmethod
339
- def _transform_input(y, offset, scale):
340
- return (y - offset) / scale
341
-
342
- @staticmethod
343
- def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree, initkw):
436
+ def _setup_mcmc(
437
+ x_train,
438
+ y_train,
439
+ offset,
440
+ w,
441
+ max_split,
442
+ lamda,
443
+ sigma_mu,
444
+ sigdf,
445
+ power,
446
+ base,
447
+ maxdepth,
448
+ ntree,
449
+ init_kw,
450
+ ):
344
451
  depth = jnp.arange(maxdepth - 1)
345
452
  p_nonterminal = base / (1 + depth).astype(float) ** power
346
- sigma2_alpha = sigdf / 2
347
- sigma2_beta = lamda * sigma2_alpha
453
+
454
+ if y_train.dtype == bool:
455
+ sigma2_alpha = None
456
+ sigma2_beta = None
457
+ else:
458
+ sigma2_alpha = sigdf / 2
459
+ sigma2_beta = lamda * sigma2_alpha
460
+
348
461
  kw = dict(
349
462
  X=x_train,
350
- y=y_train,
463
+ # copy y_train because it's going to be donated in the mcmc loop
464
+ y=jnp.array(y_train),
465
+ offset=offset,
466
+ error_scale=w,
351
467
  max_split=max_split,
352
468
  num_trees=ntree,
353
469
  p_nonterminal=p_nonterminal,
470
+ sigma_mu2=jnp.square(sigma_mu),
354
471
  sigma2_alpha=sigma2_alpha,
355
472
  sigma2_beta=sigma2_beta,
356
473
  min_points_per_leaf=5,
357
474
  )
358
- kw.update(initkw)
475
+ if init_kw is not None:
476
+ kw.update(init_kw)
359
477
  return mcmcstep.init(**kw)
360
478
 
361
479
  @staticmethod
362
- def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
363
- if isinstance(seed, jax.Array) and jnp.issubdtype(seed.dtype, jax.dtypes.prng_key):
364
- key = seed
480
+ def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
481
+ if isinstance(seed, jax.Array) and jnp.issubdtype(
482
+ seed.dtype, jax.dtypes.prng_key
483
+ ):
484
+ key = seed.copy()
485
+ # copy because the inner loop in run_mcmc will donate the buffer
365
486
  else:
366
487
  key = jax.random.key(seed)
367
- callback = mcmcloop.make_simple_print_callback(printevery)
368
- return mcmcloop.run_mcmc(mcmc_state, nskip, ndpost, keepevery, callback, key)
369
488
 
370
- @staticmethod
371
- def _predict(trace, x):
372
- return mcmcloop.evaluate_trace(trace, x)
489
+ kw = dict(
490
+ n_burn=nskip,
491
+ n_skip=keepevery,
492
+ inner_loop_length=printevery,
493
+ allow_overflow=True,
494
+ )
495
+ if printevery is not None:
496
+ kw.update(mcmcloop.make_print_callbacks())
497
+ if run_mcmc_kw is not None:
498
+ kw.update(run_mcmc_kw)
373
499
 
374
- @staticmethod
375
- def _transform_output(y, offset, scale):
376
- return offset + scale * y
500
+ return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
377
501
 
378
502
  @staticmethod
379
- def _extract_sigma(trace, scale):
380
- return scale * jnp.sqrt(trace['sigma2'])
503
+ def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
504
+ if trace['sigma2'] is None:
505
+ return None
506
+ else:
507
+ return jnp.sqrt(trace['sigma2'])
381
508
 
509
+ @staticmethod
510
+ def _predict(trace, x):
511
+ return mcmcloop.evaluate_trace(trace, x)
382
512
 
383
513
  def _show_tree(self, i_sample, i_tree, print_all=False):
384
514
  from . import debug
515
+
385
516
  trace = self._main_trace
386
517
  leaf_tree = trace['leaf_trees'][i_sample, i_tree]
387
518
  var_tree = trace['var_trees'][i_sample, i_tree]
@@ -396,30 +527,49 @@ class gbart:
396
527
  else:
397
528
  resid = bart['resid']
398
529
  alpha = bart['sigma2_alpha'] + resid.size / 2
399
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
530
+ norm2 = jnp.dot(
531
+ resid, resid, preferred_element_type=bart['sigma2_beta'].dtype
532
+ )
400
533
  beta = bart['sigma2_beta'] + norm2 / 2
401
534
  sigma2 = beta / alpha
402
- return jnp.sqrt(sigma2) * self.scale
535
+ return jnp.sqrt(sigma2)
403
536
 
404
537
  def _compare_resid(self):
405
538
  bart = self._mcmc_state
406
- resid1 = bart['resid']
407
- yhat = grove.evaluate_forest(bart['X'], bart['leaf_trees'], bart['var_trees'], bart['split_trees'], jnp.float32)
408
- resid2 = bart['y'] - yhat
539
+ resid1 = bart.resid
540
+
541
+ trees = grove.evaluate_forest(
542
+ bart.X,
543
+ bart.forest.leaf_trees,
544
+ bart.forest.var_trees,
545
+ bart.forest.split_trees,
546
+ jnp.float32, # TODO remove these configurable dtypes around
547
+ )
548
+
549
+ if bart.z is not None:
550
+ ref = bart.z
551
+ else:
552
+ ref = bart.y
553
+ resid2 = ref - (trees + bart.offset)
554
+
409
555
  return resid1, resid2
410
556
 
411
557
  def _avg_acc(self):
412
558
  trace = self._main_trace
559
+
413
560
  def acc(prefix):
414
561
  acc = trace[f'{prefix}_acc_count']
415
562
  prop = trace[f'{prefix}_prop_count']
416
563
  return acc.sum() / prop.sum()
564
+
417
565
  return acc('grow'), acc('prune')
418
566
 
419
567
  def _avg_prop(self):
420
568
  trace = self._main_trace
569
+
421
570
  def prop(prefix):
422
571
  return trace[f'{prefix}_prop_count'].sum()
572
+
423
573
  pgrow = prop('grow')
424
574
  pprune = prop('prune')
425
575
  total = pgrow + pprune
@@ -432,16 +582,19 @@ class gbart:
432
582
 
433
583
  def _depth_distr(self):
434
584
  from . import debug
585
+
435
586
  trace = self._main_trace
436
587
  split_trees = trace['split_trees']
437
588
  return debug.trace_depth_distr(split_trees)
438
589
 
439
590
  def _points_per_leaf_distr(self):
440
591
  from . import debug
441
- return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state['X'])
592
+
593
+ return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
442
594
 
443
595
  def _check_trees(self):
444
596
  from . import debug
597
+
445
598
  return debug.check_trace(self._main_trace, self._mcmc_state)
446
599
 
447
600
  def _tree_goes_bad(self):
bartz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/__init__.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -23,18 +23,10 @@
23
23
  # SOFTWARE.
24
24
 
25
25
  """
26
- Super-fast BART (Bayesian Additive Regression Trees) in Python
26
+ Super-fast BART (Bayesian Additive Regression Trees) in Python.
27
27
 
28
28
  See the manual at https://gattocrucco.github.io/bartz/docs
29
29
  """
30
30
 
31
- from ._version import __version__
32
-
33
- from . import BART
34
-
35
- from . import debug
36
- from . import grove
37
- from . import mcmcstep
38
- from . import mcmcloop
39
- from . import prepcovars
40
- from . import jaxext
31
+ from . import BART, debug, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
32
+ from ._version import __version__ # noqa: F401
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.4.1'
1
+ __version__ = '0.6.0'