bartz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/BART.py +582 -279
- bartz/__init__.py +3 -3
- bartz/_version.py +1 -1
- bartz/debug.py +1259 -79
- bartz/grove.py +168 -81
- bartz/jaxext/__init__.py +213 -0
- bartz/jaxext/_autobatch.py +238 -0
- bartz/jaxext/scipy/__init__.py +25 -0
- bartz/jaxext/scipy/special.py +240 -0
- bartz/jaxext/scipy/stats.py +36 -0
- bartz/mcmcloop.py +568 -158
- bartz/mcmcstep.py +1722 -926
- bartz/prepcovars.py +142 -44
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/METADATA +6 -5
- bartz-0.7.0.dist-info/RECORD +17 -0
- {bartz-0.5.0.dist-info → bartz-0.7.0.dist-info}/WHEEL +1 -1
- bartz/jaxext.py +0 -374
- bartz-0.5.0.dist-info/RECORD +0 -13
bartz/debug.py
CHANGED
|
@@ -1,13 +1,72 @@
|
|
|
1
|
-
|
|
1
|
+
# bartz/src/bartz/debug.py
|
|
2
|
+
#
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
|
+
#
|
|
5
|
+
# This file is part of bartz.
|
|
6
|
+
#
|
|
7
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
# in the Software without restriction, including without limitation the rights
|
|
10
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
# furnished to do so, subject to the following conditions:
|
|
13
|
+
#
|
|
14
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
15
|
+
# copies or substantial portions of the Software.
|
|
16
|
+
#
|
|
17
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
|
+
# SOFTWARE.
|
|
2
24
|
|
|
3
|
-
|
|
4
|
-
|
|
25
|
+
"""Debugging utilities. The entry point is the class `debug_gbart`."""
|
|
26
|
+
|
|
27
|
+
from collections.abc import Callable
|
|
28
|
+
from dataclasses import replace
|
|
29
|
+
from functools import partial
|
|
30
|
+
from math import ceil, log2
|
|
31
|
+
from re import fullmatch
|
|
32
|
+
|
|
33
|
+
import numpy
|
|
34
|
+
from equinox import Module, field
|
|
35
|
+
from jax import jit, lax, random, vmap
|
|
5
36
|
from jax import numpy as jnp
|
|
37
|
+
from jax.tree_util import tree_map
|
|
38
|
+
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
|
|
39
|
+
|
|
40
|
+
from bartz.BART import FloatLike, gbart
|
|
41
|
+
from bartz.grove import (
|
|
42
|
+
TreeHeaps,
|
|
43
|
+
evaluate_forest,
|
|
44
|
+
is_actual_leaf,
|
|
45
|
+
is_leaves_parent,
|
|
46
|
+
traverse_tree,
|
|
47
|
+
tree_depth,
|
|
48
|
+
tree_depths,
|
|
49
|
+
)
|
|
50
|
+
from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
|
|
51
|
+
from bartz.jaxext import split as split_key
|
|
52
|
+
from bartz.mcmcloop import TreesTrace
|
|
53
|
+
from bartz.mcmcstep import randint_masked
|
|
54
|
+
|
|
6
55
|
|
|
7
|
-
|
|
56
|
+
def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
|
|
57
|
+
"""Convert a tree to a human-readable string.
|
|
8
58
|
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
tree
|
|
62
|
+
A single tree to format.
|
|
63
|
+
print_all
|
|
64
|
+
If `True`, also print the contents of unused node slots in the arrays.
|
|
9
65
|
|
|
10
|
-
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
A string representation of the tree.
|
|
69
|
+
"""
|
|
11
70
|
tee = '├──'
|
|
12
71
|
corner = '└──'
|
|
13
72
|
join = '│ '
|
|
@@ -15,12 +74,20 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
15
74
|
down = '┐'
|
|
16
75
|
bottom = '╢' # '┨' #
|
|
17
76
|
|
|
18
|
-
def traverse_tree(
|
|
19
|
-
|
|
77
|
+
def traverse_tree(
|
|
78
|
+
lines: list[str],
|
|
79
|
+
index: int,
|
|
80
|
+
depth: int,
|
|
81
|
+
indent: str,
|
|
82
|
+
first_indent: str,
|
|
83
|
+
next_indent: str,
|
|
84
|
+
unused: bool,
|
|
85
|
+
):
|
|
86
|
+
if index >= len(tree.leaf_tree):
|
|
20
87
|
return
|
|
21
88
|
|
|
22
|
-
var = var_tree.at[index].get(mode='fill', fill_value=0)
|
|
23
|
-
split = split_tree.at[index].get(mode='fill', fill_value=0)
|
|
89
|
+
var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item()
|
|
90
|
+
split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item()
|
|
24
91
|
|
|
25
92
|
is_leaf = split == 0
|
|
26
93
|
left_child = 2 * index
|
|
@@ -33,26 +100,26 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
33
100
|
category = 'leaf'
|
|
34
101
|
else:
|
|
35
102
|
category = 'decision'
|
|
36
|
-
node_str = f'{category}({var}, {split}, {leaf_tree[index]})'
|
|
103
|
+
node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
|
|
37
104
|
else:
|
|
38
105
|
assert not unused
|
|
39
106
|
if is_leaf:
|
|
40
|
-
node_str = f'{leaf_tree[index]:#.2g}'
|
|
107
|
+
node_str = f'{tree.leaf_tree[index]:#.2g}'
|
|
41
108
|
else:
|
|
42
|
-
node_str = f'
|
|
109
|
+
node_str = f'x{var} < {split}'
|
|
43
110
|
|
|
44
|
-
if not is_leaf or (print_all and left_child < len(leaf_tree)):
|
|
111
|
+
if not is_leaf or (print_all and left_child < len(tree.leaf_tree)):
|
|
45
112
|
link = down
|
|
46
|
-
elif not print_all and left_child >= len(leaf_tree):
|
|
113
|
+
elif not print_all and left_child >= len(tree.leaf_tree):
|
|
47
114
|
link = bottom
|
|
48
115
|
else:
|
|
49
116
|
link = ' '
|
|
50
117
|
|
|
51
|
-
max_number = len(leaf_tree) - 1
|
|
118
|
+
max_number = len(tree.leaf_tree) - 1
|
|
52
119
|
ndigits = len(str(max_number))
|
|
53
120
|
number = str(index).rjust(ndigits)
|
|
54
121
|
|
|
55
|
-
|
|
122
|
+
lines.append(f' {number} {indent}{first_indent}{link}{node_str}')
|
|
56
123
|
|
|
57
124
|
indent += next_indent
|
|
58
125
|
unused = unused or is_leaf
|
|
@@ -60,125 +127,1238 @@ def print_tree(leaf_tree, var_tree, split_tree, print_all=False):
|
|
|
60
127
|
if unused and not print_all:
|
|
61
128
|
return
|
|
62
129
|
|
|
63
|
-
traverse_tree(left_child, depth + 1, indent, tee, join, unused)
|
|
64
|
-
traverse_tree(right_child, depth + 1, indent, corner, space, unused)
|
|
130
|
+
traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused)
|
|
131
|
+
traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused)
|
|
65
132
|
|
|
66
|
-
|
|
133
|
+
lines = []
|
|
134
|
+
traverse_tree(lines, 1, 0, '', '', '', False)
|
|
135
|
+
return '\n'.join(lines)
|
|
67
136
|
|
|
68
137
|
|
|
69
|
-
def tree_actual_depth(split_tree):
|
|
70
|
-
|
|
71
|
-
|
|
138
|
+
def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
|
|
139
|
+
"""Measure the depth of the tree.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
split_tree
|
|
144
|
+
The cutpoints of the decision rules.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
The depth of the deepest leaf in the tree. The root is at depth 0.
|
|
149
|
+
"""
|
|
150
|
+
# this could be done just with split_tree != 0
|
|
151
|
+
is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
|
|
152
|
+
depth = tree_depths(is_leaf.size)
|
|
72
153
|
depth = jnp.where(is_leaf, depth, 0)
|
|
73
154
|
return jnp.max(depth)
|
|
74
155
|
|
|
75
156
|
|
|
76
|
-
def forest_depth_distr(
|
|
77
|
-
|
|
78
|
-
|
|
157
|
+
def forest_depth_distr(
|
|
158
|
+
split_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
159
|
+
) -> Int32[Array, ' d']:
|
|
160
|
+
"""Histogram the depths of a set of trees.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
split_tree
|
|
165
|
+
The cutpoints of the decision rules of the trees.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
An integer vector where the i-th element counts how many trees have depth i.
|
|
170
|
+
"""
|
|
171
|
+
depth = tree_depth(split_tree) + 1
|
|
172
|
+
depths = vmap(tree_actual_depth)(split_tree)
|
|
79
173
|
return jnp.bincount(depths, length=depth)
|
|
80
174
|
|
|
81
175
|
|
|
82
|
-
|
|
83
|
-
|
|
176
|
+
@jit
|
|
177
|
+
def trace_depth_distr(
|
|
178
|
+
split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'],
|
|
179
|
+
) -> Int32[Array, 'trace_length d']:
|
|
180
|
+
"""Histogram the depths of a sequence of sets of trees.
|
|
84
181
|
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
split_tree
|
|
185
|
+
The cutpoints of the decision rules of the trees.
|
|
85
186
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
A matrix where element (t,i) counts how many trees have depth i in set t.
|
|
190
|
+
"""
|
|
191
|
+
return vmap(forest_depth_distr)(split_tree)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def points_per_decision_node_distr(
|
|
195
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
196
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
197
|
+
X: UInt[Array, 'p n'],
|
|
198
|
+
) -> Int32[Array, ' n+1']:
|
|
199
|
+
"""Histogram points-per-node counts.
|
|
200
|
+
|
|
201
|
+
Count how many parent-of-leaf nodes in a tree select each possible amount
|
|
202
|
+
of points.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
var_tree
|
|
207
|
+
The variables of the decision rules.
|
|
208
|
+
split_tree
|
|
209
|
+
The cutpoints of the decision rules.
|
|
210
|
+
X
|
|
211
|
+
The set of points to count.
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
A vector where the i-th element counts how many next-to-leaf nodes have i points.
|
|
216
|
+
"""
|
|
217
|
+
traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
|
|
218
|
+
indices = traverse_tree_X(X, var_tree, split_tree)
|
|
219
|
+
indices >>= 1
|
|
220
|
+
count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0)
|
|
221
|
+
is_parent = is_leaves_parent(split_tree)
|
|
222
|
+
return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def forest_points_per_decision_node_distr(
|
|
226
|
+
trees: TreeHeaps, X: UInt[Array, 'p n']
|
|
227
|
+
) -> Int32[Array, ' n+1']:
|
|
228
|
+
"""Histogram points-per-node counts for a set of trees.
|
|
229
|
+
|
|
230
|
+
Count how many parent-of-leaf nodes in a set of trees select each possible
|
|
231
|
+
amount of points.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
trees
|
|
236
|
+
The set of trees. The variables must have broadcast shape (num_trees,).
|
|
237
|
+
X
|
|
238
|
+
The set of points to count.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
A vector where the i-th element counts how many next-to-leaf nodes have i points.
|
|
243
|
+
"""
|
|
244
|
+
distr = jnp.zeros(X.shape[1] + 1, int)
|
|
245
|
+
|
|
246
|
+
def loop(distr, heaps: tuple[Array, Array]):
|
|
247
|
+
return distr + points_per_decision_node_distr(*heaps, X), None
|
|
248
|
+
|
|
249
|
+
distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
|
|
250
|
+
return distr
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@jit
|
|
254
|
+
def trace_points_per_decision_node_distr(
|
|
255
|
+
trace: TreeHeaps, X: UInt[Array, 'p n']
|
|
256
|
+
) -> Int32[Array, 'trace_length n+1']:
|
|
257
|
+
"""Separately histogram points-per-node counts over a sequence of sets of trees.
|
|
95
258
|
|
|
259
|
+
For each set of trees, count how many parent-of-leaf nodes select each
|
|
260
|
+
possible amount of points.
|
|
96
261
|
|
|
97
|
-
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
trace
|
|
265
|
+
The sequence of sets of trees. The variables must have broadcast shape
|
|
266
|
+
(trace_length, num_trees).
|
|
267
|
+
X
|
|
268
|
+
The set of points to count.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def loop(_, trace):
|
|
276
|
+
return None, forest_points_per_decision_node_distr(trace, X)
|
|
277
|
+
|
|
278
|
+
_, distr = lax.scan(loop, None, trace)
|
|
279
|
+
return distr
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def points_per_leaf_distr(
|
|
283
|
+
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
284
|
+
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
285
|
+
X: UInt[Array, 'p n'],
|
|
286
|
+
) -> Int32[Array, ' n+1']:
|
|
287
|
+
"""Histogram points-per-leaf counts in a tree.
|
|
288
|
+
|
|
289
|
+
Count how many leaves in a tree select each possible amount of points.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
var_tree
|
|
294
|
+
The variables of the decision rules.
|
|
295
|
+
split_tree
|
|
296
|
+
The cutpoints of the decision rules.
|
|
297
|
+
X
|
|
298
|
+
The set of points to count.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
A vector where the i-th element counts how many leaves have i points.
|
|
303
|
+
"""
|
|
304
|
+
traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
|
|
305
|
+
indices = traverse_tree_X(X, var_tree, split_tree)
|
|
306
|
+
count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1)
|
|
307
|
+
is_leaf = is_actual_leaf(split_tree, add_bottom_level=True)
|
|
308
|
+
return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def forest_points_per_leaf_distr(
|
|
312
|
+
trees: TreeHeaps, X: UInt[Array, 'p n']
|
|
313
|
+
) -> Int32[Array, ' n+1']:
|
|
314
|
+
"""Histogram points-per-leaf counts over a set of trees.
|
|
315
|
+
|
|
316
|
+
Count how many leaves in a set of trees select each possible amount of points.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
trees
|
|
321
|
+
The set of trees. The variables must have broadcast shape (num_trees,).
|
|
322
|
+
X
|
|
323
|
+
The set of points to count.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
A vector where the i-th element counts how many leaves have i points.
|
|
328
|
+
"""
|
|
98
329
|
distr = jnp.zeros(X.shape[1] + 1, int)
|
|
99
|
-
trees = bart['var_trees'], bart['split_trees']
|
|
100
330
|
|
|
101
|
-
def loop(distr,
|
|
102
|
-
return distr + points_per_leaf_distr(*
|
|
331
|
+
def loop(distr, heaps: tuple[Array, Array]):
|
|
332
|
+
return distr + points_per_leaf_distr(*heaps, X), None
|
|
103
333
|
|
|
104
|
-
distr, _ = lax.scan(loop, distr, trees)
|
|
334
|
+
distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
|
|
105
335
|
return distr
|
|
106
336
|
|
|
107
337
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
338
|
+
@jit
|
|
339
|
+
def trace_points_per_leaf_distr(
|
|
340
|
+
trace: TreeHeaps, X: UInt[Array, 'p n']
|
|
341
|
+
) -> Int32[Array, 'trace_length n+1']:
|
|
342
|
+
"""Separately histogram points-per-leaf counts over a sequence of sets of trees.
|
|
343
|
+
|
|
344
|
+
For each set of trees, count how many leaves select each possible amount of
|
|
345
|
+
points.
|
|
346
|
+
|
|
347
|
+
Parameters
|
|
348
|
+
----------
|
|
349
|
+
trace
|
|
350
|
+
The sequence of sets of trees. The variables must have broadcast shape
|
|
351
|
+
(trace_length, num_trees).
|
|
352
|
+
X
|
|
353
|
+
The set of points to count.
|
|
111
354
|
|
|
112
|
-
|
|
355
|
+
Returns
|
|
356
|
+
-------
|
|
357
|
+
A matrix where element (t,i) counts how many leaves have i points in set t.
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def loop(_, trace):
|
|
361
|
+
return None, forest_points_per_leaf_distr(trace, X)
|
|
362
|
+
|
|
363
|
+
_, distr = lax.scan(loop, None, trace)
|
|
113
364
|
return distr
|
|
114
365
|
|
|
115
366
|
|
|
116
|
-
|
|
117
|
-
|
|
367
|
+
check_functions = []
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']]
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def check(func: CheckFunc) -> CheckFunc:
|
|
374
|
+
"""Add a function to a list of functions used to check trees.
|
|
375
|
+
|
|
376
|
+
Use to decorate functions that check whether a tree is valid in some way.
|
|
377
|
+
These functions are invoked automatically by `check_tree`, `check_trace` and
|
|
378
|
+
`debug_gbart`.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
func
|
|
383
|
+
The function to add to the list. It must accept a `TreeHeaps` and a
|
|
384
|
+
`max_split` argument, and return a boolean scalar that indicates if the
|
|
385
|
+
tree is ok.
|
|
386
|
+
|
|
387
|
+
Returns
|
|
388
|
+
-------
|
|
389
|
+
The function unchanged.
|
|
390
|
+
"""
|
|
391
|
+
check_functions.append(func)
|
|
392
|
+
return func
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@check
|
|
396
|
+
def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
|
|
397
|
+
"""Check that integer types are as small as possible and coherent."""
|
|
398
|
+
expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1)
|
|
118
399
|
expected_split_dtype = max_split.dtype
|
|
119
400
|
return (
|
|
120
|
-
var_tree.dtype == expected_var_dtype
|
|
121
|
-
and split_tree.dtype == expected_split_dtype
|
|
401
|
+
tree.var_tree.dtype == expected_var_dtype
|
|
402
|
+
and tree.split_tree.dtype == expected_split_dtype
|
|
122
403
|
)
|
|
123
404
|
|
|
124
405
|
|
|
125
|
-
|
|
126
|
-
|
|
406
|
+
@check
|
|
407
|
+
def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001
|
|
408
|
+
"""Check that array sizes are coherent."""
|
|
409
|
+
return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size
|
|
127
410
|
|
|
128
411
|
|
|
129
|
-
|
|
130
|
-
|
|
412
|
+
@check
|
|
413
|
+
def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
|
|
414
|
+
"""Check that the unused node slot at index 0 is not dirty."""
|
|
415
|
+
return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
|
|
131
416
|
|
|
132
417
|
|
|
133
|
-
|
|
134
|
-
|
|
418
|
+
@check
|
|
419
|
+
def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
|
|
420
|
+
"""Check that all leaf values are not inf of nan."""
|
|
421
|
+
return jnp.all(jnp.isfinite(tree.leaf_tree))
|
|
135
422
|
|
|
136
423
|
|
|
137
|
-
|
|
424
|
+
@check
|
|
425
|
+
def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
|
|
426
|
+
"""Check if there is any marked-non-leaf node with a marked-leaf parent."""
|
|
138
427
|
index = jnp.arange(
|
|
139
|
-
2 * split_tree.size,
|
|
140
|
-
dtype=
|
|
428
|
+
2 * tree.split_tree.size,
|
|
429
|
+
dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
|
|
141
430
|
)
|
|
142
431
|
parent_index = index >> 1
|
|
143
|
-
is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
144
|
-
parent_is_leaf = split_tree[parent_index] == 0
|
|
432
|
+
is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0
|
|
433
|
+
parent_is_leaf = tree.split_tree[parent_index] == 0
|
|
145
434
|
stray = is_not_leaf & parent_is_leaf
|
|
146
435
|
stray = stray.at[1].set(False)
|
|
147
436
|
return ~jnp.any(stray)
|
|
148
437
|
|
|
149
438
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
439
|
+
@check
|
|
440
|
+
def check_rule_consistency(
|
|
441
|
+
tree: TreeHeaps, max_split: UInt[Array, ' p']
|
|
442
|
+
) -> bool | Bool[Array, '']:
|
|
443
|
+
"""Check that decision rules define proper subsets of ancestor rules."""
|
|
444
|
+
if tree.var_tree.size < 4:
|
|
445
|
+
return True
|
|
446
|
+
|
|
447
|
+
# initial boundaries of decision rules. use extreme integers instead of 0,
|
|
448
|
+
# max_split to avoid checking if there is something out of bounds.
|
|
449
|
+
small = jnp.iinfo(jnp.int32).min
|
|
450
|
+
large = jnp.iinfo(jnp.int32).max
|
|
451
|
+
lower = jnp.full(max_split.size, small, jnp.int32)
|
|
452
|
+
upper = jnp.full(max_split.size, large, jnp.int32)
|
|
453
|
+
# specify the type explicitly, otherwise they are weakly types and get
|
|
454
|
+
# implicitly converted to split.dtype (typically uint8) in the expressions
|
|
455
|
+
|
|
456
|
+
def _check_recursive(node, lower, upper):
|
|
457
|
+
# read decision rule
|
|
458
|
+
var = tree.var_tree[node]
|
|
459
|
+
split = tree.split_tree[node]
|
|
460
|
+
|
|
461
|
+
# get rule boundaries from ancestors. use fill value in case var is
|
|
462
|
+
# out of bounds, we don't want to check out of bounds in this function
|
|
463
|
+
lower_var = lower.at[var].get(mode='fill', fill_value=small)
|
|
464
|
+
upper_var = upper.at[var].get(mode='fill', fill_value=large)
|
|
465
|
+
|
|
466
|
+
# check rule is in bounds
|
|
467
|
+
bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False)
|
|
468
|
+
|
|
469
|
+
# recurse
|
|
470
|
+
if node < tree.var_tree.size // 2:
|
|
471
|
+
bad |= _check_recursive(
|
|
472
|
+
2 * node,
|
|
473
|
+
lower,
|
|
474
|
+
upper.at[jnp.where(split, var, max_split.size)].set(split),
|
|
475
|
+
)
|
|
476
|
+
bad |= _check_recursive(
|
|
477
|
+
2 * node + 1,
|
|
478
|
+
lower.at[jnp.where(split, var, max_split.size)].set(split),
|
|
479
|
+
upper,
|
|
480
|
+
)
|
|
481
|
+
return bad
|
|
157
482
|
|
|
483
|
+
return ~_check_recursive(1, lower, upper)
|
|
158
484
|
|
|
159
|
-
|
|
160
|
-
|
|
485
|
+
|
|
486
|
+
@check
|
|
487
|
+
def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001
|
|
488
|
+
"""Check that #leaves = 1 + #(internal nodes)."""
|
|
489
|
+
is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True)
|
|
490
|
+
num_leaves = jnp.count_nonzero(is_leaf)
|
|
491
|
+
num_internal = jnp.count_nonzero(tree.split_tree)
|
|
492
|
+
return num_leaves == num_internal + 1
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
@check
|
|
496
|
+
def check_var_in_bounds(
|
|
497
|
+
tree: TreeHeaps, max_split: UInt[Array, ' p']
|
|
498
|
+
) -> Bool[Array, '']:
|
|
499
|
+
"""Check that variables are in [0, max_split.size)."""
|
|
500
|
+
decision_node = tree.split_tree.astype(bool)
|
|
501
|
+
in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size)
|
|
502
|
+
return jnp.all(in_bounds | ~decision_node)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
@check
|
|
506
|
+
def check_split_in_bounds(
|
|
507
|
+
tree: TreeHeaps, max_split: UInt[Array, ' p']
|
|
508
|
+
) -> Bool[Array, '']:
|
|
509
|
+
"""Check that splits are in [0, max_split[var]]."""
|
|
510
|
+
max_split_var = (
|
|
511
|
+
max_split.astype(jnp.int32)
|
|
512
|
+
.at[tree.var_tree]
|
|
513
|
+
.get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
|
|
514
|
+
)
|
|
515
|
+
return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var))
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']:
|
|
519
|
+
"""Check the validity of a tree.
|
|
520
|
+
|
|
521
|
+
Use `describe_error` to parse the error code returned by this function.
|
|
522
|
+
|
|
523
|
+
Parameters
|
|
524
|
+
----------
|
|
525
|
+
tree
|
|
526
|
+
The tree to check.
|
|
527
|
+
max_split
|
|
528
|
+
The maximum split value for each variable.
|
|
529
|
+
|
|
530
|
+
Returns
|
|
531
|
+
-------
|
|
532
|
+
An integer where each bit indicates whether a check failed.
|
|
533
|
+
"""
|
|
534
|
+
error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1)
|
|
161
535
|
error = error_type(0)
|
|
162
536
|
for i, func in enumerate(check_functions):
|
|
163
|
-
ok = func(
|
|
537
|
+
ok = func(tree, max_split)
|
|
164
538
|
ok = jnp.bool_(ok)
|
|
165
539
|
bit = (~ok) << i
|
|
166
540
|
error |= bit
|
|
167
541
|
return error
|
|
168
542
|
|
|
169
543
|
|
|
170
|
-
def describe_error(error):
|
|
544
|
+
def describe_error(error: int | Integer[Array, '']) -> list[str]:
|
|
545
|
+
"""Describe the error code returned by `check_tree`.
|
|
546
|
+
|
|
547
|
+
Parameters
|
|
548
|
+
----------
|
|
549
|
+
error
|
|
550
|
+
The error code returned by `check_tree`.
|
|
551
|
+
|
|
552
|
+
Returns
|
|
553
|
+
-------
|
|
554
|
+
A list of the function names that implement the failed checks.
|
|
555
|
+
"""
|
|
171
556
|
return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
|
|
172
557
|
|
|
173
558
|
|
|
174
|
-
|
|
559
|
+
@jit
|
|
560
|
+
@partial(vmap_nodoc, in_axes=(0, None))
|
|
561
|
+
def check_trace(
|
|
562
|
+
trace: TreeHeaps, max_split: UInt[Array, ' p']
|
|
563
|
+
) -> UInt[Array, 'trace_length num_trees']:
|
|
564
|
+
"""Check the validity of a sequence of sets of trees.
|
|
565
|
+
|
|
566
|
+
Use `describe_error` to parse the error codes returned by this function.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
trace
|
|
571
|
+
The sequence of sets of trees to check. The tree arrays must have
|
|
572
|
+
broadcast shape (trace_length, num_trees). This object can have
|
|
573
|
+
additional attributes beyond the tree arrays, they are ignored.
|
|
574
|
+
max_split
|
|
575
|
+
The maximum split value for each variable.
|
|
576
|
+
|
|
577
|
+
Returns
|
|
578
|
+
-------
|
|
579
|
+
A matrix of error codes for each tree.
|
|
580
|
+
"""
|
|
581
|
+
trees = TreesTrace.from_dataclass(trace)
|
|
582
|
+
check_forest = vmap(check_tree, in_axes=(0, None))
|
|
583
|
+
return check_forest(trees, max_split)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _get_next_line(s: str, i: int) -> tuple[str, int]:
|
|
587
|
+
"""Get the next line from a string and the new index."""
|
|
588
|
+
i_new = s.find('\n', i)
|
|
589
|
+
if i_new == -1:
|
|
590
|
+
return s[i:], len(s)
|
|
591
|
+
return s[i:i_new], i_new + 1
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
class BARTTraceMeta(Module):
|
|
595
|
+
"""Metadata of R BART tree traces.
|
|
596
|
+
|
|
597
|
+
Parameters
|
|
598
|
+
----------
|
|
599
|
+
ndpost
|
|
600
|
+
The number of posterior draws.
|
|
601
|
+
ntree
|
|
602
|
+
The number of trees in the model.
|
|
603
|
+
numcut
|
|
604
|
+
The maximum split value for each variable.
|
|
605
|
+
heap_size
|
|
606
|
+
The size of the heap required to store the trees.
|
|
607
|
+
"""
|
|
608
|
+
|
|
609
|
+
ndpost: int = field(static=True)
|
|
610
|
+
ntree: int = field(static=True)
|
|
611
|
+
numcut: UInt[Array, ' p']
|
|
612
|
+
heap_size: int = field(static=True)
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def scan_BART_trees(trees: str) -> BARTTraceMeta:
|
|
616
|
+
"""Scan an R BART tree trace checking for errors and parsing metadata.
|
|
617
|
+
|
|
618
|
+
Parameters
|
|
619
|
+
----------
|
|
620
|
+
trees
|
|
621
|
+
The string representation of a trace of trees of the R BART package.
|
|
622
|
+
Can be accessed from ``mc_gbart(...).treedraws['trees']``.
|
|
623
|
+
|
|
624
|
+
Returns
|
|
625
|
+
-------
|
|
626
|
+
An object containing the metadata.
|
|
627
|
+
|
|
628
|
+
Raises
|
|
629
|
+
------
|
|
630
|
+
ValueError
|
|
631
|
+
If the string is malformed or contains leftover characters.
|
|
632
|
+
"""
|
|
633
|
+
# parse first line
|
|
634
|
+
line, i_char = _get_next_line(trees, 0)
|
|
635
|
+
i_line = 1
|
|
636
|
+
match = fullmatch(r'(\d+) (\d+) (\d+)', line)
|
|
637
|
+
if match is None:
|
|
638
|
+
msg = f'Malformed header at {i_line=}'
|
|
639
|
+
raise ValueError(msg)
|
|
640
|
+
ndpost, ntree, p = map(int, match.groups())
|
|
641
|
+
|
|
642
|
+
# initial values for maxima
|
|
643
|
+
max_heap_index = 0
|
|
644
|
+
numcut = numpy.zeros(p, int)
|
|
645
|
+
|
|
646
|
+
# cycle over iterations and trees
|
|
647
|
+
for i_iter in range(ndpost):
|
|
648
|
+
for i_tree in range(ntree):
|
|
649
|
+
# parse first line of tree definition
|
|
650
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
651
|
+
i_line += 1
|
|
652
|
+
match = fullmatch(r'(\d+)', line)
|
|
653
|
+
if match is None:
|
|
654
|
+
msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
|
|
655
|
+
raise ValueError(msg)
|
|
656
|
+
num_nodes = int(line)
|
|
657
|
+
|
|
658
|
+
# cycle over nodes
|
|
659
|
+
for i_node in range(num_nodes):
|
|
660
|
+
# parse node definition
|
|
661
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
662
|
+
i_line += 1
|
|
663
|
+
match = fullmatch(
|
|
664
|
+
r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
|
|
665
|
+
)
|
|
666
|
+
if match is None:
|
|
667
|
+
msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
|
|
668
|
+
raise ValueError(msg)
|
|
669
|
+
i_heap = int(match.group(1))
|
|
670
|
+
var = int(match.group(2))
|
|
671
|
+
split = int(match.group(3))
|
|
672
|
+
|
|
673
|
+
# update maxima
|
|
674
|
+
numcut[var] = max(numcut[var], split)
|
|
675
|
+
max_heap_index = max(max_heap_index, i_heap)
|
|
676
|
+
|
|
677
|
+
assert i_char <= len(trees)
|
|
678
|
+
if i_char < len(trees):
|
|
679
|
+
msg = f'Leftover {len(trees) - i_char} characters in string'
|
|
680
|
+
raise ValueError(msg)
|
|
681
|
+
|
|
682
|
+
# determine minimal integer type for numcut
|
|
683
|
+
numcut += 1 # because BART is 0-based
|
|
684
|
+
split_dtype = minimal_unsigned_dtype(numcut.max())
|
|
685
|
+
numcut = jnp.array(numcut.astype(split_dtype))
|
|
686
|
+
|
|
687
|
+
# determine minimum heap size to store the trees
|
|
688
|
+
heap_size = 2 ** ceil(log2(max_heap_index + 1))
|
|
689
|
+
|
|
690
|
+
return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class TraceWithOffset(Module):
|
|
694
|
+
"""Implementation of `bartz.mcmcloop.Trace`."""
|
|
695
|
+
|
|
696
|
+
leaf_tree: Float32[Array, 'ndpost ntree 2**d']
|
|
697
|
+
var_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
|
|
698
|
+
split_tree: UInt[Array, 'ndpost ntree 2**(d-1)']
|
|
699
|
+
offset: Float32[Array, ' ndpost']
|
|
700
|
+
|
|
701
|
+
@classmethod
|
|
702
|
+
def from_trees_trace(
|
|
703
|
+
cls, trees: TreeHeaps, offset: Float32[Array, '']
|
|
704
|
+
) -> 'TraceWithOffset':
|
|
705
|
+
"""Create a `TraceWithOffset` from a `TreeHeaps`."""
|
|
706
|
+
ndpost, _, _ = trees.leaf_tree.shape
|
|
707
|
+
return cls(
|
|
708
|
+
leaf_tree=trees.leaf_tree,
|
|
709
|
+
var_tree=trees.var_tree,
|
|
710
|
+
split_tree=trees.split_tree,
|
|
711
|
+
offset=jnp.full(ndpost, offset),
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def trees_BART_to_bartz(
|
|
716
|
+
trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
|
|
717
|
+
) -> tuple[TraceWithOffset, BARTTraceMeta]:
|
|
718
|
+
"""Convert trees from the R BART format to the bartz format.
|
|
719
|
+
|
|
720
|
+
Parameters
|
|
721
|
+
----------
|
|
722
|
+
trees
|
|
723
|
+
The string representation of a trace of trees of the R BART package.
|
|
724
|
+
Can be accessed from ``mc_gbart(...).treedraws['trees']``.
|
|
725
|
+
min_maxdepth
|
|
726
|
+
The maximum tree depth of the output will be set to the maximum
|
|
727
|
+
observed depth in the input trees. Use this parameter to require at
|
|
728
|
+
least this maximum depth in the output format.
|
|
729
|
+
offset
|
|
730
|
+
The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
|
|
731
|
+
summed to the sum of trees. To match that behavior, this function
|
|
732
|
+
returns an offset as well, zero by default. Set with this parameter
|
|
733
|
+
otherwise.
|
|
175
734
|
|
|
735
|
+
Returns
|
|
736
|
+
-------
|
|
737
|
+
trace : TraceWithOffset
|
|
738
|
+
A representation of the trees compatible with the trace returned by
|
|
739
|
+
`bartz.mcmcloop.run_mcmc`.
|
|
740
|
+
meta : BARTTraceMeta
|
|
741
|
+
The metadata of the trace, containing the number of iterations, trees,
|
|
742
|
+
and the maximum split value.
|
|
743
|
+
"""
|
|
744
|
+
# scan all the string checking for errors and determining sizes
|
|
745
|
+
meta = scan_BART_trees(trees)
|
|
176
746
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
747
|
+
# skip first line
|
|
748
|
+
_, i_char = _get_next_line(trees, 0)
|
|
749
|
+
|
|
750
|
+
heap_size = max(meta.heap_size, 2**min_maxdepth)
|
|
751
|
+
leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32)
|
|
752
|
+
var_trees = numpy.zeros(
|
|
753
|
+
(meta.ndpost, meta.ntree, heap_size // 2),
|
|
754
|
+
dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
|
|
755
|
+
)
|
|
756
|
+
split_trees = numpy.zeros(
|
|
757
|
+
(meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
|
|
184
758
|
)
|
|
759
|
+
|
|
760
|
+
# cycle over iterations and trees
|
|
761
|
+
for i_iter in range(meta.ndpost):
|
|
762
|
+
for i_tree in range(meta.ntree):
|
|
763
|
+
# parse first line of tree definition
|
|
764
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
765
|
+
num_nodes = int(line)
|
|
766
|
+
|
|
767
|
+
is_internal = numpy.zeros(heap_size // 2, dtype=bool)
|
|
768
|
+
|
|
769
|
+
# cycle over nodes
|
|
770
|
+
for _ in range(num_nodes):
|
|
771
|
+
# parse node definition
|
|
772
|
+
line, i_char = _get_next_line(trees, i_char)
|
|
773
|
+
values = line.split()
|
|
774
|
+
i_heap = int(values[0])
|
|
775
|
+
var = int(values[1])
|
|
776
|
+
split = int(values[2])
|
|
777
|
+
leaf = float(values[3])
|
|
778
|
+
|
|
779
|
+
# update values
|
|
780
|
+
leaf_trees[i_iter, i_tree, i_heap] = leaf
|
|
781
|
+
is_internal[i_heap // 2] = True
|
|
782
|
+
if i_heap < heap_size // 2:
|
|
783
|
+
var_trees[i_iter, i_tree, i_heap] = var
|
|
784
|
+
split_trees[i_iter, i_tree, i_heap] = split + 1
|
|
785
|
+
|
|
786
|
+
is_internal[0] = False
|
|
787
|
+
split_trees[i_iter, i_tree, ~is_internal] = 0
|
|
788
|
+
|
|
789
|
+
return TraceWithOffset(
|
|
790
|
+
leaf_tree=jnp.array(leaf_trees),
|
|
791
|
+
var_tree=jnp.array(var_trees),
|
|
792
|
+
split_tree=jnp.array(split_trees),
|
|
793
|
+
offset=jnp.zeros(meta.ndpost)
|
|
794
|
+
if offset is None
|
|
795
|
+
else jnp.full(meta.ndpost, offset),
|
|
796
|
+
), meta
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
class SamplePriorStack(Module):
|
|
800
|
+
"""Represent the manually managed stack used in `sample_prior`.
|
|
801
|
+
|
|
802
|
+
Each level of the stack represents a recursion into a child node in a
|
|
803
|
+
binary tree of maximum depth `d`.
|
|
804
|
+
|
|
805
|
+
Parameters
|
|
806
|
+
----------
|
|
807
|
+
nonterminal
|
|
808
|
+
Whether the node is valid or the recursion is into unused node slots.
|
|
809
|
+
lower
|
|
810
|
+
upper
|
|
811
|
+
The available cutpoints along ``var`` are in the integer range
|
|
812
|
+
``[1 + lower[var], 1 + upper[var])``.
|
|
813
|
+
var
|
|
814
|
+
split
|
|
815
|
+
The variable and cutpoint of a decision node.
|
|
816
|
+
"""
|
|
817
|
+
|
|
818
|
+
nonterminal: Bool[Array, ' d-1']
|
|
819
|
+
lower: UInt[Array, 'd-1 p']
|
|
820
|
+
upper: UInt[Array, 'd-1 p']
|
|
821
|
+
var: UInt[Array, ' d-1']
|
|
822
|
+
split: UInt[Array, ' d-1']
|
|
823
|
+
|
|
824
|
+
@classmethod
|
|
825
|
+
def initial(
|
|
826
|
+
cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
|
|
827
|
+
) -> 'SamplePriorStack':
|
|
828
|
+
"""Initialize the stack.
|
|
829
|
+
|
|
830
|
+
Parameters
|
|
831
|
+
----------
|
|
832
|
+
p_nonterminal
|
|
833
|
+
The prior probability of a node being non-terminal conditional on
|
|
834
|
+
its ancestors and on having available decision rules, at each depth.
|
|
835
|
+
max_split
|
|
836
|
+
The number of cutpoints along each variable.
|
|
837
|
+
|
|
838
|
+
Returns
|
|
839
|
+
-------
|
|
840
|
+
A `SamplePriorStack` initialized to start the recursion.
|
|
841
|
+
"""
|
|
842
|
+
var_dtype = minimal_unsigned_dtype(max_split.size - 1)
|
|
843
|
+
return cls(
|
|
844
|
+
nonterminal=jnp.ones(p_nonterminal.size, bool),
|
|
845
|
+
lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
|
|
846
|
+
upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
|
|
847
|
+
var=jnp.zeros(p_nonterminal.size, var_dtype),
|
|
848
|
+
split=jnp.zeros(p_nonterminal.size, max_split.dtype),
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
class SamplePriorTrees(Module):
|
|
853
|
+
"""Object holding the trees generated by `sample_prior`.
|
|
854
|
+
|
|
855
|
+
Parameters
|
|
856
|
+
----------
|
|
857
|
+
leaf_tree
|
|
858
|
+
var_tree
|
|
859
|
+
split_tree
|
|
860
|
+
The arrays representing the trees, see `bartz.grove`.
|
|
861
|
+
"""
|
|
862
|
+
|
|
863
|
+
leaf_tree: Float32[Array, '* 2**d']
|
|
864
|
+
var_tree: UInt[Array, '* 2**(d-1)']
|
|
865
|
+
split_tree: UInt[Array, '* 2**(d-1)']
|
|
866
|
+
|
|
867
|
+
@classmethod
|
|
868
|
+
def initial(
|
|
869
|
+
cls,
|
|
870
|
+
key: Key[Array, ''],
|
|
871
|
+
sigma_mu: Float32[Array, ''],
|
|
872
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
873
|
+
max_split: UInt[Array, ' p'],
|
|
874
|
+
) -> 'SamplePriorTrees':
|
|
875
|
+
"""Initialize the trees.
|
|
876
|
+
|
|
877
|
+
The leaves are already correct and do not need to be changed.
|
|
878
|
+
|
|
879
|
+
Parameters
|
|
880
|
+
----------
|
|
881
|
+
key
|
|
882
|
+
A jax random key.
|
|
883
|
+
sigma_mu
|
|
884
|
+
The prior standard deviation of each leaf.
|
|
885
|
+
p_nonterminal
|
|
886
|
+
The prior probability of a node being non-terminal conditional on
|
|
887
|
+
its ancestors and on having available decision rules, at each depth.
|
|
888
|
+
max_split
|
|
889
|
+
The number of cutpoints along each variable.
|
|
890
|
+
|
|
891
|
+
Returns
|
|
892
|
+
-------
|
|
893
|
+
Trees initialized with random leaves and stub tree structures.
|
|
894
|
+
"""
|
|
895
|
+
heap_size = 2 ** (p_nonterminal.size + 1)
|
|
896
|
+
return cls(
|
|
897
|
+
leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
|
|
898
|
+
var_tree=jnp.zeros(
|
|
899
|
+
heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
|
|
900
|
+
),
|
|
901
|
+
split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
class SamplePriorCarry(Module):
|
|
906
|
+
"""Object holding values carried along the recursion in `sample_prior`.
|
|
907
|
+
|
|
908
|
+
Parameters
|
|
909
|
+
----------
|
|
910
|
+
key
|
|
911
|
+
A jax random key used to sample decision rules.
|
|
912
|
+
stack
|
|
913
|
+
The stack used to manage the recursion.
|
|
914
|
+
trees
|
|
915
|
+
The output arrays.
|
|
916
|
+
"""
|
|
917
|
+
|
|
918
|
+
key: Key[Array, '']
|
|
919
|
+
stack: SamplePriorStack
|
|
920
|
+
trees: SamplePriorTrees
|
|
921
|
+
|
|
922
|
+
@classmethod
|
|
923
|
+
def initial(
|
|
924
|
+
cls,
|
|
925
|
+
key: Key[Array, ''],
|
|
926
|
+
sigma_mu: Float32[Array, ''],
|
|
927
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
928
|
+
max_split: UInt[Array, ' p'],
|
|
929
|
+
) -> 'SamplePriorCarry':
|
|
930
|
+
"""Initialize the carry object.
|
|
931
|
+
|
|
932
|
+
Parameters
|
|
933
|
+
----------
|
|
934
|
+
key
|
|
935
|
+
A jax random key.
|
|
936
|
+
sigma_mu
|
|
937
|
+
The prior standard deviation of each leaf.
|
|
938
|
+
p_nonterminal
|
|
939
|
+
The prior probability of a node being non-terminal conditional on
|
|
940
|
+
its ancestors and on having available decision rules, at each depth.
|
|
941
|
+
max_split
|
|
942
|
+
The number of cutpoints along each variable.
|
|
943
|
+
|
|
944
|
+
Returns
|
|
945
|
+
-------
|
|
946
|
+
A `SamplePriorCarry` initialized to start the recursion.
|
|
947
|
+
"""
|
|
948
|
+
keys = split_key(key)
|
|
949
|
+
return cls(
|
|
950
|
+
keys.pop(),
|
|
951
|
+
SamplePriorStack.initial(p_nonterminal, max_split),
|
|
952
|
+
SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
class SamplePriorX(Module):
|
|
957
|
+
"""Object representing the recursion scan in `sample_prior`.
|
|
958
|
+
|
|
959
|
+
The sequence of nodes to visit is pre-computed recursively once, unrolling
|
|
960
|
+
the recursion schedule.
|
|
961
|
+
|
|
962
|
+
Parameters
|
|
963
|
+
----------
|
|
964
|
+
node
|
|
965
|
+
The heap index of the node to visit.
|
|
966
|
+
depth
|
|
967
|
+
The depth of the node.
|
|
968
|
+
next_depth
|
|
969
|
+
The depth of the next node to visit, either the left child or the right
|
|
970
|
+
sibling of the node or of an ancestor.
|
|
971
|
+
"""
|
|
972
|
+
|
|
973
|
+
node: Int32[Array, ' 2**(d-1)-1']
|
|
974
|
+
depth: Int32[Array, ' 2**(d-1)-1']
|
|
975
|
+
next_depth: Int32[Array, ' 2**(d-1)-1']
|
|
976
|
+
|
|
977
|
+
@classmethod
|
|
978
|
+
def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
|
|
979
|
+
"""Initialize the sequence of nodes to visit.
|
|
980
|
+
|
|
981
|
+
Parameters
|
|
982
|
+
----------
|
|
983
|
+
p_nonterminal
|
|
984
|
+
The prior probability of a node being non-terminal conditional on
|
|
985
|
+
its ancestors and on having available decision rules, at each depth.
|
|
986
|
+
|
|
987
|
+
Returns
|
|
988
|
+
-------
|
|
989
|
+
A `SamplePriorX` initialized with the sequence of nodes to visit.
|
|
990
|
+
"""
|
|
991
|
+
seq = cls._sequence(p_nonterminal.size)
|
|
992
|
+
assert len(seq) == 2**p_nonterminal.size - 1
|
|
993
|
+
node = [node for node, depth in seq]
|
|
994
|
+
depth = [depth for node, depth in seq]
|
|
995
|
+
next_depth = depth[1:] + [p_nonterminal.size]
|
|
996
|
+
return cls(
|
|
997
|
+
node=jnp.array(node),
|
|
998
|
+
depth=jnp.array(depth),
|
|
999
|
+
next_depth=jnp.array(next_depth),
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
@classmethod
|
|
1003
|
+
def _sequence(
|
|
1004
|
+
cls, max_depth: int, depth: int = 0, node: int = 1
|
|
1005
|
+
) -> tuple[tuple[int, int], ...]:
|
|
1006
|
+
"""Recursively generate a sequence [(node, depth), ...]."""
|
|
1007
|
+
if depth < max_depth:
|
|
1008
|
+
out = ((node, depth),)
|
|
1009
|
+
out += cls._sequence(max_depth, depth + 1, 2 * node)
|
|
1010
|
+
out += cls._sequence(max_depth, depth + 1, 2 * node + 1)
|
|
1011
|
+
return out
|
|
1012
|
+
return ()
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
def sample_prior_onetree(
|
|
1016
|
+
key: Key[Array, ''],
|
|
1017
|
+
max_split: UInt[Array, ' p'],
|
|
1018
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
1019
|
+
sigma_mu: Float32[Array, ''],
|
|
1020
|
+
) -> SamplePriorTrees:
|
|
1021
|
+
"""Sample a tree from the BART prior.
|
|
1022
|
+
|
|
1023
|
+
Parameters
|
|
1024
|
+
----------
|
|
1025
|
+
key
|
|
1026
|
+
A jax random key.
|
|
1027
|
+
max_split
|
|
1028
|
+
The maximum split value for each variable.
|
|
1029
|
+
p_nonterminal
|
|
1030
|
+
The prior probability of a node being non-terminal conditional on
|
|
1031
|
+
its ancestors and on having available decision rules, at each depth.
|
|
1032
|
+
sigma_mu
|
|
1033
|
+
The prior standard deviation of each leaf.
|
|
1034
|
+
|
|
1035
|
+
Returns
|
|
1036
|
+
-------
|
|
1037
|
+
An object containing a generated tree.
|
|
1038
|
+
"""
|
|
1039
|
+
carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split)
|
|
1040
|
+
xs = SamplePriorX.initial(p_nonterminal)
|
|
1041
|
+
|
|
1042
|
+
def loop(carry: SamplePriorCarry, x: SamplePriorX):
|
|
1043
|
+
keys = split_key(carry.key, 4)
|
|
1044
|
+
|
|
1045
|
+
# get variables at current stack level
|
|
1046
|
+
stack = carry.stack
|
|
1047
|
+
nonterminal = stack.nonterminal[x.depth]
|
|
1048
|
+
lower = stack.lower[x.depth, :]
|
|
1049
|
+
upper = stack.upper[x.depth, :]
|
|
1050
|
+
|
|
1051
|
+
# sample a random decision rule
|
|
1052
|
+
available: Bool[Array, ' p'] = lower < upper
|
|
1053
|
+
allowed = jnp.any(available)
|
|
1054
|
+
var = randint_masked(keys.pop(), available)
|
|
1055
|
+
split = 1 + random.randint(keys.pop(), (), lower[var], upper[var])
|
|
1056
|
+
|
|
1057
|
+
# cast to shorter integer types
|
|
1058
|
+
var = var.astype(carry.trees.var_tree.dtype)
|
|
1059
|
+
split = split.astype(carry.trees.split_tree.dtype)
|
|
1060
|
+
|
|
1061
|
+
# decide whether to try to grow the node if it is growable
|
|
1062
|
+
pnt = p_nonterminal[x.depth]
|
|
1063
|
+
try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt)
|
|
1064
|
+
nonterminal &= try_nonterminal & allowed
|
|
1065
|
+
|
|
1066
|
+
# update trees
|
|
1067
|
+
trees = carry.trees
|
|
1068
|
+
trees = replace(
|
|
1069
|
+
trees,
|
|
1070
|
+
var_tree=trees.var_tree.at[x.node].set(var),
|
|
1071
|
+
split_tree=trees.split_tree.at[x.node].set(
|
|
1072
|
+
jnp.where(nonterminal, split, 0)
|
|
1073
|
+
),
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
def write_push_stack() -> SamplePriorStack:
|
|
1077
|
+
"""Update the stack to go to the left child."""
|
|
1078
|
+
return replace(
|
|
1079
|
+
stack,
|
|
1080
|
+
nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
|
|
1081
|
+
lower=stack.lower.at[x.next_depth, :].set(lower),
|
|
1082
|
+
upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
|
|
1083
|
+
var=stack.var.at[x.depth].set(var),
|
|
1084
|
+
split=stack.split.at[x.depth].set(split),
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
def pop_push_stack() -> SamplePriorStack:
|
|
1088
|
+
"""Update the stack to go to the right sibling, possibly at lower depth."""
|
|
1089
|
+
var = stack.var[x.next_depth - 1]
|
|
1090
|
+
split = stack.split[x.next_depth - 1]
|
|
1091
|
+
lower = stack.lower[x.next_depth - 1, :]
|
|
1092
|
+
upper = stack.upper[x.next_depth - 1, :]
|
|
1093
|
+
return replace(
|
|
1094
|
+
stack,
|
|
1095
|
+
lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
|
|
1096
|
+
upper=stack.upper.at[x.next_depth, :].set(upper),
|
|
1097
|
+
)
|
|
1098
|
+
|
|
1099
|
+
# update stack
|
|
1100
|
+
stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack)
|
|
1101
|
+
|
|
1102
|
+
# update carry
|
|
1103
|
+
carry = replace(carry, key=keys.pop(), stack=stack, trees=trees)
|
|
1104
|
+
return carry, None
|
|
1105
|
+
|
|
1106
|
+
carry, _ = lax.scan(loop, carry, xs)
|
|
1107
|
+
return carry.trees
|
|
1108
|
+
|
|
1109
|
+
|
|
1110
|
+
@partial(vmap_nodoc, in_axes=(0, None, None, None))
|
|
1111
|
+
def sample_prior_forest(
|
|
1112
|
+
keys: Key[Array, ' num_trees'],
|
|
1113
|
+
max_split: UInt[Array, ' p'],
|
|
1114
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
1115
|
+
sigma_mu: Float32[Array, ''],
|
|
1116
|
+
) -> SamplePriorTrees:
|
|
1117
|
+
"""Sample a set of independent trees from the BART prior.
|
|
1118
|
+
|
|
1119
|
+
Parameters
|
|
1120
|
+
----------
|
|
1121
|
+
keys
|
|
1122
|
+
A sequence of jax random keys, one for each tree. This determined the
|
|
1123
|
+
number of trees sampled.
|
|
1124
|
+
max_split
|
|
1125
|
+
The maximum split value for each variable.
|
|
1126
|
+
p_nonterminal
|
|
1127
|
+
The prior probability of a node being non-terminal conditional on
|
|
1128
|
+
its ancestors and on having available decision rules, at each depth.
|
|
1129
|
+
sigma_mu
|
|
1130
|
+
The prior standard deviation of each leaf.
|
|
1131
|
+
|
|
1132
|
+
Returns
|
|
1133
|
+
-------
|
|
1134
|
+
An object containing the generated trees.
|
|
1135
|
+
"""
|
|
1136
|
+
return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu)
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
@partial(jit, static_argnums=(1, 2))
|
|
1140
|
+
def sample_prior(
|
|
1141
|
+
key: Key[Array, ''],
|
|
1142
|
+
trace_length: int,
|
|
1143
|
+
num_trees: int,
|
|
1144
|
+
max_split: UInt[Array, ' p'],
|
|
1145
|
+
p_nonterminal: Float32[Array, ' d-1'],
|
|
1146
|
+
sigma_mu: Float32[Array, ''],
|
|
1147
|
+
) -> SamplePriorTrees:
|
|
1148
|
+
"""Sample independent trees from the BART prior.
|
|
1149
|
+
|
|
1150
|
+
Parameters
|
|
1151
|
+
----------
|
|
1152
|
+
key
|
|
1153
|
+
A jax random key.
|
|
1154
|
+
trace_length
|
|
1155
|
+
The number of iterations.
|
|
1156
|
+
num_trees
|
|
1157
|
+
The number of trees for each iteration.
|
|
1158
|
+
max_split
|
|
1159
|
+
The number of cutpoints along each variable.
|
|
1160
|
+
p_nonterminal
|
|
1161
|
+
The prior probability of a node being non-terminal conditional on
|
|
1162
|
+
its ancestors and on having available decision rules, at each depth.
|
|
1163
|
+
This determines the maximum depth of the trees.
|
|
1164
|
+
sigma_mu
|
|
1165
|
+
The prior standard deviation of each leaf.
|
|
1166
|
+
|
|
1167
|
+
Returns
|
|
1168
|
+
-------
|
|
1169
|
+
An object containing the generated trees, with batch shape (trace_length, num_trees).
|
|
1170
|
+
"""
|
|
1171
|
+
keys = random.split(key, trace_length * num_trees)
|
|
1172
|
+
trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu)
|
|
1173
|
+
return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees)
|
|
1174
|
+
|
|
1175
|
+
|
|
1176
|
+
class debug_gbart(gbart):
|
|
1177
|
+
"""A subclass of `gbart` that adds debugging functionality.
|
|
1178
|
+
|
|
1179
|
+
Parameters
|
|
1180
|
+
----------
|
|
1181
|
+
*args
|
|
1182
|
+
Passed to `gbart`.
|
|
1183
|
+
check_trees
|
|
1184
|
+
If `True`, check all trees with `check_trace` after running the MCMC,
|
|
1185
|
+
and assert that they are all valid. Set to `False` to allow jax tracing.
|
|
1186
|
+
**kw
|
|
1187
|
+
Passed to `gbart`.
|
|
1188
|
+
"""
|
|
1189
|
+
|
|
1190
|
+
def __init__(self, *args, check_trees: bool = True, **kw):
|
|
1191
|
+
super().__init__(*args, **kw)
|
|
1192
|
+
if check_trees:
|
|
1193
|
+
bad = self.check_trees()
|
|
1194
|
+
bad_count = jnp.count_nonzero(bad)
|
|
1195
|
+
assert bad_count == 0
|
|
1196
|
+
|
|
1197
|
+
def show_tree(self, i_sample: int, i_tree: int, print_all: bool = False):
|
|
1198
|
+
"""Print a single tree in human-readable format.
|
|
1199
|
+
|
|
1200
|
+
Parameters
|
|
1201
|
+
----------
|
|
1202
|
+
i_sample
|
|
1203
|
+
The index of the posterior sample.
|
|
1204
|
+
i_tree
|
|
1205
|
+
The index of the tree in the sample.
|
|
1206
|
+
print_all
|
|
1207
|
+
If `True`, also print the content of unused node slots.
|
|
1208
|
+
"""
|
|
1209
|
+
tree = TreesTrace.from_dataclass(self._main_trace)
|
|
1210
|
+
tree = tree_map(lambda x: x[i_sample, i_tree, :], tree)
|
|
1211
|
+
s = format_tree(tree, print_all=print_all)
|
|
1212
|
+
print(s) # noqa: T201, this method is intended for debug
|
|
1213
|
+
|
|
1214
|
+
def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, '']:
|
|
1215
|
+
"""Return the harmonic mean of the error variance.
|
|
1216
|
+
|
|
1217
|
+
Parameters
|
|
1218
|
+
----------
|
|
1219
|
+
prior
|
|
1220
|
+
If `True`, use the prior distribution, otherwise use the full
|
|
1221
|
+
conditional at the last MCMC iteration.
|
|
1222
|
+
|
|
1223
|
+
Returns
|
|
1224
|
+
-------
|
|
1225
|
+
The harmonic mean 1/E[1/sigma^2] in the selected distribution.
|
|
1226
|
+
"""
|
|
1227
|
+
bart = self._mcmc_state
|
|
1228
|
+
assert bart.sigma2_alpha is not None
|
|
1229
|
+
assert bart.z is None
|
|
1230
|
+
if prior:
|
|
1231
|
+
alpha = bart.sigma2_alpha
|
|
1232
|
+
beta = bart.sigma2_beta
|
|
1233
|
+
else:
|
|
1234
|
+
resid = bart.resid
|
|
1235
|
+
alpha = bart.sigma2_alpha + resid.size / 2
|
|
1236
|
+
norm2 = resid @ resid
|
|
1237
|
+
beta = bart.sigma2_beta + norm2 / 2
|
|
1238
|
+
sigma2 = beta / alpha
|
|
1239
|
+
return jnp.sqrt(sigma2)
|
|
1240
|
+
|
|
1241
|
+
def compare_resid(self) -> tuple[Float32[Array, ' n'], Float32[Array, ' n']]:
|
|
1242
|
+
"""Re-compute residuals to compare them with the updated ones.
|
|
1243
|
+
|
|
1244
|
+
Returns
|
|
1245
|
+
-------
|
|
1246
|
+
resid1 : Float32[Array, 'n']
|
|
1247
|
+
The final state of the residuals updated during the MCMC.
|
|
1248
|
+
resid2 : Float32[Array, 'n']
|
|
1249
|
+
The residuals computed from the final state of the trees.
|
|
1250
|
+
"""
|
|
1251
|
+
bart = self._mcmc_state
|
|
1252
|
+
resid1 = bart.resid
|
|
1253
|
+
|
|
1254
|
+
trees = evaluate_forest(bart.X, bart.forest)
|
|
1255
|
+
|
|
1256
|
+
if bart.z is not None:
|
|
1257
|
+
ref = bart.z
|
|
1258
|
+
else:
|
|
1259
|
+
ref = bart.y
|
|
1260
|
+
resid2 = ref - (trees + bart.offset)
|
|
1261
|
+
|
|
1262
|
+
return resid1, resid2
|
|
1263
|
+
|
|
1264
|
+
def avg_acc(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
|
|
1265
|
+
"""Compute the average acceptance rates of tree moves.
|
|
1266
|
+
|
|
1267
|
+
Returns
|
|
1268
|
+
-------
|
|
1269
|
+
acc_grow : Float32[Array, '']
|
|
1270
|
+
The average acceptance rate of grow moves.
|
|
1271
|
+
acc_prune : Float32[Array, '']
|
|
1272
|
+
The average acceptance rate of prune moves.
|
|
1273
|
+
"""
|
|
1274
|
+
trace = self._main_trace
|
|
1275
|
+
|
|
1276
|
+
def acc(prefix):
|
|
1277
|
+
acc = getattr(trace, f'{prefix}_acc_count')
|
|
1278
|
+
prop = getattr(trace, f'{prefix}_prop_count')
|
|
1279
|
+
return acc.sum() / prop.sum()
|
|
1280
|
+
|
|
1281
|
+
return acc('grow'), acc('prune')
|
|
1282
|
+
|
|
1283
|
+
def avg_prop(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
|
|
1284
|
+
"""Compute the average proposal rate of grow and prune moves.
|
|
1285
|
+
|
|
1286
|
+
Returns
|
|
1287
|
+
-------
|
|
1288
|
+
prop_grow : Float32[Array, '']
|
|
1289
|
+
The fraction of times grow was proposed instead of prune.
|
|
1290
|
+
prop_prune : Float32[Array, '']
|
|
1291
|
+
The fraction of times prune was proposed instead of grow.
|
|
1292
|
+
|
|
1293
|
+
Notes
|
|
1294
|
+
-----
|
|
1295
|
+
This function does not take into account cases where no move was
|
|
1296
|
+
proposed.
|
|
1297
|
+
"""
|
|
1298
|
+
trace = self._main_trace
|
|
1299
|
+
|
|
1300
|
+
def prop(prefix):
|
|
1301
|
+
return getattr(trace, f'{prefix}_prop_count').sum()
|
|
1302
|
+
|
|
1303
|
+
pgrow = prop('grow')
|
|
1304
|
+
pprune = prop('prune')
|
|
1305
|
+
total = pgrow + pprune
|
|
1306
|
+
return pgrow / total, pprune / total
|
|
1307
|
+
|
|
1308
|
+
def avg_move(self) -> tuple[Float32[Array, ''], Float32[Array, '']]:
|
|
1309
|
+
"""Compute the move rate.
|
|
1310
|
+
|
|
1311
|
+
Returns
|
|
1312
|
+
-------
|
|
1313
|
+
rate_grow : Float32[Array, '']
|
|
1314
|
+
The fraction of times a grow move was proposed and accepted.
|
|
1315
|
+
rate_prune : Float32[Array, '']
|
|
1316
|
+
The fraction of times a prune move was proposed and accepted.
|
|
1317
|
+
"""
|
|
1318
|
+
agrow, aprune = self.avg_acc()
|
|
1319
|
+
pgrow, pprune = self.avg_prop()
|
|
1320
|
+
return agrow * pgrow, aprune * pprune
|
|
1321
|
+
|
|
1322
|
+
def depth_distr(self) -> Float32[Array, 'trace_length d']:
|
|
1323
|
+
"""Histogram of tree depths for each state of the trees.
|
|
1324
|
+
|
|
1325
|
+
Returns
|
|
1326
|
+
-------
|
|
1327
|
+
A matrix where each row contains a histogram of tree depths.
|
|
1328
|
+
"""
|
|
1329
|
+
return trace_depth_distr(self._main_trace.split_tree)
|
|
1330
|
+
|
|
1331
|
+
def points_per_decision_node_distr(self) -> Float32[Array, 'trace_length n+1']:
|
|
1332
|
+
"""Histogram of number of points belonging to parent-of-leaf nodes.
|
|
1333
|
+
|
|
1334
|
+
Returns
|
|
1335
|
+
-------
|
|
1336
|
+
A matrix where each row contains a histogram of number of points.
|
|
1337
|
+
"""
|
|
1338
|
+
return trace_points_per_decision_node_distr(
|
|
1339
|
+
self._main_trace, self._mcmc_state.X
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
def points_per_leaf_distr(self) -> Float32[Array, 'trace_length n+1']:
|
|
1343
|
+
"""Histogram of number of points belonging to leaves.
|
|
1344
|
+
|
|
1345
|
+
Returns
|
|
1346
|
+
-------
|
|
1347
|
+
A matrix where each row contains a histogram of number of points.
|
|
1348
|
+
"""
|
|
1349
|
+
return trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X)
|
|
1350
|
+
|
|
1351
|
+
def check_trees(self) -> UInt[Array, 'trace_length ntree']:
|
|
1352
|
+
"""Apply `check_trace` to all the tree draws."""
|
|
1353
|
+
return check_trace(self._main_trace, self._mcmc_state.forest.max_split)
|
|
1354
|
+
|
|
1355
|
+
def tree_goes_bad(self) -> Bool[Array, 'trace_length ntree']:
|
|
1356
|
+
"""Find iterations where a tree becomes invalid.
|
|
1357
|
+
|
|
1358
|
+
Returns
|
|
1359
|
+
-------
|
|
1360
|
+
A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
|
|
1361
|
+
"""
|
|
1362
|
+
bad = self.check_trees().astype(bool)
|
|
1363
|
+
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
1364
|
+
return bad & ~bad_before
|