bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART/__init__.py +27 -0
- bartz/BART/_gbart.py +522 -0
- bartz/__init__.py +4 -2
- bartz/{BART.py → _interface.py} +256 -132
- bartz/_profiler.py +318 -0
- bartz/_version.py +1 -1
- bartz/debug.py +269 -314
- bartz/grove.py +124 -68
- bartz/jaxext/__init__.py +101 -27
- bartz/jaxext/_autobatch.py +257 -51
- bartz/jaxext/scipy/__init__.py +1 -1
- bartz/jaxext/scipy/special.py +3 -4
- bartz/jaxext/scipy/stats.py +1 -1
- bartz/mcmcloop.py +399 -208
- bartz/mcmcstep/__init__.py +35 -0
- bartz/mcmcstep/_moves.py +904 -0
- bartz/mcmcstep/_state.py +1114 -0
- bartz/mcmcstep/_step.py +1603 -0
- bartz/prepcovars.py +1 -1
- bartz/testing/__init__.py +29 -0
- bartz/testing/_dgp.py +442 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/METADATA +17 -11
- bartz-0.8.0.dist-info/RECORD +25 -0
- {bartz-0.7.0.dist-info → bartz-0.8.0.dist-info}/WHEEL +1 -1
- bartz/mcmcstep.py +0 -2616
- bartz-0.7.0.dist-info/RECORD +0 -17
bartz/debug.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/debug.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024-
|
|
3
|
+
# Copyright (c) 2024-2026, The Bartz Contributors
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -22,13 +22,14 @@
|
|
|
22
22
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
23
23
|
# SOFTWARE.
|
|
24
24
|
|
|
25
|
-
"""Debugging utilities. The
|
|
25
|
+
"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
|
|
26
26
|
|
|
27
27
|
from collections.abc import Callable
|
|
28
28
|
from dataclasses import replace
|
|
29
29
|
from functools import partial
|
|
30
30
|
from math import ceil, log2
|
|
31
31
|
from re import fullmatch
|
|
32
|
+
from typing import Literal
|
|
32
33
|
|
|
33
34
|
import numpy
|
|
34
35
|
from equinox import Module, field
|
|
@@ -37,20 +38,22 @@ from jax import numpy as jnp
|
|
|
37
38
|
from jax.tree_util import tree_map
|
|
38
39
|
from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt
|
|
39
40
|
|
|
40
|
-
from bartz.BART import
|
|
41
|
+
from bartz.BART import gbart, mc_gbart
|
|
42
|
+
from bartz.BART._gbart import FloatLike
|
|
41
43
|
from bartz.grove import (
|
|
42
44
|
TreeHeaps,
|
|
43
45
|
evaluate_forest,
|
|
44
46
|
is_actual_leaf,
|
|
45
47
|
is_leaves_parent,
|
|
46
|
-
|
|
48
|
+
normalize_axis_tuple,
|
|
49
|
+
traverse_forest,
|
|
47
50
|
tree_depth,
|
|
48
51
|
tree_depths,
|
|
49
52
|
)
|
|
50
|
-
from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc
|
|
53
|
+
from bartz.jaxext import autobatch, minimal_unsigned_dtype, vmap_nodoc
|
|
51
54
|
from bartz.jaxext import split as split_key
|
|
52
55
|
from bartz.mcmcloop import TreesTrace
|
|
53
|
-
from bartz.mcmcstep import randint_masked
|
|
56
|
+
from bartz.mcmcstep._moves import randint_masked
|
|
54
57
|
|
|
55
58
|
|
|
56
59
|
def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str:
|
|
@@ -154,9 +157,11 @@ def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']:
|
|
|
154
157
|
return jnp.max(depth)
|
|
155
158
|
|
|
156
159
|
|
|
160
|
+
@jit
|
|
161
|
+
@partial(jnp.vectorize, signature='(nt,hts)->(d)')
|
|
157
162
|
def forest_depth_distr(
|
|
158
|
-
split_tree: UInt[Array, 'num_trees 2**(d-1)'],
|
|
159
|
-
) -> Int32[Array, ' d']:
|
|
163
|
+
split_tree: UInt[Array, '*batch_shape num_trees 2**(d-1)'],
|
|
164
|
+
) -> Int32[Array, '*batch_shape d']:
|
|
160
165
|
"""Histogram the depths of a set of trees.
|
|
161
166
|
|
|
162
167
|
Parameters
|
|
@@ -173,195 +178,102 @@ def forest_depth_distr(
|
|
|
173
178
|
return jnp.bincount(depths, length=depth)
|
|
174
179
|
|
|
175
180
|
|
|
176
|
-
@jit
|
|
177
|
-
def
|
|
178
|
-
split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'],
|
|
179
|
-
) -> Int32[Array, 'trace_length d']:
|
|
180
|
-
"""Histogram the depths of a sequence of sets of trees.
|
|
181
|
-
|
|
182
|
-
Parameters
|
|
183
|
-
----------
|
|
184
|
-
split_tree
|
|
185
|
-
The cutpoints of the decision rules of the trees.
|
|
186
|
-
|
|
187
|
-
Returns
|
|
188
|
-
-------
|
|
189
|
-
A matrix where element (t,i) counts how many trees have depth i in set t.
|
|
190
|
-
"""
|
|
191
|
-
return vmap(forest_depth_distr)(split_tree)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
def points_per_decision_node_distr(
|
|
195
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
196
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
181
|
+
@partial(jit, static_argnames=('node_type', 'sum_batch_axis'))
|
|
182
|
+
def points_per_node_distr(
|
|
197
183
|
X: UInt[Array, 'p n'],
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
----------
|
|
206
|
-
var_tree
|
|
207
|
-
The variables of the decision rules.
|
|
208
|
-
split_tree
|
|
209
|
-
The cutpoints of the decision rules.
|
|
210
|
-
X
|
|
211
|
-
The set of points to count.
|
|
212
|
-
|
|
213
|
-
Returns
|
|
214
|
-
-------
|
|
215
|
-
A vector where the i-th element counts how many next-to-leaf nodes have i points.
|
|
216
|
-
"""
|
|
217
|
-
traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None))
|
|
218
|
-
indices = traverse_tree_X(X, var_tree, split_tree)
|
|
219
|
-
indices >>= 1
|
|
220
|
-
count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0)
|
|
221
|
-
is_parent = is_leaves_parent(split_tree)
|
|
222
|
-
return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def forest_points_per_decision_node_distr(
|
|
226
|
-
trees: TreeHeaps, X: UInt[Array, 'p n']
|
|
227
|
-
) -> Int32[Array, ' n+1']:
|
|
228
|
-
"""Histogram points-per-node counts for a set of trees.
|
|
229
|
-
|
|
230
|
-
Count how many parent-of-leaf nodes in a set of trees select each possible
|
|
231
|
-
amount of points.
|
|
232
|
-
|
|
233
|
-
Parameters
|
|
234
|
-
----------
|
|
235
|
-
trees
|
|
236
|
-
The set of trees. The variables must have broadcast shape (num_trees,).
|
|
237
|
-
X
|
|
238
|
-
The set of points to count.
|
|
239
|
-
|
|
240
|
-
Returns
|
|
241
|
-
-------
|
|
242
|
-
A vector where the i-th element counts how many next-to-leaf nodes have i points.
|
|
243
|
-
"""
|
|
244
|
-
distr = jnp.zeros(X.shape[1] + 1, int)
|
|
184
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
185
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
186
|
+
node_type: Literal['leaf', 'leaf-parent'],
|
|
187
|
+
*,
|
|
188
|
+
sum_batch_axis: int | tuple[int, ...] = (),
|
|
189
|
+
) -> Int32[Array, '*reduced_batch_shape n+1']:
|
|
190
|
+
"""Histogram points-per-node counts in a set of trees.
|
|
245
191
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree))
|
|
250
|
-
return distr
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
@jit
|
|
254
|
-
def trace_points_per_decision_node_distr(
|
|
255
|
-
trace: TreeHeaps, X: UInt[Array, 'p n']
|
|
256
|
-
) -> Int32[Array, 'trace_length n+1']:
|
|
257
|
-
"""Separately histogram points-per-node counts over a sequence of sets of trees.
|
|
258
|
-
|
|
259
|
-
For each set of trees, count how many parent-of-leaf nodes select each
|
|
260
|
-
possible amount of points.
|
|
192
|
+
Count how many nodes in a tree select each possible amount of points,
|
|
193
|
+
over a certain subset of nodes.
|
|
261
194
|
|
|
262
195
|
Parameters
|
|
263
196
|
----------
|
|
264
|
-
trace
|
|
265
|
-
The sequence of sets of trees. The variables must have broadcast shape
|
|
266
|
-
(trace_length, num_trees).
|
|
267
197
|
X
|
|
268
198
|
The set of points to count.
|
|
269
|
-
|
|
270
|
-
Returns
|
|
271
|
-
-------
|
|
272
|
-
A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t.
|
|
273
|
-
"""
|
|
274
|
-
|
|
275
|
-
def loop(_, trace):
|
|
276
|
-
return None, forest_points_per_decision_node_distr(trace, X)
|
|
277
|
-
|
|
278
|
-
_, distr = lax.scan(loop, None, trace)
|
|
279
|
-
return distr
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def points_per_leaf_distr(
|
|
283
|
-
var_tree: UInt[Array, ' 2**(d-1)'],
|
|
284
|
-
split_tree: UInt[Array, ' 2**(d-1)'],
|
|
285
|
-
X: UInt[Array, 'p n'],
|
|
286
|
-
) -> Int32[Array, ' n+1']:
|
|
287
|
-
"""Histogram points-per-leaf counts in a tree.
|
|
288
|
-
|
|
289
|
-
Count how many leaves in a tree select each possible amount of points.
|
|
290
|
-
|
|
291
|
-
Parameters
|
|
292
|
-
----------
|
|
293
199
|
var_tree
|
|
294
200
|
The variables of the decision rules.
|
|
295
201
|
split_tree
|
|
296
202
|
The cutpoints of the decision rules.
|
|
297
|
-
|
|
298
|
-
The
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
def forest_points_per_leaf_distr(
|
|
312
|
-
trees: TreeHeaps, X: UInt[Array, 'p n']
|
|
313
|
-
) -> Int32[Array, ' n+1']:
|
|
314
|
-
"""Histogram points-per-leaf counts over a set of trees.
|
|
315
|
-
|
|
316
|
-
Count how many leaves in a set of trees select each possible amount of points.
|
|
317
|
-
|
|
318
|
-
Parameters
|
|
319
|
-
----------
|
|
320
|
-
trees
|
|
321
|
-
The set of trees. The variables must have broadcast shape (num_trees,).
|
|
322
|
-
X
|
|
323
|
-
The set of points to count.
|
|
203
|
+
node_type
|
|
204
|
+
The type of nodes to consider. Can be:
|
|
205
|
+
|
|
206
|
+
'leaf'
|
|
207
|
+
Count only leaf nodes.
|
|
208
|
+
'leaf-parent'
|
|
209
|
+
Count only parent-of-leaf nodes.
|
|
210
|
+
sum_batch_axis
|
|
211
|
+
Aggregate the histogram over these batch axes, counting how many nodes
|
|
212
|
+
have each possible amount of points over subsets of trees instead of
|
|
213
|
+
in each tree separately.
|
|
324
214
|
|
|
325
215
|
Returns
|
|
326
216
|
-------
|
|
327
|
-
A vector where the i-th element counts how many
|
|
217
|
+
A vector where the i-th element counts how many nodes have i points.
|
|
328
218
|
"""
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
219
|
+
batch_ndim = var_tree.ndim - 1
|
|
220
|
+
axes = normalize_axis_tuple(sum_batch_axis, batch_ndim)
|
|
221
|
+
|
|
222
|
+
def func(
|
|
223
|
+
var_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
224
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
225
|
+
) -> Int32[Array, '*reduced_batch_shape n+1']:
|
|
226
|
+
indices: UInt[Array, '*batch_shape n']
|
|
227
|
+
indices = traverse_forest(X, var_tree, split_tree)
|
|
228
|
+
|
|
229
|
+
@partial(jnp.vectorize, signature='(hts),(n)->(ts_or_hts),(ts_or_hts)')
|
|
230
|
+
def count_points(
|
|
231
|
+
split_tree: UInt[Array, '*batch_shape 2**(d-1)'],
|
|
232
|
+
indices: UInt[Array, '*batch_shape n'],
|
|
233
|
+
) -> (
|
|
234
|
+
tuple[UInt[Array, '*batch_shape 2**d'], Bool[Array, '*batch_shape 2**d']]
|
|
235
|
+
| tuple[
|
|
236
|
+
UInt[Array, '*batch_shape 2**(d-1)'],
|
|
237
|
+
Bool[Array, '*batch_shape 2**(d-1)'],
|
|
238
|
+
]
|
|
239
|
+
):
|
|
240
|
+
if node_type == 'leaf-parent':
|
|
241
|
+
indices >>= 1
|
|
242
|
+
predicate = is_leaves_parent(split_tree)
|
|
243
|
+
elif node_type == 'leaf':
|
|
244
|
+
predicate = is_actual_leaf(split_tree, add_bottom_level=True)
|
|
245
|
+
else:
|
|
246
|
+
raise ValueError(node_type)
|
|
247
|
+
count_tree = jnp.zeros(predicate.size, int).at[indices].add(1).at[0].set(0)
|
|
248
|
+
return count_tree, predicate
|
|
249
|
+
|
|
250
|
+
count_tree, predicate = count_points(split_tree, indices)
|
|
251
|
+
|
|
252
|
+
def count_nodes(
|
|
253
|
+
count_tree: UInt[Array, '*summed_batch_axes half_tree_size'],
|
|
254
|
+
predicate: Bool[Array, '*summed_batch_axes half_tree_size'],
|
|
255
|
+
) -> Int32[Array, ' n+1']:
|
|
256
|
+
return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(predicate)
|
|
257
|
+
|
|
258
|
+
# vmap count_nodes over non-batched dims
|
|
259
|
+
for i in reversed(range(batch_ndim)):
|
|
260
|
+
neg_i = i - var_tree.ndim
|
|
261
|
+
if i not in axes:
|
|
262
|
+
count_nodes = vmap(count_nodes, in_axes=neg_i)
|
|
263
|
+
|
|
264
|
+
return count_nodes(count_tree, predicate)
|
|
265
|
+
|
|
266
|
+
# automatically batch over all batch dimensions
|
|
267
|
+
max_io_nbytes = 2**27 # 128 MiB
|
|
268
|
+
out_dim_shift = len(axes)
|
|
269
|
+
for i in reversed(range(batch_ndim)):
|
|
270
|
+
if i in axes:
|
|
271
|
+
out_dim_shift -= 1
|
|
272
|
+
else:
|
|
273
|
+
func = autobatch(func, max_io_nbytes, i, i - out_dim_shift)
|
|
274
|
+
assert out_dim_shift == 0
|
|
362
275
|
|
|
363
|
-
|
|
364
|
-
return distr
|
|
276
|
+
return func(var_tree, split_tree)
|
|
365
277
|
|
|
366
278
|
|
|
367
279
|
check_functions = []
|
|
@@ -400,29 +312,36 @@ def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool:
|
|
|
400
312
|
return (
|
|
401
313
|
tree.var_tree.dtype == expected_var_dtype
|
|
402
314
|
and tree.split_tree.dtype == expected_split_dtype
|
|
315
|
+
and jnp.issubdtype(max_split.dtype, jnp.unsignedinteger)
|
|
403
316
|
)
|
|
404
317
|
|
|
405
318
|
|
|
406
319
|
@check
|
|
407
|
-
def check_sizes(tree: TreeHeaps,
|
|
320
|
+
def check_sizes(tree: TreeHeaps, _max_split: UInt[Array, ' p']) -> bool:
|
|
408
321
|
"""Check that array sizes are coherent."""
|
|
409
322
|
return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size
|
|
410
323
|
|
|
411
324
|
|
|
412
325
|
@check
|
|
413
|
-
def check_unused_node(
|
|
326
|
+
def check_unused_node(
|
|
327
|
+
tree: TreeHeaps, _max_split: UInt[Array, ' p']
|
|
328
|
+
) -> Bool[Array, '']:
|
|
414
329
|
"""Check that the unused node slot at index 0 is not dirty."""
|
|
415
330
|
return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0)
|
|
416
331
|
|
|
417
332
|
|
|
418
333
|
@check
|
|
419
|
-
def check_leaf_values(
|
|
334
|
+
def check_leaf_values(
|
|
335
|
+
tree: TreeHeaps, _max_split: UInt[Array, ' p']
|
|
336
|
+
) -> Bool[Array, '']:
|
|
420
337
|
"""Check that all leaf values are not inf of nan."""
|
|
421
338
|
return jnp.all(jnp.isfinite(tree.leaf_tree))
|
|
422
339
|
|
|
423
340
|
|
|
424
341
|
@check
|
|
425
|
-
def check_stray_nodes(
|
|
342
|
+
def check_stray_nodes(
|
|
343
|
+
tree: TreeHeaps, _max_split: UInt[Array, ' p']
|
|
344
|
+
) -> Bool[Array, '']:
|
|
426
345
|
"""Check if there is any marked-non-leaf node with a marked-leaf parent."""
|
|
427
346
|
index = jnp.arange(
|
|
428
347
|
2 * tree.split_tree.size,
|
|
@@ -446,12 +365,12 @@ def check_rule_consistency(
|
|
|
446
365
|
|
|
447
366
|
# initial boundaries of decision rules. use extreme integers instead of 0,
|
|
448
367
|
# max_split to avoid checking if there is something out of bounds.
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
#
|
|
368
|
+
dtype = tree.split_tree.dtype
|
|
369
|
+
small = jnp.iinfo(dtype).min
|
|
370
|
+
large = jnp.iinfo(dtype).max
|
|
371
|
+
lower = jnp.full(max_split.size, small, dtype)
|
|
372
|
+
upper = jnp.full(max_split.size, large, dtype)
|
|
373
|
+
# the split must be in (lower[var], upper[var]]
|
|
455
374
|
|
|
456
375
|
def _check_recursive(node, lower, upper):
|
|
457
376
|
# read decision rule
|
|
@@ -464,20 +383,14 @@ def check_rule_consistency(
|
|
|
464
383
|
upper_var = upper.at[var].get(mode='fill', fill_value=large)
|
|
465
384
|
|
|
466
385
|
# check rule is in bounds
|
|
467
|
-
bad = jnp.where(split, (split <= lower_var) | (split
|
|
386
|
+
bad = jnp.where(split, (split <= lower_var) | (split > upper_var), False)
|
|
468
387
|
|
|
469
388
|
# recurse
|
|
470
389
|
if node < tree.var_tree.size // 2:
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
)
|
|
476
|
-
bad |= _check_recursive(
|
|
477
|
-
2 * node + 1,
|
|
478
|
-
lower.at[jnp.where(split, var, max_split.size)].set(split),
|
|
479
|
-
upper,
|
|
480
|
-
)
|
|
390
|
+
idx = jnp.where(split, var, max_split.size)
|
|
391
|
+
bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split - 1))
|
|
392
|
+
bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper)
|
|
393
|
+
|
|
481
394
|
return bad
|
|
482
395
|
|
|
483
396
|
return ~_check_recursive(1, lower, upper)
|
|
@@ -557,30 +470,39 @@ def describe_error(error: int | Integer[Array, '']) -> list[str]:
|
|
|
557
470
|
|
|
558
471
|
|
|
559
472
|
@jit
|
|
560
|
-
@partial(vmap_nodoc, in_axes=(0, None))
|
|
561
473
|
def check_trace(
|
|
562
474
|
trace: TreeHeaps, max_split: UInt[Array, ' p']
|
|
563
|
-
) -> UInt[Array, '
|
|
564
|
-
"""Check the validity of a
|
|
475
|
+
) -> UInt[Array, '*batch_shape']:
|
|
476
|
+
"""Check the validity of a set of trees.
|
|
565
477
|
|
|
566
478
|
Use `describe_error` to parse the error codes returned by this function.
|
|
567
479
|
|
|
568
480
|
Parameters
|
|
569
481
|
----------
|
|
570
482
|
trace
|
|
571
|
-
The
|
|
572
|
-
|
|
573
|
-
additional attributes beyond the tree arrays, they are ignored.
|
|
483
|
+
The set of trees to check. This object can have additional attributes
|
|
484
|
+
beyond the tree arrays, they are ignored.
|
|
574
485
|
max_split
|
|
575
486
|
The maximum split value for each variable.
|
|
576
487
|
|
|
577
488
|
Returns
|
|
578
489
|
-------
|
|
579
|
-
A
|
|
490
|
+
A tensor of error codes for each tree.
|
|
580
491
|
"""
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
492
|
+
# vectorize check_tree over all batch dimensions
|
|
493
|
+
unpack_check_tree = lambda l, v, s: check_tree(TreesTrace(l, v, s), max_split)
|
|
494
|
+
is_mv = trace.leaf_tree.ndim > trace.split_tree.ndim
|
|
495
|
+
signature = '(k,ts),(hts),(hts)->()' if is_mv else '(ts),(hts),(hts)->()'
|
|
496
|
+
vec_check_tree = jnp.vectorize(unpack_check_tree, signature=signature)
|
|
497
|
+
|
|
498
|
+
# automatically batch over all batch dimensions
|
|
499
|
+
max_io_nbytes = 2**24 # 16 MiB
|
|
500
|
+
batch_ndim = trace.split_tree.ndim - 1
|
|
501
|
+
batched_check_tree = vec_check_tree
|
|
502
|
+
for i in reversed(range(batch_ndim)):
|
|
503
|
+
batched_check_tree = autobatch(batched_check_tree, max_io_nbytes, i, i)
|
|
504
|
+
|
|
505
|
+
return batched_check_tree(trace.leaf_tree, trace.var_tree, trace.split_tree)
|
|
584
506
|
|
|
585
507
|
|
|
586
508
|
def _get_next_line(s: str, i: int) -> tuple[str, int]:
|
|
@@ -592,24 +514,19 @@ def _get_next_line(s: str, i: int) -> tuple[str, int]:
|
|
|
592
514
|
|
|
593
515
|
|
|
594
516
|
class BARTTraceMeta(Module):
|
|
595
|
-
"""Metadata of R BART tree traces.
|
|
596
|
-
|
|
597
|
-
Parameters
|
|
598
|
-
----------
|
|
599
|
-
ndpost
|
|
600
|
-
The number of posterior draws.
|
|
601
|
-
ntree
|
|
602
|
-
The number of trees in the model.
|
|
603
|
-
numcut
|
|
604
|
-
The maximum split value for each variable.
|
|
605
|
-
heap_size
|
|
606
|
-
The size of the heap required to store the trees.
|
|
607
|
-
"""
|
|
517
|
+
"""Metadata of R BART tree traces."""
|
|
608
518
|
|
|
609
519
|
ndpost: int = field(static=True)
|
|
520
|
+
"""The number of posterior draws."""
|
|
521
|
+
|
|
610
522
|
ntree: int = field(static=True)
|
|
523
|
+
"""The number of trees in the model."""
|
|
524
|
+
|
|
611
525
|
numcut: UInt[Array, ' p']
|
|
526
|
+
"""The maximum split value for each variable."""
|
|
527
|
+
|
|
612
528
|
heap_size: int = field(static=True)
|
|
529
|
+
"""The size of the heap required to store the trees."""
|
|
613
530
|
|
|
614
531
|
|
|
615
532
|
def scan_BART_trees(trees: str) -> BARTTraceMeta:
|
|
@@ -801,25 +718,24 @@ class SamplePriorStack(Module):
|
|
|
801
718
|
|
|
802
719
|
Each level of the stack represents a recursion into a child node in a
|
|
803
720
|
binary tree of maximum depth `d`.
|
|
804
|
-
|
|
805
|
-
Parameters
|
|
806
|
-
----------
|
|
807
|
-
nonterminal
|
|
808
|
-
Whether the node is valid or the recursion is into unused node slots.
|
|
809
|
-
lower
|
|
810
|
-
upper
|
|
811
|
-
The available cutpoints along ``var`` are in the integer range
|
|
812
|
-
``[1 + lower[var], 1 + upper[var])``.
|
|
813
|
-
var
|
|
814
|
-
split
|
|
815
|
-
The variable and cutpoint of a decision node.
|
|
816
721
|
"""
|
|
817
722
|
|
|
818
723
|
nonterminal: Bool[Array, ' d-1']
|
|
724
|
+
"""Whether the node is valid or the recursion is into unused node slots."""
|
|
725
|
+
|
|
819
726
|
lower: UInt[Array, 'd-1 p']
|
|
727
|
+
"""The available cutpoints along ``var`` are in the integer range
|
|
728
|
+
``[1 + lower[var], 1 + upper[var])``."""
|
|
729
|
+
|
|
820
730
|
upper: UInt[Array, 'd-1 p']
|
|
731
|
+
"""The available cutpoints along ``var`` are in the integer range
|
|
732
|
+
``[1 + lower[var], 1 + upper[var])``."""
|
|
733
|
+
|
|
821
734
|
var: UInt[Array, ' d-1']
|
|
735
|
+
"""The variable of a decision node."""
|
|
736
|
+
|
|
822
737
|
split: UInt[Array, ' d-1']
|
|
738
|
+
"""The cutpoint of a decision node."""
|
|
823
739
|
|
|
824
740
|
@classmethod
|
|
825
741
|
def initial(
|
|
@@ -850,19 +766,16 @@ class SamplePriorStack(Module):
|
|
|
850
766
|
|
|
851
767
|
|
|
852
768
|
class SamplePriorTrees(Module):
|
|
853
|
-
"""Object holding the trees generated by `sample_prior`.
|
|
854
|
-
|
|
855
|
-
Parameters
|
|
856
|
-
----------
|
|
857
|
-
leaf_tree
|
|
858
|
-
var_tree
|
|
859
|
-
split_tree
|
|
860
|
-
The arrays representing the trees, see `bartz.grove`.
|
|
861
|
-
"""
|
|
769
|
+
"""Object holding the trees generated by `sample_prior`."""
|
|
862
770
|
|
|
863
771
|
leaf_tree: Float32[Array, '* 2**d']
|
|
772
|
+
"""The array representing the trees, see `bartz.grove`."""
|
|
773
|
+
|
|
864
774
|
var_tree: UInt[Array, '* 2**(d-1)']
|
|
775
|
+
"""The array representing the trees, see `bartz.grove`."""
|
|
776
|
+
|
|
865
777
|
split_tree: UInt[Array, '* 2**(d-1)']
|
|
778
|
+
"""The array representing the trees, see `bartz.grove`."""
|
|
866
779
|
|
|
867
780
|
@classmethod
|
|
868
781
|
def initial(
|
|
@@ -903,21 +816,16 @@ class SamplePriorTrees(Module):
|
|
|
903
816
|
|
|
904
817
|
|
|
905
818
|
class SamplePriorCarry(Module):
|
|
906
|
-
"""Object holding values carried along the recursion in `sample_prior`.
|
|
907
|
-
|
|
908
|
-
Parameters
|
|
909
|
-
----------
|
|
910
|
-
key
|
|
911
|
-
A jax random key used to sample decision rules.
|
|
912
|
-
stack
|
|
913
|
-
The stack used to manage the recursion.
|
|
914
|
-
trees
|
|
915
|
-
The output arrays.
|
|
916
|
-
"""
|
|
819
|
+
"""Object holding values carried along the recursion in `sample_prior`."""
|
|
917
820
|
|
|
918
821
|
key: Key[Array, '']
|
|
822
|
+
"""A jax random key used to sample decision rules."""
|
|
823
|
+
|
|
919
824
|
stack: SamplePriorStack
|
|
825
|
+
"""The stack used to manage the recursion."""
|
|
826
|
+
|
|
920
827
|
trees: SamplePriorTrees
|
|
828
|
+
"""The output arrays."""
|
|
921
829
|
|
|
922
830
|
@classmethod
|
|
923
831
|
def initial(
|
|
@@ -958,21 +866,17 @@ class SamplePriorX(Module):
|
|
|
958
866
|
|
|
959
867
|
The sequence of nodes to visit is pre-computed recursively once, unrolling
|
|
960
868
|
the recursion schedule.
|
|
961
|
-
|
|
962
|
-
Parameters
|
|
963
|
-
----------
|
|
964
|
-
node
|
|
965
|
-
The heap index of the node to visit.
|
|
966
|
-
depth
|
|
967
|
-
The depth of the node.
|
|
968
|
-
next_depth
|
|
969
|
-
The depth of the next node to visit, either the left child or the right
|
|
970
|
-
sibling of the node or of an ancestor.
|
|
971
869
|
"""
|
|
972
870
|
|
|
973
871
|
node: Int32[Array, ' 2**(d-1)-1']
|
|
872
|
+
"""The heap index of the node to visit."""
|
|
873
|
+
|
|
974
874
|
depth: Int32[Array, ' 2**(d-1)-1']
|
|
875
|
+
"""The depth of the node."""
|
|
876
|
+
|
|
975
877
|
next_depth: Int32[Array, ' 2**(d-1)-1']
|
|
878
|
+
"""The depth of the next node to visit, either the left child or the right
|
|
879
|
+
sibling of the node or of an ancestor."""
|
|
976
880
|
|
|
977
881
|
@classmethod
|
|
978
882
|
def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX':
|
|
@@ -992,7 +896,7 @@ class SamplePriorX(Module):
|
|
|
992
896
|
assert len(seq) == 2**p_nonterminal.size - 1
|
|
993
897
|
node = [node for node, depth in seq]
|
|
994
898
|
depth = [depth for node, depth in seq]
|
|
995
|
-
next_depth = depth[1:]
|
|
899
|
+
next_depth = [*depth[1:], p_nonterminal.size]
|
|
996
900
|
return cls(
|
|
997
901
|
node=jnp.array(node),
|
|
998
902
|
depth=jnp.array(depth),
|
|
@@ -1173,18 +1077,18 @@ def sample_prior(
|
|
|
1173
1077
|
return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees)
|
|
1174
1078
|
|
|
1175
1079
|
|
|
1176
|
-
class
|
|
1177
|
-
"""A subclass of `
|
|
1080
|
+
class debug_mc_gbart(mc_gbart):
|
|
1081
|
+
"""A subclass of `mc_gbart` that adds debugging functionality.
|
|
1178
1082
|
|
|
1179
1083
|
Parameters
|
|
1180
1084
|
----------
|
|
1181
1085
|
*args
|
|
1182
|
-
Passed to `
|
|
1086
|
+
Passed to `mc_gbart`.
|
|
1183
1087
|
check_trees
|
|
1184
1088
|
If `True`, check all trees with `check_trace` after running the MCMC,
|
|
1185
1089
|
and assert that they are all valid. Set to `False` to allow jax tracing.
|
|
1186
1090
|
**kw
|
|
1187
|
-
Passed to `
|
|
1091
|
+
Passed to `mc_gbart`.
|
|
1188
1092
|
"""
|
|
1189
1093
|
|
|
1190
1094
|
def __init__(self, *args, check_trees: bool = True, **kw):
|
|
@@ -1194,24 +1098,28 @@ class debug_gbart(gbart):
|
|
|
1194
1098
|
bad_count = jnp.count_nonzero(bad)
|
|
1195
1099
|
assert bad_count == 0
|
|
1196
1100
|
|
|
1197
|
-
def
|
|
1101
|
+
def print_tree(
|
|
1102
|
+
self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
|
|
1103
|
+
):
|
|
1198
1104
|
"""Print a single tree in human-readable format.
|
|
1199
1105
|
|
|
1200
1106
|
Parameters
|
|
1201
1107
|
----------
|
|
1108
|
+
i_chain
|
|
1109
|
+
The index of the MCMC chain.
|
|
1202
1110
|
i_sample
|
|
1203
|
-
The index of the
|
|
1111
|
+
The index of the (post-burnin) sample in the chain.
|
|
1204
1112
|
i_tree
|
|
1205
1113
|
The index of the tree in the sample.
|
|
1206
1114
|
print_all
|
|
1207
1115
|
If `True`, also print the content of unused node slots.
|
|
1208
1116
|
"""
|
|
1209
1117
|
tree = TreesTrace.from_dataclass(self._main_trace)
|
|
1210
|
-
tree = tree_map(lambda x: x[i_sample, i_tree, :], tree)
|
|
1118
|
+
tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
|
|
1211
1119
|
s = format_tree(tree, print_all=print_all)
|
|
1212
1120
|
print(s) # noqa: T201, this method is intended for debug
|
|
1213
1121
|
|
|
1214
|
-
def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, '']:
|
|
1122
|
+
def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']:
|
|
1215
1123
|
"""Return the harmonic mean of the error variance.
|
|
1216
1124
|
|
|
1217
1125
|
Parameters
|
|
@@ -1225,33 +1133,36 @@ class debug_gbart(gbart):
|
|
|
1225
1133
|
The harmonic mean 1/E[1/sigma^2] in the selected distribution.
|
|
1226
1134
|
"""
|
|
1227
1135
|
bart = self._mcmc_state
|
|
1228
|
-
assert bart.
|
|
1136
|
+
assert bart.error_cov_df is not None
|
|
1229
1137
|
assert bart.z is None
|
|
1138
|
+
# inverse gamma prior: alpha = df / 2, beta = scale / 2
|
|
1230
1139
|
if prior:
|
|
1231
|
-
alpha = bart.
|
|
1232
|
-
beta = bart.
|
|
1140
|
+
alpha = bart.error_cov_df / 2
|
|
1141
|
+
beta = bart.error_cov_scale / 2
|
|
1233
1142
|
else:
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1143
|
+
alpha = bart.error_cov_df / 2 + bart.resid.size / 2
|
|
1144
|
+
norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
|
|
1145
|
+
beta = bart.error_cov_scale / 2 + norm2 / 2
|
|
1146
|
+
error_cov_inv = alpha / beta
|
|
1147
|
+
return jnp.sqrt(lax.reciprocal(error_cov_inv))
|
|
1148
|
+
|
|
1149
|
+
def compare_resid(
|
|
1150
|
+
self,
|
|
1151
|
+
) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
|
|
1242
1152
|
"""Re-compute residuals to compare them with the updated ones.
|
|
1243
1153
|
|
|
1244
1154
|
Returns
|
|
1245
1155
|
-------
|
|
1246
|
-
resid1 : Float32[Array, 'n']
|
|
1156
|
+
resid1 : Float32[Array, 'mc_cores n']
|
|
1247
1157
|
The final state of the residuals updated during the MCMC.
|
|
1248
|
-
resid2 : Float32[Array, 'n']
|
|
1158
|
+
resid2 : Float32[Array, 'mc_cores n']
|
|
1249
1159
|
The residuals computed from the final state of the trees.
|
|
1250
1160
|
"""
|
|
1251
1161
|
bart = self._mcmc_state
|
|
1252
1162
|
resid1 = bart.resid
|
|
1253
1163
|
|
|
1254
|
-
|
|
1164
|
+
forests = TreesTrace.from_dataclass(bart.forest)
|
|
1165
|
+
trees = evaluate_forest(bart.X, forests, sum_batch_axis=-1)
|
|
1255
1166
|
|
|
1256
1167
|
if bart.z is not None:
|
|
1257
1168
|
ref = bart.z
|
|
@@ -1261,14 +1172,16 @@ class debug_gbart(gbart):
|
|
|
1261
1172
|
|
|
1262
1173
|
return resid1, resid2
|
|
1263
1174
|
|
|
1264
|
-
def avg_acc(
|
|
1175
|
+
def avg_acc(
|
|
1176
|
+
self,
|
|
1177
|
+
) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
|
|
1265
1178
|
"""Compute the average acceptance rates of tree moves.
|
|
1266
1179
|
|
|
1267
1180
|
Returns
|
|
1268
1181
|
-------
|
|
1269
|
-
acc_grow : Float32[Array, '']
|
|
1182
|
+
acc_grow : Float32[Array, 'mc_cores']
|
|
1270
1183
|
The average acceptance rate of grow moves.
|
|
1271
|
-
acc_prune : Float32[Array, '']
|
|
1184
|
+
acc_prune : Float32[Array, 'mc_cores']
|
|
1272
1185
|
The average acceptance rate of prune moves.
|
|
1273
1186
|
"""
|
|
1274
1187
|
trace = self._main_trace
|
|
@@ -1276,18 +1189,20 @@ class debug_gbart(gbart):
|
|
|
1276
1189
|
def acc(prefix):
|
|
1277
1190
|
acc = getattr(trace, f'{prefix}_acc_count')
|
|
1278
1191
|
prop = getattr(trace, f'{prefix}_prop_count')
|
|
1279
|
-
return acc.sum() / prop.sum()
|
|
1192
|
+
return acc.sum(axis=1) / prop.sum(axis=1)
|
|
1280
1193
|
|
|
1281
1194
|
return acc('grow'), acc('prune')
|
|
1282
1195
|
|
|
1283
|
-
def avg_prop(
|
|
1196
|
+
def avg_prop(
|
|
1197
|
+
self,
|
|
1198
|
+
) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
|
|
1284
1199
|
"""Compute the average proposal rate of grow and prune moves.
|
|
1285
1200
|
|
|
1286
1201
|
Returns
|
|
1287
1202
|
-------
|
|
1288
|
-
prop_grow : Float32[Array, '']
|
|
1203
|
+
prop_grow : Float32[Array, 'mc_cores']
|
|
1289
1204
|
The fraction of times grow was proposed instead of prune.
|
|
1290
|
-
prop_prune : Float32[Array, '']
|
|
1205
|
+
prop_prune : Float32[Array, 'mc_cores']
|
|
1291
1206
|
The fraction of times prune was proposed instead of grow.
|
|
1292
1207
|
|
|
1293
1208
|
Notes
|
|
@@ -1298,61 +1213,86 @@ class debug_gbart(gbart):
|
|
|
1298
1213
|
trace = self._main_trace
|
|
1299
1214
|
|
|
1300
1215
|
def prop(prefix):
|
|
1301
|
-
return getattr(trace, f'{prefix}_prop_count').sum()
|
|
1216
|
+
return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
|
|
1302
1217
|
|
|
1303
1218
|
pgrow = prop('grow')
|
|
1304
1219
|
pprune = prop('prune')
|
|
1305
1220
|
total = pgrow + pprune
|
|
1306
1221
|
return pgrow / total, pprune / total
|
|
1307
1222
|
|
|
1308
|
-
def avg_move(
|
|
1223
|
+
def avg_move(
|
|
1224
|
+
self,
|
|
1225
|
+
) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
|
|
1309
1226
|
"""Compute the move rate.
|
|
1310
1227
|
|
|
1311
1228
|
Returns
|
|
1312
1229
|
-------
|
|
1313
|
-
rate_grow : Float32[Array, '']
|
|
1230
|
+
rate_grow : Float32[Array, 'mc_cores']
|
|
1314
1231
|
The fraction of times a grow move was proposed and accepted.
|
|
1315
|
-
rate_prune : Float32[Array, '']
|
|
1232
|
+
rate_prune : Float32[Array, 'mc_cores']
|
|
1316
1233
|
The fraction of times a prune move was proposed and accepted.
|
|
1317
1234
|
"""
|
|
1318
1235
|
agrow, aprune = self.avg_acc()
|
|
1319
1236
|
pgrow, pprune = self.avg_prop()
|
|
1320
1237
|
return agrow * pgrow, aprune * pprune
|
|
1321
1238
|
|
|
1322
|
-
def depth_distr(self) ->
|
|
1239
|
+
def depth_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores d']:
|
|
1323
1240
|
"""Histogram of tree depths for each state of the trees.
|
|
1324
1241
|
|
|
1325
1242
|
Returns
|
|
1326
1243
|
-------
|
|
1327
1244
|
A matrix where each row contains a histogram of tree depths.
|
|
1328
1245
|
"""
|
|
1329
|
-
|
|
1246
|
+
out: Int32[Array, '*chains samples d']
|
|
1247
|
+
out = forest_depth_distr(self._main_trace.split_tree)
|
|
1248
|
+
if out.ndim < 3:
|
|
1249
|
+
out = out[None, :, :]
|
|
1250
|
+
return out
|
|
1251
|
+
|
|
1252
|
+
def _points_per_node_distr(
|
|
1253
|
+
self, node_type: str
|
|
1254
|
+
) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
|
|
1255
|
+
out: Int32[Array, '*chains samples n+1']
|
|
1256
|
+
out = points_per_node_distr(
|
|
1257
|
+
self._mcmc_state.X,
|
|
1258
|
+
self._main_trace.var_tree,
|
|
1259
|
+
self._main_trace.split_tree,
|
|
1260
|
+
node_type,
|
|
1261
|
+
sum_batch_axis=-1,
|
|
1262
|
+
)
|
|
1263
|
+
if out.ndim < 3:
|
|
1264
|
+
out = out[None, :, :]
|
|
1265
|
+
return out
|
|
1330
1266
|
|
|
1331
|
-
def points_per_decision_node_distr(
|
|
1267
|
+
def points_per_decision_node_distr(
|
|
1268
|
+
self,
|
|
1269
|
+
) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
|
|
1332
1270
|
"""Histogram of number of points belonging to parent-of-leaf nodes.
|
|
1333
1271
|
|
|
1334
1272
|
Returns
|
|
1335
1273
|
-------
|
|
1336
|
-
|
|
1274
|
+
For each chain, a matrix where each row contains a histogram of number of points.
|
|
1337
1275
|
"""
|
|
1338
|
-
return
|
|
1339
|
-
self._main_trace, self._mcmc_state.X
|
|
1340
|
-
)
|
|
1276
|
+
return self._points_per_node_distr('leaf-parent')
|
|
1341
1277
|
|
|
1342
|
-
def points_per_leaf_distr(self) ->
|
|
1278
|
+
def points_per_leaf_distr(self) -> Int32[Array, 'mc_cores ndpost/mc_cores n+1']:
|
|
1343
1279
|
"""Histogram of number of points belonging to leaves.
|
|
1344
1280
|
|
|
1345
1281
|
Returns
|
|
1346
1282
|
-------
|
|
1347
1283
|
A matrix where each row contains a histogram of number of points.
|
|
1348
1284
|
"""
|
|
1349
|
-
return
|
|
1285
|
+
return self._points_per_node_distr('leaf')
|
|
1350
1286
|
|
|
1351
|
-
def check_trees(self) -> UInt[Array, '
|
|
1287
|
+
def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']:
|
|
1352
1288
|
"""Apply `check_trace` to all the tree draws."""
|
|
1353
|
-
|
|
1289
|
+
out: UInt[Array, '*chains samples num_trees']
|
|
1290
|
+
out = check_trace(self._main_trace, self._mcmc_state.forest.max_split)
|
|
1291
|
+
if out.ndim < 3:
|
|
1292
|
+
out = out[None, :, :]
|
|
1293
|
+
return out
|
|
1354
1294
|
|
|
1355
|
-
def tree_goes_bad(self) -> Bool[Array, '
|
|
1295
|
+
def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']:
|
|
1356
1296
|
"""Find iterations where a tree becomes invalid.
|
|
1357
1297
|
|
|
1358
1298
|
Returns
|
|
@@ -1360,5 +1300,20 @@ class debug_gbart(gbart):
|
|
|
1360
1300
|
A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
|
|
1361
1301
|
"""
|
|
1362
1302
|
bad = self.check_trees().astype(bool)
|
|
1363
|
-
bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
|
|
1303
|
+
bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
|
|
1364
1304
|
return bad & ~bad_before
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
class debug_gbart(debug_mc_gbart, gbart):
|
|
1308
|
+
"""A subclass of `gbart` that adds debugging functionality.
|
|
1309
|
+
|
|
1310
|
+
Parameters
|
|
1311
|
+
----------
|
|
1312
|
+
*args
|
|
1313
|
+
Passed to `gbart`.
|
|
1314
|
+
check_trees
|
|
1315
|
+
If `True`, check all trees with `check_trace` after running the MCMC,
|
|
1316
|
+
and assert that they are all valid. Set to `False` to allow jax tracing.
|
|
1317
|
+
**kw
|
|
1318
|
+
Passed to `gbart`.
|
|
1319
|
+
"""
|