bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +582 -279
- bartz/__init__.py +3 -3
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +168 -81
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +568 -158
- bartz/mcmcstep.py +1722 -926
- bartz/prepcovars.py +142 -44
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/METADATA +6 -5
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -374
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/mcmcloop.py
CHANGED
|
@@ -22,72 +22,222 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""
|
|
26
|
-
|
|
25
|
+
"""Functions that implement the full BART posterior MCMC loop.
|
|
26
|
+
|
|
27
|
+
The entry points are `run_mcmc` and `make_default_callback`.
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
|
-
import
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
from dataclasses import fields, replace
|
|
32
|
+
from functools import partial, wraps
|
|
33
|
+
from typing import Any, Protocol
|
|
30
34
|
|
|
31
35
|
import jax
|
|
32
|
-
|
|
36
|
+
import numpy
|
|
37
|
+
from equinox import Module
|
|
38
|
+
from jax import debug, lax, tree
|
|
33
39
|
from jax import numpy as jnp
|
|
40
|
+
from jax.nn import softmax
|
|
41
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
|
|
42
|
+
|
|
43
|
+
from bartz import grove, jaxext, mcmcstep
|
|
44
|
+
from bartz.mcmcstep import State
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BurninTrace(Module):
|
|
48
|
+
"""MCMC trace with only diagnostic values."""
|
|
49
|
+
|
|
50
|
+
sigma2: Float32[Array, '*trace_length'] | None
|
|
51
|
+
theta: Float32[Array, '*trace_length'] | None
|
|
52
|
+
grow_prop_count: Int32[Array, '*trace_length']
|
|
53
|
+
grow_acc_count: Int32[Array, '*trace_length']
|
|
54
|
+
prune_prop_count: Int32[Array, '*trace_length']
|
|
55
|
+
prune_acc_count: Int32[Array, '*trace_length']
|
|
56
|
+
log_likelihood: Float32[Array, '*trace_length'] | None
|
|
57
|
+
log_trans_prior: Float32[Array, '*trace_length'] | None
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_state(cls, state: State) -> 'BurninTrace':
|
|
61
|
+
"""Create a single-item burn-in trace from a MCMC state."""
|
|
62
|
+
return cls(
|
|
63
|
+
sigma2=state.sigma2,
|
|
64
|
+
theta=state.forest.theta,
|
|
65
|
+
grow_prop_count=state.forest.grow_prop_count,
|
|
66
|
+
grow_acc_count=state.forest.grow_acc_count,
|
|
67
|
+
prune_prop_count=state.forest.prune_prop_count,
|
|
68
|
+
prune_acc_count=state.forest.prune_acc_count,
|
|
69
|
+
log_likelihood=state.forest.log_likelihood,
|
|
70
|
+
log_trans_prior=state.forest.log_trans_prior,
|
|
71
|
+
)
|
|
34
72
|
|
|
35
|
-
|
|
73
|
+
|
|
74
|
+
class MainTrace(BurninTrace):
|
|
75
|
+
"""MCMC trace with trees and diagnostic values."""
|
|
76
|
+
|
|
77
|
+
leaf_tree: Float32[Array, '*trace_length 2**d']
|
|
78
|
+
var_tree: UInt[Array, '*trace_length 2**(d-1)']
|
|
79
|
+
split_tree: UInt[Array, '*trace_length 2**(d-1)']
|
|
80
|
+
offset: Float32[Array, '*trace_length']
|
|
81
|
+
varprob: Float32[Array, '*trace_length p'] | None
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def from_state(cls, state: State) -> 'MainTrace':
|
|
85
|
+
"""Create a single-item main trace from a MCMC state."""
|
|
86
|
+
# compute varprob
|
|
87
|
+
log_s = state.forest.log_s
|
|
88
|
+
if log_s is None:
|
|
89
|
+
varprob = None
|
|
90
|
+
else:
|
|
91
|
+
varprob = softmax(log_s, where=state.forest.max_split.astype(bool))
|
|
92
|
+
|
|
93
|
+
return cls(
|
|
94
|
+
leaf_tree=state.forest.leaf_tree,
|
|
95
|
+
var_tree=state.forest.var_tree,
|
|
96
|
+
split_tree=state.forest.split_tree,
|
|
97
|
+
offset=state.offset,
|
|
98
|
+
varprob=varprob,
|
|
99
|
+
**vars(BurninTrace.from_state(state)),
|
|
100
|
+
)
|
|
36
101
|
|
|
37
102
|
|
|
38
|
-
|
|
39
|
-
|
|
103
|
+
CallbackState = PyTree[Any, 'T']
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Callback(Protocol):
|
|
107
|
+
"""Callback type for `run_mcmc`."""
|
|
108
|
+
|
|
109
|
+
def __call__(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
key: Key[Array, ''],
|
|
113
|
+
bart: State,
|
|
114
|
+
burnin: Bool[Array, ''],
|
|
115
|
+
i_total: Int32[Array, ''],
|
|
116
|
+
i_skip: Int32[Array, ''],
|
|
117
|
+
callback_state: CallbackState,
|
|
118
|
+
n_burn: Int32[Array, ''],
|
|
119
|
+
n_save: Int32[Array, ''],
|
|
120
|
+
n_skip: Int32[Array, ''],
|
|
121
|
+
i_outer: Int32[Array, ''],
|
|
122
|
+
inner_loop_length: int,
|
|
123
|
+
) -> tuple[State, CallbackState] | None:
|
|
124
|
+
"""Do an arbitrary action after an iteration of the MCMC.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
key
|
|
129
|
+
A key for random number generation.
|
|
130
|
+
bart
|
|
131
|
+
The MCMC state just after updating it.
|
|
132
|
+
burnin
|
|
133
|
+
Whether the last iteration was in the burn-in phase.
|
|
134
|
+
i_total
|
|
135
|
+
The index of the last MCMC iteration (0-based).
|
|
136
|
+
i_skip
|
|
137
|
+
The number of MCMC updates from the last saved state. The initial
|
|
138
|
+
state counts as saved, even if it's not copied into the trace.
|
|
139
|
+
callback_state
|
|
140
|
+
The callback state, initially set to the argument passed to
|
|
141
|
+
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
142
|
+
of the callback.
|
|
143
|
+
n_burn
|
|
144
|
+
n_save
|
|
145
|
+
n_skip
|
|
146
|
+
The corresponding `run_mcmc` arguments as-is.
|
|
147
|
+
i_outer
|
|
148
|
+
The index of the last outer loop iteration (0-based).
|
|
149
|
+
inner_loop_length
|
|
150
|
+
The number of MCMC iterations in the inner loop.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
bart : State
|
|
155
|
+
A possibly modified MCMC state. To avoid modifying the state,
|
|
156
|
+
return the `bart` argument passed to the callback as-is.
|
|
157
|
+
callback_state : CallbackState
|
|
158
|
+
The new state to be passed on the next callback invocation.
|
|
159
|
+
|
|
160
|
+
Notes
|
|
161
|
+
-----
|
|
162
|
+
For convenience, the callback may return `None`, and the states won't
|
|
163
|
+
be updated.
|
|
164
|
+
"""
|
|
165
|
+
...
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class _Carry(Module):
|
|
169
|
+
"""Carry used in the loop in `run_mcmc`."""
|
|
170
|
+
|
|
171
|
+
bart: State
|
|
172
|
+
i_total: Int32[Array, '']
|
|
173
|
+
key: Key[Array, '']
|
|
174
|
+
burnin_trace: PyTree[Shaped[Array, 'n_burn *']]
|
|
175
|
+
main_trace: PyTree[Shaped[Array, 'n_save *']]
|
|
176
|
+
callback_state: CallbackState
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def run_mcmc(
|
|
180
|
+
key: Key[Array, ''],
|
|
181
|
+
bart: State,
|
|
182
|
+
n_save: int,
|
|
183
|
+
*,
|
|
184
|
+
n_burn: int = 0,
|
|
185
|
+
n_skip: int = 1,
|
|
186
|
+
inner_loop_length: int | None = None,
|
|
187
|
+
callback: Callback | None = None,
|
|
188
|
+
callback_state: CallbackState = None,
|
|
189
|
+
burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
|
|
190
|
+
main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
|
|
191
|
+
) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
|
|
40
192
|
"""
|
|
41
193
|
Run the MCMC for the BART posterior.
|
|
42
194
|
|
|
43
195
|
Parameters
|
|
44
196
|
----------
|
|
45
|
-
key
|
|
46
|
-
|
|
47
|
-
bart
|
|
197
|
+
key
|
|
198
|
+
A key for random number generation.
|
|
199
|
+
bart
|
|
48
200
|
The initial MCMC state, as created and updated by the functions in
|
|
49
|
-
`bartz.mcmcstep`.
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
n_save
|
|
201
|
+
`bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
|
|
202
|
+
so this variable is invalidated after running `run_mcmc`. Make a copy
|
|
203
|
+
beforehand to use it again.
|
|
204
|
+
n_save
|
|
53
205
|
The number of iterations to save.
|
|
54
|
-
|
|
206
|
+
n_burn
|
|
207
|
+
The number of initial iterations which are not saved.
|
|
208
|
+
n_skip
|
|
55
209
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
56
210
|
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
211
|
+
inner_loop_length
|
|
212
|
+
The MCMC loop is split into an outer and an inner loop. The outer loop
|
|
213
|
+
is in Python, while the inner loop is in JAX. `inner_loop_length` is the
|
|
214
|
+
number of iterations of the inner loop to run for each iteration of the
|
|
215
|
+
outer loop. If not specified, the outer loop will iterate just once,
|
|
216
|
+
with all iterations done in a single inner loop run. The inner stride is
|
|
217
|
+
unrelated to the stride used for saving the trace.
|
|
218
|
+
callback
|
|
219
|
+
An arbitrary function run during the loop after updating the state. For
|
|
220
|
+
the signature, see `Callback`. The callback is called under the jax jit,
|
|
221
|
+
so the argument values are not available at the time the Python code is
|
|
222
|
+
executed. Use the utilities in `jax.debug` to access the values at
|
|
223
|
+
actual runtime. The callback may return new values for the MCMC state
|
|
224
|
+
and the callback state.
|
|
225
|
+
callback_state
|
|
226
|
+
The initial custom state for the callback.
|
|
227
|
+
burnin_extractor
|
|
228
|
+
main_extractor
|
|
229
|
+
Functions that extract the variables to be saved respectively only in
|
|
230
|
+
the main trace and in both traces, given the MCMC state as argument.
|
|
231
|
+
Must return a pytree, and must be vmappable.
|
|
76
232
|
|
|
77
233
|
Returns
|
|
78
234
|
-------
|
|
79
|
-
bart :
|
|
235
|
+
bart : State
|
|
80
236
|
The final MCMC state.
|
|
81
|
-
burnin_trace :
|
|
82
|
-
The trace of the burn-in phase
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
|
|
86
|
-
main_trace : dict of (n_save, ...) arrays
|
|
87
|
-
The trace of the main phase, containing the following subset of fields
|
|
88
|
-
from the `bart` dictionary, with an additional head index that runs
|
|
89
|
-
over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
|
|
90
|
-
the fields in `burnin_trace`.
|
|
237
|
+
burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
|
|
238
|
+
The trace of the burn-in phase. For the default layout, see `BurninTrace`.
|
|
239
|
+
main_trace : PyTree[Shaped[Array, 'n_save *']]
|
|
240
|
+
The trace of the main phase. For the default layout, see `MainTrace`.
|
|
91
241
|
|
|
92
242
|
Notes
|
|
93
243
|
-----
|
|
@@ -95,164 +245,424 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
|
|
|
95
245
|
not include the initial state, and include the final state.
|
|
96
246
|
"""
|
|
97
247
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
248
|
+
def empty_trace(length, bart, extractor):
|
|
249
|
+
return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
250
|
+
|
|
251
|
+
burnin_trace = empty_trace(n_burn, bart, burnin_extractor)
|
|
252
|
+
main_trace = empty_trace(n_save, bart, main_extractor)
|
|
253
|
+
|
|
254
|
+
# determine number of iterations for inner and outer loops
|
|
255
|
+
n_iters = n_burn + n_skip * n_save
|
|
256
|
+
if inner_loop_length is None:
|
|
257
|
+
inner_loop_length = n_iters
|
|
258
|
+
if inner_loop_length:
|
|
259
|
+
n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
|
|
260
|
+
else:
|
|
261
|
+
n_outer = 1
|
|
262
|
+
# setting to 0 would make for a clean noop, but it's useful to keep the
|
|
263
|
+
# same code path for benchmarking and testing
|
|
264
|
+
|
|
265
|
+
carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state)
|
|
266
|
+
for i_outer in range(n_outer):
|
|
267
|
+
carry = _run_mcmc_inner_loop(
|
|
268
|
+
carry,
|
|
269
|
+
inner_loop_length,
|
|
270
|
+
callback,
|
|
271
|
+
burnin_extractor,
|
|
272
|
+
main_extractor,
|
|
273
|
+
n_burn,
|
|
274
|
+
n_save,
|
|
275
|
+
n_skip,
|
|
276
|
+
i_outer,
|
|
277
|
+
n_iters,
|
|
278
|
+
)
|
|
114
279
|
|
|
115
|
-
|
|
280
|
+
return carry.bart, carry.burnin_trace, carry.main_trace
|
|
116
281
|
|
|
117
|
-
carry = (bart, 0, key, trace_light, trace_heavy)
|
|
118
282
|
|
|
119
|
-
|
|
120
|
-
|
|
283
|
+
def _compute_i_skip(
|
|
284
|
+
i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
|
|
285
|
+
) -> Int32[Array, '']:
|
|
286
|
+
"""Compute the `i_skip` argument passed to `callback`."""
|
|
287
|
+
burnin = i_total < n_burn
|
|
288
|
+
return jnp.where(
|
|
289
|
+
burnin,
|
|
290
|
+
i_total + 1,
|
|
291
|
+
(i_total - n_burn + 1) % n_skip
|
|
292
|
+
+ jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
|
|
293
|
+
)
|
|
121
294
|
|
|
122
|
-
key, subkey = random.split(key)
|
|
123
|
-
bart = mcmcstep.step(subkey, bart)
|
|
124
295
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
296
|
+
@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
|
|
297
|
+
def _run_mcmc_inner_loop(
|
|
298
|
+
carry: _Carry,
|
|
299
|
+
inner_loop_length: int,
|
|
300
|
+
callback: Callback | None,
|
|
301
|
+
burnin_extractor: Callable[[State], PyTree],
|
|
302
|
+
main_extractor: Callable[[State], PyTree],
|
|
303
|
+
n_burn: Int32[Array, ''],
|
|
304
|
+
n_save: Int32[Array, ''],
|
|
305
|
+
n_skip: Int32[Array, ''],
|
|
306
|
+
i_outer: Int32[Array, ''],
|
|
307
|
+
n_iters: Int32[Array, ''],
|
|
308
|
+
):
|
|
309
|
+
def loop_impl(carry: _Carry) -> _Carry:
|
|
310
|
+
"""Loop body to run if i_total < n_iters."""
|
|
311
|
+
# split random key
|
|
312
|
+
keys = jaxext.split(carry.key, 3)
|
|
313
|
+
carry = replace(carry, key=keys.pop())
|
|
314
|
+
|
|
315
|
+
# update state
|
|
316
|
+
carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart))
|
|
317
|
+
|
|
318
|
+
burnin = carry.i_total < n_burn
|
|
319
|
+
|
|
320
|
+
# invoke callback
|
|
321
|
+
if callback is not None:
|
|
322
|
+
i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
|
|
323
|
+
rt = callback(
|
|
324
|
+
key=keys.pop(),
|
|
325
|
+
bart=carry.bart,
|
|
326
|
+
burnin=burnin,
|
|
327
|
+
i_total=carry.i_total,
|
|
328
|
+
i_skip=i_skip,
|
|
329
|
+
callback_state=carry.callback_state,
|
|
330
|
+
n_burn=n_burn,
|
|
331
|
+
n_save=n_save,
|
|
332
|
+
n_skip=n_skip,
|
|
333
|
+
i_outer=i_outer,
|
|
334
|
+
inner_loop_length=inner_loop_length,
|
|
335
|
+
)
|
|
336
|
+
if rt is not None:
|
|
337
|
+
bart, callback_state = rt
|
|
338
|
+
carry = replace(carry, bart=bart, callback_state=callback_state)
|
|
339
|
+
|
|
340
|
+
def save_to_burnin_trace() -> tuple[PyTree, PyTree]:
|
|
341
|
+
return _pytree_at_set(
|
|
342
|
+
carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
|
|
343
|
+
), carry.main_trace
|
|
344
|
+
|
|
345
|
+
def save_to_main_trace() -> tuple[PyTree, PyTree]:
|
|
346
|
+
idx = (carry.i_total - n_burn) // n_skip
|
|
347
|
+
return carry.burnin_trace, _pytree_at_set(
|
|
348
|
+
carry.main_trace, idx, main_extractor(carry.bart)
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# save state to trace
|
|
352
|
+
burnin_trace, main_trace = lax.cond(
|
|
353
|
+
burnin, save_to_burnin_trace, save_to_main_trace
|
|
130
354
|
)
|
|
131
|
-
|
|
132
|
-
|
|
355
|
+
return replace(
|
|
356
|
+
carry,
|
|
357
|
+
i_total=carry.i_total + 1,
|
|
358
|
+
burnin_trace=burnin_trace,
|
|
359
|
+
main_trace=main_trace,
|
|
133
360
|
)
|
|
134
361
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def update_trace(index, trace, bart):
|
|
139
|
-
bart = {k: v for k, v in bart.items() if k in trace}
|
|
362
|
+
def loop_noop(carry: _Carry) -> _Carry:
|
|
363
|
+
"""Loop body to run if i_total >= n_iters; it does nothing."""
|
|
364
|
+
return carry
|
|
140
365
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
return trace_array.at[index, ...].set(state_array)
|
|
144
|
-
else:
|
|
145
|
-
# this handles the case where a trace is empty (e.g.,
|
|
146
|
-
# no burn-in) because jax refuses to index into an array
|
|
147
|
-
# of length 0
|
|
148
|
-
return trace_array
|
|
149
|
-
|
|
150
|
-
return tree.map(assign_at_index, trace, bart)
|
|
151
|
-
|
|
152
|
-
trace_heavy = update_trace(i_heavy, trace_heavy, bart)
|
|
153
|
-
trace_light = update_trace(i_light, trace_light, bart)
|
|
154
|
-
|
|
155
|
-
i_total += 1
|
|
156
|
-
carry = (bart, i_total, key, trace_light, trace_heavy)
|
|
366
|
+
def loop(carry: _Carry, _) -> tuple[_Carry, None]:
|
|
367
|
+
carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry)
|
|
157
368
|
return carry, None
|
|
158
369
|
|
|
159
|
-
carry, _ = lax.scan(loop, carry, None,
|
|
370
|
+
carry, _ = lax.scan(loop, carry, None, inner_loop_length)
|
|
371
|
+
return carry
|
|
160
372
|
|
|
161
|
-
bart, _, _, trace_light, trace_heavy = carry
|
|
162
373
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
374
|
+
def _pytree_at_set(
|
|
375
|
+
dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array]
|
|
376
|
+
) -> PyTree[Array, ' T']:
|
|
377
|
+
"""Map ``dest.at[index].set(val)`` over pytrees."""
|
|
166
378
|
|
|
167
|
-
|
|
379
|
+
def at_set(dest, val):
|
|
380
|
+
if dest.size:
|
|
381
|
+
return dest.at[index, ...].set(val)
|
|
382
|
+
else:
|
|
383
|
+
# this handles the case where an array is empty because jax refuses
|
|
384
|
+
# to index into an array of length 0, even if just in the abstract
|
|
385
|
+
return dest
|
|
168
386
|
|
|
387
|
+
return tree.map(at_set, dest, val)
|
|
169
388
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
389
|
+
|
|
390
|
+
def make_default_callback(
|
|
391
|
+
*,
|
|
392
|
+
dot_every: int | Integer[Array, ''] | None = 1,
|
|
393
|
+
report_every: int | Integer[Array, ''] | None = 100,
|
|
394
|
+
sparse_on_at: int | Integer[Array, ''] | None = None,
|
|
395
|
+
) -> dict[str, Any]:
|
|
174
396
|
"""
|
|
175
|
-
|
|
397
|
+
Prepare a default callback for `run_mcmc`.
|
|
398
|
+
|
|
399
|
+
The callback prints a dot on every iteration, and a longer
|
|
400
|
+
report outer loop iteration, and can do variable selection.
|
|
176
401
|
|
|
177
402
|
Parameters
|
|
178
403
|
----------
|
|
179
|
-
|
|
180
|
-
|
|
404
|
+
dot_every
|
|
405
|
+
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
406
|
+
report_every
|
|
407
|
+
A one line report is printed every `report_every` MCMC iterations,
|
|
408
|
+
`None` to disable.
|
|
409
|
+
sparse_on_at
|
|
410
|
+
If specified, variable selection is activated starting from this
|
|
411
|
+
iteration. If `None`, variable selection is not used.
|
|
181
412
|
|
|
182
413
|
Returns
|
|
183
414
|
-------
|
|
184
|
-
|
|
185
|
-
|
|
415
|
+
A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
|
|
416
|
+
|
|
417
|
+
Examples
|
|
418
|
+
--------
|
|
419
|
+
>>> run_mcmc(..., **make_default_callback())
|
|
420
|
+
"""
|
|
421
|
+
|
|
422
|
+
def asarray_or_none(val: None | Any) -> None | Array:
|
|
423
|
+
return None if val is None else jnp.asarray(val)
|
|
424
|
+
|
|
425
|
+
def callback(*, bart, callback_state, **kwargs):
|
|
426
|
+
print_state, sparse_state = callback_state
|
|
427
|
+
bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs)
|
|
428
|
+
print_callback(callback_state=print_state, bart=bart, **kwargs)
|
|
429
|
+
return bart, callback_state
|
|
430
|
+
# here I assume that the callbacks don't update their states
|
|
431
|
+
|
|
432
|
+
return dict(
|
|
433
|
+
callback=callback,
|
|
434
|
+
callback_state=(
|
|
435
|
+
PrintCallbackState(
|
|
436
|
+
asarray_or_none(dot_every), asarray_or_none(report_every)
|
|
437
|
+
),
|
|
438
|
+
SparseCallbackState(asarray_or_none(sparse_on_at)),
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class PrintCallbackState(Module):
|
|
444
|
+
"""State for `print_callback`.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
dot_every
|
|
449
|
+
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
450
|
+
report_every
|
|
451
|
+
A one line report is printed every `report_every` MCMC iterations,
|
|
452
|
+
`None` to disable.
|
|
186
453
|
"""
|
|
187
454
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
455
|
+
dot_every: Int32[Array, ''] | None
|
|
456
|
+
report_every: Int32[Array, ''] | None
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def print_callback(
|
|
460
|
+
*,
|
|
461
|
+
bart: State,
|
|
462
|
+
burnin: Bool[Array, ''],
|
|
463
|
+
i_total: Int32[Array, ''],
|
|
464
|
+
n_burn: Int32[Array, ''],
|
|
465
|
+
n_save: Int32[Array, ''],
|
|
466
|
+
n_skip: Int32[Array, ''],
|
|
467
|
+
callback_state: PrintCallbackState,
|
|
468
|
+
**_,
|
|
469
|
+
):
|
|
470
|
+
"""Print a dot and/or a report periodically during the MCMC."""
|
|
471
|
+
if callback_state.dot_every is not None:
|
|
472
|
+
cond = (i_total + 1) % callback_state.dot_every == 0
|
|
473
|
+
lax.cond(
|
|
474
|
+
cond,
|
|
475
|
+
lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
|
|
476
|
+
# logging can't do in-line printing so I'll stick to print
|
|
477
|
+
lambda: None,
|
|
206
478
|
)
|
|
207
479
|
|
|
208
|
-
|
|
480
|
+
if callback_state.report_every is not None:
|
|
481
|
+
|
|
482
|
+
def print_report():
|
|
483
|
+
debug.callback(
|
|
484
|
+
_print_report,
|
|
485
|
+
newline=callback_state.dot_every is not None,
|
|
486
|
+
burnin=burnin,
|
|
487
|
+
i_total=i_total,
|
|
488
|
+
n_iters=n_burn + n_save * n_skip,
|
|
489
|
+
grow_prop_count=bart.forest.grow_prop_count,
|
|
490
|
+
grow_acc_count=bart.forest.grow_acc_count,
|
|
491
|
+
prune_prop_count=bart.forest.prune_prop_count,
|
|
492
|
+
prune_acc_count=bart.forest.prune_acc_count,
|
|
493
|
+
prop_total=len(bart.forest.leaf_tree),
|
|
494
|
+
fill=grove.forest_fill(bart.forest.split_tree),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
cond = (i_total + 1) % callback_state.report_every == 0
|
|
498
|
+
lax.cond(cond, print_report, lambda: None)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _convert_jax_arrays_in_args(func: Callable) -> Callable:
|
|
502
|
+
"""Remove jax arrays from a function arguments.
|
|
503
|
+
|
|
504
|
+
Converts all `jax.Array` instances in the arguments to either Python scalars
|
|
505
|
+
or numpy arrays.
|
|
506
|
+
"""
|
|
507
|
+
|
|
508
|
+
def convert_jax_arrays(pytree: PyTree) -> PyTree:
|
|
509
|
+
def convert_jax_arrays(val: Any) -> Any:
|
|
510
|
+
if not isinstance(val, jax.Array):
|
|
511
|
+
return val
|
|
512
|
+
elif val.shape:
|
|
513
|
+
return numpy.array(val)
|
|
514
|
+
else:
|
|
515
|
+
return val.item()
|
|
516
|
+
|
|
517
|
+
return tree.map(convert_jax_arrays, pytree)
|
|
518
|
+
|
|
519
|
+
@wraps(func)
|
|
520
|
+
def new_func(*args, **kw):
|
|
521
|
+
args = convert_jax_arrays(args)
|
|
522
|
+
kw = convert_jax_arrays(kw)
|
|
523
|
+
return func(*args, **kw)
|
|
524
|
+
|
|
525
|
+
return new_func
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@_convert_jax_arrays_in_args
|
|
529
|
+
# convert all jax arrays in arguments because operations on them could lead to
|
|
530
|
+
# deadlock with the main thread
|
|
531
|
+
def _print_report(
|
|
532
|
+
*,
|
|
533
|
+
newline: bool,
|
|
534
|
+
burnin: bool,
|
|
535
|
+
i_total: int,
|
|
536
|
+
n_iters: int,
|
|
537
|
+
grow_prop_count: int,
|
|
538
|
+
grow_acc_count: int,
|
|
539
|
+
prune_prop_count: int,
|
|
540
|
+
prune_acc_count: int,
|
|
541
|
+
prop_total: int,
|
|
542
|
+
fill: float,
|
|
543
|
+
):
|
|
544
|
+
"""Print the report for `print_callback`."""
|
|
545
|
+
|
|
546
|
+
def acc_string(acc_count, prop_count):
|
|
547
|
+
if prop_count:
|
|
548
|
+
return f'{acc_count / prop_count:.0%}'
|
|
549
|
+
else:
|
|
550
|
+
return 'n/d'
|
|
551
|
+
|
|
552
|
+
grow_prop = grow_prop_count / prop_total
|
|
553
|
+
prune_prop = prune_prop_count / prop_total
|
|
554
|
+
grow_acc = acc_string(grow_acc_count, grow_prop_count)
|
|
555
|
+
prune_acc = acc_string(prune_acc_count, prune_prop_count)
|
|
556
|
+
|
|
557
|
+
prefix = '\n' if newline else ''
|
|
558
|
+
suffix = ' (burnin)' if burnin else ''
|
|
559
|
+
|
|
560
|
+
print( # noqa: T201, see print_callback for why not logging
|
|
561
|
+
f'{prefix}It {i_total + 1}/{n_iters} '
|
|
562
|
+
f'grow P={grow_prop:.0%} A={grow_acc}, '
|
|
563
|
+
f'prune P={prune_prop:.0%} A={prune_acc}, '
|
|
564
|
+
f'fill={fill:.0%}{suffix}'
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class SparseCallbackState(Module):
|
|
569
|
+
"""State for `sparse_callback`.
|
|
570
|
+
|
|
571
|
+
Parameters
|
|
572
|
+
----------
|
|
573
|
+
sparse_on_at
|
|
574
|
+
If specified, variable selection is activated starting from this
|
|
575
|
+
iteration. If `None`, variable selection is not used.
|
|
576
|
+
"""
|
|
577
|
+
|
|
578
|
+
sparse_on_at: Int32[Array, ''] | None
|
|
209
579
|
|
|
210
580
|
|
|
211
|
-
def
|
|
212
|
-
|
|
581
|
+
def sparse_callback(
|
|
582
|
+
*,
|
|
583
|
+
key: Key[Array, ''],
|
|
584
|
+
bart: State,
|
|
585
|
+
i_total: Int32[Array, ''],
|
|
586
|
+
callback_state: SparseCallbackState,
|
|
587
|
+
**_,
|
|
213
588
|
):
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
# originating when jax is combined with some outdated dependencies. (I
|
|
221
|
-
# did not track down which dependencies exactly.) Doing .item() makes
|
|
222
|
-
# the + 1 operation be done by Python instead of by jax. The bug is that
|
|
223
|
-
# jax hangs completely, with a secondary thread blocked at this line.
|
|
224
|
-
print(
|
|
225
|
-
f'Iteration {i_str}/{total_str} '
|
|
226
|
-
f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
|
|
227
|
-
f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
|
|
589
|
+
"""Perform variable selection, see `mcmcstep.step_sparse`."""
|
|
590
|
+
if callback_state.sparse_on_at is not None:
|
|
591
|
+
bart = lax.cond(
|
|
592
|
+
i_total < callback_state.sparse_on_at,
|
|
593
|
+
lambda: bart,
|
|
594
|
+
lambda: mcmcstep.step_sparse(key, bart),
|
|
228
595
|
)
|
|
596
|
+
return bart, callback_state
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class Trace(grove.TreeHeaps, Protocol):
|
|
600
|
+
"""Protocol for a MCMC trace."""
|
|
601
|
+
|
|
602
|
+
offset: Float32[Array, ' trace_length']
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
class TreesTrace(Module):
|
|
606
|
+
"""Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
|
|
607
|
+
|
|
608
|
+
leaf_tree: Float32[Array, 'trace_length num_trees 2**d']
|
|
609
|
+
var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
|
|
610
|
+
split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
|
|
611
|
+
|
|
612
|
+
@classmethod
|
|
613
|
+
def from_dataclass(cls, obj: grove.TreeHeaps):
|
|
614
|
+
"""Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
|
|
615
|
+
return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
|
|
229
616
|
|
|
230
617
|
|
|
231
618
|
@jax.jit
|
|
232
|
-
def evaluate_trace(
|
|
619
|
+
def evaluate_trace(
|
|
620
|
+
trace: Trace, X: UInt[Array, 'p n']
|
|
621
|
+
) -> Float32[Array, 'trace_length n']:
|
|
233
622
|
"""
|
|
234
623
|
Compute predictions for all iterations of the BART MCMC.
|
|
235
624
|
|
|
236
625
|
Parameters
|
|
237
626
|
----------
|
|
238
|
-
trace
|
|
627
|
+
trace
|
|
239
628
|
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
240
|
-
X
|
|
629
|
+
X
|
|
241
630
|
The predictors matrix, with `p` predictors and `n` observations.
|
|
242
631
|
|
|
243
632
|
Returns
|
|
244
633
|
-------
|
|
245
|
-
|
|
246
|
-
The predictions for each iteration of the MCMC.
|
|
634
|
+
The predictions for each iteration of the MCMC.
|
|
247
635
|
"""
|
|
248
|
-
evaluate_trees =
|
|
249
|
-
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0
|
|
636
|
+
evaluate_trees = partial(grove.evaluate_forest, sum_trees=False)
|
|
637
|
+
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0))
|
|
638
|
+
trees = TreesTrace.from_dataclass(trace)
|
|
250
639
|
|
|
251
|
-
def loop(_,
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
)
|
|
255
|
-
return None, jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
640
|
+
def loop(_, item):
|
|
641
|
+
offset, trees = item
|
|
642
|
+
values = evaluate_trees(X, trees)
|
|
643
|
+
return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
256
644
|
|
|
257
|
-
_, y = lax.scan(loop, None, trace)
|
|
645
|
+
_, y = lax.scan(loop, None, (trace.offset, trees))
|
|
258
646
|
return y
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
650
|
+
def compute_varcount(
|
|
651
|
+
p: int, trace: grove.TreeHeaps
|
|
652
|
+
) -> Int32[Array, 'trace_length {p}']:
|
|
653
|
+
"""
|
|
654
|
+
Count how many times each predictor is used in each MCMC state.
|
|
655
|
+
|
|
656
|
+
Parameters
|
|
657
|
+
----------
|
|
658
|
+
p
|
|
659
|
+
The number of predictors.
|
|
660
|
+
trace
|
|
661
|
+
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
662
|
+
|
|
663
|
+
Returns
|
|
664
|
+
-------
|
|
665
|
+
Histogram of predictor usage in each MCMC state.
|
|
666
|
+
"""
|
|
667
|
+
vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0))
|
|
668
|
+
return vmapped_var_histogram(p, trace.var_tree, trace.split_tree)
|