bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/debug.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/debug.py
2
2
  #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -22,13 +22,14 @@
22
22
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
23
  # SOFTWARE.
24
24
 
25
- """Debugging utilities. The entry point is the class `debug_gbart`."""
25
+ """Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
26
26
 
27
27
  from collections.abc import Callable
28
28
  from dataclasses import replace
29
29
  from functools import partial
30
30
  from math import ceil, log2
31
31
  from re import fullmatch
32
+ from typing import Literal
32
33
 
33
34
  import numpy
34
35
  from equinox import Module, field
@@ -37,20 +38,22 @@ from jax import numpy as jnp
37
38
  from jax.tree_util import tree_map
38
39
  from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
39
40
 
40
- from bartz.BART import FloatLike, gbart
41
+ from bartz.BART import gbart, mc_gbart
42
+ from bartz.BART._gbart import FloatLike
41
43
  from bartz.grove import (
42
44
  TreeHeaps,
43
45
  evaluate_forest,
44
46
  is_actual_leaf,
45
47
  is_leaves_parent,
46
- traverse_tree,
48
+ normalize_axis_tuple,
49
+ traverse_forest,
47
50
  tree_depth,
48
51
  tree_depths,
49
52
  )
50
- from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
53
+ from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc
51
54
  from bartz.jaxext import split as split_key
52
55
  from bartz.mcmcloop import TreesTrace
53
- from bartz.mcmcstep import randint_masked
56
+ from bartz.mcmcstep._moves import randint_masked
54
57
 
55
58
 
56
59
  def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
@@ -154,9 +157,11 @@ def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
154
157
  return jnp.max(depth)
155
158
 
156
159
 
160
+ @jit
161
+ @partial(jnp.vectorize, signature='(nt,hts)->(d)')
157
162
  def forest_depth_distr(
158
- split_tree: UInt[Array, 'num_trees 2**(d-1)'],
159
- ) -> Int32[Array, ' d']:
163
+ split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'],
164
+ ) -> Int32[Array, '*batch_shape d']:
160
165
  """Histogram the depths of a set of trees.
161
166
 
162
167
  Parameters
@@ -173,195 +178,102 @@ def forest_depth_distr(
173
178
  return jnp.bincount(depths, length=depth)
174
179
 
175
180
 
176
- @jit
177
- def trace_depth_distr(
178
- split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'],
179
- ) -> Int32[Array, 'trace_length d']:
180
- """Histogram the depths of a sequence of sets of trees.
181
-
182
- Parameters
183
- ----------
184
- split_tree
185
- The cutpoints of the decision rules of the trees.
186
-
187
- Returns
188
- -------
189
- A matrix where element (t,i) counts how many trees have depth i in set t.
190
- """
191
- return vmap(forest_depth_distr)(split_tree)
192
-
193
-
194
- def points_per_decision_node_distr(
195
- var_tree: UInt[Array, ' 2**(d-1)'],
196
- split_tree: UInt[Array, ' 2**(d-1)'],
181
+ @partial(jit, static_argnames=('node_type', 'sum_batch_axis'))
182
+ def points_per_node_distr(
197
183
  X: UInt[Array, 'p n'],
198
- ) -> Int32[Array, ' n+1']:
199
- """Histogram points-per-node counts.
200
-
201
- Count how many parent-of-leaf nodes in a tree select each possible amount
202
- of points.
203
-
204
- Parameters
205
- ----------
206
- var_tree
207
- The variables of the decision rules.
208
- split_tree
209
- The cutpoints of the decision rules.
210
- X
211
- The set of points to count.
212
-
213
- Returns
214
- -------
215
- A vector where the i-th element counts how many next-to-leaf nodes have i points.
216
- """
217
- traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
218
- indices = traverse_tree_X(X, var_tree, split_tree)
219
- indices >>= 1
220
- count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0)
221
- is_parent = is_leaves_parent(split_tree)
222
- return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent)
223
-
224
-
225
- def forest_points_per_decision_node_distr(
226
- trees: TreeHeaps, X: UInt[Array, 'p n']
227
- ) -> Int32[Array, ' n+1']:
228
- """Histogram points-per-node counts for a set of trees.
229
-
230
- Count how many parent-of-leaf nodes in a set of trees select each possible
231
- amount of points.
232
-
233
- Parameters
234
- ----------
235
- trees
236
- The set of trees. The variables must have broadcast shape (num_trees,).
237
- X
238
- The set of points to count.
239
-
240
- Returns
241
- -------
242
- A vector where the i-th element counts how many next-to-leaf nodes have i points.
243
- """
244
- distr = jnp.zeros(X.shape[1] + 1, int)
184
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
185
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
186
+ node_type: Literal['leaf', 'leaf-parent'],
187
+ *,
188
+ sum_batch_axis: int | tuple[int, ...] = (),
189
+ ) -> Int32[Array, '*reduced_batch_shape n+1']:
190
+ """Histogram points-per-node counts in a set of trees.
245
191
 
246
- def loop(distr, heaps: tuple[Array, Array]):
247
- return distr + points_per_decision_node_distr(*heaps, X), None
248
-
249
- distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
250
- return distr
251
-
252
-
253
- @jit
254
- def trace_points_per_decision_node_distr(
255
- trace: TreeHeaps, X: UInt[Array, 'p n']
256
- ) -> Int32[Array, 'trace_length n+1']:
257
- """Separately histogram points-per-node counts over a sequence of sets of trees.
258
-
259
- For each set of trees, count how many parent-of-leaf nodes select each
260
- possible amount of points.
192
+ Count how many nodes in a tree select each possible amount of points,
193
+ over a certain subset of nodes.
261
194
 
262
195
  Parameters
263
196
  ----------
264
- trace
265
- The sequence of sets of trees. The variables must have broadcast shape
266
- (trace_length, num_trees).
267
197
  X
268
198
  The set of points to count.
269
-
270
- Returns
271
- -------
272
- A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t.
273
- """
274
-
275
- def loop(_, trace):
276
- return None, forest_points_per_decision_node_distr(trace, X)
277
-
278
- _, distr = lax.scan(loop, None, trace)
279
- return distr
280
-
281
-
282
- def points_per_leaf_distr(
283
- var_tree: UInt[Array, ' 2**(d-1)'],
284
- split_tree: UInt[Array, ' 2**(d-1)'],
285
- X: UInt[Array, 'p n'],
286
- ) -> Int32[Array, ' n+1']:
287
- """Histogram points-per-leaf counts in a tree.
288
-
289
- Count how many leaves in a tree select each possible amount of points.
290
-
291
- Parameters
292
- ----------
293
199
  var_tree
294
200
  The variables of the decision rules.
295
201
  split_tree
296
202
  The cutpoints of the decision rules.
297
- X
298
- The set of points to count.
299
-
300
- Returns
301
- -------
302
- A vector where the i-th element counts how many leaves have i points.
303
- """
304
- traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
305
- indices = traverse_tree_X(X, var_tree, split_tree)
306
- count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1)
307
- is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
308
- return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf)
309
-
310
-
311
- def forest_points_per_leaf_distr(
312
- trees: TreeHeaps, X: UInt[Array, 'p n']
313
- ) -> Int32[Array, ' n+1']:
314
- """Histogram points-per-leaf counts over a set of trees.
315
-
316
- Count how many leaves in a set of trees select each possible amount of points.
317
-
318
- Parameters
319
- ----------
320
- trees
321
- The set of trees. The variables must have broadcast shape (num_trees,).
322
- X
323
- The set of points to count.
203
+ node_type
204
+ The type of nodes to consider. Can be:
205
+
206
+ 'leaf'
207
+ Count only leaf nodes.
208
+ 'leaf-parent'
209
+ Count only parent-of-leaf nodes.
210
+ sum_batch_axis
211
+ Aggregate the histogram over these batch axes, counting how many nodes
212
+ have each possible amount of points over subsets of trees instead of
213
+ in each tree separately.
324
214
 
325
215
  Returns
326
216
  -------
327
- A vector where the i-th element counts how many leaves have i points.
217
+ A vector where the i-th element counts how many nodes have i points.
328
218
  """
329
- distr = jnp.zeros(X.shape[1] + 1, int)
330
-
331
- def loop(distr, heaps: tuple[Array, Array]):
332
- return distr + points_per_leaf_distr(*heaps, X), None
333
-
334
- distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
335
- return distr
336
-
337
-
338
- @jit
339
- def trace_points_per_leaf_distr(
340
- trace: TreeHeaps, X: UInt[Array, 'p n']
341
- ) -> Int32[Array, 'trace_length n+1']:
342
- """Separately histogram points-per-leaf counts over a sequence of sets of trees.
343
-
344
- For each set of trees, count how many leaves select each possible amount of
345
- points.
346
-
347
- Parameters
348
- ----------
349
- trace
350
- The sequence of sets of trees. The variables must have broadcast shape
351
- (trace_length, num_trees).
352
- X
353
- The set of points to count.
354
-
355
- Returns
356
- -------
357
- A matrix where element (t,i) counts how many leaves have i points in set t.
358
- """
359
-
360
- def loop(_, trace):
361
- return None, forest_points_per_leaf_distr(trace, X)
219
+ batch_ndim = var_tree.ndim - 1
220
+ axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
221
+
222
+ def func(
223
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
224
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
225
+ ) -> Int32[Array, '*reduced_batch_shape n+1']:
226
+ indices: UInt[Array, '*batch_shape n']
227
+ indices = traverse_forest(X, var_tree, split_tree)
228
+
229
+ @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)')
230
+ def count_points(
231
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
232
+ indices: UInt[Array, '*batch_shape n'],
233
+ ) -> (
234
+ tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']]
235
+ | tuple[
236
+ UInt[Array, '*batch_shape 2**(d-1)'],
237
+ Bool[Array, '*batch_shape 2**(d-1)'],
238
+ ]
239
+ ):
240
+ if node_type == 'leaf-parent':
241
+ indices >>= 1
242
+ predicate = is_leaves_parent(split_tree)
243
+ elif node_type == 'leaf':
244
+ predicate = is_actual_leaf(split_tree, add_bottom_level=True)
245
+ else:
246
+ raise ValueError(node_type)
247
+ count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0)
248
+ return count_tree, predicate
249
+
250
+ count_tree, predicate = count_points(split_tree, indices)
251
+
252
+ def count_nodes(
253
+ count_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
254
+ predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
255
+ ) -> Int32[Array, ' n+1']:
256
+ return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate)
257
+
258
+ # vmap count_nodes over non-batched dims
259
+ for i in reversed(range(batch_ndim)):
260
+ neg_i = i - var_tree.ndim
261
+ if i not in axes:
262
+ count_nodes = vmap(count_nodes, in_axes=neg_i)
263
+
264
+ return count_nodes(count_tree, predicate)
265
+
266
+ # automatically batch over all batch dimensions
267
+ max_io_nbytes = 2**27 # 128 MiB
268
+ out_dim_shift = len(axes)
269
+ for i in reversed(range(batch_ndim)):
270
+ if i in axes:
271
+ out_dim_shift -= 1
272
+ else:
273
+ func = autobatch(func, max_io_nbytes, i, i - out_dim_shift)
274
+ assert out_dim_shift == 0
362
275
 
363
- _, distr = lax.scan(loop, None, trace)
364
- return distr
276
+ return func(var_tree, split_tree)
365
277
 
366
278
 
367
279
  check_functions = []
@@ -400,29 +312,36 @@ def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
400
312
  return (
401
313
  tree.var_tree.dtype == expected_var_dtype
402
314
  and tree.split_tree.dtype == expected_split_dtype
315
+ and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
403
316
  )
404
317
 
405
318
 
406
319
  @check
407
- def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001
320
+ def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
408
321
  """Check that array sizes are coherent."""
409
322
  return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size
410
323
 
411
324
 
412
325
  @check
413
- def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
326
+ def check_unused_node(
327
+ tree: TreeHeaps, _max_split: UInt[Array, ' p']
328
+ ) -> Bool[Array, '']:
414
329
  """Check that the unused node slot at index 0 is not dirty."""
415
330
  return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
416
331
 
417
332
 
418
333
  @check
419
- def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
334
+ def check_leaf_values(
335
+ tree: TreeHeaps, _max_split: UInt[Array, ' p']
336
+ ) -> Bool[Array, '']:
420
337
  """Check that all leaf values are not inf of nan."""
421
338
  return jnp.all(jnp.isfinite(tree.leaf_tree))
422
339
 
423
340
 
424
341
  @check
425
- def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
342
+ def check_stray_nodes(
343
+ tree: TreeHeaps, _max_split: UInt[Array, ' p']
344
+ ) -> Bool[Array, '']:
426
345
  """Check if there is any marked-non-leaf node with a marked-leaf parent."""
427
346
  index = jnp.arange(
428
347
  2 * tree.split_tree.size,
@@ -446,12 +365,12 @@ def check_rule_consistency(
446
365
 
447
366
  # initial boundaries of decision rules. use extreme integers instead of 0,
448
367
  # max_split to avoid checking if there is something out of bounds.
449
- small = jnp.iinfo(jnp.int32).min
450
- large = jnp.iinfo(jnp.int32).max
451
- lower = jnp.full(max_split.size, small, jnp.int32)
452
- upper = jnp.full(max_split.size, large, jnp.int32)
453
- # specify the type explicitly, otherwise they are weakly types and get
454
- # implicitly converted to split.dtype (typically uint8) in the expressions
368
+ dtype = tree.split_tree.dtype
369
+ small = jnp.iinfo(dtype).min
370
+ large = jnp.iinfo(dtype).max
371
+ lower = jnp.full(max_split.size, small, dtype)
372
+ upper = jnp.full(max_split.size, large, dtype)
373
+ # the split must be in (lower[var], upper[var]]
455
374
 
456
375
  def _check_recursive(node, lower, upper):
457
376
  # read decision rule
@@ -464,20 +383,14 @@ def check_rule_consistency(
464
383
  upper_var = upper.at[var].get(mode='fill', fill_value=large)
465
384
 
466
385
  # check rule is in bounds
467
- bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False)
386
+ bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False)
468
387
 
469
388
  # recurse
470
389
  if node < tree.var_tree.size // 2:
471
- bad |= _check_recursive(
472
- 2 * node,
473
- lower,
474
- upper.at[jnp.where(split, var, max_split.size)].set(split),
475
- )
476
- bad |= _check_recursive(
477
- 2 * node + 1,
478
- lower.at[jnp.where(split, var, max_split.size)].set(split),
479
- upper,
480
- )
390
+ idx = jnp.where(split, var, max_split.size)
391
+ bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1))
392
+ bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper)
393
+
481
394
  return bad
482
395
 
483
396
  return ~_check_recursive(1, lower, upper)
@@ -557,30 +470,39 @@ def describe_error(error: int | Integer[Array, '']) -> list[str]:
557
470
 
558
471
 
559
472
  @jit
560
- @partial(vmap_nodoc, in_axes=(0, None))
561
473
  def check_trace(
562
474
  trace: TreeHeaps, max_split: UInt[Array, ' p']
563
- ) -> UInt[Array, 'trace_length num_trees']:
564
- """Check the validity of a sequence of sets of trees.
475
+ ) -> UInt[Array, '*batch_shape']:
476
+ """Check the validity of a set of trees.
565
477
 
566
478
  Use `describe_error` to parse the error codes returned by this function.
567
479
 
568
480
  Parameters
569
481
  ----------
570
482
  trace
571
- The sequence of sets of trees to check. The tree arrays must have
572
- broadcast shape (trace_length, num_trees). This object can have
573
- additional attributes beyond the tree arrays, they are ignored.
483
+ The set of trees to check. This object can have additional attributes
484
+ beyond the tree arrays, they are ignored.
574
485
  max_split
575
486
  The maximum split value for each variable.
576
487
 
577
488
  Returns
578
489
  -------
579
- A matrix of error codes for each tree.
490
+ A tensor of error codes for each tree.
580
491
  """
581
- trees = TreesTrace.from_dataclass(trace)
582
- check_forest = vmap(check_tree, in_axes=(0, None))
583
- return check_forest(trees, max_split)
492
+ # vectorize check_tree over all batch dimensions
493
+ unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split)
494
+ is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
495
+ signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()'
496
+ vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature)
497
+
498
+ # automatically batch over all batch dimensions
499
+ max_io_nbytes = 2**24 # 16 MiB
500
+ batch_ndim = trace.split_tree.ndim - 1
501
+ batched_check_tree = vec_check_tree
502
+ for i in reversed(range(batch_ndim)):
503
+ batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i)
504
+
505
+ return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree)
584
506
 
585
507
 
586
508
  def _get_next_line(s: str, i: int) -> tuple[str, int]:
@@ -592,24 +514,19 @@ def _get_next_line(s: str, i: int) -> tuple[str, int]:
592
514
 
593
515
 
594
516
  class BARTTraceMeta(Module):
595
- """Metadata of R BART tree traces.
596
-
597
- Parameters
598
- ----------
599
- ndpost
600
- The number of posterior draws.
601
- ntree
602
- The number of trees in the model.
603
- numcut
604
- The maximum split value for each variable.
605
- heap_size
606
- The size of the heap required to store the trees.
607
- """
517
+ """Metadata of R BART tree traces."""
608
518
 
609
519
  ndpost: int = field(static=True)
520
+ """The number of posterior draws."""
521
+
610
522
  ntree: int = field(static=True)
523
+ """The number of trees in the model."""
524
+
611
525
  numcut: UInt[Array, ' p']
526
+ """The maximum split value for each variable."""
527
+
612
528
  heap_size: int = field(static=True)
529
+ """The size of the heap required to store the trees."""
613
530
 
614
531
 
615
532
  def scan_BART_trees(trees: str) -> BARTTraceMeta:
@@ -801,25 +718,24 @@ class SamplePriorStack(Module):
801
718
 
802
719
  Each level of the stack represents a recursion into a child node in a
803
720
  binary tree of maximum depth `d`.
804
-
805
- Parameters
806
- ----------
807
- nonterminal
808
- Whether the node is valid or the recursion is into unused node slots.
809
- lower
810
- upper
811
- The available cutpoints along ``var`` are in the integer range
812
- ``[1 + lower[var], 1 + upper[var])``.
813
- var
814
- split
815
- The variable and cutpoint of a decision node.
816
721
  """
817
722
 
818
723
  nonterminal: Bool[Array, ' d-1']
724
+ """Whether the node is valid or the recursion is into unused node slots."""
725
+
819
726
  lower: UInt[Array, 'd-1 p']
727
+ """The available cutpoints along ``var`` are in the integer range
728
+ ``[1 + lower[var], 1 + upper[var])``."""
729
+
820
730
  upper: UInt[Array, 'd-1 p']
731
+ """The available cutpoints along ``var`` are in the integer range
732
+ ``[1 + lower[var], 1 + upper[var])``."""
733
+
821
734
  var: UInt[Array, ' d-1']
735
+ """The variable of a decision node."""
736
+
822
737
  split: UInt[Array, ' d-1']
738
+ """The cutpoint of a decision node."""
823
739
 
824
740
  @classmethod
825
741
  def initial(
@@ -850,19 +766,16 @@ class SamplePriorStack(Module):
850
766
 
851
767
 
852
768
  class SamplePriorTrees(Module):
853
- """Object holding the trees generated by `sample_prior`.
854
-
855
- Parameters
856
- ----------
857
- leaf_tree
858
- var_tree
859
- split_tree
860
- The arrays representing the trees, see `bartz.grove`.
861
- """
769
+ """Object holding the trees generated by `sample_prior`."""
862
770
 
863
771
  leaf_tree: Float32[Array, '* 2**d']
772
+ """The array representing the trees, see `bartz.grove`."""
773
+
864
774
  var_tree: UInt[Array, '* 2**(d-1)']
775
+ """The array representing the trees, see `bartz.grove`."""
776
+
865
777
  split_tree: UInt[Array, '* 2**(d-1)']
778
+ """The array representing the trees, see `bartz.grove`."""
866
779
 
867
780
  @classmethod
868
781
  def initial(
@@ -903,21 +816,16 @@ class SamplePriorTrees(Module):
903
816
 
904
817
 
905
818
  class SamplePriorCarry(Module):
906
- """Object holding values carried along the recursion in `sample_prior`.
907
-
908
- Parameters
909
- ----------
910
- key
911
- A jax random key used to sample decision rules.
912
- stack
913
- The stack used to manage the recursion.
914
- trees
915
- The output arrays.
916
- """
819
+ """Object holding values carried along the recursion in `sample_prior`."""
917
820
 
918
821
  key: Key[Array, '']
822
+ """A jax random key used to sample decision rules."""
823
+
919
824
  stack: SamplePriorStack
825
+ """The stack used to manage the recursion."""
826
+
920
827
  trees: SamplePriorTrees
828
+ """The output arrays."""
921
829
 
922
830
  @classmethod
923
831
  def initial(
@@ -958,21 +866,17 @@ class SamplePriorX(Module):
958
866
 
959
867
  The sequence of nodes to visit is pre-computed recursively once, unrolling
960
868
  the recursion schedule.
961
-
962
- Parameters
963
- ----------
964
- node
965
- The heap index of the node to visit.
966
- depth
967
- The depth of the node.
968
- next_depth
969
- The depth of the next node to visit, either the left child or the right
970
- sibling of the node or of an ancestor.
971
869
  """
972
870
 
973
871
  node: Int32[Array, ' 2**(d-1)-1']
872
+ """The heap index of the node to visit."""
873
+
974
874
  depth: Int32[Array, ' 2**(d-1)-1']
875
+ """The depth of the node."""
876
+
975
877
  next_depth: Int32[Array, ' 2**(d-1)-1']
878
+ """The depth of the next node to visit, either the left child or the right
879
+ sibling of the node or of an ancestor."""
976
880
 
977
881
  @classmethod
978
882
  def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
@@ -992,7 +896,7 @@ class SamplePriorX(Module):
992
896
  assert len(seq) == 2**p_nonterminal.size - 1
993
897
  node = [node for node, depth in seq]
994
898
  depth = [depth for node, depth in seq]
995
- next_depth = depth[1:] + [p_nonterminal.size]
899
+ next_depth = [*depth[1:], p_nonterminal.size]
996
900
  return cls(
997
901
  node=jnp.array(node),
998
902
  depth=jnp.array(depth),
@@ -1173,18 +1077,18 @@ def sample_prior(
1173
1077
  return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees)
1174
1078
 
1175
1079
 
1176
- class debug_gbart(gbart):
1177
- """A subclass of `gbart` that adds debugging functionality.
1080
+ class debug_mc_gbart(mc_gbart):
1081
+ """A subclass of `mc_gbart` that adds debugging functionality.
1178
1082
 
1179
1083
  Parameters
1180
1084
  ----------
1181
1085
  *args
1182
- Passed to `gbart`.
1086
+ Passed to `mc_gbart`.
1183
1087
  check_trees
1184
1088
  If `True`, check all trees with `check_trace` after running the MCMC,
1185
1089
  and assert that they are all valid. Set to `False` to allow jax tracing.
1186
1090
  **kw
1187
- Passed to `gbart`.
1091
+ Passed to `mc_gbart`.
1188
1092
  """
1189
1093
 
1190
1094
  def __init__(self, *args, check_trees: bool = True, **kw):
@@ -1194,24 +1098,28 @@ class debug_gbart(gbart):
1194
1098
  bad_count = jnp.count_nonzero(bad)
1195
1099
  assert bad_count == 0
1196
1100
 
1197
- def show_tree(self, i_sample: int, i_tree: int, print_all: bool = False):
1101
+ def print_tree(
1102
+ self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
1103
+ ):
1198
1104
  """Print a single tree in human-readable format.
1199
1105
 
1200
1106
  Parameters
1201
1107
  ----------
1108
+ i_chain
1109
+ The index of the MCMC chain.
1202
1110
  i_sample
1203
- The index of the posterior sample.
1111
+ The index of the (post-burnin) sample in the chain.
1204
1112
  i_tree
1205
1113
  The index of the tree in the sample.
1206
1114
  print_all
1207
1115
  If `True`, also print the content of unused node slots.
1208
1116
  """
1209
1117
  tree = TreesTrace.from_dataclass(self._main_trace)
1210
- tree = tree_map(lambda x: x[i_sample, i_tree, :], tree)
1118
+ tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
1211
1119
  s = format_tree(tree, print_all=print_all)
1212
1120
  print(s) # noqa: T201, this method is intended for debug
1213
1121
 
1214
- def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, '']:
1122
+ def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']:
1215
1123
  """Return the harmonic mean of the error variance.
1216
1124
 
1217
1125
  Parameters
@@ -1225,33 +1133,36 @@ class debug_gbart(gbart):
1225
1133
  The harmonic mean 1/E[1/sigma^2] in the selected distribution.
1226
1134
  """
1227
1135
  bart = self._mcmc_state
1228
- assert bart.sigma2_alpha is not None
1136
+ assert bart.error_cov_df is not None
1229
1137
  assert bart.z is None
1138
+ # inverse gamma prior: alpha = df / 2, beta = scale / 2
1230
1139
  if prior:
1231
- alpha = bart.sigma2_alpha
1232
- beta = bart.sigma2_beta
1140
+ alpha = bart.error_cov_df / 2
1141
+ beta = bart.error_cov_scale / 2
1233
1142
  else:
1234
- resid = bart.resid
1235
- alpha = bart.sigma2_alpha + resid.size / 2
1236
- norm2 = resid @ resid
1237
- beta = bart.sigma2_beta + norm2 / 2
1238
- sigma2 = beta / alpha
1239
- return jnp.sqrt(sigma2)
1240
-
1241
- def compare_resid(self) -> tuple[Float32[Array, ' n'], Float32[Array, ' n']]:
1143
+ alpha = bart.error_cov_df / 2 + bart.resid.size / 2
1144
+ norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
1145
+ beta = bart.error_cov_scale / 2 + norm2 / 2
1146
+ error_cov_inv = alpha / beta
1147
+ return jnp.sqrt(lax.reciprocal(error_cov_inv))
1148
+
1149
+ def compare_resid(
1150
+ self,
1151
+ ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
1242
1152
  """Re-compute residuals to compare them with the updated ones.
1243
1153
 
1244
1154
  Returns
1245
1155
  -------
1246
- resid1 : Float32[Array, 'n']
1156
+ resid1 : Float32[Array, 'mc_cores n']
1247
1157
  The final state of the residuals updated during the MCMC.
1248
- resid2 : Float32[Array, 'n']
1158
+ resid2 : Float32[Array, 'mc_cores n']
1249
1159
  The residuals computed from the final state of the trees.
1250
1160
  """
1251
1161
  bart = self._mcmc_state
1252
1162
  resid1 = bart.resid
1253
1163
 
1254
- trees = evaluate_forest(bart.X, bart.forest)
1164
+ forests = TreesTrace.from_dataclass(bart.forest)
1165
+ trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1)
1255
1166
 
1256
1167
  if bart.z is not None:
1257
1168
  ref = bart.z
@@ -1261,14 +1172,16 @@ class debug_gbart(gbart):
1261
1172
 
1262
1173
  return resid1, resid2
1263
1174
 
1264
- def avg_acc(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
1175
+ def avg_acc(
1176
+ self,
1177
+ ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1265
1178
  """Compute the average acceptance rates of tree moves.
1266
1179
 
1267
1180
  Returns
1268
1181
  -------
1269
- acc_grow : Float32[Array, '']
1182
+ acc_grow : Float32[Array, 'mc_cores']
1270
1183
  The average acceptance rate of grow moves.
1271
- acc_prune : Float32[Array, '']
1184
+ acc_prune : Float32[Array, 'mc_cores']
1272
1185
  The average acceptance rate of prune moves.
1273
1186
  """
1274
1187
  trace = self._main_trace
@@ -1276,18 +1189,20 @@ class debug_gbart(gbart):
1276
1189
  def acc(prefix):
1277
1190
  acc = getattr(trace, f'{prefix}_acc_count')
1278
1191
  prop = getattr(trace, f'{prefix}_prop_count')
1279
- return acc.sum() / prop.sum()
1192
+ return acc.sum(axis=1) / prop.sum(axis=1)
1280
1193
 
1281
1194
  return acc('grow'), acc('prune')
1282
1195
 
1283
- def avg_prop(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
1196
+ def avg_prop(
1197
+ self,
1198
+ ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1284
1199
  """Compute the average proposal rate of grow and prune moves.
1285
1200
 
1286
1201
  Returns
1287
1202
  -------
1288
- prop_grow : Float32[Array, '']
1203
+ prop_grow : Float32[Array, 'mc_cores']
1289
1204
  The fraction of times grow was proposed instead of prune.
1290
- prop_prune : Float32[Array, '']
1205
+ prop_prune : Float32[Array, 'mc_cores']
1291
1206
  The fraction of times prune was proposed instead of grow.
1292
1207
 
1293
1208
  Notes
@@ -1298,61 +1213,86 @@ class debug_gbart(gbart):
1298
1213
  trace = self._main_trace
1299
1214
 
1300
1215
  def prop(prefix):
1301
- return getattr(trace, f'{prefix}_prop_count').sum()
1216
+ return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
1302
1217
 
1303
1218
  pgrow = prop('grow')
1304
1219
  pprune = prop('prune')
1305
1220
  total = pgrow + pprune
1306
1221
  return pgrow / total, pprune / total
1307
1222
 
1308
- def avg_move(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
1223
+ def avg_move(
1224
+ self,
1225
+ ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1309
1226
  """Compute the move rate.
1310
1227
 
1311
1228
  Returns
1312
1229
  -------
1313
- rate_grow : Float32[Array, '']
1230
+ rate_grow : Float32[Array, 'mc_cores']
1314
1231
  The fraction of times a grow move was proposed and accepted.
1315
- rate_prune : Float32[Array, '']
1232
+ rate_prune : Float32[Array, 'mc_cores']
1316
1233
  The fraction of times a prune move was proposed and accepted.
1317
1234
  """
1318
1235
  agrow, aprune = self.avg_acc()
1319
1236
  pgrow, pprune = self.avg_prop()
1320
1237
  return agrow * pgrow, aprune * pprune
1321
1238
 
1322
- def depth_distr(self) -> Float32[Array, 'trace_length d']:
1239
+ def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']:
1323
1240
  """Histogram of tree depths for each state of the trees.
1324
1241
 
1325
1242
  Returns
1326
1243
  -------
1327
1244
  A matrix where each row contains a histogram of tree depths.
1328
1245
  """
1329
- return trace_depth_distr(self._main_trace.split_tree)
1246
+ out: Int32[Array, '*chains samples d']
1247
+ out = forest_depth_distr(self._main_trace.split_tree)
1248
+ if out.ndim < 3:
1249
+ out = out[None, :, :]
1250
+ return out
1251
+
1252
+ def _points_per_node_distr(
1253
+ self, node_type: str
1254
+ ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
1255
+ out: Int32[Array, '*chains samples n+1']
1256
+ out = points_per_node_distr(
1257
+ self._mcmc_state.X,
1258
+ self._main_trace.var_tree,
1259
+ self._main_trace.split_tree,
1260
+ node_type,
1261
+ sum_batch_axis=-1,
1262
+ )
1263
+ if out.ndim < 3:
1264
+ out = out[None, :, :]
1265
+ return out
1330
1266
 
1331
- def points_per_decision_node_distr(self) -> Float32[Array, 'trace_length n+1']:
1267
+ def points_per_decision_node_distr(
1268
+ self,
1269
+ ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
1332
1270
  """Histogram of number of points belonging to parent-of-leaf nodes.
1333
1271
 
1334
1272
  Returns
1335
1273
  -------
1336
- A matrix where each row contains a histogram of number of points.
1274
+ For each chain, a matrix where each row contains a histogram of number of points.
1337
1275
  """
1338
- return trace_points_per_decision_node_distr(
1339
- self._main_trace, self._mcmc_state.X
1340
- )
1276
+ return self._points_per_node_distr('leaf-parent')
1341
1277
 
1342
- def points_per_leaf_distr(self) -> Float32[Array, 'trace_length n+1']:
1278
+ def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
1343
1279
  """Histogram of number of points belonging to leaves.
1344
1280
 
1345
1281
  Returns
1346
1282
  -------
1347
1283
  A matrix where each row contains a histogram of number of points.
1348
1284
  """
1349
- return trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
1285
+ return self._points_per_node_distr('leaf')
1350
1286
 
1351
- def check_trees(self) -> UInt[Array, 'trace_length ntree']:
1287
+ def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']:
1352
1288
  """Apply `check_trace` to all the tree draws."""
1353
- return check_trace(self._main_trace, self._mcmc_state.forest.max_split)
1289
+ out: UInt[Array, '*chains samples num_trees']
1290
+ out = check_trace(self._main_trace, self._mcmc_state.forest.max_split)
1291
+ if out.ndim < 3:
1292
+ out = out[None, :, :]
1293
+ return out
1354
1294
 
1355
- def tree_goes_bad(self) -> Bool[Array, 'trace_length ntree']:
1295
+ def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']:
1356
1296
  """Find iterations where a tree becomes invalid.
1357
1297
 
1358
1298
  Returns
@@ -1360,5 +1300,20 @@ class debug_gbart(gbart):
1360
1300
  A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
1361
1301
  """
1362
1302
  bad = self.check_trees().astype(bool)
1363
- bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
1303
+ bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
1364
1304
  return bad & ~bad_before
1305
+
1306
+
1307
+ class debug_gbart(debug_mc_gbart, gbart):
1308
+ """A subclass of `gbart` that adds debugging functionality.
1309
+
1310
+ Parameters
1311
+ ----------
1312
+ *args
1313
+ Passed to `gbart`.
1314
+ check_trees
1315
+ If `True`, check all trees with `check_trace` after running the MCMC,
1316
+ and assert that they are all valid. Set to `False` to allow jax tracing.
1317
+ **kw
1318
+ Passed to `gbart`.
1319
+ """