bartz 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +99 -39
- bartz/__init__.py +3 -11
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +20 -11
- bartz/jaxext.py +41 -16
- bartz/mcmcloop.py +119 -58
- bartz/mcmcstep.py +426 -173
- bartz/prepcovars.py +22 -9
- {bartz-0.4.1.dist-info → bartz-0.5.0.dist-info}/METADATA +12 -16
- bartz-0.5.0.dist-info/RECORD +13 -0
- bartz-0.5.0.dist-info/WHEEL +4 -0
- bartz-0.4.1.dist-info/LICENSE +0 -21
- bartz-0.4.1.dist-info/RECORD +0 -13
- bartz-0.4.1.dist-info/WHEEL +0 -4
bartz/jaxext.py
CHANGED
|
@@ -26,11 +26,11 @@ import functools
|
|
|
26
26
|
import math
|
|
27
27
|
import warnings
|
|
28
28
|
|
|
29
|
-
from scipy import special
|
|
30
29
|
import jax
|
|
30
|
+
from jax import lax, tree_util
|
|
31
31
|
from jax import numpy as jnp
|
|
32
|
-
from
|
|
33
|
-
|
|
32
|
+
from scipy import special
|
|
33
|
+
|
|
34
34
|
|
|
35
35
|
def float_type(*args):
|
|
36
36
|
"""
|
|
@@ -39,16 +39,17 @@ def float_type(*args):
|
|
|
39
39
|
t = jnp.result_type(*args)
|
|
40
40
|
return jnp.sin(jnp.empty(0, t)).dtype
|
|
41
41
|
|
|
42
|
+
|
|
42
43
|
def castto(func, type):
|
|
43
44
|
@functools.wraps(func)
|
|
44
45
|
def newfunc(*args, **kw):
|
|
45
46
|
return func(*args, **kw).astype(type)
|
|
47
|
+
|
|
46
48
|
return newfunc
|
|
47
49
|
|
|
48
|
-
class scipy:
|
|
49
50
|
|
|
51
|
+
class scipy:
|
|
50
52
|
class special:
|
|
51
|
-
|
|
52
53
|
@functools.wraps(special.gammainccinv)
|
|
53
54
|
def gammainccinv(a, y):
|
|
54
55
|
a = jnp.asarray(a)
|
|
@@ -60,22 +61,24 @@ class scipy:
|
|
|
60
61
|
return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
|
|
61
62
|
|
|
62
63
|
class stats:
|
|
63
|
-
|
|
64
64
|
class invgamma:
|
|
65
|
-
|
|
66
65
|
def ppf(q, a):
|
|
67
66
|
return 1 / scipy.special.gammainccinv(a, q)
|
|
68
67
|
|
|
69
|
-
|
|
68
|
+
|
|
70
69
|
def vmap_nodoc(fun, *args, **kw):
|
|
71
70
|
"""
|
|
72
|
-
|
|
71
|
+
Wrapper of `jax.vmap` that preserves the docstring of the input function.
|
|
72
|
+
|
|
73
|
+
This is useful if the docstring already takes into account that the
|
|
74
|
+
arguments have additional axes due to vmap.
|
|
73
75
|
"""
|
|
74
76
|
doc = fun.__doc__
|
|
75
77
|
fun = jax.vmap(fun, *args, **kw)
|
|
76
78
|
fun.__doc__ = doc
|
|
77
79
|
return fun
|
|
78
80
|
|
|
81
|
+
|
|
79
82
|
def huge_value(x):
|
|
80
83
|
"""
|
|
81
84
|
Return the maximum value that can be stored in `x`.
|
|
@@ -95,19 +98,21 @@ def huge_value(x):
|
|
|
95
98
|
else:
|
|
96
99
|
return jnp.inf
|
|
97
100
|
|
|
101
|
+
|
|
98
102
|
def minimal_unsigned_dtype(max_value):
|
|
99
103
|
"""
|
|
100
104
|
Return the smallest unsigned integer dtype that can represent a given
|
|
101
105
|
maximum value (inclusive).
|
|
102
106
|
"""
|
|
103
|
-
if max_value < 2
|
|
107
|
+
if max_value < 2**8:
|
|
104
108
|
return jnp.uint8
|
|
105
|
-
if max_value < 2
|
|
109
|
+
if max_value < 2**16:
|
|
106
110
|
return jnp.uint16
|
|
107
|
-
if max_value < 2
|
|
111
|
+
if max_value < 2**32:
|
|
108
112
|
return jnp.uint32
|
|
109
113
|
return jnp.uint64
|
|
110
114
|
|
|
115
|
+
|
|
111
116
|
def signed_to_unsigned(int_dtype):
|
|
112
117
|
"""
|
|
113
118
|
Map a signed integer type to its unsigned counterpart. Unsigned types are
|
|
@@ -125,12 +130,14 @@ def signed_to_unsigned(int_dtype):
|
|
|
125
130
|
if int_dtype == jnp.int64:
|
|
126
131
|
return jnp.uint64
|
|
127
132
|
|
|
133
|
+
|
|
128
134
|
def ensure_unsigned(x):
|
|
129
135
|
"""
|
|
130
136
|
If x has signed integer type, cast it to the unsigned dtype of the same size.
|
|
131
137
|
"""
|
|
132
138
|
return x.astype(signed_to_unsigned(x.dtype))
|
|
133
139
|
|
|
140
|
+
|
|
134
141
|
@functools.partial(jax.jit, static_argnums=(1,))
|
|
135
142
|
def unique(x, size, fill_value):
|
|
136
143
|
"""
|
|
@@ -158,15 +165,18 @@ def unique(x, size, fill_value):
|
|
|
158
165
|
if size == 0:
|
|
159
166
|
return jnp.empty(0, x.dtype), 0
|
|
160
167
|
x = jnp.sort(x)
|
|
168
|
+
|
|
161
169
|
def loop(carry, x):
|
|
162
170
|
i_out, i_in, last, out = carry
|
|
163
171
|
i_out = jnp.where(x == last, i_out, i_out + 1)
|
|
164
172
|
out = out.at[i_out].set(x)
|
|
165
173
|
return (i_out, i_in + 1, x, out), None
|
|
174
|
+
|
|
166
175
|
carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype)
|
|
167
176
|
(actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size])
|
|
168
177
|
return out, actual_length + 1
|
|
169
178
|
|
|
179
|
+
|
|
170
180
|
def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False):
|
|
171
181
|
"""
|
|
172
182
|
Batch a function such that each batch is smaller than a threshold.
|
|
@@ -203,6 +213,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
203
213
|
def check_no_nones(axes, tree):
|
|
204
214
|
def check_not_none(_, axis):
|
|
205
215
|
assert axis is not None
|
|
216
|
+
|
|
206
217
|
tree_util.tree_map(check_not_none, tree, axes)
|
|
207
218
|
|
|
208
219
|
def extract_size(axes, tree):
|
|
@@ -211,6 +222,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
211
222
|
return None
|
|
212
223
|
else:
|
|
213
224
|
return x.shape[axis]
|
|
225
|
+
|
|
214
226
|
sizes = tree_util.tree_map(get_size, tree, axes)
|
|
215
227
|
sizes, _ = tree_util.tree_flatten(sizes)
|
|
216
228
|
assert all(s == sizes[0] for s in sizes)
|
|
@@ -219,6 +231,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
219
231
|
def sum_nbytes(tree):
|
|
220
232
|
def nbytes(x):
|
|
221
233
|
return math.prod(x.shape) * x.dtype.itemsize
|
|
234
|
+
|
|
222
235
|
return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
|
|
223
236
|
|
|
224
237
|
def next_divisor_small(dividend, min_divisor):
|
|
@@ -247,6 +260,7 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
247
260
|
return None
|
|
248
261
|
else:
|
|
249
262
|
return x
|
|
263
|
+
|
|
250
264
|
return tree_util.tree_map(pull_nonbatched, tree, axes), tree
|
|
251
265
|
|
|
252
266
|
def push_nonbatched(axes, tree, original_tree):
|
|
@@ -255,32 +269,38 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
255
269
|
return original_x
|
|
256
270
|
else:
|
|
257
271
|
return x
|
|
272
|
+
|
|
258
273
|
return tree_util.tree_map(push_nonbatched, original_tree, tree, axes)
|
|
259
274
|
|
|
260
275
|
def move_axes_out(axes, tree):
|
|
261
276
|
def move_axis_out(x, axis):
|
|
262
277
|
return jnp.moveaxis(x, axis, 0)
|
|
278
|
+
|
|
263
279
|
return tree_util.tree_map(move_axis_out, tree, axes)
|
|
264
280
|
|
|
265
281
|
def move_axes_in(axes, tree):
|
|
266
282
|
def move_axis_in(x, axis):
|
|
267
283
|
return jnp.moveaxis(x, 0, axis)
|
|
284
|
+
|
|
268
285
|
return tree_util.tree_map(move_axis_in, tree, axes)
|
|
269
286
|
|
|
270
287
|
def batch(tree, nbatches):
|
|
271
288
|
def batch(x):
|
|
272
289
|
return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:])
|
|
290
|
+
|
|
273
291
|
return tree_util.tree_map(batch, tree)
|
|
274
292
|
|
|
275
293
|
def unbatch(tree):
|
|
276
294
|
def unbatch(x):
|
|
277
295
|
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
296
|
+
|
|
278
297
|
return tree_util.tree_map(unbatch, tree)
|
|
279
298
|
|
|
280
299
|
def check_same(tree1, tree2):
|
|
281
300
|
def check_same(x1, x2):
|
|
282
301
|
assert x1.shape == x2.shape
|
|
283
302
|
assert x1.dtype == x2.dtype
|
|
303
|
+
|
|
284
304
|
tree_util.tree_map(check_same, tree1, tree2)
|
|
285
305
|
|
|
286
306
|
initial_in_axes = in_axes
|
|
@@ -300,7 +320,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
300
320
|
args, nonbatched_args = pull_nonbatched(in_axes, args)
|
|
301
321
|
|
|
302
322
|
total_nbytes = sum_nbytes((args, example_result))
|
|
303
|
-
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
323
|
+
min_nbatches = total_nbytes // max_io_nbytes + bool(
|
|
324
|
+
total_nbytes % max_io_nbytes
|
|
325
|
+
)
|
|
304
326
|
min_nbatches = max(1, min_nbatches)
|
|
305
327
|
nbatches = next_divisor(size, min_nbatches)
|
|
306
328
|
assert 1 <= nbatches <= max(1, size)
|
|
@@ -310,7 +332,9 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
310
332
|
batch_nbytes = total_nbytes // nbatches
|
|
311
333
|
if batch_nbytes > max_io_nbytes:
|
|
312
334
|
assert size == nbatches
|
|
313
|
-
warnings.warn(
|
|
335
|
+
warnings.warn(
|
|
336
|
+
f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
|
|
337
|
+
)
|
|
314
338
|
|
|
315
339
|
def loop(_, args):
|
|
316
340
|
args = move_axes_in(in_axes, args)
|
|
@@ -333,10 +357,11 @@ def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)
|
|
|
333
357
|
|
|
334
358
|
return batched_func
|
|
335
359
|
|
|
360
|
+
|
|
336
361
|
@tree_util.register_pytree_node_class
|
|
337
362
|
class LeafDict(dict):
|
|
338
|
-
"""
|
|
339
|
-
values
|
|
363
|
+
"""dictionary that acts as a leaf in jax pytrees, to store compile-time
|
|
364
|
+
values"""
|
|
340
365
|
|
|
341
366
|
def tree_flatten(self):
|
|
342
367
|
return (), self
|
bartz/mcmcloop.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcloop.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -29,22 +29,21 @@ Functions that implement the full BART posterior MCMC loop.
|
|
|
29
29
|
import functools
|
|
30
30
|
|
|
31
31
|
import jax
|
|
32
|
-
from jax import random
|
|
33
|
-
from jax import debug
|
|
32
|
+
from jax import debug, lax, random, tree
|
|
34
33
|
from jax import numpy as jnp
|
|
35
|
-
from jax import lax
|
|
36
34
|
|
|
37
|
-
from . import jaxext
|
|
38
|
-
from . import grove
|
|
39
|
-
from . import mcmcstep
|
|
35
|
+
from . import grove, jaxext, mcmcstep
|
|
40
36
|
|
|
41
|
-
|
|
42
|
-
|
|
37
|
+
|
|
38
|
+
@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))
|
|
39
|
+
def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
|
|
43
40
|
"""
|
|
44
41
|
Run the MCMC for the BART posterior.
|
|
45
42
|
|
|
46
43
|
Parameters
|
|
47
44
|
----------
|
|
45
|
+
key : jax.dtypes.prng_key array
|
|
46
|
+
The key for random number generation.
|
|
48
47
|
bart : dict
|
|
49
48
|
The initial MCMC state, as created and updated by the functions in
|
|
50
49
|
`bartz.mcmcstep`.
|
|
@@ -54,88 +53,123 @@ def run_mcmc(bart, n_burn, n_save, n_skip, callback, key):
|
|
|
54
53
|
The number of iterations to save.
|
|
55
54
|
n_skip : int
|
|
56
55
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
|
+
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
57
57
|
callback : callable
|
|
58
58
|
An arbitrary function run at each iteration, called with the following
|
|
59
59
|
arguments, passed by keyword:
|
|
60
60
|
|
|
61
61
|
bart : dict
|
|
62
|
-
The
|
|
62
|
+
The MCMC state just after updating it.
|
|
63
63
|
burnin : bool
|
|
64
64
|
Whether the last iteration was in the burn-in phase.
|
|
65
65
|
i_total : int
|
|
66
66
|
The index of the last iteration (0-based).
|
|
67
67
|
i_skip : int
|
|
68
|
-
The
|
|
69
|
-
|
|
68
|
+
The number of MCMC updates from the last saved state. The initial
|
|
69
|
+
state counts as saved, even if it's not copied into the trace.
|
|
70
70
|
n_burn, n_save, n_skip : int
|
|
71
71
|
The corresponding arguments as-is.
|
|
72
72
|
|
|
73
73
|
Since this function is called under the jax jit, the values are not
|
|
74
74
|
available at the time the Python code is executed. Use the utilities in
|
|
75
75
|
`jax.debug` to access the values at actual runtime.
|
|
76
|
-
key : jax.dtypes.prng_key array
|
|
77
|
-
The key for random number generation.
|
|
78
76
|
|
|
79
77
|
Returns
|
|
80
78
|
-------
|
|
81
79
|
bart : dict
|
|
82
80
|
The final MCMC state.
|
|
83
|
-
burnin_trace : dict
|
|
81
|
+
burnin_trace : dict of (n_burn, ...) arrays
|
|
84
82
|
The trace of the burn-in phase, containing the following subset of
|
|
85
83
|
fields from the `bart` dictionary, with an additional head index that
|
|
86
84
|
runs over MCMC iterations: 'sigma2', 'grow_prop_count',
|
|
87
85
|
'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
|
|
88
|
-
main_trace : dict
|
|
86
|
+
main_trace : dict of (n_save, ...) arrays
|
|
89
87
|
The trace of the main phase, containing the following subset of fields
|
|
90
88
|
from the `bart` dictionary, with an additional head index that runs
|
|
91
89
|
over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
|
|
92
90
|
the fields in `burnin_trace`.
|
|
91
|
+
|
|
92
|
+
Notes
|
|
93
|
+
-----
|
|
94
|
+
The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
|
|
95
|
+
not include the initial state, and include the final state.
|
|
93
96
|
"""
|
|
94
97
|
|
|
95
|
-
|
|
98
|
+
tracevars_light = (
|
|
99
|
+
'sigma2',
|
|
100
|
+
'grow_prop_count',
|
|
101
|
+
'grow_acc_count',
|
|
102
|
+
'prune_prop_count',
|
|
103
|
+
'prune_acc_count',
|
|
104
|
+
'ratios',
|
|
105
|
+
)
|
|
106
|
+
tracevars_heavy = ('leaf_trees', 'var_trees', 'split_trees')
|
|
107
|
+
|
|
108
|
+
def empty_trace(length, bart, tracelist):
|
|
109
|
+
bart = {k: v for k, v in bart.items() if k in tracelist}
|
|
110
|
+
return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
96
111
|
|
|
97
|
-
|
|
112
|
+
trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
|
|
113
|
+
trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
|
|
98
114
|
|
|
99
115
|
callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
|
|
100
116
|
|
|
101
|
-
|
|
102
|
-
|
|
117
|
+
carry = (bart, 0, key, trace_light, trace_heavy)
|
|
118
|
+
|
|
119
|
+
def loop(carry, _):
|
|
120
|
+
bart, i_total, key, trace_light, trace_heavy = carry
|
|
121
|
+
|
|
103
122
|
key, subkey = random.split(key)
|
|
104
|
-
bart = mcmcstep.step(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
123
|
+
bart = mcmcstep.step(subkey, bart)
|
|
124
|
+
|
|
125
|
+
burnin = i_total < n_burn
|
|
126
|
+
i_skip = jnp.where(
|
|
127
|
+
burnin,
|
|
128
|
+
i_total + 1,
|
|
129
|
+
(i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
|
|
130
|
+
)
|
|
131
|
+
callback(
|
|
132
|
+
bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
i_heavy = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
|
|
136
|
+
i_light = jnp.where(burnin, i_total, n_burn + i_heavy)
|
|
137
|
+
|
|
138
|
+
def update_trace(index, trace, bart):
|
|
139
|
+
bart = {k: v for k, v in bart.items() if k in trace}
|
|
140
|
+
|
|
141
|
+
def assign_at_index(trace_array, state_array):
|
|
142
|
+
if trace_array.size:
|
|
143
|
+
return trace_array.at[index, ...].set(state_array)
|
|
144
|
+
else:
|
|
145
|
+
# this handles the case where a trace is empty (e.g.,
|
|
146
|
+
# no burn-in) because jax refuses to index into an array
|
|
147
|
+
# of length 0
|
|
148
|
+
return trace_array
|
|
149
|
+
|
|
150
|
+
return tree.map(assign_at_index, trace, bart)
|
|
151
|
+
|
|
152
|
+
trace_heavy = update_trace(i_heavy, trace_heavy, bart)
|
|
153
|
+
trace_light = update_trace(i_light, trace_light, bart)
|
|
154
|
+
|
|
155
|
+
i_total += 1
|
|
156
|
+
carry = (bart, i_total, key, trace_light, trace_heavy)
|
|
157
|
+
return carry, None
|
|
158
|
+
|
|
159
|
+
carry, _ = lax.scan(loop, carry, None, n_burn + n_skip * n_save)
|
|
160
|
+
|
|
161
|
+
bart, _, _, trace_light, trace_heavy = carry
|
|
162
|
+
|
|
163
|
+
burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
|
|
164
|
+
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
|
|
165
|
+
main_trace.update(trace_heavy)
|
|
133
166
|
|
|
134
167
|
return bart, burnin_trace, main_trace
|
|
135
168
|
|
|
169
|
+
|
|
136
170
|
@functools.lru_cache
|
|
137
|
-
|
|
138
|
-
|
|
171
|
+
# cache to make the callback function object unique, such that the jit
|
|
172
|
+
# of run_mcmc recognizes it
|
|
139
173
|
def make_simple_print_callback(printevery):
|
|
140
174
|
"""
|
|
141
175
|
Create a logging callback function for MCMC iterations.
|
|
@@ -150,6 +184,7 @@ def make_simple_print_callback(printevery):
|
|
|
150
184
|
callback : callable
|
|
151
185
|
A function in the format required by `run_mcmc`.
|
|
152
186
|
"""
|
|
187
|
+
|
|
153
188
|
def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
|
|
154
189
|
prop_total = len(bart['leaf_trees'])
|
|
155
190
|
grow_prop = bart['grow_prop_count'] / prop_total
|
|
@@ -158,18 +193,40 @@ def make_simple_print_callback(printevery):
|
|
|
158
193
|
prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
|
|
159
194
|
n_total = n_burn + n_save * n_skip
|
|
160
195
|
printcond = (i_total + 1) % printevery == 0
|
|
161
|
-
debug.callback(
|
|
196
|
+
debug.callback(
|
|
197
|
+
_simple_print_callback,
|
|
198
|
+
burnin,
|
|
199
|
+
i_total,
|
|
200
|
+
n_total,
|
|
201
|
+
grow_prop,
|
|
202
|
+
grow_acc,
|
|
203
|
+
prune_prop,
|
|
204
|
+
prune_acc,
|
|
205
|
+
printcond,
|
|
206
|
+
)
|
|
207
|
+
|
|
162
208
|
return callback
|
|
163
209
|
|
|
164
|
-
|
|
210
|
+
|
|
211
|
+
def _simple_print_callback(
|
|
212
|
+
burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond
|
|
213
|
+
):
|
|
165
214
|
if printcond:
|
|
166
215
|
burnin_flag = ' (burnin)' if burnin else ''
|
|
167
216
|
total_str = str(n_total)
|
|
168
217
|
ndigits = len(total_str)
|
|
169
|
-
i_str = str(i_total + 1).rjust(ndigits)
|
|
170
|
-
|
|
218
|
+
i_str = str(i_total.item() + 1).rjust(ndigits)
|
|
219
|
+
# I do i_total.item() + 1 instead of just i_total + 1 to solve a bug
|
|
220
|
+
# originating when jax is combined with some outdated dependencies. (I
|
|
221
|
+
# did not track down which dependencies exactly.) Doing .item() makes
|
|
222
|
+
# the + 1 operation be done by Python instead of by jax. The bug is that
|
|
223
|
+
# jax hangs completely, with a secondary thread blocked at this line.
|
|
224
|
+
print(
|
|
225
|
+
f'Iteration {i_str}/{total_str} '
|
|
171
226
|
f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
|
|
172
|
-
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
|
|
227
|
+
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
|
|
228
|
+
)
|
|
229
|
+
|
|
173
230
|
|
|
174
231
|
@jax.jit
|
|
175
232
|
def evaluate_trace(trace, X):
|
|
@@ -189,9 +246,13 @@ def evaluate_trace(trace, X):
|
|
|
189
246
|
The predictions for each iteration of the MCMC.
|
|
190
247
|
"""
|
|
191
248
|
evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
|
|
192
|
-
evaluate_trees = jaxext.autobatch(evaluate_trees, 2
|
|
249
|
+
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
|
|
250
|
+
|
|
193
251
|
def loop(_, state):
|
|
194
|
-
values = evaluate_trees(
|
|
252
|
+
values = evaluate_trees(
|
|
253
|
+
X, state['leaf_trees'], state['var_trees'], state['split_trees']
|
|
254
|
+
)
|
|
195
255
|
return None, jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
256
|
+
|
|
196
257
|
_, y = lax.scan(loop, None, trace)
|
|
197
258
|
return y
|