bartz 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +99 -39
- bartz/__init__.py +6 -14
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +20 -11
- bartz/jaxext.py +44 -38
- bartz/mcmcloop.py +119 -58
- bartz/mcmcstep.py +426 -173
- bartz/prepcovars.py +22 -9
- bartz-0.5.0.dist-info/METADATA +48 -0
- bartz-0.5.0.dist-info/RECORD +13 -0
- bartz-0.5.0.dist-info/WHEEL +4 -0
- bartz-0.4.0.dist-info/LICENSE +0 -21
- bartz-0.4.0.dist-info/METADATA +0 -77
- bartz-0.4.0.dist-info/RECORD +0 -13
- bartz-0.4.0.dist-info/WHEEL +0 -4
bartz/jaxext.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/jaxext.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -26,11 +26,11 @@ import functools
|
|
|
26
26
|
import math
|
|
27
27
|
import warnings
|
|
28
28
|
|
|
29
|
-
from scipy import special
|
|
30
29
|
import jax
|
|
30
|
+
from jax import lax, tree_util
|
|
31
31
|
from jax import numpy as jnp
|
|
32
|
-
from
|
|
33
|
-
|
|
32
|
+
from scipy import special
|
|
33
|
+
|
|
34
34
|
|
|
35
35
|
def float_type(*args):
|
|
36
36
|
"""
|
|
@@ -39,62 +39,46 @@ def float_type(*args):
|
|
|
39
39
|
t = jnp.result_type(*args)
|
|
40
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
41
41
|
|
|
42
|
+
|
|
42
43
|
def castto(func, type):
|
|
43
44
|
@functools.wraps(func)
|
|
44
45
|
def newfunc(*args, **kw):
|
|
45
46
|
return func(*args, **kw).astype(type)
|
|
47
|
+
|
|
46
48
|
return newfunc
|
|
47
49
|
|
|
48
|
-
def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs):
|
|
49
|
-
""" version of `jax.pure_callback` that deals correctly with ufuncs,
|
|
50
|
-
see `<https://github.com/google/jax/issues/17187>`_ """
|
|
51
|
-
if excluded is None:
|
|
52
|
-
excluded = ()
|
|
53
|
-
shape = jnp.broadcast_shapes(*(
|
|
54
|
-
a.shape
|
|
55
|
-
for i, a in enumerate(args)
|
|
56
|
-
if i not in excluded
|
|
57
|
-
))
|
|
58
|
-
ndim = len(shape)
|
|
59
|
-
padded_args = [
|
|
60
|
-
a if i in excluded
|
|
61
|
-
else jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
|
|
62
|
-
for i, a in enumerate(args)
|
|
63
|
-
]
|
|
64
|
-
result = jax.ShapeDtypeStruct(shape, dtype)
|
|
65
|
-
return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs)
|
|
66
|
-
|
|
67
|
-
# TODO when jax solves this, check version and piggyback on original if new
|
|
68
50
|
|
|
69
51
|
class scipy:
|
|
70
|
-
|
|
71
52
|
class special:
|
|
72
|
-
|
|
73
53
|
@functools.wraps(special.gammainccinv)
|
|
74
54
|
def gammainccinv(a, y):
|
|
75
55
|
a = jnp.asarray(a)
|
|
76
56
|
y = jnp.asarray(y)
|
|
57
|
+
shape = jnp.broadcast_shapes(a.shape, y.shape)
|
|
77
58
|
dtype = float_type(a.dtype, y.dtype)
|
|
59
|
+
dummy = jax.ShapeDtypeStruct(shape, dtype)
|
|
78
60
|
ufunc = castto(special.gammainccinv, dtype)
|
|
79
|
-
return
|
|
61
|
+
return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
|
|
80
62
|
|
|
81
63
|
class stats:
|
|
82
|
-
|
|
83
64
|
class invgamma:
|
|
84
|
-
|
|
85
65
|
def ppf(q, a):
|
|
86
66
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
87
67
|
|
|
88
|
-
|
|
68
|
+
|
|
89
69
|
def vmap_nodoc(fun, *args, **kw):
|
|
90
70
|
"""
|
|
91
|
-
|
|
71
|
+
Wrapper of `jax.vmap` that preserves the docstring of the input function.
|
|
72
|
+
|
|
73
|
+
This is useful if the docstring already takes into account that the
|
|
74
|
+
arguments have additional axes due to vmap.
|
|
92
75
|
"""
|
|
93
76
|
doc = fun.__doc__
|
|
94
77
|
fun = jax.vmap(fun, *args, **kw)
|
|
95
78
|
fun.__doc__ = doc
|
|
96
79
|
return fun
|
|
97
80
|
|
|
81
|
+
|
|
98
82
|
def huge_value(x):
|
|
99
83
|
"""
|
|
100
84
|
Return the maximum value that can be stored in `x`.
|
|
@@ -114,19 +98,21 @@ def huge_value(x):
|
|
|
114
98
|
else:
|
|
115
99
|
return jnp.inf
|
|
116
100
|
|
|
101
|
+
|
|
117
102
|
def minimal_unsigned_dtype(max_value):
|
|
118
103
|
"""
|
|
119
104
|
Return the smallest unsigned integer dtype that can represent a given
|
|
120
105
|
maximum value (inclusive).
|
|
121
106
|
"""
|
|
122
|
-
if max_value < 2
|
|
107
|
+
if max_value < 2**8:
|
|
123
108
|
return jnp.uint8
|
|
124
|
-
if max_value < 2
|
|
109
|
+
if max_value < 2**16:
|
|
125
110
|
return jnp.uint16
|
|
126
|
-
if max_value < 2
|
|
111
|
+
if max_value < 2**32:
|
|
127
112
|
return jnp.uint32
|
|
128
113
|
return jnp.uint64
|
|
129
114
|
|
|
115
|
+
|
|
130
116
|
def signed_to_unsigned(int_dtype):
|
|
131
117
|
"""
|
|
132
118
|
Map a signed integer type to its unsigned counterpart. Unsigned types are
|
|
@@ -144,12 +130,14 @@ def signed_to_unsigned(int_dtype):
|
|
|
144
130
|
if int_dtype == jnp.int64:
|
|
145
131
|
return jnp.uint64
|
|
146
132
|
|
|
133
|
+
|
|
147
134
|
def ensure_unsigned(x):
|
|
148
135
|
"""
|
|
149
136
|
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
150
137
|
"""
|
|
151
138
|
return x.astype(signed_to_unsigned(x.dtype))
|
|
152
139
|
|
|
140
|
+
|
|
153
141
|
@functools.partial(jax.jit, static_argnums=(1,))
|
|
154
142
|
def unique(x, size, fill_value):
|
|
155
143
|
"""
|
|
@@ -177,15 +165,18 @@ def unique(x, size, fill_value):
|
|
|
177
165
|
if size == 0:
|
|
178
166
|
return jnp.empty(0, x.dtype), 0
|
|
179
167
|
x = jnp.sort(x)
|
|
168
|
+
|
|
180
169
|
def loop(carry, x):
|
|
181
170
|
i_out, i_in, last, out = carry
|
|
182
171
|
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
183
172
|
out = out.at[i_out].set(x)
|
|
184
173
|
return (i_out, i_in + 1, x, out), None
|
|
174
|
+
|
|
185
175
|
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
186
176
|
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
187
177
|
return out, actual_length + 1
|
|
188
178
|
|
|
179
|
+
|
|
189
180
|
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
190
181
|
"""
|
|
191
182
|
Batch a function such that each batch is smaller than a threshold.
|
|
@@ -222,6 +213,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
222
213
|
def check_no_nones(axes, tree):
|
|
223
214
|
def check_not_none(_, axis):
|
|
224
215
|
assert axis is not None
|
|
216
|
+
|
|
225
217
|
tree_util.tree_map(check_not_none, tree, axes)
|
|
226
218
|
|
|
227
219
|
def extract_size(axes, tree):
|
|
@@ -230,6 +222,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
230
222
|
return None
|
|
231
223
|
else:
|
|
232
224
|
return x.shape[axis]
|
|
225
|
+
|
|
233
226
|
sizes = tree_util.tree_map(get_size, tree, axes)
|
|
234
227
|
sizes, _ = tree_util.tree_flatten(sizes)
|
|
235
228
|
assert all(s == sizes[0] for s in sizes)
|
|
@@ -238,6 +231,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
238
231
|
def sum_nbytes(tree):
|
|
239
232
|
def nbytes(x):
|
|
240
233
|
return math.prod(x.shape) * x.dtype.itemsize
|
|
234
|
+
|
|
241
235
|
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
242
236
|
|
|
243
237
|
def next_divisor_small(dividend, min_divisor):
|
|
@@ -266,6 +260,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
266
260
|
return None
|
|
267
261
|
else:
|
|
268
262
|
return x
|
|
263
|
+
|
|
269
264
|
return tree_util.tree_map(pull_nonbatched, tree, axes), tree
|
|
270
265
|
|
|
271
266
|
def push_nonbatched(axes, tree, original_tree):
|
|
@@ -274,32 +269,38 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
274
269
|
return original_x
|
|
275
270
|
else:
|
|
276
271
|
return x
|
|
272
|
+
|
|
277
273
|
return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
|
|
278
274
|
|
|
279
275
|
def move_axes_out(axes, tree):
|
|
280
276
|
def move_axis_out(x, axis):
|
|
281
277
|
return jnp.moveaxis(x, axis, 0)
|
|
278
|
+
|
|
282
279
|
return tree_util.tree_map(move_axis_out, tree, axes)
|
|
283
280
|
|
|
284
281
|
def move_axes_in(axes, tree):
|
|
285
282
|
def move_axis_in(x, axis):
|
|
286
283
|
return jnp.moveaxis(x, 0, axis)
|
|
284
|
+
|
|
287
285
|
return tree_util.tree_map(move_axis_in, tree, axes)
|
|
288
286
|
|
|
289
287
|
def batch(tree, nbatches):
|
|
290
288
|
def batch(x):
|
|
291
289
|
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
290
|
+
|
|
292
291
|
return tree_util.tree_map(batch, tree)
|
|
293
292
|
|
|
294
293
|
def unbatch(tree):
|
|
295
294
|
def unbatch(x):
|
|
296
295
|
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
296
|
+
|
|
297
297
|
return tree_util.tree_map(unbatch, tree)
|
|
298
298
|
|
|
299
299
|
def check_same(tree1, tree2):
|
|
300
300
|
def check_same(x1, x2):
|
|
301
301
|
assert x1.shape == x2.shape
|
|
302
302
|
assert x1.dtype == x2.dtype
|
|
303
|
+
|
|
303
304
|
tree_util.tree_map(check_same, tree1, tree2)
|
|
304
305
|
|
|
305
306
|
initial_in_axes = in_axes
|
|
@@ -319,7 +320,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
319
320
|
args, nonbatched_args = pull_nonbatched(in_axes, args)
|
|
320
321
|
|
|
321
322
|
total_nbytes = sum_nbytes((args, example_result))
|
|
322
|
-
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
323
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
324
|
+
total_nbytes % max_io_nbytes
|
|
325
|
+
)
|
|
323
326
|
min_nbatches = max(1, min_nbatches)
|
|
324
327
|
nbatches = next_divisor(size, min_nbatches)
|
|
325
328
|
assert 1 <= nbatches <= max(1, size)
|
|
@@ -329,7 +332,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
329
332
|
batch_nbytes = total_nbytes // nbatches
|
|
330
333
|
if batch_nbytes > max_io_nbytes:
|
|
331
334
|
assert size == nbatches
|
|
332
|
-
warnings.warn(
|
|
335
|
+
warnings.warn(
|
|
336
|
+
f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
|
|
337
|
+
)
|
|
333
338
|
|
|
334
339
|
def loop(_, args):
|
|
335
340
|
args = move_axes_in(in_axes, args)
|
|
@@ -352,10 +357,11 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
352
357
|
|
|
353
358
|
return batched_func
|
|
354
359
|
|
|
360
|
+
|
|
355
361
|
@tree_util.register_pytree_node_class
|
|
356
362
|
class LeafDict(dict):
|
|
357
|
-
"""
|
|
358
|
-
values
|
|
363
|
+
"""dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
364
|
+
values"""
|
|
359
365
|
|
|
360
366
|
def tree_flatten(self):
|
|
361
367
|
return (), self
|
bartz/mcmcloop.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcloop.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -29,22 +29,21 @@ Functions that implement the full BART posterior MCMC loop.
|
|
|
29
29
|
import functools
|
|
30
30
|
|
|
31
31
|
import jax
|
|
32
|
-
from jax import random
|
|
33
|
-
from jax import debug
|
|
32
|
+
from jax import debug, lax, random, tree
|
|
34
33
|
from jax import numpy as jnp
|
|
35
|
-
from jax import lax
|
|
36
34
|
|
|
37
|
-
from . import jaxext
|
|
38
|
-
from . import grove
|
|
39
|
-
from . import mcmcstep
|
|
35
|
+
from . import grove, jaxext, mcmcstep
|
|
40
36
|
|
|
41
|
-
|
|
42
|
-
|
|
37
|
+
|
|
38
|
+
@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))
|
|
39
|
+
def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
|
|
43
40
|
"""
|
|
44
41
|
Run the MCMC for the BART posterior.
|
|
45
42
|
|
|
46
43
|
Parameters
|
|
47
44
|
----------
|
|
45
|
+
key : jax.dtypes.prng_key array
|
|
46
|
+
The key for random number generation.
|
|
48
47
|
bart : dict
|
|
49
48
|
The initial MCMC state, as created and updated by the functions in
|
|
50
49
|
`bartz.mcmcstep`.
|
|
@@ -54,88 +53,123 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
54
53
|
The number of iterations to save.
|
|
55
54
|
n_skip : int
|
|
56
55
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
|
+
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
57
57
|
callback : callable
|
|
58
58
|
An arbitrary function run at each iteration, called with the following
|
|
59
59
|
arguments, passed by keyword:
|
|
60
60
|
|
|
61
61
|
bart : dict
|
|
62
|
-
The
|
|
62
|
+
The MCMC state just after updating it.
|
|
63
63
|
burnin : bool
|
|
64
64
|
Whether the last iteration was in the burn-in phase.
|
|
65
65
|
i_total : int
|
|
66
66
|
The index of the last iteration (0-based).
|
|
67
67
|
i_skip : int
|
|
68
|
-
The
|
|
69
|
-
|
|
68
|
+
The number of MCMC updates from the last saved state. The initial
|
|
69
|
+
state counts as saved, even if it's not copied into the trace.
|
|
70
70
|
n_burn, n_save, n_skip : int
|
|
71
71
|
The corresponding arguments as-is.
|
|
72
72
|
|
|
73
73
|
Since this function is called under the jax jit, the values are not
|
|
74
74
|
available at the time the Python code is executed. Use the utilities in
|
|
75
75
|
`jax.debug` to access the values at actual runtime.
|
|
76
|
-
key : jax.dtypes.prng_key array
|
|
77
|
-
The key for random number generation.
|
|
78
76
|
|
|
79
77
|
Returns
|
|
80
78
|
-------
|
|
81
79
|
bart : dict
|
|
82
80
|
The final MCMC state.
|
|
83
|
-
burnin_trace : dict
|
|
81
|
+
burnin_trace : dict of (n_burn, ...) arrays
|
|
84
82
|
The trace of the burn-in phase, containing the following subset of
|
|
85
83
|
fields from the `bart` dictionary, with an additional head index that
|
|
86
84
|
runs over MCMC iterations: 'sigma2', 'grow_prop_count',
|
|
87
85
|
'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
|
|
88
|
-
main_trace : dict
|
|
86
|
+
main_trace : dict of (n_save, ...) arrays
|
|
89
87
|
The trace of the main phase, containing the following subset of fields
|
|
90
88
|
from the `bart` dictionary, with an additional head index that runs
|
|
91
89
|
over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
|
|
92
90
|
the fields in `burnin_trace`.
|
|
91
|
+
|
|
92
|
+
Notes
|
|
93
|
+
-----
|
|
94
|
+
The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
|
|
95
|
+
not include the initial state, and include the final state.
|
|
93
96
|
"""
|
|
94
97
|
|
|
95
|
-
|
|
98
|
+
tracevars_light = (
|
|
99
|
+
'sigma2',
|
|
100
|
+
'grow_prop_count',
|
|
101
|
+
'grow_acc_count',
|
|
102
|
+
'prune_prop_count',
|
|
103
|
+
'prune_acc_count',
|
|
104
|
+
'ratios',
|
|
105
|
+
)
|
|
106
|
+
tracevars_heavy = ('leaf_trees', 'var_trees', 'split_trees')
|
|
107
|
+
|
|
108
|
+
def empty_trace(length, bart, tracelist):
|
|
109
|
+
bart = {k: v for k, v in bart.items() if k in tracelist}
|
|
110
|
+
return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
96
111
|
|
|
97
|
-
|
|
112
|
+
trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
|
|
113
|
+
trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
|
|
98
114
|
|
|
99
115
|
callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
|
|
100
116
|
|
|
101
|
-
|
|
102
|
-
|
|
117
|
+
carry = (bart, 0, key, trace_light, trace_heavy)
|
|
118
|
+
|
|
119
|
+
def loop(carry, _):
|
|
120
|
+
bart, i_total, key, trace_light, trace_heavy = carry
|
|
121
|
+
|
|
103
122
|
key, subkey = random.split(key)
|
|
104
|
-
bart = mcmcstep.step(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
123
|
+
bart = mcmcstep.step(subkey, bart)
|
|
124
|
+
|
|
125
|
+
burnin = i_total < n_burn
|
|
126
|
+
i_skip = jnp.where(
|
|
127
|
+
burnin,
|
|
128
|
+
i_total + 1,
|
|
129
|
+
(i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
|
|
130
|
+
)
|
|
131
|
+
callback(
|
|
132
|
+
bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
i_heavy = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
|
|
136
|
+
i_light = jnp.where(burnin, i_total, n_burn + i_heavy)
|
|
137
|
+
|
|
138
|
+
def update_trace(index, trace, bart):
|
|
139
|
+
bart = {k: v for k, v in bart.items() if k in trace}
|
|
140
|
+
|
|
141
|
+
def assign_at_index(trace_array, state_array):
|
|
142
|
+
if trace_array.size:
|
|
143
|
+
return trace_array.at[index, ...].set(state_array)
|
|
144
|
+
else:
|
|
145
|
+
# this handles the case where a trace is empty (e.g.,
|
|
146
|
+
# no burn-in) because jax refuses to index into an array
|
|
147
|
+
# of length 0
|
|
148
|
+
return trace_array
|
|
149
|
+
|
|
150
|
+
return tree.map(assign_at_index, trace, bart)
|
|
151
|
+
|
|
152
|
+
trace_heavy = update_trace(i_heavy, trace_heavy, bart)
|
|
153
|
+
trace_light = update_trace(i_light, trace_light, bart)
|
|
154
|
+
|
|
155
|
+
i_total += 1
|
|
156
|
+
carry = (bart, i_total, key, trace_light, trace_heavy)
|
|
157
|
+
return carry, None
|
|
158
|
+
|
|
159
|
+
carry, _ = lax.scan(loop, carry, None, n_burn + n_skip * n_save)
|
|
160
|
+
|
|
161
|
+
bart, _, _, trace_light, trace_heavy = carry
|
|
162
|
+
|
|
163
|
+
burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
|
|
164
|
+
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
|
|
165
|
+
main_trace.update(trace_heavy)
|
|
133
166
|
|
|
134
167
|
return bart, burnin_trace, main_trace
|
|
135
168
|
|
|
169
|
+
|
|
136
170
|
@functools.lru_cache
|
|
137
|
-
|
|
138
|
-
|
|
171
|
+
# cache to make the callback function object unique, such that the jit
|
|
172
|
+
# of run_mcmc recognizes it
|
|
139
173
|
def make_simple_print_callback(printevery):
|
|
140
174
|
"""
|
|
141
175
|
Create a logging callback function for MCMC iterations.
|
|
@@ -150,6 +184,7 @@ def make_simple_print_callback(printevery):
|
|
|
150
184
|
callback : callable
|
|
151
185
|
A function in the format required by `run_mcmc`.
|
|
152
186
|
"""
|
|
187
|
+
|
|
153
188
|
def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
|
|
154
189
|
prop_total = len(bart['leaf_trees'])
|
|
155
190
|
grow_prop = bart['grow_prop_count'] / prop_total
|
|
@@ -158,18 +193,40 @@ def make_simple_print_callback(printevery):
|
|
|
158
193
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
159
194
|
n_total = n_burn + n_save * n_skip
|
|
160
195
|
printcond = (i_total + 1) % printevery == 0
|
|
161
|
-
debug.callback(
|
|
196
|
+
debug.callback(
|
|
197
|
+
_simple_print_callback,
|
|
198
|
+
burnin,
|
|
199
|
+
i_total,
|
|
200
|
+
n_total,
|
|
201
|
+
grow_prop,
|
|
202
|
+
grow_acc,
|
|
203
|
+
prune_prop,
|
|
204
|
+
prune_acc,
|
|
205
|
+
printcond,
|
|
206
|
+
)
|
|
207
|
+
|
|
162
208
|
return callback
|
|
163
209
|
|
|
164
|
-
|
|
210
|
+
|
|
211
|
+
def _simple_print_callback(
|
|
212
|
+
burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond
|
|
213
|
+
):
|
|
165
214
|
if printcond:
|
|
166
215
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
167
216
|
total_str = str(n_total)
|
|
168
217
|
ndigits = len(total_str)
|
|
169
|
-
i_str = str(i_total + 1).rjust(ndigits)
|
|
170
|
-
|
|
218
|
+
i_str = str(i_total.item() + 1).rjust(ndigits)
|
|
219
|
+
# I do i_total.item() + 1 instead of just i_total + 1 to solve a bug
|
|
220
|
+
# originating when jax is combined with some outdated dependencies. (I
|
|
221
|
+
# did not track down which dependencies exactly.) Doing .item() makes
|
|
222
|
+
# the + 1 operation be done by Python instead of by jax. The bug is that
|
|
223
|
+
# jax hangs completely, with a secondary thread blocked at this line.
|
|
224
|
+
print(
|
|
225
|
+
f'Iteration {i_str}/{total_str} '
|
|
171
226
|
f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
|
|
172
|
-
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
|
|
227
|
+
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
|
|
228
|
+
)
|
|
229
|
+
|
|
173
230
|
|
|
174
231
|
@jax.jit
|
|
175
232
|
def evaluate_trace(trace, X):
|
|
@@ -189,9 +246,13 @@ def evaluate_trace(trace, X):
|
|
|
189
246
|
The predictions for each iteration of the MCMC.
|
|
190
247
|
"""
|
|
191
248
|
evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
|
|
192
|
-
evaluate_trees = jaxext.autobatch(evaluate_trees, 2
|
|
249
|
+
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
|
|
250
|
+
|
|
193
251
|
def loop(_, state):
|
|
194
|
-
values = evaluate_trees(
|
|
252
|
+
values = evaluate_trees(
|
|
253
|
+
X, state['leaf_trees'], state['var_trees'], state['split_trees']
|
|
254
|
+
)
|
|
195
255
|
return None, jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
256
|
+
|
|
196
257
|
_, y = lax.scan(loop, None, trace)
|
|
197
258
|
return y
|