bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcloop.py CHANGED
@@ -22,72 +22,222 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """
26
- Functions that implement the full BART posterior MCMC loop.
25
+ """Functions that implement the full BART posterior MCMC loop.
26
+
27
+ The entry points are `run_mcmc` and `make_default_callback`.
27
28
  """
28
29
 
29
- import functools
30
+ from collections.abc import Callable
31
+ from dataclasses import fields, replace
32
+ from functools import partial, wraps
33
+ from typing import Any, Protocol
30
34
 
31
35
  import jax
32
- from jax import debug, lax, random, tree
36
+ import numpy
37
+ from equinox import Module
38
+ from jax import debug, lax, tree
33
39
  from jax import numpy as jnp
40
+ from jax.nn import softmax
41
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt
42
+
43
+ from bartz import grove, jaxext, mcmcstep
44
+ from bartz.mcmcstep import State
45
+
46
+
47
+ class BurninTrace(Module):
48
+ """MCMC trace with only diagnostic values."""
49
+
50
+ sigma2: Float32[Array, '*trace_length'] | None
51
+ theta: Float32[Array, '*trace_length'] | None
52
+ grow_prop_count: Int32[Array, '*trace_length']
53
+ grow_acc_count: Int32[Array, '*trace_length']
54
+ prune_prop_count: Int32[Array, '*trace_length']
55
+ prune_acc_count: Int32[Array, '*trace_length']
56
+ log_likelihood: Float32[Array, '*trace_length'] | None
57
+ log_trans_prior: Float32[Array, '*trace_length'] | None
58
+
59
+ @classmethod
60
+ def from_state(cls, state: State) -> 'BurninTrace':
61
+ """Create a single-item burn-in trace from a MCMC state."""
62
+ return cls(
63
+ sigma2=state.sigma2,
64
+ theta=state.forest.theta,
65
+ grow_prop_count=state.forest.grow_prop_count,
66
+ grow_acc_count=state.forest.grow_acc_count,
67
+ prune_prop_count=state.forest.prune_prop_count,
68
+ prune_acc_count=state.forest.prune_acc_count,
69
+ log_likelihood=state.forest.log_likelihood,
70
+ log_trans_prior=state.forest.log_trans_prior,
71
+ )
34
72
 
35
- from . import grove, jaxext, mcmcstep
73
+
74
+ class MainTrace(BurninTrace):
75
+ """MCMC trace with trees and diagnostic values."""
76
+
77
+ leaf_tree: Float32[Array, '*trace_length 2**d']
78
+ var_tree: UInt[Array, '*trace_length 2**(d-1)']
79
+ split_tree: UInt[Array, '*trace_length 2**(d-1)']
80
+ offset: Float32[Array, '*trace_length']
81
+ varprob: Float32[Array, '*trace_length p'] | None
82
+
83
+ @classmethod
84
+ def from_state(cls, state: State) -> 'MainTrace':
85
+ """Create a single-item main trace from a MCMC state."""
86
+ # compute varprob
87
+ log_s = state.forest.log_s
88
+ if log_s is None:
89
+ varprob = None
90
+ else:
91
+ varprob = softmax(log_s, where=state.forest.max_split.astype(bool))
92
+
93
+ return cls(
94
+ leaf_tree=state.forest.leaf_tree,
95
+ var_tree=state.forest.var_tree,
96
+ split_tree=state.forest.split_tree,
97
+ offset=state.offset,
98
+ varprob=varprob,
99
+ **vars(BurninTrace.from_state(state)),
100
+ )
36
101
 
37
102
 
38
- @functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))
39
- def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
103
+ CallbackState = PyTree[Any, 'T']
104
+
105
+
106
+ class Callback(Protocol):
107
+ """Callback type for `run_mcmc`."""
108
+
109
+ def __call__(
110
+ self,
111
+ *,
112
+ key: Key[Array, ''],
113
+ bart: State,
114
+ burnin: Bool[Array, ''],
115
+ i_total: Int32[Array, ''],
116
+ i_skip: Int32[Array, ''],
117
+ callback_state: CallbackState,
118
+ n_burn: Int32[Array, ''],
119
+ n_save: Int32[Array, ''],
120
+ n_skip: Int32[Array, ''],
121
+ i_outer: Int32[Array, ''],
122
+ inner_loop_length: int,
123
+ ) -> tuple[State, CallbackState] | None:
124
+ """Do an arbitrary action after an iteration of the MCMC.
125
+
126
+ Parameters
127
+ ----------
128
+ key
129
+ A key for random number generation.
130
+ bart
131
+ The MCMC state just after updating it.
132
+ burnin
133
+ Whether the last iteration was in the burn-in phase.
134
+ i_total
135
+ The index of the last MCMC iteration (0-based).
136
+ i_skip
137
+ The number of MCMC updates from the last saved state. The initial
138
+ state counts as saved, even if it's not copied into the trace.
139
+ callback_state
140
+ The callback state, initially set to the argument passed to
141
+ `run_mcmc`, afterwards to the value returned by the last invocation
142
+ of the callback.
143
+ n_burn
144
+ n_save
145
+ n_skip
146
+ The corresponding `run_mcmc` arguments as-is.
147
+ i_outer
148
+ The index of the last outer loop iteration (0-based).
149
+ inner_loop_length
150
+ The number of MCMC iterations in the inner loop.
151
+
152
+ Returns
153
+ -------
154
+ bart : State
155
+ A possibly modified MCMC state. To avoid modifying the state,
156
+ return the `bart` argument passed to the callback as-is.
157
+ callback_state : CallbackState
158
+ The new state to be passed on the next callback invocation.
159
+
160
+ Notes
161
+ -----
162
+ For convenience, the callback may return `None`, and the states won't
163
+ be updated.
164
+ """
165
+ ...
166
+
167
+
168
+ class _Carry(Module):
169
+ """Carry used in the loop in `run_mcmc`."""
170
+
171
+ bart: State
172
+ i_total: Int32[Array, '']
173
+ key: Key[Array, '']
174
+ burnin_trace: PyTree[Shaped[Array, 'n_burn *']]
175
+ main_trace: PyTree[Shaped[Array, 'n_save *']]
176
+ callback_state: CallbackState
177
+
178
+
179
+ def run_mcmc(
180
+ key: Key[Array, ''],
181
+ bart: State,
182
+ n_save: int,
183
+ *,
184
+ n_burn: int = 0,
185
+ n_skip: int = 1,
186
+ inner_loop_length: int | None = None,
187
+ callback: Callback | None = None,
188
+ callback_state: CallbackState = None,
189
+ burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
190
+ main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
191
+ ) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
40
192
  """
41
193
  Run the MCMC for the BART posterior.
42
194
 
43
195
  Parameters
44
196
  ----------
45
- key : jax.dtypes.prng_key array
46
- The key for random number generation.
47
- bart : dict
197
+ key
198
+ A key for random number generation.
199
+ bart
48
200
  The initial MCMC state, as created and updated by the functions in
49
- `bartz.mcmcstep`.
50
- n_burn : int
51
- The number of initial iterations which are not saved.
52
- n_save : int
201
+ `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
202
+ so this variable is invalidated after running `run_mcmc`. Make a copy
203
+ beforehand to use it again.
204
+ n_save
53
205
  The number of iterations to save.
54
- n_skip : int
206
+ n_burn
207
+ The number of initial iterations which are not saved.
208
+ n_skip
55
209
  The number of iterations to skip between each saved iteration, plus 1.
56
210
  The effective burn-in is ``n_burn + n_skip - 1``.
57
- callback : callable
58
- An arbitrary function run at each iteration, called with the following
59
- arguments, passed by keyword:
60
-
61
- bart : dict
62
- The MCMC state just after updating it.
63
- burnin : bool
64
- Whether the last iteration was in the burn-in phase.
65
- i_total : int
66
- The index of the last iteration (0-based).
67
- i_skip : int
68
- The number of MCMC updates from the last saved state. The initial
69
- state counts as saved, even if it's not copied into the trace.
70
- n_burn, n_save, n_skip : int
71
- The corresponding arguments as-is.
72
-
73
- Since this function is called under the jax jit, the values are not
74
- available at the time the Python code is executed. Use the utilities in
75
- `jax.debug` to access the values at actual runtime.
211
+ inner_loop_length
212
+ The MCMC loop is split into an outer and an inner loop. The outer loop
213
+ is in Python, while the inner loop is in JAX. `inner_loop_length` is the
214
+ number of iterations of the inner loop to run for each iteration of the
215
+ outer loop. If not specified, the outer loop will iterate just once,
216
+ with all iterations done in a single inner loop run. The inner stride is
217
+ unrelated to the stride used for saving the trace.
218
+ callback
219
+ An arbitrary function run during the loop after updating the state. For
220
+ the signature, see `Callback`. The callback is called under the jax jit,
221
+ so the argument values are not available at the time the Python code is
222
+ executed. Use the utilities in `jax.debug` to access the values at
223
+ actual runtime. The callback may return new values for the MCMC state
224
+ and the callback state.
225
+ callback_state
226
+ The initial custom state for the callback.
227
+ burnin_extractor
228
+ main_extractor
229
+ Functions that extract the variables to be saved respectively only in
230
+ the main trace and in both traces, given the MCMC state as argument.
231
+ Must return a pytree, and must be vmappable.
76
232
 
77
233
  Returns
78
234
  -------
79
- bart : dict
235
+ bart : State
80
236
  The final MCMC state.
81
- burnin_trace : dict of (n_burn, ...) arrays
82
- The trace of the burn-in phase, containing the following subset of
83
- fields from the `bart` dictionary, with an additional head index that
84
- runs over MCMC iterations: 'sigma2', 'grow_prop_count',
85
- 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
86
- main_trace : dict of (n_save, ...) arrays
87
- The trace of the main phase, containing the following subset of fields
88
- from the `bart` dictionary, with an additional head index that runs
89
- over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
90
- the fields in `burnin_trace`.
237
+ burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
238
+ The trace of the burn-in phase. For the default layout, see `BurninTrace`.
239
+ main_trace : PyTree[Shaped[Array, 'n_save *']]
240
+ The trace of the main phase. For the default layout, see `MainTrace`.
91
241
 
92
242
  Notes
93
243
  -----
@@ -95,164 +245,424 @@ def run_mcmc(key, bart, n_burn, n_save, n_skip, callback):
95
245
  not include the initial state, and include the final state.
96
246
  """
97
247
 
98
- tracevars_light = (
99
- 'sigma2',
100
- 'grow_prop_count',
101
- 'grow_acc_count',
102
- 'prune_prop_count',
103
- 'prune_acc_count',
104
- 'ratios',
105
- )
106
- tracevars_heavy = ('leaf_trees', 'var_trees', 'split_trees')
107
-
108
- def empty_trace(length, bart, tracelist):
109
- bart = {k: v for k, v in bart.items() if k in tracelist}
110
- return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=length)(bart)
111
-
112
- trace_light = empty_trace(n_burn + n_save, bart, tracevars_light)
113
- trace_heavy = empty_trace(n_save, bart, tracevars_heavy)
248
+ def empty_trace(length, bart, extractor):
249
+ return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart)
250
+
251
+ burnin_trace = empty_trace(n_burn, bart, burnin_extractor)
252
+ main_trace = empty_trace(n_save, bart, main_extractor)
253
+
254
+ # determine number of iterations for inner and outer loops
255
+ n_iters = n_burn + n_skip * n_save
256
+ if inner_loop_length is None:
257
+ inner_loop_length = n_iters
258
+ if inner_loop_length:
259
+ n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length)
260
+ else:
261
+ n_outer = 1
262
+ # setting to 0 would make for a clean noop, but it's useful to keep the
263
+ # same code path for benchmarking and testing
264
+
265
+ carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state)
266
+ for i_outer in range(n_outer):
267
+ carry = _run_mcmc_inner_loop(
268
+ carry,
269
+ inner_loop_length,
270
+ callback,
271
+ burnin_extractor,
272
+ main_extractor,
273
+ n_burn,
274
+ n_save,
275
+ n_skip,
276
+ i_outer,
277
+ n_iters,
278
+ )
114
279
 
115
- callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip)
280
+ return carry.bart, carry.burnin_trace, carry.main_trace
116
281
 
117
- carry = (bart, 0, key, trace_light, trace_heavy)
118
282
 
119
- def loop(carry, _):
120
- bart, i_total, key, trace_light, trace_heavy = carry
283
+ def _compute_i_skip(
284
+ i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
285
+ ) -> Int32[Array, '']:
286
+ """Compute the `i_skip` argument passed to `callback`."""
287
+ burnin = i_total < n_burn
288
+ return jnp.where(
289
+ burnin,
290
+ i_total + 1,
291
+ (i_total - n_burn + 1) % n_skip
292
+ + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
293
+ )
121
294
 
122
- key, subkey = random.split(key)
123
- bart = mcmcstep.step(subkey, bart)
124
295
 
125
- burnin = i_total < n_burn
126
- i_skip = jnp.where(
127
- burnin,
128
- i_total + 1,
129
- (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
296
+ @partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4))
297
+ def _run_mcmc_inner_loop(
298
+ carry: _Carry,
299
+ inner_loop_length: int,
300
+ callback: Callback | None,
301
+ burnin_extractor: Callable[[State], PyTree],
302
+ main_extractor: Callable[[State], PyTree],
303
+ n_burn: Int32[Array, ''],
304
+ n_save: Int32[Array, ''],
305
+ n_skip: Int32[Array, ''],
306
+ i_outer: Int32[Array, ''],
307
+ n_iters: Int32[Array, ''],
308
+ ):
309
+ def loop_impl(carry: _Carry) -> _Carry:
310
+ """Loop body to run if i_total < n_iters."""
311
+ # split random key
312
+ keys = jaxext.split(carry.key, 3)
313
+ carry = replace(carry, key=keys.pop())
314
+
315
+ # update state
316
+ carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart))
317
+
318
+ burnin = carry.i_total < n_burn
319
+
320
+ # invoke callback
321
+ if callback is not None:
322
+ i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip)
323
+ rt = callback(
324
+ key=keys.pop(),
325
+ bart=carry.bart,
326
+ burnin=burnin,
327
+ i_total=carry.i_total,
328
+ i_skip=i_skip,
329
+ callback_state=carry.callback_state,
330
+ n_burn=n_burn,
331
+ n_save=n_save,
332
+ n_skip=n_skip,
333
+ i_outer=i_outer,
334
+ inner_loop_length=inner_loop_length,
335
+ )
336
+ if rt is not None:
337
+ bart, callback_state = rt
338
+ carry = replace(carry, bart=bart, callback_state=callback_state)
339
+
340
+ def save_to_burnin_trace() -> tuple[PyTree, PyTree]:
341
+ return _pytree_at_set(
342
+ carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
343
+ ), carry.main_trace
344
+
345
+ def save_to_main_trace() -> tuple[PyTree, PyTree]:
346
+ idx = (carry.i_total - n_burn) // n_skip
347
+ return carry.burnin_trace, _pytree_at_set(
348
+ carry.main_trace, idx, main_extractor(carry.bart)
349
+ )
350
+
351
+ # save state to trace
352
+ burnin_trace, main_trace = lax.cond(
353
+ burnin, save_to_burnin_trace, save_to_main_trace
130
354
  )
131
- callback(
132
- bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw
355
+ return replace(
356
+ carry,
357
+ i_total=carry.i_total + 1,
358
+ burnin_trace=burnin_trace,
359
+ main_trace=main_trace,
133
360
  )
134
361
 
135
- i_heavy = jnp.where(burnin, 0, (i_total - n_burn) // n_skip)
136
- i_light = jnp.where(burnin, i_total, n_burn + i_heavy)
137
-
138
- def update_trace(index, trace, bart):
139
- bart = {k: v for k, v in bart.items() if k in trace}
362
+ def loop_noop(carry: _Carry) -> _Carry:
363
+ """Loop body to run if i_total >= n_iters; it does nothing."""
364
+ return carry
140
365
 
141
- def assign_at_index(trace_array, state_array):
142
- if trace_array.size:
143
- return trace_array.at[index, ...].set(state_array)
144
- else:
145
- # this handles the case where a trace is empty (e.g.,
146
- # no burn-in) because jax refuses to index into an array
147
- # of length 0
148
- return trace_array
149
-
150
- return tree.map(assign_at_index, trace, bart)
151
-
152
- trace_heavy = update_trace(i_heavy, trace_heavy, bart)
153
- trace_light = update_trace(i_light, trace_light, bart)
154
-
155
- i_total += 1
156
- carry = (bart, i_total, key, trace_light, trace_heavy)
366
+ def loop(carry: _Carry, _) -> tuple[_Carry, None]:
367
+ carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry)
157
368
  return carry, None
158
369
 
159
- carry, _ = lax.scan(loop, carry, None, n_burn + n_skip * n_save)
370
+ carry, _ = lax.scan(loop, carry, None, inner_loop_length)
371
+ return carry
160
372
 
161
- bart, _, _, trace_light, trace_heavy = carry
162
373
 
163
- burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_light)
164
- main_trace = tree.map(lambda x: x[n_burn:, ...], trace_light)
165
- main_trace.update(trace_heavy)
374
+ def _pytree_at_set(
375
+ dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array]
376
+ ) -> PyTree[Array, ' T']:
377
+ """Map ``dest.at[index].set(val)`` over pytrees."""
166
378
 
167
- return bart, burnin_trace, main_trace
379
+ def at_set(dest, val):
380
+ if dest.size:
381
+ return dest.at[index, ...].set(val)
382
+ else:
383
+ # this handles the case where an array is empty because jax refuses
384
+ # to index into an array of length 0, even if just in the abstract
385
+ return dest
168
386
 
387
+ return tree.map(at_set, dest, val)
169
388
 
170
- @functools.lru_cache
171
- # cache to make the callback function object unique, such that the jit
172
- # of run_mcmc recognizes it
173
- def make_simple_print_callback(printevery):
389
+
390
+ def make_default_callback(
391
+ *,
392
+ dot_every: int | Integer[Array, ''] | None = 1,
393
+ report_every: int | Integer[Array, ''] | None = 100,
394
+ sparse_on_at: int | Integer[Array, ''] | None = None,
395
+ ) -> dict[str, Any]:
174
396
  """
175
- Create a logging callback function for MCMC iterations.
397
+ Prepare a default callback for `run_mcmc`.
398
+
399
+ The callback prints a dot on every iteration, and a longer
400
+ report outer loop iteration, and can do variable selection.
176
401
 
177
402
  Parameters
178
403
  ----------
179
- printevery : int
180
- The number of iterations between each log.
404
+ dot_every
405
+ A dot is printed every `dot_every` MCMC iterations, `None` to disable.
406
+ report_every
407
+ A one line report is printed every `report_every` MCMC iterations,
408
+ `None` to disable.
409
+ sparse_on_at
410
+ If specified, variable selection is activated starting from this
411
+ iteration. If `None`, variable selection is not used.
181
412
 
182
413
  Returns
183
414
  -------
184
- callback : callable
185
- A function in the format required by `run_mcmc`.
415
+ A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
416
+
417
+ Examples
418
+ --------
419
+ >>> run_mcmc(..., **make_default_callback())
420
+ """
421
+
422
+ def asarray_or_none(val: None | Any) -> None | Array:
423
+ return None if val is None else jnp.asarray(val)
424
+
425
+ def callback(*, bart, callback_state, **kwargs):
426
+ print_state, sparse_state = callback_state
427
+ bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs)
428
+ print_callback(callback_state=print_state, bart=bart, **kwargs)
429
+ return bart, callback_state
430
+ # here I assume that the callbacks don't update their states
431
+
432
+ return dict(
433
+ callback=callback,
434
+ callback_state=(
435
+ PrintCallbackState(
436
+ asarray_or_none(dot_every), asarray_or_none(report_every)
437
+ ),
438
+ SparseCallbackState(asarray_or_none(sparse_on_at)),
439
+ ),
440
+ )
441
+
442
+
443
+ class PrintCallbackState(Module):
444
+ """State for `print_callback`.
445
+
446
+ Parameters
447
+ ----------
448
+ dot_every
449
+ A dot is printed every `dot_every` MCMC iterations, `None` to disable.
450
+ report_every
451
+ A one line report is printed every `report_every` MCMC iterations,
452
+ `None` to disable.
186
453
  """
187
454
 
188
- def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip):
189
- prop_total = len(bart['leaf_trees'])
190
- grow_prop = bart['grow_prop_count'] / prop_total
191
- prune_prop = bart['prune_prop_count'] / prop_total
192
- grow_acc = bart['grow_acc_count'] / bart['grow_prop_count']
193
- prune_acc = bart['prune_acc_count'] / bart['prune_prop_count']
194
- n_total = n_burn + n_save * n_skip
195
- printcond = (i_total + 1) % printevery == 0
196
- debug.callback(
197
- _simple_print_callback,
198
- burnin,
199
- i_total,
200
- n_total,
201
- grow_prop,
202
- grow_acc,
203
- prune_prop,
204
- prune_acc,
205
- printcond,
455
+ dot_every: Int32[Array, ''] | None
456
+ report_every: Int32[Array, ''] | None
457
+
458
+
459
+ def print_callback(
460
+ *,
461
+ bart: State,
462
+ burnin: Bool[Array, ''],
463
+ i_total: Int32[Array, ''],
464
+ n_burn: Int32[Array, ''],
465
+ n_save: Int32[Array, ''],
466
+ n_skip: Int32[Array, ''],
467
+ callback_state: PrintCallbackState,
468
+ **_,
469
+ ):
470
+ """Print a dot and/or a report periodically during the MCMC."""
471
+ if callback_state.dot_every is not None:
472
+ cond = (i_total + 1) % callback_state.dot_every == 0
473
+ lax.cond(
474
+ cond,
475
+ lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
476
+ # logging can't do in-line printing so I'll stick to print
477
+ lambda: None,
206
478
  )
207
479
 
208
- return callback
480
+ if callback_state.report_every is not None:
481
+
482
+ def print_report():
483
+ debug.callback(
484
+ _print_report,
485
+ newline=callback_state.dot_every is not None,
486
+ burnin=burnin,
487
+ i_total=i_total,
488
+ n_iters=n_burn + n_save * n_skip,
489
+ grow_prop_count=bart.forest.grow_prop_count,
490
+ grow_acc_count=bart.forest.grow_acc_count,
491
+ prune_prop_count=bart.forest.prune_prop_count,
492
+ prune_acc_count=bart.forest.prune_acc_count,
493
+ prop_total=len(bart.forest.leaf_tree),
494
+ fill=grove.forest_fill(bart.forest.split_tree),
495
+ )
496
+
497
+ cond = (i_total + 1) % callback_state.report_every == 0
498
+ lax.cond(cond, print_report, lambda: None)
499
+
500
+
501
+ def _convert_jax_arrays_in_args(func: Callable) -> Callable:
502
+ """Remove jax arrays from a function arguments.
503
+
504
+ Converts all `jax.Array` instances in the arguments to either Python scalars
505
+ or numpy arrays.
506
+ """
507
+
508
+ def convert_jax_arrays(pytree: PyTree) -> PyTree:
509
+ def convert_jax_arrays(val: Any) -> Any:
510
+ if not isinstance(val, jax.Array):
511
+ return val
512
+ elif val.shape:
513
+ return numpy.array(val)
514
+ else:
515
+ return val.item()
516
+
517
+ return tree.map(convert_jax_arrays, pytree)
518
+
519
+ @wraps(func)
520
+ def new_func(*args, **kw):
521
+ args = convert_jax_arrays(args)
522
+ kw = convert_jax_arrays(kw)
523
+ return func(*args, **kw)
524
+
525
+ return new_func
526
+
527
+
528
+ @_convert_jax_arrays_in_args
529
+ # convert all jax arrays in arguments because operations on them could lead to
530
+ # deadlock with the main thread
531
+ def _print_report(
532
+ *,
533
+ newline: bool,
534
+ burnin: bool,
535
+ i_total: int,
536
+ n_iters: int,
537
+ grow_prop_count: int,
538
+ grow_acc_count: int,
539
+ prune_prop_count: int,
540
+ prune_acc_count: int,
541
+ prop_total: int,
542
+ fill: float,
543
+ ):
544
+ """Print the report for `print_callback`."""
545
+
546
+ def acc_string(acc_count, prop_count):
547
+ if prop_count:
548
+ return f'{acc_count / prop_count:.0%}'
549
+ else:
550
+ return 'n/d'
551
+
552
+ grow_prop = grow_prop_count / prop_total
553
+ prune_prop = prune_prop_count / prop_total
554
+ grow_acc = acc_string(grow_acc_count, grow_prop_count)
555
+ prune_acc = acc_string(prune_acc_count, prune_prop_count)
556
+
557
+ prefix = '\n' if newline else ''
558
+ suffix = ' (burnin)' if burnin else ''
559
+
560
+ print( # noqa: T201, see print_callback for why not logging
561
+ f'{prefix}It {i_total + 1}/{n_iters} '
562
+ f'grow P={grow_prop:.0%} A={grow_acc}, '
563
+ f'prune P={prune_prop:.0%} A={prune_acc}, '
564
+ f'fill={fill:.0%}{suffix}'
565
+ )
566
+
567
+
568
+ class SparseCallbackState(Module):
569
+ """State for `sparse_callback`.
570
+
571
+ Parameters
572
+ ----------
573
+ sparse_on_at
574
+ If specified, variable selection is activated starting from this
575
+ iteration. If `None`, variable selection is not used.
576
+ """
577
+
578
+ sparse_on_at: Int32[Array, ''] | None
209
579
 
210
580
 
211
- def _simple_print_callback(
212
- burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond
581
+ def sparse_callback(
582
+ *,
583
+ key: Key[Array, ''],
584
+ bart: State,
585
+ i_total: Int32[Array, ''],
586
+ callback_state: SparseCallbackState,
587
+ **_,
213
588
  ):
214
- if printcond:
215
- burnin_flag = ' (burnin)' if burnin else ''
216
- total_str = str(n_total)
217
- ndigits = len(total_str)
218
- i_str = str(i_total.item() + 1).rjust(ndigits)
219
- # I do i_total.item() + 1 instead of just i_total + 1 to solve a bug
220
- # originating when jax is combined with some outdated dependencies. (I
221
- # did not track down which dependencies exactly.) Doing .item() makes
222
- # the + 1 operation be done by Python instead of by jax. The bug is that
223
- # jax hangs completely, with a secondary thread blocked at this line.
224
- print(
225
- f'Iteration {i_str}/{total_str} '
226
- f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
227
- f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}'
589
+ """Perform variable selection, see `mcmcstep.step_sparse`."""
590
+ if callback_state.sparse_on_at is not None:
591
+ bart = lax.cond(
592
+ i_total < callback_state.sparse_on_at,
593
+ lambda: bart,
594
+ lambda: mcmcstep.step_sparse(key, bart),
228
595
  )
596
+ return bart, callback_state
597
+
598
+
599
+ class Trace(grove.TreeHeaps, Protocol):
600
+ """Protocol for a MCMC trace."""
601
+
602
+ offset: Float32[Array, ' trace_length']
603
+
604
+
605
+ class TreesTrace(Module):
606
+ """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
607
+
608
+ leaf_tree: Float32[Array, 'trace_length num_trees 2**d']
609
+ var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
610
+ split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)']
611
+
612
+ @classmethod
613
+ def from_dataclass(cls, obj: grove.TreeHeaps):
614
+ """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
615
+ return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)})
229
616
 
230
617
 
231
618
  @jax.jit
232
- def evaluate_trace(trace, X):
619
+ def evaluate_trace(
620
+ trace: Trace, X: UInt[Array, 'p n']
621
+ ) -> Float32[Array, 'trace_length n']:
233
622
  """
234
623
  Compute predictions for all iterations of the BART MCMC.
235
624
 
236
625
  Parameters
237
626
  ----------
238
- trace : dict
627
+ trace
239
628
  A trace of the BART MCMC, as returned by `run_mcmc`.
240
- X : array (p, n)
629
+ X
241
630
  The predictors matrix, with `p` predictors and `n` observations.
242
631
 
243
632
  Returns
244
633
  -------
245
- y : array (n_trace, n)
246
- The predictions for each iteration of the MCMC.
634
+ The predictions for each iteration of the MCMC.
247
635
  """
248
- evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False)
249
- evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0))
636
+ evaluate_trees = partial(grove.evaluate_forest, sum_trees=False)
637
+ evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0))
638
+ trees = TreesTrace.from_dataclass(trace)
250
639
 
251
- def loop(_, state):
252
- values = evaluate_trees(
253
- X, state['leaf_trees'], state['var_trees'], state['split_trees']
254
- )
255
- return None, jnp.sum(values, axis=0, dtype=jnp.float32)
640
+ def loop(_, item):
641
+ offset, trees = item
642
+ values = evaluate_trees(X, trees)
643
+ return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32)
256
644
 
257
- _, y = lax.scan(loop, None, trace)
645
+ _, y = lax.scan(loop, None, (trace.offset, trees))
258
646
  return y
647
+
648
+
649
+ @partial(jax.jit, static_argnums=(0,))
650
+ def compute_varcount(
651
+ p: int, trace: grove.TreeHeaps
652
+ ) -> Int32[Array, 'trace_length {p}']:
653
+ """
654
+ Count how many times each predictor is used in each MCMC state.
655
+
656
+ Parameters
657
+ ----------
658
+ p
659
+ The number of predictors.
660
+ trace
661
+ A trace of the BART MCMC, as returned by `run_mcmc`.
662
+
663
+ Returns
664
+ -------
665
+ Histogram of predictor usage in each MCMC state.
666
+ """
667
+ vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0))
668
+ return vmapped_var_histogram(p, trace.var_tree, trace.split_tree)