bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- # bartz/src/bartz/BART.py
1
+ # bartz/src/bartz/_interface.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2025-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -22,17 +22,20 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """Implement a class `gbart` that mimics the R BART package."""
25
+ """Main high-level interface of the package."""
26
26
 
27
27
  import math
28
28
  from collections.abc import Sequence
29
29
  from functools import cached_property
30
- from typing import Any, Literal, Protocol
30
+ from typing import Any, Literal, Protocol, TypedDict
31
31
 
32
32
  import jax
33
33
  import jax.numpy as jnp
34
34
  from equinox import Module, field
35
+ from jax import Device, device_put, jit, make_mesh
36
+ from jax.lax import collapse
35
37
  from jax.scipy.special import ndtr
38
+ from jax.sharding import AxisType, Mesh
36
39
  from jaxtyping import (
37
40
  Array,
38
41
  Bool,
@@ -48,22 +51,21 @@ from jaxtyping import (
48
51
  from numpy import ndarray
49
52
 
50
53
  from bartz import mcmcloop, mcmcstep, prepcovars
54
+ from bartz.jaxext import is_key
51
55
  from bartz.jaxext.scipy.special import ndtri
52
56
  from bartz.jaxext.scipy.stats import invgamma
57
+ from bartz.mcmcloop import compute_varcount, evaluate_trace, run_mcmc
58
+ from bartz.mcmcstep import make_p_nonterminal
59
+ from bartz.mcmcstep._state import get_num_chains
53
60
 
54
61
  FloatLike = float | Float[Any, '']
55
62
 
56
63
 
57
64
  class DataFrame(Protocol):
58
- """DataFrame duck-type for `gbart`.
59
-
60
- Attributes
61
- ----------
62
- columns : Sequence[str]
63
- The names of the columns.
64
- """
65
+ """DataFrame duck-type for `Bart`."""
65
66
 
66
67
  columns: Sequence[str]
68
+ """The names of the columns."""
67
69
 
68
70
  def to_numpy(self) -> ndarray:
69
71
  """Convert the dataframe to a 2d numpy array with columns on the second axis."""
@@ -71,22 +73,17 @@ class DataFrame(Protocol):
71
73
 
72
74
 
73
75
  class Series(Protocol):
74
- """Series duck-type for `gbart`.
75
-
76
- Attributes
77
- ----------
78
- name : str | None
79
- The name of the series.
80
- """
76
+ """Series duck-type for `Bart`."""
81
77
 
82
78
  name: str | None
79
+ """The name of the series."""
83
80
 
84
81
  def to_numpy(self) -> ndarray:
85
82
  """Convert the series to a 1d numpy array."""
86
83
  ...
87
84
 
88
85
 
89
- class gbart(Module):
86
+ class Bart(Module):
90
87
  R"""
91
88
  Nonparametric regression with Bayesian Additive Regression Trees (BART) [2]_.
92
89
 
@@ -147,10 +144,7 @@ class gbart(Module):
147
144
  rm_const
148
145
  How to treat predictors with no associated decision rules (i.e., there
149
146
  are no available cutpoints for that predictor). If `True` (default),
150
- they are ignored. If `False`, an error is raised if there are any. If
151
- `None`, no check is performed, and the output of the MCMC may not make
152
- sense if there are predictors without cutpoints. The option `None` is
153
- provided only to allow jax tracing.
147
+ they are ignored. If `False`, an error is raised if there are any.
154
148
  sigest
155
149
  An estimate of the residual standard deviation on `y_train`, used to set
156
150
  `lamda`. If not specified, it is estimated by linear regression (with
@@ -214,21 +208,40 @@ class gbart(Module):
214
208
 
215
209
  Ignored if `xinfo` is specified.
216
210
  ndpost
217
- The number of MCMC samples to save, after burn-in.
211
+ The number of MCMC samples to save, after burn-in. `ndpost` is the
212
+ total number of samples across all chains. `ndpost` is rounded up to the
213
+ first multiple of `mc_cores`.
218
214
  nskip
219
- The number of initial MCMC samples to discard as burn-in.
215
+ The number of initial MCMC samples to discard as burn-in. This number
216
+ of samples is discarded from each chain.
220
217
  keepevery
221
218
  The thinning factor for the MCMC samples, after burn-in. By default, 1
222
219
  for continuous regression and 10 for binary regression.
223
220
  printevery
224
221
  The number of iterations (including thinned-away ones) between each log
225
- line. Set to `None` to disable logging.
226
-
227
- `printevery` has a few unexpected side effects. On cpu, interrupting
228
- with ^C halts the MCMC only on the next log. And the total number of
229
- iterations is a multiple of `printevery`, so if ``nskip + keepevery *
230
- ndpost`` is not a multiple of `printevery`, some of the last iterations
231
- will not be saved.
222
+ line. Set to `None` to disable logging. ^C interrupts the MCMC only
223
+ every `printevery` iterations, so with logging disabled it's impossible
224
+ to kill the MCMC conveniently.
225
+ num_chains
226
+ The number of independent Markov chains to run. By default only one
227
+ chain is run.
228
+
229
+ The difference between not specifying `num_chains` and setting it to 1
230
+ is that in the latter case in the object attributes and some methods
231
+ there will be an explicit chain axis of size 1.
232
+ num_chain_devices
233
+ The number of devices to spread the chains across. Must be a divisor of
234
+ `num_chains`. Each device will run a fraction of the chains.
235
+ num_data_devices
236
+ The number of devices to split datapoints across. Must be a divisor of
237
+ `n`. This is useful only with very high `n`, about > 1000_000.
238
+
239
+ If both num_chain_devices and num_data_devices are specified, the total
240
+ number of devices used is the product of the two.
241
+ devices
242
+ One or more devices used to run the MCMC on. If not specified, the
243
+ computation will follow the placement of the input arrays. If a list of
244
+ devices, this argument can be longer than the number of devices needed.
232
245
  seed
233
246
  The seed for the random number generator.
234
247
  maxdepth
@@ -239,34 +252,6 @@ class gbart(Module):
239
252
  run_mcmc_kw
240
253
  Additional arguments passed to `bartz.mcmcloop.run_mcmc`.
241
254
 
242
- Attributes
243
- ----------
244
- offset : Float32[Array, '']
245
- The prior mean of the latent mean function.
246
- sigest : Float32[Array, ''] | None
247
- The estimated standard deviation of the error used to set `lamda`.
248
- yhat_test : Float32[Array, 'ndpost m'] | None
249
- The conditional posterior mean at `x_test` for each MCMC iteration.
250
-
251
- Notes
252
- -----
253
- This interface imitates the function ``gbart`` from the R package `BART
254
- <https://cran.r-project.org/package=BART>`_, but with these differences:
255
-
256
- - If `x_train` and `x_test` are matrices, they have one predictor per row
257
- instead of per column.
258
- - If ``usequants=False``, R BART switches to quantiles anyway if there are
259
- less predictor values than the required number of bins, while bartz
260
- always follows the specification.
261
- - Some functionality is missing.
262
- - The error variance parameter is called `lamda` instead of `lambda`.
263
- - There are some additional attributes, and some missing.
264
- - The trees have a maximum depth.
265
- - `rm_const` refers to predictors without decision rules instead of
266
- predictors that are constant in `x_train`.
267
- - If `rm_const=True` and some variables are dropped, the predictors
268
- matrix/dataframe passed to `predict` should still include them.
269
-
270
255
  References
271
256
  ----------
272
257
  .. [1] Linero, Antonio R. (2018). “Bayesian Regression Trees for
@@ -283,10 +268,14 @@ class gbart(Module):
283
268
  _splits: Real[Array, 'p max_num_splits']
284
269
  _x_train_fmt: Any = field(static=True)
285
270
 
286
- ndpost: int = field(static=True)
287
271
  offset: Float32[Array, '']
272
+ """The prior mean of the latent mean function."""
273
+
288
274
  sigest: Float32[Array, ''] | None = None
275
+ """The estimated standard deviation of the error used to set `lamda`."""
276
+
289
277
  yhat_test: Float32[Array, 'ndpost m'] | None = None
278
+ """The conditional posterior mean at `x_test` for each MCMC iteration."""
290
279
 
291
280
  def __init__(
292
281
  self,
@@ -302,7 +291,7 @@ class gbart(Module):
302
291
  rho: FloatLike | None = None,
303
292
  xinfo: Float[Array, 'p n'] | None = None,
304
293
  usequants: bool = False,
305
- rm_const: bool | None = True,
294
+ rm_const: bool = True,
306
295
  sigest: FloatLike | None = None,
307
296
  sigdf: FloatLike = 3.0,
308
297
  sigquant: FloatLike = 0.9,
@@ -312,13 +301,17 @@ class gbart(Module):
312
301
  lamda: FloatLike | None = None,
313
302
  tau_num: FloatLike | None = None,
314
303
  offset: FloatLike | None = None,
315
- w: Float[Array, ' n'] | None = None,
304
+ w: Float[Array, ' n'] | Series | None = None,
316
305
  ntree: int | None = None,
317
306
  numcut: int = 100,
318
307
  ndpost: int = 1000,
319
308
  nskip: int = 100,
320
309
  keepevery: int | None = None,
321
310
  printevery: int | None = 100,
311
+ num_chains: int | None = None,
312
+ num_chain_devices: int | None = None,
313
+ num_data_devices: int | None = None,
314
+ devices: Device | Sequence[Device] | None = None,
322
315
  seed: int | Key[Array, ''] = 0,
323
316
  maxdepth: int = 6,
324
317
  init_kw: dict | None = None,
@@ -378,21 +371,19 @@ class gbart(Module):
378
371
  a,
379
372
  b,
380
373
  rho,
374
+ num_chains,
375
+ num_chain_devices,
376
+ num_data_devices,
377
+ devices,
378
+ sparse,
379
+ nskip,
381
380
  )
382
381
  final_state, burnin_trace, main_trace = self._run_mcmc(
383
- initial_state,
384
- ndpost,
385
- nskip,
386
- keepevery,
387
- printevery,
388
- seed,
389
- run_mcmc_kw,
390
- sparse,
382
+ initial_state, ndpost, nskip, keepevery, printevery, seed, run_mcmc_kw
391
383
  )
392
384
 
393
385
  # set public attributes
394
386
  self.offset = final_state.offset # from the state because of buffer donation
395
- self.ndpost = ndpost
396
387
  self.sigest = sigest
397
388
 
398
389
  # set private attributes
@@ -406,6 +397,15 @@ class gbart(Module):
406
397
  if x_test is not None:
407
398
  self.yhat_test = self.predict(x_test)
408
399
 
400
+ @property
401
+ def ndpost(self):
402
+ """The total number of posterior samples after burn-in across all chains.
403
+
404
+ May be larger than the initialization argument `ndpost` if it was not
405
+ divisible by the number of chains.
406
+ """
407
+ return self._main_trace.grow_prop_count.size
408
+
409
409
  @cached_property
410
410
  def prob_test(self) -> Float32[Array, 'ndpost m'] | None:
411
411
  """The posterior probability of y being True at `x_test` for each MCMC iteration."""
@@ -439,30 +439,53 @@ class gbart(Module):
439
439
  return self.prob_train.mean(axis=0)
440
440
 
441
441
  @cached_property
442
- def sigma(self) -> Float32[Array, ' nskip+ndpost'] | None:
442
+ def sigma(
443
+ self,
444
+ ) -> (
445
+ Float32[Array, ' nskip+ndpost']
446
+ | Float32[Array, 'nskip+ndpost/mc_cores mc_cores']
447
+ | None
448
+ ):
443
449
  """The standard deviation of the error, including burn-in samples."""
444
- if self._burnin_trace.sigma2 is None:
450
+ if self._burnin_trace.error_cov_inv is None:
445
451
  return None
446
- else:
447
- assert self._main_trace.sigma2 is not None
448
- return jnp.sqrt(
449
- jnp.concatenate([self._burnin_trace.sigma2, self._main_trace.sigma2])
452
+ assert self._main_trace.error_cov_inv is not None
453
+ return jnp.sqrt(
454
+ jnp.reciprocal(
455
+ jnp.concatenate(
456
+ [
457
+ self._burnin_trace.error_cov_inv.T,
458
+ self._main_trace.error_cov_inv.T,
459
+ ],
460
+ axis=0,
461
+ # error_cov_inv has shape (chains? samples) in the trace
462
+ )
450
463
  )
464
+ )
465
+
466
+ @cached_property
467
+ def sigma_(self) -> Float32[Array, 'ndpost'] | None:
468
+ """The standard deviation of the error, only over the post-burnin samples and flattened."""
469
+ error_cov_inv = self._main_trace.error_cov_inv
470
+ if error_cov_inv is None:
471
+ return None
472
+ else:
473
+ return jnp.sqrt(jnp.reciprocal(error_cov_inv)).reshape(-1)
451
474
 
452
475
  @cached_property
453
476
  def sigma_mean(self) -> Float32[Array, ''] | None:
454
477
  """The mean of `sigma`, only over the post-burnin samples."""
455
- if self.sigma is None:
478
+ if self.sigma_ is None:
456
479
  return None
457
- else:
458
- return self.sigma[len(self.sigma) - self.ndpost :].mean(axis=0)
480
+ return self.sigma_.mean()
459
481
 
460
482
  @cached_property
461
483
  def varcount(self) -> Int32[Array, 'ndpost p']:
462
484
  """Histogram of predictor usage for decision rules in the trees."""
463
- return mcmcloop.compute_varcount(
464
- self._mcmc_state.forest.max_split.size, self._main_trace
465
- )
485
+ p = self._mcmc_state.forest.max_split.size
486
+ varcount: Int32[Array, '*chains samples p']
487
+ varcount = compute_varcount(p, self._main_trace)
488
+ return collapse(varcount, 0, -1)
466
489
 
467
490
  @cached_property
468
491
  def varcount_mean(self) -> Float32[Array, ' p']:
@@ -472,13 +495,15 @@ class gbart(Module):
472
495
  @cached_property
473
496
  def varprob(self) -> Float32[Array, 'ndpost p']:
474
497
  """Posterior samples of the probability of choosing each predictor for a decision rule."""
498
+ max_split = self._mcmc_state.forest.max_split
499
+ p = max_split.size
475
500
  varprob = self._main_trace.varprob
476
501
  if varprob is None:
477
- max_split = self._mcmc_state.forest.max_split
478
- p = max_split.size
479
502
  peff = jnp.count_nonzero(max_split)
480
503
  varprob = jnp.where(max_split, 1 / peff, 0)
481
504
  varprob = jnp.broadcast_to(varprob, (self.ndpost, p))
505
+ else:
506
+ varprob = varprob.reshape(-1, p)
482
507
  return varprob
483
508
 
484
509
  @cached_property
@@ -567,10 +592,11 @@ class gbart(Module):
567
592
  get_length = lambda x: x.shape[-1]
568
593
  assert get_length(x1) == get_length(x2)
569
594
 
570
- @staticmethod
595
+ @classmethod
571
596
  def _process_error_variance_settings(
572
- x_train, y_train, sigest, sigdf, sigquant, lamda
597
+ cls, x_train, y_train, sigest, sigdf, sigquant, lamda
573
598
  ) -> tuple[Float32[Array, ''] | None, ...]:
599
+ """Return (lamda, sigest)."""
574
600
  if y_train.dtype == bool:
575
601
  if sigest is not None:
576
602
  msg = 'Let `sigest=None` for binary regression'
@@ -592,18 +618,26 @@ class gbart(Module):
592
618
  elif y_train.size <= x_train.shape[0]:
593
619
  sigest2 = jnp.var(y_train)
594
620
  else:
595
- x_centered = x_train.T - x_train.mean(axis=1)
596
- y_centered = y_train - y_train.mean()
597
- # centering is equivalent to adding an intercept column
598
- _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
599
- chisq = chisq.squeeze(0)
600
- dof = len(y_train) - rank
601
- sigest2 = chisq / dof
621
+ sigest2 = cls._linear_regression(x_train, y_train)
602
622
  alpha = sigdf / 2
603
623
  invchi2 = invgamma.ppf(sigquant, alpha) / 2
604
624
  invchi2rid = invchi2 * sigdf
605
625
  return sigest2 / invchi2rid, jnp.sqrt(sigest2)
606
626
 
627
+ @staticmethod
628
+ @jit
629
+ def _linear_regression(
630
+ x_train: Shaped[Array, 'p n'], y_train: Float32[Array, ' n']
631
+ ):
632
+ """Return the error variance estimated with OLS with intercept."""
633
+ x_centered = x_train.T - x_train.mean(axis=1)
634
+ y_centered = y_train - y_train.mean()
635
+ # centering is equivalent to adding an intercept column
636
+ _, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
637
+ chisq = chisq.squeeze(0)
638
+ dof = len(y_train) - rank
639
+ return chisq / dof
640
+
607
641
  @staticmethod
608
642
  def _check_type_settings(y_train, type, w): # noqa: A002
609
643
  match type:
@@ -641,6 +675,7 @@ class gbart(Module):
641
675
  | tuple[FloatLike, None, None, None]
642
676
  | tuple[None, FloatLike, FloatLike, FloatLike]
643
677
  ):
678
+ """Return (theta, a, b, rho)."""
644
679
  if not sparse:
645
680
  return None, None, None, None
646
681
  elif theta is not None:
@@ -656,6 +691,7 @@ class gbart(Module):
656
691
  y_train: Float32[Array, ' n'] | Bool[Array, ' n'],
657
692
  offset: float | Float32[Any, ''] | None,
658
693
  ) -> Float32[Array, '']:
694
+ """Return offset."""
659
695
  if offset is not None:
660
696
  return jnp.asarray(offset)
661
697
  elif y_train.size < 1:
@@ -677,6 +713,7 @@ class gbart(Module):
677
713
  ntree: int,
678
714
  tau_num: FloatLike | None,
679
715
  ):
716
+ """Return sigma_mu."""
680
717
  if tau_num is None:
681
718
  if y_train.dtype == bool:
682
719
  tau_num = 3.0
@@ -705,7 +742,9 @@ class gbart(Module):
705
742
  return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
706
743
 
707
744
  @staticmethod
708
- def _bin_predictors(x, splits) -> UInt[Array, 'p n']:
745
+ def _bin_predictors(
746
+ x: Real[Array, 'p n'], splits: Real[Array, 'p max_num_splits']
747
+ ) -> UInt[Array, 'p n']:
709
748
  return prepcovars.bin_predictors(x, splits)
710
749
 
711
750
  @staticmethod
@@ -723,23 +762,35 @@ class gbart(Module):
723
762
  maxdepth: int,
724
763
  ntree: int,
725
764
  init_kw: dict[str, Any] | None,
726
- rm_const: bool | None,
765
+ rm_const: bool,
727
766
  theta: FloatLike | None,
728
767
  a: FloatLike | None,
729
768
  b: FloatLike | None,
730
769
  rho: FloatLike | None,
770
+ num_chains: int | None,
771
+ num_chain_devices: int | None,
772
+ num_data_devices: int | None,
773
+ devices: Device | Sequence[Device] | None,
774
+ sparse: bool,
775
+ nskip: int,
731
776
  ):
732
- depth = jnp.arange(maxdepth - 1)
733
- p_nonterminal = base / (1 + depth).astype(float) ** power
777
+ p_nonterminal = make_p_nonterminal(maxdepth, base, power)
734
778
 
735
779
  if y_train.dtype == bool:
736
- sigma2_alpha = None
737
- sigma2_beta = None
780
+ error_cov_df = None
781
+ error_cov_scale = None
738
782
  else:
739
- sigma2_alpha = sigdf / 2
740
- sigma2_beta = lamda * sigma2_alpha
783
+ assert lamda is not None
784
+ # inverse gamma prior: alpha = df / 2, beta = scale / 2
785
+ error_cov_df = sigdf
786
+ error_cov_scale = lamda * sigdf
787
+
788
+ # process device settings
789
+ device_kw, device = process_device_settings(
790
+ y_train, num_chains, num_chain_devices, num_data_devices, devices
791
+ )
741
792
 
742
- kw = dict(
793
+ kw: dict = dict(
743
794
  X=x_train,
744
795
  # copy y_train because it's going to be donated in the mcmc loop
745
796
  y=jnp.array(y_train),
@@ -748,35 +799,37 @@ class gbart(Module):
748
799
  max_split=max_split,
749
800
  num_trees=ntree,
750
801
  p_nonterminal=p_nonterminal,
751
- sigma_mu2=jnp.square(sigma_mu),
752
- sigma2_alpha=sigma2_alpha,
753
- sigma2_beta=sigma2_beta,
802
+ leaf_prior_cov_inv=jnp.reciprocal(jnp.square(sigma_mu)),
803
+ error_cov_df=error_cov_df,
804
+ error_cov_scale=error_cov_scale,
754
805
  min_points_per_decision_node=10,
755
806
  min_points_per_leaf=5,
756
807
  theta=theta,
757
808
  a=a,
758
809
  b=b,
759
810
  rho=rho,
811
+ sparse_on_at=nskip // 2 if sparse else None,
812
+ **device_kw,
760
813
  )
761
814
 
762
- if rm_const is None:
763
- kw.update(filter_splitless_vars=False)
764
- elif rm_const:
765
- kw.update(filter_splitless_vars=True)
766
- else:
767
- n_empty = jnp.count_nonzero(max_split == 0)
768
- if n_empty:
769
- msg = f'There are {n_empty}/{max_split.size} predictors without decision rules'
770
- raise ValueError(msg)
771
- kw.update(filter_splitless_vars=False)
815
+ if rm_const:
816
+ n_empty = jnp.sum(max_split == 0).item()
817
+ kw.update(filter_splitless_vars=n_empty)
772
818
 
773
819
  if init_kw is not None:
774
820
  kw.update(init_kw)
775
821
 
776
- return mcmcstep.init(**kw)
822
+ state = mcmcstep.init(**kw)
777
823
 
778
- @staticmethod
824
+ # put state on device if requested explicitly by the user
825
+ if device is not None:
826
+ state = device_put(state, device, donate=True)
827
+
828
+ return state
829
+
830
+ @classmethod
779
831
  def _run_mcmc(
832
+ cls,
780
833
  mcmc_state: mcmcstep.State,
781
834
  ndpost: int,
782
835
  nskip: int,
@@ -784,30 +837,101 @@ class gbart(Module):
784
837
  printevery: int | None,
785
838
  seed: int | Integer[Array, ''] | Key[Array, ''],
786
839
  run_mcmc_kw: dict | None,
787
- sparse: bool,
788
- ):
840
+ ) -> tuple[mcmcstep.State, mcmcloop.BurninTrace, mcmcloop.MainTrace]:
789
841
  # prepare random generator seed
790
- if isinstance(seed, jax.Array) and jnp.issubdtype(
791
- seed.dtype, jax.dtypes.prng_key
792
- ):
793
- key = seed.copy()
794
- # copy because the inner loop in run_mcmc will donate the buffer
842
+ if is_key(seed):
843
+ key = jnp.copy(seed)
795
844
  else:
796
845
  key = jax.random.key(seed)
797
846
 
847
+ # round up ndpost
848
+ num_chains = get_num_chains(mcmc_state)
849
+ if num_chains is None:
850
+ num_chains = 1
851
+ n_save = ndpost // num_chains + bool(ndpost % num_chains)
852
+
798
853
  # prepare arguments
799
- kw = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
854
+ kw: dict = dict(n_burn=nskip, n_skip=keepevery, inner_loop_length=printevery)
800
855
  kw.update(
801
856
  mcmcloop.make_default_callback(
857
+ mcmc_state,
802
858
  dot_every=None if printevery is None or printevery == 1 else 1,
803
859
  report_every=printevery,
804
- sparse_on_at=nskip // 2 if sparse else None,
805
860
  )
806
861
  )
807
862
  if run_mcmc_kw is not None:
808
863
  kw.update(run_mcmc_kw)
809
864
 
810
- return mcmcloop.run_mcmc(key, mcmc_state, ndpost, **kw)
811
-
812
- def _predict(self, x):
813
- return mcmcloop.evaluate_trace(self._main_trace, x)
865
+ return run_mcmc(key, mcmc_state, n_save, **kw)
866
+
867
+ def _predict(self, x: UInt[Array, 'p m']) -> Float32[Array, 'ndpost m']:
868
+ """Evaluate trees on already quantized `x`."""
869
+ out = evaluate_trace(x, self._main_trace)
870
+ return collapse(out, 0, -1)
871
+
872
+
873
+ class DeviceKwArgs(TypedDict):
874
+ num_chains: int | None
875
+ mesh: Mesh | None
876
+ target_platform: Literal['cpu', 'gpu'] | None
877
+
878
+
879
+ def process_device_settings(
880
+ y_train: Array,
881
+ num_chains: int | None,
882
+ num_chain_devices: int | None,
883
+ num_data_devices: int | None,
884
+ devices: Device | Sequence[Device] | None,
885
+ ) -> tuple[DeviceKwArgs, Device | None]:
886
+ """Return the arguments for `mcmcstep.init` related to devices, and an optional device where to put the state."""
887
+ # determine devices
888
+ if devices is not None:
889
+ if not hasattr(devices, '__len__'):
890
+ devices = (devices,)
891
+ device = devices[0]
892
+ platform = device.platform
893
+ elif hasattr(y_train, 'platform'):
894
+ platform = y_train.platform()
895
+ device = None
896
+ # set device=None because if the devices were not specified explicitly
897
+ # we may be in the case where computation will follow data placement,
898
+ # do not disturb jax as the user may be playing with vmap, jit, reshard...
899
+ devices = jax.devices(platform)
900
+ else:
901
+ msg = 'not possible to infer device from `y_train`, please set `devices`'
902
+ raise ValueError(msg)
903
+
904
+ # create mesh
905
+ if num_chain_devices is None and num_data_devices is None:
906
+ mesh = None
907
+ else:
908
+ mesh = dict()
909
+ if num_chain_devices is not None:
910
+ mesh.update(chains=num_chain_devices)
911
+ if num_data_devices is not None:
912
+ mesh.update(data=num_data_devices)
913
+ mesh = make_mesh(
914
+ axis_shapes=tuple(mesh.values()),
915
+ axis_names=tuple(mesh),
916
+ axis_types=(AxisType.Auto,) * len(mesh),
917
+ devices=devices,
918
+ )
919
+ device = None
920
+ # set device=None because `mcmcstep.init` will `device_put` with the
921
+ # mesh already, we don't want to undo its work
922
+
923
+ # prepare arguments to `init`
924
+ settings = DeviceKwArgs(
925
+ num_chains=num_chains,
926
+ mesh=mesh,
927
+ target_platform=None
928
+ if mesh is not None or hasattr(y_train, 'platform')
929
+ else platform,
930
+ # here we don't take into account the case where the user has set both
931
+ # batch sizes; since the user has to be playing with `init_kw` to do
932
+ # that, we'll let `init` throw the error and the user set
933
+ # `target_platform` themselves so they have a clearer idea how the
934
+ # thing works.
935
+ )
936
+
937
+ return settings, device