bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/debug.py CHANGED
@@ -1,13 +1,75 @@
1
- import functools
1
+ # bartz/src/bartz/debug.py
2
+ #
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
2
24
 
3
- import jax
4
- from jax import lax
25
+ """Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
26
+
27
+ from collections.abc import Callable
28
+ from dataclasses import replace
29
+ from functools import partial
30
+ from math import ceil, log2
31
+ from re import fullmatch
32
+ from typing import Literal
33
+
34
+ import numpy
35
+ from equinox import Module, field
36
+ from jax import jit, lax, random, vmap
5
37
  from jax import numpy as jnp
38
+ from jax.tree_util import tree_map
39
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
40
+
41
+ from bartz.BART import gbart, mc_gbart
42
+ from bartz.BART._gbart import FloatLike
43
+ from bartz.grove import (
44
+ TreeHeaps,
45
+ evaluate_forest,
46
+ is_actual_leaf,
47
+ is_leaves_parent,
48
+ normalize_axis_tuple,
49
+ traverse_forest,
50
+ tree_depth,
51
+ tree_depths,
52
+ )
53
+ from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc
54
+ from bartz.jaxext import split as split_key
55
+ from bartz.mcmcloop import TreesTrace
56
+ from bartz.mcmcstep._moves import randint_masked
6
57
 
7
- from . import grove, jaxext
8
58
 
59
+ def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
60
+ """Convert a tree to a human-readable string.
9
61
 
10
- def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
62
+ Parameters
63
+ ----------
64
+ tree
65
+ A single tree to format.
66
+ print_all
67
+ If `True`, also print the contents of unused node slots in the arrays.
68
+
69
+ Returns
70
+ -------
71
+ A string representation of the tree.
72
+ """
11
73
  tee = '├──'
12
74
  corner = '└──'
13
75
  join = '│ '
@@ -15,12 +77,20 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
15
77
  down = '┐'
16
78
  bottom = '╢' # '┨' #
17
79
 
18
- def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
19
- if index >= len(leaf_tree):
80
+ def traverse_tree(
81
+ lines: list[str],
82
+ index: int,
83
+ depth: int,
84
+ indent: str,
85
+ first_indent: str,
86
+ next_indent: str,
87
+ unused: bool,
88
+ ):
89
+ if index >= len(tree.leaf_tree):
20
90
  return
21
91
 
22
- var = var_tree.at[index].get(mode='fill', fill_value=0)
23
- split = split_tree.at[index].get(mode='fill', fill_value=0)
92
+ var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item()
93
+ split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item()
24
94
 
25
95
  is_leaf = split == 0
26
96
  left_child = 2 * index
@@ -33,26 +103,26 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
33
103
  category = 'leaf'
34
104
  else:
35
105
  category = 'decision'
36
- node_str = f'{category}({var}, {split}, {leaf_tree[index]})'
106
+ node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
37
107
  else:
38
108
  assert not unused
39
109
  if is_leaf:
40
- node_str = f'{leaf_tree[index]:#.2g}'
110
+ node_str = f'{tree.leaf_tree[index]:#.2g}'
41
111
  else:
42
- node_str = f'({var}: {split})'
112
+ node_str = f'x{var} < {split}'
43
113
 
44
- if not is_leaf or (print_all and left_child < len(leaf_tree)):
114
+ if not is_leaf or (print_all and left_child < len(tree.leaf_tree)):
45
115
  link = down
46
- elif not print_all and left_child >= len(leaf_tree):
116
+ elif not print_all and left_child >= len(tree.leaf_tree):
47
117
  link = bottom
48
118
  else:
49
119
  link = ' '
50
120
 
51
- max_number = len(leaf_tree) - 1
121
+ max_number = len(tree.leaf_tree) - 1
52
122
  ndigits = len(str(max_number))
53
123
  number = str(index).rjust(ndigits)
54
124
 
55
- print(f' {number} {indent}{first_indent}{link}{node_str}')
125
+ lines.append(f' {number} {indent}{first_indent}{link}{node_str}')
56
126
 
57
127
  indent += next_indent
58
128
  unused = unused or is_leaf
@@ -60,125 +130,1190 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
60
130
  if unused and not print_all:
61
131
  return
62
132
 
63
- traverse_tree(left_child, depth + 1, indent, tee, join, unused)
64
- traverse_tree(right_child, depth + 1, indent, corner, space, unused)
133
+ traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused)
134
+ traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused)
135
+
136
+ lines = []
137
+ traverse_tree(lines, 1, 0, '', '', '', False)
138
+ return '\n'.join(lines)
139
+
65
140
 
66
- traverse_tree(1, 0, '', '', '', False)
141
+ def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
142
+ """Measure the depth of the tree.
67
143
 
144
+ Parameters
145
+ ----------
146
+ split_tree
147
+ The cutpoints of the decision rules.
68
148
 
69
- def tree_actual_depth(split_tree):
70
- is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
71
- depth = grove.tree_depths(is_leaf.size)
149
+ Returns
150
+ -------
151
+ The depth of the deepest leaf in the tree. The root is at depth 0.
152
+ """
153
+ # this could be done just with split_tree != 0
154
+ is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
155
+ depth = tree_depths(is_leaf.size)
72
156
  depth = jnp.where(is_leaf, depth, 0)
73
157
  return jnp.max(depth)
74
158
 
75
159
 
76
- def forest_depth_distr(split_trees):
77
- depth = grove.tree_depth(split_trees) + 1
78
- depths = jax.vmap(tree_actual_depth)(split_trees)
160
+ @jit
161
+ @partial(jnp.vectorize, signature='(nt,hts)->(d)')
162
+ def forest_depth_distr(
163
+ split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'],
164
+ ) -> Int32[Array, '*batch_shape d']:
165
+ """Histogram the depths of a set of trees.
166
+
167
+ Parameters
168
+ ----------
169
+ split_tree
170
+ The cutpoints of the decision rules of the trees.
171
+
172
+ Returns
173
+ -------
174
+ An integer vector where the i-th element counts how many trees have depth i.
175
+ """
176
+ depth = tree_depth(split_tree) + 1
177
+ depths = vmap(tree_actual_depth)(split_tree)
79
178
  return jnp.bincount(depths, length=depth)
80
179
 
81
180
 
82
- def trace_depth_distr(split_trees_trace):
83
- return jax.vmap(forest_depth_distr)(split_trees_trace)
181
+ @partial(jit, static_argnames=('node_type', 'sum_batch_axis'))
182
+ def points_per_node_distr(
183
+ X: UInt[Array, 'p n'],
184
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
185
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
186
+ node_type: Literal['leaf', 'leaf-parent'],
187
+ *,
188
+ sum_batch_axis: int | tuple[int, ...] = (),
189
+ ) -> Int32[Array, '*reduced_batch_shape n+1']:
190
+ """Histogram points-per-node counts in a set of trees.
84
191
 
192
+ Count how many nodes in a tree select each possible amount of points,
193
+ over a certain subset of nodes.
85
194
 
86
- def points_per_leaf_distr(var_tree, split_tree, X):
87
- traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
88
- indices = traverse_tree(X, var_tree, split_tree)
89
- count_tree = jnp.zeros(
90
- 2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
91
- )
92
- count_tree = count_tree.at[indices].add(1)
93
- is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
94
- return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
195
+ Parameters
196
+ ----------
197
+ X
198
+ The set of points to count.
199
+ var_tree
200
+ The variables of the decision rules.
201
+ split_tree
202
+ The cutpoints of the decision rules.
203
+ node_type
204
+ The type of nodes to consider. Can be:
205
+
206
+ 'leaf'
207
+ Count only leaf nodes.
208
+ 'leaf-parent'
209
+ Count only parent-of-leaf nodes.
210
+ sum_batch_axis
211
+ Aggregate the histogram over these batch axes, counting how many nodes
212
+ have each possible amount of points over subsets of trees instead of
213
+ in each tree separately.
214
+
215
+ Returns
216
+ -------
217
+ A vector where the i-th element counts how many nodes have i points.
218
+ """
219
+ batch_ndim = var_tree.ndim - 1
220
+ axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
221
+
222
+ def func(
223
+ var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
224
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
225
+ ) -> Int32[Array, '*reduced_batch_shape n+1']:
226
+ indices: UInt[Array, '*batch_shape n']
227
+ indices = traverse_forest(X, var_tree, split_tree)
228
+
229
+ @partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)')
230
+ def count_points(
231
+ split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
232
+ indices: UInt[Array, '*batch_shape n'],
233
+ ) -> (
234
+ tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']]
235
+ | tuple[
236
+ UInt[Array, '*batch_shape 2**(d-1)'],
237
+ Bool[Array, '*batch_shape 2**(d-1)'],
238
+ ]
239
+ ):
240
+ if node_type == 'leaf-parent':
241
+ indices >>= 1
242
+ predicate = is_leaves_parent(split_tree)
243
+ elif node_type == 'leaf':
244
+ predicate = is_actual_leaf(split_tree, add_bottom_level=True)
245
+ else:
246
+ raise ValueError(node_type)
247
+ count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0)
248
+ return count_tree, predicate
249
+
250
+ count_tree, predicate = count_points(split_tree, indices)
251
+
252
+ def count_nodes(
253
+ count_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
254
+ predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
255
+ ) -> Int32[Array, ' n+1']:
256
+ return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate)
257
+
258
+ # vmap count_nodes over non-batched dims
259
+ for i in reversed(range(batch_ndim)):
260
+ neg_i = i - var_tree.ndim
261
+ if i not in axes:
262
+ count_nodes = vmap(count_nodes, in_axes=neg_i)
263
+
264
+ return count_nodes(count_tree, predicate)
265
+
266
+ # automatically batch over all batch dimensions
267
+ max_io_nbytes = 2**27 # 128 MiB
268
+ out_dim_shift = len(axes)
269
+ for i in reversed(range(batch_ndim)):
270
+ if i in axes:
271
+ out_dim_shift -= 1
272
+ else:
273
+ func = autobatch(func, max_io_nbytes, i, i - out_dim_shift)
274
+ assert out_dim_shift == 0
275
+
276
+ return func(var_tree, split_tree)
277
+
278
+
279
+ check_functions = []
95
280
 
96
281
 
97
- def forest_points_per_leaf_distr(bart, X):
98
- distr = jnp.zeros(X.shape[1] + 1, int)
99
- trees = bart['var_trees'], bart['split_trees']
282
+ CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']]
100
283
 
101
- def loop(distr, tree):
102
- return distr + points_per_leaf_distr(*tree, X), None
103
284
 
104
- distr, _ = lax.scan(loop, distr, trees)
105
- return distr
285
+ def check(func: CheckFunc) -> CheckFunc:
286
+ """Add a function to a list of functions used to check trees.
106
287
 
288
+ Use to decorate functions that check whether a tree is valid in some way.
289
+ These functions are invoked automatically by `check_tree`, `check_trace` and
290
+ `debug_gbart`.
107
291
 
108
- def trace_points_per_leaf_distr(bart, X):
109
- def loop(_, bart):
110
- return None, forest_points_per_leaf_distr(bart, X)
292
+ Parameters
293
+ ----------
294
+ func
295
+ The function to add to the list. It must accept a `TreeHeaps` and a
296
+ `max_split` argument, and return a boolean scalar that indicates if the
297
+ tree is ok.
111
298
 
112
- _, distr = lax.scan(loop, None, bart)
113
- return distr
299
+ Returns
300
+ -------
301
+ The function unchanged.
302
+ """
303
+ check_functions.append(func)
304
+ return func
114
305
 
115
306
 
116
- def check_types(leaf_tree, var_tree, split_tree, max_split):
117
- expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
307
+ @check
308
+ def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
309
+ """Check that integer types are as small as possible and coherent."""
310
+ expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1)
118
311
  expected_split_dtype = max_split.dtype
119
312
  return (
120
- var_tree.dtype == expected_var_dtype
121
- and split_tree.dtype == expected_split_dtype
313
+ tree.var_tree.dtype == expected_var_dtype
314
+ and tree.split_tree.dtype == expected_split_dtype
315
+ and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
122
316
  )
123
317
 
124
318
 
125
- def check_sizes(leaf_tree, var_tree, split_tree, max_split):
126
- return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
319
+ @check
320
+ def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
321
+ """Check that array sizes are coherent."""
322
+ return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size
127
323
 
128
324
 
129
- def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
130
- return (var_tree[0] == 0) & (split_tree[0] == 0)
325
+ @check
326
+ def check_unused_node(
327
+ tree: TreeHeaps, _max_split: UInt[Array, ' p']
328
+ ) -> Bool[Array, '']:
329
+ """Check that the unused node slot at index 0 is not dirty."""
330
+ return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
131
331
 
132
332
 
133
- def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
134
- return jnp.all(jnp.isfinite(leaf_tree))
333
+ @check
334
+ def check_leaf_values(
335
+ tree: TreeHeaps, _max_split: UInt[Array, ' p']
336
+ ) -> Bool[Array, '']:
337
+ """Check that all leaf values are not inf of nan."""
338
+ return jnp.all(jnp.isfinite(tree.leaf_tree))
135
339
 
136
340
 
137
- def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
341
+ @check
342
+ def check_stray_nodes(
343
+ tree: TreeHeaps, _max_split: UInt[Array, ' p']
344
+ ) -> Bool[Array, '']:
345
+ """Check if there is any marked-non-leaf node with a marked-leaf parent."""
138
346
  index = jnp.arange(
139
- 2 * split_tree.size,
140
- dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
347
+ 2 * tree.split_tree.size,
348
+ dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
141
349
  )
142
350
  parent_index = index >> 1
143
- is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
144
- parent_is_leaf = split_tree[parent_index] == 0
351
+ is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0
352
+ parent_is_leaf = tree.split_tree[parent_index] == 0
145
353
  stray = is_not_leaf & parent_is_leaf
146
354
  stray = stray.at[1].set(False)
147
355
  return ~jnp.any(stray)
148
356
 
149
357
 
150
- check_functions = [
151
- check_types,
152
- check_sizes,
153
- check_unused_node,
154
- check_leaf_values,
155
- check_stray_nodes,
156
- ]
358
+ @check
359
+ def check_rule_consistency(
360
+ tree: TreeHeaps, max_split: UInt[Array, ' p']
361
+ ) -> bool | Bool[Array, '']:
362
+ """Check that decision rules define proper subsets of ancestor rules."""
363
+ if tree.var_tree.size < 4:
364
+ return True
157
365
 
366
+ # initial boundaries of decision rules. use extreme integers instead of 0,
367
+ # max_split to avoid checking if there is something out of bounds.
368
+ dtype = tree.split_tree.dtype
369
+ small = jnp.iinfo(dtype).min
370
+ large = jnp.iinfo(dtype).max
371
+ lower = jnp.full(max_split.size, small, dtype)
372
+ upper = jnp.full(max_split.size, large, dtype)
373
+ # the split must be in (lower[var], upper[var]]
158
374
 
159
- def check_tree(leaf_tree, var_tree, split_tree, max_split):
160
- error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
375
+ def _check_recursive(node, lower, upper):
376
+ # read decision rule
377
+ var = tree.var_tree[node]
378
+ split = tree.split_tree[node]
379
+
380
+ # get rule boundaries from ancestors. use fill value in case var is
381
+ # out of bounds, we don't want to check out of bounds in this function
382
+ lower_var = lower.at[var].get(mode='fill', fill_value=small)
383
+ upper_var = upper.at[var].get(mode='fill', fill_value=large)
384
+
385
+ # check rule is in bounds
386
+ bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False)
387
+
388
+ # recurse
389
+ if node < tree.var_tree.size // 2:
390
+ idx = jnp.where(split, var, max_split.size)
391
+ bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1))
392
+ bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper)
393
+
394
+ return bad
395
+
396
+ return ~_check_recursive(1, lower, upper)
397
+
398
+
399
+ @check
400
+ def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
401
+ """Check that #leaves = 1 + #(internal nodes)."""
402
+ is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True)
403
+ num_leaves = jnp.count_nonzero(is_leaf)
404
+ num_internal = jnp.count_nonzero(tree.split_tree)
405
+ return num_leaves == num_internal + 1
406
+
407
+
408
+ @check
409
+ def check_var_in_bounds(
410
+ tree: TreeHeaps, max_split: UInt[Array, ' p']
411
+ ) -> Bool[Array, '']:
412
+ """Check that variables are in [0, max_split.size)."""
413
+ decision_node = tree.split_tree.astype(bool)
414
+ in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size)
415
+ return jnp.all(in_bounds | ~decision_node)
416
+
417
+
418
+ @check
419
+ def check_split_in_bounds(
420
+ tree: TreeHeaps, max_split: UInt[Array, ' p']
421
+ ) -> Bool[Array, '']:
422
+ """Check that splits are in [0, max_split[var]]."""
423
+ max_split_var = (
424
+ max_split.astype(jnp.int32)
425
+ .at[tree.var_tree]
426
+ .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
427
+ )
428
+ return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var))
429
+
430
+
431
+ def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
432
+ """Check the validity of a tree.
433
+
434
+ Use `describe_error` to parse the error code returned by this function.
435
+
436
+ Parameters
437
+ ----------
438
+ tree
439
+ The tree to check.
440
+ max_split
441
+ The maximum split value for each variable.
442
+
443
+ Returns
444
+ -------
445
+ An integer where each bit indicates whether a check failed.
446
+ """
447
+ error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1)
161
448
  error = error_type(0)
162
449
  for i, func in enumerate(check_functions):
163
- ok = func(leaf_tree, var_tree, split_tree, max_split)
450
+ ok = func(tree, max_split)
164
451
  ok = jnp.bool_(ok)
165
452
  bit = (~ok) << i
166
453
  error |= bit
167
454
  return error
168
455
 
169
456
 
170
- def describe_error(error):
457
+ def describe_error(error: int | Integer[Array, '']) -> list[str]:
458
+ """Describe the error code returned by `check_tree`.
459
+
460
+ Parameters
461
+ ----------
462
+ error
463
+ The error code returned by `check_tree`.
464
+
465
+ Returns
466
+ -------
467
+ A list of the function names that implement the failed checks.
468
+ """
171
469
  return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
172
470
 
173
471
 
174
- check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None))
472
+ @jit
473
+ def check_trace(
474
+ trace: TreeHeaps, max_split: UInt[Array, ' p']
475
+ ) -> UInt[Array, '*batch_shape']:
476
+ """Check the validity of a set of trees.
477
+
478
+ Use `describe_error` to parse the error codes returned by this function.
479
+
480
+ Parameters
481
+ ----------
482
+ trace
483
+ The set of trees to check. This object can have additional attributes
484
+ beyond the tree arrays, they are ignored.
485
+ max_split
486
+ The maximum split value for each variable.
487
+
488
+ Returns
489
+ -------
490
+ A tensor of error codes for each tree.
491
+ """
492
+ # vectorize check_tree over all batch dimensions
493
+ unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split)
494
+ is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
495
+ signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()'
496
+ vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature)
497
+
498
+ # automatically batch over all batch dimensions
499
+ max_io_nbytes = 2**24 # 16 MiB
500
+ batch_ndim = trace.split_tree.ndim - 1
501
+ batched_check_tree = vec_check_tree
502
+ for i in reversed(range(batch_ndim)):
503
+ batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i)
175
504
 
505
+ return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree)
176
506
 
177
- @functools.partial(jax.vmap, in_axes=(0, None))
178
- def check_trace(trace, state):
179
- return check_forest(
180
- trace['leaf_trees'],
181
- trace['var_trees'],
182
- trace['split_trees'],
183
- state.max_split,
507
+
508
+ def _get_next_line(s: str, i: int) -> tuple[str, int]:
509
+ """Get the next line from a string and the new index."""
510
+ i_new = s.find('\n', i)
511
+ if i_new == -1:
512
+ return s[i:], len(s)
513
+ return s[i:i_new], i_new + 1
514
+
515
+
516
+ class BARTTraceMeta(Module):
517
+ """Metadata of R BART tree traces."""
518
+
519
+ ndpost: int = field(static=True)
520
+ """The number of posterior draws."""
521
+
522
+ ntree: int = field(static=True)
523
+ """The number of trees in the model."""
524
+
525
+ numcut: UInt[Array, ' p']
526
+ """The maximum split value for each variable."""
527
+
528
+ heap_size: int = field(static=True)
529
+ """The size of the heap required to store the trees."""
530
+
531
+
532
+ def scan_BART_trees(trees: str) -> BARTTraceMeta:
533
+ """Scan an R BART tree trace checking for errors and parsing metadata.
534
+
535
+ Parameters
536
+ ----------
537
+ trees
538
+ The string representation of a trace of trees of the R BART package.
539
+ Can be accessed from ``mc_gbart(...).treedraws['trees']``.
540
+
541
+ Returns
542
+ -------
543
+ An object containing the metadata.
544
+
545
+ Raises
546
+ ------
547
+ ValueError
548
+ If the string is malformed or contains leftover characters.
549
+ """
550
+ # parse first line
551
+ line, i_char = _get_next_line(trees, 0)
552
+ i_line = 1
553
+ match = fullmatch(r'(\d+) (\d+) (\d+)', line)
554
+ if match is None:
555
+ msg = f'Malformed header at {i_line=}'
556
+ raise ValueError(msg)
557
+ ndpost, ntree, p = map(int, match.groups())
558
+
559
+ # initial values for maxima
560
+ max_heap_index = 0
561
+ numcut = numpy.zeros(p, int)
562
+
563
+ # cycle over iterations and trees
564
+ for i_iter in range(ndpost):
565
+ for i_tree in range(ntree):
566
+ # parse first line of tree definition
567
+ line, i_char = _get_next_line(trees, i_char)
568
+ i_line += 1
569
+ match = fullmatch(r'(\d+)', line)
570
+ if match is None:
571
+ msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
572
+ raise ValueError(msg)
573
+ num_nodes = int(line)
574
+
575
+ # cycle over nodes
576
+ for i_node in range(num_nodes):
577
+ # parse node definition
578
+ line, i_char = _get_next_line(trees, i_char)
579
+ i_line += 1
580
+ match = fullmatch(
581
+ r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
582
+ )
583
+ if match is None:
584
+ msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
585
+ raise ValueError(msg)
586
+ i_heap = int(match.group(1))
587
+ var = int(match.group(2))
588
+ split = int(match.group(3))
589
+
590
+ # update maxima
591
+ numcut[var] = max(numcut[var], split)
592
+ max_heap_index = max(max_heap_index, i_heap)
593
+
594
+ assert i_char <= len(trees)
595
+ if i_char < len(trees):
596
+ msg = f'Leftover {len(trees) - i_char} characters in string'
597
+ raise ValueError(msg)
598
+
599
+ # determine minimal integer type for numcut
600
+ numcut += 1 # because BART is 0-based
601
+ split_dtype = minimal_unsigned_dtype(numcut.max())
602
+ numcut = jnp.array(numcut.astype(split_dtype))
603
+
604
+ # determine minimum heap size to store the trees
605
+ heap_size = 2 ** ceil(log2(max_heap_index + 1))
606
+
607
+ return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size)
608
+
609
+
610
+ class TraceWithOffset(Module):
611
+ """Implementation of `bartz.mcmcloop.Trace`."""
612
+
613
+ leaf_tree: Float32[Array, 'ndpost ntree 2**d']
614
+ var_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
615
+ split_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
616
+ offset: Float32[Array, ' ndpost']
617
+
618
+ @classmethod
619
+ def from_trees_trace(
620
+ cls, trees: TreeHeaps, offset: Float32[Array, '']
621
+ ) -> 'TraceWithOffset':
622
+ """Create a `TraceWithOffset` from a `TreeHeaps`."""
623
+ ndpost, _, _ = trees.leaf_tree.shape
624
+ return cls(
625
+ leaf_tree=trees.leaf_tree,
626
+ var_tree=trees.var_tree,
627
+ split_tree=trees.split_tree,
628
+ offset=jnp.full(ndpost, offset),
629
+ )
630
+
631
+
632
+ def trees_BART_to_bartz(
633
+ trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
634
+ ) -> tuple[TraceWithOffset, BARTTraceMeta]:
635
+ """Convert trees from the R BART format to the bartz format.
636
+
637
+ Parameters
638
+ ----------
639
+ trees
640
+ The string representation of a trace of trees of the R BART package.
641
+ Can be accessed from ``mc_gbart(...).treedraws['trees']``.
642
+ min_maxdepth
643
+ The maximum tree depth of the output will be set to the maximum
644
+ observed depth in the input trees. Use this parameter to require at
645
+ least this maximum depth in the output format.
646
+ offset
647
+ The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
648
+ summed to the sum of trees. To match that behavior, this function
649
+ returns an offset as well, zero by default. Set with this parameter
650
+ otherwise.
651
+
652
+ Returns
653
+ -------
654
+ trace : TraceWithOffset
655
+ A representation of the trees compatible with the trace returned by
656
+ `bartz.mcmcloop.run_mcmc`.
657
+ meta : BARTTraceMeta
658
+ The metadata of the trace, containing the number of iterations, trees,
659
+ and the maximum split value.
660
+ """
661
+ # scan all the string checking for errors and determining sizes
662
+ meta = scan_BART_trees(trees)
663
+
664
+ # skip first line
665
+ _, i_char = _get_next_line(trees, 0)
666
+
667
+ heap_size = max(meta.heap_size, 2**min_maxdepth)
668
+ leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32)
669
+ var_trees = numpy.zeros(
670
+ (meta.ndpost, meta.ntree, heap_size // 2),
671
+ dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
672
+ )
673
+ split_trees = numpy.zeros(
674
+ (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
184
675
  )
676
+
677
+ # cycle over iterations and trees
678
+ for i_iter in range(meta.ndpost):
679
+ for i_tree in range(meta.ntree):
680
+ # parse first line of tree definition
681
+ line, i_char = _get_next_line(trees, i_char)
682
+ num_nodes = int(line)
683
+
684
+ is_internal = numpy.zeros(heap_size // 2, dtype=bool)
685
+
686
+ # cycle over nodes
687
+ for _ in range(num_nodes):
688
+ # parse node definition
689
+ line, i_char = _get_next_line(trees, i_char)
690
+ values = line.split()
691
+ i_heap = int(values[0])
692
+ var = int(values[1])
693
+ split = int(values[2])
694
+ leaf = float(values[3])
695
+
696
+ # update values
697
+ leaf_trees[i_iter, i_tree, i_heap] = leaf
698
+ is_internal[i_heap // 2] = True
699
+ if i_heap < heap_size // 2:
700
+ var_trees[i_iter, i_tree, i_heap] = var
701
+ split_trees[i_iter, i_tree, i_heap] = split + 1
702
+
703
+ is_internal[0] = False
704
+ split_trees[i_iter, i_tree, ~is_internal] = 0
705
+
706
+ return TraceWithOffset(
707
+ leaf_tree=jnp.array(leaf_trees),
708
+ var_tree=jnp.array(var_trees),
709
+ split_tree=jnp.array(split_trees),
710
+ offset=jnp.zeros(meta.ndpost)
711
+ if offset is None
712
+ else jnp.full(meta.ndpost, offset),
713
+ ), meta
714
+
715
+
716
+ class SamplePriorStack(Module):
717
+ """Represent the manually managed stack used in `sample_prior`.
718
+
719
+ Each level of the stack represents a recursion into a child node in a
720
+ binary tree of maximum depth `d`.
721
+ """
722
+
723
+ nonterminal: Bool[Array, ' d-1']
724
+ """Whether the node is valid or the recursion is into unused node slots."""
725
+
726
+ lower: UInt[Array, 'd-1 p']
727
+ """The available cutpoints along ``var`` are in the integer range
728
+ ``[1 + lower[var], 1 + upper[var])``."""
729
+
730
+ upper: UInt[Array, 'd-1 p']
731
+ """The available cutpoints along ``var`` are in the integer range
732
+ ``[1 + lower[var], 1 + upper[var])``."""
733
+
734
+ var: UInt[Array, ' d-1']
735
+ """The variable of a decision node."""
736
+
737
+ split: UInt[Array, ' d-1']
738
+ """The cutpoint of a decision node."""
739
+
740
+ @classmethod
741
+ def initial(
742
+ cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
743
+ ) -> 'SamplePriorStack':
744
+ """Initialize the stack.
745
+
746
+ Parameters
747
+ ----------
748
+ p_nonterminal
749
+ The prior probability of a node being non-terminal conditional on
750
+ its ancestors and on having available decision rules, at each depth.
751
+ max_split
752
+ The number of cutpoints along each variable.
753
+
754
+ Returns
755
+ -------
756
+ A `SamplePriorStack` initialized to start the recursion.
757
+ """
758
+ var_dtype = minimal_unsigned_dtype(max_split.size - 1)
759
+ return cls(
760
+ nonterminal=jnp.ones(p_nonterminal.size, bool),
761
+ lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
762
+ upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
763
+ var=jnp.zeros(p_nonterminal.size, var_dtype),
764
+ split=jnp.zeros(p_nonterminal.size, max_split.dtype),
765
+ )
766
+
767
+
768
+ class SamplePriorTrees(Module):
769
+ """Object holding the trees generated by `sample_prior`."""
770
+
771
+ leaf_tree: Float32[Array, '* 2**d']
772
+ """The array representing the trees, see `bartz.grove`."""
773
+
774
+ var_tree: UInt[Array, '* 2**(d-1)']
775
+ """The array representing the trees, see `bartz.grove`."""
776
+
777
+ split_tree: UInt[Array, '* 2**(d-1)']
778
+ """The array representing the trees, see `bartz.grove`."""
779
+
780
+ @classmethod
781
+ def initial(
782
+ cls,
783
+ key: Key[Array, ''],
784
+ sigma_mu: Float32[Array, ''],
785
+ p_nonterminal: Float32[Array, ' d-1'],
786
+ max_split: UInt[Array, ' p'],
787
+ ) -> 'SamplePriorTrees':
788
+ """Initialize the trees.
789
+
790
+ The leaves are already correct and do not need to be changed.
791
+
792
+ Parameters
793
+ ----------
794
+ key
795
+ A jax random key.
796
+ sigma_mu
797
+ The prior standard deviation of each leaf.
798
+ p_nonterminal
799
+ The prior probability of a node being non-terminal conditional on
800
+ its ancestors and on having available decision rules, at each depth.
801
+ max_split
802
+ The number of cutpoints along each variable.
803
+
804
+ Returns
805
+ -------
806
+ Trees initialized with random leaves and stub tree structures.
807
+ """
808
+ heap_size = 2 ** (p_nonterminal.size + 1)
809
+ return cls(
810
+ leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
811
+ var_tree=jnp.zeros(
812
+ heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
813
+ ),
814
+ split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
815
+ )
816
+
817
+
818
+ class SamplePriorCarry(Module):
819
+ """Object holding values carried along the recursion in `sample_prior`."""
820
+
821
+ key: Key[Array, '']
822
+ """A jax random key used to sample decision rules."""
823
+
824
+ stack: SamplePriorStack
825
+ """The stack used to manage the recursion."""
826
+
827
+ trees: SamplePriorTrees
828
+ """The output arrays."""
829
+
830
+ @classmethod
831
+ def initial(
832
+ cls,
833
+ key: Key[Array, ''],
834
+ sigma_mu: Float32[Array, ''],
835
+ p_nonterminal: Float32[Array, ' d-1'],
836
+ max_split: UInt[Array, ' p'],
837
+ ) -> 'SamplePriorCarry':
838
+ """Initialize the carry object.
839
+
840
+ Parameters
841
+ ----------
842
+ key
843
+ A jax random key.
844
+ sigma_mu
845
+ The prior standard deviation of each leaf.
846
+ p_nonterminal
847
+ The prior probability of a node being non-terminal conditional on
848
+ its ancestors and on having available decision rules, at each depth.
849
+ max_split
850
+ The number of cutpoints along each variable.
851
+
852
+ Returns
853
+ -------
854
+ A `SamplePriorCarry` initialized to start the recursion.
855
+ """
856
+ keys = split_key(key)
857
+ return cls(
858
+ keys.pop(),
859
+ SamplePriorStack.initial(p_nonterminal, max_split),
860
+ SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
861
+ )
862
+
863
+
864
+ class SamplePriorX(Module):
865
+ """Object representing the recursion scan in `sample_prior`.
866
+
867
+ The sequence of nodes to visit is pre-computed recursively once, unrolling
868
+ the recursion schedule.
869
+ """
870
+
871
+ node: Int32[Array, ' 2**(d-1)-1']
872
+ """The heap index of the node to visit."""
873
+
874
+ depth: Int32[Array, ' 2**(d-1)-1']
875
+ """The depth of the node."""
876
+
877
+ next_depth: Int32[Array, ' 2**(d-1)-1']
878
+ """The depth of the next node to visit, either the left child or the right
879
+ sibling of the node or of an ancestor."""
880
+
881
+ @classmethod
882
+ def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
883
+ """Initialize the sequence of nodes to visit.
884
+
885
+ Parameters
886
+ ----------
887
+ p_nonterminal
888
+ The prior probability of a node being non-terminal conditional on
889
+ its ancestors and on having available decision rules, at each depth.
890
+
891
+ Returns
892
+ -------
893
+ A `SamplePriorX` initialized with the sequence of nodes to visit.
894
+ """
895
+ seq = cls._sequence(p_nonterminal.size)
896
+ assert len(seq) == 2**p_nonterminal.size - 1
897
+ node = [node for node, depth in seq]
898
+ depth = [depth for node, depth in seq]
899
+ next_depth = [*depth[1:], p_nonterminal.size]
900
+ return cls(
901
+ node=jnp.array(node),
902
+ depth=jnp.array(depth),
903
+ next_depth=jnp.array(next_depth),
904
+ )
905
+
906
+ @classmethod
907
+ def _sequence(
908
+ cls, max_depth: int, depth: int = 0, node: int = 1
909
+ ) -> tuple[tuple[int, int], ...]:
910
+ """Recursively generate a sequence [(node, depth), ...]."""
911
+ if depth < max_depth:
912
+ out = ((node, depth),)
913
+ out += cls._sequence(max_depth, depth + 1, 2 * node)
914
+ out += cls._sequence(max_depth, depth + 1, 2 * node + 1)
915
+ return out
916
+ return ()
917
+
918
+
919
+ def sample_prior_onetree(
920
+ key: Key[Array, ''],
921
+ max_split: UInt[Array, ' p'],
922
+ p_nonterminal: Float32[Array, ' d-1'],
923
+ sigma_mu: Float32[Array, ''],
924
+ ) -> SamplePriorTrees:
925
+ """Sample a tree from the BART prior.
926
+
927
+ Parameters
928
+ ----------
929
+ key
930
+ A jax random key.
931
+ max_split
932
+ The maximum split value for each variable.
933
+ p_nonterminal
934
+ The prior probability of a node being non-terminal conditional on
935
+ its ancestors and on having available decision rules, at each depth.
936
+ sigma_mu
937
+ The prior standard deviation of each leaf.
938
+
939
+ Returns
940
+ -------
941
+ An object containing a generated tree.
942
+ """
943
+ carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split)
944
+ xs = SamplePriorX.initial(p_nonterminal)
945
+
946
+ def loop(carry: SamplePriorCarry, x: SamplePriorX):
947
+ keys = split_key(carry.key, 4)
948
+
949
+ # get variables at current stack level
950
+ stack = carry.stack
951
+ nonterminal = stack.nonterminal[x.depth]
952
+ lower = stack.lower[x.depth, :]
953
+ upper = stack.upper[x.depth, :]
954
+
955
+ # sample a random decision rule
956
+ available: Bool[Array, ' p'] = lower < upper
957
+ allowed = jnp.any(available)
958
+ var = randint_masked(keys.pop(), available)
959
+ split = 1 + random.randint(keys.pop(), (), lower[var], upper[var])
960
+
961
+ # cast to shorter integer types
962
+ var = var.astype(carry.trees.var_tree.dtype)
963
+ split = split.astype(carry.trees.split_tree.dtype)
964
+
965
+ # decide whether to try to grow the node if it is growable
966
+ pnt = p_nonterminal[x.depth]
967
+ try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt)
968
+ nonterminal &= try_nonterminal & allowed
969
+
970
+ # update trees
971
+ trees = carry.trees
972
+ trees = replace(
973
+ trees,
974
+ var_tree=trees.var_tree.at[x.node].set(var),
975
+ split_tree=trees.split_tree.at[x.node].set(
976
+ jnp.where(nonterminal, split, 0)
977
+ ),
978
+ )
979
+
980
+ def write_push_stack() -> SamplePriorStack:
981
+ """Update the stack to go to the left child."""
982
+ return replace(
983
+ stack,
984
+ nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
985
+ lower=stack.lower.at[x.next_depth, :].set(lower),
986
+ upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
987
+ var=stack.var.at[x.depth].set(var),
988
+ split=stack.split.at[x.depth].set(split),
989
+ )
990
+
991
+ def pop_push_stack() -> SamplePriorStack:
992
+ """Update the stack to go to the right sibling, possibly at lower depth."""
993
+ var = stack.var[x.next_depth - 1]
994
+ split = stack.split[x.next_depth - 1]
995
+ lower = stack.lower[x.next_depth - 1, :]
996
+ upper = stack.upper[x.next_depth - 1, :]
997
+ return replace(
998
+ stack,
999
+ lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
1000
+ upper=stack.upper.at[x.next_depth, :].set(upper),
1001
+ )
1002
+
1003
+ # update stack
1004
+ stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack)
1005
+
1006
+ # update carry
1007
+ carry = replace(carry, key=keys.pop(), stack=stack, trees=trees)
1008
+ return carry, None
1009
+
1010
+ carry, _ = lax.scan(loop, carry, xs)
1011
+ return carry.trees
1012
+
1013
+
1014
+ @partial(vmap_nodoc, in_axes=(0, None, None, None))
1015
+ def sample_prior_forest(
1016
+ keys: Key[Array, ' num_trees'],
1017
+ max_split: UInt[Array, ' p'],
1018
+ p_nonterminal: Float32[Array, ' d-1'],
1019
+ sigma_mu: Float32[Array, ''],
1020
+ ) -> SamplePriorTrees:
1021
+ """Sample a set of independent trees from the BART prior.
1022
+
1023
+ Parameters
1024
+ ----------
1025
+ keys
1026
+ A sequence of jax random keys, one for each tree. This determined the
1027
+ number of trees sampled.
1028
+ max_split
1029
+ The maximum split value for each variable.
1030
+ p_nonterminal
1031
+ The prior probability of a node being non-terminal conditional on
1032
+ its ancestors and on having available decision rules, at each depth.
1033
+ sigma_mu
1034
+ The prior standard deviation of each leaf.
1035
+
1036
+ Returns
1037
+ -------
1038
+ An object containing the generated trees.
1039
+ """
1040
+ return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu)
1041
+
1042
+
1043
+ @partial(jit, static_argnums=(1, 2))
1044
+ def sample_prior(
1045
+ key: Key[Array, ''],
1046
+ trace_length: int,
1047
+ num_trees: int,
1048
+ max_split: UInt[Array, ' p'],
1049
+ p_nonterminal: Float32[Array, ' d-1'],
1050
+ sigma_mu: Float32[Array, ''],
1051
+ ) -> SamplePriorTrees:
1052
+ """Sample independent trees from the BART prior.
1053
+
1054
+ Parameters
1055
+ ----------
1056
+ key
1057
+ A jax random key.
1058
+ trace_length
1059
+ The number of iterations.
1060
+ num_trees
1061
+ The number of trees for each iteration.
1062
+ max_split
1063
+ The number of cutpoints along each variable.
1064
+ p_nonterminal
1065
+ The prior probability of a node being non-terminal conditional on
1066
+ its ancestors and on having available decision rules, at each depth.
1067
+ This determines the maximum depth of the trees.
1068
+ sigma_mu
1069
+ The prior standard deviation of each leaf.
1070
+
1071
+ Returns
1072
+ -------
1073
+ An object containing the generated trees, with batch shape (trace_length, num_trees).
1074
+ """
1075
+ keys = random.split(key, trace_length * num_trees)
1076
+ trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu)
1077
+ return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees)
1078
+
1079
+
1080
+ class debug_mc_gbart(mc_gbart):
1081
+ """A subclass of `mc_gbart` that adds debugging functionality.
1082
+
1083
+ Parameters
1084
+ ----------
1085
+ *args
1086
+ Passed to `mc_gbart`.
1087
+ check_trees
1088
+ If `True`, check all trees with `check_trace` after running the MCMC,
1089
+ and assert that they are all valid. Set to `False` to allow jax tracing.
1090
+ **kw
1091
+ Passed to `mc_gbart`.
1092
+ """
1093
+
1094
+ def __init__(self, *args, check_trees: bool = True, **kw):
1095
+ super().__init__(*args, **kw)
1096
+ if check_trees:
1097
+ bad = self.check_trees()
1098
+ bad_count = jnp.count_nonzero(bad)
1099
+ assert bad_count == 0
1100
+
1101
+ def print_tree(
1102
+ self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
1103
+ ):
1104
+ """Print a single tree in human-readable format.
1105
+
1106
+ Parameters
1107
+ ----------
1108
+ i_chain
1109
+ The index of the MCMC chain.
1110
+ i_sample
1111
+ The index of the (post-burnin) sample in the chain.
1112
+ i_tree
1113
+ The index of the tree in the sample.
1114
+ print_all
1115
+ If `True`, also print the content of unused node slots.
1116
+ """
1117
+ tree = TreesTrace.from_dataclass(self._main_trace)
1118
+ tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
1119
+ s = format_tree(tree, print_all=print_all)
1120
+ print(s) # noqa: T201, this method is intended for debug
1121
+
1122
+ def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']:
1123
+ """Return the harmonic mean of the error variance.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ prior
1128
+ If `True`, use the prior distribution, otherwise use the full
1129
+ conditional at the last MCMC iteration.
1130
+
1131
+ Returns
1132
+ -------
1133
+ The harmonic mean 1/E[1/sigma^2] in the selected distribution.
1134
+ """
1135
+ bart = self._mcmc_state
1136
+ assert bart.error_cov_df is not None
1137
+ assert bart.z is None
1138
+ # inverse gamma prior: alpha = df / 2, beta = scale / 2
1139
+ if prior:
1140
+ alpha = bart.error_cov_df / 2
1141
+ beta = bart.error_cov_scale / 2
1142
+ else:
1143
+ alpha = bart.error_cov_df / 2 + bart.resid.size / 2
1144
+ norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
1145
+ beta = bart.error_cov_scale / 2 + norm2 / 2
1146
+ error_cov_inv = alpha / beta
1147
+ return jnp.sqrt(lax.reciprocal(error_cov_inv))
1148
+
1149
+ def compare_resid(
1150
+ self,
1151
+ ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
1152
+ """Re-compute residuals to compare them with the updated ones.
1153
+
1154
+ Returns
1155
+ -------
1156
+ resid1 : Float32[Array, 'mc_cores n']
1157
+ The final state of the residuals updated during the MCMC.
1158
+ resid2 : Float32[Array, 'mc_cores n']
1159
+ The residuals computed from the final state of the trees.
1160
+ """
1161
+ bart = self._mcmc_state
1162
+ resid1 = bart.resid
1163
+
1164
+ forests = TreesTrace.from_dataclass(bart.forest)
1165
+ trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1)
1166
+
1167
+ if bart.z is not None:
1168
+ ref = bart.z
1169
+ else:
1170
+ ref = bart.y
1171
+ resid2 = ref - (trees + bart.offset)
1172
+
1173
+ return resid1, resid2
1174
+
1175
+ def avg_acc(
1176
+ self,
1177
+ ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1178
+ """Compute the average acceptance rates of tree moves.
1179
+
1180
+ Returns
1181
+ -------
1182
+ acc_grow : Float32[Array, 'mc_cores']
1183
+ The average acceptance rate of grow moves.
1184
+ acc_prune : Float32[Array, 'mc_cores']
1185
+ The average acceptance rate of prune moves.
1186
+ """
1187
+ trace = self._main_trace
1188
+
1189
+ def acc(prefix):
1190
+ acc = getattr(trace, f'{prefix}_acc_count')
1191
+ prop = getattr(trace, f'{prefix}_prop_count')
1192
+ return acc.sum(axis=1) / prop.sum(axis=1)
1193
+
1194
+ return acc('grow'), acc('prune')
1195
+
1196
+ def avg_prop(
1197
+ self,
1198
+ ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1199
+ """Compute the average proposal rate of grow and prune moves.
1200
+
1201
+ Returns
1202
+ -------
1203
+ prop_grow : Float32[Array, 'mc_cores']
1204
+ The fraction of times grow was proposed instead of prune.
1205
+ prop_prune : Float32[Array, 'mc_cores']
1206
+ The fraction of times prune was proposed instead of grow.
1207
+
1208
+ Notes
1209
+ -----
1210
+ This function does not take into account cases where no move was
1211
+ proposed.
1212
+ """
1213
+ trace = self._main_trace
1214
+
1215
+ def prop(prefix):
1216
+ return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
1217
+
1218
+ pgrow = prop('grow')
1219
+ pprune = prop('prune')
1220
+ total = pgrow + pprune
1221
+ return pgrow / total, pprune / total
1222
+
1223
+ def avg_move(
1224
+ self,
1225
+ ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1226
+ """Compute the move rate.
1227
+
1228
+ Returns
1229
+ -------
1230
+ rate_grow : Float32[Array, 'mc_cores']
1231
+ The fraction of times a grow move was proposed and accepted.
1232
+ rate_prune : Float32[Array, 'mc_cores']
1233
+ The fraction of times a prune move was proposed and accepted.
1234
+ """
1235
+ agrow, aprune = self.avg_acc()
1236
+ pgrow, pprune = self.avg_prop()
1237
+ return agrow * pgrow, aprune * pprune
1238
+
1239
+ def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']:
1240
+ """Histogram of tree depths for each state of the trees.
1241
+
1242
+ Returns
1243
+ -------
1244
+ A matrix where each row contains a histogram of tree depths.
1245
+ """
1246
+ out: Int32[Array, '*chains samples d']
1247
+ out = forest_depth_distr(self._main_trace.split_tree)
1248
+ if out.ndim < 3:
1249
+ out = out[None, :, :]
1250
+ return out
1251
+
1252
+ def _points_per_node_distr(
1253
+ self, node_type: str
1254
+ ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
1255
+ out: Int32[Array, '*chains samples n+1']
1256
+ out = points_per_node_distr(
1257
+ self._mcmc_state.X,
1258
+ self._main_trace.var_tree,
1259
+ self._main_trace.split_tree,
1260
+ node_type,
1261
+ sum_batch_axis=-1,
1262
+ )
1263
+ if out.ndim < 3:
1264
+ out = out[None, :, :]
1265
+ return out
1266
+
1267
+ def points_per_decision_node_distr(
1268
+ self,
1269
+ ) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
1270
+ """Histogram of number of points belonging to parent-of-leaf nodes.
1271
+
1272
+ Returns
1273
+ -------
1274
+ For each chain, a matrix where each row contains a histogram of number of points.
1275
+ """
1276
+ return self._points_per_node_distr('leaf-parent')
1277
+
1278
+ def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
1279
+ """Histogram of number of points belonging to leaves.
1280
+
1281
+ Returns
1282
+ -------
1283
+ A matrix where each row contains a histogram of number of points.
1284
+ """
1285
+ return self._points_per_node_distr('leaf')
1286
+
1287
+ def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']:
1288
+ """Apply `check_trace` to all the tree draws."""
1289
+ out: UInt[Array, '*chains samples num_trees']
1290
+ out = check_trace(self._main_trace, self._mcmc_state.forest.max_split)
1291
+ if out.ndim < 3:
1292
+ out = out[None, :, :]
1293
+ return out
1294
+
1295
+ def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']:
1296
+ """Find iterations where a tree becomes invalid.
1297
+
1298
+ Returns
1299
+ -------
1300
+ A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
1301
+ """
1302
+ bad = self.check_trees().astype(bool)
1303
+ bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
1304
+ return bad & ~bad_before
1305
+
1306
+
1307
+ class debug_gbart(debug_mc_gbart, gbart):
1308
+ """A subclass of `gbart` that adds debugging functionality.
1309
+
1310
+ Parameters
1311
+ ----------
1312
+ *args
1313
+ Passed to `gbart`.
1314
+ check_trees
1315
+ If `True`, check all trees with `check_trace` after running the MCMC,
1316
+ and assert that they are all valid. Set to `False` to allow jax tracing.
1317
+ **kw
1318
+ Passed to `gbart`.
1319
+ """