bartz 0.5.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,16 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bartz
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: Super-fast BART (Bayesian Additive Regression Trees) in Python
5
5
  Author: Giacomo Petrillo
6
6
  Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
7
7
  License-Expression: MIT
8
- Requires-Dist: jax>=0.4.35,<1
9
- Requires-Dist: jaxlib>=0.4.35,<1
10
- Requires-Dist: numpy>=1.25.2,<3
11
- Requires-Dist: scipy>=1.11.4,<2
8
+ Requires-Dist: equinox>=0.12.2
9
+ Requires-Dist: jax>=0.4.35
10
+ Requires-Dist: jaxlib>=0.4.35
11
+ Requires-Dist: jaxtyping>=0.3.2
12
+ Requires-Dist: numpy>=1.25.2
13
+ Requires-Dist: scipy>=1.11.4
12
14
  Requires-Python: >=3.10
13
15
  Project-URL: Documentation, https://gattocrucco.github.io/bartz/docs-dev
14
16
  Project-URL: Homepage, https://github.com/Gattocrucco/bartz
@@ -28,7 +28,7 @@ build-backend = "uv_build"
28
28
 
29
29
  [project]
30
30
  name = "bartz"
31
- version = "0.5.0"
31
+ version = "0.6.0"
32
32
  description = "Super-fast BART (Bayesian Additive Regression Trees) in Python"
33
33
  authors = [
34
34
  {name = "Giacomo Petrillo", email = "info@giacomopetrillo.com"},
@@ -36,14 +36,13 @@ authors = [
36
36
  license = "MIT"
37
37
  readme = "README.md"
38
38
  requires-python = ">=3.10"
39
- packages = [
40
- { include = "bartz", from = "src" },
41
- ]
42
39
  dependencies = [
43
- "jax >=0.4.35,<1",
44
- "jaxlib >=0.4.35,<1",
45
- "numpy >=1.25.2,<3",
46
- "scipy >=1.11.4,<2",
40
+ "equinox>=0.12.2",
41
+ "jax>=0.4.35",
42
+ "jaxlib>=0.4.35",
43
+ "jaxtyping>=0.3.2",
44
+ "numpy>=1.25.2",
45
+ "scipy>=1.11.4",
47
46
  ]
48
47
 
49
48
  [project.urls]
@@ -57,8 +56,8 @@ only-local = [
57
56
  "ipython>=8.36.0",
58
57
  "matplotlib>=3.10.3",
59
58
  "matplotlib-label-lines>=0.8.1",
60
- "polars[pandas,pyarrow]>=1.29.0",
61
59
  "pre-commit>=4.2.0",
60
+ "pydoclint>=0.6.6",
62
61
  "ruff>=0.11.9",
63
62
  "scikit-learn>=1.6.1",
64
63
  "tomli>=2.2.1",
@@ -71,12 +70,15 @@ ci = [
71
70
  "myst-parser>=4.0.1",
72
71
  "numpydoc>=1.8.0",
73
72
  "packaging>=25.0",
73
+ "polars[pandas,pyarrow]>=1.29.0",
74
74
  "pytest>=8.3.5",
75
75
  "pytest-timeout>=2.4.0",
76
76
  "sphinx>=8.1.3",
77
+ "sphinx-autodoc-typehints>=3.0.1",
77
78
  ]
78
79
 
79
80
  [tool.pytest.ini_options]
81
+ cache_dir = "config/pytest_cache"
80
82
  testpaths = ["tests"]
81
83
  filterwarnings = [
82
84
  'error:scatter inputs have incompatible types.*',
@@ -85,8 +87,9 @@ addopts = [
85
87
  "-r xXfE",
86
88
  "--pdbcls=IPython.terminal.debugger:TerminalPdb",
87
89
  "--durations=3",
90
+ "--verbose",
88
91
  ]
89
- timeout = 32
92
+ timeout = 64
90
93
  timeout_method = "thread" # when jax hangs, signals do not work
91
94
 
92
95
  # I wanted to use `--import-mode=importlib`, but it breaks importing submodules,
@@ -101,6 +104,7 @@ show_missing = true
101
104
 
102
105
  [tool.coverage.html]
103
106
  show_contexts = true
107
+ directory = "_site/coverage"
104
108
 
105
109
  [tool.coverage.paths]
106
110
  # the first path in each list must be the source directory in the machine that's
@@ -129,6 +133,7 @@ local = [
129
133
 
130
134
  [tool.ruff]
131
135
  exclude = [".asv", "*.ipynb"]
136
+ cache-dir = "config/ruff_cache"
132
137
 
133
138
  [tool.ruff.format]
134
139
  quote-style = "single"
@@ -138,12 +143,43 @@ select = [
138
143
  "B", # bugbear: grab bag of additional stuff
139
144
  "UP", # pyupgrade: fix some outdated idioms
140
145
  "I", # isort: sort and reformat import statements
141
- "F", # flake8
146
+ "F", # pyflakes
147
+ "D", # pydocstyle
148
+ "PT", # flake8-pytest-style
142
149
  ]
143
150
  ignore = [
144
- "B028", # warn with stacklevel = 2
151
+ "B028", # warn with stacklevel = 2
152
+ "D105", # Missing docstring in magic method
153
+ "F722", # Syntax error in forward annotation. I ignore this because jaxtyping uses strings for shapes instead of for deferred annotations.
154
+ "F821", # Undefined name. I ignore this because strings in jaxtyping.
155
+ "UP037", # Remove quotes from type annotation. Ignore because jaxtyping.
145
156
  ]
146
157
 
158
+ [tool.ruff.lint.per-file-ignores]
159
+ "{config/*,benchmarks/*,docs/*,src/bartz/debug.py,tests/rbartpackages/*,tests/__init__.py}" = [
160
+ "D100", # Missing docstring in public module
161
+ "D101", # Missing docstring in public class
162
+ "D102", # Missing docstring in public method
163
+ "D103", # Missing docstring in public function
164
+ "D104", # Missing docstring in public package
165
+ ]
166
+
167
+ [tool.ruff.lint.pydocstyle]
168
+ convention = "numpy"
169
+ ignore-decorators = ["functools.cached_property"]
170
+
171
+ [tool.pydoclint]
172
+ baseline = "config/pydoclint-baseline.txt"
173
+ auto-regenerate-baseline = true
174
+ arg-type-hints-in-signature = true
175
+ arg-type-hints-in-docstring = false
176
+ check-return-types = false
177
+ check-yield-types = false
178
+ treat-property-methods-as-class-attributes = true
179
+ check-style-mismatch = true
180
+ show-filenames-in-every-violation-message = true
181
+ check-class-attributes = false
182
+
147
183
  [tool.uv]
148
184
  python-downloads = "never"
149
185
  python-preference = "only-system"
@@ -22,13 +22,21 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
+ """Implement a user interface that mimics the R BART package."""
26
+
25
27
  import functools
28
+ import math
29
+ from typing import Any, Literal
26
30
 
27
31
  import jax
28
32
  import jax.numpy as jnp
33
+ from jax.scipy.special import ndtri
34
+ from jaxtyping import Array, Bool, Float, Float32
29
35
 
30
36
  from . import grove, jaxext, mcmcloop, mcmcstep, prepcovars
31
37
 
38
+ FloatLike = float | Float[Any, '']
39
+
32
40
 
33
41
  class gbart:
34
42
  """
@@ -46,6 +54,9 @@ class gbart:
46
54
  The training responses.
47
55
  x_test : array (p, m) or DataFrame, optional
48
56
  The test predictors.
57
+ type
58
+ The type of regression. 'wbart' for continuous regression, 'pbart' for
59
+ binary regression with probit link.
49
60
  usequants : bool, default False
50
61
  Whether to use predictors quantiles instead of a uniform grid to bin
51
62
  predictors.
@@ -70,16 +81,20 @@ class gbart:
70
81
  Parameters of the prior on tree node generation. The probability that a
71
82
  node at depth `d` (0-based) is non-terminal is ``base / (1 + d) **
72
83
  power``.
73
- maxdepth : int, default 6
74
- The maximum depth of the trees. This is 1-based, so with the default
75
- ``maxdepth=6``, the depths of the levels range from 0 to 5.
76
- lamda : float, optional
77
- The scale of the prior on the noise variance. If ``lamda==1``, the
78
- prior is an inverse chi-squared scaled to have harmonic mean 1. If
79
- not specified, it is set based on `sigest` and `sigquant`.
80
- offset : float, optional
84
+ lamda
85
+ The prior harmonic mean of the error variance. (The harmonic mean of x
86
+ is 1/mean(1/x).) If not specified, it is set based on `sigest` and
87
+ `sigquant`.
88
+ tau_num
89
+ The numerator in the expression that determines the prior standard
90
+ deviation of leaves. If not specified, default to ``(max(y_train) -
91
+ min(y_train)) / 2`` (or 1 if `y_train` has less than two elements) for
92
+ continuous regression, and 3 for binary regression.
93
+ offset
81
94
  The prior mean of the latent mean function. If not specified, it is set
82
- to the mean of `y_train`. If `y_train` is empty, it is set to 0.
95
+ to the mean of `y_train` for continuous regression, and to
96
+ ``Phi^-1(mean(y_train))`` for binary regression. If `y_train` is empty,
97
+ `offset` is set to 0.
83
98
  w : array (n,), optional
84
99
  Coefficients that rescale the error standard deviation on each
85
100
  datapoint. Not specifying `w` is equivalent to setting it to 1 for all
@@ -108,12 +123,24 @@ class gbart:
108
123
  The number of initial MCMC samples to discard as burn-in.
109
124
  keepevery : int, default 1
110
125
  The thinning factor for the MCMC samples, after burn-in.
111
- printevery : int, default 100
112
- The number of iterations (including skipped ones) between each log.
126
+ printevery : int or None, default 100
127
+ The number of iterations (including thinned-away ones) between each log
128
+ line. Set to `None` to disable logging.
129
+
130
+ `printevery` has a few unexpected side effects. On cpu, interrupting
131
+ with ^C halts the MCMC only on the next log. And the total number of
132
+ iterations is a multiple of `printevery`, so if ``nskip + keepevery *
133
+ ndpost`` is not a multiple of `printevery`, some of the last iterations
134
+ will not be saved.
113
135
  seed : int or jax random key, default 0
114
136
  The seed for the random number generator.
115
- initkw : dict
137
+ maxdepth : int, default 6
138
+ The maximum depth of the trees. This is 1-based, so with the default
139
+ ``maxdepth=6``, the depths of the levels range from 0 to 5.
140
+ init_kw : dict
116
141
  Additional arguments passed to `mcmcstep.init`.
142
+ run_mcmc_kw : dict
143
+ Additional arguments passed to `mcmcloop.run_mcmc`.
117
144
 
118
145
  Attributes
119
146
  ----------
@@ -131,20 +158,8 @@ class gbart:
131
158
  The standard deviation of the error in the burn-in phase.
132
159
  offset : float
133
160
  The prior mean of the latent mean function.
134
- scale : float
135
- The prior standard deviation of the latent mean function.
136
- lamda : float
137
- The prior harmonic mean of the error variance.
138
161
  sigest : float or None
139
162
  The estimated standard deviation of the error used to set `lamda`.
140
- ntree : int
141
- The number of trees.
142
- maxdepth : int
143
- The maximum depth of the trees.
144
-
145
- Methods
146
- -------
147
- predict
148
163
 
149
164
  Notes
150
165
  -----
@@ -153,14 +168,17 @@ class gbart:
153
168
 
154
169
  - If `x_train` and `x_test` are matrices, they have one predictor per row
155
170
  instead of per column.
171
+ - If `type` is not specified, it is determined solely based on the data type
172
+ of `y_train`, and not on whether it contains only two unique values.
156
173
  - If ``usequants=False``, R BART switches to quantiles anyway if there are
157
174
  less predictor values than the required number of bins, while bartz
158
175
  always follows the specification.
159
176
  - The error variance parameter is called `lamda` instead of `lambda`.
160
177
  - `rm_const` is always `False`.
161
178
  - The default `numcut` is 255 instead of 100.
162
- - A lot of functionality is missing (variable selection, discrete response).
179
+ - A lot of functionality is missing (e.g., variable selection).
163
180
  - There are some additional attributes, and some missing.
181
+ - The trees have a maximum depth.
164
182
 
165
183
  """
166
184
 
@@ -170,6 +188,7 @@ class gbart:
170
188
  y_train,
171
189
  *,
172
190
  x_test=None,
191
+ type: Literal['wbart', 'pbart'] = 'wbart',
173
192
  usequants=False,
174
193
  sigest=None,
175
194
  sigdf=3,
@@ -177,9 +196,9 @@ class gbart:
177
196
  k=2,
178
197
  power=2,
179
198
  base=0.95,
180
- maxdepth=6,
181
- lamda=None,
182
- offset=None,
199
+ lamda: FloatLike | None = None,
200
+ tau_num: FloatLike | None = None,
201
+ offset: FloatLike | None = None,
183
202
  w=None,
184
203
  ntree=200,
185
204
  numcut=255,
@@ -188,7 +207,9 @@ class gbart:
188
207
  keepevery=1,
189
208
  printevery=100,
190
209
  seed=0,
191
- initkw=None,
210
+ maxdepth=6,
211
+ init_kw=None,
212
+ run_mcmc_kw=None,
192
213
  ):
193
214
  x_train, x_train_fmt = self._process_predictor_input(x_train)
194
215
  y_train, _ = self._process_response_input(y_train)
@@ -197,42 +218,41 @@ class gbart:
197
218
  w, _ = self._process_response_input(w)
198
219
  self._check_same_length(x_train, w)
199
220
 
221
+ y_train = self._process_type_settings(y_train, type, w)
222
+ # from here onwards, the type is determined by y_train.dtype == bool
200
223
  offset = self._process_offset_settings(y_train, offset)
201
- scale = self._process_scale_settings(y_train, k)
202
- lamda, sigest = self._process_noise_variance_settings(
203
- x_train, y_train, sigest, sigdf, sigquant, lamda, offset
224
+ sigma_mu = self._process_leaf_sdev_settings(y_train, k, ntree, tau_num)
225
+ lamda, sigest = self._process_error_variance_settings(
226
+ x_train, y_train, sigest, sigdf, sigquant, lamda
204
227
  )
205
228
 
206
229
  splits, max_split = self._determine_splits(x_train, usequants, numcut)
207
230
  x_train = self._bin_predictors(x_train, splits)
208
- y_train, lamda_scaled = self._transform_input(y_train, lamda, offset, scale)
209
231
 
210
232
  mcmc_state = self._setup_mcmc(
211
233
  x_train,
212
234
  y_train,
235
+ offset,
213
236
  w,
214
237
  max_split,
215
- lamda_scaled,
238
+ lamda,
239
+ sigma_mu,
216
240
  sigdf,
217
241
  power,
218
242
  base,
219
243
  maxdepth,
220
244
  ntree,
221
- initkw,
245
+ init_kw,
222
246
  )
223
247
  final_state, burnin_trace, main_trace = self._run_mcmc(
224
- mcmc_state, ndpost, nskip, keepevery, printevery, seed
248
+ mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
225
249
  )
226
250
 
227
- sigma = self._extract_sigma(main_trace, scale)
228
- first_sigma = self._extract_sigma(burnin_trace, scale)
251
+ sigma = self._extract_sigma(main_trace)
252
+ first_sigma = self._extract_sigma(burnin_trace)
229
253
 
230
- self.offset = offset
231
- self.scale = scale
232
- self.lamda = lamda
254
+ self.offset = final_state.offset # from the state because of buffer donation
233
255
  self.sigest = sigest
234
- self.ntree = ntree
235
- self.maxdepth = maxdepth
236
256
  self.sigma = sigma
237
257
  self.first_sigma = first_sigma
238
258
 
@@ -248,9 +268,8 @@ class gbart:
248
268
 
249
269
  @functools.cached_property
250
270
  def yhat_train(self):
251
- x_train = self._mcmc_state['X']
252
- yhat_train = self._predict(self._main_trace, x_train)
253
- return self._transform_output(yhat_train, self.offset, self.scale)
271
+ x_train = self._mcmc_state.X
272
+ return self._predict(self._main_trace, x_train)
254
273
 
255
274
  @functools.cached_property
256
275
  def yhat_train_mean(self):
@@ -269,12 +288,19 @@ class gbart:
269
288
  -------
270
289
  yhat_test : array (ndpost, m)
271
290
  The conditional posterior mean at `x_test` for each MCMC iteration.
291
+
292
+ Raises
293
+ ------
294
+ ValueError
295
+ If `x_test` has a different format than `x_train`.
272
296
  """
273
297
  x_test, x_test_fmt = self._process_predictor_input(x_test)
274
- self._check_compatible_formats(x_test_fmt, self._x_train_fmt)
298
+ if x_test_fmt != self._x_train_fmt:
299
+ raise ValueError(
300
+ f'Input format mismatch: {x_test_fmt=} != x_train_fmt={self._x_train_fmt!r}'
301
+ )
275
302
  x_test = self._bin_predictors(x_test, self._splits)
276
- yhat_test = self._predict(self._main_trace, x_test)
277
- return self._transform_output(yhat_test, self.offset, self.scale)
303
+ return self._predict(self._main_trace, x_test)
278
304
 
279
305
  @staticmethod
280
306
  def _process_predictor_input(x):
@@ -287,10 +313,6 @@ class gbart:
287
313
  assert x.ndim == 2
288
314
  return x, fmt
289
315
 
290
- @staticmethod
291
- def _check_compatible_formats(fmt1, fmt2):
292
- assert fmt1 == fmt2
293
-
294
316
  @staticmethod
295
317
  def _process_response_input(y):
296
318
  if hasattr(y, 'to_numpy'):
@@ -308,18 +330,26 @@ class gbart:
308
330
  assert get_length(x1) == get_length(x2)
309
331
 
310
332
  @staticmethod
311
- def _process_noise_variance_settings(
312
- x_train, y_train, sigest, sigdf, sigquant, lamda, offset
313
- ):
314
- if lamda is not None:
333
+ def _process_error_variance_settings(
334
+ x_train, y_train, sigest, sigdf, sigquant, lamda
335
+ ) -> tuple[Float32[Array, ''] | None, ...]:
336
+ if y_train.dtype == bool:
337
+ if sigest is not None:
338
+ raise ValueError('Let `sigest=None` for binary regression')
339
+ if lamda is not None:
340
+ raise ValueError('Let `lamda=None` for binary regression')
341
+ return None, None
342
+ elif lamda is not None:
343
+ if sigest is not None:
344
+ raise ValueError('Let `sigest=None` if `lamda` is specified')
315
345
  return lamda, None
316
346
  else:
317
347
  if sigest is not None:
318
- sigest2 = sigest * sigest
348
+ sigest2 = jnp.square(sigest)
319
349
  elif y_train.size < 2:
320
350
  sigest2 = 1
321
351
  elif y_train.size <= x_train.shape[0]:
322
- sigest2 = jnp.var(y_train - offset)
352
+ sigest2 = jnp.var(y_train)
323
353
  else:
324
354
  x_centered = x_train.T - x_train.mean(axis=1)
325
355
  y_centered = y_train - y_train.mean()
@@ -334,20 +364,62 @@ class gbart:
334
364
  return sigest2 / invchi2rid, jnp.sqrt(sigest2)
335
365
 
336
366
  @staticmethod
337
- def _process_offset_settings(y_train, offset):
367
+ def _process_type_settings(y_train, type, w):
368
+ match type:
369
+ case 'wbart':
370
+ if y_train.dtype != jnp.float32:
371
+ raise TypeError(
372
+ 'Continuous regression requires y_train.dtype=float32,'
373
+ f' got {y_train.dtype=} instead.'
374
+ )
375
+ case 'pbart':
376
+ if w is not None:
377
+ raise ValueError(
378
+ 'Binary regression does not support weights, set `w=None`'
379
+ )
380
+ if y_train.dtype != bool:
381
+ raise TypeError(
382
+ 'Binary regression requires y_train.dtype=bool,'
383
+ f' got {y_train.dtype=} instead.'
384
+ )
385
+ case _:
386
+ raise ValueError(f'Invalid {type=}')
387
+
388
+ return y_train
389
+
390
+ @staticmethod
391
+ def _process_offset_settings(
392
+ y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
393
+ offset: float | Float32[Any, ''] | None,
394
+ ) -> Float32[Array, '']:
338
395
  if offset is not None:
339
- return offset
396
+ return jnp.asarray(offset)
340
397
  elif y_train.size < 1:
341
- return 0
398
+ return jnp.array(0.0)
342
399
  else:
343
- return y_train.mean()
400
+ mean = y_train.mean()
344
401
 
345
- @staticmethod
346
- def _process_scale_settings(y_train, k):
347
- if y_train.size < 2:
348
- return 1
402
+ if y_train.dtype == bool:
403
+ return ndtri(mean)
349
404
  else:
350
- return (y_train.max() - y_train.min()) / (2 * k)
405
+ return mean
406
+
407
+ @staticmethod
408
+ def _process_leaf_sdev_settings(
409
+ y_train: Float32[Array, 'n'] | Bool[Array, 'n'],
410
+ k: float,
411
+ ntree: int,
412
+ tau_num: FloatLike | None,
413
+ ):
414
+ if tau_num is None:
415
+ if y_train.dtype == bool:
416
+ tau_num = 3.0
417
+ elif y_train.size < 2:
418
+ tau_num = 1.0
419
+ else:
420
+ tau_num = (y_train.max() - y_train.min()) / 2
421
+
422
+ return tau_num / (k * math.sqrt(ntree))
351
423
 
352
424
  @staticmethod
353
425
  def _determine_splits(x_train, usequants, numcut):
@@ -360,67 +432,83 @@ class gbart:
360
432
  def _bin_predictors(x, splits):
361
433
  return prepcovars.bin_predictors(x, splits)
362
434
 
363
- @staticmethod
364
- def _transform_input(y, lamda, offset, scale):
365
- y = (y - offset) / scale
366
- lamda = lamda / (scale * scale)
367
- return y, lamda
368
-
369
435
  @staticmethod
370
436
  def _setup_mcmc(
371
437
  x_train,
372
438
  y_train,
439
+ offset,
373
440
  w,
374
441
  max_split,
375
442
  lamda,
443
+ sigma_mu,
376
444
  sigdf,
377
445
  power,
378
446
  base,
379
447
  maxdepth,
380
448
  ntree,
381
- initkw,
449
+ init_kw,
382
450
  ):
383
451
  depth = jnp.arange(maxdepth - 1)
384
452
  p_nonterminal = base / (1 + depth).astype(float) ** power
385
- sigma2_alpha = sigdf / 2
386
- sigma2_beta = lamda * sigma2_alpha
453
+
454
+ if y_train.dtype == bool:
455
+ sigma2_alpha = None
456
+ sigma2_beta = None
457
+ else:
458
+ sigma2_alpha = sigdf / 2
459
+ sigma2_beta = lamda * sigma2_alpha
460
+
387
461
  kw = dict(
388
462
  X=x_train,
389
- y=y_train,
463
+ # copy y_train because it's going to be donated in the mcmc loop
464
+ y=jnp.array(y_train),
465
+ offset=offset,
390
466
  error_scale=w,
391
467
  max_split=max_split,
392
468
  num_trees=ntree,
393
469
  p_nonterminal=p_nonterminal,
470
+ sigma_mu2=jnp.square(sigma_mu),
394
471
  sigma2_alpha=sigma2_alpha,
395
472
  sigma2_beta=sigma2_beta,
396
473
  min_points_per_leaf=5,
397
474
  )
398
- if initkw is not None:
399
- kw.update(initkw)
475
+ if init_kw is not None:
476
+ kw.update(init_kw)
400
477
  return mcmcstep.init(**kw)
401
478
 
402
479
  @staticmethod
403
- def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
480
+ def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw):
404
481
  if isinstance(seed, jax.Array) and jnp.issubdtype(
405
482
  seed.dtype, jax.dtypes.prng_key
406
483
  ):
407
- key = seed
484
+ key = seed.copy()
485
+ # copy because the inner loop in run_mcmc will donate the buffer
408
486
  else:
409
487
  key = jax.random.key(seed)
410
- callback = mcmcloop.make_simple_print_callback(printevery)
411
- return mcmcloop.run_mcmc(key, mcmc_state, nskip, ndpost, keepevery, callback)
412
488
 
413
- @staticmethod
414
- def _predict(trace, x):
415
- return mcmcloop.evaluate_trace(trace, x)
489
+ kw = dict(
490
+ n_burn=nskip,
491
+ n_skip=keepevery,
492
+ inner_loop_length=printevery,
493
+ allow_overflow=True,
494
+ )
495
+ if printevery is not None:
496
+ kw.update(mcmcloop.make_print_callbacks())
497
+ if run_mcmc_kw is not None:
498
+ kw.update(run_mcmc_kw)
499
+
500
+ return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
416
501
 
417
502
  @staticmethod
418
- def _transform_output(y, offset, scale):
419
- return offset + scale * y
503
+ def _extract_sigma(trace) -> Float32[Array, 'trace_length'] | None:
504
+ if trace['sigma2'] is None:
505
+ return None
506
+ else:
507
+ return jnp.sqrt(trace['sigma2'])
420
508
 
421
509
  @staticmethod
422
- def _extract_sigma(trace, scale):
423
- return scale * jnp.sqrt(trace['sigma2'])
510
+ def _predict(trace, x):
511
+ return mcmcloop.evaluate_trace(trace, x)
424
512
 
425
513
  def _show_tree(self, i_sample, i_tree, print_all=False):
426
514
  from . import debug
@@ -444,19 +532,26 @@ class gbart:
444
532
  )
445
533
  beta = bart['sigma2_beta'] + norm2 / 2
446
534
  sigma2 = beta / alpha
447
- return jnp.sqrt(sigma2) * self.scale
535
+ return jnp.sqrt(sigma2)
448
536
 
449
537
  def _compare_resid(self):
450
538
  bart = self._mcmc_state
451
- resid1 = bart['resid']
452
- yhat = grove.evaluate_forest(
453
- bart['X'],
454
- bart['leaf_trees'],
455
- bart['var_trees'],
456
- bart['split_trees'],
457
- jnp.float32,
539
+ resid1 = bart.resid
540
+
541
+ trees = grove.evaluate_forest(
542
+ bart.X,
543
+ bart.forest.leaf_trees,
544
+ bart.forest.var_trees,
545
+ bart.forest.split_trees,
546
+ jnp.float32, # TODO remove these configurable dtypes around
458
547
  )
459
- resid2 = bart['y'] - yhat
548
+
549
+ if bart.z is not None:
550
+ ref = bart.z
551
+ else:
552
+ ref = bart.y
553
+ resid2 = ref - (trees + bart.offset)
554
+
460
555
  return resid1, resid2
461
556
 
462
557
  def _avg_acc(self):
@@ -495,9 +590,7 @@ class gbart:
495
590
  def _points_per_leaf_distr(self):
496
591
  from . import debug
497
592
 
498
- return debug.trace_points_per_leaf_distr(
499
- self._main_trace, self._mcmc_state['X']
500
- )
593
+ return debug.trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
501
594
 
502
595
  def _check_trees(self):
503
596
  from . import debug
@@ -23,7 +23,7 @@
23
23
  # SOFTWARE.
24
24
 
25
25
  """
26
- Super-fast BART (Bayesian Additive Regression Trees) in Python
26
+ Super-fast BART (Bayesian Additive Regression Trees) in Python.
27
27
 
28
28
  See the manual at https://gattocrucco.github.io/bartz/docs
29
29
  """
@@ -0,0 +1 @@
1
+ __version__ = '0.6.0'
@@ -180,5 +180,5 @@ def check_trace(trace, state):
180
180
  trace['leaf_trees'],
181
181
  trace['var_trees'],
182
182
  trace['split_trees'],
183
- state['max_split'],
183
+ state.max_split,
184
184
  )