bartz 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +43 -18
- bartz/_version.py +1 -1
- bartz/grove.py +19 -14
- bartz/jaxext.py +48 -21
- bartz/mcmcloop.py +13 -15
- bartz/mcmcstep.py +681 -299
- bartz/prepcovars.py +43 -13
- bartz-0.3.0.dist-info/METADATA +77 -0
- bartz-0.3.0.dist-info/RECORD +13 -0
- bartz-0.2.1.dist-info/METADATA +0 -32
- bartz-0.2.1.dist-info/RECORD +0 -13
- {bartz-0.2.1.dist-info → bartz-0.3.0.dist-info}/LICENSE +0 -0
- {bartz-0.2.1.dist-info → bartz-0.3.0.dist-info}/WHEEL +0 -0
bartz/BART.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -49,6 +49,9 @@ class gbart:
|
|
|
49
49
|
The training responses.
|
|
50
50
|
x_test : array (p, m) or DataFrame, optional
|
|
51
51
|
The test predictors.
|
|
52
|
+
usequants : bool, default False
|
|
53
|
+
Whether to use predictors quantiles instead of a uniform grid to bin
|
|
54
|
+
predictors.
|
|
52
55
|
sigest : float, optional
|
|
53
56
|
An estimate of the residual standard deviation on `y_train`, used to
|
|
54
57
|
set `lamda`. If not specified, it is estimated by linear regression.
|
|
@@ -82,10 +85,16 @@ class gbart:
|
|
|
82
85
|
ntree : int, default 200
|
|
83
86
|
The number of trees used to represent the latent mean function.
|
|
84
87
|
numcut : int, default 255
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
88
|
+
If `usequants` is `False`: the exact number of cutpoints used to bin the
|
|
89
|
+
predictors, ranging between the minimum and maximum observed values
|
|
90
|
+
(excluded).
|
|
91
|
+
|
|
92
|
+
If `usequants` is `True`: the maximum number of cutpoints to use for
|
|
93
|
+
binning the predictors. Each predictor is binned such that its
|
|
94
|
+
distribution in `x_train` is approximately uniform across bins. The
|
|
95
|
+
number of bins is at most the number of unique values appearing in
|
|
96
|
+
`x_train`, or ``numcut + 1``.
|
|
97
|
+
|
|
89
98
|
Before running the algorithm, the predictors are compressed to the
|
|
90
99
|
smallest integer type that fits the bin indices, so `numcut` is best set
|
|
91
100
|
to the maximum value of an unsigned integer type.
|
|
@@ -126,6 +135,8 @@ class gbart:
|
|
|
126
135
|
The number of trees.
|
|
127
136
|
maxdepth : int
|
|
128
137
|
The maximum depth of the trees.
|
|
138
|
+
initkw : dict
|
|
139
|
+
Additional arguments passed to `mcmcstep.init`.
|
|
129
140
|
|
|
130
141
|
Methods
|
|
131
142
|
-------
|
|
@@ -133,21 +144,26 @@ class gbart:
|
|
|
133
144
|
|
|
134
145
|
Notes
|
|
135
146
|
-----
|
|
136
|
-
This interface imitates the function
|
|
147
|
+
This interface imitates the function ``gbart`` from the R package `BART
|
|
137
148
|
<https://cran.r-project.org/package=BART>`_, but with these differences:
|
|
138
149
|
|
|
139
150
|
- If `x_train` and `x_test` are matrices, they have one predictor per row
|
|
140
151
|
instead of per column.
|
|
152
|
+
- If ``usequants=False``, R BART switches to quantiles anyway if there are
|
|
153
|
+
less predictor values than the required number of bins, while bartz
|
|
154
|
+
always follows the specification.
|
|
141
155
|
- The error variance parameter is called `lamda` instead of `lambda`.
|
|
142
|
-
- `usequants` is always `True`.
|
|
143
156
|
- `rm_const` is always `False`.
|
|
144
157
|
- The default `numcut` is 255 instead of 100.
|
|
145
158
|
- A lot of functionality is missing (variable selection, discrete response).
|
|
146
159
|
- There are some additional attributes, and some missing.
|
|
160
|
+
|
|
161
|
+
The linear regression used to set `sigest` adds an intercept.
|
|
147
162
|
"""
|
|
148
163
|
|
|
149
164
|
def __init__(self, x_train, y_train, *,
|
|
150
165
|
x_test=None,
|
|
166
|
+
usequants=False,
|
|
151
167
|
sigest=None,
|
|
152
168
|
sigdf=3,
|
|
153
169
|
sigquant=0.9,
|
|
@@ -164,24 +180,25 @@ class gbart:
|
|
|
164
180
|
keepevery=1,
|
|
165
181
|
printevery=100,
|
|
166
182
|
seed=0,
|
|
183
|
+
initkw={},
|
|
167
184
|
):
|
|
168
185
|
|
|
169
186
|
x_train, x_train_fmt = self._process_predictor_input(x_train)
|
|
170
|
-
|
|
187
|
+
|
|
171
188
|
y_train, y_train_fmt = self._process_response_input(y_train)
|
|
172
189
|
self._check_same_length(x_train, y_train)
|
|
173
|
-
|
|
190
|
+
|
|
174
191
|
offset = self._process_offset_settings(y_train, offset)
|
|
175
192
|
scale = self._process_scale_settings(y_train, k)
|
|
176
193
|
lamda, sigest = self._process_noise_variance_settings(x_train, y_train, sigest, sigdf, sigquant, lamda, offset)
|
|
177
194
|
|
|
178
|
-
splits, max_split = self._determine_splits(x_train, numcut)
|
|
195
|
+
splits, max_split = self._determine_splits(x_train, usequants, numcut)
|
|
179
196
|
x_train = self._bin_predictors(x_train, splits)
|
|
180
197
|
|
|
181
198
|
y_train = self._transform_input(y_train, offset, scale)
|
|
182
199
|
lamda_scaled = lamda / (scale * scale)
|
|
183
200
|
|
|
184
|
-
mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree)
|
|
201
|
+
mcmc_state = self._setup_mcmc(x_train, y_train, max_split, lamda_scaled, sigdf, power, base, maxdepth, ntree, initkw)
|
|
185
202
|
final_state, burnin_trace, main_trace = self._run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed)
|
|
186
203
|
|
|
187
204
|
sigma = self._extract_sigma(main_trace, scale)
|
|
@@ -279,7 +296,10 @@ class gbart:
|
|
|
279
296
|
elif y_train.size <= x_train.shape[0]:
|
|
280
297
|
sigest2 = jnp.var(y_train - offset)
|
|
281
298
|
else:
|
|
282
|
-
|
|
299
|
+
x_centered = x_train.T - x_train.mean(axis=1)
|
|
300
|
+
y_centered = y_train - y_train.mean()
|
|
301
|
+
# centering is equivalent to adding an intercept column
|
|
302
|
+
_, chisq, rank, _ = jnp.linalg.lstsq(x_centered, y_centered)
|
|
283
303
|
chisq = chisq.squeeze(0)
|
|
284
304
|
dof = len(y_train) - rank
|
|
285
305
|
sigest2 = chisq / dof
|
|
@@ -305,8 +325,11 @@ class gbart:
|
|
|
305
325
|
return (y_train.max() - y_train.min()) / (2 * k)
|
|
306
326
|
|
|
307
327
|
@staticmethod
|
|
308
|
-
def _determine_splits(x_train, numcut):
|
|
309
|
-
|
|
328
|
+
def _determine_splits(x_train, usequants, numcut):
|
|
329
|
+
if usequants:
|
|
330
|
+
return prepcovars.quantilized_splits_from_matrix(x_train, numcut + 1)
|
|
331
|
+
else:
|
|
332
|
+
return prepcovars.uniform_splits_from_matrix(x_train, numcut + 1)
|
|
310
333
|
|
|
311
334
|
@staticmethod
|
|
312
335
|
def _bin_predictors(x, splits):
|
|
@@ -317,12 +340,12 @@ class gbart:
|
|
|
317
340
|
return (y - offset) / scale
|
|
318
341
|
|
|
319
342
|
@staticmethod
|
|
320
|
-
def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree):
|
|
343
|
+
def _setup_mcmc(x_train, y_train, max_split, lamda, sigdf, power, base, maxdepth, ntree, initkw):
|
|
321
344
|
depth = jnp.arange(maxdepth - 1)
|
|
322
345
|
p_nonterminal = base / (1 + depth).astype(float) ** power
|
|
323
346
|
sigma2_alpha = sigdf / 2
|
|
324
347
|
sigma2_beta = lamda * sigma2_alpha
|
|
325
|
-
|
|
348
|
+
kw = dict(
|
|
326
349
|
X=x_train,
|
|
327
350
|
y=y_train,
|
|
328
351
|
max_split=max_split,
|
|
@@ -332,6 +355,8 @@ class gbart:
|
|
|
332
355
|
sigma2_beta=sigma2_beta,
|
|
333
356
|
min_points_per_leaf=5,
|
|
334
357
|
)
|
|
358
|
+
kw.update(initkw)
|
|
359
|
+
return mcmcstep.init(**kw)
|
|
335
360
|
|
|
336
361
|
@staticmethod
|
|
337
362
|
def _run_mcmc(mcmc_state, ndpost, nskip, keepevery, printevery, seed):
|
|
@@ -354,7 +379,7 @@ class gbart:
|
|
|
354
379
|
def _extract_sigma(trace, scale):
|
|
355
380
|
return scale * jnp.sqrt(trace['sigma2'])
|
|
356
381
|
|
|
357
|
-
|
|
382
|
+
|
|
358
383
|
def _show_tree(self, i_sample, i_tree, print_all=False):
|
|
359
384
|
from . import debug
|
|
360
385
|
trace = self._main_trace
|
bartz/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.3.0'
|
bartz/grove.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -114,7 +114,7 @@ def traverse_tree(x, var_tree, split_tree):
|
|
|
114
114
|
|
|
115
115
|
split = split_tree[index]
|
|
116
116
|
var = var_tree[index]
|
|
117
|
-
|
|
117
|
+
|
|
118
118
|
leaf_found |= split == 0
|
|
119
119
|
child_index = (index << 1) + (x[var] >= split)
|
|
120
120
|
index = jnp.where(leaf_found, index, child_index)
|
|
@@ -147,7 +147,7 @@ def traverse_forest(X, var_trees, split_trees):
|
|
|
147
147
|
"""
|
|
148
148
|
return traverse_tree(X, var_trees, split_trees)
|
|
149
149
|
|
|
150
|
-
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
150
|
+
def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True):
|
|
151
151
|
"""
|
|
152
152
|
Evaluate a ensemble of trees at an array of points.
|
|
153
153
|
|
|
@@ -162,21 +162,26 @@ def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype):
|
|
|
162
162
|
The decision axes of the trees.
|
|
163
163
|
split_trees : array (m, 2 ** (d - 1))
|
|
164
164
|
The decision boundaries of the trees.
|
|
165
|
-
dtype : dtype
|
|
166
|
-
The dtype of the output.
|
|
165
|
+
dtype : dtype, optional
|
|
166
|
+
The dtype of the output. Ignored if `sum_trees` is `False`.
|
|
167
|
+
sum_trees : bool, default True
|
|
168
|
+
Whether to sum the values across trees.
|
|
167
169
|
|
|
168
170
|
Returns
|
|
169
171
|
-------
|
|
170
|
-
out : array (n,)
|
|
171
|
-
The sum of the values of the trees at the points in `X`.
|
|
172
|
+
out : array (n,) or (m, n)
|
|
173
|
+
The (sum of) the values of the trees at the points in `X`.
|
|
172
174
|
"""
|
|
173
175
|
indices = traverse_forest(X, var_trees, split_trees)
|
|
174
176
|
ntree, _ = leaf_trees.shape
|
|
175
|
-
tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))
|
|
176
|
-
leaves = leaf_trees[tree_index, indices]
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
177
|
+
tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1))
|
|
178
|
+
leaves = leaf_trees[tree_index[:, None], indices]
|
|
179
|
+
if sum_trees:
|
|
180
|
+
return jnp.sum(leaves, axis=0, dtype=dtype)
|
|
181
|
+
# this sum suggests to swap the vmaps, but I think it's better for X
|
|
182
|
+
# copying to keep it that way
|
|
183
|
+
else:
|
|
184
|
+
return leaves
|
|
180
185
|
|
|
181
186
|
def is_actual_leaf(split_tree, *, add_bottom_level=False):
|
|
182
187
|
"""
|
|
@@ -238,7 +243,7 @@ def tree_depths(tree_length):
|
|
|
238
243
|
tree_length : int
|
|
239
244
|
The length of the tree array, i.e., 2 ** d.
|
|
240
245
|
|
|
241
|
-
Returns
|
|
246
|
+
Returns
|
|
242
247
|
-------
|
|
243
248
|
depth : array (tree_length,)
|
|
244
249
|
The depth of each node. The root node (index 1) has depth 0. The depth
|
bartz/jaxext.py
CHANGED
|
@@ -196,13 +196,14 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
196
196
|
A jittable function with positional arguments only, with inputs and
|
|
197
197
|
outputs pytrees of arrays.
|
|
198
198
|
max_io_nbytes : int
|
|
199
|
-
The maximum number of input + output bytes in each batch
|
|
200
|
-
|
|
199
|
+
The maximum number of input + output bytes in each batch (excluding
|
|
200
|
+
unbatched arguments.)
|
|
201
|
+
in_axes : pytree of int or None, default 0
|
|
201
202
|
A tree matching the structure of the function input, indicating along
|
|
202
203
|
which axes each array should be batched. If a single integer, it is
|
|
203
|
-
used for all arrays.
|
|
204
|
+
used for all arrays. A `None` axis indicates to not batch an argument.
|
|
204
205
|
out_axes : pytree of ints, default 0
|
|
205
|
-
The same for outputs.
|
|
206
|
+
The same for outputs (but non-batching is not allowed).
|
|
206
207
|
return_nbatches : bool, default False
|
|
207
208
|
If True, the number of batches is returned as a second output.
|
|
208
209
|
|
|
@@ -218,8 +219,18 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
218
219
|
return tree_util.tree_map(lambda _: axes, tree)
|
|
219
220
|
return tree_util.tree_map(lambda _, axis: axis, tree, axes)
|
|
220
221
|
|
|
222
|
+
def check_no_nones(axes, tree):
|
|
223
|
+
def check_not_none(_, axis):
|
|
224
|
+
assert axis is not None
|
|
225
|
+
tree_util.tree_map(check_not_none, tree, axes)
|
|
226
|
+
|
|
221
227
|
def extract_size(axes, tree):
|
|
222
|
-
|
|
228
|
+
def get_size(x, axis):
|
|
229
|
+
if axis is None:
|
|
230
|
+
return None
|
|
231
|
+
else:
|
|
232
|
+
return x.shape[axis]
|
|
233
|
+
sizes = tree_util.tree_map(get_size, tree, axes)
|
|
223
234
|
sizes, _ = tree_util.tree_flatten(sizes)
|
|
224
235
|
assert all(s == sizes[0] for s in sizes)
|
|
225
236
|
return sizes[0]
|
|
@@ -243,23 +254,37 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
243
254
|
return dividend
|
|
244
255
|
|
|
245
256
|
def next_divisor(dividend, min_divisor):
|
|
257
|
+
if dividend == 0:
|
|
258
|
+
return min_divisor
|
|
246
259
|
if min_divisor * min_divisor <= dividend:
|
|
247
260
|
return next_divisor_small(dividend, min_divisor)
|
|
248
261
|
return next_divisor_large(dividend, min_divisor)
|
|
249
262
|
|
|
263
|
+
def pull_nonbatched(axes, tree):
|
|
264
|
+
def pull_nonbatched(x, axis):
|
|
265
|
+
if axis is None:
|
|
266
|
+
return None
|
|
267
|
+
else:
|
|
268
|
+
return x
|
|
269
|
+
return tree_util.tree_map(pull_nonbatched, tree, axes), tree
|
|
270
|
+
|
|
271
|
+
def push_nonbatched(axes, tree, original_tree):
|
|
272
|
+
def push_nonbatched(original_x, x, axis):
|
|
273
|
+
if axis is None:
|
|
274
|
+
return original_x
|
|
275
|
+
else:
|
|
276
|
+
return x
|
|
277
|
+
return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
|
|
278
|
+
|
|
250
279
|
def move_axes_out(axes, tree):
|
|
251
|
-
def move_axis_out(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
return x
|
|
255
|
-
return tree_util.tree_map(move_axis_out, axes, tree)
|
|
280
|
+
def move_axis_out(x, axis):
|
|
281
|
+
return jnp.moveaxis(x, axis, 0)
|
|
282
|
+
return tree_util.tree_map(move_axis_out, tree, axes)
|
|
256
283
|
|
|
257
284
|
def move_axes_in(axes, tree):
|
|
258
|
-
def move_axis_in(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
return x
|
|
262
|
-
return tree_util.tree_map(move_axis_in, axes, tree)
|
|
285
|
+
def move_axis_in(x, axis):
|
|
286
|
+
return jnp.moveaxis(x, 0, axis)
|
|
287
|
+
return tree_util.tree_map(move_axis_in, tree, axes)
|
|
263
288
|
|
|
264
289
|
def batch(tree, nbatches):
|
|
265
290
|
def batch(x):
|
|
@@ -287,16 +312,17 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
287
312
|
|
|
288
313
|
in_axes = expand_axes(initial_in_axes, args)
|
|
289
314
|
out_axes = expand_axes(initial_out_axes, example_result)
|
|
315
|
+
check_no_nones(out_axes, example_result)
|
|
316
|
+
|
|
317
|
+
size = extract_size((in_axes, out_axes), (args, example_result))
|
|
290
318
|
|
|
291
|
-
|
|
292
|
-
out_size = extract_size(out_axes, example_result)
|
|
293
|
-
assert in_size == out_size
|
|
294
|
-
size = in_size
|
|
319
|
+
args, nonbatched_args = pull_nonbatched(in_axes, args)
|
|
295
320
|
|
|
296
|
-
total_nbytes = sum_nbytes(args
|
|
321
|
+
total_nbytes = sum_nbytes((args, example_result))
|
|
297
322
|
min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes)
|
|
323
|
+
min_nbatches = max(1, min_nbatches)
|
|
298
324
|
nbatches = next_divisor(size, min_nbatches)
|
|
299
|
-
assert 1 <= nbatches <= size
|
|
325
|
+
assert 1 <= nbatches <= max(1, size)
|
|
300
326
|
assert size % nbatches == 0
|
|
301
327
|
assert total_nbytes % nbatches == 0
|
|
302
328
|
|
|
@@ -307,6 +333,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
307
333
|
|
|
308
334
|
def loop(_, args):
|
|
309
335
|
args = move_axes_in(in_axes, args)
|
|
336
|
+
args = push_nonbatched(in_axes, args, nonbatched_args)
|
|
310
337
|
result = func(*args)
|
|
311
338
|
result = move_axes_out(out_axes, result)
|
|
312
339
|
return None, result
|
bartz/mcmcloop.py
CHANGED
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
11
|
# copies of the Software, and to permit persons to whom the Software is
|
|
12
12
|
# furnished to do so, subject to the following conditions:
|
|
13
|
-
#
|
|
13
|
+
#
|
|
14
14
|
# The above copyright notice and this permission notice shall be included in all
|
|
15
15
|
# copies or substantial portions of the Software.
|
|
16
|
-
#
|
|
16
|
+
#
|
|
17
17
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
18
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
19
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
@@ -34,8 +34,9 @@ from jax import debug
|
|
|
34
34
|
from jax import numpy as jnp
|
|
35
35
|
from jax import lax
|
|
36
36
|
|
|
37
|
-
from . import
|
|
37
|
+
from . import jaxext
|
|
38
38
|
from . import grove
|
|
39
|
+
from . import mcmcstep
|
|
39
40
|
|
|
40
41
|
@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
|
41
42
|
def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
@@ -91,7 +92,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
91
92
|
the fields in `burnin_trace`.
|
|
92
93
|
"""
|
|
93
94
|
|
|
94
|
-
tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'
|
|
95
|
+
tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios'
|
|
95
96
|
|
|
96
97
|
tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees')
|
|
97
98
|
|
|
@@ -102,14 +103,11 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
102
103
|
key, subkey = random.split(key)
|
|
103
104
|
bart = mcmcstep.step(bart, subkey)
|
|
104
105
|
callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw)
|
|
105
|
-
output = {key: bart[key] for key in tracelist}
|
|
106
|
+
output = {key: bart[key] for key in tracelist if key in bart}
|
|
106
107
|
return (bart, i_total + 1, i_skip + 1, key), output
|
|
107
108
|
|
|
108
109
|
def empty_trace(bart, tracelist):
|
|
109
|
-
return
|
|
110
|
-
key: jnp.empty((0,) + bart[key].shape, bart[key].dtype)
|
|
111
|
-
for key in tracelist
|
|
112
|
-
}
|
|
110
|
+
return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart)
|
|
113
111
|
|
|
114
112
|
if n_burn > 0:
|
|
115
113
|
carry = bart, 0, 0, key
|
|
@@ -124,7 +122,7 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
124
122
|
main_loop = functools.partial(inner_loop, tracelist=[], burnin=False)
|
|
125
123
|
inner_carry = bart, i_total, 0, key
|
|
126
124
|
(bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip)
|
|
127
|
-
output = {key: bart[key] for key in tracelist_main}
|
|
125
|
+
output = {key: bart[key] for key in tracelist_main if key in bart}
|
|
128
126
|
return (bart, i_total, key), output
|
|
129
127
|
|
|
130
128
|
if n_save > 0:
|
|
@@ -135,12 +133,9 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
135
133
|
|
|
136
134
|
return bart, burnin_trace, main_trace
|
|
137
135
|
|
|
138
|
-
# TODO I could add an argument callback_state to carry over state. This would allow e.g. accumulating counts. If I made the callback return the mcmc state, I could modify the mcmc from the callback.
|
|
139
|
-
|
|
140
136
|
@functools.lru_cache
|
|
141
137
|
# cache to make the callback function object unique, such that the jit
|
|
142
|
-
# of run_mcmc recognizes it
|
|
143
|
-
# printevery a runtime quantity
|
|
138
|
+
# of run_mcmc recognizes it
|
|
144
139
|
def make_simple_print_callback(printevery):
|
|
145
140
|
"""
|
|
146
141
|
Create a logging callback function for MCMC iterations.
|
|
@@ -193,7 +188,10 @@ def evaluate_trace(trace, X):
|
|
|
193
188
|
y : array (n_trace, n)
|
|
194
189
|
The predictions for each iteration of the MCMC.
|
|
195
190
|
"""
|
|
191
|
+
evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
|
|
192
|
+
evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0))
|
|
196
193
|
def loop(_, state):
|
|
197
|
-
|
|
194
|
+
values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees'])
|
|
195
|
+
return None, jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
198
196
|
_, y = lax.scan(loop, None, trace)
|
|
199
197
|
return y
|