rbartpackages 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rbartpackages/BART.py +335 -0
- rbartpackages/BART3.py +368 -0
- rbartpackages/__init__.py +35 -0
- rbartpackages/_base.py +293 -0
- rbartpackages/_version.py +2 -0
- rbartpackages/bartMachine.py +58 -0
- rbartpackages/dbarts.py +152 -0
- rbartpackages-0.1.0.dist-info/METADATA +65 -0
- rbartpackages-0.1.0.dist-info/RECORD +11 -0
- rbartpackages-0.1.0.dist-info/WHEEL +4 -0
- rbartpackages-0.1.0.dist-info/licenses/LICENSE +21 -0
rbartpackages/BART.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
# rbartpackages/src/rbartpackages/BART.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The rbartpackages Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of rbartpackages.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Wrapper for the R package BART."""
|
|
26
|
+
|
|
27
|
+
# ruff: noqa: ANN002, ANN003
|
|
28
|
+
|
|
29
|
+
from functools import partial
|
|
30
|
+
from typing import NamedTuple, TypedDict, cast
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
from jaxtyping import AbstractDtype, Bool, Float64, Int32
|
|
34
|
+
from numpy import ndarray
|
|
35
|
+
from rpy2 import robjects
|
|
36
|
+
from rpy2.rlike.container import NamedList
|
|
37
|
+
|
|
38
|
+
from rbartpackages._base import RObjectBase, fork_safe_native_threads, rmethod
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TreeDraws(TypedDict):
|
|
42
|
+
"""Type of the `treedraws` attribute of `mc_gbart`."""
|
|
43
|
+
|
|
44
|
+
cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']]
|
|
45
|
+
"""Per-variable grid of candidate split points, keyed by column index or name."""
|
|
46
|
+
|
|
47
|
+
trees: str
|
|
48
|
+
"""Posterior tree ensemble serialized in BART's text format (read by `predict`)."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PredictBinary(TypedDict):
|
|
52
|
+
"""Type of `predict`'s return value for binary (`pbart`/`lbart`) fits."""
|
|
53
|
+
|
|
54
|
+
yhat_test: Float64[ndarray, 'ndpost m']
|
|
55
|
+
"""Posterior latent-function draws at the test points."""
|
|
56
|
+
|
|
57
|
+
prob_test: Float64[ndarray, 'ndpost m']
|
|
58
|
+
"""Success-probability draws (inverse probit/logit transform of `yhat_test`)."""
|
|
59
|
+
|
|
60
|
+
prob_test_mean: Float64[ndarray, ' m']
|
|
61
|
+
"""Posterior mean of `prob_test`."""
|
|
62
|
+
|
|
63
|
+
binaryOffset: float
|
|
64
|
+
"""Data centering value on the latent scale."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class String(AbstractDtype):
|
|
68
|
+
"""Represent a `numpy.str_` data dtype."""
|
|
69
|
+
|
|
70
|
+
dtypes = r'<U\d+'
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ProcTime(NamedTuple):
|
|
74
|
+
"""Python representation of the output of R's `proc.time`."""
|
|
75
|
+
|
|
76
|
+
user_self: float
|
|
77
|
+
"""CPU seconds charged to the R process in user mode."""
|
|
78
|
+
|
|
79
|
+
sys_self: float
|
|
80
|
+
"""CPU seconds charged to the R process in system (kernel) mode."""
|
|
81
|
+
|
|
82
|
+
elapsed: float
|
|
83
|
+
"""Wall-clock seconds elapsed."""
|
|
84
|
+
|
|
85
|
+
user_child: float
|
|
86
|
+
"""User-mode CPU seconds of forked child processes (`mc.gbart` workers)."""
|
|
87
|
+
|
|
88
|
+
sys_child: float
|
|
89
|
+
"""System-mode CPU seconds of forked child processes."""
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class mc_gbart(RObjectBase): # noqa: D101 because the R doc is added automatically
|
|
93
|
+
_rfuncname = 'BART::mc.gbart'
|
|
94
|
+
|
|
95
|
+
LPML: float
|
|
96
|
+
"""Log pseudo-marginal likelihood; unstable for BART.
|
|
97
|
+
|
|
98
|
+
Always computed, even without burn-in. Miscomputed by R for binary
|
|
99
|
+
`mc.gbart` fits with ``mc_cores > 1`` (the chains' probabilities are not
|
|
100
|
+
combined before the computation).
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
hostname: Bool[ndarray, ' mc_cores'] | String[ndarray, ' mc_cores']
|
|
104
|
+
"""Per-chain hostname if fitted with ``hostname=True``, else per-chain ``False``."""
|
|
105
|
+
|
|
106
|
+
ndpost: int
|
|
107
|
+
"""Number of posterior draws kept, after burn-in and thinning."""
|
|
108
|
+
|
|
109
|
+
offset: float
|
|
110
|
+
"""Data centering value for the response (link scale for binary)."""
|
|
111
|
+
|
|
112
|
+
prob_test: None | Float64[ndarray, 'ndpost/mc_cores m'] = None
|
|
113
|
+
"""Test-point success-probability draws (binary outcomes only).
|
|
114
|
+
|
|
115
|
+
`mc.gbart` with ``mc_cores > 1`` forgets to combine the chains, leaving
|
|
116
|
+
only the first chain's draws.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
prob_test_mean: None | Float64[ndarray, ' m'] = None
|
|
120
|
+
"""Posterior mean of `prob_test`."""
|
|
121
|
+
|
|
122
|
+
prob_train: None | Float64[ndarray, 'ndpost/mc_cores n'] = None
|
|
123
|
+
"""Training-point success-probability draws (binary outcomes only).
|
|
124
|
+
|
|
125
|
+
`mc.gbart` with ``mc_cores > 1`` forgets to combine the chains, leaving
|
|
126
|
+
only the first chain's draws.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
prob_train_mean: None | Float64[ndarray, ' n'] = None
|
|
130
|
+
"""Posterior mean of `prob_train`."""
|
|
131
|
+
|
|
132
|
+
proc_time: ProcTime
|
|
133
|
+
"""Timing of the fit, from R's `proc.time`."""
|
|
134
|
+
|
|
135
|
+
rm_const: Int32[ndarray, '<=p']
|
|
136
|
+
"""0-based indices of the `x_train` columns kept (constant columns dropped).
|
|
137
|
+
|
|
138
|
+
`mc.gbart` with ``mc_cores=1`` relabels the kept columns to ``0 .. kept-1``,
|
|
139
|
+
losing which original columns were dropped.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
sigma: (
|
|
143
|
+
Float64[ndarray, ' nskip+ndpost*keepevery']
|
|
144
|
+
| Float64[ndarray, 'nskip+ndpost*keepevery/mc_cores mc_cores']
|
|
145
|
+
| None
|
|
146
|
+
) = None
|
|
147
|
+
"""Error-SD draws, continuous outcomes only (per chain for `mc.gbart`).
|
|
148
|
+
|
|
149
|
+
One draw per MCMC iteration: burn-in and the thinned-away iterations are
|
|
150
|
+
included.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
sigma_mean: float | None = None
|
|
154
|
+
"""Mean of the first `ndpost` post-burn-in `sigma` draws (continuous only)."""
|
|
155
|
+
|
|
156
|
+
treedraws: TreeDraws
|
|
157
|
+
"""Sampled trees: per-variable cutpoint grid and the serialized ensemble."""
|
|
158
|
+
|
|
159
|
+
varcount: Int32[ndarray, 'ndpost p']
|
|
160
|
+
"""Per-draw count of splits on each variable, summed over trees."""
|
|
161
|
+
|
|
162
|
+
varcount_mean: Float64[ndarray, ' p']
|
|
163
|
+
"""Posterior mean of `varcount` per variable."""
|
|
164
|
+
|
|
165
|
+
varprob: Float64[ndarray, 'ndpost p']
|
|
166
|
+
"""Per-draw probability assigned to each variable for splitting."""
|
|
167
|
+
|
|
168
|
+
varprob_mean: Float64[ndarray, ' p']
|
|
169
|
+
"""Posterior mean of `varprob` per variable."""
|
|
170
|
+
|
|
171
|
+
yhat_test: Float64[ndarray, 'ndpost m']
|
|
172
|
+
"""Test-point posterior function draws (latent scale for binary).
|
|
173
|
+
|
|
174
|
+
Always present: R's `cgbart` allocates it unconditionally, so without test
|
|
175
|
+
data it is an empty array rather than ``None`` (with the rows of the first
|
|
176
|
+
chain only for `mc.gbart`, which combines the chains just when there is
|
|
177
|
+
test data).
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
yhat_test_mean: Float64[ndarray, ' m'] | None = None
|
|
181
|
+
"""Posterior mean of `yhat_test` (continuous with test data only)."""
|
|
182
|
+
|
|
183
|
+
yhat_train: Float64[ndarray, 'ndpost n']
|
|
184
|
+
"""Training-point posterior function draws (latent scale for binary)."""
|
|
185
|
+
|
|
186
|
+
yhat_train_mean: Float64[ndarray, ' n'] | None = None
|
|
187
|
+
"""Posterior mean of `yhat_train` (continuous only)."""
|
|
188
|
+
|
|
189
|
+
def __init__(self, *args, **kw) -> None:
|
|
190
|
+
# mc.gbart forks via parallel::mcparallel; cap native thread pools at one
|
|
191
|
+
# thread across the fork to avoid a libgomp deadlock in the children.
|
|
192
|
+
with fork_safe_native_threads():
|
|
193
|
+
super().__init__(*args, **kw)
|
|
194
|
+
|
|
195
|
+
# fix up attributes
|
|
196
|
+
self.LPML = self.LPML.item()
|
|
197
|
+
self.ndpost = self.ndpost.astype(int).item()
|
|
198
|
+
self.offset = self.offset.item()
|
|
199
|
+
self.proc_time = ProcTime(*map(float, self.proc_time))
|
|
200
|
+
|
|
201
|
+
if np.all(self.rm_const < 0):
|
|
202
|
+
# R reports the dropped constant columns as negative indices into
|
|
203
|
+
# the original design matrix, while varcount has the kept ones
|
|
204
|
+
_, kept = self.varcount.shape
|
|
205
|
+
p = kept + self.rm_const.size
|
|
206
|
+
rm_const = np.ones(p, bool)
|
|
207
|
+
rm_const[-self.rm_const - 1] = False
|
|
208
|
+
self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
|
|
209
|
+
elif np.all(self.rm_const > 0):
|
|
210
|
+
self.rm_const -= 1
|
|
211
|
+
else: # pragma: no cover - R gives all-positive or all-negative indices
|
|
212
|
+
msg = 'failed to parse rm.const because indices change sign'
|
|
213
|
+
raise ValueError(msg)
|
|
214
|
+
|
|
215
|
+
if self.sigma_mean is not None:
|
|
216
|
+
self.sigma_mean = self.sigma_mean.item()
|
|
217
|
+
|
|
218
|
+
r_treedraws = cast(NamedList, self.treedraws)
|
|
219
|
+
cutpoints: NamedList = r_treedraws.getbyname('cutpoints')
|
|
220
|
+
self.treedraws = {
|
|
221
|
+
'cutpoints': {
|
|
222
|
+
i if it.name is None else it.name.item(): it.value
|
|
223
|
+
for i, it in enumerate(cutpoints.items())
|
|
224
|
+
},
|
|
225
|
+
'trees': r_treedraws.getbyname('trees').item(),
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
@partial(rmethod, rname='predict')
|
|
229
|
+
def _predict(self, newdata: Float64[ndarray, 'm p'], *args, **kwargs) -> object:
|
|
230
|
+
"""Call R's `predict`; returns a matrix (continuous) or a list (binary)."""
|
|
231
|
+
...
|
|
232
|
+
|
|
233
|
+
def predict(
|
|
234
|
+
self, newdata: Float64[ndarray, 'm p'], *args, **kwargs
|
|
235
|
+
) -> Float64[ndarray, 'ndpost m'] | PredictBinary:
|
|
236
|
+
"""Compute predictions.
|
|
237
|
+
|
|
238
|
+
For continuous (`wbart`) fits this is the matrix of posterior
|
|
239
|
+
latent-function draws. For binary (`pbart`/`lbart`) fits R returns a
|
|
240
|
+
list, exposed here as a `PredictBinary` dict.
|
|
241
|
+
|
|
242
|
+
For `mc.gbart` fits with ``mc_cores > 1`` that dropped constant
|
|
243
|
+
columns, R miscounts the kept columns and fails to update the header
|
|
244
|
+
of the serialized ensemble, so only the first chain's draws are
|
|
245
|
+
returned.
|
|
246
|
+
"""
|
|
247
|
+
out = self._predict(newdata, *args, **kwargs)
|
|
248
|
+
if not hasattr(out, 'items'):
|
|
249
|
+
return out # continuous: already a matrix
|
|
250
|
+
|
|
251
|
+
# binary: convert R's list (a NamedList) to a dict of arrays
|
|
252
|
+
out = cast(NamedList, out)
|
|
253
|
+
result = {str(it.name).replace('.', '_'): it.value for it in out.items()}
|
|
254
|
+
result['binaryOffset'] = result['binaryOffset'].item()
|
|
255
|
+
return result
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically
|
|
259
|
+
_rfuncname = 'BART::bartModelMatrix'
|
|
260
|
+
|
|
261
|
+
X: Float64[ndarray, 'N p']
|
|
262
|
+
"""Design matrix: vectors and data frames coerced to numeric, factors expanded to indicators."""
|
|
263
|
+
|
|
264
|
+
numcut: Int32[ndarray, ' p']
|
|
265
|
+
"""Number of cutpoints chosen per column."""
|
|
266
|
+
|
|
267
|
+
rm_const: Int32[ndarray, '<=p']
|
|
268
|
+
"""0-based indices of the non-constant columns of the expanded design.
|
|
269
|
+
|
|
270
|
+
The indices refer to the columns of `X` before removal: ``rm.const=True``
|
|
271
|
+
removes the constant columns from `X`, `numcut` and `xinfo`, while the
|
|
272
|
+
default only detects them.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
xinfo: Float64[ndarray, 'p numcut']
|
|
276
|
+
"""Per-column cutpoint grid, NaN-padded to the maximum cut count."""
|
|
277
|
+
|
|
278
|
+
grp: Int32[ndarray, ' p'] | Float64[ndarray, ' 1'] | None
|
|
279
|
+
"""1-based input-column index each output column comes from (factors expand
|
|
280
|
+
to one indicator column per level); ``None`` for matrix input."""
|
|
281
|
+
|
|
282
|
+
def __new__(cls, *args, **kw) -> Float64[ndarray, 'N p'] | RObjectBase:
|
|
283
|
+
"""Match R: return the bare matrix for ``numcut=0``, else a populated instance."""
|
|
284
|
+
# __init__ cannot change the return type, so the matrix-or-list choice
|
|
285
|
+
# is made here; returning a non-instance (the matrix) skips __init__.
|
|
286
|
+
self = super().__new__(cls)
|
|
287
|
+
self._robject = self._invoke_rfunc(args, kw)
|
|
288
|
+
if self._has_named_components(self._robject):
|
|
289
|
+
return self
|
|
290
|
+
return self._r2py(self._robject)
|
|
291
|
+
|
|
292
|
+
def __init__(self, *args, **kw) -> None:
|
|
293
|
+
# Only reached for the named-list case (numcut > 0); __new__ already
|
|
294
|
+
# invoked R and stored `_robject`, so just expose its components rather
|
|
295
|
+
# than calling super().__init__ (which would invoke R a second time).
|
|
296
|
+
self._set_attrs_from_robject()
|
|
297
|
+
|
|
298
|
+
# grp is R NULL for matrix input; expose it as None.
|
|
299
|
+
if self.grp is robjects.NULL:
|
|
300
|
+
self.grp = None
|
|
301
|
+
|
|
302
|
+
if np.all(self.rm_const < 0):
|
|
303
|
+
# R flags detected-constant columns as negative indices into the
|
|
304
|
+
# pre-removal design matrix; whether they were also removed from X
|
|
305
|
+
# depends on the rm.const argument, so recover it from the call
|
|
306
|
+
# (rm.const is the 5th parameter of R's bartModelMatrix)
|
|
307
|
+
removed = kw.get(
|
|
308
|
+
'rm.const', kw.get('rm_const', args[4] if len(args) > 4 else False)
|
|
309
|
+
)
|
|
310
|
+
_, n_cols = self.X.shape
|
|
311
|
+
p = n_cols + self.rm_const.size if removed else n_cols
|
|
312
|
+
rm_const = np.ones(p, bool)
|
|
313
|
+
rm_const[-self.rm_const - 1] = False
|
|
314
|
+
self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
|
|
315
|
+
elif np.all(self.rm_const > 0):
|
|
316
|
+
self.rm_const -= 1
|
|
317
|
+
else: # pragma: no cover - R gives all-positive or all-negative indices
|
|
318
|
+
msg = 'failed to parse rm.const because indices change sign'
|
|
319
|
+
raise ValueError(msg)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically
|
|
323
|
+
_rfuncname = 'BART::gbart'
|
|
324
|
+
|
|
325
|
+
hostname: Bool[ndarray, ' 1'] | String[ndarray, ' 1']
|
|
326
|
+
"""Hostname the fit ran on if fitted with ``hostname=True``, else ``False``."""
|
|
327
|
+
|
|
328
|
+
prob_test: None | Float64[ndarray, 'ndpost m'] = None
|
|
329
|
+
"""Test-point success-probability draws (binary outcomes only)."""
|
|
330
|
+
|
|
331
|
+
prob_train: None | Float64[ndarray, 'ndpost n'] = None
|
|
332
|
+
"""Training-point success-probability draws (binary outcomes only)."""
|
|
333
|
+
|
|
334
|
+
sigma: Float64[ndarray, ' nskip+ndpost*keepevery'] | None = None
|
|
335
|
+
"""Error-SD draws for every MCMC iteration, burn-in included (continuous only)."""
|
rbartpackages/BART3.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
# rbartpackages/src/rbartpackages/BART3.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The rbartpackages Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of rbartpackages.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Wrapper for the R package BART3."""
|
|
26
|
+
|
|
27
|
+
# ruff: noqa: ANN002, ANN003
|
|
28
|
+
|
|
29
|
+
from functools import partial
|
|
30
|
+
from typing import NamedTuple, TypedDict, cast
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
from jaxtyping import AbstractDtype, Float64, Int32
|
|
34
|
+
from numpy import ndarray
|
|
35
|
+
from rpy2 import robjects
|
|
36
|
+
from rpy2.rlike.container import NamedList
|
|
37
|
+
|
|
38
|
+
from rbartpackages._base import RObjectBase, fork_safe_native_threads, rmethod
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TreeDraws(TypedDict):
|
|
42
|
+
"""Type of the `treedraws` attribute of `mc_gbart`."""
|
|
43
|
+
|
|
44
|
+
cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']]
|
|
45
|
+
"""Per-variable grid of candidate split points, keyed by column index or name."""
|
|
46
|
+
|
|
47
|
+
trees: str
|
|
48
|
+
"""Posterior tree ensemble serialized in BART's text format (read by `predict`)."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PredictBinary(TypedDict):
|
|
52
|
+
"""Type of `predict`'s return value for binary (`pbart`/`lbart`) fits."""
|
|
53
|
+
|
|
54
|
+
yhat_test: Float64[ndarray, 'ndpost m']
|
|
55
|
+
"""Posterior latent-function draws at the test points."""
|
|
56
|
+
|
|
57
|
+
prob_test: Float64[ndarray, 'ndpost m']
|
|
58
|
+
"""Success-probability draws (inverse probit/logit transform of `yhat_test`)."""
|
|
59
|
+
|
|
60
|
+
prob_test_mean: Float64[ndarray, ' m']
|
|
61
|
+
"""Posterior mean of `prob_test`."""
|
|
62
|
+
|
|
63
|
+
prob_test_lower: Float64[ndarray, ' m']
|
|
64
|
+
"""Lower `probs` quantile of `prob_test` (default 2.5%)."""
|
|
65
|
+
|
|
66
|
+
prob_test_upper: Float64[ndarray, ' m']
|
|
67
|
+
"""Upper `probs` quantile of `prob_test` (default 97.5%)."""
|
|
68
|
+
|
|
69
|
+
binaryOffset: float
|
|
70
|
+
"""Data centering value on the latent scale."""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class String(AbstractDtype):
|
|
74
|
+
"""Represent a `numpy.str_` data dtype."""
|
|
75
|
+
|
|
76
|
+
dtypes = r'<U\d+'
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ProcTime(NamedTuple):
|
|
80
|
+
"""Python representation of the output of R's `proc.time`."""
|
|
81
|
+
|
|
82
|
+
user_self: float
|
|
83
|
+
"""CPU seconds charged to the R process in user mode."""
|
|
84
|
+
|
|
85
|
+
sys_self: float
|
|
86
|
+
"""CPU seconds charged to the R process in system (kernel) mode."""
|
|
87
|
+
|
|
88
|
+
elapsed: float
|
|
89
|
+
"""Wall-clock seconds elapsed."""
|
|
90
|
+
|
|
91
|
+
user_child: float
|
|
92
|
+
"""User-mode CPU seconds of forked child processes (`mc.gbart` workers)."""
|
|
93
|
+
|
|
94
|
+
sys_child: float
|
|
95
|
+
"""System-mode CPU seconds of forked child processes."""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class mc_gbart(RObjectBase): # noqa: D101 because the R doc is added automatically
|
|
99
|
+
_rfuncname = 'BART3::mc.gbart'
|
|
100
|
+
|
|
101
|
+
LPML: float | None = None
|
|
102
|
+
"""Log pseudo-marginal likelihood; ``None`` without burn-in. Unstable for BART."""
|
|
103
|
+
|
|
104
|
+
accept: (
|
|
105
|
+
Float64[ndarray, ' nskip+ndpost*keepevery']
|
|
106
|
+
| Float64[ndarray, 'nskip+ndpost*keepevery/mc_cores mc_cores']
|
|
107
|
+
)
|
|
108
|
+
"""Per-iteration Metropolis-Hastings acceptance rate (per chain for `mc.gbart`).
|
|
109
|
+
|
|
110
|
+
Recorded for every MCMC iteration, including the thinned-away ones (unlike
|
|
111
|
+
`sigma`, which keeps only burn-in plus retained draws).
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
chains: int
|
|
115
|
+
"""Number of MCMC chains, i.e. the `mc_cores` actually used."""
|
|
116
|
+
|
|
117
|
+
grp: Float64[ndarray, ' p']
|
|
118
|
+
"""Group index of each column for the sparse (DART) variable-selection prior."""
|
|
119
|
+
|
|
120
|
+
ndpost: int
|
|
121
|
+
"""Number of posterior draws kept, after burn-in and thinning."""
|
|
122
|
+
|
|
123
|
+
offset: float
|
|
124
|
+
"""Data centering value for the response (link scale for binary)."""
|
|
125
|
+
|
|
126
|
+
prob_test: None | Float64[ndarray, 'ndpost m'] = None
|
|
127
|
+
"""Test-point success-probability draws (binary outcomes only)."""
|
|
128
|
+
|
|
129
|
+
prob_test_lower: Float64[ndarray, ' m'] | None = None
|
|
130
|
+
"""Lower `probs` quantile of `prob_test` (default 2.5%)."""
|
|
131
|
+
|
|
132
|
+
prob_test_mean: None | Float64[ndarray, ' m'] = None
|
|
133
|
+
"""Posterior mean of `prob_test`."""
|
|
134
|
+
|
|
135
|
+
prob_test_upper: Float64[ndarray, ' m'] | None = None
|
|
136
|
+
"""Upper `probs` quantile of `prob_test` (default 97.5%)."""
|
|
137
|
+
|
|
138
|
+
prob_train: None | Float64[ndarray, 'ndpost n'] = None
|
|
139
|
+
"""Training-point success-probability draws (binary outcomes only)."""
|
|
140
|
+
|
|
141
|
+
prob_train_mean: None | Float64[ndarray, ' n'] = None
|
|
142
|
+
"""Posterior mean of `prob_train`."""
|
|
143
|
+
|
|
144
|
+
proc_time: ProcTime
|
|
145
|
+
"""Timing of the fit, from R's `proc.time`."""
|
|
146
|
+
|
|
147
|
+
rho: float
|
|
148
|
+
"""Concentration of the sparse (DART) prior; defaults to ``sum(1/grp)``."""
|
|
149
|
+
|
|
150
|
+
rm_const: Int32[ndarray, '<=p']
|
|
151
|
+
"""0-based indices of the `x_train` columns kept (constant columns dropped)."""
|
|
152
|
+
|
|
153
|
+
sigest: float | None = None
|
|
154
|
+
"""Rough residual SD used to set the sigma prior (continuous only).
|
|
155
|
+
|
|
156
|
+
``None`` for binary outcomes; ``nan`` when the `mc.gbart` ``mc_cores > 1``
|
|
157
|
+
bug overwrites it with a logical missing value.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
sigma: (
|
|
161
|
+
Float64[ndarray, ' nskip+ndpost']
|
|
162
|
+
| Float64[ndarray, 'nskip+ndpost/mc_cores mc_cores']
|
|
163
|
+
| None
|
|
164
|
+
) = None
|
|
165
|
+
"""Error-SD draws including burn-in, continuous only (per chain for `mc.gbart`)."""
|
|
166
|
+
|
|
167
|
+
sigma_: Float64[ndarray, ' ndpost'] | None = None
|
|
168
|
+
"""Kept `sigma` draws with burn-in dropped; ``None`` without burn-in."""
|
|
169
|
+
|
|
170
|
+
sigma_mean: float | None = None
|
|
171
|
+
"""Mean of `sigma_`; falls back to `sigest` when no draws are kept."""
|
|
172
|
+
|
|
173
|
+
treedraws: TreeDraws
|
|
174
|
+
"""Sampled trees: per-variable cutpoint grid and the serialized ensemble."""
|
|
175
|
+
|
|
176
|
+
varcount: Int32[ndarray, 'ndpost p']
|
|
177
|
+
"""Per-draw count of splits on each variable, summed over trees."""
|
|
178
|
+
|
|
179
|
+
varcount_mean: Float64[ndarray, ' p']
|
|
180
|
+
"""Posterior mean of `varcount` per variable."""
|
|
181
|
+
|
|
182
|
+
varprob: Float64[ndarray, 'ndpost p']
|
|
183
|
+
"""Per-draw probability assigned to each variable for splitting."""
|
|
184
|
+
|
|
185
|
+
varprob_mean: Float64[ndarray, ' p']
|
|
186
|
+
"""Posterior mean of `varprob` per variable."""
|
|
187
|
+
|
|
188
|
+
x_test: Float64[ndarray, ' m <=p'] | None = None
|
|
189
|
+
"""Test design matrix as used (imputed, factors expanded, constant columns dropped)."""
|
|
190
|
+
|
|
191
|
+
x_train: Float64[ndarray, ' n <=p']
|
|
192
|
+
"""Training design matrix as used (original scale, not binned; constant columns dropped)."""
|
|
193
|
+
|
|
194
|
+
yhat_test: Float64[ndarray, 'ndpost m']
|
|
195
|
+
"""Test-point posterior function draws (latent scale for binary).
|
|
196
|
+
|
|
197
|
+
Always present: R's `cgbart` allocates it unconditionally, so without test
|
|
198
|
+
data it is an empty ``(ndpost, 0)`` array rather than ``None`` (unlike the
|
|
199
|
+
derived `yhat_test_mean`/`lower`/`upper`, which R only fills when test data
|
|
200
|
+
is given).
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
yhat_test_lower: Float64[ndarray, ' m'] | None = None
|
|
204
|
+
"""Lower `probs` quantile of `yhat_test` (default 2.5%, continuous only)."""
|
|
205
|
+
|
|
206
|
+
yhat_test_mean: Float64[ndarray, ' m'] | None = None
|
|
207
|
+
"""Posterior mean of `yhat_test`."""
|
|
208
|
+
|
|
209
|
+
yhat_test_upper: Float64[ndarray, ' m'] | None = None
|
|
210
|
+
"""Upper `probs` quantile of `yhat_test` (default 97.5%, continuous only)."""
|
|
211
|
+
|
|
212
|
+
yhat_train: Float64[ndarray, 'ndpost n']
|
|
213
|
+
"""Training-point posterior function draws (latent scale for binary)."""
|
|
214
|
+
|
|
215
|
+
yhat_train_lower: Float64[ndarray, ' n'] | None = None
|
|
216
|
+
"""Lower `probs` quantile of `yhat_train` (default 2.5%, continuous only)."""
|
|
217
|
+
|
|
218
|
+
yhat_train_mean: Float64[ndarray, ' n'] | None = None
|
|
219
|
+
"""Posterior mean of `yhat_train`."""
|
|
220
|
+
|
|
221
|
+
yhat_train_upper: Float64[ndarray, ' n'] | None = None
|
|
222
|
+
"""Upper `probs` quantile of `yhat_train` (default 97.5%, continuous only)."""
|
|
223
|
+
|
|
224
|
+
def __init__(self, *args, **kw) -> None:
|
|
225
|
+
# mc.gbart forks via parallel::mcparallel; cap native thread pools at one
|
|
226
|
+
# thread across the fork to avoid a libgomp deadlock in the children.
|
|
227
|
+
with fork_safe_native_threads():
|
|
228
|
+
super().__init__(*args, **kw)
|
|
229
|
+
|
|
230
|
+
# fix up attributes
|
|
231
|
+
self.chains = self.chains.item()
|
|
232
|
+
self.ndpost = self.ndpost.astype(int).item()
|
|
233
|
+
self.offset = self.offset.item()
|
|
234
|
+
self.proc_time = ProcTime(*map(float, self.proc_time))
|
|
235
|
+
self.rho = self.rho.item()
|
|
236
|
+
|
|
237
|
+
if np.all(self.rm_const < 0):
|
|
238
|
+
# R reports the dropped constant columns as negative indices into
|
|
239
|
+
# the original design matrix, while varcount has the kept ones
|
|
240
|
+
_, kept = self.varcount.shape
|
|
241
|
+
p = kept + self.rm_const.size
|
|
242
|
+
rm_const = np.ones(p, bool)
|
|
243
|
+
rm_const[-self.rm_const - 1] = False
|
|
244
|
+
self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
|
|
245
|
+
elif np.all(self.rm_const > 0):
|
|
246
|
+
self.rm_const -= 1
|
|
247
|
+
else: # pragma: no cover - R gives all-positive or all-negative indices
|
|
248
|
+
msg = 'failed to parse rm.const because indices change sign'
|
|
249
|
+
raise ValueError(msg)
|
|
250
|
+
|
|
251
|
+
if self.LPML is not None:
|
|
252
|
+
self.LPML = self.LPML.item()
|
|
253
|
+
if self.sigest is not None:
|
|
254
|
+
if self.sigest.dtype == bool:
|
|
255
|
+
# BART3 bug: mc.gbart with mc_cores > 1 overwrites sigest with
|
|
256
|
+
# its logical-NA default instead of the estimate.
|
|
257
|
+
self.sigest = float('nan')
|
|
258
|
+
else:
|
|
259
|
+
self.sigest = self.sigest.item()
|
|
260
|
+
if self.sigma_mean is not None:
|
|
261
|
+
self.sigma_mean = self.sigma_mean.item()
|
|
262
|
+
|
|
263
|
+
r_treedraws = cast(NamedList, self.treedraws)
|
|
264
|
+
cutpoints: NamedList = r_treedraws.getbyname('cutpoints')
|
|
265
|
+
self.treedraws = {
|
|
266
|
+
'cutpoints': {
|
|
267
|
+
i if it.name is None else it.name.item(): it.value
|
|
268
|
+
for i, it in enumerate(cutpoints.items())
|
|
269
|
+
},
|
|
270
|
+
'trees': r_treedraws.getbyname('trees').item(),
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
@partial(rmethod, rname='predict')
|
|
274
|
+
def _predict(self, newdata: Float64[ndarray, 'm p'], *args, **kwargs) -> object:
|
|
275
|
+
"""Call R's `predict`; returns a matrix (continuous) or a list (binary)."""
|
|
276
|
+
...
|
|
277
|
+
|
|
278
|
+
def predict(
|
|
279
|
+
self, newdata: Float64[ndarray, 'm p'], *args, **kwargs
|
|
280
|
+
) -> Float64[ndarray, 'ndpost m'] | PredictBinary:
|
|
281
|
+
"""Compute predictions.
|
|
282
|
+
|
|
283
|
+
For continuous (`wbart`) fits this is the matrix of posterior
|
|
284
|
+
latent-function draws. For binary (`pbart`/`lbart`) fits R returns a
|
|
285
|
+
list, exposed here as a `PredictBinary` dict.
|
|
286
|
+
"""
|
|
287
|
+
out = self._predict(newdata, *args, **kwargs)
|
|
288
|
+
if not hasattr(out, 'items'):
|
|
289
|
+
return out # continuous: already a matrix
|
|
290
|
+
|
|
291
|
+
# binary: convert R's list (a NamedList) to a dict of arrays
|
|
292
|
+
out = cast(NamedList, out)
|
|
293
|
+
result = {str(it.name).replace('.', '_'): it.value for it in out.items()}
|
|
294
|
+
result['binaryOffset'] = result['binaryOffset'].item()
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically
|
|
299
|
+
_rfuncname = 'BART3::bartModelMatrix'
|
|
300
|
+
|
|
301
|
+
X: Float64[ndarray, 'N p']
|
|
302
|
+
"""Design matrix: vectors and data frames coerced to numeric, factors expanded to indicators."""
|
|
303
|
+
|
|
304
|
+
numcut: Int32[ndarray, ' p']
|
|
305
|
+
"""Number of cutpoints chosen per column."""
|
|
306
|
+
|
|
307
|
+
rm_const: Int32[ndarray, '<=p']
|
|
308
|
+
"""0-based indices of the non-constant columns of the expanded design.
|
|
309
|
+
|
|
310
|
+
The indices refer to the columns of `X` before removal: ``rm.const=True``
|
|
311
|
+
removes the constant columns from `X`, `numcut` and `xinfo`, while the
|
|
312
|
+
default only detects them.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
xinfo: Float64[ndarray, 'p numcut']
|
|
316
|
+
"""Per-column cutpoint grid, NaN-padded to the maximum cut count."""
|
|
317
|
+
|
|
318
|
+
grp: Float64[ndarray, ' p'] | None
|
|
319
|
+
"""Group size of each expanded factor's indicator columns, or None if no factors."""
|
|
320
|
+
|
|
321
|
+
def __new__(cls, *args, **kw) -> Float64[ndarray, 'N p'] | RObjectBase:
|
|
322
|
+
"""Match R: return the bare matrix for ``numcut=0``, else a populated instance."""
|
|
323
|
+
# __init__ cannot change the return type, so the matrix-or-list choice
|
|
324
|
+
# is made here; returning a non-instance (the matrix) skips __init__.
|
|
325
|
+
self = super().__new__(cls)
|
|
326
|
+
self._robject = self._invoke_rfunc(args, kw)
|
|
327
|
+
if self._has_named_components(self._robject):
|
|
328
|
+
return self
|
|
329
|
+
return self._r2py(self._robject)
|
|
330
|
+
|
|
331
|
+
def __init__(self, *args, **kw) -> None:
|
|
332
|
+
# Only reached for the named-list case (numcut > 0); __new__ already
|
|
333
|
+
# invoked R and stored `_robject`, so just expose its components rather
|
|
334
|
+
# than calling super().__init__ (which would invoke R a second time).
|
|
335
|
+
self._set_attrs_from_robject()
|
|
336
|
+
|
|
337
|
+
# grp is R NULL unless the input had factor columns; expose it as None.
|
|
338
|
+
if self.grp is robjects.NULL:
|
|
339
|
+
self.grp = None
|
|
340
|
+
|
|
341
|
+
if np.all(self.rm_const < 0):
|
|
342
|
+
# R flags detected-constant columns as negative indices into the
|
|
343
|
+
# pre-removal design matrix; whether they were also removed from X
|
|
344
|
+
# depends on the rm.const argument, so recover it from the call
|
|
345
|
+
# (rm.const is the 5th parameter of R's bartModelMatrix)
|
|
346
|
+
removed = kw.get(
|
|
347
|
+
'rm.const', kw.get('rm_const', args[4] if len(args) > 4 else False)
|
|
348
|
+
)
|
|
349
|
+
_, n_cols = self.X.shape
|
|
350
|
+
p = n_cols + self.rm_const.size if removed else n_cols
|
|
351
|
+
rm_const = np.ones(p, bool)
|
|
352
|
+
rm_const[-self.rm_const - 1] = False
|
|
353
|
+
self.rm_const = np.arange(p, dtype=np.int32)[rm_const]
|
|
354
|
+
elif np.all(self.rm_const > 0):
|
|
355
|
+
self.rm_const -= 1
|
|
356
|
+
else: # pragma: no cover - R gives all-positive or all-negative indices
|
|
357
|
+
msg = 'failed to parse rm.const because indices change sign'
|
|
358
|
+
raise ValueError(msg)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically
|
|
362
|
+
_rfuncname = 'BART3::gbart'
|
|
363
|
+
|
|
364
|
+
accept: Float64[ndarray, ' nskip+ndpost*keepevery']
|
|
365
|
+
"""Per-iteration Metropolis-Hastings acceptance rate (every MCMC iteration)."""
|
|
366
|
+
|
|
367
|
+
sigma: Float64[ndarray, ' nskip+ndpost'] | None = None
|
|
368
|
+
"""Error-SD draws including burn-in (continuous only)."""
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# rbartpackages/src/rbartpackages/__init__.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The rbartpackages Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of rbartpackages.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
Python wrappers of R BART packages.
|
|
27
|
+
|
|
28
|
+
Each wrapped R package has its own submodule (``BART``, ``BART3``,
|
|
29
|
+
``bartMachine``, ``dbarts``); import the one you need, e.g.
|
|
30
|
+
``from rbartpackages import BART3``. Importing a wrapper requires the
|
|
31
|
+
corresponding R package to be installed, because the class docstrings are pulled
|
|
32
|
+
from the R documentation at import time.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from rbartpackages._version import __version__, __version_info__ # noqa: F401
|
rbartpackages/_base.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# rbartpackages/src/rbartpackages/_base.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The rbartpackages Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of rbartpackages.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
import ctypes
|
|
26
|
+
from collections.abc import Callable, Iterable, Iterator, Mapping
|
|
27
|
+
from contextlib import contextmanager
|
|
28
|
+
from functools import wraps
|
|
29
|
+
from re import fullmatch, match
|
|
30
|
+
from textwrap import indent
|
|
31
|
+
from typing import Any
|
|
32
|
+
|
|
33
|
+
import numpy as np
|
|
34
|
+
from rpy2 import robjects
|
|
35
|
+
from rpy2.robjects import BoolVector, conversion, numpy2ri
|
|
36
|
+
from rpy2.robjects.help import Package
|
|
37
|
+
from rpy2.robjects.methods import RS4
|
|
38
|
+
|
|
39
|
+
# converter for pandas
|
|
40
|
+
PANDAS_CONVERTER = conversion.Converter('pandas')
|
|
41
|
+
try:
|
|
42
|
+
from rpy2.robjects import pandas2ri
|
|
43
|
+
except ImportError: # pragma: no cover - optional dep always present in CI
|
|
44
|
+
pass
|
|
45
|
+
else:
|
|
46
|
+
PANDAS_CONVERTER = pandas2ri.converter
|
|
47
|
+
|
|
48
|
+
# converter for polars
|
|
49
|
+
POLARS_CONVERTER = conversion.Converter('polars')
|
|
50
|
+
try:
|
|
51
|
+
import polars as pl
|
|
52
|
+
from rpy2.robjects import pandas2ri
|
|
53
|
+
except ImportError: # pragma: no cover - optional dep always present in CI
|
|
54
|
+
pass
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
def polars_to_r(df: pl.DataFrame) -> object:
|
|
58
|
+
df = df.to_pandas()
|
|
59
|
+
return pandas2ri.py2rpy(df)
|
|
60
|
+
|
|
61
|
+
POLARS_CONVERTER.py2rpy.register(pl.DataFrame, polars_to_r)
|
|
62
|
+
POLARS_CONVERTER.py2rpy.register(pl.Series, polars_to_r)
|
|
63
|
+
|
|
64
|
+
# converter for jax
|
|
65
|
+
JAX_CONVERTER = conversion.Converter('jax')
|
|
66
|
+
try:
|
|
67
|
+
import jax
|
|
68
|
+
except ImportError: # pragma: no cover - optional dep always present in CI
|
|
69
|
+
pass
|
|
70
|
+
else:
|
|
71
|
+
|
|
72
|
+
def jax_to_r(x: jax.Array) -> object:
|
|
73
|
+
x = np.asarray(x)
|
|
74
|
+
if x.ndim == 0:
|
|
75
|
+
x = x[()]
|
|
76
|
+
return numpy2ri.py2rpy(x)
|
|
77
|
+
|
|
78
|
+
JAX_CONVERTER.py2rpy.register(jax.Array, jax_to_r)
|
|
79
|
+
|
|
80
|
+
# converter for numpy
|
|
81
|
+
NUMPY_CONVERTER = numpy2ri.converter
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# converter for BoolVector (why isn't it in the numpy converter?)
|
|
85
|
+
def bool_vector_to_python(x: BoolVector) -> np.ndarray[Any, np.dtype[np.bool_]]:
|
|
86
|
+
return np.array(x, bool)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
BOOL_VECTOR_CONVERTER = conversion.Converter('bool_vector')
|
|
90
|
+
BOOL_VECTOR_CONVERTER.rpy2py.register(BoolVector, bool_vector_to_python)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# converter for python dictionaries
|
|
94
|
+
DICT_CONVERTER = conversion.Converter('dict')
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def dict_to_r(x: dict[str, Any]) -> robjects.ListVector:
|
|
98
|
+
return robjects.ListVector(x)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
DICT_CONVERTER.py2rpy.register(dict, dict_to_r)
|
|
102
|
+
|
|
103
|
+
R_IDENTIFIER = r'(?:[a-zA-Z]|\.(?![0-9]))[a-zA-Z0-9._]*'
|
|
104
|
+
|
|
105
|
+
# In-process native thread pools to cap before R forks. R's
|
|
106
|
+
# `parallel::mcparallel` (used by the `mc.*` BART functions) forks, but GNU
|
|
107
|
+
# libgomp is not fork-safe: a forked child that enters an OpenMP parallel region
|
|
108
|
+
# hangs forever on a barrier because the worker threads do not survive the fork.
|
|
109
|
+
# The threaded OpenBLAS that R's LAPACK calls (e.g. `summary(lm(...))` for the
|
|
110
|
+
# `sigest` default) dispatches through libgomp, so a child deadlocks there.
|
|
111
|
+
# Running these pools single-threaded across the fork stops the thread team from
|
|
112
|
+
# being started at all, sidestepping the deadlock. Each entry is a (getter,
|
|
113
|
+
# setter) pair of C symbols; missing ones (e.g. a single-threaded reference BLAS)
|
|
114
|
+
# are skipped.
|
|
115
|
+
NATIVE_THREAD_POOLS = (
|
|
116
|
+
('omp_get_max_threads', 'omp_set_num_threads'),
|
|
117
|
+
('openblas_get_num_threads', 'openblas_set_num_threads'),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@contextmanager
|
|
122
|
+
def fork_safe_native_threads() -> Iterator[None]:
|
|
123
|
+
"""Cap OpenMP/OpenBLAS thread pools at one thread for the duration.
|
|
124
|
+
|
|
125
|
+
Workaround for the deadlock that hangs the children forked by R's
|
|
126
|
+
``parallel::mcparallel`` when GNU libgomp has a live thread pool (see
|
|
127
|
+
`NATIVE_THREAD_POOLS`). The previous thread counts are restored on exit.
|
|
128
|
+
"""
|
|
129
|
+
handle = ctypes.CDLL(None)
|
|
130
|
+
saved = []
|
|
131
|
+
for getter_name, setter_name in NATIVE_THREAD_POOLS:
|
|
132
|
+
try:
|
|
133
|
+
getter = getattr(handle, getter_name)
|
|
134
|
+
setter = getattr(handle, setter_name)
|
|
135
|
+
except AttributeError:
|
|
136
|
+
continue
|
|
137
|
+
getter.restype = ctypes.c_int
|
|
138
|
+
setter.argtypes = (ctypes.c_int,)
|
|
139
|
+
saved.append((setter, getter()))
|
|
140
|
+
setter(1)
|
|
141
|
+
try:
|
|
142
|
+
yield
|
|
143
|
+
finally:
|
|
144
|
+
for setter, nthreads in saved:
|
|
145
|
+
setter(nthreads)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class RObjectBase:
|
|
149
|
+
"""
|
|
150
|
+
Base class for Python wrappers of R objects creators.
|
|
151
|
+
|
|
152
|
+
Subclasses should define the class attribute `_rfuncname`, and declare
|
|
153
|
+
stub methods decorated with `rmethod`.
|
|
154
|
+
|
|
155
|
+
_rfuncname : str
|
|
156
|
+
An R function in the format ``'<package>::<function>``. The function is
|
|
157
|
+
called with the initialization arguments, converted to R objects, and is
|
|
158
|
+
expected to return an R object. The attributes of the R object are
|
|
159
|
+
converted to equivalent Python values and set as attributes of the
|
|
160
|
+
Python object. The R object itself is assigned to the member `_robject`.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
_converter = (
|
|
164
|
+
robjects.default_converter
|
|
165
|
+
+ PANDAS_CONVERTER
|
|
166
|
+
+ POLARS_CONVERTER
|
|
167
|
+
+ NUMPY_CONVERTER
|
|
168
|
+
+ BOOL_VECTOR_CONVERTER
|
|
169
|
+
+ JAX_CONVERTER
|
|
170
|
+
+ DICT_CONVERTER
|
|
171
|
+
)
|
|
172
|
+
_convctx = conversion.localconverter(_converter)
|
|
173
|
+
|
|
174
|
+
def _py2r(self, x: object) -> object:
|
|
175
|
+
if isinstance(x, __class__):
|
|
176
|
+
return x._robject # noqa: SLF001, same-class access
|
|
177
|
+
with self._convctx:
|
|
178
|
+
return self._converter.py2rpy(x)
|
|
179
|
+
|
|
180
|
+
def _r2py(self, x: object) -> object:
|
|
181
|
+
with self._convctx:
|
|
182
|
+
return self._converter.rpy2py(x)
|
|
183
|
+
|
|
184
|
+
def _args2r(self, args: Iterable[Any]) -> tuple[Any, ...]:
|
|
185
|
+
return tuple(map(self._py2r, args))
|
|
186
|
+
|
|
187
|
+
def _kw2r(self, kw: Mapping[str, Any]) -> dict[str, Any]:
|
|
188
|
+
return {key: self._py2r(value) for key, value in kw.items()}
|
|
189
|
+
|
|
190
|
+
_rfuncname: str = NotImplemented
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def _library(self) -> str:
|
|
194
|
+
"""Parse `_rfuncname` to get the library. Also checks `_rfuncname` is valid."""
|
|
195
|
+
pattern = rf'^({R_IDENTIFIER})::({R_IDENTIFIER})$'
|
|
196
|
+
m = match(pattern, self._rfuncname)
|
|
197
|
+
if m is None:
|
|
198
|
+
msg = f'Invalid _rfuncname: {self._rfuncname}.'
|
|
199
|
+
raise ValueError(msg)
|
|
200
|
+
return m.group(1)
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def _has_named_components(obj: object) -> bool:
|
|
204
|
+
"""Whether `obj` exposes named components to set as attributes.
|
|
205
|
+
|
|
206
|
+
Only an R named list qualifies. A bare matrix (as `bartModelMatrix`
|
|
207
|
+
gives with ``numcut=0``) is excluded by the `ListVector` check: rpy2
|
|
208
|
+
reports a matrix's dimnames as ``names``, so the names check alone
|
|
209
|
+
would not cut it out.
|
|
210
|
+
"""
|
|
211
|
+
names = getattr(obj, 'names', None)
|
|
212
|
+
return (
|
|
213
|
+
isinstance(obj, robjects.vectors.ListVector)
|
|
214
|
+
and names is not None
|
|
215
|
+
and names is not robjects.NULL
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def _invoke_rfunc(self, args: Iterable[Any], kw: Mapping[str, Any]) -> object:
|
|
219
|
+
"""Load the namespace and call `_rfuncname` on the converted arguments."""
|
|
220
|
+
robjects.r(f'loadNamespace("{self._library}")')
|
|
221
|
+
func = robjects.r(self._rfuncname)
|
|
222
|
+
return func(*self._args2r(args), **self._kw2r(kw))
|
|
223
|
+
|
|
224
|
+
def _set_attrs_from_robject(self) -> None:
|
|
225
|
+
"""Set the named components of `self._robject` as Python attributes."""
|
|
226
|
+
if self._has_named_components(self._robject):
|
|
227
|
+
for s, v in self._robject.items():
|
|
228
|
+
setattr(self, s.replace('.', '_'), self._r2py(v))
|
|
229
|
+
|
|
230
|
+
def __init__(self, *args: Any, **kw: Any) -> None:
|
|
231
|
+
self._robject = self._invoke_rfunc(args, kw)
|
|
232
|
+
self._set_attrs_from_robject()
|
|
233
|
+
|
|
234
|
+
def __init_subclass__(cls, **kw: Any) -> None:
|
|
235
|
+
"""Automatically add R documentation to subclasses."""
|
|
236
|
+
library, name = cls._rfuncname.split('::')
|
|
237
|
+
page = Package(library).fetch(name)
|
|
238
|
+
if cls.__doc__ is None:
|
|
239
|
+
cls.__doc__ = ''
|
|
240
|
+
# the R help text is plain text, not valid RST: append it as a literal
|
|
241
|
+
# block so docutils renders it verbatim instead of misparsing it
|
|
242
|
+
cls.__doc__ += '\nR documentation::\n\n' + indent(page.to_docstring(), ' ')
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def rmethod(meth: Callable, *, rname: str | None = None) -> Callable:
|
|
246
|
+
"""Automatically implement a method using the correspoding R method.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
meth
|
|
251
|
+
A method in a subclass of `RObjectBase`.
|
|
252
|
+
rname
|
|
253
|
+
The name of the method in R. If not specified, use the name of `meth`.
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
methimpl
|
|
258
|
+
An implementation of the method that calls the R method. The original
|
|
259
|
+
implementation of meth is completely discarded.
|
|
260
|
+
|
|
261
|
+
Examples
|
|
262
|
+
--------
|
|
263
|
+
>>> class MyRObject(RObjectBase):
|
|
264
|
+
... _rfuncname = 'mypackage::myfunction'
|
|
265
|
+
... @partial(rmethod, rname='my.method')
|
|
266
|
+
... def my_method(self, arg1: int, arg2: str):
|
|
267
|
+
... ...
|
|
268
|
+
"""
|
|
269
|
+
if rname is None:
|
|
270
|
+
rname = meth.__name__
|
|
271
|
+
|
|
272
|
+
# I can't automatically add a docstring to the method because the R class
|
|
273
|
+
# can be determined at runtime
|
|
274
|
+
|
|
275
|
+
@wraps(meth)
|
|
276
|
+
def impl(self: RObjectBase, *args: Any, **kw: Any) -> object:
|
|
277
|
+
if isinstance(self._robject, RS4):
|
|
278
|
+
func = robjects.r['$'](self._robject, rname)
|
|
279
|
+
out = func(*self._args2r(args), **self._kw2r(kw))
|
|
280
|
+
|
|
281
|
+
else:
|
|
282
|
+
if not fullmatch(R_IDENTIFIER, rname):
|
|
283
|
+
msg = f'Invalid R method name: {rname}'
|
|
284
|
+
raise ValueError(msg)
|
|
285
|
+
rclass = self._robject.rclass[0]
|
|
286
|
+
func = robjects.r(
|
|
287
|
+
f'getS3method("{rname}", "{rclass}", envir = asNamespace("{self._library}"))'
|
|
288
|
+
)
|
|
289
|
+
out = func(self._robject, *self._args2r(args), **self._kw2r(kw))
|
|
290
|
+
|
|
291
|
+
return self._r2py(out)
|
|
292
|
+
|
|
293
|
+
return impl
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# rbartpackages/src/rbartpackages/bartMachine.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The rbartpackages Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of rbartpackages.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Python wrapper of the R package bartMachine."""
|
|
26
|
+
|
|
27
|
+
# ruff: noqa: D102, ANN201, ANN002, ANN003
|
|
28
|
+
|
|
29
|
+
from rpy2 import robjects
|
|
30
|
+
|
|
31
|
+
from rbartpackages._base import RObjectBase, rmethod
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class bartMachine(RObjectBase): # noqa: D101, because the doc is pulled from R
|
|
35
|
+
_rfuncname = 'bartMachine::bartMachine'
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self, *args, num_cores: int | None = None, megabytes: int = 5000, **kw
|
|
39
|
+
) -> None:
|
|
40
|
+
# bartMachine uses the (incubating) Vector API on JDK 16+, so the JVM
|
|
41
|
+
# must be started with that module or fitting raises NoClassDefFoundError.
|
|
42
|
+
robjects.r(
|
|
43
|
+
'options(java.parameters = c('
|
|
44
|
+
f'"-Xmx{megabytes:d}m", "--add-modules=jdk.incubator.vector"))'
|
|
45
|
+
)
|
|
46
|
+
robjects.r('loadNamespace("bartMachine")')
|
|
47
|
+
if num_cores is not None:
|
|
48
|
+
robjects.r(f'bartMachine::set_bart_machine_num_cores({int(num_cores)})')
|
|
49
|
+
super().__init__(*args, **kw)
|
|
50
|
+
|
|
51
|
+
@rmethod
|
|
52
|
+
def predict(self, *args, **kw): ...
|
|
53
|
+
|
|
54
|
+
@rmethod
|
|
55
|
+
def get_posterior(self, *args, **kw): ...
|
|
56
|
+
|
|
57
|
+
@rmethod
|
|
58
|
+
def get_sigsqs(self, *args, **kw): ...
|
rbartpackages/dbarts.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# rbartpackages/src/rbartpackages/dbarts.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2025-2026, The rbartpackages Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of rbartpackages.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
24
|
+
|
|
25
|
+
"""Python wrapper of the R package `dbarts`."""
|
|
26
|
+
|
|
27
|
+
# ruff: noqa: D101, D102, ANN201, ANN002, ANN003
|
|
28
|
+
|
|
29
|
+
from rpy2 import robjects
|
|
30
|
+
|
|
31
|
+
from rbartpackages._base import RObjectBase, rmethod
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class bart(RObjectBase):
|
|
35
|
+
"""
|
|
36
|
+
Python interface to dbarts::bart.
|
|
37
|
+
|
|
38
|
+
The named numeric vector form of the `splitprobs` parameter must be
|
|
39
|
+
specified as a dictionary in Python.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
_rfuncname = 'dbarts::bart'
|
|
43
|
+
_split_probs = 'splitprobs'
|
|
44
|
+
|
|
45
|
+
def __init__(self, *args, **kw) -> None:
|
|
46
|
+
split_probs = kw.get(self._split_probs)
|
|
47
|
+
if isinstance(split_probs, dict):
|
|
48
|
+
values = list(split_probs.values())
|
|
49
|
+
names = list(split_probs.keys())
|
|
50
|
+
split_probs = robjects.FloatVector(values)
|
|
51
|
+
split_probs = robjects.r('setNames')(split_probs, names)
|
|
52
|
+
kw[self._split_probs] = split_probs
|
|
53
|
+
|
|
54
|
+
super().__init__(*args, **kw)
|
|
55
|
+
|
|
56
|
+
@rmethod
|
|
57
|
+
def predict(self, *args, **kw): ...
|
|
58
|
+
|
|
59
|
+
@rmethod
|
|
60
|
+
def extract(self, *args, **kw): ...
|
|
61
|
+
|
|
62
|
+
@rmethod
|
|
63
|
+
def fitted(self, *args, **kw): ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class bart2(bart):
|
|
67
|
+
"""
|
|
68
|
+
Python interface to dbarts::bart2.
|
|
69
|
+
|
|
70
|
+
The named numeric vector form of the `split_probs` parameter must be
|
|
71
|
+
specified as a dictionary in Python.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
_rfuncname = 'dbarts::bart2'
|
|
75
|
+
_split_probs = 'split_probs'
|
|
76
|
+
|
|
77
|
+
def __init__(self, formula: str, *args, **kw) -> None:
|
|
78
|
+
rformula = robjects.Formula(formula)
|
|
79
|
+
super().__init__(rformula, *args, **kw)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class rbart_vi(bart2):
|
|
83
|
+
"""
|
|
84
|
+
Python interface to dbarts::rbart_vi.
|
|
85
|
+
|
|
86
|
+
The named numeric vector form of the `split_probs` parameter must be
|
|
87
|
+
specified as a dictionary in Python.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
_rfuncname = 'dbarts::rbart_vi'
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class dbarts(RObjectBase):
|
|
94
|
+
_rfuncname = 'dbarts::dbarts'
|
|
95
|
+
|
|
96
|
+
@rmethod
|
|
97
|
+
def run(self, *args, **kw): ...
|
|
98
|
+
|
|
99
|
+
@rmethod
|
|
100
|
+
def sampleTreesFromPrior(self, *args, **kw): ...
|
|
101
|
+
|
|
102
|
+
@rmethod
|
|
103
|
+
def sampleNodeParametersFromPrior(self, *args, **kw): ...
|
|
104
|
+
|
|
105
|
+
@rmethod
|
|
106
|
+
def copy(self, *args, **kw): ...
|
|
107
|
+
|
|
108
|
+
@rmethod
|
|
109
|
+
def show(self, *args, **kw): ...
|
|
110
|
+
|
|
111
|
+
@rmethod
|
|
112
|
+
def predict(self, *args, **kw): ...
|
|
113
|
+
|
|
114
|
+
@rmethod
|
|
115
|
+
def setControl(self, *args, **kw): ...
|
|
116
|
+
|
|
117
|
+
@rmethod
|
|
118
|
+
def setModel(self, *args, **kw): ...
|
|
119
|
+
|
|
120
|
+
@rmethod
|
|
121
|
+
def setData(self, *args, **kw): ...
|
|
122
|
+
|
|
123
|
+
@rmethod
|
|
124
|
+
def setResponse(self, *args, **kw): ...
|
|
125
|
+
|
|
126
|
+
@rmethod
|
|
127
|
+
def setOffset(self, *args, **kw): ...
|
|
128
|
+
|
|
129
|
+
@rmethod
|
|
130
|
+
def setSigma(self, *args, **kw): ...
|
|
131
|
+
|
|
132
|
+
@rmethod
|
|
133
|
+
def setPredictor(self, *args, **kw): ...
|
|
134
|
+
|
|
135
|
+
@rmethod
|
|
136
|
+
def setTestPredictor(self, *args, **kw): ...
|
|
137
|
+
|
|
138
|
+
@rmethod
|
|
139
|
+
def setTestPredictorAndOffset(self, *args, **kw): ...
|
|
140
|
+
|
|
141
|
+
@rmethod
|
|
142
|
+
def setTestOffset(self, *args, **kw): ...
|
|
143
|
+
|
|
144
|
+
@rmethod
|
|
145
|
+
def printTrees(self, *args, **kw): ...
|
|
146
|
+
|
|
147
|
+
@rmethod
|
|
148
|
+
def plotTree(self, *args, **kw): ...
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class dbartsControl(RObjectBase):
|
|
152
|
+
_rfuncname = 'dbarts::dbartsControl'
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rbartpackages
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Python wrappers of R BART packages via rpy2
|
|
5
|
+
Author: Giacomo Petrillo
|
|
6
|
+
Author-email: Giacomo Petrillo <info@giacomopetrillo.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: rpy2>=3.6.0
|
|
10
|
+
Requires-Dist: numpy>=2.2.6
|
|
11
|
+
Requires-Dist: jaxtyping>=0.3.2
|
|
12
|
+
Requires-Dist: jax>=0.6.1 ; extra == 'jax'
|
|
13
|
+
Requires-Dist: pandas>=2.2.3 ; extra == 'pandas'
|
|
14
|
+
Requires-Dist: polars>=1.30.0 ; extra == 'polars'
|
|
15
|
+
Requires-Dist: pandas>=2.2.3 ; extra == 'polars'
|
|
16
|
+
Requires-Python: >=3.10
|
|
17
|
+
Project-URL: Homepage, https://github.com/bartz-org/rbartpackages
|
|
18
|
+
Project-URL: Documentation, https://bartz-org.github.io/rbartpackages/docs-dev
|
|
19
|
+
Project-URL: Issues, https://github.com/bartz-org/rbartpackages/issues
|
|
20
|
+
Provides-Extra: jax
|
|
21
|
+
Provides-Extra: pandas
|
|
22
|
+
Provides-Extra: polars
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# rbartpackages
|
|
26
|
+
|
|
27
|
+
Python wrappers of R BART (Bayesian Additive Regression Trees) packages, built on [rpy2](https://rpy2.github.io).
|
|
28
|
+
|
|
29
|
+
`rbartpackages` lets you call several R BART implementations from Python with a uniform, lightly-typed interface: arguments are converted to R, the fitted R object's components become Python attributes, and the original R documentation is attached to each wrapper class. It currently wraps:
|
|
30
|
+
|
|
31
|
+
- [`BART`](https://cran.r-project.org/package=BART)
|
|
32
|
+
- [`BART3`](https://github.com/rsparapa/bnptools) (the development superset of `BART`)
|
|
33
|
+
- [`bartMachine`](https://cran.r-project.org/package=bartMachine)
|
|
34
|
+
- [`dbarts`](https://cran.r-project.org/package=dbarts)
|
|
35
|
+
|
|
36
|
+
## Installation
|
|
37
|
+
|
|
38
|
+
```sh
|
|
39
|
+
pip install rbartpackages
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
You also need R with the package(s) you want to use installed (`BART`, `dbarts`, `bartMachine` from CRAN; `BART3` from `rsparapa/bnptools` on GitHub). `bartMachine` additionally requires Java. Optional extras `pandas`, `polars`, and `jax` enable passing those array/frame types directly. See the documentation for details.
|
|
43
|
+
|
|
44
|
+
## Usage
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
import numpy as np
|
|
48
|
+
from rbartpackages import BART3
|
|
49
|
+
|
|
50
|
+
x_train = np.random.randn(100, 5)
|
|
51
|
+
y_train = x_train[:, 0] + 0.1 * np.random.randn(100)
|
|
52
|
+
|
|
53
|
+
bart = BART3.gbart(x_train=x_train, y_train=y_train, ndpost=200)
|
|
54
|
+
y_pred = bart.predict(x_train) # shape (ndpost, n)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
R argument names with dots are passed with underscores (`x.train` → `x_train`).
|
|
58
|
+
|
|
59
|
+
## Links
|
|
60
|
+
|
|
61
|
+
- [Documentation](https://bartz-org.github.io/rbartpackages/docs-dev)
|
|
62
|
+
- [Repository](https://github.com/bartz-org/rbartpackages)
|
|
63
|
+
- [List of BART packages](https://bartz-org.github.io/bartz/docs-dev/pkglist.html) (maintained in the bartz docs)
|
|
64
|
+
|
|
65
|
+
These wrappers originated in the [bartz](https://github.com/bartz-org/bartz) project, where they are used to validate against reference R implementations.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
rbartpackages/BART.py,sha256=Hs227Xk4NomTkUj0zxtrUlky1TQdtVcrfD5maZFdD9A,13154
|
|
2
|
+
rbartpackages/BART3.py,sha256=_OlOYiXXWQA0HGk6MqDmv-epa81bmP4LC5G2TKimSyc,14499
|
|
3
|
+
rbartpackages/__init__.py,sha256=gRLv9gu9gJ_ry7dts9TKfJh8gH77mlqQYl2JTE1_q1s,1648
|
|
4
|
+
rbartpackages/_base.py,sha256=23KeW39IoFxkZ-t1U9rDwoqJCdgOdDM4gf5L79OLw7o,10521
|
|
5
|
+
rbartpackages/_version.py,sha256=Y4LHeOXBqKMcOffQfycAar5j-JEc-ObkXZE04PES0Tw,51
|
|
6
|
+
rbartpackages/bartMachine.py,sha256=bZPHksSM6w7ebNrfT2XovOLjNmzliuf64OcsBAwZ7Fg,2311
|
|
7
|
+
rbartpackages/dbarts.py,sha256=WuhrFmqfzZCAeo6G6cp5BdYGX9MyIP2giTYgiUJK11c,4078
|
|
8
|
+
rbartpackages-0.1.0.dist-info/licenses/LICENSE,sha256=pouECvzJhnd4fBwt1rCQYLgRpjnCdMAEpqOE8Pwn-sI,1092
|
|
9
|
+
rbartpackages-0.1.0.dist-info/WHEEL,sha256=f5fWSvWsg5Knq5GWa6t1nJIug0Tqo69GqAWD_9LbBKw,81
|
|
10
|
+
rbartpackages-0.1.0.dist-info/METADATA,sha256=zIkothg8jK6V4dyC3zEqgfg-pH9hweqZuq3-g245CbA,2702
|
|
11
|
+
rbartpackages-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024-2026 The rbartpackages Contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|