bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/mcmcloop.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcloop.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -22,268 +22,416 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""Functions that implement the full BART posterior MCMC loop.
|
|
25
|
+
"""Functions that implement the full BART posterior MCMC loop.
|
|
26
26
|
|
|
27
|
-
|
|
27
|
+
The entry points are `run_mcmc` and `make_default_callback`.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
from dataclasses import fields
|
|
32
|
+
from functools import partial, wraps
|
|
33
|
+
from math import floor
|
|
34
|
+
from typing import Any, Protocol
|
|
28
35
|
|
|
29
36
|
import jax
|
|
30
37
|
import numpy
|
|
31
|
-
from
|
|
38
|
+
from equinox import Module
|
|
39
|
+
from jax import (
|
|
40
|
+
NamedSharding,
|
|
41
|
+
ShapeDtypeStruct,
|
|
42
|
+
debug,
|
|
43
|
+
device_put,
|
|
44
|
+
eval_shape,
|
|
45
|
+
jit,
|
|
46
|
+
tree,
|
|
47
|
+
)
|
|
32
48
|
from jax import numpy as jnp
|
|
33
|
-
from
|
|
49
|
+
from jax.nn import softmax
|
|
50
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
51
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
|
|
52
|
+
|
|
53
|
+
from bartz import jaxext, mcmcstep
|
|
54
|
+
from bartz._profiler import (
|
|
55
|
+
cond_if_not_profiling,
|
|
56
|
+
get_profile_mode,
|
|
57
|
+
jit_if_not_profiling,
|
|
58
|
+
scan_if_not_profiling,
|
|
59
|
+
)
|
|
60
|
+
from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram
|
|
61
|
+
from bartz.jaxext import autobatch
|
|
62
|
+
from bartz.mcmcstep import State
|
|
63
|
+
from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class BurninTrace(Module):
|
|
67
|
+
"""MCMC trace with only diagnostic values."""
|
|
68
|
+
|
|
69
|
+
error_cov_inv: (
|
|
70
|
+
Float32[Array, '*chains_and_samples']
|
|
71
|
+
| Float32[Array, '*chains_and_samples k k']
|
|
72
|
+
| None
|
|
73
|
+
) = field(chains=True)
|
|
74
|
+
theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
|
|
75
|
+
grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
76
|
+
grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
77
|
+
prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
78
|
+
prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
|
|
79
|
+
log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
|
|
80
|
+
log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_state(cls, state: State) -> 'BurninTrace':
|
|
84
|
+
"""Create a single-item burn-in trace from a MCMC state."""
|
|
85
|
+
return cls(
|
|
86
|
+
error_cov_inv=state.error_cov_inv,
|
|
87
|
+
theta=state.forest.theta,
|
|
88
|
+
grow_prop_count=state.forest.grow_prop_count,
|
|
89
|
+
grow_acc_count=state.forest.grow_acc_count,
|
|
90
|
+
prune_prop_count=state.forest.prune_prop_count,
|
|
91
|
+
prune_acc_count=state.forest.prune_acc_count,
|
|
92
|
+
log_likelihood=state.forest.log_likelihood,
|
|
93
|
+
log_trans_prior=state.forest.log_trans_prior,
|
|
94
|
+
)
|
|
34
95
|
|
|
35
|
-
from . import grove, jaxext, mcmcstep
|
|
36
|
-
from .mcmcstep import State
|
|
37
96
|
|
|
97
|
+
class MainTrace(BurninTrace):
|
|
98
|
+
"""MCMC trace with trees and diagnostic values."""
|
|
99
|
+
|
|
100
|
+
leaf_tree: (
|
|
101
|
+
Float32[Array, '*chains_and_samples 2**d']
|
|
102
|
+
| Float32[Array, '*chains_and_samples k 2**d']
|
|
103
|
+
) = field(chains=True)
|
|
104
|
+
var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
|
|
105
|
+
split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
|
|
106
|
+
offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
|
|
107
|
+
varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
def from_state(cls, state: State) -> 'MainTrace':
|
|
111
|
+
"""Create a single-item main trace from a MCMC state."""
|
|
112
|
+
# compute varprob
|
|
113
|
+
log_s = state.forest.log_s
|
|
114
|
+
if log_s is None:
|
|
115
|
+
varprob = None
|
|
116
|
+
else:
|
|
117
|
+
varprob = softmax(log_s, where=state.forest.max_split.astype(bool))
|
|
118
|
+
|
|
119
|
+
return cls(
|
|
120
|
+
leaf_tree=state.forest.leaf_tree,
|
|
121
|
+
var_tree=state.forest.var_tree,
|
|
122
|
+
split_tree=state.forest.split_tree,
|
|
123
|
+
offset=state.offset,
|
|
124
|
+
varprob=varprob,
|
|
125
|
+
**vars(BurninTrace.from_state(state)),
|
|
126
|
+
)
|
|
38
127
|
|
|
39
|
-
def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
|
|
40
|
-
"""Extract variables for the main trace, to be used in `run_mcmc`."""
|
|
41
|
-
return dict(
|
|
42
|
-
leaf_trees=state.forest.leaf_trees,
|
|
43
|
-
var_trees=state.forest.var_trees,
|
|
44
|
-
split_trees=state.forest.split_trees,
|
|
45
|
-
offset=state.offset,
|
|
46
|
-
)
|
|
47
128
|
|
|
129
|
+
CallbackState = PyTree[Any, 'T']
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Callback(Protocol):
|
|
133
|
+
"""Callback type for `run_mcmc`."""
|
|
134
|
+
|
|
135
|
+
def __call__(
|
|
136
|
+
self,
|
|
137
|
+
*,
|
|
138
|
+
key: Key[Array, ''],
|
|
139
|
+
bart: State,
|
|
140
|
+
burnin: Bool[Array, ''],
|
|
141
|
+
i_total: Int32[Array, ''],
|
|
142
|
+
i_skip: Int32[Array, ''],
|
|
143
|
+
callback_state: CallbackState,
|
|
144
|
+
n_burn: Int32[Array, ''],
|
|
145
|
+
n_save: Int32[Array, ''],
|
|
146
|
+
n_skip: Int32[Array, ''],
|
|
147
|
+
i_outer: Int32[Array, ''],
|
|
148
|
+
inner_loop_length: int,
|
|
149
|
+
) -> tuple[State, CallbackState] | None:
|
|
150
|
+
"""Do an arbitrary action after an iteration of the MCMC.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
key
|
|
155
|
+
A key for random number generation.
|
|
156
|
+
bart
|
|
157
|
+
The MCMC state just after updating it.
|
|
158
|
+
burnin
|
|
159
|
+
Whether the last iteration was in the burn-in phase.
|
|
160
|
+
i_total
|
|
161
|
+
The index of the last MCMC iteration (0-based).
|
|
162
|
+
i_skip
|
|
163
|
+
The number of MCMC updates from the last saved state. The initial
|
|
164
|
+
state counts as saved, even if it's not copied into the trace.
|
|
165
|
+
callback_state
|
|
166
|
+
The callback state, initially set to the argument passed to
|
|
167
|
+
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
168
|
+
of the callback.
|
|
169
|
+
n_burn
|
|
170
|
+
n_save
|
|
171
|
+
n_skip
|
|
172
|
+
The corresponding `run_mcmc` arguments as-is.
|
|
173
|
+
i_outer
|
|
174
|
+
The index of the last outer loop iteration (0-based).
|
|
175
|
+
inner_loop_length
|
|
176
|
+
The number of MCMC iterations in the inner loop.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
bart : State
|
|
181
|
+
A possibly modified MCMC state. To avoid modifying the state,
|
|
182
|
+
return the `bart` argument passed to the callback as-is.
|
|
183
|
+
callback_state : CallbackState
|
|
184
|
+
The new state to be passed on the next callback invocation.
|
|
48
185
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
186
|
+
Notes
|
|
187
|
+
-----
|
|
188
|
+
For convenience, the callback may return `None`, and the states won't
|
|
189
|
+
be updated.
|
|
190
|
+
"""
|
|
191
|
+
...
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class _Carry(Module):
|
|
195
|
+
"""Carry used in the loop in `run_mcmc`."""
|
|
196
|
+
|
|
197
|
+
bart: State
|
|
198
|
+
i_total: Int32[Array, '']
|
|
199
|
+
key: Key[Array, '']
|
|
200
|
+
burnin_trace: PyTree[
|
|
201
|
+
Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
|
|
202
|
+
]
|
|
203
|
+
main_trace: PyTree[
|
|
204
|
+
Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
|
|
205
|
+
]
|
|
206
|
+
callback_state: CallbackState
|
|
60
207
|
|
|
61
208
|
|
|
62
209
|
def run_mcmc(
|
|
63
|
-
key,
|
|
64
|
-
bart,
|
|
65
|
-
n_save,
|
|
210
|
+
key: Key[Array, ''],
|
|
211
|
+
bart: State,
|
|
212
|
+
n_save: int,
|
|
66
213
|
*,
|
|
67
|
-
n_burn=0,
|
|
68
|
-
n_skip=1,
|
|
69
|
-
inner_loop_length=None,
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
214
|
+
n_burn: int = 0,
|
|
215
|
+
n_skip: int = 1,
|
|
216
|
+
inner_loop_length: int | None = None,
|
|
217
|
+
callback: Callback | None = None,
|
|
218
|
+
callback_state: CallbackState = None,
|
|
219
|
+
burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
|
|
220
|
+
main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
|
|
221
|
+
) -> tuple[
|
|
222
|
+
State,
|
|
223
|
+
PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']],
|
|
224
|
+
PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']],
|
|
225
|
+
]:
|
|
77
226
|
"""
|
|
78
227
|
Run the MCMC for the BART posterior.
|
|
79
228
|
|
|
80
229
|
Parameters
|
|
81
230
|
----------
|
|
82
|
-
key
|
|
231
|
+
key
|
|
83
232
|
A key for random number generation.
|
|
84
|
-
bart
|
|
233
|
+
bart
|
|
85
234
|
The initial MCMC state, as created and updated by the functions in
|
|
86
235
|
`bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
|
|
87
236
|
so this variable is invalidated after running `run_mcmc`. Make a copy
|
|
88
237
|
beforehand to use it again.
|
|
89
|
-
n_save
|
|
238
|
+
n_save
|
|
90
239
|
The number of iterations to save.
|
|
91
|
-
n_burn
|
|
240
|
+
n_burn
|
|
92
241
|
The number of initial iterations which are not saved.
|
|
93
|
-
n_skip
|
|
242
|
+
n_skip
|
|
94
243
|
The number of iterations to skip between each saved iteration, plus 1.
|
|
95
244
|
The effective burn-in is ``n_burn + n_skip - 1``.
|
|
96
|
-
inner_loop_length
|
|
245
|
+
inner_loop_length
|
|
97
246
|
The MCMC loop is split into an outer and an inner loop. The outer loop
|
|
98
247
|
is in Python, while the inner loop is in JAX. `inner_loop_length` is the
|
|
99
248
|
number of iterations of the inner loop to run for each iteration of the
|
|
100
249
|
outer loop. If not specified, the outer loop will iterate just once,
|
|
101
250
|
with all iterations done in a single inner loop run. The inner stride is
|
|
102
251
|
unrelated to the stride used for saving the trace.
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
burnin : bool
|
|
118
|
-
Whether the last iteration was in the burn-in phase.
|
|
119
|
-
overflow : bool
|
|
120
|
-
Whether the last iteration was in the overflow phase (iterations
|
|
121
|
-
not saved due to `inner_loop_length` not being a divisor of the
|
|
122
|
-
total number of iterations).
|
|
123
|
-
i_total : int
|
|
124
|
-
The index of the last MCMC iteration (0-based).
|
|
125
|
-
i_skip : int
|
|
126
|
-
The number of MCMC updates from the last saved state. The initial
|
|
127
|
-
state counts as saved, even if it's not copied into the trace.
|
|
128
|
-
callback_state : jax pytree
|
|
129
|
-
The callback state, initially set to the argument passed to
|
|
130
|
-
`run_mcmc`, afterwards to the value returned by the last invocation
|
|
131
|
-
of `inner_callback` or `outer_callback`.
|
|
132
|
-
n_burn, n_save, n_skip : int
|
|
133
|
-
The corresponding arguments as-is.
|
|
134
|
-
i_outer : int
|
|
135
|
-
The index of the last outer loop iteration (0-based).
|
|
136
|
-
inner_loop_length : int
|
|
137
|
-
The number of MCMC iterations in the inner loop.
|
|
138
|
-
|
|
139
|
-
`inner_callback` is called under the jax jit, so the argument values are
|
|
140
|
-
not available at the time the Python code is executed. Use the utilities
|
|
141
|
-
in `jax.debug` to access the values at actual runtime.
|
|
142
|
-
|
|
143
|
-
The callbacks must return two values:
|
|
144
|
-
|
|
145
|
-
bart : dict
|
|
146
|
-
A possibly modified MCMC state. To avoid modifying the state,
|
|
147
|
-
return the `bart` argument passed to the callback as-is.
|
|
148
|
-
callback_state : jax pytree
|
|
149
|
-
The new state to be passed on the next callback invocation.
|
|
150
|
-
|
|
151
|
-
For convenience, if a callback returns `None`, the states are not
|
|
152
|
-
updated.
|
|
153
|
-
callback_state : jax pytree, optional
|
|
154
|
-
The initial state for the callbacks.
|
|
155
|
-
onlymain_extractor : callable, optional
|
|
156
|
-
both_extractor : callable, optional
|
|
157
|
-
Functions that extract the variables to be saved respectively only in
|
|
158
|
-
the main trace and in both traces, given the MCMC state as argument.
|
|
159
|
-
Must return a pytree, and must be vmappable.
|
|
252
|
+
callback
|
|
253
|
+
An arbitrary function run during the loop after updating the state. For
|
|
254
|
+
the signature, see `Callback`. The callback is called under the jax jit,
|
|
255
|
+
so the argument values are not available at the time the Python code is
|
|
256
|
+
executed. Use the utilities in `jax.debug` to access the values at
|
|
257
|
+
actual runtime. The callback may return new values for the MCMC state
|
|
258
|
+
and the callback state.
|
|
259
|
+
callback_state
|
|
260
|
+
The initial custom state for the callback.
|
|
261
|
+
burnin_extractor
|
|
262
|
+
main_extractor
|
|
263
|
+
Functions that extract the variables to be saved respectively in the
|
|
264
|
+
burnin trace and main traces, given the MCMC state as argument. Must
|
|
265
|
+
return a pytree, and must be vmappable.
|
|
160
266
|
|
|
161
267
|
Returns
|
|
162
268
|
-------
|
|
163
|
-
bart :
|
|
269
|
+
bart : State
|
|
164
270
|
The final MCMC state.
|
|
165
|
-
burnin_trace :
|
|
166
|
-
The trace of the burn-in phase
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
|
|
170
|
-
the fields in `tracevars_both`).
|
|
171
|
-
main_trace : dict of (n_save, ...) arrays
|
|
172
|
-
The trace of the main phase, containing the following subset of fields
|
|
173
|
-
from the `bart` dictionary, with an additional head index that runs over
|
|
174
|
-
MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
|
|
175
|
-
specified the fields in `tracevars_onlymain`), plus the fields in
|
|
176
|
-
`burnin_trace`.
|
|
271
|
+
burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
|
|
272
|
+
The trace of the burn-in phase. For the default layout, see `BurninTrace`.
|
|
273
|
+
main_trace : PyTree[Shaped[Array, 'n_save *']]
|
|
274
|
+
The trace of the main phase. For the default layout, see `MainTrace`.
|
|
177
275
|
|
|
178
276
|
Raises
|
|
179
277
|
------
|
|
180
|
-
|
|
181
|
-
If `
|
|
182
|
-
|
|
278
|
+
RuntimeError
|
|
279
|
+
If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and
|
|
280
|
+
with settings that would create unrolled loops in the trace.
|
|
183
281
|
|
|
184
282
|
Notes
|
|
185
283
|
-----
|
|
186
284
|
The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
|
|
187
285
|
not include the initial state, and include the final state.
|
|
188
286
|
"""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
|
|
194
|
-
trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
|
|
287
|
+
# create empty traces
|
|
288
|
+
burnin_trace = _empty_trace(n_burn, bart, burnin_extractor)
|
|
289
|
+
main_trace = _empty_trace(n_save, bart, main_extractor)
|
|
195
290
|
|
|
196
291
|
# determine number of iterations for inner and outer loops
|
|
197
292
|
n_iters = n_burn + n_skip * n_save
|
|
198
293
|
if inner_loop_length is None:
|
|
199
294
|
inner_loop_length = n_iters
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
295
|
+
if inner_loop_length:
|
|
296
|
+
n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
|
|
297
|
+
else:
|
|
298
|
+
n_outer = 1
|
|
299
|
+
# setting to 0 would make for a clean noop, but it's useful to keep the
|
|
300
|
+
# same code path for benchmarking and testing
|
|
301
|
+
|
|
302
|
+
# error if under jit and there are unrolled loops or profile mode is on
|
|
303
|
+
under_jit = not hasattr(jnp.empty(0), 'platform')
|
|
304
|
+
if under_jit and (n_outer > 1 or get_profile_mode()):
|
|
305
|
+
msg = (
|
|
306
|
+
'`run_mcmc` was called within a jit-compiled function and '
|
|
307
|
+
'there are either more than 1 outer loops or profile mode is active, '
|
|
308
|
+
'please either do not jit, set `inner_loop_length=None`, or disable '
|
|
309
|
+
'profile mode.'
|
|
310
|
+
)
|
|
311
|
+
raise RuntimeError(msg)
|
|
312
|
+
|
|
313
|
+
replicate = partial(_replicate, mesh=bart.config.mesh)
|
|
314
|
+
carry = _Carry(
|
|
315
|
+
bart,
|
|
316
|
+
replicate(jnp.int32(0)),
|
|
317
|
+
replicate(key),
|
|
318
|
+
burnin_trace,
|
|
319
|
+
main_trace,
|
|
320
|
+
callback_state,
|
|
321
|
+
)
|
|
322
|
+
_run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001
|
|
208
323
|
for i_outer in range(n_outer):
|
|
209
324
|
carry = _run_mcmc_inner_loop(
|
|
210
325
|
carry,
|
|
211
326
|
inner_loop_length,
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
327
|
+
callback,
|
|
328
|
+
burnin_extractor,
|
|
329
|
+
main_extractor,
|
|
215
330
|
n_burn,
|
|
216
331
|
n_save,
|
|
217
332
|
n_skip,
|
|
218
333
|
i_outer,
|
|
334
|
+
n_iters,
|
|
219
335
|
)
|
|
220
|
-
if outer_callback is not None:
|
|
221
|
-
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
222
|
-
i_total -= 1 # because i_total is updated at the end of the inner loop
|
|
223
|
-
i_skip = _compute_i_skip(i_total, n_burn, n_skip)
|
|
224
|
-
rt = outer_callback(
|
|
225
|
-
bart=bart,
|
|
226
|
-
burnin=i_total < n_burn,
|
|
227
|
-
overflow=i_total >= n_iters,
|
|
228
|
-
i_total=i_total,
|
|
229
|
-
i_skip=i_skip,
|
|
230
|
-
callback_state=callback_state,
|
|
231
|
-
n_burn=n_burn,
|
|
232
|
-
n_save=n_save,
|
|
233
|
-
n_skip=n_skip,
|
|
234
|
-
i_outer=i_outer,
|
|
235
|
-
inner_loop_length=inner_loop_length,
|
|
236
|
-
)
|
|
237
|
-
if rt is not None:
|
|
238
|
-
bart, callback_state = rt
|
|
239
|
-
i_total += 1
|
|
240
|
-
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
241
336
|
|
|
242
|
-
bart,
|
|
337
|
+
return carry.bart, carry.burnin_trace, carry.main_trace
|
|
243
338
|
|
|
244
|
-
burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
|
|
245
|
-
main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
|
|
246
|
-
main_trace.update(trace_onlymain)
|
|
247
339
|
|
|
248
|
-
|
|
340
|
+
def _replicate(x: Array, mesh: Mesh | None) -> Array:
|
|
341
|
+
if mesh is None:
|
|
342
|
+
return x
|
|
343
|
+
else:
|
|
344
|
+
return device_put(x, NamedSharding(mesh, PartitionSpec()))
|
|
249
345
|
|
|
250
346
|
|
|
251
|
-
|
|
347
|
+
@partial(jit, static_argnums=(0, 2))
|
|
348
|
+
def _empty_trace(
|
|
349
|
+
length: int, bart: State, extractor: Callable[[State], PyTree]
|
|
350
|
+
) -> PyTree:
|
|
351
|
+
num_chains = get_num_chains(bart)
|
|
352
|
+
if num_chains is None:
|
|
353
|
+
out_axes = 0
|
|
354
|
+
else:
|
|
355
|
+
example_output = eval_shape(extractor, bart)
|
|
356
|
+
chain_axes = chain_vmap_axes(example_output)
|
|
357
|
+
out_axes = tree.map(
|
|
358
|
+
lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
|
|
359
|
+
)
|
|
360
|
+
return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@jit
|
|
364
|
+
def _compute_i_skip(
|
|
365
|
+
i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
|
|
366
|
+
) -> Int32[Array, '']:
|
|
367
|
+
"""Compute the `i_skip` argument passed to `callback`."""
|
|
252
368
|
burnin = i_total < n_burn
|
|
253
369
|
return jnp.where(
|
|
254
370
|
burnin,
|
|
255
371
|
i_total + 1,
|
|
256
|
-
(i_total + 1) % n_skip
|
|
372
|
+
(i_total - n_burn + 1) % n_skip
|
|
373
|
+
+ jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
|
|
257
374
|
)
|
|
258
375
|
|
|
259
376
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
both_extractor,
|
|
267
|
-
n_burn,
|
|
268
|
-
n_save,
|
|
269
|
-
n_skip,
|
|
270
|
-
i_outer,
|
|
271
|
-
):
|
|
272
|
-
def loop(carry, _):
|
|
273
|
-
bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
|
|
377
|
+
class _CallCounter:
|
|
378
|
+
"""Wrap a callable to check it's not called more than once."""
|
|
379
|
+
|
|
380
|
+
def __init__(self, func: Callable) -> None:
|
|
381
|
+
self.func = func
|
|
382
|
+
self.n_calls = 0
|
|
274
383
|
|
|
275
|
-
|
|
384
|
+
def reset_call_counter(self) -> None:
|
|
385
|
+
"""Reset the call counter."""
|
|
386
|
+
self.n_calls = 0
|
|
387
|
+
|
|
388
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
389
|
+
if self.n_calls and not get_profile_mode():
|
|
390
|
+
msg = (
|
|
391
|
+
'The inner loop of `run_mcmc` was traced more than once, '
|
|
392
|
+
'which indicates a double compilation of the MCMC code. This '
|
|
393
|
+
'probably depends on the input state having different type from the '
|
|
394
|
+
'output state. Check the input is in a format that is the '
|
|
395
|
+
'same jax would output, e.g., all arrays and scalars are jax '
|
|
396
|
+
'arrays, with the right shardings.'
|
|
397
|
+
)
|
|
398
|
+
raise RuntimeError(msg)
|
|
399
|
+
self.n_calls += 1
|
|
400
|
+
return self.func(*args, **kwargs)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
|
|
404
|
+
@_CallCounter
|
|
405
|
+
def _run_mcmc_inner_loop(
|
|
406
|
+
carry: _Carry,
|
|
407
|
+
inner_loop_length: int,
|
|
408
|
+
callback: Callback | None,
|
|
409
|
+
burnin_extractor: Callable[[State], PyTree],
|
|
410
|
+
main_extractor: Callable[[State], PyTree],
|
|
411
|
+
n_burn: Int32[Array, ''],
|
|
412
|
+
n_save: Int32[Array, ''],
|
|
413
|
+
n_skip: Int32[Array, ''],
|
|
414
|
+
i_outer: Int32[Array, ''],
|
|
415
|
+
n_iters: Int32[Array, ''],
|
|
416
|
+
) -> _Carry:
|
|
417
|
+
def loop_impl(carry: _Carry) -> _Carry:
|
|
418
|
+
"""Loop body to run if i_total < n_iters."""
|
|
419
|
+
# split random key
|
|
420
|
+
keys = jaxext.split(carry.key, 3)
|
|
276
421
|
key = keys.pop()
|
|
277
|
-
bart = mcmcstep.step(keys.pop(), bart)
|
|
278
422
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
423
|
+
# update state
|
|
424
|
+
bart = mcmcstep.step(keys.pop(), carry.bart)
|
|
425
|
+
|
|
426
|
+
# invoke callback
|
|
427
|
+
callback_state = carry.callback_state
|
|
428
|
+
if callback is not None:
|
|
429
|
+
i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
|
|
430
|
+
rt = callback(
|
|
431
|
+
key=keys.pop(),
|
|
283
432
|
bart=bart,
|
|
284
|
-
burnin=
|
|
285
|
-
|
|
286
|
-
i_total=i_total,
|
|
433
|
+
burnin=carry.i_total < n_burn,
|
|
434
|
+
i_total=carry.i_total,
|
|
287
435
|
i_skip=i_skip,
|
|
288
436
|
callback_state=callback_state,
|
|
289
437
|
n_burn=n_burn,
|
|
@@ -295,137 +443,240 @@ def _run_mcmc_inner_loop(
|
|
|
295
443
|
if rt is not None:
|
|
296
444
|
bart, callback_state = rt
|
|
297
445
|
|
|
298
|
-
|
|
299
|
-
|
|
446
|
+
# save to trace
|
|
447
|
+
burnin_trace, main_trace = _save_state_to_trace(
|
|
448
|
+
carry.burnin_trace,
|
|
449
|
+
carry.main_trace,
|
|
450
|
+
burnin_extractor,
|
|
451
|
+
main_extractor,
|
|
452
|
+
bart,
|
|
453
|
+
carry.i_total,
|
|
454
|
+
n_burn,
|
|
455
|
+
n_skip,
|
|
456
|
+
)
|
|
300
457
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
return trace_array
|
|
458
|
+
return _Carry(
|
|
459
|
+
bart=bart,
|
|
460
|
+
i_total=carry.i_total + 1,
|
|
461
|
+
key=key,
|
|
462
|
+
burnin_trace=burnin_trace,
|
|
463
|
+
main_trace=main_trace,
|
|
464
|
+
callback_state=callback_state,
|
|
465
|
+
)
|
|
310
466
|
|
|
311
|
-
|
|
467
|
+
def loop_noop(carry: _Carry) -> _Carry:
|
|
468
|
+
"""Loop body to run if i_total >= n_iters; it does nothing."""
|
|
469
|
+
return carry
|
|
312
470
|
|
|
313
|
-
|
|
314
|
-
|
|
471
|
+
def loop(carry: _Carry, _) -> tuple[_Carry, None]:
|
|
472
|
+
carry = cond_if_not_profiling(
|
|
473
|
+
carry.i_total < n_iters, loop_impl, loop_noop, carry
|
|
315
474
|
)
|
|
316
|
-
trace_both = update_trace(i_both, trace_both, both_extractor(bart))
|
|
317
|
-
|
|
318
|
-
i_total += 1
|
|
319
|
-
carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
|
|
320
475
|
return carry, None
|
|
321
476
|
|
|
322
|
-
carry, _ =
|
|
477
|
+
carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length)
|
|
323
478
|
return carry
|
|
324
479
|
|
|
325
480
|
|
|
326
|
-
|
|
481
|
+
@partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3))
|
|
482
|
+
# this is jitted because under profiling _run_mcmc_inner_loop and the loop
|
|
483
|
+
# within it are not, so I need the donate_argnums feature of jit to avoid
|
|
484
|
+
# creating copies of the traces
|
|
485
|
+
def _save_state_to_trace(
|
|
486
|
+
burnin_trace: PyTree,
|
|
487
|
+
main_trace: PyTree,
|
|
488
|
+
burnin_extractor: Callable[[State], PyTree],
|
|
489
|
+
main_extractor: Callable[[State], PyTree],
|
|
490
|
+
bart: State,
|
|
491
|
+
i_total: Int32[Array, ''],
|
|
492
|
+
n_burn: Int32[Array, ''],
|
|
493
|
+
n_skip: Int32[Array, ''],
|
|
494
|
+
) -> tuple[PyTree, PyTree]:
|
|
495
|
+
# trace index where to save during burnin; out-of-bounds => noop after
|
|
496
|
+
# burnin
|
|
497
|
+
burnin_idx = i_total
|
|
498
|
+
|
|
499
|
+
# trace index where to save during main phase; force it out-of-bounds
|
|
500
|
+
# during burnin
|
|
501
|
+
main_idx = (i_total - n_burn) // n_skip
|
|
502
|
+
noop_idx = jnp.iinfo(jnp.int32).max
|
|
503
|
+
noop_cond = i_total < n_burn
|
|
504
|
+
main_idx = jnp.where(noop_cond, noop_idx, main_idx)
|
|
505
|
+
|
|
506
|
+
# prepare array index
|
|
507
|
+
num_chains = get_num_chains(bart)
|
|
508
|
+
burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains)
|
|
509
|
+
main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains)
|
|
510
|
+
|
|
511
|
+
return burnin_trace, main_trace
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def _set(
|
|
515
|
+
trace: PyTree[Array, ' T'],
|
|
516
|
+
index: Int32[Array, ''],
|
|
517
|
+
val: PyTree[Array, ' T'],
|
|
518
|
+
num_chains: int | None,
|
|
519
|
+
) -> PyTree[Array, ' T']:
|
|
520
|
+
"""Do ``trace[index] = val`` but fancier."""
|
|
521
|
+
chain_axis = chain_vmap_axes(val)
|
|
522
|
+
|
|
523
|
+
def at_set(
|
|
524
|
+
trace: Shaped[Array, 'chains samples *shape']
|
|
525
|
+
| Shaped[Array, ' samples *shape']
|
|
526
|
+
| None,
|
|
527
|
+
val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
|
|
528
|
+
chain_axis: int | None,
|
|
529
|
+
):
|
|
530
|
+
if trace is None or trace.size == 0:
|
|
531
|
+
# this handles the case where an array is empty because jax refuses
|
|
532
|
+
# to index into an axis of length 0, even if just in the abstract,
|
|
533
|
+
# and optional elements that are considered leaves due to `is_leaf`
|
|
534
|
+
# below needed to traverse `chain_axis`.
|
|
535
|
+
return trace
|
|
536
|
+
|
|
537
|
+
if num_chains is None or chain_axis is None:
|
|
538
|
+
ndindex = (index, ...)
|
|
539
|
+
else:
|
|
540
|
+
ndindex = (slice(None), index, ...)
|
|
541
|
+
|
|
542
|
+
return trace.at[ndindex].set(val, mode='drop')
|
|
543
|
+
|
|
544
|
+
return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def make_default_callback(
|
|
548
|
+
state: State,
|
|
549
|
+
*,
|
|
550
|
+
dot_every: int | Integer[Array, ''] | None = 1,
|
|
551
|
+
report_every: int | Integer[Array, ''] | None = 100,
|
|
552
|
+
) -> dict[str, Any]:
|
|
327
553
|
"""
|
|
328
|
-
Prepare
|
|
554
|
+
Prepare a default callback for `run_mcmc`.
|
|
329
555
|
|
|
330
|
-
|
|
331
|
-
report outer loop iteration.
|
|
556
|
+
The callback prints a dot on every iteration, and a longer
|
|
557
|
+
report outer loop iteration, and can do variable selection.
|
|
332
558
|
|
|
333
559
|
Parameters
|
|
334
560
|
----------
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
iterations.
|
|
561
|
+
state
|
|
562
|
+
The bart state to use the callback with, used to determine device
|
|
563
|
+
sharding.
|
|
564
|
+
dot_every
|
|
565
|
+
A dot is printed every `dot_every` MCMC iterations, `None` to disable.
|
|
566
|
+
report_every
|
|
567
|
+
A one line report is printed every `report_every` MCMC iterations,
|
|
568
|
+
`None` to disable.
|
|
340
569
|
|
|
341
570
|
Returns
|
|
342
571
|
-------
|
|
343
|
-
|
|
344
|
-
A dictionary with the arguments to pass to `run_mcmc` as keyword
|
|
345
|
-
arguments to set up the callbacks.
|
|
572
|
+
A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
|
|
346
573
|
|
|
347
574
|
Examples
|
|
348
575
|
--------
|
|
349
|
-
>>> run_mcmc(..., **
|
|
576
|
+
>>> run_mcmc(key, state, ..., **make_default_callback(state, ...))
|
|
350
577
|
"""
|
|
578
|
+
|
|
579
|
+
def as_replicated_array_or_none(val: None | Any) -> None | Array:
|
|
580
|
+
return None if val is None else _replicate(jnp.asarray(val), state.config.mesh)
|
|
581
|
+
|
|
351
582
|
return dict(
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
583
|
+
callback=print_callback,
|
|
584
|
+
callback_state=PrintCallbackState(
|
|
585
|
+
as_replicated_array_or_none(dot_every),
|
|
586
|
+
as_replicated_array_or_none(report_every),
|
|
356
587
|
),
|
|
357
588
|
)
|
|
358
589
|
|
|
359
590
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
if dot_every_inner is not None:
|
|
363
|
-
cond = (i_total + 1) % dot_every_inner == 0
|
|
364
|
-
debug.callback(_print_dot, cond)
|
|
591
|
+
class PrintCallbackState(Module):
|
|
592
|
+
"""State for `print_callback`."""
|
|
365
593
|
|
|
594
|
+
dot_every: Int32[Array, ''] | None
|
|
595
|
+
"""A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
|
|
366
596
|
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
597
|
+
report_every: Int32[Array, ''] | None
|
|
598
|
+
"""A one line report is printed every `report_every` MCMC iterations,
|
|
599
|
+
`None` to disable."""
|
|
370
600
|
|
|
371
601
|
|
|
372
|
-
def
|
|
602
|
+
def print_callback(
|
|
373
603
|
*,
|
|
374
|
-
bart,
|
|
375
|
-
burnin,
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
callback_state,
|
|
382
|
-
i_outer,
|
|
383
|
-
inner_loop_length,
|
|
604
|
+
bart: State,
|
|
605
|
+
burnin: Bool[Array, ''],
|
|
606
|
+
i_total: Int32[Array, ''],
|
|
607
|
+
n_burn: Int32[Array, ''],
|
|
608
|
+
n_save: Int32[Array, ''],
|
|
609
|
+
n_skip: Int32[Array, ''],
|
|
610
|
+
callback_state: PrintCallbackState,
|
|
384
611
|
**_,
|
|
385
612
|
):
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
613
|
+
"""Print a dot and/or a report periodically during the MCMC."""
|
|
614
|
+
report_every = callback_state.report_every
|
|
615
|
+
dot_every = callback_state.dot_every
|
|
616
|
+
it = i_total + 1
|
|
617
|
+
|
|
618
|
+
def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']:
|
|
619
|
+
return False if every is None else it % every == 0
|
|
620
|
+
|
|
621
|
+
report_cond = get_cond(report_every)
|
|
622
|
+
dot_cond = get_cond(dot_every)
|
|
623
|
+
|
|
624
|
+
def line_report_branch():
|
|
625
|
+
if report_every is None:
|
|
626
|
+
return
|
|
627
|
+
if dot_every is None:
|
|
628
|
+
print_newline = False
|
|
391
629
|
else:
|
|
392
|
-
|
|
630
|
+
print_newline = it % report_every > it % dot_every
|
|
393
631
|
debug.callback(
|
|
394
632
|
_print_report,
|
|
395
|
-
|
|
396
|
-
|
|
633
|
+
print_dot=dot_cond,
|
|
634
|
+
print_newline=print_newline,
|
|
397
635
|
burnin=burnin,
|
|
398
|
-
|
|
399
|
-
i_total=i_total,
|
|
636
|
+
it=it,
|
|
400
637
|
n_iters=n_burn + n_save * n_skip,
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
prune_acc_count=bart.forest.prune_acc_count,
|
|
405
|
-
prop_total=
|
|
406
|
-
fill=
|
|
638
|
+
num_chains=bart.forest.num_chains(),
|
|
639
|
+
grow_prop_count=bart.forest.grow_prop_count.mean(),
|
|
640
|
+
grow_acc_count=bart.forest.grow_acc_count.mean(),
|
|
641
|
+
prune_acc_count=bart.forest.prune_acc_count.mean(),
|
|
642
|
+
prop_total=bart.forest.split_tree.shape[-2],
|
|
643
|
+
fill=forest_fill(bart.forest.split_tree),
|
|
407
644
|
)
|
|
408
645
|
|
|
646
|
+
def just_dot_branch():
|
|
647
|
+
if dot_every is None:
|
|
648
|
+
return
|
|
649
|
+
debug.callback(
|
|
650
|
+
lambda: print('.', end='', flush=True) # noqa: T201
|
|
651
|
+
)
|
|
652
|
+
# logging can't do in-line printing so we use print
|
|
653
|
+
|
|
654
|
+
cond_if_not_profiling(
|
|
655
|
+
report_cond,
|
|
656
|
+
line_report_branch,
|
|
657
|
+
lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None),
|
|
658
|
+
)
|
|
409
659
|
|
|
410
|
-
|
|
660
|
+
|
|
661
|
+
def _convert_jax_arrays_in_args(func: Callable) -> Callable:
|
|
411
662
|
"""Remove jax arrays from a function arguments.
|
|
412
663
|
|
|
413
|
-
Converts all jax.Array instances in the arguments to either Python scalars
|
|
664
|
+
Converts all `jax.Array` instances in the arguments to either Python scalars
|
|
414
665
|
or numpy arrays.
|
|
415
666
|
"""
|
|
416
667
|
|
|
417
|
-
def convert_jax_arrays(pytree):
|
|
418
|
-
def
|
|
419
|
-
if not isinstance(val,
|
|
668
|
+
def convert_jax_arrays(pytree: PyTree) -> PyTree:
|
|
669
|
+
def convert_jax_array(val: Any) -> Any:
|
|
670
|
+
if not isinstance(val, Array):
|
|
420
671
|
return val
|
|
421
672
|
elif val.shape:
|
|
422
673
|
return numpy.array(val)
|
|
423
674
|
else:
|
|
424
675
|
return val.item()
|
|
425
676
|
|
|
426
|
-
return tree.map(
|
|
677
|
+
return tree.map(convert_jax_array, pytree)
|
|
427
678
|
|
|
428
|
-
@
|
|
679
|
+
@wraps(func)
|
|
429
680
|
def new_func(*args, **kw):
|
|
430
681
|
args = convert_jax_arrays(args)
|
|
431
682
|
kw = convert_jax_arrays(kw)
|
|
@@ -439,73 +690,170 @@ def _convert_jax_arrays_in_args(func):
|
|
|
439
690
|
# deadlock with the main thread
|
|
440
691
|
def _print_report(
|
|
441
692
|
*,
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
burnin,
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
grow_prop_count,
|
|
449
|
-
grow_acc_count,
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
fill,
|
|
693
|
+
print_dot: bool,
|
|
694
|
+
print_newline: bool,
|
|
695
|
+
burnin: bool,
|
|
696
|
+
it: int,
|
|
697
|
+
n_iters: int,
|
|
698
|
+
num_chains: int | None,
|
|
699
|
+
grow_prop_count: float,
|
|
700
|
+
grow_acc_count: float,
|
|
701
|
+
prune_acc_count: float,
|
|
702
|
+
prop_total: int,
|
|
703
|
+
fill: float,
|
|
454
704
|
):
|
|
455
|
-
|
|
456
|
-
|
|
705
|
+
"""Print the report for `print_callback`."""
|
|
706
|
+
# compute fractions
|
|
707
|
+
grow_prop = grow_prop_count / prop_total
|
|
708
|
+
move_acc = (grow_acc_count + prune_acc_count) / prop_total
|
|
709
|
+
|
|
710
|
+
# determine prefix
|
|
711
|
+
if print_dot:
|
|
712
|
+
prefix = '.\n'
|
|
713
|
+
elif print_newline:
|
|
714
|
+
prefix = '\n'
|
|
715
|
+
else:
|
|
716
|
+
prefix = ''
|
|
717
|
+
|
|
718
|
+
# determine suffix in parentheses
|
|
719
|
+
msgs = []
|
|
720
|
+
if num_chains is not None:
|
|
721
|
+
msgs.append(f'avg. {num_chains} chains')
|
|
722
|
+
if burnin:
|
|
723
|
+
msgs.append('burnin')
|
|
724
|
+
suffix = f' ({", ".join(msgs)})' if msgs else ''
|
|
725
|
+
|
|
726
|
+
print( # noqa: T201, see print_callback for why not logging
|
|
727
|
+
f'{prefix}Iteration {it}/{n_iters}, '
|
|
728
|
+
f'grow prob: {grow_prop:.0%}, '
|
|
729
|
+
f'move acc: {move_acc:.0%}, '
|
|
730
|
+
f'fill: {fill:.0%}{suffix}'
|
|
731
|
+
)
|
|
457
732
|
|
|
458
|
-
def acc_string(acc_count, prop_count):
|
|
459
|
-
if prop_count:
|
|
460
|
-
return f'{acc_count / prop_count:.0%}'
|
|
461
|
-
else:
|
|
462
|
-
return ' n/d'
|
|
463
733
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
grow_acc = acc_string(grow_acc_count, grow_prop_count)
|
|
467
|
-
prune_acc = acc_string(prune_acc_count, prune_prop_count)
|
|
734
|
+
class Trace(TreeHeaps, Protocol):
|
|
735
|
+
"""Protocol for a MCMC trace."""
|
|
468
736
|
|
|
469
|
-
|
|
470
|
-
flag = ' (burnin)'
|
|
471
|
-
elif overflow:
|
|
472
|
-
flag = ' (overflow)'
|
|
473
|
-
else:
|
|
474
|
-
flag = ''
|
|
737
|
+
offset: Float32[Array, '*trace_shape']
|
|
475
738
|
|
|
476
|
-
print(
|
|
477
|
-
f'{newline}It {i_total + 1}/{n_iters} '
|
|
478
|
-
f'grow P={grow_prop:.0%} A={grow_acc}, '
|
|
479
|
-
f'prune P={prune_prop:.0%} A={prune_acc}, '
|
|
480
|
-
f'fill={fill:.0%}{flag}'
|
|
481
|
-
)
|
|
482
739
|
|
|
740
|
+
class TreesTrace(Module):
|
|
741
|
+
"""Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
|
|
742
|
+
|
|
743
|
+
leaf_tree: (
|
|
744
|
+
Float32[Array, '*trace_shape num_trees 2**d']
|
|
745
|
+
| Float32[Array, '*trace_shape num_trees k 2**d']
|
|
746
|
+
)
|
|
747
|
+
var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
|
|
748
|
+
split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
|
|
483
749
|
|
|
484
|
-
@
|
|
485
|
-
def
|
|
750
|
+
@classmethod
|
|
751
|
+
def from_dataclass(cls, obj: TreeHeaps):
|
|
752
|
+
"""Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
|
|
753
|
+
return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@jit
|
|
757
|
+
def evaluate_trace(
|
|
758
|
+
X: UInt[Array, 'p n'], trace: Trace
|
|
759
|
+
) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
|
|
486
760
|
"""
|
|
487
761
|
Compute predictions for all iterations of the BART MCMC.
|
|
488
762
|
|
|
489
763
|
Parameters
|
|
490
764
|
----------
|
|
491
|
-
|
|
492
|
-
A trace of the BART MCMC, as returned by `run_mcmc`.
|
|
493
|
-
X : array (p, n)
|
|
765
|
+
X
|
|
494
766
|
The predictors matrix, with `p` predictors and `n` observations.
|
|
767
|
+
trace
|
|
768
|
+
A main trace of the BART MCMC, as returned by `run_mcmc`.
|
|
495
769
|
|
|
496
770
|
Returns
|
|
497
771
|
-------
|
|
498
|
-
|
|
499
|
-
The predictions for each iteration of the MCMC.
|
|
772
|
+
The predictions for each chain and iteration of the MCMC.
|
|
500
773
|
"""
|
|
501
|
-
|
|
502
|
-
|
|
774
|
+
# per-device memory limit
|
|
775
|
+
max_io_nbytes = 2**27 # 128 MiB
|
|
776
|
+
|
|
777
|
+
# adjust memory limit for number of devices
|
|
778
|
+
mesh = jax.typeof(trace.leaf_tree).sharding.mesh
|
|
779
|
+
num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data')
|
|
780
|
+
max_io_nbytes *= num_devices
|
|
781
|
+
|
|
782
|
+
# determine batching axes
|
|
783
|
+
has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes
|
|
784
|
+
if has_chains:
|
|
785
|
+
sample_axis = 1
|
|
786
|
+
tree_axis = 2
|
|
787
|
+
else:
|
|
788
|
+
sample_axis = 0
|
|
789
|
+
tree_axis = 1
|
|
790
|
+
|
|
791
|
+
# batch and sum over trees
|
|
792
|
+
batched_eval = autobatch(
|
|
793
|
+
evaluate_forest,
|
|
794
|
+
max_io_nbytes,
|
|
795
|
+
(None, tree_axis),
|
|
796
|
+
tree_axis,
|
|
797
|
+
reduce_ufunc=jnp.add,
|
|
798
|
+
)
|
|
503
799
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
800
|
+
# determine output shape (to avoid autobatch tracing everything 4 times)
|
|
801
|
+
is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
|
|
802
|
+
k = trace.leaf_tree.shape[-2] if is_mv else 1
|
|
803
|
+
mv_shape = (k,) if is_mv else ()
|
|
804
|
+
_, n = X.shape
|
|
805
|
+
out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n)
|
|
806
|
+
|
|
807
|
+
# adjust memory limit keeping into account that trees are summed over
|
|
808
|
+
num_trees, hts = trace.split_tree.shape[-2:]
|
|
809
|
+
out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest
|
|
810
|
+
core_io_size = (
|
|
811
|
+
num_trees
|
|
812
|
+
* hts
|
|
813
|
+
* (
|
|
814
|
+
2 * k * trace.leaf_tree.itemsize
|
|
815
|
+
+ trace.var_tree.itemsize
|
|
816
|
+
+ trace.split_tree.itemsize
|
|
507
817
|
)
|
|
508
|
-
|
|
818
|
+
+ out_size
|
|
819
|
+
)
|
|
820
|
+
core_int_size = (num_trees - 1) * out_size
|
|
821
|
+
max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size)))
|
|
822
|
+
|
|
823
|
+
# batch over mcmc samples
|
|
824
|
+
batched_eval = autobatch(
|
|
825
|
+
batched_eval,
|
|
826
|
+
max_io_nbytes,
|
|
827
|
+
(None, sample_axis),
|
|
828
|
+
sample_axis,
|
|
829
|
+
warn_on_overflow=False, # the inner autobatch will handle it
|
|
830
|
+
result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
# extract only the trees from the trace
|
|
834
|
+
trees = TreesTrace.from_dataclass(trace)
|
|
835
|
+
|
|
836
|
+
# evaluate trees
|
|
837
|
+
y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
|
|
838
|
+
y_centered = batched_eval(X, trees)
|
|
839
|
+
return y_centered + trace.offset[..., None]
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
@partial(jit, static_argnums=(0,))
|
|
843
|
+
def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
|
|
844
|
+
"""
|
|
845
|
+
Count how many times each predictor is used in each MCMC state.
|
|
509
846
|
|
|
510
|
-
|
|
511
|
-
|
|
847
|
+
Parameters
|
|
848
|
+
----------
|
|
849
|
+
p
|
|
850
|
+
The number of predictors.
|
|
851
|
+
trace
|
|
852
|
+
A main trace of the BART MCMC, as returned by `run_mcmc`.
|
|
853
|
+
|
|
854
|
+
Returns
|
|
855
|
+
-------
|
|
856
|
+
Histogram of predictor usage in each MCMC state.
|
|
857
|
+
"""
|
|
858
|
+
# var_tree has shape (chains? samples trees nodes)
|
|
859
|
+
return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1)
|