bartz 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcloop.py CHANGED
@@ -22,164 +22,222 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """Functions that implement the full BART posterior MCMC loop."""
25
+ """Functions that implement the full BART posterior MCMC loop.
26
26
 
27
- import functools
27
+ The entry points are `run_mcmc` and `make_default_callback`.
28
+ """
29
+
30
+ from collections.abc import Callable
31
+ from dataclasses import fields, replace
32
+ from functools import partial, wraps
33
+ from typing import Any, Protocol
28
34
 
29
35
  import jax
30
36
  import numpy
37
+ from equinox import Module
31
38
  from jax import debug, lax, tree
32
39
  from jax import numpy as jnp
33
- from jaxtyping import Array, Real
40
+ from jax.nn import softmax
41
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
42
+
43
+ from bartz import grove, jaxext, mcmcstep
44
+ from bartz.mcmcstep import State
45
+
46
+
47
+ class BurninTrace(Module):
48
+ """MCMC trace with only diagnostic values."""
49
+
50
+ sigma2: Float32[Array, '*trace_length'] | None
51
+ theta: Float32[Array, '*trace_length'] | None
52
+ grow_prop_count: Int32[Array, '*trace_length']
53
+ grow_acc_count: Int32[Array, '*trace_length']
54
+ prune_prop_count: Int32[Array, '*trace_length']
55
+ prune_acc_count: Int32[Array, '*trace_length']
56
+ log_likelihood: Float32[Array, '*trace_length'] | None
57
+ log_trans_prior: Float32[Array, '*trace_length'] | None
58
+
59
+ @classmethod
60
+ def from_state(cls, state: State) -> 'BurninTrace':
61
+ """Create a single-item burn-in trace from a MCMC state."""
62
+ return cls(
63
+ sigma2=state.sigma2,
64
+ theta=state.forest.theta,
65
+ grow_prop_count=state.forest.grow_prop_count,
66
+ grow_acc_count=state.forest.grow_acc_count,
67
+ prune_prop_count=state.forest.prune_prop_count,
68
+ prune_acc_count=state.forest.prune_acc_count,
69
+ log_likelihood=state.forest.log_likelihood,
70
+ log_trans_prior=state.forest.log_trans_prior,
71
+ )
34
72
 
35
- from . import grove, jaxext, mcmcstep
36
- from .mcmcstep import State
37
73
 
74
+ class MainTrace(BurninTrace):
75
+ """MCMC trace with trees and diagnostic values."""
38
76
 
39
- def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]:
40
- """Extract variables for the main trace, to be used in `run_mcmc`."""
41
- return dict(
42
- leaf_trees=state.forest.leaf_trees,
43
- var_trees=state.forest.var_trees,
44
- split_trees=state.forest.split_trees,
45
- offset=state.offset,
46
- )
77
+ leaf_tree: Float32[Array, '*trace_length 2**d']
78
+ var_tree: UInt[Array, '*trace_length 2**(d-1)']
79
+ split_tree: UInt[Array, '*trace_length 2**(d-1)']
80
+ offset: Float32[Array, '*trace_length']
81
+ varprob: Float32[Array, '*trace_length p'] | None
47
82
 
83
+ @classmethod
84
+ def from_state(cls, state: State) -> 'MainTrace':
85
+ """Create a single-item main trace from a MCMC state."""
86
+ # compute varprob
87
+ log_s = state.forest.log_s
88
+ if log_s is None:
89
+ varprob = None
90
+ else:
91
+ varprob = softmax(log_s, where=state.forest.max_split.astype(bool))
92
+
93
+ return cls(
94
+ leaf_tree=state.forest.leaf_tree,
95
+ var_tree=state.forest.var_tree,
96
+ split_tree=state.forest.split_tree,
97
+ offset=state.offset,
98
+ varprob=varprob,
99
+ **vars(BurninTrace.from_state(state)),
100
+ )
48
101
 
49
- def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]:
50
- """Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
51
- return dict(
52
- sigma2=state.sigma2,
53
- grow_prop_count=state.forest.grow_prop_count,
54
- grow_acc_count=state.forest.grow_acc_count,
55
- prune_prop_count=state.forest.prune_prop_count,
56
- prune_acc_count=state.forest.prune_acc_count,
57
- log_likelihood=state.forest.log_likelihood,
58
- log_trans_prior=state.forest.log_trans_prior,
59
- )
102
+
103
+ CallbackState = PyTree[Any, 'T']
104
+
105
+
106
+ class Callback(Protocol):
107
+ """Callback type for `run_mcmc`."""
108
+
109
+ def __call__(
110
+ self,
111
+ *,
112
+ key: Key[Array, ''],
113
+ bart: State,
114
+ burnin: Bool[Array, ''],
115
+ i_total: Int32[Array, ''],
116
+ i_skip: Int32[Array, ''],
117
+ callback_state: CallbackState,
118
+ n_burn: Int32[Array, ''],
119
+ n_save: Int32[Array, ''],
120
+ n_skip: Int32[Array, ''],
121
+ i_outer: Int32[Array, ''],
122
+ inner_loop_length: int,
123
+ ) -> tuple[State, CallbackState] | None:
124
+ """Do an arbitrary action after an iteration of the MCMC.
125
+
126
+ Parameters
127
+ ----------
128
+ key
129
+ A key for random number generation.
130
+ bart
131
+ The MCMC state just after updating it.
132
+ burnin
133
+ Whether the last iteration was in the burn-in phase.
134
+ i_total
135
+ The index of the last MCMC iteration (0-based).
136
+ i_skip
137
+ The number of MCMC updates from the last saved state. The initial
138
+ state counts as saved, even if it's not copied into the trace.
139
+ callback_state
140
+ The callback state, initially set to the argument passed to
141
+ `run_mcmc`, afterwards to the value returned by the last invocation
142
+ of the callback.
143
+ n_burn
144
+ n_save
145
+ n_skip
146
+ The corresponding `run_mcmc` arguments as-is.
147
+ i_outer
148
+ The index of the last outer loop iteration (0-based).
149
+ inner_loop_length
150
+ The number of MCMC iterations in the inner loop.
151
+
152
+ Returns
153
+ -------
154
+ bart : State
155
+ A possibly modified MCMC state. To avoid modifying the state,
156
+ return the `bart` argument passed to the callback as-is.
157
+ callback_state : CallbackState
158
+ The new state to be passed on the next callback invocation.
159
+
160
+ Notes
161
+ -----
162
+ For convenience, the callback may return `None`, and the states won't
163
+ be updated.
164
+ """
165
+ ...
166
+
167
+
168
+ class _Carry(Module):
169
+ """Carry used in the loop in `run_mcmc`."""
170
+
171
+ bart: State
172
+ i_total: Int32[Array, '']
173
+ key: Key[Array, '']
174
+ burnin_trace: PyTree[Shaped[Array, 'n_burn *']]
175
+ main_trace: PyTree[Shaped[Array, 'n_save *']]
176
+ callback_state: CallbackState
60
177
 
61
178
 
62
179
  def run_mcmc(
63
- key,
64
- bart,
65
- n_save,
180
+ key: Key[Array, ''],
181
+ bart: State,
182
+ n_save: int,
66
183
  *,
67
- n_burn=0,
68
- n_skip=1,
69
- inner_loop_length=None,
70
- allow_overflow=False,
71
- inner_callback=None,
72
- outer_callback=None,
73
- callback_state=None,
74
- onlymain_extractor=default_onlymain_extractor,
75
- both_extractor=default_both_extractor,
76
- ):
184
+ n_burn: int = 0,
185
+ n_skip: int = 1,
186
+ inner_loop_length: int | None = None,
187
+ callback: Callback | None = None,
188
+ callback_state: CallbackState = None,
189
+ burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
190
+ main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
191
+ ) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
77
192
  """
78
193
  Run the MCMC for the BART posterior.
79
194
 
80
195
  Parameters
81
196
  ----------
82
- key : jax.dtypes.prng_key array
197
+ key
83
198
  A key for random number generation.
84
- bart : dict
199
+ bart
85
200
  The initial MCMC state, as created and updated by the functions in
86
201
  `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
87
202
  so this variable is invalidated after running `run_mcmc`. Make a copy
88
203
  beforehand to use it again.
89
- n_save : int
204
+ n_save
90
205
  The number of iterations to save.
91
- n_burn : int, default 0
206
+ n_burn
92
207
  The number of initial iterations which are not saved.
93
- n_skip : int, default 1
208
+ n_skip
94
209
  The number of iterations to skip between each saved iteration, plus 1.
95
210
  The effective burn-in is ``n_burn + n_skip - 1``.
96
- inner_loop_length : int, optional
211
+ inner_loop_length
97
212
  The MCMC loop is split into an outer and an inner loop. The outer loop
98
213
  is in Python, while the inner loop is in JAX. `inner_loop_length` is the
99
214
  number of iterations of the inner loop to run for each iteration of the
100
215
  outer loop. If not specified, the outer loop will iterate just once,
101
216
  with all iterations done in a single inner loop run. The inner stride is
102
217
  unrelated to the stride used for saving the trace.
103
- allow_overflow : bool, default False
104
- If `False`, `inner_loop_length` must be a divisor of the total number of
105
- iterations ``n_burn + n_skip * n_save``. If `True` and
106
- `inner_loop_length` is not a divisor, some of the MCMC iterations in the
107
- last outer loop iteration will not be saved to the trace.
108
- inner_callback : callable, optional
109
- outer_callback : callable, optional
110
- Arbitrary functions run during the loop after updating the state.
111
- `inner_callback` is called after each update, while `outer_callback` is
112
- called after completing an inner loop. The callbacks are invoked with
113
- the following arguments, passed by keyword:
114
-
115
- bart : dict
116
- The MCMC state just after updating it.
117
- burnin : bool
118
- Whether the last iteration was in the burn-in phase.
119
- overflow : bool
120
- Whether the last iteration was in the overflow phase (iterations
121
- not saved due to `inner_loop_length` not being a divisor of the
122
- total number of iterations).
123
- i_total : int
124
- The index of the last MCMC iteration (0-based).
125
- i_skip : int
126
- The number of MCMC updates from the last saved state. The initial
127
- state counts as saved, even if it's not copied into the trace.
128
- callback_state : jax pytree
129
- The callback state, initially set to the argument passed to
130
- `run_mcmc`, afterwards to the value returned by the last invocation
131
- of `inner_callback` or `outer_callback`.
132
- n_burn, n_save, n_skip : int
133
- The corresponding arguments as-is.
134
- i_outer : int
135
- The index of the last outer loop iteration (0-based).
136
- inner_loop_length : int
137
- The number of MCMC iterations in the inner loop.
138
-
139
- `inner_callback` is called under the jax jit, so the argument values are
140
- not available at the time the Python code is executed. Use the utilities
141
- in `jax.debug` to access the values at actual runtime.
142
-
143
- The callbacks must return two values:
144
-
145
- bart : dict
146
- A possibly modified MCMC state. To avoid modifying the state,
147
- return the `bart` argument passed to the callback as-is.
148
- callback_state : jax pytree
149
- The new state to be passed on the next callback invocation.
150
-
151
- For convenience, if a callback returns `None`, the states are not
152
- updated.
153
- callback_state : jax pytree, optional
154
- The initial state for the callbacks.
155
- onlymain_extractor : callable, optional
156
- both_extractor : callable, optional
218
+ callback
219
+ An arbitrary function run during the loop after updating the state. For
220
+ the signature, see `Callback`. The callback is called under the jax jit,
221
+ so the argument values are not available at the time the Python code is
222
+ executed. Use the utilities in `jax.debug` to access the values at
223
+ actual runtime. The callback may return new values for the MCMC state
224
+ and the callback state.
225
+ callback_state
226
+ The initial custom state for the callback.
227
+ burnin_extractor
228
+ main_extractor
157
229
  Functions that extract the variables to be saved respectively only in
158
230
  the main trace and in both traces, given the MCMC state as argument.
159
231
  Must return a pytree, and must be vmappable.
160
232
 
161
233
  Returns
162
234
  -------
163
- bart : dict
235
+ bart : State
164
236
  The final MCMC state.
165
- burnin_trace : dict of (n_burn, ...) arrays
166
- The trace of the burn-in phase, containing the following subset of
167
- fields from the `bart` dictionary, with an additional head index that
168
- runs over MCMC iterations: 'sigma2', 'grow_prop_count',
169
- 'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
170
- the fields in `tracevars_both`).
171
- main_trace : dict of (n_save, ...) arrays
172
- The trace of the main phase, containing the following subset of fields
173
- from the `bart` dictionary, with an additional head index that runs over
174
- MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
175
- specified the fields in `tracevars_onlymain`), plus the fields in
176
- `burnin_trace`.
177
-
178
- Raises
179
- ------
180
- ValueError
181
- If `inner_loop_length` is not a divisor of the total number of
182
- iterations and `allow_overflow` is `False`.
237
+ burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
238
+ The trace of the burn-in phase. For the default layout, see `BurninTrace`.
239
+ main_trace : PyTree[Shaped[Array, 'n_save *']]
240
+ The trace of the main phase. For the default layout, see `MainTrace`.
183
241
 
184
242
  Notes
185
243
  -----
@@ -190,102 +248,85 @@ def run_mcmc(
190
248
  def empty_trace(length, bart, extractor):
191
249
  return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
192
250
 
193
- trace_both = empty_trace(n_burn + n_save, bart, both_extractor)
194
- trace_onlymain = empty_trace(n_save, bart, onlymain_extractor)
251
+ burnin_trace = empty_trace(n_burn, bart, burnin_extractor)
252
+ main_trace = empty_trace(n_save, bart, main_extractor)
195
253
 
196
254
  # determine number of iterations for inner and outer loops
197
255
  n_iters = n_burn + n_skip * n_save
198
256
  if inner_loop_length is None:
199
257
  inner_loop_length = n_iters
200
- n_outer = n_iters // inner_loop_length
201
- if n_iters % inner_loop_length:
202
- if allow_overflow:
203
- n_outer += 1
204
- else:
205
- raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
206
-
207
- carry = (bart, 0, key, trace_both, trace_onlymain, callback_state)
258
+ if inner_loop_length:
259
+ n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
260
+ else:
261
+ n_outer = 1
262
+ # setting to 0 would make for a clean noop, but it's useful to keep the
263
+ # same code path for benchmarking and testing
264
+
265
+ carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state)
208
266
  for i_outer in range(n_outer):
209
267
  carry = _run_mcmc_inner_loop(
210
268
  carry,
211
269
  inner_loop_length,
212
- inner_callback,
213
- onlymain_extractor,
214
- both_extractor,
270
+ callback,
271
+ burnin_extractor,
272
+ main_extractor,
215
273
  n_burn,
216
274
  n_save,
217
275
  n_skip,
218
276
  i_outer,
277
+ n_iters,
219
278
  )
220
- if outer_callback is not None:
221
- bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
222
- i_total -= 1 # because i_total is updated at the end of the inner loop
223
- i_skip = _compute_i_skip(i_total, n_burn, n_skip)
224
- rt = outer_callback(
225
- bart=bart,
226
- burnin=i_total < n_burn,
227
- overflow=i_total >= n_iters,
228
- i_total=i_total,
229
- i_skip=i_skip,
230
- callback_state=callback_state,
231
- n_burn=n_burn,
232
- n_save=n_save,
233
- n_skip=n_skip,
234
- i_outer=i_outer,
235
- inner_loop_length=inner_loop_length,
236
- )
237
- if rt is not None:
238
- bart, callback_state = rt
239
- i_total += 1
240
- carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
241
-
242
- bart, _, _, trace_both, trace_onlymain, _ = carry
243
279
 
244
- burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both)
245
- main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both)
246
- main_trace.update(trace_onlymain)
280
+ return carry.bart, carry.burnin_trace, carry.main_trace
247
281
 
248
- return bart, burnin_trace, main_trace
249
282
 
250
-
251
- def _compute_i_skip(i_total, n_burn, n_skip):
283
+ def _compute_i_skip(
284
+ i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
285
+ ) -> Int32[Array, '']:
286
+ """Compute the `i_skip` argument passed to `callback`."""
252
287
  burnin = i_total < n_burn
253
288
  return jnp.where(
254
289
  burnin,
255
290
  i_total + 1,
256
- (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
291
+ (i_total - n_burn + 1) % n_skip
292
+ + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
257
293
  )
258
294
 
259
295
 
260
- @functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
296
+ @partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
261
297
  def _run_mcmc_inner_loop(
262
- carry,
263
- inner_loop_length,
264
- inner_callback,
265
- onlymain_extractor,
266
- both_extractor,
267
- n_burn,
268
- n_save,
269
- n_skip,
270
- i_outer,
298
+ carry: _Carry,
299
+ inner_loop_length: int,
300
+ callback: Callback | None,
301
+ burnin_extractor: Callable[[State], PyTree],
302
+ main_extractor: Callable[[State], PyTree],
303
+ n_burn: Int32[Array, ''],
304
+ n_save: Int32[Array, ''],
305
+ n_skip: Int32[Array, ''],
306
+ i_outer: Int32[Array, ''],
307
+ n_iters: Int32[Array, ''],
271
308
  ):
272
- def loop(carry, _):
273
- bart, i_total, key, trace_both, trace_onlymain, callback_state = carry
274
-
275
- keys = jaxext.split(key)
276
- key = keys.pop()
277
- bart = mcmcstep.step(keys.pop(), bart)
278
-
279
- burnin = i_total < n_burn
280
- if inner_callback is not None:
281
- i_skip = _compute_i_skip(i_total, n_burn, n_skip)
282
- rt = inner_callback(
283
- bart=bart,
309
+ def loop_impl(carry: _Carry) -> _Carry:
310
+ """Loop body to run if i_total < n_iters."""
311
+ # split random key
312
+ keys = jaxext.split(carry.key, 3)
313
+ carry = replace(carry, key=keys.pop())
314
+
315
+ # update state
316
+ carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart))
317
+
318
+ burnin = carry.i_total < n_burn
319
+
320
+ # invoke callback
321
+ if callback is not None:
322
+ i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
323
+ rt = callback(
324
+ key=keys.pop(),
325
+ bart=carry.bart,
284
326
  burnin=burnin,
285
- overflow=i_total >= n_burn + n_save * n_skip,
286
- i_total=i_total,
327
+ i_total=carry.i_total,
287
328
  i_skip=i_skip,
288
- callback_state=callback_state,
329
+ callback_state=carry.callback_state,
289
330
  n_burn=n_burn,
290
331
  n_save=n_save,
291
332
  n_skip=n_skip,
@@ -294,128 +335,178 @@ def _run_mcmc_inner_loop(
294
335
  )
295
336
  if rt is not None:
296
337
  bart, callback_state = rt
338
+ carry = replace(carry, bart=bart, callback_state=callback_state)
297
339
 
298
- i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
299
- i_both = jnp.where(burnin, i_total, n_burn + i_onlymain)
300
-
301
- def update_trace(index, trace, state):
302
- def assign_at_index(trace_array, state_array):
303
- if trace_array.size:
304
- return trace_array.at[index, ...].set(state_array)
305
- else:
306
- # this handles the case where a trace is empty (e.g.,
307
- # no burn-in) because jax refuses to index into an array
308
- # of length 0
309
- return trace_array
340
+ def save_to_burnin_trace() -> tuple[PyTree, PyTree]:
341
+ return _pytree_at_set(
342
+ carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
343
+ ), carry.main_trace
310
344
 
311
- return tree.map(assign_at_index, trace, state)
345
+ def save_to_main_trace() -> tuple[PyTree, PyTree]:
346
+ idx = (carry.i_total - n_burn) // n_skip
347
+ return carry.burnin_trace, _pytree_at_set(
348
+ carry.main_trace, idx, main_extractor(carry.bart)
349
+ )
312
350
 
313
- trace_onlymain = update_trace(
314
- i_onlymain, trace_onlymain, onlymain_extractor(bart)
351
+ # save state to trace
352
+ burnin_trace, main_trace = lax.cond(
353
+ burnin, save_to_burnin_trace, save_to_main_trace
354
+ )
355
+ return replace(
356
+ carry,
357
+ i_total=carry.i_total + 1,
358
+ burnin_trace=burnin_trace,
359
+ main_trace=main_trace,
315
360
  )
316
- trace_both = update_trace(i_both, trace_both, both_extractor(bart))
317
361
 
318
- i_total += 1
319
- carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
362
+ def loop_noop(carry: _Carry) -> _Carry:
363
+ """Loop body to run if i_total >= n_iters; it does nothing."""
364
+ return carry
365
+
366
+ def loop(carry: _Carry, _) -> tuple[_Carry, None]:
367
+ carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry)
320
368
  return carry, None
321
369
 
322
370
  carry, _ = lax.scan(loop, carry, None, inner_loop_length)
323
371
  return carry
324
372
 
325
373
 
326
- def make_print_callbacks(dot_every_inner=1, report_every_outer=1):
374
+ def _pytree_at_set(
375
+ dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array]
376
+ ) -> PyTree[Array, ' T']:
377
+ """Map ``dest.at[index].set(val)`` over pytrees."""
378
+
379
+ def at_set(dest, val):
380
+ if dest.size:
381
+ return dest.at[index, ...].set(val)
382
+ else:
383
+ # this handles the case where an array is empty because jax refuses
384
+ # to index into an array of length 0, even if just in the abstract
385
+ return dest
386
+
387
+ return tree.map(at_set, dest, val)
388
+
389
+
390
+ def make_default_callback(
391
+ *,
392
+ dot_every: int | Integer[Array, ''] | None = 1,
393
+ report_every: int | Integer[Array, ''] | None = 100,
394
+ sparse_on_at: int | Integer[Array, ''] | None = None,
395
+ ) -> dict[str, Any]:
327
396
  """
328
- Prepare logging callbacks for `run_mcmc`.
397
+ Prepare a default callback for `run_mcmc`.
329
398
 
330
- Prepare callbacks which print a dot on every iteration, and a longer
331
- report outer loop iteration.
399
+ The callback prints a dot on every iteration, and a longer
400
+ report outer loop iteration, and can do variable selection.
332
401
 
333
402
  Parameters
334
403
  ----------
335
- dot_every_inner : int, default 1
336
- A dot is printed every `dot_every_inner` MCMC iterations.
337
- report_every_outer : int, default 1
338
- A report is printed every `report_every_outer` outer loop
339
- iterations.
404
+ dot_every
405
+ A dot is printed every `dot_every` MCMC iterations, `None` to disable.
406
+ report_every
407
+ A one line report is printed every `report_every` MCMC iterations,
408
+ `None` to disable.
409
+ sparse_on_at
410
+ If specified, variable selection is activated starting from this
411
+ iteration. If `None`, variable selection is not used.
340
412
 
341
413
  Returns
342
414
  -------
343
- kwargs : dict
344
- A dictionary with the arguments to pass to `run_mcmc` as keyword
345
- arguments to set up the callbacks.
415
+ A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
346
416
 
347
417
  Examples
348
418
  --------
349
- >>> run_mcmc(..., **make_print_callbacks())
419
+ >>> run_mcmc(..., **make_default_callback())
350
420
  """
421
+
422
+ def asarray_or_none(val: None | Any) -> None | Array:
423
+ return None if val is None else jnp.asarray(val)
424
+
425
+ def callback(*, bart, callback_state, **kwargs):
426
+ print_state, sparse_state = callback_state
427
+ bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs)
428
+ print_callback(callback_state=print_state, bart=bart, **kwargs)
429
+ return bart, callback_state
430
+ # here I assume that the callbacks don't update their states
431
+
351
432
  return dict(
352
- inner_callback=_print_callback_inner,
353
- outer_callback=_print_callback_outer,
354
- callback_state=dict(
355
- dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
433
+ callback=callback,
434
+ callback_state=(
435
+ PrintCallbackState(
436
+ asarray_or_none(dot_every), asarray_or_none(report_every)
437
+ ),
438
+ SparseCallbackState(asarray_or_none(sparse_on_at)),
356
439
  ),
357
440
  )
358
441
 
359
442
 
360
- def _print_callback_inner(*, i_total, callback_state, **_):
361
- dot_every_inner = callback_state['dot_every_inner']
362
- if dot_every_inner is not None:
363
- cond = (i_total + 1) % dot_every_inner == 0
364
- debug.callback(_print_dot, cond)
443
+ class PrintCallbackState(Module):
444
+ """State for `print_callback`.
365
445
 
446
+ Parameters
447
+ ----------
448
+ dot_every
449
+ A dot is printed every `dot_every` MCMC iterations, `None` to disable.
450
+ report_every
451
+ A one line report is printed every `report_every` MCMC iterations,
452
+ `None` to disable.
453
+ """
366
454
 
367
- def _print_dot(cond):
368
- if cond:
369
- print('.', end='', flush=True)
455
+ dot_every: Int32[Array, ''] | None
456
+ report_every: Int32[Array, ''] | None
370
457
 
371
458
 
372
- def _print_callback_outer(
459
+ def print_callback(
373
460
  *,
374
- bart,
375
- burnin,
376
- overflow,
377
- i_total,
378
- n_burn,
379
- n_save,
380
- n_skip,
381
- callback_state,
382
- i_outer,
383
- inner_loop_length,
461
+ bart: State,
462
+ burnin: Bool[Array, ''],
463
+ i_total: Int32[Array, ''],
464
+ n_burn: Int32[Array, ''],
465
+ n_save: Int32[Array, ''],
466
+ n_skip: Int32[Array, ''],
467
+ callback_state: PrintCallbackState,
384
468
  **_,
385
469
  ):
386
- report_every_outer = callback_state['report_every_outer']
387
- if report_every_outer is not None:
388
- dot_every_inner = callback_state['dot_every_inner']
389
- if dot_every_inner is None:
390
- newline = False
391
- else:
392
- newline = dot_every_inner < inner_loop_length
393
- debug.callback(
394
- _print_report,
395
- cond=(i_outer + 1) % report_every_outer == 0,
396
- newline=newline,
397
- burnin=burnin,
398
- overflow=overflow,
399
- i_total=i_total,
400
- n_iters=n_burn + n_save * n_skip,
401
- grow_prop_count=bart.forest.grow_prop_count,
402
- grow_acc_count=bart.forest.grow_acc_count,
403
- prune_prop_count=bart.forest.prune_prop_count,
404
- prune_acc_count=bart.forest.prune_acc_count,
405
- prop_total=len(bart.forest.leaf_trees),
406
- fill=grove.forest_fill(bart.forest.split_trees),
470
+ """Print a dot and/or a report periodically during the MCMC."""
471
+ if callback_state.dot_every is not None:
472
+ cond = (i_total + 1) % callback_state.dot_every == 0
473
+ lax.cond(
474
+ cond,
475
+ lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
476
+ # logging can't do in-line printing so I'll stick to print
477
+ lambda: None,
407
478
  )
408
479
 
480
+ if callback_state.report_every is not None:
481
+
482
+ def print_report():
483
+ debug.callback(
484
+ _print_report,
485
+ newline=callback_state.dot_every is not None,
486
+ burnin=burnin,
487
+ i_total=i_total,
488
+ n_iters=n_burn + n_save * n_skip,
489
+ grow_prop_count=bart.forest.grow_prop_count,
490
+ grow_acc_count=bart.forest.grow_acc_count,
491
+ prune_prop_count=bart.forest.prune_prop_count,
492
+ prune_acc_count=bart.forest.prune_acc_count,
493
+ prop_total=len(bart.forest.leaf_tree),
494
+ fill=grove.forest_fill(bart.forest.split_tree),
495
+ )
496
+
497
+ cond = (i_total + 1) % callback_state.report_every == 0
498
+ lax.cond(cond, print_report, lambda: None)
409
499
 
410
- def _convert_jax_arrays_in_args(func):
500
+
501
+ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
411
502
  """Remove jax arrays from a function arguments.
412
503
 
413
- Converts all jax.Array instances in the arguments to either Python scalars
504
+ Converts all `jax.Array` instances in the arguments to either Python scalars
414
505
  or numpy arrays.
415
506
  """
416
507
 
417
- def convert_jax_arrays(pytree):
418
- def convert_jax_arrays(val):
508
+ def convert_jax_arrays(pytree: PyTree) -> PyTree:
509
+ def convert_jax_arrays(val: Any) -> Any:
419
510
  if not isinstance(val, jax.Array):
420
511
  return val
421
512
  elif val.shape:
@@ -425,7 +516,7 @@ def _convert_jax_arrays_in_args(func):
425
516
 
426
517
  return tree.map(convert_jax_arrays, pytree)
427
518
 
428
- @functools.wraps(func)
519
+ @wraps(func)
429
520
  def new_func(*args, **kw):
430
521
  args = convert_jax_arrays(args)
431
522
  kw = convert_jax_arrays(kw)
@@ -439,73 +530,139 @@ def _convert_jax_arrays_in_args(func):
439
530
  # deadlock with the main thread
440
531
  def _print_report(
441
532
  *,
442
- cond,
443
- newline,
444
- burnin,
445
- overflow,
446
- i_total,
447
- n_iters,
448
- grow_prop_count,
449
- grow_acc_count,
450
- prune_prop_count,
451
- prune_acc_count,
452
- prop_total,
453
- fill,
533
+ newline: bool,
534
+ burnin: bool,
535
+ i_total: int,
536
+ n_iters: int,
537
+ grow_prop_count: int,
538
+ grow_acc_count: int,
539
+ prune_prop_count: int,
540
+ prune_acc_count: int,
541
+ prop_total: int,
542
+ fill: float,
454
543
  ):
455
- if cond:
456
- newline = '\n' if newline else ''
544
+ """Print the report for `print_callback`."""
457
545
 
458
- def acc_string(acc_count, prop_count):
459
- if prop_count:
460
- return f'{acc_count / prop_count:.0%}'
461
- else:
462
- return ' n/d'
546
+ def acc_string(acc_count, prop_count):
547
+ if prop_count:
548
+ return f'{acc_count / prop_count:.0%}'
549
+ else:
550
+ return 'n/d'
463
551
 
464
- grow_prop = grow_prop_count / prop_total
465
- prune_prop = prune_prop_count / prop_total
466
- grow_acc = acc_string(grow_acc_count, grow_prop_count)
467
- prune_acc = acc_string(prune_acc_count, prune_prop_count)
552
+ grow_prop = grow_prop_count / prop_total
553
+ prune_prop = prune_prop_count / prop_total
554
+ grow_acc = acc_string(grow_acc_count, grow_prop_count)
555
+ prune_acc = acc_string(prune_acc_count, prune_prop_count)
556
+
557
+ prefix = '\n' if newline else ''
558
+ suffix = ' (burnin)' if burnin else ''
559
+
560
+ print( # noqa: T201, see print_callback for why not logging
561
+ f'{prefix}It {i_total + 1}/{n_iters} '
562
+ f'grow P={grow_prop:.0%} A={grow_acc}, '
563
+ f'prune P={prune_prop:.0%} A={prune_acc}, '
564
+ f'fill={fill:.0%}{suffix}'
565
+ )
468
566
 
469
- if burnin:
470
- flag = ' (burnin)'
471
- elif overflow:
472
- flag = ' (overflow)'
473
- else:
474
- flag = ''
475
567
 
476
- print(
477
- f'{newline}It {i_total + 1}/{n_iters} '
478
- f'grow P={grow_prop:.0%} A={grow_acc}, '
479
- f'prune P={prune_prop:.0%} A={prune_acc}, '
480
- f'fill={fill:.0%}{flag}'
568
+ class SparseCallbackState(Module):
569
+ """State for `sparse_callback`.
570
+
571
+ Parameters
572
+ ----------
573
+ sparse_on_at
574
+ If specified, variable selection is activated starting from this
575
+ iteration. If `None`, variable selection is not used.
576
+ """
577
+
578
+ sparse_on_at: Int32[Array, ''] | None
579
+
580
+
581
+ def sparse_callback(
582
+ *,
583
+ key: Key[Array, ''],
584
+ bart: State,
585
+ i_total: Int32[Array, ''],
586
+ callback_state: SparseCallbackState,
587
+ **_,
588
+ ):
589
+ """Perform variable selection, see `mcmcstep.step_sparse`."""
590
+ if callback_state.sparse_on_at is not None:
591
+ bart = lax.cond(
592
+ i_total < callback_state.sparse_on_at,
593
+ lambda: bart,
594
+ lambda: mcmcstep.step_sparse(key, bart),
481
595
  )
596
+ return bart, callback_state
597
+
598
+
599
+ class Trace(grove.TreeHeaps, Protocol):
600
+ """Protocol for a MCMC trace."""
601
+
602
+ offset: Float32[Array, ' trace_length']
603
+
604
+
605
+ class TreesTrace(Module):
606
+ """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
607
+
608
+ leaf_tree: Float32[Array, 'trace_length num_trees 2**d']
609
+ var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
610
+ split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
611
+
612
+ @classmethod
613
+ def from_dataclass(cls, obj: grove.TreeHeaps):
614
+ """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
615
+ return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
482
616
 
483
617
 
484
618
  @jax.jit
485
- def evaluate_trace(trace, X):
619
+ def evaluate_trace(
620
+ trace: Trace, X: UInt[Array, 'p n']
621
+ ) -> Float32[Array, 'trace_length n']:
486
622
  """
487
623
  Compute predictions for all iterations of the BART MCMC.
488
624
 
489
625
  Parameters
490
626
  ----------
491
- trace : dict
627
+ trace
492
628
  A trace of the BART MCMC, as returned by `run_mcmc`.
493
- X : array (p, n)
629
+ X
494
630
  The predictors matrix, with `p` predictors and `n` observations.
495
631
 
496
632
  Returns
497
633
  -------
498
- y : array (n_trace, n)
499
- The predictions for each iteration of the MCMC.
634
+ The predictions for each iteration of the MCMC.
500
635
  """
501
- evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
502
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
636
+ evaluate_trees = partial(grove.evaluate_forest, sum_trees=False)
637
+ evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0))
638
+ trees = TreesTrace.from_dataclass(trace)
503
639
 
504
- def loop(_, row):
505
- values = evaluate_trees(
506
- X, row['leaf_trees'], row['var_trees'], row['split_trees']
507
- )
508
- return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32)
640
+ def loop(_, item):
641
+ offset, trees = item
642
+ values = evaluate_trees(X, trees)
643
+ return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32)
509
644
 
510
- _, y = lax.scan(loop, None, trace)
645
+ _, y = lax.scan(loop, None, (trace.offset, trees))
511
646
  return y
647
+
648
+
649
+ @partial(jax.jit, static_argnums=(0,))
650
+ def compute_varcount(
651
+ p: int, trace: grove.TreeHeaps
652
+ ) -> Int32[Array, 'trace_length {p}']:
653
+ """
654
+ Count how many times each predictor is used in each MCMC state.
655
+
656
+ Parameters
657
+ ----------
658
+ p
659
+ The number of predictors.
660
+ trace
661
+ A trace of the BART MCMC, as returned by `run_mcmc`.
662
+
663
+ Returns
664
+ -------
665
+ Histogram of predictor usage in each MCMC state.
666
+ """
667
+ vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0))
668
+ return vmapped_var_histogram(p, trace.var_tree, trace.split_tree)