bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/debug.py CHANGED
@@ -1,13 +1,72 @@
1
- import functools
1
+ # bartz/src/bartz/debug.py
2
+ #
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
2
24
 
3
- import jax
4
- from jax import lax
25
+ """Debugging utilities. The entry point is the class `debug_gbart`."""
26
+
27
+ from collections.abc import Callable
28
+ from dataclasses import replace
29
+ from functools import partial
30
+ from math import ceil, log2
31
+ from re import fullmatch
32
+
33
+ import numpy
34
+ from equinox import Module, field
35
+ from jax import jit, lax, random, vmap
5
36
  from jax import numpy as jnp
37
+ from jax.tree_util import tree_map
38
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
39
+
40
+ from bartz.BART import FloatLike, gbart
41
+ from bartz.grove import (
42
+ TreeHeaps,
43
+ evaluate_forest,
44
+ is_actual_leaf,
45
+ is_leaves_parent,
46
+ traverse_tree,
47
+ tree_depth,
48
+ tree_depths,
49
+ )
50
+ from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
51
+ from bartz.jaxext import split as split_key
52
+ from bartz.mcmcloop import TreesTrace
53
+ from bartz.mcmcstep import randint_masked
54
+
6
55
 
7
- from . import grove, jaxext
56
+ def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
57
+ """Convert a tree to a human-readable string.
8
58
 
59
+ Parameters
60
+ ----------
61
+ tree
62
+ A single tree to format.
63
+ print_all
64
+ If `True`, also print the contents of unused node slots in the arrays.
9
65
 
10
- def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
66
+ Returns
67
+ -------
68
+ A string representation of the tree.
69
+ """
11
70
  tee = '├──'
12
71
  corner = '└──'
13
72
  join = '│ '
@@ -15,12 +74,20 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
15
74
  down = '┐'
16
75
  bottom = '╢' # '┨' #
17
76
 
18
- def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
19
- if index >= len(leaf_tree):
77
+ def traverse_tree(
78
+ lines: list[str],
79
+ index: int,
80
+ depth: int,
81
+ indent: str,
82
+ first_indent: str,
83
+ next_indent: str,
84
+ unused: bool,
85
+ ):
86
+ if index >= len(tree.leaf_tree):
20
87
  return
21
88
 
22
- var = var_tree.at[index].get(mode='fill', fill_value=0)
23
- split = split_tree.at[index].get(mode='fill', fill_value=0)
89
+ var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item()
90
+ split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item()
24
91
 
25
92
  is_leaf = split == 0
26
93
  left_child = 2 * index
@@ -33,26 +100,26 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
33
100
  category = 'leaf'
34
101
  else:
35
102
  category = 'decision'
36
- node_str = f'{category}({var}, {split}, {leaf_tree[index]})'
103
+ node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
37
104
  else:
38
105
  assert not unused
39
106
  if is_leaf:
40
- node_str = f'{leaf_tree[index]:#.2g}'
107
+ node_str = f'{tree.leaf_tree[index]:#.2g}'
41
108
  else:
42
- node_str = f'({var}: {split})'
109
+ node_str = f'x{var} < {split}'
43
110
 
44
- if not is_leaf or (print_all and left_child < len(leaf_tree)):
111
+ if not is_leaf or (print_all and left_child < len(tree.leaf_tree)):
45
112
  link = down
46
- elif not print_all and left_child >= len(leaf_tree):
113
+ elif not print_all and left_child >= len(tree.leaf_tree):
47
114
  link = bottom
48
115
  else:
49
116
  link = ' '
50
117
 
51
- max_number = len(leaf_tree) - 1
118
+ max_number = len(tree.leaf_tree) - 1
52
119
  ndigits = len(str(max_number))
53
120
  number = str(index).rjust(ndigits)
54
121
 
55
- print(f' {number} {indent}{first_indent}{link}{node_str}')
122
+ lines.append(f' {number} {indent}{first_indent}{link}{node_str}')
56
123
 
57
124
  indent += next_indent
58
125
  unused = unused or is_leaf
@@ -60,125 +127,1238 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
60
127
  if unused and not print_all:
61
128
  return
62
129
 
63
- traverse_tree(left_child, depth + 1, indent, tee, join, unused)
64
- traverse_tree(right_child, depth + 1, indent, corner, space, unused)
130
+ traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused)
131
+ traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused)
65
132
 
66
- traverse_tree(1, 0, '', '', '', False)
133
+ lines = []
134
+ traverse_tree(lines, 1, 0, '', '', '', False)
135
+ return '\n'.join(lines)
67
136
 
68
137
 
69
- def tree_actual_depth(split_tree):
70
- is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
71
- depth = grove.tree_depths(is_leaf.size)
138
+ def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
139
+ """Measure the depth of the tree.
140
+
141
+ Parameters
142
+ ----------
143
+ split_tree
144
+ The cutpoints of the decision rules.
145
+
146
+ Returns
147
+ -------
148
+ The depth of the deepest leaf in the tree. The root is at depth 0.
149
+ """
150
+ # this could be done just with split_tree != 0
151
+ is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
152
+ depth = tree_depths(is_leaf.size)
72
153
  depth = jnp.where(is_leaf, depth, 0)
73
154
  return jnp.max(depth)
74
155
 
75
156
 
76
- def forest_depth_distr(split_trees):
77
- depth = grove.tree_depth(split_trees) + 1
78
- depths = jax.vmap(tree_actual_depth)(split_trees)
157
+ def forest_depth_distr(
158
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'],
159
+ ) -> Int32[Array, ' d']:
160
+ """Histogram the depths of a set of trees.
161
+
162
+ Parameters
163
+ ----------
164
+ split_tree
165
+ The cutpoints of the decision rules of the trees.
166
+
167
+ Returns
168
+ -------
169
+ An integer vector where the i-th element counts how many trees have depth i.
170
+ """
171
+ depth = tree_depth(split_tree) + 1
172
+ depths = vmap(tree_actual_depth)(split_tree)
79
173
  return jnp.bincount(depths, length=depth)
80
174
 
81
175
 
82
- def trace_depth_distr(split_trees_trace):
83
- return jax.vmap(forest_depth_distr)(split_trees_trace)
176
+ @jit
177
+ def trace_depth_distr(
178
+ split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'],
179
+ ) -> Int32[Array, 'trace_length d']:
180
+ """Histogram the depths of a sequence of sets of trees.
84
181
 
182
+ Parameters
183
+ ----------
184
+ split_tree
185
+ The cutpoints of the decision rules of the trees.
85
186
 
86
- def points_per_leaf_distr(var_tree, split_tree, X):
87
- traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None))
88
- indices = traverse_tree(X, var_tree, split_tree)
89
- count_tree = jnp.zeros(
90
- 2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
91
- )
92
- count_tree = count_tree.at[indices].add(1)
93
- is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8)
94
- return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1)
187
+ Returns
188
+ -------
189
+ A matrix where element (t,i) counts how many trees have depth i in set t.
190
+ """
191
+ return vmap(forest_depth_distr)(split_tree)
192
+
193
+
194
+ def points_per_decision_node_distr(
195
+ var_tree: UInt[Array, ' 2**(d-1)'],
196
+ split_tree: UInt[Array, ' 2**(d-1)'],
197
+ X: UInt[Array, 'p n'],
198
+ ) -> Int32[Array, ' n+1']:
199
+ """Histogram points-per-node counts.
200
+
201
+ Count how many parent-of-leaf nodes in a tree select each possible amount
202
+ of points.
203
+
204
+ Parameters
205
+ ----------
206
+ var_tree
207
+ The variables of the decision rules.
208
+ split_tree
209
+ The cutpoints of the decision rules.
210
+ X
211
+ The set of points to count.
212
+
213
+ Returns
214
+ -------
215
+ A vector where the i-th element counts how many next-to-leaf nodes have i points.
216
+ """
217
+ traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
218
+ indices = traverse_tree_X(X, var_tree, split_tree)
219
+ indices >>= 1
220
+ count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0)
221
+ is_parent = is_leaves_parent(split_tree)
222
+ return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent)
223
+
224
+
225
+ def forest_points_per_decision_node_distr(
226
+ trees: TreeHeaps, X: UInt[Array, 'p n']
227
+ ) -> Int32[Array, ' n+1']:
228
+ """Histogram points-per-node counts for a set of trees.
229
+
230
+ Count how many parent-of-leaf nodes in a set of trees select each possible
231
+ amount of points.
232
+
233
+ Parameters
234
+ ----------
235
+ trees
236
+ The set of trees. The variables must have broadcast shape (num_trees,).
237
+ X
238
+ The set of points to count.
239
+
240
+ Returns
241
+ -------
242
+ A vector where the i-th element counts how many next-to-leaf nodes have i points.
243
+ """
244
+ distr = jnp.zeros(X.shape[1] + 1, int)
245
+
246
+ def loop(distr, heaps: tuple[Array, Array]):
247
+ return distr + points_per_decision_node_distr(*heaps, X), None
248
+
249
+ distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
250
+ return distr
251
+
252
+
253
+ @jit
254
+ def trace_points_per_decision_node_distr(
255
+ trace: TreeHeaps, X: UInt[Array, 'p n']
256
+ ) -> Int32[Array, 'trace_length n+1']:
257
+ """Separately histogram points-per-node counts over a sequence of sets of trees.
95
258
 
259
+ For each set of trees, count how many parent-of-leaf nodes select each
260
+ possible amount of points.
96
261
 
97
- def forest_points_per_leaf_distr(bart, X):
262
+ Parameters
263
+ ----------
264
+ trace
265
+ The sequence of sets of trees. The variables must have broadcast shape
266
+ (trace_length, num_trees).
267
+ X
268
+ The set of points to count.
269
+
270
+ Returns
271
+ -------
272
+ A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t.
273
+ """
274
+
275
+ def loop(_, trace):
276
+ return None, forest_points_per_decision_node_distr(trace, X)
277
+
278
+ _, distr = lax.scan(loop, None, trace)
279
+ return distr
280
+
281
+
282
+ def points_per_leaf_distr(
283
+ var_tree: UInt[Array, ' 2**(d-1)'],
284
+ split_tree: UInt[Array, ' 2**(d-1)'],
285
+ X: UInt[Array, 'p n'],
286
+ ) -> Int32[Array, ' n+1']:
287
+ """Histogram points-per-leaf counts in a tree.
288
+
289
+ Count how many leaves in a tree select each possible amount of points.
290
+
291
+ Parameters
292
+ ----------
293
+ var_tree
294
+ The variables of the decision rules.
295
+ split_tree
296
+ The cutpoints of the decision rules.
297
+ X
298
+ The set of points to count.
299
+
300
+ Returns
301
+ -------
302
+ A vector where the i-th element counts how many leaves have i points.
303
+ """
304
+ traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
305
+ indices = traverse_tree_X(X, var_tree, split_tree)
306
+ count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1)
307
+ is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
308
+ return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf)
309
+
310
+
311
+ def forest_points_per_leaf_distr(
312
+ trees: TreeHeaps, X: UInt[Array, 'p n']
313
+ ) -> Int32[Array, ' n+1']:
314
+ """Histogram points-per-leaf counts over a set of trees.
315
+
316
+ Count how many leaves in a set of trees select each possible amount of points.
317
+
318
+ Parameters
319
+ ----------
320
+ trees
321
+ The set of trees. The variables must have broadcast shape (num_trees,).
322
+ X
323
+ The set of points to count.
324
+
325
+ Returns
326
+ -------
327
+ A vector where the i-th element counts how many leaves have i points.
328
+ """
98
329
  distr = jnp.zeros(X.shape[1] + 1, int)
99
- trees = bart['var_trees'], bart['split_trees']
100
330
 
101
- def loop(distr, tree):
102
- return distr + points_per_leaf_distr(*tree, X), None
331
+ def loop(distr, heaps: tuple[Array, Array]):
332
+ return distr + points_per_leaf_distr(*heaps, X), None
103
333
 
104
- distr, _ = lax.scan(loop, distr, trees)
334
+ distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
105
335
  return distr
106
336
 
107
337
 
108
- def trace_points_per_leaf_distr(bart, X):
109
- def loop(_, bart):
110
- return None, forest_points_per_leaf_distr(bart, X)
338
+ @jit
339
+ def trace_points_per_leaf_distr(
340
+ trace: TreeHeaps, X: UInt[Array, 'p n']
341
+ ) -> Int32[Array, 'trace_length n+1']:
342
+ """Separately histogram points-per-leaf counts over a sequence of sets of trees.
343
+
344
+ For each set of trees, count how many leaves select each possible amount of
345
+ points.
346
+
347
+ Parameters
348
+ ----------
349
+ trace
350
+ The sequence of sets of trees. The variables must have broadcast shape
351
+ (trace_length, num_trees).
352
+ X
353
+ The set of points to count.
111
354
 
112
- _, distr = lax.scan(loop, None, bart)
355
+ Returns
356
+ -------
357
+ A matrix where element (t,i) counts how many leaves have i points in set t.
358
+ """
359
+
360
+ def loop(_, trace):
361
+ return None, forest_points_per_leaf_distr(trace, X)
362
+
363
+ _, distr = lax.scan(loop, None, trace)
113
364
  return distr
114
365
 
115
366
 
116
- def check_types(leaf_tree, var_tree, split_tree, max_split):
117
- expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1)
367
+ check_functions = []
368
+
369
+
370
+ CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']]
371
+
372
+
373
+ def check(func: CheckFunc) -> CheckFunc:
374
+ """Add a function to a list of functions used to check trees.
375
+
376
+ Use to decorate functions that check whether a tree is valid in some way.
377
+ These functions are invoked automatically by `check_tree`, `check_trace` and
378
+ `debug_gbart`.
379
+
380
+ Parameters
381
+ ----------
382
+ func
383
+ The function to add to the list. It must accept a `TreeHeaps` and a
384
+ `max_split` argument, and return a boolean scalar that indicates if the
385
+ tree is ok.
386
+
387
+ Returns
388
+ -------
389
+ The function unchanged.
390
+ """
391
+ check_functions.append(func)
392
+ return func
393
+
394
+
395
+ @check
396
+ def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
397
+ """Check that integer types are as small as possible and coherent."""
398
+ expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1)
118
399
  expected_split_dtype = max_split.dtype
119
400
  return (
120
- var_tree.dtype == expected_var_dtype
121
- and split_tree.dtype == expected_split_dtype
401
+ tree.var_tree.dtype == expected_var_dtype
402
+ and tree.split_tree.dtype == expected_split_dtype
122
403
  )
123
404
 
124
405
 
125
- def check_sizes(leaf_tree, var_tree, split_tree, max_split):
126
- return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size
406
+ @check
407
+ def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001
408
+ """Check that array sizes are coherent."""
409
+ return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size
127
410
 
128
411
 
129
- def check_unused_node(leaf_tree, var_tree, split_tree, max_split):
130
- return (var_tree[0] == 0) & (split_tree[0] == 0)
412
+ @check
413
+ def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
414
+ """Check that the unused node slot at index 0 is not dirty."""
415
+ return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
131
416
 
132
417
 
133
- def check_leaf_values(leaf_tree, var_tree, split_tree, max_split):
134
- return jnp.all(jnp.isfinite(leaf_tree))
418
+ @check
419
+ def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
420
+ """Check that all leaf values are not inf of nan."""
421
+ return jnp.all(jnp.isfinite(tree.leaf_tree))
135
422
 
136
423
 
137
- def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split):
424
+ @check
425
+ def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
426
+ """Check if there is any marked-non-leaf node with a marked-leaf parent."""
138
427
  index = jnp.arange(
139
- 2 * split_tree.size,
140
- dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
428
+ 2 * tree.split_tree.size,
429
+ dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
141
430
  )
142
431
  parent_index = index >> 1
143
- is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
144
- parent_is_leaf = split_tree[parent_index] == 0
432
+ is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0
433
+ parent_is_leaf = tree.split_tree[parent_index] == 0
145
434
  stray = is_not_leaf & parent_is_leaf
146
435
  stray = stray.at[1].set(False)
147
436
  return ~jnp.any(stray)
148
437
 
149
438
 
150
- check_functions = [
151
- check_types,
152
- check_sizes,
153
- check_unused_node,
154
- check_leaf_values,
155
- check_stray_nodes,
156
- ]
439
+ @check
440
+ def check_rule_consistency(
441
+ tree: TreeHeaps, max_split: UInt[Array, ' p']
442
+ ) -> bool | Bool[Array, '']:
443
+ """Check that decision rules define proper subsets of ancestor rules."""
444
+ if tree.var_tree.size < 4:
445
+ return True
446
+
447
+ # initial boundaries of decision rules. use extreme integers instead of 0,
448
+ # max_split to avoid checking if there is something out of bounds.
449
+ small = jnp.iinfo(jnp.int32).min
450
+ large = jnp.iinfo(jnp.int32).max
451
+ lower = jnp.full(max_split.size, small, jnp.int32)
452
+ upper = jnp.full(max_split.size, large, jnp.int32)
453
+ # specify the type explicitly, otherwise they are weakly types and get
454
+ # implicitly converted to split.dtype (typically uint8) in the expressions
455
+
456
+ def _check_recursive(node, lower, upper):
457
+ # read decision rule
458
+ var = tree.var_tree[node]
459
+ split = tree.split_tree[node]
460
+
461
+ # get rule boundaries from ancestors. use fill value in case var is
462
+ # out of bounds, we don't want to check out of bounds in this function
463
+ lower_var = lower.at[var].get(mode='fill', fill_value=small)
464
+ upper_var = upper.at[var].get(mode='fill', fill_value=large)
465
+
466
+ # check rule is in bounds
467
+ bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False)
468
+
469
+ # recurse
470
+ if node < tree.var_tree.size // 2:
471
+ bad |= _check_recursive(
472
+ 2 * node,
473
+ lower,
474
+ upper.at[jnp.where(split, var, max_split.size)].set(split),
475
+ )
476
+ bad |= _check_recursive(
477
+ 2 * node + 1,
478
+ lower.at[jnp.where(split, var, max_split.size)].set(split),
479
+ upper,
480
+ )
481
+ return bad
157
482
 
483
+ return ~_check_recursive(1, lower, upper)
158
484
 
159
- def check_tree(leaf_tree, var_tree, split_tree, max_split):
160
- error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1)
485
+
486
+ @check
487
+ def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
488
+ """Check that #leaves = 1 + #(internal nodes)."""
489
+ is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True)
490
+ num_leaves = jnp.count_nonzero(is_leaf)
491
+ num_internal = jnp.count_nonzero(tree.split_tree)
492
+ return num_leaves == num_internal + 1
493
+
494
+
495
+ @check
496
+ def check_var_in_bounds(
497
+ tree: TreeHeaps, max_split: UInt[Array, ' p']
498
+ ) -> Bool[Array, '']:
499
+ """Check that variables are in [0, max_split.size)."""
500
+ decision_node = tree.split_tree.astype(bool)
501
+ in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size)
502
+ return jnp.all(in_bounds | ~decision_node)
503
+
504
+
505
+ @check
506
+ def check_split_in_bounds(
507
+ tree: TreeHeaps, max_split: UInt[Array, ' p']
508
+ ) -> Bool[Array, '']:
509
+ """Check that splits are in [0, max_split[var]]."""
510
+ max_split_var = (
511
+ max_split.astype(jnp.int32)
512
+ .at[tree.var_tree]
513
+ .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
514
+ )
515
+ return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var))
516
+
517
+
518
+ def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
519
+ """Check the validity of a tree.
520
+
521
+ Use `describe_error` to parse the error code returned by this function.
522
+
523
+ Parameters
524
+ ----------
525
+ tree
526
+ The tree to check.
527
+ max_split
528
+ The maximum split value for each variable.
529
+
530
+ Returns
531
+ -------
532
+ An integer where each bit indicates whether a check failed.
533
+ """
534
+ error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1)
161
535
  error = error_type(0)
162
536
  for i, func in enumerate(check_functions):
163
- ok = func(leaf_tree, var_tree, split_tree, max_split)
537
+ ok = func(tree, max_split)
164
538
  ok = jnp.bool_(ok)
165
539
  bit = (~ok) << i
166
540
  error |= bit
167
541
  return error
168
542
 
169
543
 
170
- def describe_error(error):
544
+ def describe_error(error: int | Integer[Array, '']) -> list[str]:
545
+ """Describe the error code returned by `check_tree`.
546
+
547
+ Parameters
548
+ ----------
549
+ error
550
+ The error code returned by `check_tree`.
551
+
552
+ Returns
553
+ -------
554
+ A list of the function names that implement the failed checks.
555
+ """
171
556
  return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
172
557
 
173
558
 
174
- check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None))
559
+ @jit
560
+ @partial(vmap_nodoc, in_axes=(0, None))
561
+ def check_trace(
562
+ trace: TreeHeaps, max_split: UInt[Array, ' p']
563
+ ) -> UInt[Array, 'trace_length num_trees']:
564
+ """Check the validity of a sequence of sets of trees.
565
+
566
+ Use `describe_error` to parse the error codes returned by this function.
567
+
568
+ Parameters
569
+ ----------
570
+ trace
571
+ The sequence of sets of trees to check. The tree arrays must have
572
+ broadcast shape (trace_length, num_trees). This object can have
573
+ additional attributes beyond the tree arrays, they are ignored.
574
+ max_split
575
+ The maximum split value for each variable.
576
+
577
+ Returns
578
+ -------
579
+ A matrix of error codes for each tree.
580
+ """
581
+ trees = TreesTrace.from_dataclass(trace)
582
+ check_forest = vmap(check_tree, in_axes=(0, None))
583
+ return check_forest(trees, max_split)
584
+
585
+
586
+ def _get_next_line(s: str, i: int) -> tuple[str, int]:
587
+ """Get the next line from a string and the new index."""
588
+ i_new = s.find('\n', i)
589
+ if i_new == -1:
590
+ return s[i:], len(s)
591
+ return s[i:i_new], i_new + 1
592
+
593
+
594
+ class BARTTraceMeta(Module):
595
+ """Metadata of R BART tree traces.
596
+
597
+ Parameters
598
+ ----------
599
+ ndpost
600
+ The number of posterior draws.
601
+ ntree
602
+ The number of trees in the model.
603
+ numcut
604
+ The maximum split value for each variable.
605
+ heap_size
606
+ The size of the heap required to store the trees.
607
+ """
608
+
609
+ ndpost: int = field(static=True)
610
+ ntree: int = field(static=True)
611
+ numcut: UInt[Array, ' p']
612
+ heap_size: int = field(static=True)
613
+
614
+
615
+ def scan_BART_trees(trees: str) -> BARTTraceMeta:
616
+ """Scan an R BART tree trace checking for errors and parsing metadata.
617
+
618
+ Parameters
619
+ ----------
620
+ trees
621
+ The string representation of a trace of trees of the R BART package.
622
+ Can be accessed from ``mc_gbart(...).treedraws['trees']``.
623
+
624
+ Returns
625
+ -------
626
+ An object containing the metadata.
627
+
628
+ Raises
629
+ ------
630
+ ValueError
631
+ If the string is malformed or contains leftover characters.
632
+ """
633
+ # parse first line
634
+ line, i_char = _get_next_line(trees, 0)
635
+ i_line = 1
636
+ match = fullmatch(r'(\d+) (\d+) (\d+)', line)
637
+ if match is None:
638
+ msg = f'Malformed header at {i_line=}'
639
+ raise ValueError(msg)
640
+ ndpost, ntree, p = map(int, match.groups())
641
+
642
+ # initial values for maxima
643
+ max_heap_index = 0
644
+ numcut = numpy.zeros(p, int)
645
+
646
+ # cycle over iterations and trees
647
+ for i_iter in range(ndpost):
648
+ for i_tree in range(ntree):
649
+ # parse first line of tree definition
650
+ line, i_char = _get_next_line(trees, i_char)
651
+ i_line += 1
652
+ match = fullmatch(r'(\d+)', line)
653
+ if match is None:
654
+ msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
655
+ raise ValueError(msg)
656
+ num_nodes = int(line)
657
+
658
+ # cycle over nodes
659
+ for i_node in range(num_nodes):
660
+ # parse node definition
661
+ line, i_char = _get_next_line(trees, i_char)
662
+ i_line += 1
663
+ match = fullmatch(
664
+ r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
665
+ )
666
+ if match is None:
667
+ msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
668
+ raise ValueError(msg)
669
+ i_heap = int(match.group(1))
670
+ var = int(match.group(2))
671
+ split = int(match.group(3))
672
+
673
+ # update maxima
674
+ numcut[var] = max(numcut[var], split)
675
+ max_heap_index = max(max_heap_index, i_heap)
676
+
677
+ assert i_char <= len(trees)
678
+ if i_char < len(trees):
679
+ msg = f'Leftover {len(trees) - i_char} characters in string'
680
+ raise ValueError(msg)
681
+
682
+ # determine minimal integer type for numcut
683
+ numcut += 1 # because BART is 0-based
684
+ split_dtype = minimal_unsigned_dtype(numcut.max())
685
+ numcut = jnp.array(numcut.astype(split_dtype))
686
+
687
+ # determine minimum heap size to store the trees
688
+ heap_size = 2 ** ceil(log2(max_heap_index + 1))
689
+
690
+ return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size)
691
+
692
+
693
+ class TraceWithOffset(Module):
694
+ """Implementation of `bartz.mcmcloop.Trace`."""
695
+
696
+ leaf_tree: Float32[Array, 'ndpost ntree 2**d']
697
+ var_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
698
+ split_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
699
+ offset: Float32[Array, ' ndpost']
700
+
701
+ @classmethod
702
+ def from_trees_trace(
703
+ cls, trees: TreeHeaps, offset: Float32[Array, '']
704
+ ) -> 'TraceWithOffset':
705
+ """Create a `TraceWithOffset` from a `TreeHeaps`."""
706
+ ndpost, _, _ = trees.leaf_tree.shape
707
+ return cls(
708
+ leaf_tree=trees.leaf_tree,
709
+ var_tree=trees.var_tree,
710
+ split_tree=trees.split_tree,
711
+ offset=jnp.full(ndpost, offset),
712
+ )
713
+
714
+
715
+ def trees_BART_to_bartz(
716
+ trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
717
+ ) -> tuple[TraceWithOffset, BARTTraceMeta]:
718
+ """Convert trees from the R BART format to the bartz format.
719
+
720
+ Parameters
721
+ ----------
722
+ trees
723
+ The string representation of a trace of trees of the R BART package.
724
+ Can be accessed from ``mc_gbart(...).treedraws['trees']``.
725
+ min_maxdepth
726
+ The maximum tree depth of the output will be set to the maximum
727
+ observed depth in the input trees. Use this parameter to require at
728
+ least this maximum depth in the output format.
729
+ offset
730
+ The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
731
+ summed to the sum of trees. To match that behavior, this function
732
+ returns an offset as well, zero by default. Set with this parameter
733
+ otherwise.
175
734
 
735
+ Returns
736
+ -------
737
+ trace : TraceWithOffset
738
+ A representation of the trees compatible with the trace returned by
739
+ `bartz.mcmcloop.run_mcmc`.
740
+ meta : BARTTraceMeta
741
+ The metadata of the trace, containing the number of iterations, trees,
742
+ and the maximum split value.
743
+ """
744
+ # scan all the string checking for errors and determining sizes
745
+ meta = scan_BART_trees(trees)
176
746
 
177
- @functools.partial(jax.vmap, in_axes=(0, None))
178
- def check_trace(trace, state):
179
- return check_forest(
180
- trace['leaf_trees'],
181
- trace['var_trees'],
182
- trace['split_trees'],
183
- state['max_split'],
747
+ # skip first line
748
+ _, i_char = _get_next_line(trees, 0)
749
+
750
+ heap_size = max(meta.heap_size, 2**min_maxdepth)
751
+ leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32)
752
+ var_trees = numpy.zeros(
753
+ (meta.ndpost, meta.ntree, heap_size // 2),
754
+ dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
755
+ )
756
+ split_trees = numpy.zeros(
757
+ (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
184
758
  )
759
+
760
+ # cycle over iterations and trees
761
+ for i_iter in range(meta.ndpost):
762
+ for i_tree in range(meta.ntree):
763
+ # parse first line of tree definition
764
+ line, i_char = _get_next_line(trees, i_char)
765
+ num_nodes = int(line)
766
+
767
+ is_internal = numpy.zeros(heap_size // 2, dtype=bool)
768
+
769
+ # cycle over nodes
770
+ for _ in range(num_nodes):
771
+ # parse node definition
772
+ line, i_char = _get_next_line(trees, i_char)
773
+ values = line.split()
774
+ i_heap = int(values[0])
775
+ var = int(values[1])
776
+ split = int(values[2])
777
+ leaf = float(values[3])
778
+
779
+ # update values
780
+ leaf_trees[i_iter, i_tree, i_heap] = leaf
781
+ is_internal[i_heap // 2] = True
782
+ if i_heap < heap_size // 2:
783
+ var_trees[i_iter, i_tree, i_heap] = var
784
+ split_trees[i_iter, i_tree, i_heap] = split + 1
785
+
786
+ is_internal[0] = False
787
+ split_trees[i_iter, i_tree, ~is_internal] = 0
788
+
789
+ return TraceWithOffset(
790
+ leaf_tree=jnp.array(leaf_trees),
791
+ var_tree=jnp.array(var_trees),
792
+ split_tree=jnp.array(split_trees),
793
+ offset=jnp.zeros(meta.ndpost)
794
+ if offset is None
795
+ else jnp.full(meta.ndpost, offset),
796
+ ), meta
797
+
798
+
799
+ class SamplePriorStack(Module):
800
+ """Represent the manually managed stack used in `sample_prior`.
801
+
802
+ Each level of the stack represents a recursion into a child node in a
803
+ binary tree of maximum depth `d`.
804
+
805
+ Parameters
806
+ ----------
807
+ nonterminal
808
+ Whether the node is valid or the recursion is into unused node slots.
809
+ lower
810
+ upper
811
+ The available cutpoints along ``var`` are in the integer range
812
+ ``[1 + lower[var], 1 + upper[var])``.
813
+ var
814
+ split
815
+ The variable and cutpoint of a decision node.
816
+ """
817
+
818
+ nonterminal: Bool[Array, ' d-1']
819
+ lower: UInt[Array, 'd-1 p']
820
+ upper: UInt[Array, 'd-1 p']
821
+ var: UInt[Array, ' d-1']
822
+ split: UInt[Array, ' d-1']
823
+
824
+ @classmethod
825
+ def initial(
826
+ cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
827
+ ) -> 'SamplePriorStack':
828
+ """Initialize the stack.
829
+
830
+ Parameters
831
+ ----------
832
+ p_nonterminal
833
+ The prior probability of a node being non-terminal conditional on
834
+ its ancestors and on having available decision rules, at each depth.
835
+ max_split
836
+ The number of cutpoints along each variable.
837
+
838
+ Returns
839
+ -------
840
+ A `SamplePriorStack` initialized to start the recursion.
841
+ """
842
+ var_dtype = minimal_unsigned_dtype(max_split.size - 1)
843
+ return cls(
844
+ nonterminal=jnp.ones(p_nonterminal.size, bool),
845
+ lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
846
+ upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
847
+ var=jnp.zeros(p_nonterminal.size, var_dtype),
848
+ split=jnp.zeros(p_nonterminal.size, max_split.dtype),
849
+ )
850
+
851
+
852
+ class SamplePriorTrees(Module):
853
+ """Object holding the trees generated by `sample_prior`.
854
+
855
+ Parameters
856
+ ----------
857
+ leaf_tree
858
+ var_tree
859
+ split_tree
860
+ The arrays representing the trees, see `bartz.grove`.
861
+ """
862
+
863
+ leaf_tree: Float32[Array, '* 2**d']
864
+ var_tree: UInt[Array, '* 2**(d-1)']
865
+ split_tree: UInt[Array, '* 2**(d-1)']
866
+
867
+ @classmethod
868
+ def initial(
869
+ cls,
870
+ key: Key[Array, ''],
871
+ sigma_mu: Float32[Array, ''],
872
+ p_nonterminal: Float32[Array, ' d-1'],
873
+ max_split: UInt[Array, ' p'],
874
+ ) -> 'SamplePriorTrees':
875
+ """Initialize the trees.
876
+
877
+ The leaves are already correct and do not need to be changed.
878
+
879
+ Parameters
880
+ ----------
881
+ key
882
+ A jax random key.
883
+ sigma_mu
884
+ The prior standard deviation of each leaf.
885
+ p_nonterminal
886
+ The prior probability of a node being non-terminal conditional on
887
+ its ancestors and on having available decision rules, at each depth.
888
+ max_split
889
+ The number of cutpoints along each variable.
890
+
891
+ Returns
892
+ -------
893
+ Trees initialized with random leaves and stub tree structures.
894
+ """
895
+ heap_size = 2 ** (p_nonterminal.size + 1)
896
+ return cls(
897
+ leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
898
+ var_tree=jnp.zeros(
899
+ heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
900
+ ),
901
+ split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
902
+ )
903
+
904
+
905
+ class SamplePriorCarry(Module):
906
+ """Object holding values carried along the recursion in `sample_prior`.
907
+
908
+ Parameters
909
+ ----------
910
+ key
911
+ A jax random key used to sample decision rules.
912
+ stack
913
+ The stack used to manage the recursion.
914
+ trees
915
+ The output arrays.
916
+ """
917
+
918
+ key: Key[Array, '']
919
+ stack: SamplePriorStack
920
+ trees: SamplePriorTrees
921
+
922
+ @classmethod
923
+ def initial(
924
+ cls,
925
+ key: Key[Array, ''],
926
+ sigma_mu: Float32[Array, ''],
927
+ p_nonterminal: Float32[Array, ' d-1'],
928
+ max_split: UInt[Array, ' p'],
929
+ ) -> 'SamplePriorCarry':
930
+ """Initialize the carry object.
931
+
932
+ Parameters
933
+ ----------
934
+ key
935
+ A jax random key.
936
+ sigma_mu
937
+ The prior standard deviation of each leaf.
938
+ p_nonterminal
939
+ The prior probability of a node being non-terminal conditional on
940
+ its ancestors and on having available decision rules, at each depth.
941
+ max_split
942
+ The number of cutpoints along each variable.
943
+
944
+ Returns
945
+ -------
946
+ A `SamplePriorCarry` initialized to start the recursion.
947
+ """
948
+ keys = split_key(key)
949
+ return cls(
950
+ keys.pop(),
951
+ SamplePriorStack.initial(p_nonterminal, max_split),
952
+ SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
953
+ )
954
+
955
+
956
+ class SamplePriorX(Module):
957
+ """Object representing the recursion scan in `sample_prior`.
958
+
959
+ The sequence of nodes to visit is pre-computed recursively once, unrolling
960
+ the recursion schedule.
961
+
962
+ Parameters
963
+ ----------
964
+ node
965
+ The heap index of the node to visit.
966
+ depth
967
+ The depth of the node.
968
+ next_depth
969
+ The depth of the next node to visit, either the left child or the right
970
+ sibling of the node or of an ancestor.
971
+ """
972
+
973
+ node: Int32[Array, ' 2**(d-1)-1']
974
+ depth: Int32[Array, ' 2**(d-1)-1']
975
+ next_depth: Int32[Array, ' 2**(d-1)-1']
976
+
977
+ @classmethod
978
+ def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
979
+ """Initialize the sequence of nodes to visit.
980
+
981
+ Parameters
982
+ ----------
983
+ p_nonterminal
984
+ The prior probability of a node being non-terminal conditional on
985
+ its ancestors and on having available decision rules, at each depth.
986
+
987
+ Returns
988
+ -------
989
+ A `SamplePriorX` initialized with the sequence of nodes to visit.
990
+ """
991
+ seq = cls._sequence(p_nonterminal.size)
992
+ assert len(seq) == 2**p_nonterminal.size - 1
993
+ node = [node for node, depth in seq]
994
+ depth = [depth for node, depth in seq]
995
+ next_depth = depth[1:] + [p_nonterminal.size]
996
+ return cls(
997
+ node=jnp.array(node),
998
+ depth=jnp.array(depth),
999
+ next_depth=jnp.array(next_depth),
1000
+ )
1001
+
1002
+ @classmethod
1003
+ def _sequence(
1004
+ cls, max_depth: int, depth: int = 0, node: int = 1
1005
+ ) -> tuple[tuple[int, int], ...]:
1006
+ """Recursively generate a sequence [(node, depth), ...]."""
1007
+ if depth < max_depth:
1008
+ out = ((node, depth),)
1009
+ out += cls._sequence(max_depth, depth + 1, 2 * node)
1010
+ out += cls._sequence(max_depth, depth + 1, 2 * node + 1)
1011
+ return out
1012
+ return ()
1013
+
1014
+
1015
+ def sample_prior_onetree(
1016
+ key: Key[Array, ''],
1017
+ max_split: UInt[Array, ' p'],
1018
+ p_nonterminal: Float32[Array, ' d-1'],
1019
+ sigma_mu: Float32[Array, ''],
1020
+ ) -> SamplePriorTrees:
1021
+ """Sample a tree from the BART prior.
1022
+
1023
+ Parameters
1024
+ ----------
1025
+ key
1026
+ A jax random key.
1027
+ max_split
1028
+ The maximum split value for each variable.
1029
+ p_nonterminal
1030
+ The prior probability of a node being non-terminal conditional on
1031
+ its ancestors and on having available decision rules, at each depth.
1032
+ sigma_mu
1033
+ The prior standard deviation of each leaf.
1034
+
1035
+ Returns
1036
+ -------
1037
+ An object containing a generated tree.
1038
+ """
1039
+ carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split)
1040
+ xs = SamplePriorX.initial(p_nonterminal)
1041
+
1042
+ def loop(carry: SamplePriorCarry, x: SamplePriorX):
1043
+ keys = split_key(carry.key, 4)
1044
+
1045
+ # get variables at current stack level
1046
+ stack = carry.stack
1047
+ nonterminal = stack.nonterminal[x.depth]
1048
+ lower = stack.lower[x.depth, :]
1049
+ upper = stack.upper[x.depth, :]
1050
+
1051
+ # sample a random decision rule
1052
+ available: Bool[Array, ' p'] = lower < upper
1053
+ allowed = jnp.any(available)
1054
+ var = randint_masked(keys.pop(), available)
1055
+ split = 1 + random.randint(keys.pop(), (), lower[var], upper[var])
1056
+
1057
+ # cast to shorter integer types
1058
+ var = var.astype(carry.trees.var_tree.dtype)
1059
+ split = split.astype(carry.trees.split_tree.dtype)
1060
+
1061
+ # decide whether to try to grow the node if it is growable
1062
+ pnt = p_nonterminal[x.depth]
1063
+ try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt)
1064
+ nonterminal &= try_nonterminal & allowed
1065
+
1066
+ # update trees
1067
+ trees = carry.trees
1068
+ trees = replace(
1069
+ trees,
1070
+ var_tree=trees.var_tree.at[x.node].set(var),
1071
+ split_tree=trees.split_tree.at[x.node].set(
1072
+ jnp.where(nonterminal, split, 0)
1073
+ ),
1074
+ )
1075
+
1076
+ def write_push_stack() -> SamplePriorStack:
1077
+ """Update the stack to go to the left child."""
1078
+ return replace(
1079
+ stack,
1080
+ nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
1081
+ lower=stack.lower.at[x.next_depth, :].set(lower),
1082
+ upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
1083
+ var=stack.var.at[x.depth].set(var),
1084
+ split=stack.split.at[x.depth].set(split),
1085
+ )
1086
+
1087
+ def pop_push_stack() -> SamplePriorStack:
1088
+ """Update the stack to go to the right sibling, possibly at lower depth."""
1089
+ var = stack.var[x.next_depth - 1]
1090
+ split = stack.split[x.next_depth - 1]
1091
+ lower = stack.lower[x.next_depth - 1, :]
1092
+ upper = stack.upper[x.next_depth - 1, :]
1093
+ return replace(
1094
+ stack,
1095
+ lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
1096
+ upper=stack.upper.at[x.next_depth, :].set(upper),
1097
+ )
1098
+
1099
+ # update stack
1100
+ stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack)
1101
+
1102
+ # update carry
1103
+ carry = replace(carry, key=keys.pop(), stack=stack, trees=trees)
1104
+ return carry, None
1105
+
1106
+ carry, _ = lax.scan(loop, carry, xs)
1107
+ return carry.trees
1108
+
1109
+
1110
+ @partial(vmap_nodoc, in_axes=(0, None, None, None))
1111
+ def sample_prior_forest(
1112
+ keys: Key[Array, ' num_trees'],
1113
+ max_split: UInt[Array, ' p'],
1114
+ p_nonterminal: Float32[Array, ' d-1'],
1115
+ sigma_mu: Float32[Array, ''],
1116
+ ) -> SamplePriorTrees:
1117
+ """Sample a set of independent trees from the BART prior.
1118
+
1119
+ Parameters
1120
+ ----------
1121
+ keys
1122
+ A sequence of jax random keys, one for each tree. This determined the
1123
+ number of trees sampled.
1124
+ max_split
1125
+ The maximum split value for each variable.
1126
+ p_nonterminal
1127
+ The prior probability of a node being non-terminal conditional on
1128
+ its ancestors and on having available decision rules, at each depth.
1129
+ sigma_mu
1130
+ The prior standard deviation of each leaf.
1131
+
1132
+ Returns
1133
+ -------
1134
+ An object containing the generated trees.
1135
+ """
1136
+ return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu)
1137
+
1138
+
1139
+ @partial(jit, static_argnums=(1, 2))
1140
+ def sample_prior(
1141
+ key: Key[Array, ''],
1142
+ trace_length: int,
1143
+ num_trees: int,
1144
+ max_split: UInt[Array, ' p'],
1145
+ p_nonterminal: Float32[Array, ' d-1'],
1146
+ sigma_mu: Float32[Array, ''],
1147
+ ) -> SamplePriorTrees:
1148
+ """Sample independent trees from the BART prior.
1149
+
1150
+ Parameters
1151
+ ----------
1152
+ key
1153
+ A jax random key.
1154
+ trace_length
1155
+ The number of iterations.
1156
+ num_trees
1157
+ The number of trees for each iteration.
1158
+ max_split
1159
+ The number of cutpoints along each variable.
1160
+ p_nonterminal
1161
+ The prior probability of a node being non-terminal conditional on
1162
+ its ancestors and on having available decision rules, at each depth.
1163
+ This determines the maximum depth of the trees.
1164
+ sigma_mu
1165
+ The prior standard deviation of each leaf.
1166
+
1167
+ Returns
1168
+ -------
1169
+ An object containing the generated trees, with batch shape (trace_length, num_trees).
1170
+ """
1171
+ keys = random.split(key, trace_length * num_trees)
1172
+ trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu)
1173
+ return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees)
1174
+
1175
+
1176
+ class debug_gbart(gbart):
1177
+ """A subclass of `gbart` that adds debugging functionality.
1178
+
1179
+ Parameters
1180
+ ----------
1181
+ *args
1182
+ Passed to `gbart`.
1183
+ check_trees
1184
+ If `True`, check all trees with `check_trace` after running the MCMC,
1185
+ and assert that they are all valid. Set to `False` to allow jax tracing.
1186
+ **kw
1187
+ Passed to `gbart`.
1188
+ """
1189
+
1190
+ def __init__(self, *args, check_trees: bool = True, **kw):
1191
+ super().__init__(*args, **kw)
1192
+ if check_trees:
1193
+ bad = self.check_trees()
1194
+ bad_count = jnp.count_nonzero(bad)
1195
+ assert bad_count == 0
1196
+
1197
+ def show_tree(self, i_sample: int, i_tree: int, print_all: bool = False):
1198
+ """Print a single tree in human-readable format.
1199
+
1200
+ Parameters
1201
+ ----------
1202
+ i_sample
1203
+ The index of the posterior sample.
1204
+ i_tree
1205
+ The index of the tree in the sample.
1206
+ print_all
1207
+ If `True`, also print the content of unused node slots.
1208
+ """
1209
+ tree = TreesTrace.from_dataclass(self._main_trace)
1210
+ tree = tree_map(lambda x: x[i_sample, i_tree, :], tree)
1211
+ s = format_tree(tree, print_all=print_all)
1212
+ print(s) # noqa: T201, this method is intended for debug
1213
+
1214
+ def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, '']:
1215
+ """Return the harmonic mean of the error variance.
1216
+
1217
+ Parameters
1218
+ ----------
1219
+ prior
1220
+ If `True`, use the prior distribution, otherwise use the full
1221
+ conditional at the last MCMC iteration.
1222
+
1223
+ Returns
1224
+ -------
1225
+ The harmonic mean 1/E[1/sigma^2] in the selected distribution.
1226
+ """
1227
+ bart = self._mcmc_state
1228
+ assert bart.sigma2_alpha is not None
1229
+ assert bart.z is None
1230
+ if prior:
1231
+ alpha = bart.sigma2_alpha
1232
+ beta = bart.sigma2_beta
1233
+ else:
1234
+ resid = bart.resid
1235
+ alpha = bart.sigma2_alpha + resid.size / 2
1236
+ norm2 = resid @ resid
1237
+ beta = bart.sigma2_beta + norm2 / 2
1238
+ sigma2 = beta / alpha
1239
+ return jnp.sqrt(sigma2)
1240
+
1241
+ def compare_resid(self) -> tuple[Float32[Array, ' n'], Float32[Array, ' n']]:
1242
+ """Re-compute residuals to compare them with the updated ones.
1243
+
1244
+ Returns
1245
+ -------
1246
+ resid1 : Float32[Array, 'n']
1247
+ The final state of the residuals updated during the MCMC.
1248
+ resid2 : Float32[Array, 'n']
1249
+ The residuals computed from the final state of the trees.
1250
+ """
1251
+ bart = self._mcmc_state
1252
+ resid1 = bart.resid
1253
+
1254
+ trees = evaluate_forest(bart.X, bart.forest)
1255
+
1256
+ if bart.z is not None:
1257
+ ref = bart.z
1258
+ else:
1259
+ ref = bart.y
1260
+ resid2 = ref - (trees + bart.offset)
1261
+
1262
+ return resid1, resid2
1263
+
1264
+ def avg_acc(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
1265
+ """Compute the average acceptance rates of tree moves.
1266
+
1267
+ Returns
1268
+ -------
1269
+ acc_grow : Float32[Array, '']
1270
+ The average acceptance rate of grow moves.
1271
+ acc_prune : Float32[Array, '']
1272
+ The average acceptance rate of prune moves.
1273
+ """
1274
+ trace = self._main_trace
1275
+
1276
+ def acc(prefix):
1277
+ acc = getattr(trace, f'{prefix}_acc_count')
1278
+ prop = getattr(trace, f'{prefix}_prop_count')
1279
+ return acc.sum() / prop.sum()
1280
+
1281
+ return acc('grow'), acc('prune')
1282
+
1283
+ def avg_prop(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
1284
+ """Compute the average proposal rate of grow and prune moves.
1285
+
1286
+ Returns
1287
+ -------
1288
+ prop_grow : Float32[Array, '']
1289
+ The fraction of times grow was proposed instead of prune.
1290
+ prop_prune : Float32[Array, '']
1291
+ The fraction of times prune was proposed instead of grow.
1292
+
1293
+ Notes
1294
+ -----
1295
+ This function does not take into account cases where no move was
1296
+ proposed.
1297
+ """
1298
+ trace = self._main_trace
1299
+
1300
+ def prop(prefix):
1301
+ return getattr(trace, f'{prefix}_prop_count').sum()
1302
+
1303
+ pgrow = prop('grow')
1304
+ pprune = prop('prune')
1305
+ total = pgrow + pprune
1306
+ return pgrow / total, pprune / total
1307
+
1308
+ def avg_move(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
1309
+ """Compute the move rate.
1310
+
1311
+ Returns
1312
+ -------
1313
+ rate_grow : Float32[Array, '']
1314
+ The fraction of times a grow move was proposed and accepted.
1315
+ rate_prune : Float32[Array, '']
1316
+ The fraction of times a prune move was proposed and accepted.
1317
+ """
1318
+ agrow, aprune = self.avg_acc()
1319
+ pgrow, pprune = self.avg_prop()
1320
+ return agrow * pgrow, aprune * pprune
1321
+
1322
+ def depth_distr(self) -> Float32[Array, 'trace_length d']:
1323
+ """Histogram of tree depths for each state of the trees.
1324
+
1325
+ Returns
1326
+ -------
1327
+ A matrix where each row contains a histogram of tree depths.
1328
+ """
1329
+ return trace_depth_distr(self._main_trace.split_tree)
1330
+
1331
+ def points_per_decision_node_distr(self) -> Float32[Array, 'trace_length n+1']:
1332
+ """Histogram of number of points belonging to parent-of-leaf nodes.
1333
+
1334
+ Returns
1335
+ -------
1336
+ A matrix where each row contains a histogram of number of points.
1337
+ """
1338
+ return trace_points_per_decision_node_distr(
1339
+ self._main_trace, self._mcmc_state.X
1340
+ )
1341
+
1342
+ def points_per_leaf_distr(self) -> Float32[Array, 'trace_length n+1']:
1343
+ """Histogram of number of points belonging to leaves.
1344
+
1345
+ Returns
1346
+ -------
1347
+ A matrix where each row contains a histogram of number of points.
1348
+ """
1349
+ return trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
1350
+
1351
+ def check_trees(self) -> UInt[Array, 'trace_length ntree']:
1352
+ """Apply `check_trace` to all the tree draws."""
1353
+ return check_trace(self._main_trace, self._mcmc_state.forest.max_split)
1354
+
1355
+ def tree_goes_bad(self) -> Bool[Array, 'trace_length ntree']:
1356
+ """Find iterations where a tree becomes invalid.
1357
+
1358
+ Returns
1359
+ -------
1360
+ A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
1361
+ """
1362
+ bad = self.check_trees().astype(bool)
1363
+ bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
1364
+ return bad & ~bad_before