bartz 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +196 -103
- bartz/__init__.py +1 -1
- bartz/_version.py +1 -1
- bartz/debug.py +1 -1
- bartz/grove.py +43 -2
- bartz/jaxext.py +82 -33
- bartz/mcmcloop.py +367 -114
- bartz/mcmcstep.py +1322 -807
- bartz/prepcovars.py +3 -1
- {bartz-0.5.0.dist-info → bartz-0.6.0.dist-info}/METADATA +7 -5
- bartz-0.6.0.dist-info/RECORD +13 -0
- {bartz-0.5.0.dist-info → bartz-0.6.0.dist-info}/WHEEL +1 -1
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/mcmcloop.py
CHANGED
|
@@ -22,57 +22,141 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""
|
|
26
|
-
Functions that implement the full BART posterior MCMC loop.
|
|
27
|
-
"""
|
|
25
|
+
"""Functions that implement the full BART posterior MCMC loop."""
|
|
28
26
|
|
|
29
27
|
import functools
|
|
30
28
|
|
|
31
29
|
import jax
|
|
32
|
-
|
|
30
|
+
import numpy
|
|
31
|
+
from jax import debug, lax, tree
|
|
33
32
|
from jax import numpy as jnp
|
|
33
|
+
from jaxtyping import Array, Real
|
|
34
34
|
|
|
35
35
|
from . import grove, jaxext, mcmcstep
|
|
36
|
+
from .mcmcstep import State
|
|
36
37
|
|
|
37
38
|
|
|
38
|
-
|
|
39
|
-
|
|
39
|
+
def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
|
|
40
|
+
"""Extract variables for the main trace, to be used in `run_mcmc`."""
|
|
41
|
+
return dict(
|
|
42
|
+
leaf_trees=state.forest.leaf_trees,
|
|
43
|
+
var_trees=state.forest.var_trees,
|
|
44
|
+
split_trees=state.forest.split_trees,
|
|
45
|
+
offset=state.offset,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]:
|
|
50
|
+
"""Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
|
|
51
|
+
return dict(
|
|
52
|
+
sigma2=state.sigma2,
|
|
53
|
+
grow_prop_count=state.forest.grow_prop_count,
|
|
54
|
+
grow_acc_count=state.forest.grow_acc_count,
|
|
55
|
+
prune_prop_count=state.forest.prune_prop_count,
|
|
56
|
+
prune_acc_count=state.forest.prune_acc_count,
|
|
57
|
+
log_likelihood=state.forest.log_likelihood,
|
|
58
|
+
log_trans_prior=state.forest.log_trans_prior,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def run_mcmc(
|
|
63
|
+
key,
|
|
64
|
+
bart,
|
|
65
|
+
n_save,
|
|
66
|
+
*,
|
|
67
|
+
n_burn=0,
|
|
68
|
+
n_skip=1,
|
|
69
|
+
inner_loop_length=None,
|
|
70
|
+
allow_overflow=False,
|
|
71
|
+
inner_callback=None,
|
|
72
|
+
outer_callback=None,
|
|
73
|
+
callback_state=None,
|
|
74
|
+
onlymain_extractor=default_onlymain_extractor,
|
|
75
|
+
both_extractor=default_both_extractor,
|
|
76
|
+
):
|
|
40
77
|
"""
|
|
41
78
|
Run the MCMC for the BART posterior.
|
|
42
79
|
|
|
43
80
|
Parameters
|
|
44
81
|
----------
|
|
45
82
|
key : jax.dtypes.prng_key array
|
|
46
|
-
|
|
83
|
+
A key for random number generation.
|
|
47
84
|
bart : dict
|
|
48
85
|
The initial MCMC state, as created and updated by the functions in
|
|
49
|
-
`bartz.mcmcstep`.
|
|
50
|
-
|
|
51
|
-
|
|
86
|
+
`bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
|
|
87
|
+
so this variable is invalidated after running `run_mcmc`. Make a copy
|
|
88
|
+
beforehand to use it again.
|
|
52
89
|
n_save : int
|
|
53
90
|
The number of iterations to save.
|
|
54
|
-
|
|
91
|
+
n_burn : int, default 0
|
|
92
|
+
The number of initial iterations which are not saved.
|
|
93
|
+
n_skip : int, default 1
|
|
55
94
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
95
|
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
96
|
+
inner_loop_length : int, optional
|
|
97
|
+
The MCMC loop is split into an outer and an inner loop. The outer loop
|
|
98
|
+
is in Python, while the inner loop is in JAX. `inner_loop_length` is the
|
|
99
|
+
number of iterations of the inner loop to run for each iteration of the
|
|
100
|
+
outer loop. If not specified, the outer loop will iterate just once,
|
|
101
|
+
with all iterations done in a single inner loop run. The inner stride is
|
|
102
|
+
unrelated to the stride used for saving the trace.
|
|
103
|
+
allow_overflow : bool, default False
|
|
104
|
+
If `False`, `inner_loop_length` must be a divisor of the total number of
|
|
105
|
+
iterations ``n_burn + n_skip * n_save``. If `True` and
|
|
106
|
+
`inner_loop_length` is not a divisor, some of the MCMC iterations in the
|
|
107
|
+
last outer loop iteration will not be saved to the trace.
|
|
108
|
+
inner_callback : callable, optional
|
|
109
|
+
outer_callback : callable, optional
|
|
110
|
+
Arbitrary functions run during the loop after updating the state.
|
|
111
|
+
`inner_callback` is called after each update, while `outer_callback` is
|
|
112
|
+
called after completing an inner loop. The callbacks are invoked with
|
|
113
|
+
the following arguments, passed by keyword:
|
|
60
114
|
|
|
61
115
|
bart : dict
|
|
62
116
|
The MCMC state just after updating it.
|
|
63
117
|
burnin : bool
|
|
64
118
|
Whether the last iteration was in the burn-in phase.
|
|
119
|
+
overflow : bool
|
|
120
|
+
Whether the last iteration was in the overflow phase (iterations
|
|
121
|
+
not saved due to `inner_loop_length` not being a divisor of the
|
|
122
|
+
total number of iterations).
|
|
65
123
|
i_total : int
|
|
66
|
-
The index of the last iteration (0-based).
|
|
124
|
+
The index of the last MCMC iteration (0-based).
|
|
67
125
|
i_skip : int
|
|
68
126
|
The number of MCMC updates from the last saved state. The initial
|
|
69
127
|
state counts as saved, even if it's not copied into the trace.
|
|
128
|
+
callback_state : jax pytree
|
|
129
|
+
The callback state, initially set to the argument passed to
|
|
130
|
+
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
131
|
+
of `inner_callback` or `outer_callback`.
|
|
70
132
|
n_burn, n_save, n_skip : int
|
|
71
133
|
The corresponding arguments as-is.
|
|
134
|
+
i_outer : int
|
|
135
|
+
The index of the last outer loop iteration (0-based).
|
|
136
|
+
inner_loop_length : int
|
|
137
|
+
The number of MCMC iterations in the inner loop.
|
|
138
|
+
|
|
139
|
+
`inner_callback` is called under the jax jit, so the argument values are
|
|
140
|
+
not available at the time the Python code is executed. Use the utilities
|
|
141
|
+
in `jax.debug` to access the values at actual runtime.
|
|
72
142
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
143
|
+
The callbacks must return two values:
|
|
144
|
+
|
|
145
|
+
bart : dict
|
|
146
|
+
A possibly modified MCMC state. To avoid modifying the state,
|
|
147
|
+
return the `bart` argument passed to the callback as-is.
|
|
148
|
+
callback_state : jax pytree
|
|
149
|
+
The new state to be passed on the next callback invocation.
|
|
150
|
+
|
|
151
|
+
For convenience, if a callback returns `None`, the states are not
|
|
152
|
+
updated.
|
|
153
|
+
callback_state : jax pytree, optional
|
|
154
|
+
The initial state for the callbacks.
|
|
155
|
+
onlymain_extractor : callable, optional
|
|
156
|
+
both_extractor : callable, optional
|
|
157
|
+
Functions that extract the variables to be saved respectively only in
|
|
158
|
+
the main trace and in both traces, given the MCMC state as argument.
|
|
159
|
+
Must return a pytree, and must be vmappable.
|
|
76
160
|
|
|
77
161
|
Returns
|
|
78
162
|
-------
|
|
@@ -82,12 +166,20 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
|
|
|
82
166
|
The trace of the burn-in phase, containing the following subset of
|
|
83
167
|
fields from the `bart` dictionary, with an additional head index that
|
|
84
168
|
runs over MCMC iterations: 'sigma2', 'grow_prop_count',
|
|
85
|
-
'grow_acc_count', 'prune_prop_count', 'prune_acc_count'
|
|
169
|
+
'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
|
|
170
|
+
the fields in `tracevars_both`).
|
|
86
171
|
main_trace : dict of (n_save, ...) arrays
|
|
87
172
|
The trace of the main phase, containing the following subset of fields
|
|
88
|
-
from the `bart` dictionary, with an additional head index that runs
|
|
89
|
-
|
|
90
|
-
the fields in `
|
|
173
|
+
from the `bart` dictionary, with an additional head index that runs over
|
|
174
|
+
MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
|
|
175
|
+
specified the fields in `tracevars_onlymain`), plus the fields in
|
|
176
|
+
`burnin_trace`.
|
|
177
|
+
|
|
178
|
+
Raises
|
|
179
|
+
------
|
|
180
|
+
ValueError
|
|
181
|
+
If `inner_loop_length` is not a divisor of the total number of
|
|
182
|
+
iterations and `allow_overflow` is `False`.
|
|
91
183
|
|
|
92
184
|
Notes
|
|
93
185
|
-----
|
|
@@ -95,49 +187,118 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
|
|
|
95
187
|
not include the initial state, and include the final state.
|
|
96
188
|
"""
|
|
97
189
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
190
|
+
def empty_trace(length, bart, extractor):
|
|
191
|
+
return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
192
|
+
|
|
193
|
+
trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
|
|
194
|
+
trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
|
|
195
|
+
|
|
196
|
+
# determine number of iterations for inner and outer loops
|
|
197
|
+
n_iters = n_burn + n_skip * n_save
|
|
198
|
+
if inner_loop_length is None:
|
|
199
|
+
inner_loop_length = n_iters
|
|
200
|
+
n_outer = n_iters // inner_loop_length
|
|
201
|
+
if n_iters % inner_loop_length:
|
|
202
|
+
if allow_overflow:
|
|
203
|
+
n_outer += 1
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
|
|
206
|
+
|
|
207
|
+
carry = (bart, 0, key, trace_both, trace_onlymain, callback_state)
|
|
208
|
+
for i_outer in range(n_outer):
|
|
209
|
+
carry = _run_mcmc_inner_loop(
|
|
210
|
+
carry,
|
|
211
|
+
inner_loop_length,
|
|
212
|
+
inner_callback,
|
|
213
|
+
onlymain_extractor,
|
|
214
|
+
both_extractor,
|
|
215
|
+
n_burn,
|
|
216
|
+
n_save,
|
|
217
|
+
n_skip,
|
|
218
|
+
i_outer,
|
|
219
|
+
)
|
|
220
|
+
if outer_callback is not None:
|
|
221
|
+
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
222
|
+
i_total -= 1 # because i_total is updated at the end of the inner loop
|
|
223
|
+
i_skip = _compute_i_skip(i_total, n_burn, n_skip)
|
|
224
|
+
rt = outer_callback(
|
|
225
|
+
bart=bart,
|
|
226
|
+
burnin=i_total < n_burn,
|
|
227
|
+
overflow=i_total >= n_iters,
|
|
228
|
+
i_total=i_total,
|
|
229
|
+
i_skip=i_skip,
|
|
230
|
+
callback_state=callback_state,
|
|
231
|
+
n_burn=n_burn,
|
|
232
|
+
n_save=n_save,
|
|
233
|
+
n_skip=n_skip,
|
|
234
|
+
i_outer=i_outer,
|
|
235
|
+
inner_loop_length=inner_loop_length,
|
|
236
|
+
)
|
|
237
|
+
if rt is not None:
|
|
238
|
+
bart, callback_state = rt
|
|
239
|
+
i_total += 1
|
|
240
|
+
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
241
|
+
|
|
242
|
+
bart, _, _, trace_both, trace_onlymain, _ = carry
|
|
243
|
+
|
|
244
|
+
burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
|
|
245
|
+
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
|
|
246
|
+
main_trace.update(trace_onlymain)
|
|
107
247
|
|
|
108
|
-
|
|
109
|
-
bart = {k: v for k, v in bart.items() if k in tracelist}
|
|
110
|
-
return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
248
|
+
return bart, burnin_trace, main_trace
|
|
111
249
|
|
|
112
|
-
trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
|
|
113
|
-
trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
|
|
114
250
|
|
|
115
|
-
|
|
251
|
+
def _compute_i_skip(i_total, n_burn, n_skip):
|
|
252
|
+
burnin = i_total < n_burn
|
|
253
|
+
return jnp.where(
|
|
254
|
+
burnin,
|
|
255
|
+
i_total + 1,
|
|
256
|
+
(i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
|
|
257
|
+
)
|
|
116
258
|
|
|
117
|
-
carry = (bart, 0, key, trace_light, trace_heavy)
|
|
118
259
|
|
|
260
|
+
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
|
|
261
|
+
def _run_mcmc_inner_loop(
|
|
262
|
+
carry,
|
|
263
|
+
inner_loop_length,
|
|
264
|
+
inner_callback,
|
|
265
|
+
onlymain_extractor,
|
|
266
|
+
both_extractor,
|
|
267
|
+
n_burn,
|
|
268
|
+
n_save,
|
|
269
|
+
n_skip,
|
|
270
|
+
i_outer,
|
|
271
|
+
):
|
|
119
272
|
def loop(carry, _):
|
|
120
|
-
bart, i_total, key,
|
|
273
|
+
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
121
274
|
|
|
122
|
-
|
|
123
|
-
|
|
275
|
+
keys = jaxext.split(key)
|
|
276
|
+
key = keys.pop()
|
|
277
|
+
bart = mcmcstep.step(keys.pop(), bart)
|
|
124
278
|
|
|
125
279
|
burnin = i_total < n_burn
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
280
|
+
if inner_callback is not None:
|
|
281
|
+
i_skip = _compute_i_skip(i_total, n_burn, n_skip)
|
|
282
|
+
rt = inner_callback(
|
|
283
|
+
bart=bart,
|
|
284
|
+
burnin=burnin,
|
|
285
|
+
overflow=i_total >= n_burn + n_save * n_skip,
|
|
286
|
+
i_total=i_total,
|
|
287
|
+
i_skip=i_skip,
|
|
288
|
+
callback_state=callback_state,
|
|
289
|
+
n_burn=n_burn,
|
|
290
|
+
n_save=n_save,
|
|
291
|
+
n_skip=n_skip,
|
|
292
|
+
i_outer=i_outer,
|
|
293
|
+
inner_loop_length=inner_loop_length,
|
|
294
|
+
)
|
|
295
|
+
if rt is not None:
|
|
296
|
+
bart, callback_state = rt
|
|
297
|
+
|
|
298
|
+
i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
|
|
299
|
+
i_both = jnp.where(burnin, i_total, n_burn + i_onlymain)
|
|
300
|
+
|
|
301
|
+
def update_trace(index, trace, state):
|
|
141
302
|
def assign_at_index(trace_array, state_array):
|
|
142
303
|
if trace_array.size:
|
|
143
304
|
return trace_array.at[index, ...].set(state_array)
|
|
@@ -147,84 +308,176 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
|
|
|
147
308
|
# of length 0
|
|
148
309
|
return trace_array
|
|
149
310
|
|
|
150
|
-
return tree.map(assign_at_index, trace,
|
|
311
|
+
return tree.map(assign_at_index, trace, state)
|
|
151
312
|
|
|
152
|
-
|
|
153
|
-
|
|
313
|
+
trace_onlymain = update_trace(
|
|
314
|
+
i_onlymain, trace_onlymain, onlymain_extractor(bart)
|
|
315
|
+
)
|
|
316
|
+
trace_both = update_trace(i_both, trace_both, both_extractor(bart))
|
|
154
317
|
|
|
155
318
|
i_total += 1
|
|
156
|
-
carry = (bart, i_total, key,
|
|
319
|
+
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
157
320
|
return carry, None
|
|
158
321
|
|
|
159
|
-
carry, _ = lax.scan(loop, carry, None,
|
|
160
|
-
|
|
161
|
-
bart, _, _, trace_light, trace_heavy = carry
|
|
322
|
+
carry, _ = lax.scan(loop, carry, None, inner_loop_length)
|
|
323
|
+
return carry
|
|
162
324
|
|
|
163
|
-
burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
|
|
164
|
-
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
|
|
165
|
-
main_trace.update(trace_heavy)
|
|
166
325
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
@functools.lru_cache
|
|
171
|
-
# cache to make the callback function object unique, such that the jit
|
|
172
|
-
# of run_mcmc recognizes it
|
|
173
|
-
def make_simple_print_callback(printevery):
|
|
326
|
+
def make_print_callbacks(dot_every_inner=1, report_every_outer=1):
|
|
174
327
|
"""
|
|
175
|
-
|
|
328
|
+
Prepare logging callbacks for `run_mcmc`.
|
|
329
|
+
|
|
330
|
+
Prepare callbacks which print a dot on every iteration, and a longer
|
|
331
|
+
report outer loop iteration.
|
|
176
332
|
|
|
177
333
|
Parameters
|
|
178
334
|
----------
|
|
179
|
-
|
|
180
|
-
|
|
335
|
+
dot_every_inner : int, default 1
|
|
336
|
+
A dot is printed every `dot_every_inner` MCMC iterations.
|
|
337
|
+
report_every_outer : int, default 1
|
|
338
|
+
A report is printed every `report_every_outer` outer loop
|
|
339
|
+
iterations.
|
|
181
340
|
|
|
182
341
|
Returns
|
|
183
342
|
-------
|
|
184
|
-
|
|
185
|
-
A
|
|
343
|
+
kwargs : dict
|
|
344
|
+
A dictionary with the arguments to pass to `run_mcmc` as keyword
|
|
345
|
+
arguments to set up the callbacks.
|
|
346
|
+
|
|
347
|
+
Examples
|
|
348
|
+
--------
|
|
349
|
+
>>> run_mcmc(..., **make_print_callbacks())
|
|
186
350
|
"""
|
|
351
|
+
return dict(
|
|
352
|
+
inner_callback=_print_callback_inner,
|
|
353
|
+
outer_callback=_print_callback_outer,
|
|
354
|
+
callback_state=dict(
|
|
355
|
+
dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
|
|
356
|
+
),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _print_callback_inner(*, i_total, callback_state, **_):
|
|
361
|
+
dot_every_inner = callback_state['dot_every_inner']
|
|
362
|
+
if dot_every_inner is not None:
|
|
363
|
+
cond = (i_total + 1) % dot_every_inner == 0
|
|
364
|
+
debug.callback(_print_dot, cond)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _print_dot(cond):
|
|
368
|
+
if cond:
|
|
369
|
+
print('.', end='', flush=True)
|
|
187
370
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
371
|
+
|
|
372
|
+
def _print_callback_outer(
|
|
373
|
+
*,
|
|
374
|
+
bart,
|
|
375
|
+
burnin,
|
|
376
|
+
overflow,
|
|
377
|
+
i_total,
|
|
378
|
+
n_burn,
|
|
379
|
+
n_save,
|
|
380
|
+
n_skip,
|
|
381
|
+
callback_state,
|
|
382
|
+
i_outer,
|
|
383
|
+
inner_loop_length,
|
|
384
|
+
**_,
|
|
385
|
+
):
|
|
386
|
+
report_every_outer = callback_state['report_every_outer']
|
|
387
|
+
if report_every_outer is not None:
|
|
388
|
+
dot_every_inner = callback_state['dot_every_inner']
|
|
389
|
+
if dot_every_inner is None:
|
|
390
|
+
newline = False
|
|
391
|
+
else:
|
|
392
|
+
newline = dot_every_inner < inner_loop_length
|
|
196
393
|
debug.callback(
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
394
|
+
_print_report,
|
|
395
|
+
cond=(i_outer + 1) % report_every_outer == 0,
|
|
396
|
+
newline=newline,
|
|
397
|
+
burnin=burnin,
|
|
398
|
+
overflow=overflow,
|
|
399
|
+
i_total=i_total,
|
|
400
|
+
n_iters=n_burn + n_save * n_skip,
|
|
401
|
+
grow_prop_count=bart.forest.grow_prop_count,
|
|
402
|
+
grow_acc_count=bart.forest.grow_acc_count,
|
|
403
|
+
prune_prop_count=bart.forest.prune_prop_count,
|
|
404
|
+
prune_acc_count=bart.forest.prune_acc_count,
|
|
405
|
+
prop_total=len(bart.forest.leaf_trees),
|
|
406
|
+
fill=grove.forest_fill(bart.forest.split_trees),
|
|
206
407
|
)
|
|
207
408
|
|
|
208
|
-
return callback
|
|
209
409
|
|
|
410
|
+
def _convert_jax_arrays_in_args(func):
|
|
411
|
+
"""Remove jax arrays from a function arguments.
|
|
210
412
|
|
|
211
|
-
|
|
212
|
-
|
|
413
|
+
Converts all jax.Array instances in the arguments to either Python scalars
|
|
414
|
+
or numpy arrays.
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
def convert_jax_arrays(pytree):
|
|
418
|
+
def convert_jax_arrays(val):
|
|
419
|
+
if not isinstance(val, jax.Array):
|
|
420
|
+
return val
|
|
421
|
+
elif val.shape:
|
|
422
|
+
return numpy.array(val)
|
|
423
|
+
else:
|
|
424
|
+
return val.item()
|
|
425
|
+
|
|
426
|
+
return tree.map(convert_jax_arrays, pytree)
|
|
427
|
+
|
|
428
|
+
@functools.wraps(func)
|
|
429
|
+
def new_func(*args, **kw):
|
|
430
|
+
args = convert_jax_arrays(args)
|
|
431
|
+
kw = convert_jax_arrays(kw)
|
|
432
|
+
return func(*args, **kw)
|
|
433
|
+
|
|
434
|
+
return new_func
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
@_convert_jax_arrays_in_args
|
|
438
|
+
# convert all jax arrays in arguments because operations on them could lead to
|
|
439
|
+
# deadlock with the main thread
|
|
440
|
+
def _print_report(
|
|
441
|
+
*,
|
|
442
|
+
cond,
|
|
443
|
+
newline,
|
|
444
|
+
burnin,
|
|
445
|
+
overflow,
|
|
446
|
+
i_total,
|
|
447
|
+
n_iters,
|
|
448
|
+
grow_prop_count,
|
|
449
|
+
grow_acc_count,
|
|
450
|
+
prune_prop_count,
|
|
451
|
+
prune_acc_count,
|
|
452
|
+
prop_total,
|
|
453
|
+
fill,
|
|
213
454
|
):
|
|
214
|
-
if
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
455
|
+
if cond:
|
|
456
|
+
newline = '\n' if newline else ''
|
|
457
|
+
|
|
458
|
+
def acc_string(acc_count, prop_count):
|
|
459
|
+
if prop_count:
|
|
460
|
+
return f'{acc_count / prop_count:.0%}'
|
|
461
|
+
else:
|
|
462
|
+
return ' n/d'
|
|
463
|
+
|
|
464
|
+
grow_prop = grow_prop_count / prop_total
|
|
465
|
+
prune_prop = prune_prop_count / prop_total
|
|
466
|
+
grow_acc = acc_string(grow_acc_count, grow_prop_count)
|
|
467
|
+
prune_acc = acc_string(prune_acc_count, prune_prop_count)
|
|
468
|
+
|
|
469
|
+
if burnin:
|
|
470
|
+
flag = ' (burnin)'
|
|
471
|
+
elif overflow:
|
|
472
|
+
flag = ' (overflow)'
|
|
473
|
+
else:
|
|
474
|
+
flag = ''
|
|
475
|
+
|
|
224
476
|
print(
|
|
225
|
-
f'
|
|
226
|
-
f'
|
|
227
|
-
f'
|
|
477
|
+
f'{newline}It {i_total + 1}/{n_iters} '
|
|
478
|
+
f'grow P={grow_prop:.0%} A={grow_acc}, '
|
|
479
|
+
f'prune P={prune_prop:.0%} A={prune_acc}, '
|
|
480
|
+
f'fill={fill:.0%}{flag}'
|
|
228
481
|
)
|
|
229
482
|
|
|
230
483
|
|
|
@@ -248,11 +501,11 @@ def evaluate_trace(trace, X):
|
|
|
248
501
|
evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
|
|
249
502
|
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
|
|
250
503
|
|
|
251
|
-
def loop(_,
|
|
504
|
+
def loop(_, row):
|
|
252
505
|
values = evaluate_trees(
|
|
253
|
-
X,
|
|
506
|
+
X, row['leaf_trees'], row['var_trees'], row['split_trees']
|
|
254
507
|
)
|
|
255
|
-
return None, jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
508
|
+
return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
256
509
|
|
|
257
510
|
_, y = lax.scan(loop, None, trace)
|
|
258
511
|
return y
|