bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +464 -254
- bartz/__init__.py +2 -2
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +139 -93
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +468 -311
- bartz/mcmcstep.py +734 -453
- bartz/prepcovars.py +139 -43
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/METADATA +2 -3
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.6.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -423
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/mcmcloop.py
CHANGED
|
@@ -22,164 +22,222 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""Functions that implement the full BART posterior MCMC loop.
|
|
25
|
+
"""Functions that implement the full BART posterior MCMC loop.
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
The entry points are `run_mcmc` and `make_default_callback`.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
from dataclasses import fields, replace
|
|
32
|
+
from functools import partial, wraps
|
|
33
|
+
from typing import Any, Protocol
|
|
28
34
|
|
|
29
35
|
import jax
|
|
30
36
|
import numpy
|
|
37
|
+
from equinox import Module
|
|
31
38
|
from jax import debug, lax, tree
|
|
32
39
|
from jax import numpy as jnp
|
|
33
|
-
from
|
|
40
|
+
from jax.nn import softmax
|
|
41
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
|
|
42
|
+
|
|
43
|
+
from bartz import grove, jaxext, mcmcstep
|
|
44
|
+
from bartz.mcmcstep import State
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BurninTrace(Module):
|
|
48
|
+
"""MCMC trace with only diagnostic values."""
|
|
49
|
+
|
|
50
|
+
sigma2: Float32[Array, '*trace_length'] | None
|
|
51
|
+
theta: Float32[Array, '*trace_length'] | None
|
|
52
|
+
grow_prop_count: Int32[Array, '*trace_length']
|
|
53
|
+
grow_acc_count: Int32[Array, '*trace_length']
|
|
54
|
+
prune_prop_count: Int32[Array, '*trace_length']
|
|
55
|
+
prune_acc_count: Int32[Array, '*trace_length']
|
|
56
|
+
log_likelihood: Float32[Array, '*trace_length'] | None
|
|
57
|
+
log_trans_prior: Float32[Array, '*trace_length'] | None
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_state(cls, state: State) -> 'BurninTrace':
|
|
61
|
+
"""Create a single-item burn-in trace from a MCMC state."""
|
|
62
|
+
return cls(
|
|
63
|
+
sigma2=state.sigma2,
|
|
64
|
+
theta=state.forest.theta,
|
|
65
|
+
grow_prop_count=state.forest.grow_prop_count,
|
|
66
|
+
grow_acc_count=state.forest.grow_acc_count,
|
|
67
|
+
prune_prop_count=state.forest.prune_prop_count,
|
|
68
|
+
prune_acc_count=state.forest.prune_acc_count,
|
|
69
|
+
log_likelihood=state.forest.log_likelihood,
|
|
70
|
+
log_trans_prior=state.forest.log_trans_prior,
|
|
71
|
+
)
|
|
34
72
|
|
|
35
|
-
from . import grove, jaxext, mcmcstep
|
|
36
|
-
from .mcmcstep import State
|
|
37
73
|
|
|
74
|
+
class MainTrace(BurninTrace):
|
|
75
|
+
"""MCMC trace with trees and diagnostic values."""
|
|
38
76
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
split_trees=state.forest.split_trees,
|
|
45
|
-
offset=state.offset,
|
|
46
|
-
)
|
|
77
|
+
leaf_tree: Float32[Array, '*trace_length 2**d']
|
|
78
|
+
var_tree: UInt[Array, '*trace_length 2**(d-1)']
|
|
79
|
+
split_tree: UInt[Array, '*trace_length 2**(d-1)']
|
|
80
|
+
offset: Float32[Array, '*trace_length']
|
|
81
|
+
varprob: Float32[Array, '*trace_length p'] | None
|
|
47
82
|
|
|
83
|
+
@classmethod
|
|
84
|
+
def from_state(cls, state: State) -> 'MainTrace':
|
|
85
|
+
"""Create a single-item main trace from a MCMC state."""
|
|
86
|
+
# compute varprob
|
|
87
|
+
log_s = state.forest.log_s
|
|
88
|
+
if log_s is None:
|
|
89
|
+
varprob = None
|
|
90
|
+
else:
|
|
91
|
+
varprob = softmax(log_s, where=state.forest.max_split.astype(bool))
|
|
92
|
+
|
|
93
|
+
return cls(
|
|
94
|
+
leaf_tree=state.forest.leaf_tree,
|
|
95
|
+
var_tree=state.forest.var_tree,
|
|
96
|
+
split_tree=state.forest.split_tree,
|
|
97
|
+
offset=state.offset,
|
|
98
|
+
varprob=varprob,
|
|
99
|
+
**vars(BurninTrace.from_state(state)),
|
|
100
|
+
)
|
|
48
101
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
102
|
+
|
|
103
|
+
CallbackState = PyTree[Any, 'T']
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class Callback(Protocol):
|
|
107
|
+
"""Callback type for `run_mcmc`."""
|
|
108
|
+
|
|
109
|
+
def __call__(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
key: Key[Array, ''],
|
|
113
|
+
bart: State,
|
|
114
|
+
burnin: Bool[Array, ''],
|
|
115
|
+
i_total: Int32[Array, ''],
|
|
116
|
+
i_skip: Int32[Array, ''],
|
|
117
|
+
callback_state: CallbackState,
|
|
118
|
+
n_burn: Int32[Array, ''],
|
|
119
|
+
n_save: Int32[Array, ''],
|
|
120
|
+
n_skip: Int32[Array, ''],
|
|
121
|
+
i_outer: Int32[Array, ''],
|
|
122
|
+
inner_loop_length: int,
|
|
123
|
+
) -> tuple[State, CallbackState] | None:
|
|
124
|
+
"""Do an arbitrary action after an iteration of the MCMC.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
key
|
|
129
|
+
A key for random number generation.
|
|
130
|
+
bart
|
|
131
|
+
The MCMC state just after updating it.
|
|
132
|
+
burnin
|
|
133
|
+
Whether the last iteration was in the burn-in phase.
|
|
134
|
+
i_total
|
|
135
|
+
The index of the last MCMC iteration (0-based).
|
|
136
|
+
i_skip
|
|
137
|
+
The number of MCMC updates from the last saved state. The initial
|
|
138
|
+
state counts as saved, even if it's not copied into the trace.
|
|
139
|
+
callback_state
|
|
140
|
+
The callback state, initially set to the argument passed to
|
|
141
|
+
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
142
|
+
of the callback.
|
|
143
|
+
n_burn
|
|
144
|
+
n_save
|
|
145
|
+
n_skip
|
|
146
|
+
The corresponding `run_mcmc` arguments as-is.
|
|
147
|
+
i_outer
|
|
148
|
+
The index of the last outer loop iteration (0-based).
|
|
149
|
+
inner_loop_length
|
|
150
|
+
The number of MCMC iterations in the inner loop.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
bart : State
|
|
155
|
+
A possibly modified MCMC state. To avoid modifying the state,
|
|
156
|
+
return the `bart` argument passed to the callback as-is.
|
|
157
|
+
callback_state : CallbackState
|
|
158
|
+
The new state to be passed on the next callback invocation.
|
|
159
|
+
|
|
160
|
+
Notes
|
|
161
|
+
-----
|
|
162
|
+
For convenience, the callback may return `None`, and the states won't
|
|
163
|
+
be updated.
|
|
164
|
+
"""
|
|
165
|
+
...
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class _Carry(Module):
|
|
169
|
+
"""Carry used in the loop in `run_mcmc`."""
|
|
170
|
+
|
|
171
|
+
bart: State
|
|
172
|
+
i_total: Int32[Array, '']
|
|
173
|
+
key: Key[Array, '']
|
|
174
|
+
burnin_trace: PyTree[Shaped[Array, 'n_burn *']]
|
|
175
|
+
main_trace: PyTree[Shaped[Array, 'n_save *']]
|
|
176
|
+
callback_state: CallbackState
|
|
60
177
|
|
|
61
178
|
|
|
62
179
|
def run_mcmc(
|
|
63
|
-
key,
|
|
64
|
-
bart,
|
|
65
|
-
n_save,
|
|
180
|
+
key: Key[Array, ''],
|
|
181
|
+
bart: State,
|
|
182
|
+
n_save: int,
|
|
66
183
|
*,
|
|
67
|
-
n_burn=0,
|
|
68
|
-
n_skip=1,
|
|
69
|
-
inner_loop_length=None,
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
both_extractor=default_both_extractor,
|
|
76
|
-
):
|
|
184
|
+
n_burn: int = 0,
|
|
185
|
+
n_skip: int = 1,
|
|
186
|
+
inner_loop_length: int | None = None,
|
|
187
|
+
callback: Callback | None = None,
|
|
188
|
+
callback_state: CallbackState = None,
|
|
189
|
+
burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
|
|
190
|
+
main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
|
|
191
|
+
) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
|
|
77
192
|
"""
|
|
78
193
|
Run the MCMC for the BART posterior.
|
|
79
194
|
|
|
80
195
|
Parameters
|
|
81
196
|
----------
|
|
82
|
-
key
|
|
197
|
+
key
|
|
83
198
|
A key for random number generation.
|
|
84
|
-
bart
|
|
199
|
+
bart
|
|
85
200
|
The initial MCMC state, as created and updated by the functions in
|
|
86
201
|
`bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
|
|
87
202
|
so this variable is invalidated after running `run_mcmc`. Make a copy
|
|
88
203
|
beforehand to use it again.
|
|
89
|
-
n_save
|
|
204
|
+
n_save
|
|
90
205
|
The number of iterations to save.
|
|
91
|
-
n_burn
|
|
206
|
+
n_burn
|
|
92
207
|
The number of initial iterations which are not saved.
|
|
93
|
-
n_skip
|
|
208
|
+
n_skip
|
|
94
209
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
95
210
|
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
96
|
-
inner_loop_length
|
|
211
|
+
inner_loop_length
|
|
97
212
|
The MCMC loop is split into an outer and an inner loop. The outer loop
|
|
98
213
|
is in Python, while the inner loop is in JAX. `inner_loop_length` is the
|
|
99
214
|
number of iterations of the inner loop to run for each iteration of the
|
|
100
215
|
outer loop. If not specified, the outer loop will iterate just once,
|
|
101
216
|
with all iterations done in a single inner loop run. The inner stride is
|
|
102
217
|
unrelated to the stride used for saving the trace.
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
bart : dict
|
|
116
|
-
The MCMC state just after updating it.
|
|
117
|
-
burnin : bool
|
|
118
|
-
Whether the last iteration was in the burn-in phase.
|
|
119
|
-
overflow : bool
|
|
120
|
-
Whether the last iteration was in the overflow phase (iterations
|
|
121
|
-
not saved due to `inner_loop_length` not being a divisor of the
|
|
122
|
-
total number of iterations).
|
|
123
|
-
i_total : int
|
|
124
|
-
The index of the last MCMC iteration (0-based).
|
|
125
|
-
i_skip : int
|
|
126
|
-
The number of MCMC updates from the last saved state. The initial
|
|
127
|
-
state counts as saved, even if it's not copied into the trace.
|
|
128
|
-
callback_state : jax pytree
|
|
129
|
-
The callback state, initially set to the argument passed to
|
|
130
|
-
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
131
|
-
of `inner_callback` or `outer_callback`.
|
|
132
|
-
n_burn, n_save, n_skip : int
|
|
133
|
-
The corresponding arguments as-is.
|
|
134
|
-
i_outer : int
|
|
135
|
-
The index of the last outer loop iteration (0-based).
|
|
136
|
-
inner_loop_length : int
|
|
137
|
-
The number of MCMC iterations in the inner loop.
|
|
138
|
-
|
|
139
|
-
`inner_callback` is called under the jax jit, so the argument values are
|
|
140
|
-
not available at the time the Python code is executed. Use the utilities
|
|
141
|
-
in `jax.debug` to access the values at actual runtime.
|
|
142
|
-
|
|
143
|
-
The callbacks must return two values:
|
|
144
|
-
|
|
145
|
-
bart : dict
|
|
146
|
-
A possibly modified MCMC state. To avoid modifying the state,
|
|
147
|
-
return the `bart` argument passed to the callback as-is.
|
|
148
|
-
callback_state : jax pytree
|
|
149
|
-
The new state to be passed on the next callback invocation.
|
|
150
|
-
|
|
151
|
-
For convenience, if a callback returns `None`, the states are not
|
|
152
|
-
updated.
|
|
153
|
-
callback_state : jax pytree, optional
|
|
154
|
-
The initial state for the callbacks.
|
|
155
|
-
onlymain_extractor : callable, optional
|
|
156
|
-
both_extractor : callable, optional
|
|
218
|
+
callback
|
|
219
|
+
An arbitrary function run during the loop after updating the state. For
|
|
220
|
+
the signature, see `Callback`. The callback is called under the jax jit,
|
|
221
|
+
so the argument values are not available at the time the Python code is
|
|
222
|
+
executed. Use the utilities in `jax.debug` to access the values at
|
|
223
|
+
actual runtime. The callback may return new values for the MCMC state
|
|
224
|
+
and the callback state.
|
|
225
|
+
callback_state
|
|
226
|
+
The initial custom state for the callback.
|
|
227
|
+
burnin_extractor
|
|
228
|
+
main_extractor
|
|
157
229
|
Functions that extract the variables to be saved respectively only in
|
|
158
230
|
the main trace and in both traces, given the MCMC state as argument.
|
|
159
231
|
Must return a pytree, and must be vmappable.
|
|
160
232
|
|
|
161
233
|
Returns
|
|
162
234
|
-------
|
|
163
|
-
bart :
|
|
235
|
+
bart : State
|
|
164
236
|
The final MCMC state.
|
|
165
|
-
burnin_trace :
|
|
166
|
-
The trace of the burn-in phase
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
|
|
170
|
-
the fields in `tracevars_both`).
|
|
171
|
-
main_trace : dict of (n_save, ...) arrays
|
|
172
|
-
The trace of the main phase, containing the following subset of fields
|
|
173
|
-
from the `bart` dictionary, with an additional head index that runs over
|
|
174
|
-
MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
|
|
175
|
-
specified the fields in `tracevars_onlymain`), plus the fields in
|
|
176
|
-
`burnin_trace`.
|
|
177
|
-
|
|
178
|
-
Raises
|
|
179
|
-
------
|
|
180
|
-
ValueError
|
|
181
|
-
If `inner_loop_length` is not a divisor of the total number of
|
|
182
|
-
iterations and `allow_overflow` is `False`.
|
|
237
|
+
burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
|
|
238
|
+
The trace of the burn-in phase. For the default layout, see `BurninTrace`.
|
|
239
|
+
main_trace : PyTree[Shaped[Array, 'n_save *']]
|
|
240
|
+
The trace of the main phase. For the default layout, see `MainTrace`.
|
|
183
241
|
|
|
184
242
|
Notes
|
|
185
243
|
-----
|
|
@@ -190,102 +248,85 @@ def run_mcmc(
|
|
|
190
248
|
def empty_trace(length, bart, extractor):
|
|
191
249
|
return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
|
|
192
250
|
|
|
193
|
-
|
|
194
|
-
|
|
251
|
+
burnin_trace = empty_trace(n_burn, bart, burnin_extractor)
|
|
252
|
+
main_trace = empty_trace(n_save, bart, main_extractor)
|
|
195
253
|
|
|
196
254
|
# determine number of iterations for inner and outer loops
|
|
197
255
|
n_iters = n_burn + n_skip * n_save
|
|
198
256
|
if inner_loop_length is None:
|
|
199
257
|
inner_loop_length = n_iters
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
carry = (bart, 0, key,
|
|
258
|
+
if inner_loop_length:
|
|
259
|
+
n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
|
|
260
|
+
else:
|
|
261
|
+
n_outer = 1
|
|
262
|
+
# setting to 0 would make for a clean noop, but it's useful to keep the
|
|
263
|
+
# same code path for benchmarking and testing
|
|
264
|
+
|
|
265
|
+
carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state)
|
|
208
266
|
for i_outer in range(n_outer):
|
|
209
267
|
carry = _run_mcmc_inner_loop(
|
|
210
268
|
carry,
|
|
211
269
|
inner_loop_length,
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
270
|
+
callback,
|
|
271
|
+
burnin_extractor,
|
|
272
|
+
main_extractor,
|
|
215
273
|
n_burn,
|
|
216
274
|
n_save,
|
|
217
275
|
n_skip,
|
|
218
276
|
i_outer,
|
|
277
|
+
n_iters,
|
|
219
278
|
)
|
|
220
|
-
if outer_callback is not None:
|
|
221
|
-
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
222
|
-
i_total -= 1 # because i_total is updated at the end of the inner loop
|
|
223
|
-
i_skip = _compute_i_skip(i_total, n_burn, n_skip)
|
|
224
|
-
rt = outer_callback(
|
|
225
|
-
bart=bart,
|
|
226
|
-
burnin=i_total < n_burn,
|
|
227
|
-
overflow=i_total >= n_iters,
|
|
228
|
-
i_total=i_total,
|
|
229
|
-
i_skip=i_skip,
|
|
230
|
-
callback_state=callback_state,
|
|
231
|
-
n_burn=n_burn,
|
|
232
|
-
n_save=n_save,
|
|
233
|
-
n_skip=n_skip,
|
|
234
|
-
i_outer=i_outer,
|
|
235
|
-
inner_loop_length=inner_loop_length,
|
|
236
|
-
)
|
|
237
|
-
if rt is not None:
|
|
238
|
-
bart, callback_state = rt
|
|
239
|
-
i_total += 1
|
|
240
|
-
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
241
|
-
|
|
242
|
-
bart, _, _, trace_both, trace_onlymain, _ = carry
|
|
243
279
|
|
|
244
|
-
|
|
245
|
-
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
|
|
246
|
-
main_trace.update(trace_onlymain)
|
|
280
|
+
return carry.bart, carry.burnin_trace, carry.main_trace
|
|
247
281
|
|
|
248
|
-
return bart, burnin_trace, main_trace
|
|
249
282
|
|
|
250
|
-
|
|
251
|
-
|
|
283
|
+
def _compute_i_skip(
|
|
284
|
+
i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
|
|
285
|
+
) -> Int32[Array, '']:
|
|
286
|
+
"""Compute the `i_skip` argument passed to `callback`."""
|
|
252
287
|
burnin = i_total < n_burn
|
|
253
288
|
return jnp.where(
|
|
254
289
|
burnin,
|
|
255
290
|
i_total + 1,
|
|
256
|
-
(i_total + 1) % n_skip
|
|
291
|
+
(i_total - n_burn + 1) % n_skip
|
|
292
|
+
+ jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
|
|
257
293
|
)
|
|
258
294
|
|
|
259
295
|
|
|
260
|
-
@
|
|
296
|
+
@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
|
|
261
297
|
def _run_mcmc_inner_loop(
|
|
262
|
-
carry,
|
|
263
|
-
inner_loop_length,
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
n_burn,
|
|
268
|
-
n_save,
|
|
269
|
-
n_skip,
|
|
270
|
-
i_outer,
|
|
298
|
+
carry: _Carry,
|
|
299
|
+
inner_loop_length: int,
|
|
300
|
+
callback: Callback | None,
|
|
301
|
+
burnin_extractor: Callable[[State], PyTree],
|
|
302
|
+
main_extractor: Callable[[State], PyTree],
|
|
303
|
+
n_burn: Int32[Array, ''],
|
|
304
|
+
n_save: Int32[Array, ''],
|
|
305
|
+
n_skip: Int32[Array, ''],
|
|
306
|
+
i_outer: Int32[Array, ''],
|
|
307
|
+
n_iters: Int32[Array, ''],
|
|
271
308
|
):
|
|
272
|
-
def
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
keys = jaxext.split(key)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
309
|
+
def loop_impl(carry: _Carry) -> _Carry:
|
|
310
|
+
"""Loop body to run if i_total < n_iters."""
|
|
311
|
+
# split random key
|
|
312
|
+
keys = jaxext.split(carry.key, 3)
|
|
313
|
+
carry = replace(carry, key=keys.pop())
|
|
314
|
+
|
|
315
|
+
# update state
|
|
316
|
+
carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart))
|
|
317
|
+
|
|
318
|
+
burnin = carry.i_total < n_burn
|
|
319
|
+
|
|
320
|
+
# invoke callback
|
|
321
|
+
if callback is not None:
|
|
322
|
+
i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
|
|
323
|
+
rt = callback(
|
|
324
|
+
key=keys.pop(),
|
|
325
|
+
bart=carry.bart,
|
|
284
326
|
burnin=burnin,
|
|
285
|
-
|
|
286
|
-
i_total=i_total,
|
|
327
|
+
i_total=carry.i_total,
|
|
287
328
|
i_skip=i_skip,
|
|
288
|
-
callback_state=callback_state,
|
|
329
|
+
callback_state=carry.callback_state,
|
|
289
330
|
n_burn=n_burn,
|
|
290
331
|
n_save=n_save,
|
|
291
332
|
n_skip=n_skip,
|
|
@@ -294,128 +335,178 @@ def _run_mcmc_inner_loop(
|
|
|
294
335
|
)
|
|
295
336
|
if rt is not None:
|
|
296
337
|
bart, callback_state = rt
|
|
338
|
+
carry = replace(carry, bart=bart, callback_state=callback_state)
|
|
297
339
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
def assign_at_index(trace_array, state_array):
|
|
303
|
-
if trace_array.size:
|
|
304
|
-
return trace_array.at[index, ...].set(state_array)
|
|
305
|
-
else:
|
|
306
|
-
# this handles the case where a trace is empty (e.g.,
|
|
307
|
-
# no burn-in) because jax refuses to index into an array
|
|
308
|
-
# of length 0
|
|
309
|
-
return trace_array
|
|
340
|
+
def save_to_burnin_trace() -> tuple[PyTree, PyTree]:
|
|
341
|
+
return _pytree_at_set(
|
|
342
|
+
carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
|
|
343
|
+
), carry.main_trace
|
|
310
344
|
|
|
311
|
-
|
|
345
|
+
def save_to_main_trace() -> tuple[PyTree, PyTree]:
|
|
346
|
+
idx = (carry.i_total - n_burn) // n_skip
|
|
347
|
+
return carry.burnin_trace, _pytree_at_set(
|
|
348
|
+
carry.main_trace, idx, main_extractor(carry.bart)
|
|
349
|
+
)
|
|
312
350
|
|
|
313
|
-
|
|
314
|
-
|
|
351
|
+
# save state to trace
|
|
352
|
+
burnin_trace, main_trace = lax.cond(
|
|
353
|
+
burnin, save_to_burnin_trace, save_to_main_trace
|
|
354
|
+
)
|
|
355
|
+
return replace(
|
|
356
|
+
carry,
|
|
357
|
+
i_total=carry.i_total + 1,
|
|
358
|
+
burnin_trace=burnin_trace,
|
|
359
|
+
main_trace=main_trace,
|
|
315
360
|
)
|
|
316
|
-
trace_both = update_trace(i_both, trace_both, both_extractor(bart))
|
|
317
361
|
|
|
318
|
-
|
|
319
|
-
|
|
362
|
+
def loop_noop(carry: _Carry) -> _Carry:
|
|
363
|
+
"""Loop body to run if i_total >= n_iters; it does nothing."""
|
|
364
|
+
return carry
|
|
365
|
+
|
|
366
|
+
def loop(carry: _Carry, _) -> tuple[_Carry, None]:
|
|
367
|
+
carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry)
|
|
320
368
|
return carry, None
|
|
321
369
|
|
|
322
370
|
carry, _ = lax.scan(loop, carry, None, inner_loop_length)
|
|
323
371
|
return carry
|
|
324
372
|
|
|
325
373
|
|
|
326
|
-
def
|
|
374
|
+
def _pytree_at_set(
|
|
375
|
+
dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array]
|
|
376
|
+
) -> PyTree[Array, ' T']:
|
|
377
|
+
"""Map ``dest.at[index].set(val)`` over pytrees."""
|
|
378
|
+
|
|
379
|
+
def at_set(dest, val):
|
|
380
|
+
if dest.size:
|
|
381
|
+
return dest.at[index, ...].set(val)
|
|
382
|
+
else:
|
|
383
|
+
# this handles the case where an array is empty because jax refuses
|
|
384
|
+
# to index into an array of length 0, even if just in the abstract
|
|
385
|
+
return dest
|
|
386
|
+
|
|
387
|
+
return tree.map(at_set, dest, val)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def make_default_callback(
|
|
391
|
+
*,
|
|
392
|
+
dot_every: int | Integer[Array, ''] | None = 1,
|
|
393
|
+
report_every: int | Integer[Array, ''] | None = 100,
|
|
394
|
+
sparse_on_at: int | Integer[Array, ''] | None = None,
|
|
395
|
+
) -> dict[str, Any]:
|
|
327
396
|
"""
|
|
328
|
-
Prepare
|
|
397
|
+
Prepare a default callback for `run_mcmc`.
|
|
329
398
|
|
|
330
|
-
|
|
331
|
-
report outer loop iteration.
|
|
399
|
+
The callback prints a dot on every iteration, and a longer
|
|
400
|
+
report outer loop iteration, and can do variable selection.
|
|
332
401
|
|
|
333
402
|
Parameters
|
|
334
403
|
----------
|
|
335
|
-
|
|
336
|
-
A dot is printed every `
|
|
337
|
-
|
|
338
|
-
A report is printed every `
|
|
339
|
-
|
|
404
|
+
dot_every
|
|
405
|
+
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
406
|
+
report_every
|
|
407
|
+
A one line report is printed every `report_every` MCMC iterations,
|
|
408
|
+
`None` to disable.
|
|
409
|
+
sparse_on_at
|
|
410
|
+
If specified, variable selection is activated starting from this
|
|
411
|
+
iteration. If `None`, variable selection is not used.
|
|
340
412
|
|
|
341
413
|
Returns
|
|
342
414
|
-------
|
|
343
|
-
|
|
344
|
-
A dictionary with the arguments to pass to `run_mcmc` as keyword
|
|
345
|
-
arguments to set up the callbacks.
|
|
415
|
+
A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
|
|
346
416
|
|
|
347
417
|
Examples
|
|
348
418
|
--------
|
|
349
|
-
>>> run_mcmc(..., **
|
|
419
|
+
>>> run_mcmc(..., **make_default_callback())
|
|
350
420
|
"""
|
|
421
|
+
|
|
422
|
+
def asarray_or_none(val: None | Any) -> None | Array:
|
|
423
|
+
return None if val is None else jnp.asarray(val)
|
|
424
|
+
|
|
425
|
+
def callback(*, bart, callback_state, **kwargs):
|
|
426
|
+
print_state, sparse_state = callback_state
|
|
427
|
+
bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs)
|
|
428
|
+
print_callback(callback_state=print_state, bart=bart, **kwargs)
|
|
429
|
+
return bart, callback_state
|
|
430
|
+
# here I assume that the callbacks don't update their states
|
|
431
|
+
|
|
351
432
|
return dict(
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
433
|
+
callback=callback,
|
|
434
|
+
callback_state=(
|
|
435
|
+
PrintCallbackState(
|
|
436
|
+
asarray_or_none(dot_every), asarray_or_none(report_every)
|
|
437
|
+
),
|
|
438
|
+
SparseCallbackState(asarray_or_none(sparse_on_at)),
|
|
356
439
|
),
|
|
357
440
|
)
|
|
358
441
|
|
|
359
442
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
if dot_every_inner is not None:
|
|
363
|
-
cond = (i_total + 1) % dot_every_inner == 0
|
|
364
|
-
debug.callback(_print_dot, cond)
|
|
443
|
+
class PrintCallbackState(Module):
|
|
444
|
+
"""State for `print_callback`.
|
|
365
445
|
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
dot_every
|
|
449
|
+
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
450
|
+
report_every
|
|
451
|
+
A one line report is printed every `report_every` MCMC iterations,
|
|
452
|
+
`None` to disable.
|
|
453
|
+
"""
|
|
366
454
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
print('.', end='', flush=True)
|
|
455
|
+
dot_every: Int32[Array, ''] | None
|
|
456
|
+
report_every: Int32[Array, ''] | None
|
|
370
457
|
|
|
371
458
|
|
|
372
|
-
def
|
|
459
|
+
def print_callback(
|
|
373
460
|
*,
|
|
374
|
-
bart,
|
|
375
|
-
burnin,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
callback_state,
|
|
382
|
-
i_outer,
|
|
383
|
-
inner_loop_length,
|
|
461
|
+
bart: State,
|
|
462
|
+
burnin: Bool[Array, ''],
|
|
463
|
+
i_total: Int32[Array, ''],
|
|
464
|
+
n_burn: Int32[Array, ''],
|
|
465
|
+
n_save: Int32[Array, ''],
|
|
466
|
+
n_skip: Int32[Array, ''],
|
|
467
|
+
callback_state: PrintCallbackState,
|
|
384
468
|
**_,
|
|
385
469
|
):
|
|
386
|
-
|
|
387
|
-
if
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
_print_report,
|
|
395
|
-
cond=(i_outer + 1) % report_every_outer == 0,
|
|
396
|
-
newline=newline,
|
|
397
|
-
burnin=burnin,
|
|
398
|
-
overflow=overflow,
|
|
399
|
-
i_total=i_total,
|
|
400
|
-
n_iters=n_burn + n_save * n_skip,
|
|
401
|
-
grow_prop_count=bart.forest.grow_prop_count,
|
|
402
|
-
grow_acc_count=bart.forest.grow_acc_count,
|
|
403
|
-
prune_prop_count=bart.forest.prune_prop_count,
|
|
404
|
-
prune_acc_count=bart.forest.prune_acc_count,
|
|
405
|
-
prop_total=len(bart.forest.leaf_trees),
|
|
406
|
-
fill=grove.forest_fill(bart.forest.split_trees),
|
|
470
|
+
"""Print a dot and/or a report periodically during the MCMC."""
|
|
471
|
+
if callback_state.dot_every is not None:
|
|
472
|
+
cond = (i_total + 1) % callback_state.dot_every == 0
|
|
473
|
+
lax.cond(
|
|
474
|
+
cond,
|
|
475
|
+
lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
|
|
476
|
+
# logging can't do in-line printing so I'll stick to print
|
|
477
|
+
lambda: None,
|
|
407
478
|
)
|
|
408
479
|
|
|
480
|
+
if callback_state.report_every is not None:
|
|
481
|
+
|
|
482
|
+
def print_report():
|
|
483
|
+
debug.callback(
|
|
484
|
+
_print_report,
|
|
485
|
+
newline=callback_state.dot_every is not None,
|
|
486
|
+
burnin=burnin,
|
|
487
|
+
i_total=i_total,
|
|
488
|
+
n_iters=n_burn + n_save * n_skip,
|
|
489
|
+
grow_prop_count=bart.forest.grow_prop_count,
|
|
490
|
+
grow_acc_count=bart.forest.grow_acc_count,
|
|
491
|
+
prune_prop_count=bart.forest.prune_prop_count,
|
|
492
|
+
prune_acc_count=bart.forest.prune_acc_count,
|
|
493
|
+
prop_total=len(bart.forest.leaf_tree),
|
|
494
|
+
fill=grove.forest_fill(bart.forest.split_tree),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
cond = (i_total + 1) % callback_state.report_every == 0
|
|
498
|
+
lax.cond(cond, print_report, lambda: None)
|
|
409
499
|
|
|
410
|
-
|
|
500
|
+
|
|
501
|
+
def _convert_jax_arrays_in_args(func: Callable) -> Callable:
|
|
411
502
|
"""Remove jax arrays from a function arguments.
|
|
412
503
|
|
|
413
|
-
Converts all jax.Array instances in the arguments to either Python scalars
|
|
504
|
+
Converts all `jax.Array` instances in the arguments to either Python scalars
|
|
414
505
|
or numpy arrays.
|
|
415
506
|
"""
|
|
416
507
|
|
|
417
|
-
def convert_jax_arrays(pytree):
|
|
418
|
-
def convert_jax_arrays(val):
|
|
508
|
+
def convert_jax_arrays(pytree: PyTree) -> PyTree:
|
|
509
|
+
def convert_jax_arrays(val: Any) -> Any:
|
|
419
510
|
if not isinstance(val, jax.Array):
|
|
420
511
|
return val
|
|
421
512
|
elif val.shape:
|
|
@@ -425,7 +516,7 @@ def _convert_jax_arrays_in_args(func):
|
|
|
425
516
|
|
|
426
517
|
return tree.map(convert_jax_arrays, pytree)
|
|
427
518
|
|
|
428
|
-
@
|
|
519
|
+
@wraps(func)
|
|
429
520
|
def new_func(*args, **kw):
|
|
430
521
|
args = convert_jax_arrays(args)
|
|
431
522
|
kw = convert_jax_arrays(kw)
|
|
@@ -439,73 +530,139 @@ def _convert_jax_arrays_in_args(func):
|
|
|
439
530
|
# deadlock with the main thread
|
|
440
531
|
def _print_report(
|
|
441
532
|
*,
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
prop_total,
|
|
453
|
-
fill,
|
|
533
|
+
newline: bool,
|
|
534
|
+
burnin: bool,
|
|
535
|
+
i_total: int,
|
|
536
|
+
n_iters: int,
|
|
537
|
+
grow_prop_count: int,
|
|
538
|
+
grow_acc_count: int,
|
|
539
|
+
prune_prop_count: int,
|
|
540
|
+
prune_acc_count: int,
|
|
541
|
+
prop_total: int,
|
|
542
|
+
fill: float,
|
|
454
543
|
):
|
|
455
|
-
|
|
456
|
-
newline = '\n' if newline else ''
|
|
544
|
+
"""Print the report for `print_callback`."""
|
|
457
545
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
546
|
+
def acc_string(acc_count, prop_count):
|
|
547
|
+
if prop_count:
|
|
548
|
+
return f'{acc_count / prop_count:.0%}'
|
|
549
|
+
else:
|
|
550
|
+
return 'n/d'
|
|
463
551
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
552
|
+
grow_prop = grow_prop_count / prop_total
|
|
553
|
+
prune_prop = prune_prop_count / prop_total
|
|
554
|
+
grow_acc = acc_string(grow_acc_count, grow_prop_count)
|
|
555
|
+
prune_acc = acc_string(prune_acc_count, prune_prop_count)
|
|
556
|
+
|
|
557
|
+
prefix = '\n' if newline else ''
|
|
558
|
+
suffix = ' (burnin)' if burnin else ''
|
|
559
|
+
|
|
560
|
+
print( # noqa: T201, see print_callback for why not logging
|
|
561
|
+
f'{prefix}It {i_total + 1}/{n_iters} '
|
|
562
|
+
f'grow P={grow_prop:.0%} A={grow_acc}, '
|
|
563
|
+
f'prune P={prune_prop:.0%} A={prune_acc}, '
|
|
564
|
+
f'fill={fill:.0%}{suffix}'
|
|
565
|
+
)
|
|
468
566
|
|
|
469
|
-
if burnin:
|
|
470
|
-
flag = ' (burnin)'
|
|
471
|
-
elif overflow:
|
|
472
|
-
flag = ' (overflow)'
|
|
473
|
-
else:
|
|
474
|
-
flag = ''
|
|
475
567
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
568
|
+
class SparseCallbackState(Module):
|
|
569
|
+
"""State for `sparse_callback`.
|
|
570
|
+
|
|
571
|
+
Parameters
|
|
572
|
+
----------
|
|
573
|
+
sparse_on_at
|
|
574
|
+
If specified, variable selection is activated starting from this
|
|
575
|
+
iteration. If `None`, variable selection is not used.
|
|
576
|
+
"""
|
|
577
|
+
|
|
578
|
+
sparse_on_at: Int32[Array, ''] | None
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def sparse_callback(
|
|
582
|
+
*,
|
|
583
|
+
key: Key[Array, ''],
|
|
584
|
+
bart: State,
|
|
585
|
+
i_total: Int32[Array, ''],
|
|
586
|
+
callback_state: SparseCallbackState,
|
|
587
|
+
**_,
|
|
588
|
+
):
|
|
589
|
+
"""Perform variable selection, see `mcmcstep.step_sparse`."""
|
|
590
|
+
if callback_state.sparse_on_at is not None:
|
|
591
|
+
bart = lax.cond(
|
|
592
|
+
i_total < callback_state.sparse_on_at,
|
|
593
|
+
lambda: bart,
|
|
594
|
+
lambda: mcmcstep.step_sparse(key, bart),
|
|
481
595
|
)
|
|
596
|
+
return bart, callback_state
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class Trace(grove.TreeHeaps, Protocol):
|
|
600
|
+
"""Protocol for a MCMC trace."""
|
|
601
|
+
|
|
602
|
+
offset: Float32[Array, ' trace_length']
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
class TreesTrace(Module):
|
|
606
|
+
"""Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
|
|
607
|
+
|
|
608
|
+
leaf_tree: Float32[Array, 'trace_length num_trees 2**d']
|
|
609
|
+
var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
|
|
610
|
+
split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
|
|
611
|
+
|
|
612
|
+
@classmethod
|
|
613
|
+
def from_dataclass(cls, obj: grove.TreeHeaps):
|
|
614
|
+
"""Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
|
|
615
|
+
return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
|
|
482
616
|
|
|
483
617
|
|
|
484
618
|
@jax.jit
|
|
485
|
-
def evaluate_trace(
|
|
619
|
+
def evaluate_trace(
|
|
620
|
+
trace: Trace, X: UInt[Array, 'p n']
|
|
621
|
+
) -> Float32[Array, 'trace_length n']:
|
|
486
622
|
"""
|
|
487
623
|
Compute predictions for all iterations of the BART MCMC.
|
|
488
624
|
|
|
489
625
|
Parameters
|
|
490
626
|
----------
|
|
491
|
-
trace
|
|
627
|
+
trace
|
|
492
628
|
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
493
|
-
X
|
|
629
|
+
X
|
|
494
630
|
The predictors matrix, with `p` predictors and `n` observations.
|
|
495
631
|
|
|
496
632
|
Returns
|
|
497
633
|
-------
|
|
498
|
-
|
|
499
|
-
The predictions for each iteration of the MCMC.
|
|
634
|
+
The predictions for each iteration of the MCMC.
|
|
500
635
|
"""
|
|
501
|
-
evaluate_trees =
|
|
502
|
-
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0
|
|
636
|
+
evaluate_trees = partial(grove.evaluate_forest, sum_trees=False)
|
|
637
|
+
evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0))
|
|
638
|
+
trees = TreesTrace.from_dataclass(trace)
|
|
503
639
|
|
|
504
|
-
def loop(_,
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
)
|
|
508
|
-
return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
640
|
+
def loop(_, item):
|
|
641
|
+
offset, trees = item
|
|
642
|
+
values = evaluate_trees(X, trees)
|
|
643
|
+
return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32)
|
|
509
644
|
|
|
510
|
-
_, y = lax.scan(loop, None, trace)
|
|
645
|
+
_, y = lax.scan(loop, None, (trace.offset, trees))
|
|
511
646
|
return y
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
@partial(jax.jit, static_argnums=(0,))
|
|
650
|
+
def compute_varcount(
|
|
651
|
+
p: int, trace: grove.TreeHeaps
|
|
652
|
+
) -> Int32[Array, 'trace_length {p}']:
|
|
653
|
+
"""
|
|
654
|
+
Count how many times each predictor is used in each MCMC state.
|
|
655
|
+
|
|
656
|
+
Parameters
|
|
657
|
+
----------
|
|
658
|
+
p
|
|
659
|
+
The number of predictors.
|
|
660
|
+
trace
|
|
661
|
+
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
662
|
+
|
|
663
|
+
Returns
|
|
664
|
+
-------
|
|
665
|
+
Histogram of predictor usage in each MCMC state.
|
|
666
|
+
"""
|
|
667
|
+
vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0))
|
|
668
|
+
return vmapped_var_histogram(p, trace.var_tree, trace.split_tree)
|