bartz 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +266 -113
- bartz/__init__.py +4 -12
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +62 -12
- bartz/jaxext.py +111 -37
- bartz/mcmcloop.py +419 -105
- bartz/mcmcstep.py +1528 -760
- bartz/prepcovars.py +25 -10
- {bartz-0.4.1.dist-info → bartz-0.6.0.dist-info}/METADATA +14 -16
- bartz-0.6.0.dist-info/RECORD +13 -0
- bartz-0.6.0.dist-info/WHEEL +4 -0
- bartz-0.4.1.dist-info/LICENSE +0 -21
- bartz-0.4.1.dist-info/RECORD +0 -13
- bartz-0.4.1.dist-info/WHEEL +0 -4
bartz/mcmcloop.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcloop.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -22,154 +22,464 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""
|
|
26
|
-
Functions that implement the full BART posterior MCMC loop.
|
|
27
|
-
"""
|
|
25
|
+
"""Functions that implement the full BART posterior MCMC loop."""
|
|
28
26
|
|
|
29
27
|
import functools
|
|
30
28
|
|
|
31
29
|
import jax
|
|
32
|
-
|
|
33
|
-
from jax import debug
|
|
30
|
+
import numpy
|
|
31
|
+
from jax import debug, lax, tree
|
|
34
32
|
from jax import numpy as jnp
|
|
35
|
-
from
|
|
33
|
+
from jaxtyping import Array, Real
|
|
36
34
|
|
|
37
|
-
from . import jaxext
|
|
38
|
-
from . import
|
|
39
|
-
from . import mcmcstep
|
|
35
|
+
from . import grove, jaxext, mcmcstep
|
|
36
|
+
from .mcmcstep import State
|
|
40
37
|
|
|
41
|
-
|
|
42
|
-
def
|
|
38
|
+
|
|
39
|
+
def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
|
|
40
|
+
"""Extract variables for the main trace, to be used in `run_mcmc`."""
|
|
41
|
+
return dict(
|
|
42
|
+
leaf_trees=state.forest.leaf_trees,
|
|
43
|
+
var_trees=state.forest.var_trees,
|
|
44
|
+
split_trees=state.forest.split_trees,
|
|
45
|
+
offset=state.offset,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]:
|
|
50
|
+
"""Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
|
|
51
|
+
return dict(
|
|
52
|
+
sigma2=state.sigma2,
|
|
53
|
+
grow_prop_count=state.forest.grow_prop_count,
|
|
54
|
+
grow_acc_count=state.forest.grow_acc_count,
|
|
55
|
+
prune_prop_count=state.forest.prune_prop_count,
|
|
56
|
+
prune_acc_count=state.forest.prune_acc_count,
|
|
57
|
+
log_likelihood=state.forest.log_likelihood,
|
|
58
|
+
log_trans_prior=state.forest.log_trans_prior,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def run_mcmc(
|
|
63
|
+
key,
|
|
64
|
+
bart,
|
|
65
|
+
n_save,
|
|
66
|
+
*,
|
|
67
|
+
n_burn=0,
|
|
68
|
+
n_skip=1,
|
|
69
|
+
inner_loop_length=None,
|
|
70
|
+
allow_overflow=False,
|
|
71
|
+
inner_callback=None,
|
|
72
|
+
outer_callback=None,
|
|
73
|
+
callback_state=None,
|
|
74
|
+
onlymain_extractor=default_onlymain_extractor,
|
|
75
|
+
both_extractor=default_both_extractor,
|
|
76
|
+
):
|
|
43
77
|
"""
|
|
44
78
|
Run the MCMC for the BART posterior.
|
|
45
79
|
|
|
46
80
|
Parameters
|
|
47
81
|
----------
|
|
82
|
+
key : jax.dtypes.prng_key array
|
|
83
|
+
A key for random number generation.
|
|
48
84
|
bart : dict
|
|
49
85
|
The initial MCMC state, as created and updated by the functions in
|
|
50
|
-
`bartz.mcmcstep`.
|
|
51
|
-
|
|
52
|
-
|
|
86
|
+
`bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
|
|
87
|
+
so this variable is invalidated after running `run_mcmc`. Make a copy
|
|
88
|
+
beforehand to use it again.
|
|
53
89
|
n_save : int
|
|
54
90
|
The number of iterations to save.
|
|
55
|
-
|
|
91
|
+
n_burn : int, default 0
|
|
92
|
+
The number of initial iterations which are not saved.
|
|
93
|
+
n_skip : int, default 1
|
|
56
94
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
95
|
+
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
96
|
+
inner_loop_length : int, optional
|
|
97
|
+
The MCMC loop is split into an outer and an inner loop. The outer loop
|
|
98
|
+
is in Python, while the inner loop is in JAX. `inner_loop_length` is the
|
|
99
|
+
number of iterations of the inner loop to run for each iteration of the
|
|
100
|
+
outer loop. If not specified, the outer loop will iterate just once,
|
|
101
|
+
with all iterations done in a single inner loop run. The inner stride is
|
|
102
|
+
unrelated to the stride used for saving the trace.
|
|
103
|
+
allow_overflow : bool, default False
|
|
104
|
+
If `False`, `inner_loop_length` must be a divisor of the total number of
|
|
105
|
+
iterations ``n_burn + n_skip * n_save``. If `True` and
|
|
106
|
+
`inner_loop_length` is not a divisor, some of the MCMC iterations in the
|
|
107
|
+
last outer loop iteration will not be saved to the trace.
|
|
108
|
+
inner_callback : callable, optional
|
|
109
|
+
outer_callback : callable, optional
|
|
110
|
+
Arbitrary functions run during the loop after updating the state.
|
|
111
|
+
`inner_callback` is called after each update, while `outer_callback` is
|
|
112
|
+
called after completing an inner loop. The callbacks are invoked with
|
|
113
|
+
the following arguments, passed by keyword:
|
|
60
114
|
|
|
61
115
|
bart : dict
|
|
62
|
-
The
|
|
116
|
+
The MCMC state just after updating it.
|
|
63
117
|
burnin : bool
|
|
64
118
|
Whether the last iteration was in the burn-in phase.
|
|
119
|
+
overflow : bool
|
|
120
|
+
Whether the last iteration was in the overflow phase (iterations
|
|
121
|
+
not saved due to `inner_loop_length` not being a divisor of the
|
|
122
|
+
total number of iterations).
|
|
65
123
|
i_total : int
|
|
66
|
-
The index of the last iteration (0-based).
|
|
124
|
+
The index of the last MCMC iteration (0-based).
|
|
67
125
|
i_skip : int
|
|
68
|
-
The
|
|
69
|
-
|
|
126
|
+
The number of MCMC updates from the last saved state. The initial
|
|
127
|
+
state counts as saved, even if it's not copied into the trace.
|
|
128
|
+
callback_state : jax pytree
|
|
129
|
+
The callback state, initially set to the argument passed to
|
|
130
|
+
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
131
|
+
of `inner_callback` or `outer_callback`.
|
|
70
132
|
n_burn, n_save, n_skip : int
|
|
71
133
|
The corresponding arguments as-is.
|
|
134
|
+
i_outer : int
|
|
135
|
+
The index of the last outer loop iteration (0-based).
|
|
136
|
+
inner_loop_length : int
|
|
137
|
+
The number of MCMC iterations in the inner loop.
|
|
72
138
|
|
|
73
|
-
|
|
74
|
-
available at the time the Python code is executed. Use the utilities
|
|
75
|
-
`jax.debug` to access the values at actual runtime.
|
|
76
|
-
|
|
77
|
-
The
|
|
139
|
+
`inner_callback` is called under the jax jit, so the argument values are
|
|
140
|
+
not available at the time the Python code is executed. Use the utilities
|
|
141
|
+
in `jax.debug` to access the values at actual runtime.
|
|
142
|
+
|
|
143
|
+
The callbacks must return two values:
|
|
144
|
+
|
|
145
|
+
bart : dict
|
|
146
|
+
A possibly modified MCMC state. To avoid modifying the state,
|
|
147
|
+
return the `bart` argument passed to the callback as-is.
|
|
148
|
+
callback_state : jax pytree
|
|
149
|
+
The new state to be passed on the next callback invocation.
|
|
150
|
+
|
|
151
|
+
For convenience, if a callback returns `None`, the states are not
|
|
152
|
+
updated.
|
|
153
|
+
callback_state : jax pytree, optional
|
|
154
|
+
The initial state for the callbacks.
|
|
155
|
+
onlymain_extractor : callable, optional
|
|
156
|
+
both_extractor : callable, optional
|
|
157
|
+
Functions that extract the variables to be saved respectively only in
|
|
158
|
+
the main trace and in both traces, given the MCMC state as argument.
|
|
159
|
+
Must return a pytree, and must be vmappable.
|
|
78
160
|
|
|
79
161
|
Returns
|
|
80
162
|
-------
|
|
81
163
|
bart : dict
|
|
82
164
|
The final MCMC state.
|
|
83
|
-
burnin_trace : dict
|
|
165
|
+
burnin_trace : dict of (n_burn, ...) arrays
|
|
84
166
|
The trace of the burn-in phase, containing the following subset of
|
|
85
167
|
fields from the `bart` dictionary, with an additional head index that
|
|
86
168
|
runs over MCMC iterations: 'sigma2', 'grow_prop_count',
|
|
87
|
-
'grow_acc_count', 'prune_prop_count', 'prune_acc_count'
|
|
88
|
-
|
|
169
|
+
'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
|
|
170
|
+
the fields in `tracevars_both`).
|
|
171
|
+
main_trace : dict of (n_save, ...) arrays
|
|
89
172
|
The trace of the main phase, containing the following subset of fields
|
|
90
|
-
from the `bart` dictionary, with an additional head index that runs
|
|
91
|
-
|
|
92
|
-
the fields in `
|
|
173
|
+
from the `bart` dictionary, with an additional head index that runs over
|
|
174
|
+
MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
|
|
175
|
+
specified the fields in `tracevars_onlymain`), plus the fields in
|
|
176
|
+
`burnin_trace`.
|
|
177
|
+
|
|
178
|
+
Raises
|
|
179
|
+
------
|
|
180
|
+
ValueError
|
|
181
|
+
If `inner_loop_length` is not a divisor of the total number of
|
|
182
|
+
iterations and `allow_overflow` is `False`.
|
|
183
|
+
|
|
184
|
+
Notes
|
|
185
|
+
-----
|
|
186
|
+
The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
|
|
187
|
+
not include the initial state, and include the final state.
|
|
93
188
|
"""
|
|
94
189
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
190
|
+
def empty_trace(length, bart, extractor):
|
|
191
|
+
return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
192
|
+
|
|
193
|
+
trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
|
|
194
|
+
trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
|
|
195
|
+
|
|
196
|
+
# determine number of iterations for inner and outer loops
|
|
197
|
+
n_iters = n_burn + n_skip * n_save
|
|
198
|
+
if inner_loop_length is None:
|
|
199
|
+
inner_loop_length = n_iters
|
|
200
|
+
n_outer = n_iters // inner_loop_length
|
|
201
|
+
if n_iters % inner_loop_length:
|
|
202
|
+
if allow_overflow:
|
|
203
|
+
n_outer += 1
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
|
|
206
|
+
|
|
207
|
+
carry = (bart, 0, key, trace_both, trace_onlymain, callback_state)
|
|
208
|
+
for i_outer in range(n_outer):
|
|
209
|
+
carry = _run_mcmc_inner_loop(
|
|
210
|
+
carry,
|
|
211
|
+
inner_loop_length,
|
|
212
|
+
inner_callback,
|
|
213
|
+
onlymain_extractor,
|
|
214
|
+
both_extractor,
|
|
215
|
+
n_burn,
|
|
216
|
+
n_save,
|
|
217
|
+
n_skip,
|
|
218
|
+
i_outer,
|
|
219
|
+
)
|
|
220
|
+
if outer_callback is not None:
|
|
221
|
+
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
222
|
+
i_total -= 1 # because i_total is updated at the end of the inner loop
|
|
223
|
+
i_skip = _compute_i_skip(i_total, n_burn, n_skip)
|
|
224
|
+
rt = outer_callback(
|
|
225
|
+
bart=bart,
|
|
226
|
+
burnin=i_total < n_burn,
|
|
227
|
+
overflow=i_total >= n_iters,
|
|
228
|
+
i_total=i_total,
|
|
229
|
+
i_skip=i_skip,
|
|
230
|
+
callback_state=callback_state,
|
|
231
|
+
n_burn=n_burn,
|
|
232
|
+
n_save=n_save,
|
|
233
|
+
n_skip=n_skip,
|
|
234
|
+
i_outer=i_outer,
|
|
235
|
+
inner_loop_length=inner_loop_length,
|
|
236
|
+
)
|
|
237
|
+
if rt is not None:
|
|
238
|
+
bart, callback_state = rt
|
|
239
|
+
i_total += 1
|
|
240
|
+
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
241
|
+
|
|
242
|
+
bart, _, _, trace_both, trace_onlymain, _ = carry
|
|
243
|
+
|
|
244
|
+
burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
|
|
245
|
+
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
|
|
246
|
+
main_trace.update(trace_onlymain)
|
|
133
247
|
|
|
134
248
|
return bart, burnin_trace, main_trace
|
|
135
249
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
250
|
+
|
|
251
|
+
def _compute_i_skip(i_total, n_burn, n_skip):
|
|
252
|
+
burnin = i_total < n_burn
|
|
253
|
+
return jnp.where(
|
|
254
|
+
burnin,
|
|
255
|
+
i_total + 1,
|
|
256
|
+
(i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
|
|
261
|
+
def _run_mcmc_inner_loop(
|
|
262
|
+
carry,
|
|
263
|
+
inner_loop_length,
|
|
264
|
+
inner_callback,
|
|
265
|
+
onlymain_extractor,
|
|
266
|
+
both_extractor,
|
|
267
|
+
n_burn,
|
|
268
|
+
n_save,
|
|
269
|
+
n_skip,
|
|
270
|
+
i_outer,
|
|
271
|
+
):
|
|
272
|
+
def loop(carry, _):
|
|
273
|
+
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
274
|
+
|
|
275
|
+
keys = jaxext.split(key)
|
|
276
|
+
key = keys.pop()
|
|
277
|
+
bart = mcmcstep.step(keys.pop(), bart)
|
|
278
|
+
|
|
279
|
+
burnin = i_total < n_burn
|
|
280
|
+
if inner_callback is not None:
|
|
281
|
+
i_skip = _compute_i_skip(i_total, n_burn, n_skip)
|
|
282
|
+
rt = inner_callback(
|
|
283
|
+
bart=bart,
|
|
284
|
+
burnin=burnin,
|
|
285
|
+
overflow=i_total >= n_burn + n_save * n_skip,
|
|
286
|
+
i_total=i_total,
|
|
287
|
+
i_skip=i_skip,
|
|
288
|
+
callback_state=callback_state,
|
|
289
|
+
n_burn=n_burn,
|
|
290
|
+
n_save=n_save,
|
|
291
|
+
n_skip=n_skip,
|
|
292
|
+
i_outer=i_outer,
|
|
293
|
+
inner_loop_length=inner_loop_length,
|
|
294
|
+
)
|
|
295
|
+
if rt is not None:
|
|
296
|
+
bart, callback_state = rt
|
|
297
|
+
|
|
298
|
+
i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
|
|
299
|
+
i_both = jnp.where(burnin, i_total, n_burn + i_onlymain)
|
|
300
|
+
|
|
301
|
+
def update_trace(index, trace, state):
|
|
302
|
+
def assign_at_index(trace_array, state_array):
|
|
303
|
+
if trace_array.size:
|
|
304
|
+
return trace_array.at[index, ...].set(state_array)
|
|
305
|
+
else:
|
|
306
|
+
# this handles the case where a trace is empty (e.g.,
|
|
307
|
+
# no burn-in) because jax refuses to index into an array
|
|
308
|
+
# of length 0
|
|
309
|
+
return trace_array
|
|
310
|
+
|
|
311
|
+
return tree.map(assign_at_index, trace, state)
|
|
312
|
+
|
|
313
|
+
trace_onlymain = update_trace(
|
|
314
|
+
i_onlymain, trace_onlymain, onlymain_extractor(bart)
|
|
315
|
+
)
|
|
316
|
+
trace_both = update_trace(i_both, trace_both, both_extractor(bart))
|
|
317
|
+
|
|
318
|
+
i_total += 1
|
|
319
|
+
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
320
|
+
return carry, None
|
|
321
|
+
|
|
322
|
+
carry, _ = lax.scan(loop, carry, None, inner_loop_length)
|
|
323
|
+
return carry
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def make_print_callbacks(dot_every_inner=1, report_every_outer=1):
|
|
140
327
|
"""
|
|
141
|
-
|
|
328
|
+
Prepare logging callbacks for `run_mcmc`.
|
|
329
|
+
|
|
330
|
+
Prepare callbacks which print a dot on every iteration, and a longer
|
|
331
|
+
report outer loop iteration.
|
|
142
332
|
|
|
143
333
|
Parameters
|
|
144
334
|
----------
|
|
145
|
-
|
|
146
|
-
|
|
335
|
+
dot_every_inner : int, default 1
|
|
336
|
+
A dot is printed every `dot_every_inner` MCMC iterations.
|
|
337
|
+
report_every_outer : int, default 1
|
|
338
|
+
A report is printed every `report_every_outer` outer loop
|
|
339
|
+
iterations.
|
|
147
340
|
|
|
148
341
|
Returns
|
|
149
342
|
-------
|
|
150
|
-
|
|
151
|
-
A
|
|
343
|
+
kwargs : dict
|
|
344
|
+
A dictionary with the arguments to pass to `run_mcmc` as keyword
|
|
345
|
+
arguments to set up the callbacks.
|
|
346
|
+
|
|
347
|
+
Examples
|
|
348
|
+
--------
|
|
349
|
+
>>> run_mcmc(..., **make_print_callbacks())
|
|
152
350
|
"""
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
351
|
+
return dict(
|
|
352
|
+
inner_callback=_print_callback_inner,
|
|
353
|
+
outer_callback=_print_callback_outer,
|
|
354
|
+
callback_state=dict(
|
|
355
|
+
dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
|
|
356
|
+
),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _print_callback_inner(*, i_total, callback_state, **_):
|
|
361
|
+
dot_every_inner = callback_state['dot_every_inner']
|
|
362
|
+
if dot_every_inner is not None:
|
|
363
|
+
cond = (i_total + 1) % dot_every_inner == 0
|
|
364
|
+
debug.callback(_print_dot, cond)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _print_dot(cond):
|
|
368
|
+
if cond:
|
|
369
|
+
print('.', end='', flush=True)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _print_callback_outer(
|
|
373
|
+
*,
|
|
374
|
+
bart,
|
|
375
|
+
burnin,
|
|
376
|
+
overflow,
|
|
377
|
+
i_total,
|
|
378
|
+
n_burn,
|
|
379
|
+
n_save,
|
|
380
|
+
n_skip,
|
|
381
|
+
callback_state,
|
|
382
|
+
i_outer,
|
|
383
|
+
inner_loop_length,
|
|
384
|
+
**_,
|
|
385
|
+
):
|
|
386
|
+
report_every_outer = callback_state['report_every_outer']
|
|
387
|
+
if report_every_outer is not None:
|
|
388
|
+
dot_every_inner = callback_state['dot_every_inner']
|
|
389
|
+
if dot_every_inner is None:
|
|
390
|
+
newline = False
|
|
391
|
+
else:
|
|
392
|
+
newline = dot_every_inner < inner_loop_length
|
|
393
|
+
debug.callback(
|
|
394
|
+
_print_report,
|
|
395
|
+
cond=(i_outer + 1) % report_every_outer == 0,
|
|
396
|
+
newline=newline,
|
|
397
|
+
burnin=burnin,
|
|
398
|
+
overflow=overflow,
|
|
399
|
+
i_total=i_total,
|
|
400
|
+
n_iters=n_burn + n_save * n_skip,
|
|
401
|
+
grow_prop_count=bart.forest.grow_prop_count,
|
|
402
|
+
grow_acc_count=bart.forest.grow_acc_count,
|
|
403
|
+
prune_prop_count=bart.forest.prune_prop_count,
|
|
404
|
+
prune_acc_count=bart.forest.prune_acc_count,
|
|
405
|
+
prop_total=len(bart.forest.leaf_trees),
|
|
406
|
+
fill=grove.forest_fill(bart.forest.split_trees),
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _convert_jax_arrays_in_args(func):
|
|
411
|
+
"""Remove jax arrays from a function arguments.
|
|
412
|
+
|
|
413
|
+
Converts all jax.Array instances in the arguments to either Python scalars
|
|
414
|
+
or numpy arrays.
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
def convert_jax_arrays(pytree):
|
|
418
|
+
def convert_jax_arrays(val):
|
|
419
|
+
if not isinstance(val, jax.Array):
|
|
420
|
+
return val
|
|
421
|
+
elif val.shape:
|
|
422
|
+
return numpy.array(val)
|
|
423
|
+
else:
|
|
424
|
+
return val.item()
|
|
425
|
+
|
|
426
|
+
return tree.map(convert_jax_arrays, pytree)
|
|
427
|
+
|
|
428
|
+
@functools.wraps(func)
|
|
429
|
+
def new_func(*args, **kw):
|
|
430
|
+
args = convert_jax_arrays(args)
|
|
431
|
+
kw = convert_jax_arrays(kw)
|
|
432
|
+
return func(*args, **kw)
|
|
433
|
+
|
|
434
|
+
return new_func
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
@_convert_jax_arrays_in_args
|
|
438
|
+
# convert all jax arrays in arguments because operations on them could lead to
|
|
439
|
+
# deadlock with the main thread
|
|
440
|
+
def _print_report(
|
|
441
|
+
*,
|
|
442
|
+
cond,
|
|
443
|
+
newline,
|
|
444
|
+
burnin,
|
|
445
|
+
overflow,
|
|
446
|
+
i_total,
|
|
447
|
+
n_iters,
|
|
448
|
+
grow_prop_count,
|
|
449
|
+
grow_acc_count,
|
|
450
|
+
prune_prop_count,
|
|
451
|
+
prune_acc_count,
|
|
452
|
+
prop_total,
|
|
453
|
+
fill,
|
|
454
|
+
):
|
|
455
|
+
if cond:
|
|
456
|
+
newline = '\n' if newline else ''
|
|
457
|
+
|
|
458
|
+
def acc_string(acc_count, prop_count):
|
|
459
|
+
if prop_count:
|
|
460
|
+
return f'{acc_count / prop_count:.0%}'
|
|
461
|
+
else:
|
|
462
|
+
return ' n/d'
|
|
463
|
+
|
|
464
|
+
grow_prop = grow_prop_count / prop_total
|
|
465
|
+
prune_prop = prune_prop_count / prop_total
|
|
466
|
+
grow_acc = acc_string(grow_acc_count, grow_prop_count)
|
|
467
|
+
prune_acc = acc_string(prune_acc_count, prune_prop_count)
|
|
468
|
+
|
|
469
|
+
if burnin:
|
|
470
|
+
flag = ' (burnin)'
|
|
471
|
+
elif overflow:
|
|
472
|
+
flag = ' (overflow)'
|
|
473
|
+
else:
|
|
474
|
+
flag = ''
|
|
475
|
+
|
|
476
|
+
print(
|
|
477
|
+
f'{newline}It {i_total + 1}/{n_iters} '
|
|
478
|
+
f'grow P={grow_prop:.0%} A={grow_acc}, '
|
|
479
|
+
f'prune P={prune_prop:.0%} A={prune_acc}, '
|
|
480
|
+
f'fill={fill:.0%}{flag}'
|
|
481
|
+
)
|
|
482
|
+
|
|
173
483
|
|
|
174
484
|
@jax.jit
|
|
175
485
|
def evaluate_trace(trace, X):
|
|
@@ -189,9 +499,13 @@ def evaluate_trace(trace, X):
|
|
|
189
499
|
The predictions for each iteration of the MCMC.
|
|
190
500
|
"""
|
|
191
501
|
evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
|
|
192
|
-
evaluate_trees = jaxext.autobatch(evaluate_trees, 2
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
502
|
+
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
|
|
503
|
+
|
|
504
|
+
def loop(_, row):
|
|
505
|
+
values = evaluate_trees(
|
|
506
|
+
X, row['leaf_trees'], row['var_trees'], row['split_trees']
|
|
507
|
+
)
|
|
508
|
+
return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
509
|
+
|
|
196
510
|
_, y = lax.scan(loop, None, trace)
|
|
197
511
|
return y
|