bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +6 -4
- bartz/_interface.py +937 -0
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +1217 -82
- bartz/grove.py +205 -103
- bartz/jaxext/__init__.py +287 -0
- bartz/jaxext/_autobatch.py +444 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +239 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +662 -314
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +140 -44
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/METADATA +18 -13
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.6.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/BART.py +0 -603
- bartz/jaxext.py +0 -423
- bartz/mcmcstep.py +0 -2335
- bartz-0.6.0.dist-info/RECORD +0 -13
bartz/debug.py
CHANGED
|
@@ -1,13 +1,75 @@
|
|
|
1
|
-
|
|
1
|
+
# bartz/src/bartz/debug.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
2
24
|
|
|
3
|
-
|
|
4
|
-
|
|
25
|
+
"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
|
|
26
|
+
|
|
27
|
+
from collections.abc import Callable
|
|
28
|
+
from dataclasses import replace
|
|
29
|
+
from functools import partial
|
|
30
|
+
from math import ceil, log2
|
|
31
|
+
from re import fullmatch
|
|
32
|
+
from typing import Literal
|
|
33
|
+
|
|
34
|
+
import numpy
|
|
35
|
+
from equinox import Module, field
|
|
36
|
+
from jax import jit, lax, random, vmap
|
|
5
37
|
from jax import numpy as jnp
|
|
38
|
+
from jax.tree_util import tree_map
|
|
39
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
|
|
40
|
+
|
|
41
|
+
from bartz.BART import gbart, mc_gbart
|
|
42
|
+
from bartz.BART._gbart import FloatLike
|
|
43
|
+
from bartz.grove import (
|
|
44
|
+
TreeHeaps,
|
|
45
|
+
evaluate_forest,
|
|
46
|
+
is_actual_leaf,
|
|
47
|
+
is_leaves_parent,
|
|
48
|
+
normalize_axis_tuple,
|
|
49
|
+
traverse_forest,
|
|
50
|
+
tree_depth,
|
|
51
|
+
tree_depths,
|
|
52
|
+
)
|
|
53
|
+
from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc
|
|
54
|
+
from bartz.jaxext import split as split_key
|
|
55
|
+
from bartz.mcmcloop import TreesTrace
|
|
56
|
+
from bartz.mcmcstep._moves import randint_masked
|
|
6
57
|
|
|
7
|
-
from . import grove, jaxext
|
|
8
58
|
|
|
59
|
+
def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
|
|
60
|
+
"""Convert a tree to a human-readable string.
|
|
9
61
|
|
|
10
|
-
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
tree
|
|
65
|
+
A single tree to format.
|
|
66
|
+
print_all
|
|
67
|
+
If `True`, also print the contents of unused node slots in the arrays.
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
A string representation of the tree.
|
|
72
|
+
"""
|
|
11
73
|
tee = '├──'
|
|
12
74
|
corner = '└──'
|
|
13
75
|
join = '│ '
|
|
@@ -15,12 +77,20 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
15
77
|
down = '┐'
|
|
16
78
|
bottom = '╢' # '┨' #
|
|
17
79
|
|
|
18
|
-
def traverse_tree(
|
|
19
|
-
|
|
80
|
+
def traverse_tree(
|
|
81
|
+
lines: list[str],
|
|
82
|
+
index: int,
|
|
83
|
+
depth: int,
|
|
84
|
+
indent: str,
|
|
85
|
+
first_indent: str,
|
|
86
|
+
next_indent: str,
|
|
87
|
+
unused: bool,
|
|
88
|
+
):
|
|
89
|
+
if index >= len(tree.leaf_tree):
|
|
20
90
|
return
|
|
21
91
|
|
|
22
|
-
var = var_tree.at[index].get(mode='fill', fill_value=0)
|
|
23
|
-
split = split_tree.at[index].get(mode='fill', fill_value=0)
|
|
92
|
+
var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item()
|
|
93
|
+
split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item()
|
|
24
94
|
|
|
25
95
|
is_leaf = split == 0
|
|
26
96
|
left_child = 2 * index
|
|
@@ -33,26 +103,26 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
33
103
|
category = 'leaf'
|
|
34
104
|
else:
|
|
35
105
|
category = 'decision'
|
|
36
|
-
node_str = f'{category}({var}, {split}, {leaf_tree[index]})'
|
|
106
|
+
node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
|
|
37
107
|
else:
|
|
38
108
|
assert not unused
|
|
39
109
|
if is_leaf:
|
|
40
|
-
node_str = f'{leaf_tree[index]:#.2g}'
|
|
110
|
+
node_str = f'{tree.leaf_tree[index]:#.2g}'
|
|
41
111
|
else:
|
|
42
|
-
node_str = f'
|
|
112
|
+
node_str = f'x{var} < {split}'
|
|
43
113
|
|
|
44
|
-
if not is_leaf or (print_all and left_child < len(leaf_tree)):
|
|
114
|
+
if not is_leaf or (print_all and left_child < len(tree.leaf_tree)):
|
|
45
115
|
link = down
|
|
46
|
-
elif not print_all and left_child >= len(leaf_tree):
|
|
116
|
+
elif not print_all and left_child >= len(tree.leaf_tree):
|
|
47
117
|
link = bottom
|
|
48
118
|
else:
|
|
49
119
|
link = ' '
|
|
50
120
|
|
|
51
|
-
max_number = len(leaf_tree) - 1
|
|
121
|
+
max_number = len(tree.leaf_tree) - 1
|
|
52
122
|
ndigits = len(str(max_number))
|
|
53
123
|
number = str(index).rjust(ndigits)
|
|
54
124
|
|
|
55
|
-
|
|
125
|
+
lines.append(f' {number} {indent}{first_indent}{link}{node_str}')
|
|
56
126
|
|
|
57
127
|
indent += next_indent
|
|
58
128
|
unused = unused or is_leaf
|
|
@@ -60,125 +130,1190 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
60
130
|
if unused and not print_all:
|
|
61
131
|
return
|
|
62
132
|
|
|
63
|
-
traverse_tree(left_child, depth + 1, indent, tee, join, unused)
|
|
64
|
-
traverse_tree(right_child, depth + 1, indent, corner, space, unused)
|
|
133
|
+
traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused)
|
|
134
|
+
traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused)
|
|
135
|
+
|
|
136
|
+
lines = []
|
|
137
|
+
traverse_tree(lines, 1, 0, '', '', '', False)
|
|
138
|
+
return '\n'.join(lines)
|
|
139
|
+
|
|
65
140
|
|
|
66
|
-
|
|
141
|
+
def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
|
|
142
|
+
"""Measure the depth of the tree.
|
|
67
143
|
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
split_tree
|
|
147
|
+
The cutpoints of the decision rules.
|
|
68
148
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
depth
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
The depth of the deepest leaf in the tree. The root is at depth 0.
|
|
152
|
+
"""
|
|
153
|
+
# this could be done just with split_tree != 0
|
|
154
|
+
is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
|
|
155
|
+
depth = tree_depths(is_leaf.size)
|
|
72
156
|
depth = jnp.where(is_leaf, depth, 0)
|
|
73
157
|
return jnp.max(depth)
|
|
74
158
|
|
|
75
159
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
160
|
+
@jit
|
|
161
|
+
@partial(jnp.vectorize, signature='(nt,hts)->(d)')
|
|
162
|
+
def forest_depth_distr(
|
|
163
|
+
split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'],
|
|
164
|
+
) -> Int32[Array, '*batch_shape d']:
|
|
165
|
+
"""Histogram the depths of a set of trees.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
split_tree
|
|
170
|
+
The cutpoints of the decision rules of the trees.
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
An integer vector where the i-th element counts how many trees have depth i.
|
|
175
|
+
"""
|
|
176
|
+
depth = tree_depth(split_tree) + 1
|
|
177
|
+
depths = vmap(tree_actual_depth)(split_tree)
|
|
79
178
|
return jnp.bincount(depths, length=depth)
|
|
80
179
|
|
|
81
180
|
|
|
82
|
-
|
|
83
|
-
|
|
181
|
+
@partial(jit, static_argnames=('node_type', 'sum_batch_axis'))
|
|
182
|
+
def points_per_node_distr(
|
|
183
|
+
X: UInt[Array, 'p n'],
|
|
184
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
185
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
186
|
+
node_type: Literal['leaf', 'leaf-parent'],
|
|
187
|
+
*,
|
|
188
|
+
sum_batch_axis: int | tuple[int, ...] = (),
|
|
189
|
+
) -> Int32[Array, '*reduced_batch_shape n+1']:
|
|
190
|
+
"""Histogram points-per-node counts in a set of trees.
|
|
84
191
|
|
|
192
|
+
Count how many nodes in a tree select each possible amount of points,
|
|
193
|
+
over a certain subset of nodes.
|
|
85
194
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
X
|
|
198
|
+
The set of points to count.
|
|
199
|
+
var_tree
|
|
200
|
+
The variables of the decision rules.
|
|
201
|
+
split_tree
|
|
202
|
+
The cutpoints of the decision rules.
|
|
203
|
+
node_type
|
|
204
|
+
The type of nodes to consider. Can be:
|
|
205
|
+
|
|
206
|
+
'leaf'
|
|
207
|
+
Count only leaf nodes.
|
|
208
|
+
'leaf-parent'
|
|
209
|
+
Count only parent-of-leaf nodes.
|
|
210
|
+
sum_batch_axis
|
|
211
|
+
Aggregate the histogram over these batch axes, counting how many nodes
|
|
212
|
+
have each possible amount of points over subsets of trees instead of
|
|
213
|
+
in each tree separately.
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
A vector where the i-th element counts how many nodes have i points.
|
|
218
|
+
"""
|
|
219
|
+
batch_ndim = var_tree.ndim - 1
|
|
220
|
+
axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
|
|
221
|
+
|
|
222
|
+
def func(
|
|
223
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
224
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
225
|
+
) -> Int32[Array, '*reduced_batch_shape n+1']:
|
|
226
|
+
indices: UInt[Array, '*batch_shape n']
|
|
227
|
+
indices = traverse_forest(X, var_tree, split_tree)
|
|
228
|
+
|
|
229
|
+
@partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)')
|
|
230
|
+
def count_points(
|
|
231
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
232
|
+
indices: UInt[Array, '*batch_shape n'],
|
|
233
|
+
) -> (
|
|
234
|
+
tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']]
|
|
235
|
+
| tuple[
|
|
236
|
+
UInt[Array, '*batch_shape 2**(d-1)'],
|
|
237
|
+
Bool[Array, '*batch_shape 2**(d-1)'],
|
|
238
|
+
]
|
|
239
|
+
):
|
|
240
|
+
if node_type == 'leaf-parent':
|
|
241
|
+
indices >>= 1
|
|
242
|
+
predicate = is_leaves_parent(split_tree)
|
|
243
|
+
elif node_type == 'leaf':
|
|
244
|
+
predicate = is_actual_leaf(split_tree, add_bottom_level=True)
|
|
245
|
+
else:
|
|
246
|
+
raise ValueError(node_type)
|
|
247
|
+
count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0)
|
|
248
|
+
return count_tree, predicate
|
|
249
|
+
|
|
250
|
+
count_tree, predicate = count_points(split_tree, indices)
|
|
251
|
+
|
|
252
|
+
def count_nodes(
|
|
253
|
+
count_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
|
|
254
|
+
predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
|
|
255
|
+
) -> Int32[Array, ' n+1']:
|
|
256
|
+
return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate)
|
|
257
|
+
|
|
258
|
+
# vmap count_nodes over non-batched dims
|
|
259
|
+
for i in reversed(range(batch_ndim)):
|
|
260
|
+
neg_i = i - var_tree.ndim
|
|
261
|
+
if i not in axes:
|
|
262
|
+
count_nodes = vmap(count_nodes, in_axes=neg_i)
|
|
263
|
+
|
|
264
|
+
return count_nodes(count_tree, predicate)
|
|
265
|
+
|
|
266
|
+
# automatically batch over all batch dimensions
|
|
267
|
+
max_io_nbytes = 2**27 # 128 MiB
|
|
268
|
+
out_dim_shift = len(axes)
|
|
269
|
+
for i in reversed(range(batch_ndim)):
|
|
270
|
+
if i in axes:
|
|
271
|
+
out_dim_shift -= 1
|
|
272
|
+
else:
|
|
273
|
+
func = autobatch(func, max_io_nbytes, i, i - out_dim_shift)
|
|
274
|
+
assert out_dim_shift == 0
|
|
275
|
+
|
|
276
|
+
return func(var_tree, split_tree)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
check_functions = []
|
|
95
280
|
|
|
96
281
|
|
|
97
|
-
|
|
98
|
-
distr = jnp.zeros(X.shape[1] + 1, int)
|
|
99
|
-
trees = bart['var_trees'], bart['split_trees']
|
|
282
|
+
CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']]
|
|
100
283
|
|
|
101
|
-
def loop(distr, tree):
|
|
102
|
-
return distr + points_per_leaf_distr(*tree, X), None
|
|
103
284
|
|
|
104
|
-
|
|
105
|
-
|
|
285
|
+
def check(func: CheckFunc) -> CheckFunc:
|
|
286
|
+
"""Add a function to a list of functions used to check trees.
|
|
106
287
|
|
|
288
|
+
Use to decorate functions that check whether a tree is valid in some way.
|
|
289
|
+
These functions are invoked automatically by `check_tree`, `check_trace` and
|
|
290
|
+
`debug_gbart`.
|
|
107
291
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
292
|
+
Parameters
|
|
293
|
+
----------
|
|
294
|
+
func
|
|
295
|
+
The function to add to the list. It must accept a `TreeHeaps` and a
|
|
296
|
+
`max_split` argument, and return a boolean scalar that indicates if the
|
|
297
|
+
tree is ok.
|
|
111
298
|
|
|
112
|
-
|
|
113
|
-
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
The function unchanged.
|
|
302
|
+
"""
|
|
303
|
+
check_functions.append(func)
|
|
304
|
+
return func
|
|
114
305
|
|
|
115
306
|
|
|
116
|
-
|
|
117
|
-
|
|
307
|
+
@check
|
|
308
|
+
def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
|
|
309
|
+
"""Check that integer types are as small as possible and coherent."""
|
|
310
|
+
expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1)
|
|
118
311
|
expected_split_dtype = max_split.dtype
|
|
119
312
|
return (
|
|
120
|
-
var_tree.dtype == expected_var_dtype
|
|
121
|
-
and split_tree.dtype == expected_split_dtype
|
|
313
|
+
tree.var_tree.dtype == expected_var_dtype
|
|
314
|
+
and tree.split_tree.dtype == expected_split_dtype
|
|
315
|
+
and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
|
|
122
316
|
)
|
|
123
317
|
|
|
124
318
|
|
|
125
|
-
|
|
126
|
-
|
|
319
|
+
@check
|
|
320
|
+
def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
|
|
321
|
+
"""Check that array sizes are coherent."""
|
|
322
|
+
return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size
|
|
127
323
|
|
|
128
324
|
|
|
129
|
-
|
|
130
|
-
|
|
325
|
+
@check
|
|
326
|
+
def check_unused_node(
|
|
327
|
+
tree: TreeHeaps, _max_split: UInt[Array, ' p']
|
|
328
|
+
) -> Bool[Array, '']:
|
|
329
|
+
"""Check that the unused node slot at index 0 is not dirty."""
|
|
330
|
+
return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
|
|
131
331
|
|
|
132
332
|
|
|
133
|
-
|
|
134
|
-
|
|
333
|
+
@check
|
|
334
|
+
def check_leaf_values(
|
|
335
|
+
tree: TreeHeaps, _max_split: UInt[Array, ' p']
|
|
336
|
+
) -> Bool[Array, '']:
|
|
337
|
+
"""Check that all leaf values are not inf of nan."""
|
|
338
|
+
return jnp.all(jnp.isfinite(tree.leaf_tree))
|
|
135
339
|
|
|
136
340
|
|
|
137
|
-
|
|
341
|
+
@check
|
|
342
|
+
def check_stray_nodes(
|
|
343
|
+
tree: TreeHeaps, _max_split: UInt[Array, ' p']
|
|
344
|
+
) -> Bool[Array, '']:
|
|
345
|
+
"""Check if there is any marked-non-leaf node with a marked-leaf parent."""
|
|
138
346
|
index = jnp.arange(
|
|
139
|
-
2 * split_tree.size,
|
|
140
|
-
dtype=
|
|
347
|
+
2 * tree.split_tree.size,
|
|
348
|
+
dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
|
|
141
349
|
)
|
|
142
350
|
parent_index = index >> 1
|
|
143
|
-
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
144
|
-
parent_is_leaf = split_tree[parent_index] == 0
|
|
351
|
+
is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
352
|
+
parent_is_leaf = tree.split_tree[parent_index] == 0
|
|
145
353
|
stray = is_not_leaf & parent_is_leaf
|
|
146
354
|
stray = stray.at[1].set(False)
|
|
147
355
|
return ~jnp.any(stray)
|
|
148
356
|
|
|
149
357
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
358
|
+
@check
|
|
359
|
+
def check_rule_consistency(
|
|
360
|
+
tree: TreeHeaps, max_split: UInt[Array, ' p']
|
|
361
|
+
) -> bool | Bool[Array, '']:
|
|
362
|
+
"""Check that decision rules define proper subsets of ancestor rules."""
|
|
363
|
+
if tree.var_tree.size < 4:
|
|
364
|
+
return True
|
|
157
365
|
|
|
366
|
+
# initial boundaries of decision rules. use extreme integers instead of 0,
|
|
367
|
+
# max_split to avoid checking if there is something out of bounds.
|
|
368
|
+
dtype = tree.split_tree.dtype
|
|
369
|
+
small = jnp.iinfo(dtype).min
|
|
370
|
+
large = jnp.iinfo(dtype).max
|
|
371
|
+
lower = jnp.full(max_split.size, small, dtype)
|
|
372
|
+
upper = jnp.full(max_split.size, large, dtype)
|
|
373
|
+
# the split must be in (lower[var], upper[var]]
|
|
158
374
|
|
|
159
|
-
def
|
|
160
|
-
|
|
375
|
+
def _check_recursive(node, lower, upper):
|
|
376
|
+
# read decision rule
|
|
377
|
+
var = tree.var_tree[node]
|
|
378
|
+
split = tree.split_tree[node]
|
|
379
|
+
|
|
380
|
+
# get rule boundaries from ancestors. use fill value in case var is
|
|
381
|
+
# out of bounds, we don't want to check out of bounds in this function
|
|
382
|
+
lower_var = lower.at[var].get(mode='fill', fill_value=small)
|
|
383
|
+
upper_var = upper.at[var].get(mode='fill', fill_value=large)
|
|
384
|
+
|
|
385
|
+
# check rule is in bounds
|
|
386
|
+
bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False)
|
|
387
|
+
|
|
388
|
+
# recurse
|
|
389
|
+
if node < tree.var_tree.size // 2:
|
|
390
|
+
idx = jnp.where(split, var, max_split.size)
|
|
391
|
+
bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1))
|
|
392
|
+
bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper)
|
|
393
|
+
|
|
394
|
+
return bad
|
|
395
|
+
|
|
396
|
+
return ~_check_recursive(1, lower, upper)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@check
|
|
400
|
+
def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
|
|
401
|
+
"""Check that #leaves = 1 + #(internal nodes)."""
|
|
402
|
+
is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True)
|
|
403
|
+
num_leaves = jnp.count_nonzero(is_leaf)
|
|
404
|
+
num_internal = jnp.count_nonzero(tree.split_tree)
|
|
405
|
+
return num_leaves == num_internal + 1
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@check
|
|
409
|
+
def check_var_in_bounds(
|
|
410
|
+
tree: TreeHeaps, max_split: UInt[Array, ' p']
|
|
411
|
+
) -> Bool[Array, '']:
|
|
412
|
+
"""Check that variables are in [0, max_split.size)."""
|
|
413
|
+
decision_node = tree.split_tree.astype(bool)
|
|
414
|
+
in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size)
|
|
415
|
+
return jnp.all(in_bounds | ~decision_node)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
@check
|
|
419
|
+
def check_split_in_bounds(
|
|
420
|
+
tree: TreeHeaps, max_split: UInt[Array, ' p']
|
|
421
|
+
) -> Bool[Array, '']:
|
|
422
|
+
"""Check that splits are in [0, max_split[var]]."""
|
|
423
|
+
max_split_var = (
|
|
424
|
+
max_split.astype(jnp.int32)
|
|
425
|
+
.at[tree.var_tree]
|
|
426
|
+
.get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
|
|
427
|
+
)
|
|
428
|
+
return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var))
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
|
|
432
|
+
"""Check the validity of a tree.
|
|
433
|
+
|
|
434
|
+
Use `describe_error` to parse the error code returned by this function.
|
|
435
|
+
|
|
436
|
+
Parameters
|
|
437
|
+
----------
|
|
438
|
+
tree
|
|
439
|
+
The tree to check.
|
|
440
|
+
max_split
|
|
441
|
+
The maximum split value for each variable.
|
|
442
|
+
|
|
443
|
+
Returns
|
|
444
|
+
-------
|
|
445
|
+
An integer where each bit indicates whether a check failed.
|
|
446
|
+
"""
|
|
447
|
+
error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
161
448
|
error = error_type(0)
|
|
162
449
|
for i, func in enumerate(check_functions):
|
|
163
|
-
ok = func(
|
|
450
|
+
ok = func(tree, max_split)
|
|
164
451
|
ok = jnp.bool_(ok)
|
|
165
452
|
bit = (~ok) << i
|
|
166
453
|
error |= bit
|
|
167
454
|
return error
|
|
168
455
|
|
|
169
456
|
|
|
170
|
-
def describe_error(error):
|
|
457
|
+
def describe_error(error: int | Integer[Array, '']) -> list[str]:
|
|
458
|
+
"""Describe the error code returned by `check_tree`.
|
|
459
|
+
|
|
460
|
+
Parameters
|
|
461
|
+
----------
|
|
462
|
+
error
|
|
463
|
+
The error code returned by `check_tree`.
|
|
464
|
+
|
|
465
|
+
Returns
|
|
466
|
+
-------
|
|
467
|
+
A list of the function names that implement the failed checks.
|
|
468
|
+
"""
|
|
171
469
|
return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
|
|
172
470
|
|
|
173
471
|
|
|
174
|
-
|
|
472
|
+
@jit
|
|
473
|
+
def check_trace(
|
|
474
|
+
trace: TreeHeaps, max_split: UInt[Array, ' p']
|
|
475
|
+
) -> UInt[Array, '*batch_shape']:
|
|
476
|
+
"""Check the validity of a set of trees.
|
|
477
|
+
|
|
478
|
+
Use `describe_error` to parse the error codes returned by this function.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
trace
|
|
483
|
+
The set of trees to check. This object can have additional attributes
|
|
484
|
+
beyond the tree arrays, they are ignored.
|
|
485
|
+
max_split
|
|
486
|
+
The maximum split value for each variable.
|
|
487
|
+
|
|
488
|
+
Returns
|
|
489
|
+
-------
|
|
490
|
+
A tensor of error codes for each tree.
|
|
491
|
+
"""
|
|
492
|
+
# vectorize check_tree over all batch dimensions
|
|
493
|
+
unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split)
|
|
494
|
+
is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
|
|
495
|
+
signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()'
|
|
496
|
+
vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature)
|
|
497
|
+
|
|
498
|
+
# automatically batch over all batch dimensions
|
|
499
|
+
max_io_nbytes = 2**24 # 16 MiB
|
|
500
|
+
batch_ndim = trace.split_tree.ndim - 1
|
|
501
|
+
batched_check_tree = vec_check_tree
|
|
502
|
+
for i in reversed(range(batch_ndim)):
|
|
503
|
+
batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i)
|
|
175
504
|
|
|
505
|
+
return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree)
|
|
176
506
|
|
|
177
|
-
|
|
178
|
-
def
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
507
|
+
|
|
508
|
+
def _get_next_line(s: str, i: int) -> tuple[str, int]:
|
|
509
|
+
"""Get the next line from a string and the new index."""
|
|
510
|
+
i_new = s.find('\n', i)
|
|
511
|
+
if i_new == -1:
|
|
512
|
+
return s[i:], len(s)
|
|
513
|
+
return s[i:i_new], i_new + 1
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class BARTTraceMeta(Module):
|
|
517
|
+
"""Metadata of R BART tree traces."""
|
|
518
|
+
|
|
519
|
+
ndpost: int = field(static=True)
|
|
520
|
+
"""The number of posterior draws."""
|
|
521
|
+
|
|
522
|
+
ntree: int = field(static=True)
|
|
523
|
+
"""The number of trees in the model."""
|
|
524
|
+
|
|
525
|
+
numcut: UInt[Array, ' p']
|
|
526
|
+
"""The maximum split value for each variable."""
|
|
527
|
+
|
|
528
|
+
heap_size: int = field(static=True)
|
|
529
|
+
"""The size of the heap required to store the trees."""
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def scan_BART_trees(trees: str) -> BARTTraceMeta:
|
|
533
|
+
"""Scan an R BART tree trace checking for errors and parsing metadata.
|
|
534
|
+
|
|
535
|
+
Parameters
|
|
536
|
+
----------
|
|
537
|
+
trees
|
|
538
|
+
The string representation of a trace of trees of the R BART package.
|
|
539
|
+
Can be accessed from ``mc_gbart(...).treedraws['trees']``.
|
|
540
|
+
|
|
541
|
+
Returns
|
|
542
|
+
-------
|
|
543
|
+
An object containing the metadata.
|
|
544
|
+
|
|
545
|
+
Raises
|
|
546
|
+
------
|
|
547
|
+
ValueError
|
|
548
|
+
If the string is malformed or contains leftover characters.
|
|
549
|
+
"""
|
|
550
|
+
# parse first line
|
|
551
|
+
line, i_char = _get_next_line(trees, 0)
|
|
552
|
+
i_line = 1
|
|
553
|
+
match = fullmatch(r'(\d+) (\d+) (\d+)', line)
|
|
554
|
+
if match is None:
|
|
555
|
+
msg = f'Malformed header at {i_line=}'
|
|
556
|
+
raise ValueError(msg)
|
|
557
|
+
ndpost, ntree, p = map(int, match.groups())
|
|
558
|
+
|
|
559
|
+
# initial values for maxima
|
|
560
|
+
max_heap_index = 0
|
|
561
|
+
numcut = numpy.zeros(p, int)
|
|
562
|
+
|
|
563
|
+
# cycle over iterations and trees
|
|
564
|
+
for i_iter in range(ndpost):
|
|
565
|
+
for i_tree in range(ntree):
|
|
566
|
+
# parse first line of tree definition
|
|
567
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
568
|
+
i_line += 1
|
|
569
|
+
match = fullmatch(r'(\d+)', line)
|
|
570
|
+
if match is None:
|
|
571
|
+
msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
|
|
572
|
+
raise ValueError(msg)
|
|
573
|
+
num_nodes = int(line)
|
|
574
|
+
|
|
575
|
+
# cycle over nodes
|
|
576
|
+
for i_node in range(num_nodes):
|
|
577
|
+
# parse node definition
|
|
578
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
579
|
+
i_line += 1
|
|
580
|
+
match = fullmatch(
|
|
581
|
+
r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
|
|
582
|
+
)
|
|
583
|
+
if match is None:
|
|
584
|
+
msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
|
|
585
|
+
raise ValueError(msg)
|
|
586
|
+
i_heap = int(match.group(1))
|
|
587
|
+
var = int(match.group(2))
|
|
588
|
+
split = int(match.group(3))
|
|
589
|
+
|
|
590
|
+
# update maxima
|
|
591
|
+
numcut[var] = max(numcut[var], split)
|
|
592
|
+
max_heap_index = max(max_heap_index, i_heap)
|
|
593
|
+
|
|
594
|
+
assert i_char <= len(trees)
|
|
595
|
+
if i_char < len(trees):
|
|
596
|
+
msg = f'Leftover {len(trees) - i_char} characters in string'
|
|
597
|
+
raise ValueError(msg)
|
|
598
|
+
|
|
599
|
+
# determine minimal integer type for numcut
|
|
600
|
+
numcut += 1 # because BART is 0-based
|
|
601
|
+
split_dtype = minimal_unsigned_dtype(numcut.max())
|
|
602
|
+
numcut = jnp.array(numcut.astype(split_dtype))
|
|
603
|
+
|
|
604
|
+
# determine minimum heap size to store the trees
|
|
605
|
+
heap_size = 2 ** ceil(log2(max_heap_index + 1))
|
|
606
|
+
|
|
607
|
+
return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class TraceWithOffset(Module):
|
|
611
|
+
"""Implementation of `bartz.mcmcloop.Trace`."""
|
|
612
|
+
|
|
613
|
+
leaf_tree: Float32[Array, 'ndpost ntree 2**d']
|
|
614
|
+
var_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
|
|
615
|
+
split_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
|
|
616
|
+
offset: Float32[Array, ' ndpost']
|
|
617
|
+
|
|
618
|
+
@classmethod
|
|
619
|
+
def from_trees_trace(
|
|
620
|
+
cls, trees: TreeHeaps, offset: Float32[Array, '']
|
|
621
|
+
) -> 'TraceWithOffset':
|
|
622
|
+
"""Create a `TraceWithOffset` from a `TreeHeaps`."""
|
|
623
|
+
ndpost, _, _ = trees.leaf_tree.shape
|
|
624
|
+
return cls(
|
|
625
|
+
leaf_tree=trees.leaf_tree,
|
|
626
|
+
var_tree=trees.var_tree,
|
|
627
|
+
split_tree=trees.split_tree,
|
|
628
|
+
offset=jnp.full(ndpost, offset),
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def trees_BART_to_bartz(
|
|
633
|
+
trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
|
|
634
|
+
) -> tuple[TraceWithOffset, BARTTraceMeta]:
|
|
635
|
+
"""Convert trees from the R BART format to the bartz format.
|
|
636
|
+
|
|
637
|
+
Parameters
|
|
638
|
+
----------
|
|
639
|
+
trees
|
|
640
|
+
The string representation of a trace of trees of the R BART package.
|
|
641
|
+
Can be accessed from ``mc_gbart(...).treedraws['trees']``.
|
|
642
|
+
min_maxdepth
|
|
643
|
+
The maximum tree depth of the output will be set to the maximum
|
|
644
|
+
observed depth in the input trees. Use this parameter to require at
|
|
645
|
+
least this maximum depth in the output format.
|
|
646
|
+
offset
|
|
647
|
+
The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
|
|
648
|
+
summed to the sum of trees. To match that behavior, this function
|
|
649
|
+
returns an offset as well, zero by default. Set with this parameter
|
|
650
|
+
otherwise.
|
|
651
|
+
|
|
652
|
+
Returns
|
|
653
|
+
-------
|
|
654
|
+
trace : TraceWithOffset
|
|
655
|
+
A representation of the trees compatible with the trace returned by
|
|
656
|
+
`bartz.mcmcloop.run_mcmc`.
|
|
657
|
+
meta : BARTTraceMeta
|
|
658
|
+
The metadata of the trace, containing the number of iterations, trees,
|
|
659
|
+
and the maximum split value.
|
|
660
|
+
"""
|
|
661
|
+
# scan all the string checking for errors and determining sizes
|
|
662
|
+
meta = scan_BART_trees(trees)
|
|
663
|
+
|
|
664
|
+
# skip first line
|
|
665
|
+
_, i_char = _get_next_line(trees, 0)
|
|
666
|
+
|
|
667
|
+
heap_size = max(meta.heap_size, 2**min_maxdepth)
|
|
668
|
+
leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32)
|
|
669
|
+
var_trees = numpy.zeros(
|
|
670
|
+
(meta.ndpost, meta.ntree, heap_size // 2),
|
|
671
|
+
dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
|
|
672
|
+
)
|
|
673
|
+
split_trees = numpy.zeros(
|
|
674
|
+
(meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
|
|
184
675
|
)
|
|
676
|
+
|
|
677
|
+
# cycle over iterations and trees
|
|
678
|
+
for i_iter in range(meta.ndpost):
|
|
679
|
+
for i_tree in range(meta.ntree):
|
|
680
|
+
# parse first line of tree definition
|
|
681
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
682
|
+
num_nodes = int(line)
|
|
683
|
+
|
|
684
|
+
is_internal = numpy.zeros(heap_size // 2, dtype=bool)
|
|
685
|
+
|
|
686
|
+
# cycle over nodes
|
|
687
|
+
for _ in range(num_nodes):
|
|
688
|
+
# parse node definition
|
|
689
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
690
|
+
values = line.split()
|
|
691
|
+
i_heap = int(values[0])
|
|
692
|
+
var = int(values[1])
|
|
693
|
+
split = int(values[2])
|
|
694
|
+
leaf = float(values[3])
|
|
695
|
+
|
|
696
|
+
# update values
|
|
697
|
+
leaf_trees[i_iter, i_tree, i_heap] = leaf
|
|
698
|
+
is_internal[i_heap // 2] = True
|
|
699
|
+
if i_heap < heap_size // 2:
|
|
700
|
+
var_trees[i_iter, i_tree, i_heap] = var
|
|
701
|
+
split_trees[i_iter, i_tree, i_heap] = split + 1
|
|
702
|
+
|
|
703
|
+
is_internal[0] = False
|
|
704
|
+
split_trees[i_iter, i_tree, ~is_internal] = 0
|
|
705
|
+
|
|
706
|
+
return TraceWithOffset(
|
|
707
|
+
leaf_tree=jnp.array(leaf_trees),
|
|
708
|
+
var_tree=jnp.array(var_trees),
|
|
709
|
+
split_tree=jnp.array(split_trees),
|
|
710
|
+
offset=jnp.zeros(meta.ndpost)
|
|
711
|
+
if offset is None
|
|
712
|
+
else jnp.full(meta.ndpost, offset),
|
|
713
|
+
), meta
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
class SamplePriorStack(Module):
|
|
717
|
+
"""Represent the manually managed stack used in `sample_prior`.
|
|
718
|
+
|
|
719
|
+
Each level of the stack represents a recursion into a child node in a
|
|
720
|
+
binary tree of maximum depth `d`.
|
|
721
|
+
"""
|
|
722
|
+
|
|
723
|
+
nonterminal: Bool[Array, ' d-1']
|
|
724
|
+
"""Whether the node is valid or the recursion is into unused node slots."""
|
|
725
|
+
|
|
726
|
+
lower: UInt[Array, 'd-1 p']
|
|
727
|
+
"""The available cutpoints along ``var`` are in the integer range
|
|
728
|
+
``[1 + lower[var], 1 + upper[var])``."""
|
|
729
|
+
|
|
730
|
+
upper: UInt[Array, 'd-1 p']
|
|
731
|
+
"""The available cutpoints along ``var`` are in the integer range
|
|
732
|
+
``[1 + lower[var], 1 + upper[var])``."""
|
|
733
|
+
|
|
734
|
+
var: UInt[Array, ' d-1']
|
|
735
|
+
"""The variable of a decision node."""
|
|
736
|
+
|
|
737
|
+
split: UInt[Array, ' d-1']
|
|
738
|
+
"""The cutpoint of a decision node."""
|
|
739
|
+
|
|
740
|
+
@classmethod
|
|
741
|
+
def initial(
|
|
742
|
+
cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
|
|
743
|
+
) -> 'SamplePriorStack':
|
|
744
|
+
"""Initialize the stack.
|
|
745
|
+
|
|
746
|
+
Parameters
|
|
747
|
+
----------
|
|
748
|
+
p_nonterminal
|
|
749
|
+
The prior probability of a node being non-terminal conditional on
|
|
750
|
+
its ancestors and on having available decision rules, at each depth.
|
|
751
|
+
max_split
|
|
752
|
+
The number of cutpoints along each variable.
|
|
753
|
+
|
|
754
|
+
Returns
|
|
755
|
+
-------
|
|
756
|
+
A `SamplePriorStack` initialized to start the recursion.
|
|
757
|
+
"""
|
|
758
|
+
var_dtype = minimal_unsigned_dtype(max_split.size - 1)
|
|
759
|
+
return cls(
|
|
760
|
+
nonterminal=jnp.ones(p_nonterminal.size, bool),
|
|
761
|
+
lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
|
|
762
|
+
upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
|
|
763
|
+
var=jnp.zeros(p_nonterminal.size, var_dtype),
|
|
764
|
+
split=jnp.zeros(p_nonterminal.size, max_split.dtype),
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
class SamplePriorTrees(Module):
|
|
769
|
+
"""Object holding the trees generated by `sample_prior`."""
|
|
770
|
+
|
|
771
|
+
leaf_tree: Float32[Array, '* 2**d']
|
|
772
|
+
"""The array representing the trees, see `bartz.grove`."""
|
|
773
|
+
|
|
774
|
+
var_tree: UInt[Array, '* 2**(d-1)']
|
|
775
|
+
"""The array representing the trees, see `bartz.grove`."""
|
|
776
|
+
|
|
777
|
+
split_tree: UInt[Array, '* 2**(d-1)']
|
|
778
|
+
"""The array representing the trees, see `bartz.grove`."""
|
|
779
|
+
|
|
780
|
+
@classmethod
|
|
781
|
+
def initial(
|
|
782
|
+
cls,
|
|
783
|
+
key: Key[Array, ''],
|
|
784
|
+
sigma_mu: Float32[Array, ''],
|
|
785
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
786
|
+
max_split: UInt[Array, ' p'],
|
|
787
|
+
) -> 'SamplePriorTrees':
|
|
788
|
+
"""Initialize the trees.
|
|
789
|
+
|
|
790
|
+
The leaves are already correct and do not need to be changed.
|
|
791
|
+
|
|
792
|
+
Parameters
|
|
793
|
+
----------
|
|
794
|
+
key
|
|
795
|
+
A jax random key.
|
|
796
|
+
sigma_mu
|
|
797
|
+
The prior standard deviation of each leaf.
|
|
798
|
+
p_nonterminal
|
|
799
|
+
The prior probability of a node being non-terminal conditional on
|
|
800
|
+
its ancestors and on having available decision rules, at each depth.
|
|
801
|
+
max_split
|
|
802
|
+
The number of cutpoints along each variable.
|
|
803
|
+
|
|
804
|
+
Returns
|
|
805
|
+
-------
|
|
806
|
+
Trees initialized with random leaves and stub tree structures.
|
|
807
|
+
"""
|
|
808
|
+
heap_size = 2 ** (p_nonterminal.size + 1)
|
|
809
|
+
return cls(
|
|
810
|
+
leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
|
|
811
|
+
var_tree=jnp.zeros(
|
|
812
|
+
heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
|
|
813
|
+
),
|
|
814
|
+
split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
class SamplePriorCarry(Module):
|
|
819
|
+
"""Object holding values carried along the recursion in `sample_prior`."""
|
|
820
|
+
|
|
821
|
+
key: Key[Array, '']
|
|
822
|
+
"""A jax random key used to sample decision rules."""
|
|
823
|
+
|
|
824
|
+
stack: SamplePriorStack
|
|
825
|
+
"""The stack used to manage the recursion."""
|
|
826
|
+
|
|
827
|
+
trees: SamplePriorTrees
|
|
828
|
+
"""The output arrays."""
|
|
829
|
+
|
|
830
|
+
@classmethod
|
|
831
|
+
def initial(
|
|
832
|
+
cls,
|
|
833
|
+
key: Key[Array, ''],
|
|
834
|
+
sigma_mu: Float32[Array, ''],
|
|
835
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
836
|
+
max_split: UInt[Array, ' p'],
|
|
837
|
+
) -> 'SamplePriorCarry':
|
|
838
|
+
"""Initialize the carry object.
|
|
839
|
+
|
|
840
|
+
Parameters
|
|
841
|
+
----------
|
|
842
|
+
key
|
|
843
|
+
A jax random key.
|
|
844
|
+
sigma_mu
|
|
845
|
+
The prior standard deviation of each leaf.
|
|
846
|
+
p_nonterminal
|
|
847
|
+
The prior probability of a node being non-terminal conditional on
|
|
848
|
+
its ancestors and on having available decision rules, at each depth.
|
|
849
|
+
max_split
|
|
850
|
+
The number of cutpoints along each variable.
|
|
851
|
+
|
|
852
|
+
Returns
|
|
853
|
+
-------
|
|
854
|
+
A `SamplePriorCarry` initialized to start the recursion.
|
|
855
|
+
"""
|
|
856
|
+
keys = split_key(key)
|
|
857
|
+
return cls(
|
|
858
|
+
keys.pop(),
|
|
859
|
+
SamplePriorStack.initial(p_nonterminal, max_split),
|
|
860
|
+
SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
class SamplePriorX(Module):
|
|
865
|
+
"""Object representing the recursion scan in `sample_prior`.
|
|
866
|
+
|
|
867
|
+
The sequence of nodes to visit is pre-computed recursively once, unrolling
|
|
868
|
+
the recursion schedule.
|
|
869
|
+
"""
|
|
870
|
+
|
|
871
|
+
node: Int32[Array, ' 2**(d-1)-1']
|
|
872
|
+
"""The heap index of the node to visit."""
|
|
873
|
+
|
|
874
|
+
depth: Int32[Array, ' 2**(d-1)-1']
|
|
875
|
+
"""The depth of the node."""
|
|
876
|
+
|
|
877
|
+
next_depth: Int32[Array, ' 2**(d-1)-1']
|
|
878
|
+
"""The depth of the next node to visit, either the left child or the right
|
|
879
|
+
sibling of the node or of an ancestor."""
|
|
880
|
+
|
|
881
|
+
@classmethod
|
|
882
|
+
def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
|
|
883
|
+
"""Initialize the sequence of nodes to visit.
|
|
884
|
+
|
|
885
|
+
Parameters
|
|
886
|
+
----------
|
|
887
|
+
p_nonterminal
|
|
888
|
+
The prior probability of a node being non-terminal conditional on
|
|
889
|
+
its ancestors and on having available decision rules, at each depth.
|
|
890
|
+
|
|
891
|
+
Returns
|
|
892
|
+
-------
|
|
893
|
+
A `SamplePriorX` initialized with the sequence of nodes to visit.
|
|
894
|
+
"""
|
|
895
|
+
seq = cls._sequence(p_nonterminal.size)
|
|
896
|
+
assert len(seq) == 2**p_nonterminal.size - 1
|
|
897
|
+
node = [node for node, depth in seq]
|
|
898
|
+
depth = [depth for node, depth in seq]
|
|
899
|
+
next_depth = [*depth[1:], p_nonterminal.size]
|
|
900
|
+
return cls(
|
|
901
|
+
node=jnp.array(node),
|
|
902
|
+
depth=jnp.array(depth),
|
|
903
|
+
next_depth=jnp.array(next_depth),
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
@classmethod
|
|
907
|
+
def _sequence(
|
|
908
|
+
cls, max_depth: int, depth: int = 0, node: int = 1
|
|
909
|
+
) -> tuple[tuple[int, int], ...]:
|
|
910
|
+
"""Recursively generate a sequence [(node, depth), ...]."""
|
|
911
|
+
if depth < max_depth:
|
|
912
|
+
out = ((node, depth),)
|
|
913
|
+
out += cls._sequence(max_depth, depth + 1, 2 * node)
|
|
914
|
+
out += cls._sequence(max_depth, depth + 1, 2 * node + 1)
|
|
915
|
+
return out
|
|
916
|
+
return ()
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def sample_prior_onetree(
|
|
920
|
+
key: Key[Array, ''],
|
|
921
|
+
max_split: UInt[Array, ' p'],
|
|
922
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
923
|
+
sigma_mu: Float32[Array, ''],
|
|
924
|
+
) -> SamplePriorTrees:
|
|
925
|
+
"""Sample a tree from the BART prior.
|
|
926
|
+
|
|
927
|
+
Parameters
|
|
928
|
+
----------
|
|
929
|
+
key
|
|
930
|
+
A jax random key.
|
|
931
|
+
max_split
|
|
932
|
+
The maximum split value for each variable.
|
|
933
|
+
p_nonterminal
|
|
934
|
+
The prior probability of a node being non-terminal conditional on
|
|
935
|
+
its ancestors and on having available decision rules, at each depth.
|
|
936
|
+
sigma_mu
|
|
937
|
+
The prior standard deviation of each leaf.
|
|
938
|
+
|
|
939
|
+
Returns
|
|
940
|
+
-------
|
|
941
|
+
An object containing a generated tree.
|
|
942
|
+
"""
|
|
943
|
+
carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split)
|
|
944
|
+
xs = SamplePriorX.initial(p_nonterminal)
|
|
945
|
+
|
|
946
|
+
def loop(carry: SamplePriorCarry, x: SamplePriorX):
|
|
947
|
+
keys = split_key(carry.key, 4)
|
|
948
|
+
|
|
949
|
+
# get variables at current stack level
|
|
950
|
+
stack = carry.stack
|
|
951
|
+
nonterminal = stack.nonterminal[x.depth]
|
|
952
|
+
lower = stack.lower[x.depth, :]
|
|
953
|
+
upper = stack.upper[x.depth, :]
|
|
954
|
+
|
|
955
|
+
# sample a random decision rule
|
|
956
|
+
available: Bool[Array, ' p'] = lower < upper
|
|
957
|
+
allowed = jnp.any(available)
|
|
958
|
+
var = randint_masked(keys.pop(), available)
|
|
959
|
+
split = 1 + random.randint(keys.pop(), (), lower[var], upper[var])
|
|
960
|
+
|
|
961
|
+
# cast to shorter integer types
|
|
962
|
+
var = var.astype(carry.trees.var_tree.dtype)
|
|
963
|
+
split = split.astype(carry.trees.split_tree.dtype)
|
|
964
|
+
|
|
965
|
+
# decide whether to try to grow the node if it is growable
|
|
966
|
+
pnt = p_nonterminal[x.depth]
|
|
967
|
+
try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt)
|
|
968
|
+
nonterminal &= try_nonterminal & allowed
|
|
969
|
+
|
|
970
|
+
# update trees
|
|
971
|
+
trees = carry.trees
|
|
972
|
+
trees = replace(
|
|
973
|
+
trees,
|
|
974
|
+
var_tree=trees.var_tree.at[x.node].set(var),
|
|
975
|
+
split_tree=trees.split_tree.at[x.node].set(
|
|
976
|
+
jnp.where(nonterminal, split, 0)
|
|
977
|
+
),
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
def write_push_stack() -> SamplePriorStack:
|
|
981
|
+
"""Update the stack to go to the left child."""
|
|
982
|
+
return replace(
|
|
983
|
+
stack,
|
|
984
|
+
nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
|
|
985
|
+
lower=stack.lower.at[x.next_depth, :].set(lower),
|
|
986
|
+
upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
|
|
987
|
+
var=stack.var.at[x.depth].set(var),
|
|
988
|
+
split=stack.split.at[x.depth].set(split),
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
def pop_push_stack() -> SamplePriorStack:
|
|
992
|
+
"""Update the stack to go to the right sibling, possibly at lower depth."""
|
|
993
|
+
var = stack.var[x.next_depth - 1]
|
|
994
|
+
split = stack.split[x.next_depth - 1]
|
|
995
|
+
lower = stack.lower[x.next_depth - 1, :]
|
|
996
|
+
upper = stack.upper[x.next_depth - 1, :]
|
|
997
|
+
return replace(
|
|
998
|
+
stack,
|
|
999
|
+
lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
|
|
1000
|
+
upper=stack.upper.at[x.next_depth, :].set(upper),
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
# update stack
|
|
1004
|
+
stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack)
|
|
1005
|
+
|
|
1006
|
+
# update carry
|
|
1007
|
+
carry = replace(carry, key=keys.pop(), stack=stack, trees=trees)
|
|
1008
|
+
return carry, None
|
|
1009
|
+
|
|
1010
|
+
carry, _ = lax.scan(loop, carry, xs)
|
|
1011
|
+
return carry.trees
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
@partial(vmap_nodoc, in_axes=(0, None, None, None))
|
|
1015
|
+
def sample_prior_forest(
|
|
1016
|
+
keys: Key[Array, ' num_trees'],
|
|
1017
|
+
max_split: UInt[Array, ' p'],
|
|
1018
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
1019
|
+
sigma_mu: Float32[Array, ''],
|
|
1020
|
+
) -> SamplePriorTrees:
|
|
1021
|
+
"""Sample a set of independent trees from the BART prior.
|
|
1022
|
+
|
|
1023
|
+
Parameters
|
|
1024
|
+
----------
|
|
1025
|
+
keys
|
|
1026
|
+
A sequence of jax random keys, one for each tree. This determined the
|
|
1027
|
+
number of trees sampled.
|
|
1028
|
+
max_split
|
|
1029
|
+
The maximum split value for each variable.
|
|
1030
|
+
p_nonterminal
|
|
1031
|
+
The prior probability of a node being non-terminal conditional on
|
|
1032
|
+
its ancestors and on having available decision rules, at each depth.
|
|
1033
|
+
sigma_mu
|
|
1034
|
+
The prior standard deviation of each leaf.
|
|
1035
|
+
|
|
1036
|
+
Returns
|
|
1037
|
+
-------
|
|
1038
|
+
An object containing the generated trees.
|
|
1039
|
+
"""
|
|
1040
|
+
return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
@partial(jit, static_argnums=(1, 2))
|
|
1044
|
+
def sample_prior(
|
|
1045
|
+
key: Key[Array, ''],
|
|
1046
|
+
trace_length: int,
|
|
1047
|
+
num_trees: int,
|
|
1048
|
+
max_split: UInt[Array, ' p'],
|
|
1049
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
1050
|
+
sigma_mu: Float32[Array, ''],
|
|
1051
|
+
) -> SamplePriorTrees:
|
|
1052
|
+
"""Sample independent trees from the BART prior.
|
|
1053
|
+
|
|
1054
|
+
Parameters
|
|
1055
|
+
----------
|
|
1056
|
+
key
|
|
1057
|
+
A jax random key.
|
|
1058
|
+
trace_length
|
|
1059
|
+
The number of iterations.
|
|
1060
|
+
num_trees
|
|
1061
|
+
The number of trees for each iteration.
|
|
1062
|
+
max_split
|
|
1063
|
+
The number of cutpoints along each variable.
|
|
1064
|
+
p_nonterminal
|
|
1065
|
+
The prior probability of a node being non-terminal conditional on
|
|
1066
|
+
its ancestors and on having available decision rules, at each depth.
|
|
1067
|
+
This determines the maximum depth of the trees.
|
|
1068
|
+
sigma_mu
|
|
1069
|
+
The prior standard deviation of each leaf.
|
|
1070
|
+
|
|
1071
|
+
Returns
|
|
1072
|
+
-------
|
|
1073
|
+
An object containing the generated trees, with batch shape (trace_length, num_trees).
|
|
1074
|
+
"""
|
|
1075
|
+
keys = random.split(key, trace_length * num_trees)
|
|
1076
|
+
trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu)
|
|
1077
|
+
return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees)
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
class debug_mc_gbart(mc_gbart):
|
|
1081
|
+
"""A subclass of `mc_gbart` that adds debugging functionality.
|
|
1082
|
+
|
|
1083
|
+
Parameters
|
|
1084
|
+
----------
|
|
1085
|
+
*args
|
|
1086
|
+
Passed to `mc_gbart`.
|
|
1087
|
+
check_trees
|
|
1088
|
+
If `True`, check all trees with `check_trace` after running the MCMC,
|
|
1089
|
+
and assert that they are all valid. Set to `False` to allow jax tracing.
|
|
1090
|
+
**kw
|
|
1091
|
+
Passed to `mc_gbart`.
|
|
1092
|
+
"""
|
|
1093
|
+
|
|
1094
|
+
def __init__(self, *args, check_trees: bool = True, **kw):
|
|
1095
|
+
super().__init__(*args, **kw)
|
|
1096
|
+
if check_trees:
|
|
1097
|
+
bad = self.check_trees()
|
|
1098
|
+
bad_count = jnp.count_nonzero(bad)
|
|
1099
|
+
assert bad_count == 0
|
|
1100
|
+
|
|
1101
|
+
def print_tree(
|
|
1102
|
+
self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
|
|
1103
|
+
):
|
|
1104
|
+
"""Print a single tree in human-readable format.
|
|
1105
|
+
|
|
1106
|
+
Parameters
|
|
1107
|
+
----------
|
|
1108
|
+
i_chain
|
|
1109
|
+
The index of the MCMC chain.
|
|
1110
|
+
i_sample
|
|
1111
|
+
The index of the (post-burnin) sample in the chain.
|
|
1112
|
+
i_tree
|
|
1113
|
+
The index of the tree in the sample.
|
|
1114
|
+
print_all
|
|
1115
|
+
If `True`, also print the content of unused node slots.
|
|
1116
|
+
"""
|
|
1117
|
+
tree = TreesTrace.from_dataclass(self._main_trace)
|
|
1118
|
+
tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
|
|
1119
|
+
s = format_tree(tree, print_all=print_all)
|
|
1120
|
+
print(s) # noqa: T201, this method is intended for debug
|
|
1121
|
+
|
|
1122
|
+
def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']:
|
|
1123
|
+
"""Return the harmonic mean of the error variance.
|
|
1124
|
+
|
|
1125
|
+
Parameters
|
|
1126
|
+
----------
|
|
1127
|
+
prior
|
|
1128
|
+
If `True`, use the prior distribution, otherwise use the full
|
|
1129
|
+
conditional at the last MCMC iteration.
|
|
1130
|
+
|
|
1131
|
+
Returns
|
|
1132
|
+
-------
|
|
1133
|
+
The harmonic mean 1/E[1/sigma^2] in the selected distribution.
|
|
1134
|
+
"""
|
|
1135
|
+
bart = self._mcmc_state
|
|
1136
|
+
assert bart.error_cov_df is not None
|
|
1137
|
+
assert bart.z is None
|
|
1138
|
+
# inverse gamma prior: alpha = df / 2, beta = scale / 2
|
|
1139
|
+
if prior:
|
|
1140
|
+
alpha = bart.error_cov_df / 2
|
|
1141
|
+
beta = bart.error_cov_scale / 2
|
|
1142
|
+
else:
|
|
1143
|
+
alpha = bart.error_cov_df / 2 + bart.resid.size / 2
|
|
1144
|
+
norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
|
|
1145
|
+
beta = bart.error_cov_scale / 2 + norm2 / 2
|
|
1146
|
+
error_cov_inv = alpha / beta
|
|
1147
|
+
return jnp.sqrt(lax.reciprocal(error_cov_inv))
|
|
1148
|
+
|
|
1149
|
+
def compare_resid(
|
|
1150
|
+
self,
|
|
1151
|
+
) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
|
|
1152
|
+
"""Re-compute residuals to compare them with the updated ones.
|
|
1153
|
+
|
|
1154
|
+
Returns
|
|
1155
|
+
-------
|
|
1156
|
+
resid1 : Float32[Array, 'mc_cores n']
|
|
1157
|
+
The final state of the residuals updated during the MCMC.
|
|
1158
|
+
resid2 : Float32[Array, 'mc_cores n']
|
|
1159
|
+
The residuals computed from the final state of the trees.
|
|
1160
|
+
"""
|
|
1161
|
+
bart = self._mcmc_state
|
|
1162
|
+
resid1 = bart.resid
|
|
1163
|
+
|
|
1164
|
+
forests = TreesTrace.from_dataclass(bart.forest)
|
|
1165
|
+
trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1)
|
|
1166
|
+
|
|
1167
|
+
if bart.z is not None:
|
|
1168
|
+
ref = bart.z
|
|
1169
|
+
else:
|
|
1170
|
+
ref = bart.y
|
|
1171
|
+
resid2 = ref - (trees + bart.offset)
|
|
1172
|
+
|
|
1173
|
+
return resid1, resid2
|
|
1174
|
+
|
|
1175
|
+
def avg_acc(
|
|
1176
|
+
self,
|
|
1177
|
+
) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
|
|
1178
|
+
"""Compute the average acceptance rates of tree moves.
|
|
1179
|
+
|
|
1180
|
+
Returns
|
|
1181
|
+
-------
|
|
1182
|
+
acc_grow : Float32[Array, 'mc_cores']
|
|
1183
|
+
The average acceptance rate of grow moves.
|
|
1184
|
+
acc_prune : Float32[Array, 'mc_cores']
|
|
1185
|
+
The average acceptance rate of prune moves.
|
|
1186
|
+
"""
|
|
1187
|
+
trace = self._main_trace
|
|
1188
|
+
|
|
1189
|
+
def acc(prefix):
|
|
1190
|
+
acc = getattr(trace, f'{prefix}_acc_count')
|
|
1191
|
+
prop = getattr(trace, f'{prefix}_prop_count')
|
|
1192
|
+
return acc.sum(axis=1) / prop.sum(axis=1)
|
|
1193
|
+
|
|
1194
|
+
return acc('grow'), acc('prune')
|
|
1195
|
+
|
|
1196
|
+
def avg_prop(
|
|
1197
|
+
self,
|
|
1198
|
+
) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
|
|
1199
|
+
"""Compute the average proposal rate of grow and prune moves.
|
|
1200
|
+
|
|
1201
|
+
Returns
|
|
1202
|
+
-------
|
|
1203
|
+
prop_grow : Float32[Array, 'mc_cores']
|
|
1204
|
+
The fraction of times grow was proposed instead of prune.
|
|
1205
|
+
prop_prune : Float32[Array, 'mc_cores']
|
|
1206
|
+
The fraction of times prune was proposed instead of grow.
|
|
1207
|
+
|
|
1208
|
+
Notes
|
|
1209
|
+
-----
|
|
1210
|
+
This function does not take into account cases where no move was
|
|
1211
|
+
proposed.
|
|
1212
|
+
"""
|
|
1213
|
+
trace = self._main_trace
|
|
1214
|
+
|
|
1215
|
+
def prop(prefix):
|
|
1216
|
+
return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
|
|
1217
|
+
|
|
1218
|
+
pgrow = prop('grow')
|
|
1219
|
+
pprune = prop('prune')
|
|
1220
|
+
total = pgrow + pprune
|
|
1221
|
+
return pgrow / total, pprune / total
|
|
1222
|
+
|
|
1223
|
+
def avg_move(
|
|
1224
|
+
self,
|
|
1225
|
+
) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
|
|
1226
|
+
"""Compute the move rate.
|
|
1227
|
+
|
|
1228
|
+
Returns
|
|
1229
|
+
-------
|
|
1230
|
+
rate_grow : Float32[Array, 'mc_cores']
|
|
1231
|
+
The fraction of times a grow move was proposed and accepted.
|
|
1232
|
+
rate_prune : Float32[Array, 'mc_cores']
|
|
1233
|
+
The fraction of times a prune move was proposed and accepted.
|
|
1234
|
+
"""
|
|
1235
|
+
agrow, aprune = self.avg_acc()
|
|
1236
|
+
pgrow, pprune = self.avg_prop()
|
|
1237
|
+
return agrow * pgrow, aprune * pprune
|
|
1238
|
+
|
|
1239
|
+
def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']:
|
|
1240
|
+
"""Histogram of tree depths for each state of the trees.
|
|
1241
|
+
|
|
1242
|
+
Returns
|
|
1243
|
+
-------
|
|
1244
|
+
A matrix where each row contains a histogram of tree depths.
|
|
1245
|
+
"""
|
|
1246
|
+
out: Int32[Array, '*chains samples d']
|
|
1247
|
+
out = forest_depth_distr(self._main_trace.split_tree)
|
|
1248
|
+
if out.ndim < 3:
|
|
1249
|
+
out = out[None, :, :]
|
|
1250
|
+
return out
|
|
1251
|
+
|
|
1252
|
+
def _points_per_node_distr(
|
|
1253
|
+
self, node_type: str
|
|
1254
|
+
) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
|
|
1255
|
+
out: Int32[Array, '*chains samples n+1']
|
|
1256
|
+
out = points_per_node_distr(
|
|
1257
|
+
self._mcmc_state.X,
|
|
1258
|
+
self._main_trace.var_tree,
|
|
1259
|
+
self._main_trace.split_tree,
|
|
1260
|
+
node_type,
|
|
1261
|
+
sum_batch_axis=-1,
|
|
1262
|
+
)
|
|
1263
|
+
if out.ndim < 3:
|
|
1264
|
+
out = out[None, :, :]
|
|
1265
|
+
return out
|
|
1266
|
+
|
|
1267
|
+
def points_per_decision_node_distr(
|
|
1268
|
+
self,
|
|
1269
|
+
) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
|
|
1270
|
+
"""Histogram of number of points belonging to parent-of-leaf nodes.
|
|
1271
|
+
|
|
1272
|
+
Returns
|
|
1273
|
+
-------
|
|
1274
|
+
For each chain, a matrix where each row contains a histogram of number of points.
|
|
1275
|
+
"""
|
|
1276
|
+
return self._points_per_node_distr('leaf-parent')
|
|
1277
|
+
|
|
1278
|
+
def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
|
|
1279
|
+
"""Histogram of number of points belonging to leaves.
|
|
1280
|
+
|
|
1281
|
+
Returns
|
|
1282
|
+
-------
|
|
1283
|
+
A matrix where each row contains a histogram of number of points.
|
|
1284
|
+
"""
|
|
1285
|
+
return self._points_per_node_distr('leaf')
|
|
1286
|
+
|
|
1287
|
+
def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']:
|
|
1288
|
+
"""Apply `check_trace` to all the tree draws."""
|
|
1289
|
+
out: UInt[Array, '*chains samples num_trees']
|
|
1290
|
+
out = check_trace(self._main_trace, self._mcmc_state.forest.max_split)
|
|
1291
|
+
if out.ndim < 3:
|
|
1292
|
+
out = out[None, :, :]
|
|
1293
|
+
return out
|
|
1294
|
+
|
|
1295
|
+
def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']:
|
|
1296
|
+
"""Find iterations where a tree becomes invalid.
|
|
1297
|
+
|
|
1298
|
+
Returns
|
|
1299
|
+
-------
|
|
1300
|
+
A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
|
|
1301
|
+
"""
|
|
1302
|
+
bad = self.check_trees().astype(bool)
|
|
1303
|
+
bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
|
|
1304
|
+
return bad & ~bad_before
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
class debug_gbart(debug_mc_gbart, gbart):
|
|
1308
|
+
"""A subclass of `gbart` that adds debugging functionality.
|
|
1309
|
+
|
|
1310
|
+
Parameters
|
|
1311
|
+
----------
|
|
1312
|
+
*args
|
|
1313
|
+
Passed to `gbart`.
|
|
1314
|
+
check_trees
|
|
1315
|
+
If `True`, check all trees with `check_trace` after running the MCMC,
|
|
1316
|
+
and assert that they are all valid. Set to `False` to allow jax tracing.
|
|
1317
|
+
**kw
|
|
1318
|
+
Passed to `gbart`.
|
|
1319
|
+
"""
|