rbartpackages 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rbartpackages/BART.py ADDED
@@ -0,0 +1,335 @@
1
+ # rbartpackages/src/rbartpackages/BART.py
2
+ #
3
+ # Copyright (c) 2024-2026, The rbartpackages Contributors
4
+ #
5
+ # This file is part of rbartpackages.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Wrapper for the R package BART."""
26
+
27
+ # ruff: noqa: ANN002, ANN003
28
+
29
+ from functools import partial
30
+ from typing import NamedTuple, TypedDict, cast
31
+
32
+ import numpy as np
33
+ from jaxtyping import AbstractDtype, Bool, Float64, Int32
34
+ from numpy import ndarray
35
+ from rpy2 import robjects
36
+ from rpy2.rlike.container import NamedList
37
+
38
+ from rbartpackages._base import RObjectBase, fork_safe_native_threads, rmethod
39
+
40
+
41
+ class TreeDraws(TypedDict):
42
+ """Type of the `treedraws` attribute of `mc_gbart`."""
43
+
44
+ cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']]
45
+ """Per-variable grid of candidate split points, keyed by column index or name."""
46
+
47
+ trees: str
48
+ """Posterior tree ensemble serialized in BART's text format (read by `predict`)."""
49
+
50
+
51
+ class PredictBinary(TypedDict):
52
+ """Type of `predict`'s return value for binary (`pbart`/`lbart`) fits."""
53
+
54
+ yhat_test: Float64[ndarray, 'ndpost m']
55
+ """Posterior latent-function draws at the test points."""
56
+
57
+ prob_test: Float64[ndarray, 'ndpost m']
58
+ """Success-probability draws (inverse probit/logit transform of `yhat_test`)."""
59
+
60
+ prob_test_mean: Float64[ndarray, ' m']
61
+ """Posterior mean of `prob_test`."""
62
+
63
+ binaryOffset: float
64
+ """Data centering value on the latent scale."""
65
+
66
+
67
+ class String(AbstractDtype):
68
+ """Represent a `numpy.str_` data dtype."""
69
+
70
+ dtypes = r'<U\d+'
71
+
72
+
73
+ class ProcTime(NamedTuple):
74
+ """Python representation of the output of R's `proc.time`."""
75
+
76
+ user_self: float
77
+ """CPU seconds charged to the R process in user mode."""
78
+
79
+ sys_self: float
80
+ """CPU seconds charged to the R process in system (kernel) mode."""
81
+
82
+ elapsed: float
83
+ """Wall-clock seconds elapsed."""
84
+
85
+ user_child: float
86
+ """User-mode CPU seconds of forked child processes (`mc.gbart` workers)."""
87
+
88
+ sys_child: float
89
+ """System-mode CPU seconds of forked child processes."""
90
+
91
+
92
+ class mc_gbart(RObjectBase): # noqa: D101 because the R doc is added automatically
93
+ _rfuncname = 'BART::mc.gbart'
94
+
95
+ LPML: float
96
+ """Log pseudo-marginal likelihood; unstable for BART.
97
+
98
+ Always computed, even without burn-in. Miscomputed by R for binary
99
+ `mc.gbart` fits with ``mc_cores > 1`` (the chains' probabilities are not
100
+ combined before the computation).
101
+ """
102
+
103
+ hostname: Bool[ndarray, ' mc_cores'] | String[ndarray, ' mc_cores']
104
+ """Per-chain hostname if fitted with ``hostname=True``, else per-chain ``False``."""
105
+
106
+ ndpost: int
107
+ """Number of posterior draws kept, after burn-in and thinning."""
108
+
109
+ offset: float
110
+ """Data centering value for the response (link scale for binary)."""
111
+
112
+ prob_test: None | Float64[ndarray, 'ndpost/mc_cores m'] = None
113
+ """Test-point success-probability draws (binary outcomes only).
114
+
115
+ `mc.gbart` with ``mc_cores > 1`` forgets to combine the chains, leaving
116
+ only the first chain's draws.
117
+ """
118
+
119
+ prob_test_mean: None | Float64[ndarray, ' m'] = None
120
+ """Posterior mean of `prob_test`."""
121
+
122
+ prob_train: None | Float64[ndarray, 'ndpost/mc_cores n'] = None
123
+ """Training-point success-probability draws (binary outcomes only).
124
+
125
+ `mc.gbart` with ``mc_cores > 1`` forgets to combine the chains, leaving
126
+ only the first chain's draws.
127
+ """
128
+
129
+ prob_train_mean: None | Float64[ndarray, ' n'] = None
130
+ """Posterior mean of `prob_train`."""
131
+
132
+ proc_time: ProcTime
133
+ """Timing of the fit, from R's `proc.time`."""
134
+
135
+ rm_const: Int32[ndarray, '<=p']
136
+ """0-based indices of the `x_train` columns kept (constant columns dropped).
137
+
138
+ `mc.gbart` with ``mc_cores=1`` relabels the kept columns to ``0 .. kept-1``,
139
+ losing which original columns were dropped.
140
+ """
141
+
142
+ sigma: (
143
+ Float64[ndarray, ' nskip+ndpost*keepevery']
144
+ | Float64[ndarray, 'nskip+ndpost*keepevery/mc_cores mc_cores']
145
+ | None
146
+ ) = None
147
+ """Error-SD draws, continuous outcomes only (per chain for `mc.gbart`).
148
+
149
+ One draw per MCMC iteration: burn-in and the thinned-away iterations are
150
+ included.
151
+ """
152
+
153
+ sigma_mean: float | None = None
154
+ """Mean of the first `ndpost` post-burn-in `sigma` draws (continuous only)."""
155
+
156
+ treedraws: TreeDraws
157
+ """Sampled trees: per-variable cutpoint grid and the serialized ensemble."""
158
+
159
+ varcount: Int32[ndarray, 'ndpost p']
160
+ """Per-draw count of splits on each variable, summed over trees."""
161
+
162
+ varcount_mean: Float64[ndarray, ' p']
163
+ """Posterior mean of `varcount` per variable."""
164
+
165
+ varprob: Float64[ndarray, 'ndpost p']
166
+ """Per-draw probability assigned to each variable for splitting."""
167
+
168
+ varprob_mean: Float64[ndarray, ' p']
169
+ """Posterior mean of `varprob` per variable."""
170
+
171
+ yhat_test: Float64[ndarray, 'ndpost m']
172
+ """Test-point posterior function draws (latent scale for binary).
173
+
174
+ Always present: R's `cgbart` allocates it unconditionally, so without test
175
+ data it is an empty array rather than ``None`` (with the rows of the first
176
+ chain only for `mc.gbart`, which combines the chains just when there is
177
+ test data).
178
+ """
179
+
180
+ yhat_test_mean: Float64[ndarray, ' m'] | None = None
181
+ """Posterior mean of `yhat_test` (continuous with test data only)."""
182
+
183
+ yhat_train: Float64[ndarray, 'ndpost n']
184
+ """Training-point posterior function draws (latent scale for binary)."""
185
+
186
+ yhat_train_mean: Float64[ndarray, ' n'] | None = None
187
+ """Posterior mean of `yhat_train` (continuous only)."""
188
+
189
+ def __init__(self, *args, **kw) -> None:
190
+ # mc.gbart forks via parallel::mcparallel; cap native thread pools at one
191
+ # thread across the fork to avoid a libgomp deadlock in the children.
192
+ with fork_safe_native_threads():
193
+ super().__init__(*args, **kw)
194
+
195
+ # fix up attributes
196
+ self.LPML = self.LPML.item()
197
+ self.ndpost = self.ndpost.astype(int).item()
198
+ self.offset = self.offset.item()
199
+ self.proc_time = ProcTime(*map(float, self.proc_time))
200
+
201
+ if np.all(self.rm_const < 0):
202
+ # R reports the dropped constant columns as negative indices into
203
+ # the original design matrix, while varcount has the kept ones
204
+ _, kept = self.varcount.shape
205
+ p = kept + self.rm_const.size
206
+ rm_const = np.ones(p, bool)
207
+ rm_const[-self.rm_const - 1] = False
208
+ self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
209
+ elif np.all(self.rm_const > 0):
210
+ self.rm_const -= 1
211
+ else: # pragma: no cover - R gives all-positive or all-negative indices
212
+ msg = 'failed to parse rm.const because indices change sign'
213
+ raise ValueError(msg)
214
+
215
+ if self.sigma_mean is not None:
216
+ self.sigma_mean = self.sigma_mean.item()
217
+
218
+ r_treedraws = cast(NamedList, self.treedraws)
219
+ cutpoints: NamedList = r_treedraws.getbyname('cutpoints')
220
+ self.treedraws = {
221
+ 'cutpoints': {
222
+ i if it.name is None else it.name.item(): it.value
223
+ for i, it in enumerate(cutpoints.items())
224
+ },
225
+ 'trees': r_treedraws.getbyname('trees').item(),
226
+ }
227
+
228
+ @partial(rmethod, rname='predict')
229
+ def _predict(self, newdata: Float64[ndarray, 'm p'], *args, **kwargs) -> object:
230
+ """Call R's `predict`; returns a matrix (continuous) or a list (binary)."""
231
+ ...
232
+
233
+ def predict(
234
+ self, newdata: Float64[ndarray, 'm p'], *args, **kwargs
235
+ ) -> Float64[ndarray, 'ndpost m'] | PredictBinary:
236
+ """Compute predictions.
237
+
238
+ For continuous (`wbart`) fits this is the matrix of posterior
239
+ latent-function draws. For binary (`pbart`/`lbart`) fits R returns a
240
+ list, exposed here as a `PredictBinary` dict.
241
+
242
+ For `mc.gbart` fits with ``mc_cores > 1`` that dropped constant
243
+ columns, R miscounts the kept columns and fails to update the header
244
+ of the serialized ensemble, so only the first chain's draws are
245
+ returned.
246
+ """
247
+ out = self._predict(newdata, *args, **kwargs)
248
+ if not hasattr(out, 'items'):
249
+ return out # continuous: already a matrix
250
+
251
+ # binary: convert R's list (a NamedList) to a dict of arrays
252
+ out = cast(NamedList, out)
253
+ result = {str(it.name).replace('.', '_'): it.value for it in out.items()}
254
+ result['binaryOffset'] = result['binaryOffset'].item()
255
+ return result
256
+
257
+
258
+ class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically
259
+ _rfuncname = 'BART::bartModelMatrix'
260
+
261
+ X: Float64[ndarray, 'N p']
262
+ """Design matrix: vectors and data frames coerced to numeric, factors expanded to indicators."""
263
+
264
+ numcut: Int32[ndarray, ' p']
265
+ """Number of cutpoints chosen per column."""
266
+
267
+ rm_const: Int32[ndarray, '<=p']
268
+ """0-based indices of the non-constant columns of the expanded design.
269
+
270
+ The indices refer to the columns of `X` before removal: ``rm.const=True``
271
+ removes the constant columns from `X`, `numcut` and `xinfo`, while the
272
+ default only detects them.
273
+ """
274
+
275
+ xinfo: Float64[ndarray, 'p numcut']
276
+ """Per-column cutpoint grid, NaN-padded to the maximum cut count."""
277
+
278
+ grp: Int32[ndarray, ' p'] | Float64[ndarray, ' 1'] | None
279
+ """1-based input-column index each output column comes from (factors expand
280
+ to one indicator column per level); ``None`` for matrix input."""
281
+
282
+ def __new__(cls, *args, **kw) -> Float64[ndarray, 'N p'] | RObjectBase:
283
+ """Match R: return the bare matrix for ``numcut=0``, else a populated instance."""
284
+ # __init__ cannot change the return type, so the matrix-or-list choice
285
+ # is made here; returning a non-instance (the matrix) skips __init__.
286
+ self = super().__new__(cls)
287
+ self._robject = self._invoke_rfunc(args, kw)
288
+ if self._has_named_components(self._robject):
289
+ return self
290
+ return self._r2py(self._robject)
291
+
292
+ def __init__(self, *args, **kw) -> None:
293
+ # Only reached for the named-list case (numcut > 0); __new__ already
294
+ # invoked R and stored `_robject`, so just expose its components rather
295
+ # than calling super().__init__ (which would invoke R a second time).
296
+ self._set_attrs_from_robject()
297
+
298
+ # grp is R NULL for matrix input; expose it as None.
299
+ if self.grp is robjects.NULL:
300
+ self.grp = None
301
+
302
+ if np.all(self.rm_const < 0):
303
+ # R flags detected-constant columns as negative indices into the
304
+ # pre-removal design matrix; whether they were also removed from X
305
+ # depends on the rm.const argument, so recover it from the call
306
+ # (rm.const is the 5th parameter of R's bartModelMatrix)
307
+ removed = kw.get(
308
+ 'rm.const', kw.get('rm_const', args[4] if len(args) > 4 else False)
309
+ )
310
+ _, n_cols = self.X.shape
311
+ p = n_cols + self.rm_const.size if removed else n_cols
312
+ rm_const = np.ones(p, bool)
313
+ rm_const[-self.rm_const - 1] = False
314
+ self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
315
+ elif np.all(self.rm_const > 0):
316
+ self.rm_const -= 1
317
+ else: # pragma: no cover - R gives all-positive or all-negative indices
318
+ msg = 'failed to parse rm.const because indices change sign'
319
+ raise ValueError(msg)
320
+
321
+
322
+ class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically
323
+ _rfuncname = 'BART::gbart'
324
+
325
+ hostname: Bool[ndarray, ' 1'] | String[ndarray, ' 1']
326
+ """Hostname the fit ran on if fitted with ``hostname=True``, else ``False``."""
327
+
328
+ prob_test: None | Float64[ndarray, 'ndpost m'] = None
329
+ """Test-point success-probability draws (binary outcomes only)."""
330
+
331
+ prob_train: None | Float64[ndarray, 'ndpost n'] = None
332
+ """Training-point success-probability draws (binary outcomes only)."""
333
+
334
+ sigma: Float64[ndarray, ' nskip+ndpost*keepevery'] | None = None
335
+ """Error-SD draws for every MCMC iteration, burn-in included (continuous only)."""
rbartpackages/BART3.py ADDED
@@ -0,0 +1,368 @@
1
+ # rbartpackages/src/rbartpackages/BART3.py
2
+ #
3
+ # Copyright (c) 2025-2026, The rbartpackages Contributors
4
+ #
5
+ # This file is part of rbartpackages.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Wrapper for the R package BART3."""
26
+
27
+ # ruff: noqa: ANN002, ANN003
28
+
29
+ from functools import partial
30
+ from typing import NamedTuple, TypedDict, cast
31
+
32
+ import numpy as np
33
+ from jaxtyping import AbstractDtype, Float64, Int32
34
+ from numpy import ndarray
35
+ from rpy2 import robjects
36
+ from rpy2.rlike.container import NamedList
37
+
38
+ from rbartpackages._base import RObjectBase, fork_safe_native_threads, rmethod
39
+
40
+
41
+ class TreeDraws(TypedDict):
42
+ """Type of the `treedraws` attribute of `mc_gbart`."""
43
+
44
+ cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']]
45
+ """Per-variable grid of candidate split points, keyed by column index or name."""
46
+
47
+ trees: str
48
+ """Posterior tree ensemble serialized in BART's text format (read by `predict`)."""
49
+
50
+
51
+ class PredictBinary(TypedDict):
52
+ """Type of `predict`'s return value for binary (`pbart`/`lbart`) fits."""
53
+
54
+ yhat_test: Float64[ndarray, 'ndpost m']
55
+ """Posterior latent-function draws at the test points."""
56
+
57
+ prob_test: Float64[ndarray, 'ndpost m']
58
+ """Success-probability draws (inverse probit/logit transform of `yhat_test`)."""
59
+
60
+ prob_test_mean: Float64[ndarray, ' m']
61
+ """Posterior mean of `prob_test`."""
62
+
63
+ prob_test_lower: Float64[ndarray, ' m']
64
+ """Lower `probs` quantile of `prob_test` (default 2.5%)."""
65
+
66
+ prob_test_upper: Float64[ndarray, ' m']
67
+ """Upper `probs` quantile of `prob_test` (default 97.5%)."""
68
+
69
+ binaryOffset: float
70
+ """Data centering value on the latent scale."""
71
+
72
+
73
+ class String(AbstractDtype):
74
+ """Represent a `numpy.str_` data dtype."""
75
+
76
+ dtypes = r'<U\d+'
77
+
78
+
79
+ class ProcTime(NamedTuple):
80
+ """Python representation of the output of R's `proc.time`."""
81
+
82
+ user_self: float
83
+ """CPU seconds charged to the R process in user mode."""
84
+
85
+ sys_self: float
86
+ """CPU seconds charged to the R process in system (kernel) mode."""
87
+
88
+ elapsed: float
89
+ """Wall-clock seconds elapsed."""
90
+
91
+ user_child: float
92
+ """User-mode CPU seconds of forked child processes (`mc.gbart` workers)."""
93
+
94
+ sys_child: float
95
+ """System-mode CPU seconds of forked child processes."""
96
+
97
+
98
+ class mc_gbart(RObjectBase): # noqa: D101 because the R doc is added automatically
99
+ _rfuncname = 'BART3::mc.gbart'
100
+
101
+ LPML: float | None = None
102
+ """Log pseudo-marginal likelihood; ``None`` without burn-in. Unstable for BART."""
103
+
104
+ accept: (
105
+ Float64[ndarray, ' nskip+ndpost*keepevery']
106
+ | Float64[ndarray, 'nskip+ndpost*keepevery/mc_cores mc_cores']
107
+ )
108
+ """Per-iteration Metropolis-Hastings acceptance rate (per chain for `mc.gbart`).
109
+
110
+ Recorded for every MCMC iteration, including the thinned-away ones (unlike
111
+ `sigma`, which keeps only burn-in plus retained draws).
112
+ """
113
+
114
+ chains: int
115
+ """Number of MCMC chains, i.e. the `mc_cores` actually used."""
116
+
117
+ grp: Float64[ndarray, ' p']
118
+ """Group index of each column for the sparse (DART) variable-selection prior."""
119
+
120
+ ndpost: int
121
+ """Number of posterior draws kept, after burn-in and thinning."""
122
+
123
+ offset: float
124
+ """Data centering value for the response (link scale for binary)."""
125
+
126
+ prob_test: None | Float64[ndarray, 'ndpost m'] = None
127
+ """Test-point success-probability draws (binary outcomes only)."""
128
+
129
+ prob_test_lower: Float64[ndarray, ' m'] | None = None
130
+ """Lower `probs` quantile of `prob_test` (default 2.5%)."""
131
+
132
+ prob_test_mean: None | Float64[ndarray, ' m'] = None
133
+ """Posterior mean of `prob_test`."""
134
+
135
+ prob_test_upper: Float64[ndarray, ' m'] | None = None
136
+ """Upper `probs` quantile of `prob_test` (default 97.5%)."""
137
+
138
+ prob_train: None | Float64[ndarray, 'ndpost n'] = None
139
+ """Training-point success-probability draws (binary outcomes only)."""
140
+
141
+ prob_train_mean: None | Float64[ndarray, ' n'] = None
142
+ """Posterior mean of `prob_train`."""
143
+
144
+ proc_time: ProcTime
145
+ """Timing of the fit, from R's `proc.time`."""
146
+
147
+ rho: float
148
+ """Concentration of the sparse (DART) prior; defaults to ``sum(1/grp)``."""
149
+
150
+ rm_const: Int32[ndarray, '<=p']
151
+ """0-based indices of the `x_train` columns kept (constant columns dropped)."""
152
+
153
+ sigest: float | None = None
154
+ """Rough residual SD used to set the sigma prior (continuous only).
155
+
156
+ ``None`` for binary outcomes; ``nan`` when the `mc.gbart` ``mc_cores > 1``
157
+ bug overwrites it with a logical missing value.
158
+ """
159
+
160
+ sigma: (
161
+ Float64[ndarray, ' nskip+ndpost']
162
+ | Float64[ndarray, 'nskip+ndpost/mc_cores mc_cores']
163
+ | None
164
+ ) = None
165
+ """Error-SD draws including burn-in, continuous only (per chain for `mc.gbart`)."""
166
+
167
+ sigma_: Float64[ndarray, ' ndpost'] | None = None
168
+ """Kept `sigma` draws with burn-in dropped; ``None`` without burn-in."""
169
+
170
+ sigma_mean: float | None = None
171
+ """Mean of `sigma_`; falls back to `sigest` when no draws are kept."""
172
+
173
+ treedraws: TreeDraws
174
+ """Sampled trees: per-variable cutpoint grid and the serialized ensemble."""
175
+
176
+ varcount: Int32[ndarray, 'ndpost p']
177
+ """Per-draw count of splits on each variable, summed over trees."""
178
+
179
+ varcount_mean: Float64[ndarray, ' p']
180
+ """Posterior mean of `varcount` per variable."""
181
+
182
+ varprob: Float64[ndarray, 'ndpost p']
183
+ """Per-draw probability assigned to each variable for splitting."""
184
+
185
+ varprob_mean: Float64[ndarray, ' p']
186
+ """Posterior mean of `varprob` per variable."""
187
+
188
+ x_test: Float64[ndarray, ' m <=p'] | None = None
189
+ """Test design matrix as used (imputed, factors expanded, constant columns dropped)."""
190
+
191
+ x_train: Float64[ndarray, ' n <=p']
192
+ """Training design matrix as used (original scale, not binned; constant columns dropped)."""
193
+
194
+ yhat_test: Float64[ndarray, 'ndpost m']
195
+ """Test-point posterior function draws (latent scale for binary).
196
+
197
+ Always present: R's `cgbart` allocates it unconditionally, so without test
198
+ data it is an empty ``(ndpost, 0)`` array rather than ``None`` (unlike the
199
+ derived `yhat_test_mean`/`lower`/`upper`, which R only fills when test data
200
+ is given).
201
+ """
202
+
203
+ yhat_test_lower: Float64[ndarray, ' m'] | None = None
204
+ """Lower `probs` quantile of `yhat_test` (default 2.5%, continuous only)."""
205
+
206
+ yhat_test_mean: Float64[ndarray, ' m'] | None = None
207
+ """Posterior mean of `yhat_test`."""
208
+
209
+ yhat_test_upper: Float64[ndarray, ' m'] | None = None
210
+ """Upper `probs` quantile of `yhat_test` (default 97.5%, continuous only)."""
211
+
212
+ yhat_train: Float64[ndarray, 'ndpost n']
213
+ """Training-point posterior function draws (latent scale for binary)."""
214
+
215
+ yhat_train_lower: Float64[ndarray, ' n'] | None = None
216
+ """Lower `probs` quantile of `yhat_train` (default 2.5%, continuous only)."""
217
+
218
+ yhat_train_mean: Float64[ndarray, ' n'] | None = None
219
+ """Posterior mean of `yhat_train`."""
220
+
221
+ yhat_train_upper: Float64[ndarray, ' n'] | None = None
222
+ """Upper `probs` quantile of `yhat_train` (default 97.5%, continuous only)."""
223
+
224
+ def __init__(self, *args, **kw) -> None:
225
+ # mc.gbart forks via parallel::mcparallel; cap native thread pools at one
226
+ # thread across the fork to avoid a libgomp deadlock in the children.
227
+ with fork_safe_native_threads():
228
+ super().__init__(*args, **kw)
229
+
230
+ # fix up attributes
231
+ self.chains = self.chains.item()
232
+ self.ndpost = self.ndpost.astype(int).item()
233
+ self.offset = self.offset.item()
234
+ self.proc_time = ProcTime(*map(float, self.proc_time))
235
+ self.rho = self.rho.item()
236
+
237
+ if np.all(self.rm_const < 0):
238
+ # R reports the dropped constant columns as negative indices into
239
+ # the original design matrix, while varcount has the kept ones
240
+ _, kept = self.varcount.shape
241
+ p = kept + self.rm_const.size
242
+ rm_const = np.ones(p, bool)
243
+ rm_const[-self.rm_const - 1] = False
244
+ self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
245
+ elif np.all(self.rm_const > 0):
246
+ self.rm_const -= 1
247
+ else: # pragma: no cover - R gives all-positive or all-negative indices
248
+ msg = 'failed to parse rm.const because indices change sign'
249
+ raise ValueError(msg)
250
+
251
+ if self.LPML is not None:
252
+ self.LPML = self.LPML.item()
253
+ if self.sigest is not None:
254
+ if self.sigest.dtype == bool:
255
+ # BART3 bug: mc.gbart with mc_cores > 1 overwrites sigest with
256
+ # its logical-NA default instead of the estimate.
257
+ self.sigest = float('nan')
258
+ else:
259
+ self.sigest = self.sigest.item()
260
+ if self.sigma_mean is not None:
261
+ self.sigma_mean = self.sigma_mean.item()
262
+
263
+ r_treedraws = cast(NamedList, self.treedraws)
264
+ cutpoints: NamedList = r_treedraws.getbyname('cutpoints')
265
+ self.treedraws = {
266
+ 'cutpoints': {
267
+ i if it.name is None else it.name.item(): it.value
268
+ for i, it in enumerate(cutpoints.items())
269
+ },
270
+ 'trees': r_treedraws.getbyname('trees').item(),
271
+ }
272
+
273
+ @partial(rmethod, rname='predict')
274
+ def _predict(self, newdata: Float64[ndarray, 'm p'], *args, **kwargs) -> object:
275
+ """Call R's `predict`; returns a matrix (continuous) or a list (binary)."""
276
+ ...
277
+
278
+ def predict(
279
+ self, newdata: Float64[ndarray, 'm p'], *args, **kwargs
280
+ ) -> Float64[ndarray, 'ndpost m'] | PredictBinary:
281
+ """Compute predictions.
282
+
283
+ For continuous (`wbart`) fits this is the matrix of posterior
284
+ latent-function draws. For binary (`pbart`/`lbart`) fits R returns a
285
+ list, exposed here as a `PredictBinary` dict.
286
+ """
287
+ out = self._predict(newdata, *args, **kwargs)
288
+ if not hasattr(out, 'items'):
289
+ return out # continuous: already a matrix
290
+
291
+ # binary: convert R's list (a NamedList) to a dict of arrays
292
+ out = cast(NamedList, out)
293
+ result = {str(it.name).replace('.', '_'): it.value for it in out.items()}
294
+ result['binaryOffset'] = result['binaryOffset'].item()
295
+ return result
296
+
297
+
298
+ class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically
299
+ _rfuncname = 'BART3::bartModelMatrix'
300
+
301
+ X: Float64[ndarray, 'N p']
302
+ """Design matrix: vectors and data frames coerced to numeric, factors expanded to indicators."""
303
+
304
+ numcut: Int32[ndarray, ' p']
305
+ """Number of cutpoints chosen per column."""
306
+
307
+ rm_const: Int32[ndarray, '<=p']
308
+ """0-based indices of the non-constant columns of the expanded design.
309
+
310
+ The indices refer to the columns of `X` before removal: ``rm.const=True``
311
+ removes the constant columns from `X`, `numcut` and `xinfo`, while the
312
+ default only detects them.
313
+ """
314
+
315
+ xinfo: Float64[ndarray, 'p numcut']
316
+ """Per-column cutpoint grid, NaN-padded to the maximum cut count."""
317
+
318
+ grp: Float64[ndarray, ' p'] | None
319
+ """Group size of each expanded factor's indicator columns, or None if no factors."""
320
+
321
+ def __new__(cls, *args, **kw) -> Float64[ndarray, 'N p'] | RObjectBase:
322
+ """Match R: return the bare matrix for ``numcut=0``, else a populated instance."""
323
+ # __init__ cannot change the return type, so the matrix-or-list choice
324
+ # is made here; returning a non-instance (the matrix) skips __init__.
325
+ self = super().__new__(cls)
326
+ self._robject = self._invoke_rfunc(args, kw)
327
+ if self._has_named_components(self._robject):
328
+ return self
329
+ return self._r2py(self._robject)
330
+
331
+ def __init__(self, *args, **kw) -> None:
332
+ # Only reached for the named-list case (numcut > 0); __new__ already
333
+ # invoked R and stored `_robject`, so just expose its components rather
334
+ # than calling super().__init__ (which would invoke R a second time).
335
+ self._set_attrs_from_robject()
336
+
337
+ # grp is R NULL unless the input had factor columns; expose it as None.
338
+ if self.grp is robjects.NULL:
339
+ self.grp = None
340
+
341
+ if np.all(self.rm_const < 0):
342
+ # R flags detected-constant columns as negative indices into the
343
+ # pre-removal design matrix; whether they were also removed from X
344
+ # depends on the rm.const argument, so recover it from the call
345
+ # (rm.const is the 5th parameter of R's bartModelMatrix)
346
+ removed = kw.get(
347
+ 'rm.const', kw.get('rm_const', args[4] if len(args) > 4 else False)
348
+ )
349
+ _, n_cols = self.X.shape
350
+ p = n_cols + self.rm_const.size if removed else n_cols
351
+ rm_const = np.ones(p, bool)
352
+ rm_const[-self.rm_const - 1] = False
353
+ self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
354
+ elif np.all(self.rm_const > 0):
355
+ self.rm_const -= 1
356
+ else: # pragma: no cover - R gives all-positive or all-negative indices
357
+ msg = 'failed to parse rm.const because indices change sign'
358
+ raise ValueError(msg)
359
+
360
+
361
+ class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically
362
+ _rfuncname = 'BART3::gbart'
363
+
364
+ accept: Float64[ndarray, ' nskip+ndpost*keepevery']
365
+ """Per-iteration Metropolis-Hastings acceptance rate (every MCMC iteration)."""
366
+
367
+ sigma: Float64[ndarray, ' nskip+ndpost'] | None = None
368
+ """Error-SD draws including burn-in (continuous only)."""
@@ -0,0 +1,35 @@
1
+ # rbartpackages/src/rbartpackages/__init__.py
2
+ #
3
+ # Copyright (c) 2025-2026, The rbartpackages Contributors
4
+ #
5
+ # This file is part of rbartpackages.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """
26
+ Python wrappers of R BART packages.
27
+
28
+ Each wrapped R package has its own submodule (``BART``, ``BART3``,
29
+ ``bartMachine``, ``dbarts``); import the one you need, e.g.
30
+ ``from rbartpackages import BART3``. Importing a wrapper requires the
31
+ corresponding R package to be installed, because the class docstrings are pulled
32
+ from the R documentation at import time.
33
+ """
34
+
35
+ from rbartpackages._version import __version__, __version_info__ # noqa: F401
rbartpackages/_base.py ADDED
@@ -0,0 +1,293 @@
1
+ # rbartpackages/src/rbartpackages/_base.py
2
+ #
3
+ # Copyright (c) 2024-2026, The rbartpackages Contributors
4
+ #
5
+ # This file is part of rbartpackages.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ import ctypes
26
+ from collections.abc import Callable, Iterable, Iterator, Mapping
27
+ from contextlib import contextmanager
28
+ from functools import wraps
29
+ from re import fullmatch, match
30
+ from textwrap import indent
31
+ from typing import Any
32
+
33
+ import numpy as np
34
+ from rpy2 import robjects
35
+ from rpy2.robjects import BoolVector, conversion, numpy2ri
36
+ from rpy2.robjects.help import Package
37
+ from rpy2.robjects.methods import RS4
38
+
39
+ # converter for pandas
40
+ PANDAS_CONVERTER = conversion.Converter('pandas')
41
+ try:
42
+ from rpy2.robjects import pandas2ri
43
+ except ImportError: # pragma: no cover - optional dep always present in CI
44
+ pass
45
+ else:
46
+ PANDAS_CONVERTER = pandas2ri.converter
47
+
48
+ # converter for polars
49
+ POLARS_CONVERTER = conversion.Converter('polars')
50
+ try:
51
+ import polars as pl
52
+ from rpy2.robjects import pandas2ri
53
+ except ImportError: # pragma: no cover - optional dep always present in CI
54
+ pass
55
+ else:
56
+
57
+ def polars_to_r(df: pl.DataFrame) -> object:
58
+ df = df.to_pandas()
59
+ return pandas2ri.py2rpy(df)
60
+
61
+ POLARS_CONVERTER.py2rpy.register(pl.DataFrame, polars_to_r)
62
+ POLARS_CONVERTER.py2rpy.register(pl.Series, polars_to_r)
63
+
64
+ # converter for jax
65
+ JAX_CONVERTER = conversion.Converter('jax')
66
+ try:
67
+ import jax
68
+ except ImportError: # pragma: no cover - optional dep always present in CI
69
+ pass
70
+ else:
71
+
72
+ def jax_to_r(x: jax.Array) -> object:
73
+ x = np.asarray(x)
74
+ if x.ndim == 0:
75
+ x = x[()]
76
+ return numpy2ri.py2rpy(x)
77
+
78
+ JAX_CONVERTER.py2rpy.register(jax.Array, jax_to_r)
79
+
80
+ # converter for numpy
81
+ NUMPY_CONVERTER = numpy2ri.converter
82
+
83
+
84
+ # converter for BoolVector (why isn't it in the numpy converter?)
85
+ def bool_vector_to_python(x: BoolVector) -> np.ndarray[Any, np.dtype[np.bool_]]:
86
+ return np.array(x, bool)
87
+
88
+
89
+ BOOL_VECTOR_CONVERTER = conversion.Converter('bool_vector')
90
+ BOOL_VECTOR_CONVERTER.rpy2py.register(BoolVector, bool_vector_to_python)
91
+
92
+
93
+ # converter for python dictionaries
94
+ DICT_CONVERTER = conversion.Converter('dict')
95
+
96
+
97
+ def dict_to_r(x: dict[str, Any]) -> robjects.ListVector:
98
+ return robjects.ListVector(x)
99
+
100
+
101
+ DICT_CONVERTER.py2rpy.register(dict, dict_to_r)
102
+
103
+ R_IDENTIFIER = r'(?:[a-zA-Z]|\.(?![0-9]))[a-zA-Z0-9._]*'
104
+
105
+ # In-process native thread pools to cap before R forks. R's
106
+ # `parallel::mcparallel` (used by the `mc.*` BART functions) forks, but GNU
107
+ # libgomp is not fork-safe: a forked child that enters an OpenMP parallel region
108
+ # hangs forever on a barrier because the worker threads do not survive the fork.
109
+ # The threaded OpenBLAS that R's LAPACK calls (e.g. `summary(lm(...))` for the
110
+ # `sigest` default) dispatches through libgomp, so a child deadlocks there.
111
+ # Running these pools single-threaded across the fork stops the thread team from
112
+ # being started at all, sidestepping the deadlock. Each entry is a (getter,
113
+ # setter) pair of C symbols; missing ones (e.g. a single-threaded reference BLAS)
114
+ # are skipped.
115
+ NATIVE_THREAD_POOLS = (
116
+ ('omp_get_max_threads', 'omp_set_num_threads'),
117
+ ('openblas_get_num_threads', 'openblas_set_num_threads'),
118
+ )
119
+
120
+
121
+ @contextmanager
122
+ def fork_safe_native_threads() -> Iterator[None]:
123
+ """Cap OpenMP/OpenBLAS thread pools at one thread for the duration.
124
+
125
+ Workaround for the deadlock that hangs the children forked by R's
126
+ ``parallel::mcparallel`` when GNU libgomp has a live thread pool (see
127
+ `NATIVE_THREAD_POOLS`). The previous thread counts are restored on exit.
128
+ """
129
+ handle = ctypes.CDLL(None)
130
+ saved = []
131
+ for getter_name, setter_name in NATIVE_THREAD_POOLS:
132
+ try:
133
+ getter = getattr(handle, getter_name)
134
+ setter = getattr(handle, setter_name)
135
+ except AttributeError:
136
+ continue
137
+ getter.restype = ctypes.c_int
138
+ setter.argtypes = (ctypes.c_int,)
139
+ saved.append((setter, getter()))
140
+ setter(1)
141
+ try:
142
+ yield
143
+ finally:
144
+ for setter, nthreads in saved:
145
+ setter(nthreads)
146
+
147
+
148
+ class RObjectBase:
149
+ """
150
+ Base class for Python wrappers of R objects creators.
151
+
152
+ Subclasses should define the class attribute `_rfuncname`, and declare
153
+ stub methods decorated with `rmethod`.
154
+
155
+ _rfuncname : str
156
+ An R function in the format ``'<package>::<function>``. The function is
157
+ called with the initialization arguments, converted to R objects, and is
158
+ expected to return an R object. The attributes of the R object are
159
+ converted to equivalent Python values and set as attributes of the
160
+ Python object. The R object itself is assigned to the member `_robject`.
161
+ """
162
+
163
+ _converter = (
164
+ robjects.default_converter
165
+ + PANDAS_CONVERTER
166
+ + POLARS_CONVERTER
167
+ + NUMPY_CONVERTER
168
+ + BOOL_VECTOR_CONVERTER
169
+ + JAX_CONVERTER
170
+ + DICT_CONVERTER
171
+ )
172
+ _convctx = conversion.localconverter(_converter)
173
+
174
+ def _py2r(self, x: object) -> object:
175
+ if isinstance(x, __class__):
176
+ return x._robject # noqa: SLF001, same-class access
177
+ with self._convctx:
178
+ return self._converter.py2rpy(x)
179
+
180
+ def _r2py(self, x: object) -> object:
181
+ with self._convctx:
182
+ return self._converter.rpy2py(x)
183
+
184
+ def _args2r(self, args: Iterable[Any]) -> tuple[Any, ...]:
185
+ return tuple(map(self._py2r, args))
186
+
187
+ def _kw2r(self, kw: Mapping[str, Any]) -> dict[str, Any]:
188
+ return {key: self._py2r(value) for key, value in kw.items()}
189
+
190
+ _rfuncname: str = NotImplemented
191
+
192
+ @property
193
+ def _library(self) -> str:
194
+ """Parse `_rfuncname` to get the library. Also checks `_rfuncname` is valid."""
195
+ pattern = rf'^({R_IDENTIFIER})::({R_IDENTIFIER})$'
196
+ m = match(pattern, self._rfuncname)
197
+ if m is None:
198
+ msg = f'Invalid _rfuncname: {self._rfuncname}.'
199
+ raise ValueError(msg)
200
+ return m.group(1)
201
+
202
+ @staticmethod
203
+ def _has_named_components(obj: object) -> bool:
204
+ """Whether `obj` exposes named components to set as attributes.
205
+
206
+ Only an R named list qualifies. A bare matrix (as `bartModelMatrix`
207
+ gives with ``numcut=0``) is excluded by the `ListVector` check: rpy2
208
+ reports a matrix's dimnames as ``names``, so the names check alone
209
+ would not cut it out.
210
+ """
211
+ names = getattr(obj, 'names', None)
212
+ return (
213
+ isinstance(obj, robjects.vectors.ListVector)
214
+ and names is not None
215
+ and names is not robjects.NULL
216
+ )
217
+
218
+ def _invoke_rfunc(self, args: Iterable[Any], kw: Mapping[str, Any]) -> object:
219
+ """Load the namespace and call `_rfuncname` on the converted arguments."""
220
+ robjects.r(f'loadNamespace("{self._library}")')
221
+ func = robjects.r(self._rfuncname)
222
+ return func(*self._args2r(args), **self._kw2r(kw))
223
+
224
+ def _set_attrs_from_robject(self) -> None:
225
+ """Set the named components of `self._robject` as Python attributes."""
226
+ if self._has_named_components(self._robject):
227
+ for s, v in self._robject.items():
228
+ setattr(self, s.replace('.', '_'), self._r2py(v))
229
+
230
+ def __init__(self, *args: Any, **kw: Any) -> None:
231
+ self._robject = self._invoke_rfunc(args, kw)
232
+ self._set_attrs_from_robject()
233
+
234
+ def __init_subclass__(cls, **kw: Any) -> None:
235
+ """Automatically add R documentation to subclasses."""
236
+ library, name = cls._rfuncname.split('::')
237
+ page = Package(library).fetch(name)
238
+ if cls.__doc__ is None:
239
+ cls.__doc__ = ''
240
+ # the R help text is plain text, not valid RST: append it as a literal
241
+ # block so docutils renders it verbatim instead of misparsing it
242
+ cls.__doc__ += '\nR documentation::\n\n' + indent(page.to_docstring(), ' ')
243
+
244
+
245
+ def rmethod(meth: Callable, *, rname: str | None = None) -> Callable:
246
+ """Automatically implement a method using the correspoding R method.
247
+
248
+ Parameters
249
+ ----------
250
+ meth
251
+ A method in a subclass of `RObjectBase`.
252
+ rname
253
+ The name of the method in R. If not specified, use the name of `meth`.
254
+
255
+ Returns
256
+ -------
257
+ methimpl
258
+ An implementation of the method that calls the R method. The original
259
+ implementation of meth is completely discarded.
260
+
261
+ Examples
262
+ --------
263
+ >>> class MyRObject(RObjectBase):
264
+ ... _rfuncname = 'mypackage::myfunction'
265
+ ... @partial(rmethod, rname='my.method')
266
+ ... def my_method(self, arg1: int, arg2: str):
267
+ ... ...
268
+ """
269
+ if rname is None:
270
+ rname = meth.__name__
271
+
272
+ # I can't automatically add a docstring to the method because the R class
273
+ # can be determined at runtime
274
+
275
+ @wraps(meth)
276
+ def impl(self: RObjectBase, *args: Any, **kw: Any) -> object:
277
+ if isinstance(self._robject, RS4):
278
+ func = robjects.r['$'](self._robject, rname)
279
+ out = func(*self._args2r(args), **self._kw2r(kw))
280
+
281
+ else:
282
+ if not fullmatch(R_IDENTIFIER, rname):
283
+ msg = f'Invalid R method name: {rname}'
284
+ raise ValueError(msg)
285
+ rclass = self._robject.rclass[0]
286
+ func = robjects.r(
287
+ f'getS3method("{rname}", "{rclass}", envir = asNamespace("{self._library}"))'
288
+ )
289
+ out = func(self._robject, *self._args2r(args), **self._kw2r(kw))
290
+
291
+ return self._r2py(out)
292
+
293
+ return impl
@@ -0,0 +1,2 @@
1
+ __version__ = '0.1.0'
2
+ __version_info__ = (0, 1, 0)
@@ -0,0 +1,58 @@
1
+ # rbartpackages/src/rbartpackages/bartMachine.py
2
+ #
3
+ # Copyright (c) 2025-2026, The rbartpackages Contributors
4
+ #
5
+ # This file is part of rbartpackages.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Python wrapper of the R package bartMachine."""
26
+
27
+ # ruff: noqa: D102, ANN201, ANN002, ANN003
28
+
29
+ from rpy2 import robjects
30
+
31
+ from rbartpackages._base import RObjectBase, rmethod
32
+
33
+
34
+ class bartMachine(RObjectBase): # noqa: D101, because the doc is pulled from R
35
+ _rfuncname = 'bartMachine::bartMachine'
36
+
37
+ def __init__(
38
+ self, *args, num_cores: int | None = None, megabytes: int = 5000, **kw
39
+ ) -> None:
40
+ # bartMachine uses the (incubating) Vector API on JDK 16+, so the JVM
41
+ # must be started with that module or fitting raises NoClassDefFoundError.
42
+ robjects.r(
43
+ 'options(java.parameters = c('
44
+ f'"-Xmx{megabytes:d}m", "--add-modules=jdk.incubator.vector"))'
45
+ )
46
+ robjects.r('loadNamespace("bartMachine")')
47
+ if num_cores is not None:
48
+ robjects.r(f'bartMachine::set_bart_machine_num_cores({int(num_cores)})')
49
+ super().__init__(*args, **kw)
50
+
51
+ @rmethod
52
+ def predict(self, *args, **kw): ...
53
+
54
+ @rmethod
55
+ def get_posterior(self, *args, **kw): ...
56
+
57
+ @rmethod
58
+ def get_sigsqs(self, *args, **kw): ...
@@ -0,0 +1,152 @@
1
+ # rbartpackages/src/rbartpackages/dbarts.py
2
+ #
3
+ # Copyright (c) 2025-2026, The rbartpackages Contributors
4
+ #
5
+ # This file is part of rbartpackages.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Python wrapper of the R package `dbarts`."""
26
+
27
+ # ruff: noqa: D101, D102, ANN201, ANN002, ANN003
28
+
29
+ from rpy2 import robjects
30
+
31
+ from rbartpackages._base import RObjectBase, rmethod
32
+
33
+
34
+ class bart(RObjectBase):
35
+ """
36
+ Python interface to dbarts::bart.
37
+
38
+ The named numeric vector form of the `splitprobs` parameter must be
39
+ specified as a dictionary in Python.
40
+ """
41
+
42
+ _rfuncname = 'dbarts::bart'
43
+ _split_probs = 'splitprobs'
44
+
45
+ def __init__(self, *args, **kw) -> None:
46
+ split_probs = kw.get(self._split_probs)
47
+ if isinstance(split_probs, dict):
48
+ values = list(split_probs.values())
49
+ names = list(split_probs.keys())
50
+ split_probs = robjects.FloatVector(values)
51
+ split_probs = robjects.r('setNames')(split_probs, names)
52
+ kw[self._split_probs] = split_probs
53
+
54
+ super().__init__(*args, **kw)
55
+
56
+ @rmethod
57
+ def predict(self, *args, **kw): ...
58
+
59
+ @rmethod
60
+ def extract(self, *args, **kw): ...
61
+
62
+ @rmethod
63
+ def fitted(self, *args, **kw): ...
64
+
65
+
66
+ class bart2(bart):
67
+ """
68
+ Python interface to dbarts::bart2.
69
+
70
+ The named numeric vector form of the `split_probs` parameter must be
71
+ specified as a dictionary in Python.
72
+ """
73
+
74
+ _rfuncname = 'dbarts::bart2'
75
+ _split_probs = 'split_probs'
76
+
77
+ def __init__(self, formula: str, *args, **kw) -> None:
78
+ rformula = robjects.Formula(formula)
79
+ super().__init__(rformula, *args, **kw)
80
+
81
+
82
+ class rbart_vi(bart2):
83
+ """
84
+ Python interface to dbarts::rbart_vi.
85
+
86
+ The named numeric vector form of the `split_probs` parameter must be
87
+ specified as a dictionary in Python.
88
+ """
89
+
90
+ _rfuncname = 'dbarts::rbart_vi'
91
+
92
+
93
+ class dbarts(RObjectBase):
94
+ _rfuncname = 'dbarts::dbarts'
95
+
96
+ @rmethod
97
+ def run(self, *args, **kw): ...
98
+
99
+ @rmethod
100
+ def sampleTreesFromPrior(self, *args, **kw): ...
101
+
102
+ @rmethod
103
+ def sampleNodeParametersFromPrior(self, *args, **kw): ...
104
+
105
+ @rmethod
106
+ def copy(self, *args, **kw): ...
107
+
108
+ @rmethod
109
+ def show(self, *args, **kw): ...
110
+
111
+ @rmethod
112
+ def predict(self, *args, **kw): ...
113
+
114
+ @rmethod
115
+ def setControl(self, *args, **kw): ...
116
+
117
+ @rmethod
118
+ def setModel(self, *args, **kw): ...
119
+
120
+ @rmethod
121
+ def setData(self, *args, **kw): ...
122
+
123
+ @rmethod
124
+ def setResponse(self, *args, **kw): ...
125
+
126
+ @rmethod
127
+ def setOffset(self, *args, **kw): ...
128
+
129
+ @rmethod
130
+ def setSigma(self, *args, **kw): ...
131
+
132
+ @rmethod
133
+ def setPredictor(self, *args, **kw): ...
134
+
135
+ @rmethod
136
+ def setTestPredictor(self, *args, **kw): ...
137
+
138
+ @rmethod
139
+ def setTestPredictorAndOffset(self, *args, **kw): ...
140
+
141
+ @rmethod
142
+ def setTestOffset(self, *args, **kw): ...
143
+
144
+ @rmethod
145
+ def printTrees(self, *args, **kw): ...
146
+
147
+ @rmethod
148
+ def plotTree(self, *args, **kw): ...
149
+
150
+
151
+ class dbartsControl(RObjectBase):
152
+ _rfuncname = 'dbarts::dbartsControl'
@@ -0,0 +1,65 @@
1
+ Metadata-Version: 2.4
2
+ Name: rbartpackages
3
+ Version: 0.1.0
4
+ Summary: Python wrappers of R BART packages via rpy2
5
+ Author: Giacomo Petrillo
6
+ Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
7
+ License-Expression: MIT
8
+ License-File: LICENSE
9
+ Requires-Dist: rpy2>=3.6.0
10
+ Requires-Dist: numpy>=2.2.6
11
+ Requires-Dist: jaxtyping>=0.3.2
12
+ Requires-Dist: jax>=0.6.1 ; extra == 'jax'
13
+ Requires-Dist: pandas>=2.2.3 ; extra == 'pandas'
14
+ Requires-Dist: polars>=1.30.0 ; extra == 'polars'
15
+ Requires-Dist: pandas>=2.2.3 ; extra == 'polars'
16
+ Requires-Python: >=3.10
17
+ Project-URL: Homepage, https://github.com/bartz-org/rbartpackages
18
+ Project-URL: Documentation, https://bartz-org.github.io/rbartpackages/docs-dev
19
+ Project-URL: Issues, https://github.com/bartz-org/rbartpackages/issues
20
+ Provides-Extra: jax
21
+ Provides-Extra: pandas
22
+ Provides-Extra: polars
23
+ Description-Content-Type: text/markdown
24
+
25
+ # rbartpackages
26
+
27
+ Python wrappers of R BART (Bayesian Additive Regression Trees) packages, built on [rpy2](https://rpy2.github.io).
28
+
29
+ `rbartpackages` lets you call several R BART implementations from Python with a uniform, lightly-typed interface: arguments are converted to R, the fitted R object's components become Python attributes, and the original R documentation is attached to each wrapper class. It currently wraps:
30
+
31
+ - [`BART`](https://cran.r-project.org/package=BART)
32
+ - [`BART3`](https://github.com/rsparapa/bnptools) (the development superset of `BART`)
33
+ - [`bartMachine`](https://cran.r-project.org/package=bartMachine)
34
+ - [`dbarts`](https://cran.r-project.org/package=dbarts)
35
+
36
+ ## Installation
37
+
38
+ ```sh
39
+ pip install rbartpackages
40
+ ```
41
+
42
+ You also need R with the package(s) you want to use installed (`BART`, `dbarts`, `bartMachine` from CRAN; `BART3` from `rsparapa/bnptools` on GitHub). `bartMachine` additionally requires Java. Optional extras `pandas`, `polars`, and `jax` enable passing those array/frame types directly. See the documentation for details.
43
+
44
+ ## Usage
45
+
46
+ ```python
47
+ import numpy as np
48
+ from rbartpackages import BART3
49
+
50
+ x_train = np.random.randn(100, 5)
51
+ y_train = x_train[:, 0] + 0.1 * np.random.randn(100)
52
+
53
+ bart = BART3.gbart(x_train=x_train, y_train=y_train, ndpost=200)
54
+ y_pred = bart.predict(x_train) # shape (ndpost, n)
55
+ ```
56
+
57
+ R argument names with dots are passed with underscores (`x.train` → `x_train`).
58
+
59
+ ## Links
60
+
61
+ - [Documentation](https://bartz-org.github.io/rbartpackages/docs-dev)
62
+ - [Repository](https://github.com/bartz-org/rbartpackages)
63
+ - [List of BART packages](https://bartz-org.github.io/bartz/docs-dev/pkglist.html) (maintained in the bartz docs)
64
+
65
+ These wrappers originated in the [bartz](https://github.com/bartz-org/bartz) project, where they are used to validate against reference R implementations.
@@ -0,0 +1,11 @@
1
+ rbartpackages/BART.py,sha256=Hs227Xk4NomTkUj0zxtrUlky1TQdtVcrfD5maZFdD9A,13154
2
+ rbartpackages/BART3.py,sha256=_OlOYiXXWQA0HGk6MqDmv-epa81bmP4LC5G2TKimSyc,14499
3
+ rbartpackages/__init__.py,sha256=gRLv9gu9gJ_ry7dts9TKfJh8gH77mlqQYl2JTE1_q1s,1648
4
+ rbartpackages/_base.py,sha256=23KeW39IoFxkZ-t1U9rDwoqJCdgOdDM4gf5L79OLw7o,10521
5
+ rbartpackages/_version.py,sha256=Y4LHeOXBqKMcOffQfycAar5j-JEc-ObkXZE04PES0Tw,51
6
+ rbartpackages/bartMachine.py,sha256=bZPHksSM6w7ebNrfT2XovOLjNmzliuf64OcsBAwZ7Fg,2311
7
+ rbartpackages/dbarts.py,sha256=WuhrFmqfzZCAeo6G6cp5BdYGX9MyIP2giTYgiUJK11c,4078
8
+ rbartpackages-0.1.0.dist-info/licenses/LICENSE,sha256=pouECvzJhnd4fBwt1rCQYLgRpjnCdMAEpqOE8Pwn-sI,1092
9
+ rbartpackages-0.1.0.dist-info/WHEEL,sha256=f5fWSvWsg5Knq5GWa6t1nJIug0Tqo69GqAWD_9LbBKw,81
10
+ rbartpackages-0.1.0.dist-info/METADATA,sha256=zIkothg8jK6V4dyC3zEqgfg-pH9hweqZuq3-g245CbA,2702
11
+ rbartpackages-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: uv 0.11.16
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024-2026 The rbartpackages Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.