bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcloop.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcloop.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -22,268 +22,416 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """Functions that implement the full BART posterior MCMC loop."""
25
+ """Functions that implement the full BART posterior MCMC loop.
26
26
 
27
- import functools
27
+ The entry points are `run_mcmc` and `make_default_callback`.
28
+ """
29
+
30
+ from collections.abc import Callable
31
+ from dataclasses import fields
32
+ from functools import partial, wraps
33
+ from math import floor
34
+ from typing import Any, Protocol
28
35
 
29
36
  import jax
30
37
  import numpy
31
- from jax import debug, lax, tree
38
+ from equinox import Module
39
+ from jax import (
40
+ NamedSharding,
41
+ ShapeDtypeStruct,
42
+ debug,
43
+ device_put,
44
+ eval_shape,
45
+ jit,
46
+ tree,
47
+ )
32
48
  from jax import numpy as jnp
33
- from jaxtyping import Array, Real
49
+ from jax.nn import softmax
50
+ from jax.sharding import Mesh, PartitionSpec
51
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
52
+
53
+ from bartz import jaxext, mcmcstep
54
+ from bartz._profiler import (
55
+ cond_if_not_profiling,
56
+ get_profile_mode,
57
+ jit_if_not_profiling,
58
+ scan_if_not_profiling,
59
+ )
60
+ from bartz.grove import TreeHeaps, evaluate_forest, forest_fill, var_histogram
61
+ from bartz.jaxext import autobatch
62
+ from bartz.mcmcstep import State
63
+ from bartz.mcmcstep._state import chain_vmap_axes, field, get_axis_size, get_num_chains
64
+
65
+
66
+ class BurninTrace(Module):
67
+ """MCMC trace with only diagnostic values."""
68
+
69
+ error_cov_inv: (
70
+ Float32[Array, '*chains_and_samples']
71
+ | Float32[Array, '*chains_and_samples k k']
72
+ | None
73
+ ) = field(chains=True)
74
+ theta: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
75
+ grow_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
76
+ grow_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
77
+ prune_prop_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
78
+ prune_acc_count: Int32[Array, '*chains_and_samples'] = field(chains=True)
79
+ log_likelihood: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
80
+ log_trans_prior: Float32[Array, '*chains_and_samples'] | None = field(chains=True)
81
+
82
+ @classmethod
83
+ def from_state(cls, state: State) -> 'BurninTrace':
84
+ """Create a single-item burn-in trace from a MCMC state."""
85
+ return cls(
86
+ error_cov_inv=state.error_cov_inv,
87
+ theta=state.forest.theta,
88
+ grow_prop_count=state.forest.grow_prop_count,
89
+ grow_acc_count=state.forest.grow_acc_count,
90
+ prune_prop_count=state.forest.prune_prop_count,
91
+ prune_acc_count=state.forest.prune_acc_count,
92
+ log_likelihood=state.forest.log_likelihood,
93
+ log_trans_prior=state.forest.log_trans_prior,
94
+ )
34
95
 
35
- from . import grove, jaxext, mcmcstep
36
- from .mcmcstep import State
37
96
 
97
+ class MainTrace(BurninTrace):
98
+ """MCMC trace with trees and diagnostic values."""
99
+
100
+ leaf_tree: (
101
+ Float32[Array, '*chains_and_samples 2**d']
102
+ | Float32[Array, '*chains_and_samples k 2**d']
103
+ ) = field(chains=True)
104
+ var_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
105
+ split_tree: UInt[Array, '*chains_and_samples 2**(d-1)'] = field(chains=True)
106
+ offset: Float32[Array, '*samples'] | Float32[Array, '*samples k']
107
+ varprob: Float32[Array, '*chains_and_samples p'] | None = field(chains=True)
108
+
109
+ @classmethod
110
+ def from_state(cls, state: State) -> 'MainTrace':
111
+ """Create a single-item main trace from a MCMC state."""
112
+ # compute varprob
113
+ log_s = state.forest.log_s
114
+ if log_s is None:
115
+ varprob = None
116
+ else:
117
+ varprob = softmax(log_s, where=state.forest.max_split.astype(bool))
118
+
119
+ return cls(
120
+ leaf_tree=state.forest.leaf_tree,
121
+ var_tree=state.forest.var_tree,
122
+ split_tree=state.forest.split_tree,
123
+ offset=state.offset,
124
+ varprob=varprob,
125
+ **vars(BurninTrace.from_state(state)),
126
+ )
38
127
 
39
- def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
40
- """Extract variables for the main trace, to be used in `run_mcmc`."""
41
- return dict(
42
- leaf_trees=state.forest.leaf_trees,
43
- var_trees=state.forest.var_trees,
44
- split_trees=state.forest.split_trees,
45
- offset=state.offset,
46
- )
47
128
 
129
+ CallbackState = PyTree[Any, 'T']
130
+
131
+
132
+ class Callback(Protocol):
133
+ """Callback type for `run_mcmc`."""
134
+
135
+ def __call__(
136
+ self,
137
+ *,
138
+ key: Key[Array, ''],
139
+ bart: State,
140
+ burnin: Bool[Array, ''],
141
+ i_total: Int32[Array, ''],
142
+ i_skip: Int32[Array, ''],
143
+ callback_state: CallbackState,
144
+ n_burn: Int32[Array, ''],
145
+ n_save: Int32[Array, ''],
146
+ n_skip: Int32[Array, ''],
147
+ i_outer: Int32[Array, ''],
148
+ inner_loop_length: int,
149
+ ) -> tuple[State, CallbackState] | None:
150
+ """Do an arbitrary action after an iteration of the MCMC.
151
+
152
+ Parameters
153
+ ----------
154
+ key
155
+ A key for random number generation.
156
+ bart
157
+ The MCMC state just after updating it.
158
+ burnin
159
+ Whether the last iteration was in the burn-in phase.
160
+ i_total
161
+ The index of the last MCMC iteration (0-based).
162
+ i_skip
163
+ The number of MCMC updates from the last saved state. The initial
164
+ state counts as saved, even if it's not copied into the trace.
165
+ callback_state
166
+ The callback state, initially set to the argument passed to
167
+ `run_mcmc`, afterwards to the value returned by the last invocation
168
+ of the callback.
169
+ n_burn
170
+ n_save
171
+ n_skip
172
+ The corresponding `run_mcmc` arguments as-is.
173
+ i_outer
174
+ The index of the last outer loop iteration (0-based).
175
+ inner_loop_length
176
+ The number of MCMC iterations in the inner loop.
177
+
178
+ Returns
179
+ -------
180
+ bart : State
181
+ A possibly modified MCMC state. To avoid modifying the state,
182
+ return the `bart` argument passed to the callback as-is.
183
+ callback_state : CallbackState
184
+ The new state to be passed on the next callback invocation.
48
185
 
49
- def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]:
50
- """Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
51
- return dict(
52
- sigma2=state.sigma2,
53
- grow_prop_count=state.forest.grow_prop_count,
54
- grow_acc_count=state.forest.grow_acc_count,
55
- prune_prop_count=state.forest.prune_prop_count,
56
- prune_acc_count=state.forest.prune_acc_count,
57
- log_likelihood=state.forest.log_likelihood,
58
- log_trans_prior=state.forest.log_trans_prior,
59
- )
186
+ Notes
187
+ -----
188
+ For convenience, the callback may return `None`, and the states won't
189
+ be updated.
190
+ """
191
+ ...
192
+
193
+
194
+ class _Carry(Module):
195
+ """Carry used in the loop in `run_mcmc`."""
196
+
197
+ bart: State
198
+ i_total: Int32[Array, '']
199
+ key: Key[Array, '']
200
+ burnin_trace: PyTree[
201
+ Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']
202
+ ]
203
+ main_trace: PyTree[
204
+ Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']
205
+ ]
206
+ callback_state: CallbackState
60
207
 
61
208
 
62
209
  def run_mcmc(
63
- key,
64
- bart,
65
- n_save,
210
+ key: Key[Array, ''],
211
+ bart: State,
212
+ n_save: int,
66
213
  *,
67
- n_burn=0,
68
- n_skip=1,
69
- inner_loop_length=None,
70
- allow_overflow=False,
71
- inner_callback=None,
72
- outer_callback=None,
73
- callback_state=None,
74
- onlymain_extractor=default_onlymain_extractor,
75
- both_extractor=default_both_extractor,
76
- ):
214
+ n_burn: int = 0,
215
+ n_skip: int = 1,
216
+ inner_loop_length: int | None = None,
217
+ callback: Callback | None = None,
218
+ callback_state: CallbackState = None,
219
+ burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
220
+ main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
221
+ ) -> tuple[
222
+ State,
223
+ PyTree[Shaped[Array, 'n_burn ...'] | Shaped[Array, 'num_chains n_burn ...']],
224
+ PyTree[Shaped[Array, 'n_save ...'] | Shaped[Array, 'num_chains n_save ...']],
225
+ ]:
77
226
  """
78
227
  Run the MCMC for the BART posterior.
79
228
 
80
229
  Parameters
81
230
  ----------
82
- key : jax.dtypes.prng_key array
231
+ key
83
232
  A key for random number generation.
84
- bart : dict
233
+ bart
85
234
  The initial MCMC state, as created and updated by the functions in
86
235
  `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
87
236
  so this variable is invalidated after running `run_mcmc`. Make a copy
88
237
  beforehand to use it again.
89
- n_save : int
238
+ n_save
90
239
  The number of iterations to save.
91
- n_burn : int, default 0
240
+ n_burn
92
241
  The number of initial iterations which are not saved.
93
- n_skip : int, default 1
242
+ n_skip
94
243
  The number of iterations to skip between each saved iteration, plus 1.
95
244
  The effective burn-in is ``n_burn + n_skip - 1``.
96
- inner_loop_length : int, optional
245
+ inner_loop_length
97
246
  The MCMC loop is split into an outer and an inner loop. The outer loop
98
247
  is in Python, while the inner loop is in JAX. `inner_loop_length` is the
99
248
  number of iterations of the inner loop to run for each iteration of the
100
249
  outer loop. If not specified, the outer loop will iterate just once,
101
250
  with all iterations done in a single inner loop run. The inner stride is
102
251
  unrelated to the stride used for saving the trace.
103
- allow_overflow : bool, default False
104
- If `False`, `inner_loop_length` must be a divisor of the total number of
105
- iterations ``n_burn + n_skip * n_save``. If `True` and
106
- `inner_loop_length` is not a divisor, some of the MCMC iterations in the
107
- last outer loop iteration will not be saved to the trace.
108
- inner_callback : callable, optional
109
- outer_callback : callable, optional
110
- Arbitrary functions run during the loop after updating the state.
111
- `inner_callback` is called after each update, while `outer_callback` is
112
- called after completing an inner loop. The callbacks are invoked with
113
- the following arguments, passed by keyword:
114
-
115
- bart : dict
116
- The MCMC state just after updating it.
117
- burnin : bool
118
- Whether the last iteration was in the burn-in phase.
119
- overflow : bool
120
- Whether the last iteration was in the overflow phase (iterations
121
- not saved due to `inner_loop_length` not being a divisor of the
122
- total number of iterations).
123
- i_total : int
124
- The index of the last MCMC iteration (0-based).
125
- i_skip : int
126
- The number of MCMC updates from the last saved state. The initial
127
- state counts as saved, even if it's not copied into the trace.
128
- callback_state : jax pytree
129
- The callback state, initially set to the argument passed to
130
- `run_mcmc`, afterwards to the value returned by the last invocation
131
- of `inner_callback` or `outer_callback`.
132
- n_burn, n_save, n_skip : int
133
- The corresponding arguments as-is.
134
- i_outer : int
135
- The index of the last outer loop iteration (0-based).
136
- inner_loop_length : int
137
- The number of MCMC iterations in the inner loop.
138
-
139
- `inner_callback` is called under the jax jit, so the argument values are
140
- not available at the time the Python code is executed. Use the utilities
141
- in `jax.debug` to access the values at actual runtime.
142
-
143
- The callbacks must return two values:
144
-
145
- bart : dict
146
- A possibly modified MCMC state. To avoid modifying the state,
147
- return the `bart` argument passed to the callback as-is.
148
- callback_state : jax pytree
149
- The new state to be passed on the next callback invocation.
150
-
151
- For convenience, if a callback returns `None`, the states are not
152
- updated.
153
- callback_state : jax pytree, optional
154
- The initial state for the callbacks.
155
- onlymain_extractor : callable, optional
156
- both_extractor : callable, optional
157
- Functions that extract the variables to be saved respectively only in
158
- the main trace and in both traces, given the MCMC state as argument.
159
- Must return a pytree, and must be vmappable.
252
+ callback
253
+ An arbitrary function run during the loop after updating the state. For
254
+ the signature, see `Callback`. The callback is called under the jax jit,
255
+ so the argument values are not available at the time the Python code is
256
+ executed. Use the utilities in `jax.debug` to access the values at
257
+ actual runtime. The callback may return new values for the MCMC state
258
+ and the callback state.
259
+ callback_state
260
+ The initial custom state for the callback.
261
+ burnin_extractor
262
+ main_extractor
263
+ Functions that extract the variables to be saved respectively in the
264
+ burnin trace and main traces, given the MCMC state as argument. Must
265
+ return a pytree, and must be vmappable.
160
266
 
161
267
  Returns
162
268
  -------
163
- bart : dict
269
+ bart : State
164
270
  The final MCMC state.
165
- burnin_trace : dict of (n_burn, ...) arrays
166
- The trace of the burn-in phase, containing the following subset of
167
- fields from the `bart` dictionary, with an additional head index that
168
- runs over MCMC iterations: 'sigma2', 'grow_prop_count',
169
- 'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
170
- the fields in `tracevars_both`).
171
- main_trace : dict of (n_save, ...) arrays
172
- The trace of the main phase, containing the following subset of fields
173
- from the `bart` dictionary, with an additional head index that runs over
174
- MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
175
- specified the fields in `tracevars_onlymain`), plus the fields in
176
- `burnin_trace`.
271
+ burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
272
+ The trace of the burn-in phase. For the default layout, see `BurninTrace`.
273
+ main_trace : PyTree[Shaped[Array, 'n_save *']]
274
+ The trace of the main phase. For the default layout, see `MainTrace`.
177
275
 
178
276
  Raises
179
277
  ------
180
- ValueError
181
- If `inner_loop_length` is not a divisor of the total number of
182
- iterations and `allow_overflow` is `False`.
278
+ RuntimeError
279
+ If `run_mcmc` detects it's being invoked in a `jit`-wrapped context and
280
+ with settings that would create unrolled loops in the trace.
183
281
 
184
282
  Notes
185
283
  -----
186
284
  The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
187
285
  not include the initial state, and include the final state.
188
286
  """
189
-
190
- def empty_trace(length, bart, extractor):
191
- return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
192
-
193
- trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
194
- trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
287
+ # create empty traces
288
+ burnin_trace = _empty_trace(n_burn, bart, burnin_extractor)
289
+ main_trace = _empty_trace(n_save, bart, main_extractor)
195
290
 
196
291
  # determine number of iterations for inner and outer loops
197
292
  n_iters = n_burn + n_skip * n_save
198
293
  if inner_loop_length is None:
199
294
  inner_loop_length = n_iters
200
- n_outer = n_iters // inner_loop_length
201
- if n_iters % inner_loop_length:
202
- if allow_overflow:
203
- n_outer += 1
204
- else:
205
- raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
206
-
207
- carry = (bart, 0, key, trace_both, trace_onlymain, callback_state)
295
+ if inner_loop_length:
296
+ n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
297
+ else:
298
+ n_outer = 1
299
+ # setting to 0 would make for a clean noop, but it's useful to keep the
300
+ # same code path for benchmarking and testing
301
+
302
+ # error if under jit and there are unrolled loops or profile mode is on
303
+ under_jit = not hasattr(jnp.empty(0), 'platform')
304
+ if under_jit and (n_outer > 1 or get_profile_mode()):
305
+ msg = (
306
+ '`run_mcmc` was called within a jit-compiled function and '
307
+ 'there are either more than 1 outer loops or profile mode is active, '
308
+ 'please either do not jit, set `inner_loop_length=None`, or disable '
309
+ 'profile mode.'
310
+ )
311
+ raise RuntimeError(msg)
312
+
313
+ replicate = partial(_replicate, mesh=bart.config.mesh)
314
+ carry = _Carry(
315
+ bart,
316
+ replicate(jnp.int32(0)),
317
+ replicate(key),
318
+ burnin_trace,
319
+ main_trace,
320
+ callback_state,
321
+ )
322
+ _run_mcmc_inner_loop._fun.reset_call_counter() # noqa: SLF001
208
323
  for i_outer in range(n_outer):
209
324
  carry = _run_mcmc_inner_loop(
210
325
  carry,
211
326
  inner_loop_length,
212
- inner_callback,
213
- onlymain_extractor,
214
- both_extractor,
327
+ callback,
328
+ burnin_extractor,
329
+ main_extractor,
215
330
  n_burn,
216
331
  n_save,
217
332
  n_skip,
218
333
  i_outer,
334
+ n_iters,
219
335
  )
220
- if outer_callback is not None:
221
- bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
222
- i_total -= 1 # because i_total is updated at the end of the inner loop
223
- i_skip = _compute_i_skip(i_total, n_burn, n_skip)
224
- rt = outer_callback(
225
- bart=bart,
226
- burnin=i_total < n_burn,
227
- overflow=i_total >= n_iters,
228
- i_total=i_total,
229
- i_skip=i_skip,
230
- callback_state=callback_state,
231
- n_burn=n_burn,
232
- n_save=n_save,
233
- n_skip=n_skip,
234
- i_outer=i_outer,
235
- inner_loop_length=inner_loop_length,
236
- )
237
- if rt is not None:
238
- bart, callback_state = rt
239
- i_total += 1
240
- carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
241
336
 
242
- bart, _, _, trace_both, trace_onlymain, _ = carry
337
+ return carry.bart, carry.burnin_trace, carry.main_trace
243
338
 
244
- burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
245
- main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
246
- main_trace.update(trace_onlymain)
247
339
 
248
- return bart, burnin_trace, main_trace
340
+ def _replicate(x: Array, mesh: Mesh | None) -> Array:
341
+ if mesh is None:
342
+ return x
343
+ else:
344
+ return device_put(x, NamedSharding(mesh, PartitionSpec()))
249
345
 
250
346
 
251
- def _compute_i_skip(i_total, n_burn, n_skip):
347
+ @partial(jit, static_argnums=(0, 2))
348
+ def _empty_trace(
349
+ length: int, bart: State, extractor: Callable[[State], PyTree]
350
+ ) -> PyTree:
351
+ num_chains = get_num_chains(bart)
352
+ if num_chains is None:
353
+ out_axes = 0
354
+ else:
355
+ example_output = eval_shape(extractor, bart)
356
+ chain_axes = chain_vmap_axes(example_output)
357
+ out_axes = tree.map(
358
+ lambda a: 0 if a is None else 1, chain_axes, is_leaf=lambda a: a is None
359
+ )
360
+ return jax.vmap(extractor, in_axes=None, out_axes=out_axes, axis_size=length)(bart)
361
+
362
+
363
+ @jit
364
+ def _compute_i_skip(
365
+ i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
366
+ ) -> Int32[Array, '']:
367
+ """Compute the `i_skip` argument passed to `callback`."""
252
368
  burnin = i_total < n_burn
253
369
  return jnp.where(
254
370
  burnin,
255
371
  i_total + 1,
256
- (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
372
+ (i_total - n_burn + 1) % n_skip
373
+ + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
257
374
  )
258
375
 
259
376
 
260
- @functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
261
- def _run_mcmc_inner_loop(
262
- carry,
263
- inner_loop_length,
264
- inner_callback,
265
- onlymain_extractor,
266
- both_extractor,
267
- n_burn,
268
- n_save,
269
- n_skip,
270
- i_outer,
271
- ):
272
- def loop(carry, _):
273
- bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
377
+ class _CallCounter:
378
+ """Wrap a callable to check it's not called more than once."""
379
+
380
+ def __init__(self, func: Callable) -> None:
381
+ self.func = func
382
+ self.n_calls = 0
274
383
 
275
- keys = jaxext.split(key)
384
+ def reset_call_counter(self) -> None:
385
+ """Reset the call counter."""
386
+ self.n_calls = 0
387
+
388
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
389
+ if self.n_calls and not get_profile_mode():
390
+ msg = (
391
+ 'The inner loop of `run_mcmc` was traced more than once, '
392
+ 'which indicates a double compilation of the MCMC code. This '
393
+ 'probably depends on the input state having different type from the '
394
+ 'output state. Check the input is in a format that is the '
395
+ 'same jax would output, e.g., all arrays and scalars are jax '
396
+ 'arrays, with the right shardings.'
397
+ )
398
+ raise RuntimeError(msg)
399
+ self.n_calls += 1
400
+ return self.func(*args, **kwargs)
401
+
402
+
403
+ @partial(jit_if_not_profiling, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
404
+ @_CallCounter
405
+ def _run_mcmc_inner_loop(
406
+ carry: _Carry,
407
+ inner_loop_length: int,
408
+ callback: Callback | None,
409
+ burnin_extractor: Callable[[State], PyTree],
410
+ main_extractor: Callable[[State], PyTree],
411
+ n_burn: Int32[Array, ''],
412
+ n_save: Int32[Array, ''],
413
+ n_skip: Int32[Array, ''],
414
+ i_outer: Int32[Array, ''],
415
+ n_iters: Int32[Array, ''],
416
+ ) -> _Carry:
417
+ def loop_impl(carry: _Carry) -> _Carry:
418
+ """Loop body to run if i_total < n_iters."""
419
+ # split random key
420
+ keys = jaxext.split(carry.key, 3)
276
421
  key = keys.pop()
277
- bart = mcmcstep.step(keys.pop(), bart)
278
422
 
279
- burnin = i_total < n_burn
280
- if inner_callback is not None:
281
- i_skip = _compute_i_skip(i_total, n_burn, n_skip)
282
- rt = inner_callback(
423
+ # update state
424
+ bart = mcmcstep.step(keys.pop(), carry.bart)
425
+
426
+ # invoke callback
427
+ callback_state = carry.callback_state
428
+ if callback is not None:
429
+ i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
430
+ rt = callback(
431
+ key=keys.pop(),
283
432
  bart=bart,
284
- burnin=burnin,
285
- overflow=i_total >= n_burn + n_save * n_skip,
286
- i_total=i_total,
433
+ burnin=carry.i_total < n_burn,
434
+ i_total=carry.i_total,
287
435
  i_skip=i_skip,
288
436
  callback_state=callback_state,
289
437
  n_burn=n_burn,
@@ -295,137 +443,240 @@ def _run_mcmc_inner_loop(
295
443
  if rt is not None:
296
444
  bart, callback_state = rt
297
445
 
298
- i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
299
- i_both = jnp.where(burnin, i_total, n_burn + i_onlymain)
446
+ # save to trace
447
+ burnin_trace, main_trace = _save_state_to_trace(
448
+ carry.burnin_trace,
449
+ carry.main_trace,
450
+ burnin_extractor,
451
+ main_extractor,
452
+ bart,
453
+ carry.i_total,
454
+ n_burn,
455
+ n_skip,
456
+ )
300
457
 
301
- def update_trace(index, trace, state):
302
- def assign_at_index(trace_array, state_array):
303
- if trace_array.size:
304
- return trace_array.at[index, ...].set(state_array)
305
- else:
306
- # this handles the case where a trace is empty (e.g.,
307
- # no burn-in) because jax refuses to index into an array
308
- # of length 0
309
- return trace_array
458
+ return _Carry(
459
+ bart=bart,
460
+ i_total=carry.i_total + 1,
461
+ key=key,
462
+ burnin_trace=burnin_trace,
463
+ main_trace=main_trace,
464
+ callback_state=callback_state,
465
+ )
310
466
 
311
- return tree.map(assign_at_index, trace, state)
467
+ def loop_noop(carry: _Carry) -> _Carry:
468
+ """Loop body to run if i_total >= n_iters; it does nothing."""
469
+ return carry
312
470
 
313
- trace_onlymain = update_trace(
314
- i_onlymain, trace_onlymain, onlymain_extractor(bart)
471
+ def loop(carry: _Carry, _) -> tuple[_Carry, None]:
472
+ carry = cond_if_not_profiling(
473
+ carry.i_total < n_iters, loop_impl, loop_noop, carry
315
474
  )
316
- trace_both = update_trace(i_both, trace_both, both_extractor(bart))
317
-
318
- i_total += 1
319
- carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
320
475
  return carry, None
321
476
 
322
- carry, _ = lax.scan(loop, carry, None, inner_loop_length)
477
+ carry, _ = scan_if_not_profiling(loop, carry, None, inner_loop_length)
323
478
  return carry
324
479
 
325
480
 
326
- def make_print_callbacks(dot_every_inner=1, report_every_outer=1):
481
+ @partial(jit, donate_argnums=(0, 1), static_argnums=(2, 3))
482
+ # this is jitted because under profiling _run_mcmc_inner_loop and the loop
483
+ # within it are not, so I need the donate_argnums feature of jit to avoid
484
+ # creating copies of the traces
485
+ def _save_state_to_trace(
486
+ burnin_trace: PyTree,
487
+ main_trace: PyTree,
488
+ burnin_extractor: Callable[[State], PyTree],
489
+ main_extractor: Callable[[State], PyTree],
490
+ bart: State,
491
+ i_total: Int32[Array, ''],
492
+ n_burn: Int32[Array, ''],
493
+ n_skip: Int32[Array, ''],
494
+ ) -> tuple[PyTree, PyTree]:
495
+ # trace index where to save during burnin; out-of-bounds => noop after
496
+ # burnin
497
+ burnin_idx = i_total
498
+
499
+ # trace index where to save during main phase; force it out-of-bounds
500
+ # during burnin
501
+ main_idx = (i_total - n_burn) // n_skip
502
+ noop_idx = jnp.iinfo(jnp.int32).max
503
+ noop_cond = i_total < n_burn
504
+ main_idx = jnp.where(noop_cond, noop_idx, main_idx)
505
+
506
+ # prepare array index
507
+ num_chains = get_num_chains(bart)
508
+ burnin_trace = _set(burnin_trace, burnin_idx, burnin_extractor(bart), num_chains)
509
+ main_trace = _set(main_trace, main_idx, main_extractor(bart), num_chains)
510
+
511
+ return burnin_trace, main_trace
512
+
513
+
514
+ def _set(
515
+ trace: PyTree[Array, ' T'],
516
+ index: Int32[Array, ''],
517
+ val: PyTree[Array, ' T'],
518
+ num_chains: int | None,
519
+ ) -> PyTree[Array, ' T']:
520
+ """Do ``trace[index] = val`` but fancier."""
521
+ chain_axis = chain_vmap_axes(val)
522
+
523
+ def at_set(
524
+ trace: Shaped[Array, 'chains samples *shape']
525
+ | Shaped[Array, ' samples *shape']
526
+ | None,
527
+ val: Shaped[Array, ' chains *shape'] | Shaped[Array, '*shape'] | None,
528
+ chain_axis: int | None,
529
+ ):
530
+ if trace is None or trace.size == 0:
531
+ # this handles the case where an array is empty because jax refuses
532
+ # to index into an axis of length 0, even if just in the abstract,
533
+ # and optional elements that are considered leaves due to `is_leaf`
534
+ # below needed to traverse `chain_axis`.
535
+ return trace
536
+
537
+ if num_chains is None or chain_axis is None:
538
+ ndindex = (index, ...)
539
+ else:
540
+ ndindex = (slice(None), index, ...)
541
+
542
+ return trace.at[ndindex].set(val, mode='drop')
543
+
544
+ return tree.map(at_set, trace, val, chain_axis, is_leaf=lambda x: x is None)
545
+
546
+
547
+ def make_default_callback(
548
+ state: State,
549
+ *,
550
+ dot_every: int | Integer[Array, ''] | None = 1,
551
+ report_every: int | Integer[Array, ''] | None = 100,
552
+ ) -> dict[str, Any]:
327
553
  """
328
- Prepare logging callbacks for `run_mcmc`.
554
+ Prepare a default callback for `run_mcmc`.
329
555
 
330
- Prepare callbacks which print a dot on every iteration, and a longer
331
- report outer loop iteration.
556
+ The callback prints a dot on every iteration, and a longer
557
+ report outer loop iteration, and can do variable selection.
332
558
 
333
559
  Parameters
334
560
  ----------
335
- dot_every_inner : int, default 1
336
- A dot is printed every `dot_every_inner` MCMC iterations.
337
- report_every_outer : int, default 1
338
- A report is printed every `report_every_outer` outer loop
339
- iterations.
561
+ state
562
+ The bart state to use the callback with, used to determine device
563
+ sharding.
564
+ dot_every
565
+ A dot is printed every `dot_every` MCMC iterations, `None` to disable.
566
+ report_every
567
+ A one line report is printed every `report_every` MCMC iterations,
568
+ `None` to disable.
340
569
 
341
570
  Returns
342
571
  -------
343
- kwargs : dict
344
- A dictionary with the arguments to pass to `run_mcmc` as keyword
345
- arguments to set up the callbacks.
572
+ A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
346
573
 
347
574
  Examples
348
575
  --------
349
- >>> run_mcmc(..., **make_print_callbacks())
576
+ >>> run_mcmc(key, state, ..., **make_default_callback(state, ...))
350
577
  """
578
+
579
+ def as_replicated_array_or_none(val: None | Any) -> None | Array:
580
+ return None if val is None else _replicate(jnp.asarray(val), state.config.mesh)
581
+
351
582
  return dict(
352
- inner_callback=_print_callback_inner,
353
- outer_callback=_print_callback_outer,
354
- callback_state=dict(
355
- dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
583
+ callback=print_callback,
584
+ callback_state=PrintCallbackState(
585
+ as_replicated_array_or_none(dot_every),
586
+ as_replicated_array_or_none(report_every),
356
587
  ),
357
588
  )
358
589
 
359
590
 
360
- def _print_callback_inner(*, i_total, callback_state, **_):
361
- dot_every_inner = callback_state['dot_every_inner']
362
- if dot_every_inner is not None:
363
- cond = (i_total + 1) % dot_every_inner == 0
364
- debug.callback(_print_dot, cond)
591
+ class PrintCallbackState(Module):
592
+ """State for `print_callback`."""
365
593
 
594
+ dot_every: Int32[Array, ''] | None
595
+ """A dot is printed every `dot_every` MCMC iterations, `None` to disable."""
366
596
 
367
- def _print_dot(cond):
368
- if cond:
369
- print('.', end='', flush=True)
597
+ report_every: Int32[Array, ''] | None
598
+ """A one line report is printed every `report_every` MCMC iterations,
599
+ `None` to disable."""
370
600
 
371
601
 
372
- def _print_callback_outer(
602
+ def print_callback(
373
603
  *,
374
- bart,
375
- burnin,
376
- overflow,
377
- i_total,
378
- n_burn,
379
- n_save,
380
- n_skip,
381
- callback_state,
382
- i_outer,
383
- inner_loop_length,
604
+ bart: State,
605
+ burnin: Bool[Array, ''],
606
+ i_total: Int32[Array, ''],
607
+ n_burn: Int32[Array, ''],
608
+ n_save: Int32[Array, ''],
609
+ n_skip: Int32[Array, ''],
610
+ callback_state: PrintCallbackState,
384
611
  **_,
385
612
  ):
386
- report_every_outer = callback_state['report_every_outer']
387
- if report_every_outer is not None:
388
- dot_every_inner = callback_state['dot_every_inner']
389
- if dot_every_inner is None:
390
- newline = False
613
+ """Print a dot and/or a report periodically during the MCMC."""
614
+ report_every = callback_state.report_every
615
+ dot_every = callback_state.dot_every
616
+ it = i_total + 1
617
+
618
+ def get_cond(every: Int32[Array, ''] | None) -> bool | Bool[Array, '']:
619
+ return False if every is None else it % every == 0
620
+
621
+ report_cond = get_cond(report_every)
622
+ dot_cond = get_cond(dot_every)
623
+
624
+ def line_report_branch():
625
+ if report_every is None:
626
+ return
627
+ if dot_every is None:
628
+ print_newline = False
391
629
  else:
392
- newline = dot_every_inner < inner_loop_length
630
+ print_newline = it % report_every > it % dot_every
393
631
  debug.callback(
394
632
  _print_report,
395
- cond=(i_outer + 1) % report_every_outer == 0,
396
- newline=newline,
633
+ print_dot=dot_cond,
634
+ print_newline=print_newline,
397
635
  burnin=burnin,
398
- overflow=overflow,
399
- i_total=i_total,
636
+ it=it,
400
637
  n_iters=n_burn + n_save * n_skip,
401
- grow_prop_count=bart.forest.grow_prop_count,
402
- grow_acc_count=bart.forest.grow_acc_count,
403
- prune_prop_count=bart.forest.prune_prop_count,
404
- prune_acc_count=bart.forest.prune_acc_count,
405
- prop_total=len(bart.forest.leaf_trees),
406
- fill=grove.forest_fill(bart.forest.split_trees),
638
+ num_chains=bart.forest.num_chains(),
639
+ grow_prop_count=bart.forest.grow_prop_count.mean(),
640
+ grow_acc_count=bart.forest.grow_acc_count.mean(),
641
+ prune_acc_count=bart.forest.prune_acc_count.mean(),
642
+ prop_total=bart.forest.split_tree.shape[-2],
643
+ fill=forest_fill(bart.forest.split_tree),
407
644
  )
408
645
 
646
+ def just_dot_branch():
647
+ if dot_every is None:
648
+ return
649
+ debug.callback(
650
+ lambda: print('.', end='', flush=True) # noqa: T201
651
+ )
652
+ # logging can't do in-line printing so we use print
653
+
654
+ cond_if_not_profiling(
655
+ report_cond,
656
+ line_report_branch,
657
+ lambda: cond_if_not_profiling(dot_cond, just_dot_branch, lambda: None),
658
+ )
409
659
 
410
- def _convert_jax_arrays_in_args(func):
660
+
661
+ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
411
662
  """Remove jax arrays from a function arguments.
412
663
 
413
- Converts all jax.Array instances in the arguments to either Python scalars
664
+ Converts all `jax.Array` instances in the arguments to either Python scalars
414
665
  or numpy arrays.
415
666
  """
416
667
 
417
- def convert_jax_arrays(pytree):
418
- def convert_jax_arrays(val):
419
- if not isinstance(val, jax.Array):
668
+ def convert_jax_arrays(pytree: PyTree) -> PyTree:
669
+ def convert_jax_array(val: Any) -> Any:
670
+ if not isinstance(val, Array):
420
671
  return val
421
672
  elif val.shape:
422
673
  return numpy.array(val)
423
674
  else:
424
675
  return val.item()
425
676
 
426
- return tree.map(convert_jax_arrays, pytree)
677
+ return tree.map(convert_jax_array, pytree)
427
678
 
428
- @functools.wraps(func)
679
+ @wraps(func)
429
680
  def new_func(*args, **kw):
430
681
  args = convert_jax_arrays(args)
431
682
  kw = convert_jax_arrays(kw)
@@ -439,73 +690,170 @@ def _convert_jax_arrays_in_args(func):
439
690
  # deadlock with the main thread
440
691
  def _print_report(
441
692
  *,
442
- cond,
443
- newline,
444
- burnin,
445
- overflow,
446
- i_total,
447
- n_iters,
448
- grow_prop_count,
449
- grow_acc_count,
450
- prune_prop_count,
451
- prune_acc_count,
452
- prop_total,
453
- fill,
693
+ print_dot: bool,
694
+ print_newline: bool,
695
+ burnin: bool,
696
+ it: int,
697
+ n_iters: int,
698
+ num_chains: int | None,
699
+ grow_prop_count: float,
700
+ grow_acc_count: float,
701
+ prune_acc_count: float,
702
+ prop_total: int,
703
+ fill: float,
454
704
  ):
455
- if cond:
456
- newline = '\n' if newline else ''
705
+ """Print the report for `print_callback`."""
706
+ # compute fractions
707
+ grow_prop = grow_prop_count / prop_total
708
+ move_acc = (grow_acc_count + prune_acc_count) / prop_total
709
+
710
+ # determine prefix
711
+ if print_dot:
712
+ prefix = '.\n'
713
+ elif print_newline:
714
+ prefix = '\n'
715
+ else:
716
+ prefix = ''
717
+
718
+ # determine suffix in parentheses
719
+ msgs = []
720
+ if num_chains is not None:
721
+ msgs.append(f'avg. {num_chains} chains')
722
+ if burnin:
723
+ msgs.append('burnin')
724
+ suffix = f' ({", ".join(msgs)})' if msgs else ''
725
+
726
+ print( # noqa: T201, see print_callback for why not logging
727
+ f'{prefix}Iteration {it}/{n_iters}, '
728
+ f'grow prob: {grow_prop:.0%}, '
729
+ f'move acc: {move_acc:.0%}, '
730
+ f'fill: {fill:.0%}{suffix}'
731
+ )
457
732
 
458
- def acc_string(acc_count, prop_count):
459
- if prop_count:
460
- return f'{acc_count / prop_count:.0%}'
461
- else:
462
- return ' n/d'
463
733
 
464
- grow_prop = grow_prop_count / prop_total
465
- prune_prop = prune_prop_count / prop_total
466
- grow_acc = acc_string(grow_acc_count, grow_prop_count)
467
- prune_acc = acc_string(prune_acc_count, prune_prop_count)
734
+ class Trace(TreeHeaps, Protocol):
735
+ """Protocol for a MCMC trace."""
468
736
 
469
- if burnin:
470
- flag = ' (burnin)'
471
- elif overflow:
472
- flag = ' (overflow)'
473
- else:
474
- flag = ''
737
+ offset: Float32[Array, '*trace_shape']
475
738
 
476
- print(
477
- f'{newline}It {i_total + 1}/{n_iters} '
478
- f'grow P={grow_prop:.0%} A={grow_acc}, '
479
- f'prune P={prune_prop:.0%} A={prune_acc}, '
480
- f'fill={fill:.0%}{flag}'
481
- )
482
739
 
740
+ class TreesTrace(Module):
741
+ """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
742
+
743
+ leaf_tree: (
744
+ Float32[Array, '*trace_shape num_trees 2**d']
745
+ | Float32[Array, '*trace_shape num_trees k 2**d']
746
+ )
747
+ var_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
748
+ split_tree: UInt[Array, '*trace_shape num_trees 2**(d-1)']
483
749
 
484
- @jax.jit
485
- def evaluate_trace(trace, X):
750
+ @classmethod
751
+ def from_dataclass(cls, obj: TreeHeaps):
752
+ """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
753
+ return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
754
+
755
+
756
+ @jit
757
+ def evaluate_trace(
758
+ X: UInt[Array, 'p n'], trace: Trace
759
+ ) -> Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']:
486
760
  """
487
761
  Compute predictions for all iterations of the BART MCMC.
488
762
 
489
763
  Parameters
490
764
  ----------
491
- trace : dict
492
- A trace of the BART MCMC, as returned by `run_mcmc`.
493
- X : array (p, n)
765
+ X
494
766
  The predictors matrix, with `p` predictors and `n` observations.
767
+ trace
768
+ A main trace of the BART MCMC, as returned by `run_mcmc`.
495
769
 
496
770
  Returns
497
771
  -------
498
- y : array (n_trace, n)
499
- The predictions for each iteration of the MCMC.
772
+ The predictions for each chain and iteration of the MCMC.
500
773
  """
501
- evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
502
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
774
+ # per-device memory limit
775
+ max_io_nbytes = 2**27 # 128 MiB
776
+
777
+ # adjust memory limit for number of devices
778
+ mesh = jax.typeof(trace.leaf_tree).sharding.mesh
779
+ num_devices = get_axis_size(mesh, 'chains') * get_axis_size(mesh, 'data')
780
+ max_io_nbytes *= num_devices
781
+
782
+ # determine batching axes
783
+ has_chains = trace.split_tree.ndim > 3 # chains, samples, trees, nodes
784
+ if has_chains:
785
+ sample_axis = 1
786
+ tree_axis = 2
787
+ else:
788
+ sample_axis = 0
789
+ tree_axis = 1
790
+
791
+ # batch and sum over trees
792
+ batched_eval = autobatch(
793
+ evaluate_forest,
794
+ max_io_nbytes,
795
+ (None, tree_axis),
796
+ tree_axis,
797
+ reduce_ufunc=jnp.add,
798
+ )
503
799
 
504
- def loop(_, row):
505
- values = evaluate_trees(
506
- X, row['leaf_trees'], row['var_trees'], row['split_trees']
800
+ # determine output shape (to avoid autobatch tracing everything 4 times)
801
+ is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
802
+ k = trace.leaf_tree.shape[-2] if is_mv else 1
803
+ mv_shape = (k,) if is_mv else ()
804
+ _, n = X.shape
805
+ out_shape = (*trace.split_tree.shape[:-2], *mv_shape, n)
806
+
807
+ # adjust memory limit keeping into account that trees are summed over
808
+ num_trees, hts = trace.split_tree.shape[-2:]
809
+ out_size = k * n * jnp.float32.dtype.itemsize # the value of the forest
810
+ core_io_size = (
811
+ num_trees
812
+ * hts
813
+ * (
814
+ 2 * k * trace.leaf_tree.itemsize
815
+ + trace.var_tree.itemsize
816
+ + trace.split_tree.itemsize
507
817
  )
508
- return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
818
+ + out_size
819
+ )
820
+ core_int_size = (num_trees - 1) * out_size
821
+ max_io_nbytes = max(1, floor(max_io_nbytes / (1 + core_int_size / core_io_size)))
822
+
823
+ # batch over mcmc samples
824
+ batched_eval = autobatch(
825
+ batched_eval,
826
+ max_io_nbytes,
827
+ (None, sample_axis),
828
+ sample_axis,
829
+ warn_on_overflow=False, # the inner autobatch will handle it
830
+ result_shape_dtype=ShapeDtypeStruct(out_shape, jnp.float32),
831
+ )
832
+
833
+ # extract only the trees from the trace
834
+ trees = TreesTrace.from_dataclass(trace)
835
+
836
+ # evaluate trees
837
+ y_centered: Float32[Array, '*trace_shape n'] | Float32[Array, '*trace_shape k n']
838
+ y_centered = batched_eval(X, trees)
839
+ return y_centered + trace.offset[..., None]
840
+
841
+
842
+ @partial(jit, static_argnums=(0,))
843
+ def compute_varcount(p: int, trace: TreeHeaps) -> Int32[Array, '*trace_shape {p}']:
844
+ """
845
+ Count how many times each predictor is used in each MCMC state.
509
846
 
510
- _, y = lax.scan(loop, None, trace)
511
- return y
847
+ Parameters
848
+ ----------
849
+ p
850
+ The number of predictors.
851
+ trace
852
+ A main trace of the BART MCMC, as returned by `run_mcmc`.
853
+
854
+ Returns
855
+ -------
856
+ Histogram of predictor usage in each MCMC state.
857
+ """
858
+ # var_tree has shape (chains? samples trees nodes)
859
+ return var_histogram(p, trace.var_tree, trace.split_tree, sum_batch_axis=-1)