bartz 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bartz/.DS_Store +0 -0
- bartz/BART.py +99 -39
- bartz/__init__.py +6 -14
- bartz/_version.py +1 -1
- bartz/debug.py +42 -16
- bartz/grove.py +20 -11
- bartz/jaxext.py +44 -38
- bartz/mcmcloop.py +119 -58
- bartz/mcmcstep.py +426 -173
- bartz/prepcovars.py +22 -9
- bartz-0.5.0.dist-info/METADATA +48 -0
- bartz-0.5.0.dist-info/RECORD +13 -0
- bartz-0.5.0.dist-info/WHEEL +4 -0
- bartz-0.4.0.dist-info/LICENSE +0 -21
- bartz-0.4.0.dist-info/METADATA +0 -77
- bartz-0.4.0.dist-info/RECORD +0 -13
- bartz-0.4.0.dist-info/WHEEL +0 -4
bartz/mcmcstep.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# bartz/src/bartz/mcmcstep.py
|
|
2
2
|
#
|
|
3
|
-
# Copyright (c) 2024, Giacomo Petrillo
|
|
3
|
+
# Copyright (c) 2024-2025, Giacomo Petrillo
|
|
4
4
|
#
|
|
5
5
|
# This file is part of bartz.
|
|
6
6
|
#
|
|
@@ -37,14 +37,14 @@ import functools
|
|
|
37
37
|
import math
|
|
38
38
|
|
|
39
39
|
import jax
|
|
40
|
-
from jax import random
|
|
40
|
+
from jax import lax, random
|
|
41
41
|
from jax import numpy as jnp
|
|
42
|
-
from jax import lax
|
|
43
42
|
|
|
44
|
-
from . import jaxext
|
|
45
|
-
from . import grove
|
|
43
|
+
from . import grove, jaxext
|
|
46
44
|
|
|
47
|
-
|
|
45
|
+
|
|
46
|
+
def init(
|
|
47
|
+
*,
|
|
48
48
|
X,
|
|
49
49
|
y,
|
|
50
50
|
max_split,
|
|
@@ -52,13 +52,14 @@ def init(*,
|
|
|
52
52
|
p_nonterminal,
|
|
53
53
|
sigma2_alpha,
|
|
54
54
|
sigma2_beta,
|
|
55
|
+
error_scale=None,
|
|
55
56
|
small_float=jnp.float32,
|
|
56
57
|
large_float=jnp.float32,
|
|
57
58
|
min_points_per_leaf=None,
|
|
58
59
|
resid_batch_size='auto',
|
|
59
60
|
count_batch_size='auto',
|
|
60
61
|
save_ratios=False,
|
|
61
|
-
|
|
62
|
+
):
|
|
62
63
|
"""
|
|
63
64
|
Make a BART posterior sampling MCMC initial state.
|
|
64
65
|
|
|
@@ -76,9 +77,12 @@ def init(*,
|
|
|
76
77
|
The probability of a nonterminal node at each depth. The maximum depth
|
|
77
78
|
of trees is fixed by the length of this array.
|
|
78
79
|
sigma2_alpha : float
|
|
79
|
-
The shape parameter of the inverse gamma prior on the
|
|
80
|
+
The shape parameter of the inverse gamma prior on the error variance.
|
|
80
81
|
sigma2_beta : float
|
|
81
|
-
The scale parameter of the inverse gamma prior on the
|
|
82
|
+
The scale parameter of the inverse gamma prior on the error variance.
|
|
83
|
+
error_scale : float array (n,), optional
|
|
84
|
+
Each error is scaled by the corresponding factor in `error_scale`, so
|
|
85
|
+
the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
|
|
82
86
|
small_float : dtype, default float32
|
|
83
87
|
The dtype for large arrays used in the algorithm.
|
|
84
88
|
large_float : dtype, default float32
|
|
@@ -110,6 +114,8 @@ def init(*,
|
|
|
110
114
|
roundoff.
|
|
111
115
|
'sigma2' : large_float
|
|
112
116
|
The noise variance.
|
|
117
|
+
'prec_scale' : large_float array (n,) or None
|
|
118
|
+
The scale on the error precision, i.e., ``1 / error_scale ** 2``.
|
|
113
119
|
'grow_prop_count', 'prune_prop_count' : int
|
|
114
120
|
The number of grow/prune proposals made during one full MCMC cycle.
|
|
115
121
|
'grow_acc_count', 'prune_acc_count' : int
|
|
@@ -169,16 +175,27 @@ def init(*,
|
|
|
169
175
|
small_float = jnp.dtype(small_float)
|
|
170
176
|
large_float = jnp.dtype(large_float)
|
|
171
177
|
y = jnp.asarray(y, small_float)
|
|
172
|
-
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
178
|
+
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
|
|
179
|
+
resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
|
|
180
|
+
)
|
|
173
181
|
sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
|
|
174
|
-
sigma2 = jnp.where(
|
|
182
|
+
sigma2 = jnp.where(
|
|
183
|
+
jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1
|
|
184
|
+
) # TODO: I don't like this error check, these functions should be low-level and just do the thing. Why is it here?
|
|
175
185
|
|
|
176
186
|
bart = dict(
|
|
177
187
|
leaf_trees=make_forest(max_depth, small_float),
|
|
178
|
-
var_trees=make_forest(
|
|
188
|
+
var_trees=make_forest(
|
|
189
|
+
max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)
|
|
190
|
+
),
|
|
179
191
|
split_trees=make_forest(max_depth - 1, max_split.dtype),
|
|
180
192
|
resid=jnp.asarray(y, large_float),
|
|
181
193
|
sigma2=sigma2,
|
|
194
|
+
prec_scale=(
|
|
195
|
+
None
|
|
196
|
+
if error_scale is None
|
|
197
|
+
else lax.reciprocal(jnp.square(jnp.asarray(error_scale, large_float)))
|
|
198
|
+
),
|
|
182
199
|
grow_prop_count=jnp.zeros((), int),
|
|
183
200
|
grow_acc_count=jnp.zeros((), int),
|
|
184
201
|
prune_prop_count=jnp.zeros((), int),
|
|
@@ -190,14 +207,18 @@ def init(*,
|
|
|
190
207
|
max_split=jnp.asarray(max_split),
|
|
191
208
|
y=y,
|
|
192
209
|
X=jnp.asarray(X),
|
|
193
|
-
leaf_indices=jnp.ones(
|
|
210
|
+
leaf_indices=jnp.ones(
|
|
211
|
+
(num_trees, y.size), jaxext.minimal_unsigned_dtype(2**max_depth - 1)
|
|
212
|
+
),
|
|
194
213
|
min_points_per_leaf=(
|
|
195
|
-
None if min_points_per_leaf is None else
|
|
196
|
-
jnp.asarray(min_points_per_leaf)
|
|
214
|
+
None if min_points_per_leaf is None else jnp.asarray(min_points_per_leaf)
|
|
197
215
|
),
|
|
198
216
|
affluence_trees=(
|
|
199
|
-
None
|
|
200
|
-
|
|
217
|
+
None
|
|
218
|
+
if min_points_per_leaf is None
|
|
219
|
+
else make_forest(max_depth - 1, bool)
|
|
220
|
+
.at[:, 1]
|
|
221
|
+
.set(y.size >= 2 * min_points_per_leaf)
|
|
201
222
|
),
|
|
202
223
|
opt=jaxext.LeafDict(
|
|
203
224
|
small_float=small_float,
|
|
@@ -216,8 +237,8 @@ def init(*,
|
|
|
216
237
|
|
|
217
238
|
return bart
|
|
218
239
|
|
|
219
|
-
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
|
|
220
240
|
|
|
241
|
+
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
|
|
221
242
|
@functools.cache
|
|
222
243
|
def get_platform():
|
|
223
244
|
try:
|
|
@@ -233,9 +254,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
233
254
|
platform = get_platform()
|
|
234
255
|
n = max(1, y.size)
|
|
235
256
|
if platform == 'cpu':
|
|
236
|
-
resid_batch_size = 2 ** int(round(math.log2(n / 6)))
|
|
257
|
+
resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
|
|
237
258
|
elif platform == 'gpu':
|
|
238
|
-
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3))
|
|
259
|
+
resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
|
|
239
260
|
resid_batch_size = max(1, resid_batch_size)
|
|
240
261
|
|
|
241
262
|
if count_batch_size == 'auto':
|
|
@@ -244,9 +265,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
244
265
|
count_batch_size = None
|
|
245
266
|
elif platform == 'gpu':
|
|
246
267
|
n = max(1, y.size)
|
|
247
|
-
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2))
|
|
248
|
-
|
|
249
|
-
max_memory = 2
|
|
268
|
+
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
|
|
269
|
+
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
270
|
+
max_memory = 2**29
|
|
250
271
|
itemsize = 4
|
|
251
272
|
min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
|
|
252
273
|
count_batch_size = max(count_batch_size, min_batch_size)
|
|
@@ -254,16 +275,17 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
|
|
|
254
275
|
|
|
255
276
|
return resid_batch_size, count_batch_size
|
|
256
277
|
|
|
257
|
-
|
|
278
|
+
|
|
279
|
+
def step(key, bart):
|
|
258
280
|
"""
|
|
259
281
|
Perform one full MCMC step on a BART state.
|
|
260
282
|
|
|
261
283
|
Parameters
|
|
262
284
|
----------
|
|
263
|
-
bart : dict
|
|
264
|
-
A BART mcmc state, as created by `init`.
|
|
265
285
|
key : jax.dtypes.prng_key array
|
|
266
286
|
A jax random key.
|
|
287
|
+
bart : dict
|
|
288
|
+
A BART mcmc state, as created by `init`.
|
|
267
289
|
|
|
268
290
|
Returns
|
|
269
291
|
-------
|
|
@@ -271,19 +293,20 @@ def step(bart, key):
|
|
|
271
293
|
The new BART mcmc state.
|
|
272
294
|
"""
|
|
273
295
|
key, subkey = random.split(key)
|
|
274
|
-
bart = sample_trees(
|
|
275
|
-
return sample_sigma(
|
|
296
|
+
bart = sample_trees(subkey, bart)
|
|
297
|
+
return sample_sigma(key, bart)
|
|
298
|
+
|
|
276
299
|
|
|
277
|
-
def sample_trees(
|
|
300
|
+
def sample_trees(key, bart):
|
|
278
301
|
"""
|
|
279
302
|
Forest sampling step of BART MCMC.
|
|
280
303
|
|
|
281
304
|
Parameters
|
|
282
305
|
----------
|
|
283
|
-
bart : dict
|
|
284
|
-
A BART mcmc state, as created by `init`.
|
|
285
306
|
key : jax.dtypes.prng_key array
|
|
286
307
|
A jax random key.
|
|
308
|
+
bart : dict
|
|
309
|
+
A BART mcmc state, as created by `init`.
|
|
287
310
|
|
|
288
311
|
Returns
|
|
289
312
|
-------
|
|
@@ -295,19 +318,20 @@ def sample_trees(bart, key):
|
|
|
295
318
|
This function zeroes the proposal counters.
|
|
296
319
|
"""
|
|
297
320
|
key, subkey = random.split(key)
|
|
298
|
-
moves = sample_moves(
|
|
299
|
-
return accept_moves_and_sample_leaves(bart, moves
|
|
321
|
+
moves = sample_moves(subkey, bart)
|
|
322
|
+
return accept_moves_and_sample_leaves(key, bart, moves)
|
|
300
323
|
|
|
301
|
-
|
|
324
|
+
|
|
325
|
+
def sample_moves(key, bart):
|
|
302
326
|
"""
|
|
303
327
|
Propose moves for all the trees.
|
|
304
328
|
|
|
305
329
|
Parameters
|
|
306
330
|
----------
|
|
307
|
-
bart : dict
|
|
308
|
-
BART mcmc state.
|
|
309
331
|
key : jax.dtypes.prng_key array
|
|
310
332
|
A jax random key.
|
|
333
|
+
bart : dict
|
|
334
|
+
BART mcmc state.
|
|
311
335
|
|
|
312
336
|
Returns
|
|
313
337
|
-------
|
|
@@ -343,14 +367,22 @@ def sample_moves(bart, key):
|
|
|
343
367
|
key, subkey = key[0], key[1:]
|
|
344
368
|
|
|
345
369
|
# compute moves
|
|
346
|
-
grow_moves, prune_moves = _sample_moves_vmap_trees(
|
|
370
|
+
grow_moves, prune_moves = _sample_moves_vmap_trees(
|
|
371
|
+
subkey,
|
|
372
|
+
bart['var_trees'],
|
|
373
|
+
bart['split_trees'],
|
|
374
|
+
bart['affluence_trees'],
|
|
375
|
+
bart['max_split'],
|
|
376
|
+
bart['p_nonterminal'],
|
|
377
|
+
bart['p_propose_grow'],
|
|
378
|
+
)
|
|
347
379
|
|
|
348
380
|
u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
|
|
349
381
|
|
|
350
382
|
# choose between grow or prune
|
|
351
383
|
grow_allowed = grow_moves['num_growable'].astype(bool)
|
|
352
384
|
p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
|
|
353
|
-
grow = u < p_grow
|
|
385
|
+
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
354
386
|
|
|
355
387
|
# compute children indices
|
|
356
388
|
node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
|
|
@@ -364,22 +396,28 @@ def sample_moves(bart, key):
|
|
|
364
396
|
node=node,
|
|
365
397
|
left=left,
|
|
366
398
|
right=right,
|
|
367
|
-
partial_ratio=jnp.where(
|
|
399
|
+
partial_ratio=jnp.where(
|
|
400
|
+
grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']
|
|
401
|
+
),
|
|
368
402
|
grow_var=grow_moves['var'],
|
|
369
403
|
grow_split=grow_moves['split'],
|
|
370
404
|
var_trees=grow_moves['var_tree'],
|
|
371
405
|
logu=jnp.log1p(-logu),
|
|
372
406
|
)
|
|
373
407
|
|
|
374
|
-
|
|
408
|
+
|
|
409
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
|
|
375
410
|
def _sample_moves_vmap_trees(*args):
|
|
376
|
-
|
|
411
|
+
key, args = args[0], args[1:]
|
|
377
412
|
key, key1 = random.split(key)
|
|
378
|
-
grow = grow_move(*args
|
|
379
|
-
prune = prune_move(*args
|
|
413
|
+
grow = grow_move(key, *args)
|
|
414
|
+
prune = prune_move(key1, *args)
|
|
380
415
|
return grow, prune
|
|
381
416
|
|
|
382
|
-
|
|
417
|
+
|
|
418
|
+
def grow_move(
|
|
419
|
+
key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
|
|
420
|
+
):
|
|
383
421
|
"""
|
|
384
422
|
Tree structure grow move proposal of BART MCMC.
|
|
385
423
|
|
|
@@ -426,14 +464,18 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
|
|
|
426
464
|
|
|
427
465
|
key, key1, key2 = random.split(key, 3)
|
|
428
466
|
|
|
429
|
-
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
467
|
+
leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
|
|
468
|
+
key, split_tree, affluence_tree, p_propose_grow
|
|
469
|
+
)
|
|
430
470
|
|
|
431
|
-
var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow
|
|
471
|
+
var = choose_variable(key1, var_tree, split_tree, max_split, leaf_to_grow)
|
|
432
472
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
433
473
|
|
|
434
|
-
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow
|
|
474
|
+
split = choose_split(key2, var_tree, split_tree, max_split, leaf_to_grow)
|
|
435
475
|
|
|
436
|
-
ratio = compute_partial_ratio(
|
|
476
|
+
ratio = compute_partial_ratio(
|
|
477
|
+
prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
478
|
+
)
|
|
437
479
|
|
|
438
480
|
return dict(
|
|
439
481
|
num_growable=num_growable,
|
|
@@ -444,7 +486,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
|
|
|
444
486
|
var_tree=var_tree,
|
|
445
487
|
)
|
|
446
488
|
|
|
447
|
-
|
|
489
|
+
|
|
490
|
+
def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
|
|
448
491
|
"""
|
|
449
492
|
Choose a leaf node to grow in a tree.
|
|
450
493
|
|
|
@@ -482,6 +525,7 @@ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
|
482
525
|
num_prunable = jnp.count_nonzero(is_parent)
|
|
483
526
|
return leaf_to_grow, num_growable, prob_choose, num_prunable
|
|
484
527
|
|
|
528
|
+
|
|
485
529
|
def growable_leaves(split_tree, affluence_tree):
|
|
486
530
|
"""
|
|
487
531
|
Return a mask indicating the leaf nodes that can be proposed for growth.
|
|
@@ -505,6 +549,7 @@ def growable_leaves(split_tree, affluence_tree):
|
|
|
505
549
|
is_growable &= affluence_tree
|
|
506
550
|
return is_growable
|
|
507
551
|
|
|
552
|
+
|
|
508
553
|
def categorical(key, distr):
|
|
509
554
|
"""
|
|
510
555
|
Return a random integer from an arbitrary distribution.
|
|
@@ -521,12 +566,15 @@ def categorical(key, distr):
|
|
|
521
566
|
u : int
|
|
522
567
|
A random integer in the range ``[0, n)``. If all probabilities are zero,
|
|
523
568
|
return ``n``.
|
|
569
|
+
norm : float
|
|
570
|
+
The sum of `distr`.
|
|
524
571
|
"""
|
|
525
572
|
ecdf = jnp.cumsum(distr)
|
|
526
573
|
u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
|
|
527
574
|
return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
|
|
528
575
|
|
|
529
|
-
|
|
576
|
+
|
|
577
|
+
def choose_variable(key, var_tree, split_tree, max_split, leaf_index):
|
|
530
578
|
"""
|
|
531
579
|
Choose a variable to split on for a new non-terminal node.
|
|
532
580
|
|
|
@@ -556,6 +604,7 @@ def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
556
604
|
var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
|
|
557
605
|
return randint_exclude(key, max_split.size, var_to_ignore)
|
|
558
606
|
|
|
607
|
+
|
|
559
608
|
def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
|
|
560
609
|
"""
|
|
561
610
|
Return a list of variables that have an empty split range at a given node.
|
|
@@ -586,6 +635,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
|
|
|
586
635
|
num_split = r - l
|
|
587
636
|
return jnp.where(num_split == 0, var_to_ignore, max_split.size)
|
|
588
637
|
|
|
638
|
+
|
|
589
639
|
def ancestor_variables(var_tree, max_split, node_index):
|
|
590
640
|
"""
|
|
591
641
|
Return the list of variables in the ancestors of a node.
|
|
@@ -606,8 +656,11 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
606
656
|
the parent. Unused spots are filled with `p`.
|
|
607
657
|
"""
|
|
608
658
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
609
|
-
ancestor_vars = jnp.zeros(
|
|
659
|
+
ancestor_vars = jnp.zeros(
|
|
660
|
+
max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)
|
|
661
|
+
)
|
|
610
662
|
carry = ancestor_vars.size - 1, node_index, ancestor_vars
|
|
663
|
+
|
|
611
664
|
def loop(carry, _):
|
|
612
665
|
i, index, ancestor_vars = carry
|
|
613
666
|
index >>= 1
|
|
@@ -615,9 +668,11 @@ def ancestor_variables(var_tree, max_split, node_index):
|
|
|
615
668
|
var = jnp.where(index, var, max_split.size)
|
|
616
669
|
ancestor_vars = ancestor_vars.at[i].set(var)
|
|
617
670
|
return (i - 1, index, ancestor_vars), None
|
|
671
|
+
|
|
618
672
|
(_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
|
|
619
673
|
return ancestor_vars
|
|
620
674
|
|
|
675
|
+
|
|
621
676
|
def split_range(var_tree, split_tree, max_split, node_index, ref_var):
|
|
622
677
|
"""
|
|
623
678
|
Return the range of allowed splits for a variable at a given node.
|
|
@@ -641,8 +696,11 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
|
|
|
641
696
|
The range of allowed splits is [l, r).
|
|
642
697
|
"""
|
|
643
698
|
max_num_ancestors = grove.tree_depth(var_tree) - 1
|
|
644
|
-
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
699
|
+
initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
|
|
700
|
+
jnp.int32
|
|
701
|
+
)
|
|
645
702
|
carry = 0, initial_r, node_index
|
|
703
|
+
|
|
646
704
|
def loop(carry, _):
|
|
647
705
|
l, r, index = carry
|
|
648
706
|
right_child = (index & 1).astype(bool)
|
|
@@ -652,9 +710,11 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
|
|
|
652
710
|
l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
|
|
653
711
|
r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
|
|
654
712
|
return (l, r, index), None
|
|
713
|
+
|
|
655
714
|
(l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
|
|
656
715
|
return l + 1, r
|
|
657
716
|
|
|
717
|
+
|
|
658
718
|
def randint_exclude(key, sup, exclude):
|
|
659
719
|
"""
|
|
660
720
|
Return a random integer in a range, excluding some values.
|
|
@@ -679,12 +739,15 @@ def randint_exclude(key, sup, exclude):
|
|
|
679
739
|
exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
|
|
680
740
|
num_allowed = sup - jnp.count_nonzero(exclude < sup)
|
|
681
741
|
u = random.randint(key, (), 0, num_allowed)
|
|
742
|
+
|
|
682
743
|
def loop(u, i):
|
|
683
744
|
return jnp.where(i <= u, u + 1, u), None
|
|
745
|
+
|
|
684
746
|
u, _ = lax.scan(loop, u, exclude)
|
|
685
747
|
return u
|
|
686
748
|
|
|
687
|
-
|
|
749
|
+
|
|
750
|
+
def choose_split(key, var_tree, split_tree, max_split, leaf_index):
|
|
688
751
|
"""
|
|
689
752
|
Choose a split point for a new non-terminal node.
|
|
690
753
|
|
|
@@ -711,6 +774,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
711
774
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
712
775
|
return random.randint(key, (), l, r)
|
|
713
776
|
|
|
777
|
+
|
|
714
778
|
def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
|
|
715
779
|
"""
|
|
716
780
|
Compute the product of the transition and prior ratios of a grow move.
|
|
@@ -742,9 +806,9 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
742
806
|
# computed in the acceptance phase
|
|
743
807
|
|
|
744
808
|
prune_allowed = leaf_to_grow != 1
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
809
|
+
# prune allowed <---> the initial tree is not a root
|
|
810
|
+
# leaf to grow is root --> the tree can only be a root
|
|
811
|
+
# tree is a root --> the only leaf I can grow is root
|
|
748
812
|
|
|
749
813
|
p_grow = jnp.where(prune_allowed, 0.5, 1)
|
|
750
814
|
|
|
@@ -757,7 +821,10 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
757
821
|
|
|
758
822
|
return tree_ratio / inv_trans_ratio
|
|
759
823
|
|
|
760
|
-
|
|
824
|
+
|
|
825
|
+
def prune_move(
|
|
826
|
+
key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
|
|
827
|
+
):
|
|
761
828
|
"""
|
|
762
829
|
Tree structure prune move proposal of BART MCMC.
|
|
763
830
|
|
|
@@ -792,18 +859,23 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p
|
|
|
792
859
|
the likelihood ratio and the probability of proposing the prune
|
|
793
860
|
move. This ratio is inverted.
|
|
794
861
|
"""
|
|
795
|
-
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
796
|
-
|
|
862
|
+
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
|
|
863
|
+
key, split_tree, affluence_tree, p_propose_grow
|
|
864
|
+
)
|
|
865
|
+
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
797
866
|
|
|
798
|
-
ratio = compute_partial_ratio(
|
|
867
|
+
ratio = compute_partial_ratio(
|
|
868
|
+
prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
869
|
+
)
|
|
799
870
|
|
|
800
871
|
return dict(
|
|
801
872
|
allowed=allowed,
|
|
802
873
|
node=node_to_prune,
|
|
803
|
-
partial_ratio=ratio,
|
|
874
|
+
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
804
875
|
)
|
|
805
876
|
|
|
806
|
-
|
|
877
|
+
|
|
878
|
+
def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
|
|
807
879
|
"""
|
|
808
880
|
Pick a non-terminal node with leaf children to prune in a tree.
|
|
809
881
|
|
|
@@ -835,8 +907,7 @@ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
|
|
|
835
907
|
|
|
836
908
|
split_tree = split_tree.at[node_to_prune].set(0)
|
|
837
909
|
affluence_tree = (
|
|
838
|
-
None if affluence_tree is None else
|
|
839
|
-
affluence_tree.at[node_to_prune].set(True)
|
|
910
|
+
None if affluence_tree is None else affluence_tree.at[node_to_prune].set(True)
|
|
840
911
|
)
|
|
841
912
|
is_growable_leaf = growable_leaves(split_tree, affluence_tree)
|
|
842
913
|
prob_choose = p_propose_grow[node_to_prune]
|
|
@@ -844,6 +915,7 @@ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
|
|
|
844
915
|
|
|
845
916
|
return node_to_prune, num_prunable, prob_choose
|
|
846
917
|
|
|
918
|
+
|
|
847
919
|
def randint_masked(key, mask):
|
|
848
920
|
"""
|
|
849
921
|
Return a random integer in a range, including only some values.
|
|
@@ -865,40 +937,46 @@ def randint_masked(key, mask):
|
|
|
865
937
|
u = random.randint(key, (), 0, ecdf[-1])
|
|
866
938
|
return jnp.searchsorted(ecdf, u, 'right')
|
|
867
939
|
|
|
868
|
-
|
|
940
|
+
|
|
941
|
+
def accept_moves_and_sample_leaves(key, bart, moves):
|
|
869
942
|
"""
|
|
870
943
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
871
944
|
|
|
872
945
|
Parameters
|
|
873
946
|
----------
|
|
947
|
+
key : jax.dtypes.prng_key array
|
|
948
|
+
A jax random key.
|
|
874
949
|
bart : dict
|
|
875
950
|
A BART mcmc state.
|
|
876
951
|
moves : dict
|
|
877
952
|
The proposed moves, see `sample_moves`.
|
|
878
|
-
key : jax.dtypes.prng_key array
|
|
879
|
-
A jax random key.
|
|
880
953
|
|
|
881
954
|
Returns
|
|
882
955
|
-------
|
|
883
956
|
bart : dict
|
|
884
957
|
The new BART mcmc state.
|
|
885
958
|
"""
|
|
886
|
-
bart, moves,
|
|
887
|
-
|
|
959
|
+
bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf = (
|
|
960
|
+
accept_moves_parallel_stage(key, bart, moves)
|
|
961
|
+
)
|
|
962
|
+
bart, moves = accept_moves_sequential_stage(
|
|
963
|
+
bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
|
|
964
|
+
)
|
|
888
965
|
return accept_moves_final_stage(bart, moves)
|
|
889
966
|
|
|
890
|
-
|
|
967
|
+
|
|
968
|
+
def accept_moves_parallel_stage(key, bart, moves):
|
|
891
969
|
"""
|
|
892
970
|
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
893
971
|
|
|
894
972
|
Parameters
|
|
895
973
|
----------
|
|
974
|
+
key : jax.dtypes.prng_key array
|
|
975
|
+
A jax random key.
|
|
896
976
|
bart : dict
|
|
897
977
|
A BART mcmc state.
|
|
898
978
|
moves : dict
|
|
899
979
|
The proposed moves, see `sample_moves`.
|
|
900
|
-
key : jax.dtypes.prng_key array
|
|
901
|
-
A jax random key.
|
|
902
980
|
|
|
903
981
|
Returns
|
|
904
982
|
-------
|
|
@@ -907,11 +985,14 @@ def accept_moves_parallel_stage(bart, moves, key):
|
|
|
907
985
|
moves : dict
|
|
908
986
|
The proposed moves, with the field 'partial_ratio' replaced
|
|
909
987
|
by 'log_trans_prior_ratio'.
|
|
910
|
-
|
|
911
|
-
The
|
|
988
|
+
prec_trees : float array (num_trees, 2 ** d)
|
|
989
|
+
The likelihood precision scale in each potential or actual leaf node. If
|
|
990
|
+
there is no precision scale, this is the number of points in each leaf.
|
|
912
991
|
move_counts : dict
|
|
913
992
|
The counts of the number of points in the the nodes modified by the
|
|
914
993
|
moves.
|
|
994
|
+
move_precs : dict
|
|
995
|
+
The likelihood precision scale in each node modified by the moves.
|
|
915
996
|
prelkv, prelk, prelf : dict
|
|
916
997
|
Dictionary with pre-computed terms of the likelihood ratios and leaf
|
|
917
998
|
samples.
|
|
@@ -924,20 +1005,35 @@ def accept_moves_parallel_stage(bart, moves, key):
|
|
|
924
1005
|
bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
|
|
925
1006
|
|
|
926
1007
|
# count number of datapoints per leaf
|
|
927
|
-
count_trees, move_counts = compute_count_trees(
|
|
1008
|
+
count_trees, move_counts = compute_count_trees(
|
|
1009
|
+
bart['leaf_indices'], moves, bart['opt']['count_batch_size']
|
|
1010
|
+
)
|
|
928
1011
|
if bart['opt']['require_min_points']:
|
|
929
|
-
count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
|
|
1012
|
+
count_half_trees = count_trees[:, : bart['var_trees'].shape[1]]
|
|
930
1013
|
bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
|
|
931
1014
|
|
|
1015
|
+
# count number of datapoints per leaf, weighted by error precision scale
|
|
1016
|
+
if bart['prec_scale'] is None:
|
|
1017
|
+
prec_trees = count_trees
|
|
1018
|
+
move_precs = move_counts
|
|
1019
|
+
else:
|
|
1020
|
+
prec_trees, move_precs = compute_prec_trees(
|
|
1021
|
+
bart['prec_scale'],
|
|
1022
|
+
bart['leaf_indices'],
|
|
1023
|
+
moves,
|
|
1024
|
+
bart['opt']['count_batch_size'],
|
|
1025
|
+
)
|
|
1026
|
+
|
|
932
1027
|
# compute some missing information about moves
|
|
933
1028
|
moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
|
|
934
1029
|
bart['grow_prop_count'] = jnp.sum(moves['grow'])
|
|
935
1030
|
bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
|
|
936
1031
|
|
|
937
|
-
prelkv, prelk = precompute_likelihood_terms(
|
|
938
|
-
prelf = precompute_leaf_terms(
|
|
1032
|
+
prelkv, prelk = precompute_likelihood_terms(bart['sigma2'], move_precs)
|
|
1033
|
+
prelf = precompute_leaf_terms(key, prec_trees, bart['sigma2'])
|
|
1034
|
+
|
|
1035
|
+
return bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf
|
|
939
1036
|
|
|
940
|
-
return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
|
|
941
1037
|
|
|
942
1038
|
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
|
|
943
1039
|
def apply_grow_to_indices(moves, leaf_indices, X):
|
|
@@ -968,15 +1064,16 @@ def apply_grow_to_indices(moves, leaf_indices, X):
|
|
|
968
1064
|
leaf_indices,
|
|
969
1065
|
)
|
|
970
1066
|
|
|
1067
|
+
|
|
971
1068
|
def compute_count_trees(leaf_indices, moves, batch_size):
|
|
972
1069
|
"""
|
|
973
1070
|
Count the number of datapoints in each leaf.
|
|
974
1071
|
|
|
975
1072
|
Parameters
|
|
976
1073
|
----------
|
|
977
|
-
|
|
978
|
-
The index of the leaf each datapoint falls into,
|
|
979
|
-
|
|
1074
|
+
leaf_indices : int array (num_trees, n)
|
|
1075
|
+
The index of the leaf each datapoint falls into, with the deeper version
|
|
1076
|
+
of the tree (post-GROW, pre-PRUNE).
|
|
980
1077
|
moves : dict
|
|
981
1078
|
The proposed moves, see `sample_moves`.
|
|
982
1079
|
batch_size : int or None
|
|
@@ -987,9 +1084,8 @@ def compute_count_trees(leaf_indices, moves, batch_size):
|
|
|
987
1084
|
count_trees : int array (num_trees, 2 ** (d - 1))
|
|
988
1085
|
The number of points in each potential or actual leaf node.
|
|
989
1086
|
counts : dict
|
|
990
|
-
The counts of the number of points in the
|
|
991
|
-
moves,
|
|
992
|
-
'left', 'right', and 'total'.
|
|
1087
|
+
The counts of the number of points in the leaves grown or pruned by the
|
|
1088
|
+
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
993
1089
|
"""
|
|
994
1090
|
|
|
995
1091
|
ntree, tree_size = moves['var_trees'].shape
|
|
@@ -1009,6 +1105,7 @@ def compute_count_trees(leaf_indices, moves, batch_size):
|
|
|
1009
1105
|
|
|
1010
1106
|
return count_trees, counts
|
|
1011
1107
|
|
|
1108
|
+
|
|
1012
1109
|
def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
1013
1110
|
"""
|
|
1014
1111
|
Count the number of datapoints in each leaf.
|
|
@@ -1032,40 +1129,129 @@ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
|
|
|
1032
1129
|
else:
|
|
1033
1130
|
return _count_vec(leaf_indices, tree_size, batch_size)
|
|
1034
1131
|
|
|
1132
|
+
|
|
1035
1133
|
def _count_scan(leaf_indices, tree_size):
|
|
1036
1134
|
def loop(_, leaf_indices):
|
|
1037
1135
|
return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
|
|
1136
|
+
|
|
1038
1137
|
_, count_trees = lax.scan(loop, None, leaf_indices)
|
|
1039
1138
|
return count_trees
|
|
1040
1139
|
|
|
1140
|
+
|
|
1041
1141
|
def _aggregate_scatter(values, indices, size, dtype):
|
|
1042
|
-
return (
|
|
1043
|
-
|
|
1044
|
-
.at[indices]
|
|
1045
|
-
.add(values)
|
|
1046
|
-
)
|
|
1142
|
+
return jnp.zeros(size, dtype).at[indices].add(values)
|
|
1143
|
+
|
|
1047
1144
|
|
|
1048
1145
|
def _count_vec(leaf_indices, tree_size, batch_size):
|
|
1049
|
-
return _aggregate_batched_alltrees(
|
|
1050
|
-
|
|
1146
|
+
return _aggregate_batched_alltrees(
|
|
1147
|
+
1, leaf_indices, tree_size, jnp.uint32, batch_size
|
|
1148
|
+
)
|
|
1149
|
+
# uint16 is super-slow on gpu, don't use it even if n < 2^16
|
|
1150
|
+
|
|
1051
1151
|
|
|
1052
1152
|
def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
|
|
1053
1153
|
ntree, n = indices.shape
|
|
1054
1154
|
tree_indices = jnp.arange(ntree)
|
|
1055
1155
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1056
1156
|
batch_indices = jnp.arange(n) % nbatches
|
|
1057
|
-
return (
|
|
1058
|
-
.zeros((ntree, size, nbatches), dtype)
|
|
1157
|
+
return (
|
|
1158
|
+
jnp.zeros((ntree, size, nbatches), dtype)
|
|
1059
1159
|
.at[tree_indices[:, None], indices, batch_indices]
|
|
1060
1160
|
.add(values)
|
|
1061
1161
|
.sum(axis=2)
|
|
1062
1162
|
)
|
|
1063
1163
|
|
|
1164
|
+
|
|
1165
|
+
def compute_prec_trees(prec_scale, leaf_indices, moves, batch_size):
|
|
1166
|
+
"""
|
|
1167
|
+
Compute the likelihood precision scale in each leaf.
|
|
1168
|
+
|
|
1169
|
+
Parameters
|
|
1170
|
+
----------
|
|
1171
|
+
prec_scale : float array (n,)
|
|
1172
|
+
The scale of the precision of the error on each datapoint.
|
|
1173
|
+
leaf_indices : int array (num_trees, n)
|
|
1174
|
+
The index of the leaf each datapoint falls into, with the deeper version
|
|
1175
|
+
of the tree (post-GROW, pre-PRUNE).
|
|
1176
|
+
moves : dict
|
|
1177
|
+
The proposed moves, see `sample_moves`.
|
|
1178
|
+
batch_size : int or None
|
|
1179
|
+
The data batch size to use for the summation.
|
|
1180
|
+
|
|
1181
|
+
Returns
|
|
1182
|
+
-------
|
|
1183
|
+
prec_trees : float array (num_trees, 2 ** (d - 1))
|
|
1184
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1185
|
+
counts : dict
|
|
1186
|
+
The likelihood precision scale in the leaves grown or pruned by the
|
|
1187
|
+
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1188
|
+
"""
|
|
1189
|
+
|
|
1190
|
+
ntree, tree_size = moves['var_trees'].shape
|
|
1191
|
+
tree_size *= 2
|
|
1192
|
+
tree_indices = jnp.arange(ntree)
|
|
1193
|
+
|
|
1194
|
+
prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1195
|
+
|
|
1196
|
+
# prec datapoints in nodes modified by move
|
|
1197
|
+
precs = dict()
|
|
1198
|
+
precs['left'] = prec_trees[tree_indices, moves['left']]
|
|
1199
|
+
precs['right'] = prec_trees[tree_indices, moves['right']]
|
|
1200
|
+
precs['total'] = precs['left'] + precs['right']
|
|
1201
|
+
|
|
1202
|
+
# write prec into non-leaf node
|
|
1203
|
+
prec_trees = prec_trees.at[tree_indices, moves['node']].set(precs['total'])
|
|
1204
|
+
|
|
1205
|
+
return prec_trees, precs
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
|
|
1209
|
+
"""
|
|
1210
|
+
Compute the likelihood precision scale in each leaf.
|
|
1211
|
+
|
|
1212
|
+
Parameters
|
|
1213
|
+
----------
|
|
1214
|
+
prec_scale : float array (n,)
|
|
1215
|
+
The scale of the precision of the error on each datapoint.
|
|
1216
|
+
leaf_indices : int array (num_trees, n)
|
|
1217
|
+
The index of the leaf each datapoint falls into.
|
|
1218
|
+
tree_size : int
|
|
1219
|
+
The size of the leaf tree array (2 ** d).
|
|
1220
|
+
batch_size : int or None
|
|
1221
|
+
The data batch size to use for the summation.
|
|
1222
|
+
|
|
1223
|
+
Returns
|
|
1224
|
+
-------
|
|
1225
|
+
prec_trees : int array (num_trees, 2 ** (d - 1))
|
|
1226
|
+
The likelihood precision scale in each leaf node.
|
|
1227
|
+
"""
|
|
1228
|
+
if batch_size is None:
|
|
1229
|
+
return _prec_scan(prec_scale, leaf_indices, tree_size)
|
|
1230
|
+
else:
|
|
1231
|
+
return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
def _prec_scan(prec_scale, leaf_indices, tree_size):
|
|
1235
|
+
def loop(_, leaf_indices):
|
|
1236
|
+
return None, _aggregate_scatter(
|
|
1237
|
+
prec_scale, leaf_indices, tree_size, jnp.float32
|
|
1238
|
+
) # TODO: use large_float
|
|
1239
|
+
|
|
1240
|
+
_, prec_trees = lax.scan(loop, None, leaf_indices)
|
|
1241
|
+
return prec_trees
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
def _prec_vec(prec_scale, leaf_indices, tree_size, batch_size):
|
|
1245
|
+
return _aggregate_batched_alltrees(
|
|
1246
|
+
prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
|
|
1247
|
+
) # TODO: use large_float
|
|
1248
|
+
|
|
1249
|
+
|
|
1064
1250
|
def complete_ratio(moves, move_counts, min_points_per_leaf):
|
|
1065
1251
|
"""
|
|
1066
1252
|
Complete non-likelihood MH ratio calculation.
|
|
1067
1253
|
|
|
1068
|
-
This
|
|
1254
|
+
This function adds the probability of choosing the prune move.
|
|
1069
1255
|
|
|
1070
1256
|
Parameters
|
|
1071
1257
|
----------
|
|
@@ -1084,10 +1270,13 @@ def complete_ratio(moves, move_counts, min_points_per_leaf):
|
|
|
1084
1270
|
'log_trans_prior_ratio'.
|
|
1085
1271
|
"""
|
|
1086
1272
|
moves = moves.copy()
|
|
1087
|
-
p_prune = compute_p_prune(
|
|
1273
|
+
p_prune = compute_p_prune(
|
|
1274
|
+
moves, move_counts['left'], move_counts['right'], min_points_per_leaf
|
|
1275
|
+
)
|
|
1088
1276
|
moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
|
|
1089
1277
|
return moves
|
|
1090
1278
|
|
|
1279
|
+
|
|
1091
1280
|
def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
|
|
1092
1281
|
"""
|
|
1093
1282
|
Compute the probability of proposing a prune move.
|
|
@@ -1123,6 +1312,7 @@ def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
|
|
|
1123
1312
|
|
|
1124
1313
|
return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
|
|
1125
1314
|
|
|
1315
|
+
|
|
1126
1316
|
@jaxext.vmap_nodoc
|
|
1127
1317
|
def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
|
|
1128
1318
|
"""
|
|
@@ -1143,26 +1333,25 @@ def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
|
|
|
1143
1333
|
what would be its children if the grow move was accepted.
|
|
1144
1334
|
"""
|
|
1145
1335
|
values_at_node = leaf_trees[moves['node']]
|
|
1146
|
-
return (
|
|
1147
|
-
.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
|
|
1336
|
+
return (
|
|
1337
|
+
leaf_trees.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
|
|
1148
1338
|
.set(values_at_node)
|
|
1149
1339
|
.at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
|
|
1150
1340
|
.set(values_at_node)
|
|
1151
1341
|
)
|
|
1152
1342
|
|
|
1153
|
-
|
|
1343
|
+
|
|
1344
|
+
def precompute_likelihood_terms(sigma2, move_precs):
|
|
1154
1345
|
"""
|
|
1155
1346
|
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1156
1347
|
|
|
1157
1348
|
Parameters
|
|
1158
1349
|
----------
|
|
1159
|
-
count_trees : array (num_trees, 2 ** d)
|
|
1160
|
-
The number of points in each potential or actual leaf node.
|
|
1161
1350
|
sigma2 : float
|
|
1162
1351
|
The noise variance.
|
|
1163
|
-
|
|
1164
|
-
The
|
|
1165
|
-
moves.
|
|
1352
|
+
move_precs : dict
|
|
1353
|
+
The likelihood precision scale in the leaves grown or pruned by the
|
|
1354
|
+
moves, under keys 'left', 'right', and 'total' (left + right).
|
|
1166
1355
|
|
|
1167
1356
|
Returns
|
|
1168
1357
|
-------
|
|
@@ -1173,32 +1362,37 @@ def precompute_likelihood_terms(count_trees, sigma2, move_counts):
|
|
|
1173
1362
|
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
1174
1363
|
all trees.
|
|
1175
1364
|
"""
|
|
1176
|
-
ntree = len(
|
|
1365
|
+
ntree = len(move_precs['total'])
|
|
1177
1366
|
sigma_mu2 = 1 / ntree
|
|
1178
1367
|
prelkv = dict()
|
|
1179
|
-
prelkv['sigma2_left'] = sigma2 +
|
|
1180
|
-
prelkv['sigma2_right'] = sigma2 +
|
|
1181
|
-
prelkv['sigma2_total'] = sigma2 +
|
|
1182
|
-
prelkv['sqrt_term'] =
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1368
|
+
prelkv['sigma2_left'] = sigma2 + move_precs['left'] * sigma_mu2
|
|
1369
|
+
prelkv['sigma2_right'] = sigma2 + move_precs['right'] * sigma_mu2
|
|
1370
|
+
prelkv['sigma2_total'] = sigma2 + move_precs['total'] * sigma_mu2
|
|
1371
|
+
prelkv['sqrt_term'] = (
|
|
1372
|
+
jnp.log(
|
|
1373
|
+
sigma2
|
|
1374
|
+
* prelkv['sigma2_total']
|
|
1375
|
+
/ (prelkv['sigma2_left'] * prelkv['sigma2_right'])
|
|
1376
|
+
)
|
|
1377
|
+
/ 2
|
|
1378
|
+
)
|
|
1186
1379
|
return prelkv, dict(
|
|
1187
1380
|
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
1188
1381
|
)
|
|
1189
1382
|
|
|
1190
|
-
|
|
1383
|
+
|
|
1384
|
+
def precompute_leaf_terms(key, prec_trees, sigma2):
|
|
1191
1385
|
"""
|
|
1192
1386
|
Pre-compute terms used to sample leaves from their posterior.
|
|
1193
1387
|
|
|
1194
1388
|
Parameters
|
|
1195
1389
|
----------
|
|
1196
|
-
count_trees : array (num_trees, 2 ** d)
|
|
1197
|
-
The number of points in each potential or actual leaf node.
|
|
1198
|
-
sigma2 : float
|
|
1199
|
-
The noise variance.
|
|
1200
1390
|
key : jax.dtypes.prng_key array
|
|
1201
1391
|
A jax random key.
|
|
1392
|
+
prec_trees : array (num_trees, 2 ** d)
|
|
1393
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1394
|
+
sigma2 : float
|
|
1395
|
+
The noise variance.
|
|
1202
1396
|
|
|
1203
1397
|
Returns
|
|
1204
1398
|
-------
|
|
@@ -1206,22 +1400,25 @@ def precompute_leaf_terms(count_trees, sigma2, key):
|
|
|
1206
1400
|
Dictionary with pre-computed terms of the leaf sampling, with fields:
|
|
1207
1401
|
|
|
1208
1402
|
'mean_factor' : float array (num_trees, 2 ** d)
|
|
1209
|
-
The factor to be multiplied by the sum of residuals to
|
|
1210
|
-
posterior mean.
|
|
1403
|
+
The factor to be multiplied by the sum of the scaled residuals to
|
|
1404
|
+
obtain the posterior mean.
|
|
1211
1405
|
'centered_leaves' : float array (num_trees, 2 ** d)
|
|
1212
1406
|
The mean-zero normal values to be added to the posterior mean to
|
|
1213
1407
|
obtain the posterior leaf samples.
|
|
1214
1408
|
"""
|
|
1215
|
-
ntree = len(
|
|
1216
|
-
prec_lk =
|
|
1217
|
-
var_post = lax.reciprocal(prec_lk + ntree)
|
|
1218
|
-
z = random.normal(key,
|
|
1409
|
+
ntree = len(prec_trees)
|
|
1410
|
+
prec_lk = prec_trees / sigma2
|
|
1411
|
+
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
1412
|
+
z = random.normal(key, prec_trees.shape, sigma2.dtype)
|
|
1219
1413
|
return dict(
|
|
1220
|
-
mean_factor=var_post / sigma2,
|
|
1414
|
+
mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
|
|
1221
1415
|
centered_leaves=z * jnp.sqrt(var_post),
|
|
1222
1416
|
)
|
|
1223
1417
|
|
|
1224
|
-
|
|
1418
|
+
|
|
1419
|
+
def accept_moves_sequential_stage(
|
|
1420
|
+
bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
|
|
1421
|
+
):
|
|
1225
1422
|
"""
|
|
1226
1423
|
The part of accepting the moves that has to be done one tree at a time.
|
|
1227
1424
|
|
|
@@ -1229,13 +1426,15 @@ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv,
|
|
|
1229
1426
|
----------
|
|
1230
1427
|
bart : dict
|
|
1231
1428
|
A partially updated BART mcmc state.
|
|
1232
|
-
|
|
1233
|
-
The
|
|
1429
|
+
prec_trees : float array (num_trees, 2 ** d)
|
|
1430
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1234
1431
|
moves : dict
|
|
1235
1432
|
The proposed moves, see `sample_moves`.
|
|
1236
1433
|
move_counts : dict
|
|
1237
1434
|
The counts of the number of points in the the nodes modified by the
|
|
1238
1435
|
moves.
|
|
1436
|
+
move_precs : dict
|
|
1437
|
+
The likelihood precision scale in each node modified by the moves.
|
|
1239
1438
|
prelkv, prelk, prelf : dict
|
|
1240
1439
|
Dictionaries with pre-computed terms of the likelihood ratios and leaf
|
|
1241
1440
|
samples.
|
|
@@ -1262,6 +1461,7 @@ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv,
|
|
|
1262
1461
|
len(bart['leaf_trees']),
|
|
1263
1462
|
bart['opt']['resid_batch_size'],
|
|
1264
1463
|
resid,
|
|
1464
|
+
bart['prec_scale'],
|
|
1265
1465
|
bart['min_points_per_leaf'],
|
|
1266
1466
|
'ratios' in bart,
|
|
1267
1467
|
prelk,
|
|
@@ -1270,22 +1470,44 @@ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv,
|
|
|
1270
1470
|
return resid, (leaf_tree, acc, to_prune, ratios)
|
|
1271
1471
|
|
|
1272
1472
|
items = (
|
|
1273
|
-
bart['leaf_trees'],
|
|
1274
|
-
|
|
1473
|
+
bart['leaf_trees'],
|
|
1474
|
+
prec_trees,
|
|
1475
|
+
moves,
|
|
1476
|
+
move_counts,
|
|
1477
|
+
move_precs,
|
|
1275
1478
|
bart['leaf_indices'],
|
|
1276
|
-
prelkv,
|
|
1479
|
+
prelkv,
|
|
1480
|
+
prelf,
|
|
1277
1481
|
)
|
|
1278
1482
|
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
|
|
1279
1483
|
|
|
1280
1484
|
bart['resid'] = resid
|
|
1281
1485
|
bart['leaf_trees'] = leaf_trees
|
|
1282
|
-
bart.get('ratios', {}).update(ratios)
|
|
1486
|
+
bart.get('ratios', {}).update(ratios) # noop if there are no ratios
|
|
1283
1487
|
moves['acc'] = acc
|
|
1284
1488
|
moves['to_prune'] = to_prune
|
|
1285
1489
|
|
|
1286
1490
|
return bart, moves
|
|
1287
1491
|
|
|
1288
|
-
|
|
1492
|
+
|
|
1493
|
+
def accept_move_and_sample_leaves(
|
|
1494
|
+
X,
|
|
1495
|
+
ntree,
|
|
1496
|
+
resid_batch_size,
|
|
1497
|
+
resid,
|
|
1498
|
+
prec_scale,
|
|
1499
|
+
min_points_per_leaf,
|
|
1500
|
+
save_ratios,
|
|
1501
|
+
prelk,
|
|
1502
|
+
leaf_tree,
|
|
1503
|
+
prec_tree,
|
|
1504
|
+
move,
|
|
1505
|
+
move_counts,
|
|
1506
|
+
move_precs,
|
|
1507
|
+
leaf_indices,
|
|
1508
|
+
prelkv,
|
|
1509
|
+
prelf,
|
|
1510
|
+
):
|
|
1289
1511
|
"""
|
|
1290
1512
|
Accept or reject a proposed move and sample the new leaf values.
|
|
1291
1513
|
|
|
@@ -1299,6 +1521,9 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1299
1521
|
The batch size for computing the sum of residuals in each leaf.
|
|
1300
1522
|
resid : float array (n,)
|
|
1301
1523
|
The residuals (data minus forest value).
|
|
1524
|
+
prec_scale : float array (n,) or None
|
|
1525
|
+
The scale of the precision of the error on each datapoint. If None, it
|
|
1526
|
+
is assumed to be 1.
|
|
1302
1527
|
min_points_per_leaf : int or None
|
|
1303
1528
|
The minimum number of data points in a leaf node.
|
|
1304
1529
|
save_ratios : bool
|
|
@@ -1308,10 +1533,15 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1308
1533
|
trees.
|
|
1309
1534
|
leaf_tree : float array (2 ** d,)
|
|
1310
1535
|
The leaf values of the tree.
|
|
1311
|
-
|
|
1312
|
-
The
|
|
1536
|
+
prec_tree : float array (2 ** d,)
|
|
1537
|
+
The likelihood precision scale in each potential or actual leaf node.
|
|
1313
1538
|
move : dict
|
|
1314
1539
|
The proposed move, see `sample_moves`.
|
|
1540
|
+
move_counts : dict
|
|
1541
|
+
The counts of the number of points in the the nodes modified by the
|
|
1542
|
+
moves.
|
|
1543
|
+
move_precs : dict
|
|
1544
|
+
The likelihood precision scale in each node modified by the moves.
|
|
1315
1545
|
leaf_indices : int array (n,)
|
|
1316
1546
|
The leaf indices for the largest version of the tree compatible with
|
|
1317
1547
|
the move.
|
|
@@ -1334,11 +1564,15 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1334
1564
|
The acceptance ratios for the moves. Empty if not to be saved.
|
|
1335
1565
|
"""
|
|
1336
1566
|
|
|
1337
|
-
# sum residuals
|
|
1338
|
-
|
|
1567
|
+
# sum residuals in each leaf, in tree proposed by grow move
|
|
1568
|
+
if prec_scale is None:
|
|
1569
|
+
scaled_resid = resid
|
|
1570
|
+
else:
|
|
1571
|
+
scaled_resid = resid * prec_scale
|
|
1572
|
+
resid_tree = sum_resid(scaled_resid, leaf_indices, leaf_tree.size, resid_batch_size)
|
|
1339
1573
|
|
|
1340
1574
|
# subtract starting tree from function
|
|
1341
|
-
resid_tree +=
|
|
1575
|
+
resid_tree += prec_tree * leaf_tree
|
|
1342
1576
|
|
|
1343
1577
|
# get indices of move
|
|
1344
1578
|
node = move['node']
|
|
@@ -1353,7 +1587,9 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1353
1587
|
resid_tree = resid_tree.at[node].set(resid_total)
|
|
1354
1588
|
|
|
1355
1589
|
# compute acceptance ratio
|
|
1356
|
-
log_lk_ratio = compute_likelihood_ratio(
|
|
1590
|
+
log_lk_ratio = compute_likelihood_ratio(
|
|
1591
|
+
resid_total, resid_left, resid_right, prelkv, prelk
|
|
1592
|
+
)
|
|
1357
1593
|
log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
|
|
1358
1594
|
log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
|
|
1359
1595
|
ratios = {}
|
|
@@ -1374,10 +1610,10 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1374
1610
|
mean_post = resid_tree * prelf['mean_factor']
|
|
1375
1611
|
leaf_tree = mean_post + prelf['centered_leaves']
|
|
1376
1612
|
|
|
1377
|
-
# copy leaves around such that the leaf indices
|
|
1613
|
+
# copy leaves around such that the leaf indices point to the correct leaf
|
|
1378
1614
|
to_prune = acc ^ move['grow']
|
|
1379
|
-
leaf_tree = (
|
|
1380
|
-
.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1615
|
+
leaf_tree = (
|
|
1616
|
+
leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1381
1617
|
.set(leaf_tree[node])
|
|
1382
1618
|
.at[jnp.where(to_prune, right, leaf_tree.size)]
|
|
1383
1619
|
.set(leaf_tree[node])
|
|
@@ -1388,14 +1624,16 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
|
|
|
1388
1624
|
|
|
1389
1625
|
return resid, leaf_tree, acc, to_prune, ratios
|
|
1390
1626
|
|
|
1391
|
-
|
|
1627
|
+
|
|
1628
|
+
def sum_resid(scaled_resid, leaf_indices, tree_size, batch_size):
|
|
1392
1629
|
"""
|
|
1393
1630
|
Sum the residuals in each leaf.
|
|
1394
1631
|
|
|
1395
1632
|
Parameters
|
|
1396
1633
|
----------
|
|
1397
|
-
|
|
1398
|
-
The residuals (data minus forest value)
|
|
1634
|
+
scaled_resid : float array (n,)
|
|
1635
|
+
The residuals (data minus forest value) multiplied by the error
|
|
1636
|
+
precision scale.
|
|
1399
1637
|
leaf_indices : int array (n,)
|
|
1400
1638
|
The leaf indices of the tree (in which leaf each data point falls into).
|
|
1401
1639
|
tree_size : int
|
|
@@ -1413,29 +1651,32 @@ def sum_resid(resid, leaf_indices, tree_size, batch_size):
|
|
|
1413
1651
|
aggr_func = _aggregate_scatter
|
|
1414
1652
|
else:
|
|
1415
1653
|
aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
|
|
1416
|
-
return aggr_func(
|
|
1654
|
+
return aggr_func(
|
|
1655
|
+
scaled_resid, leaf_indices, tree_size, jnp.float32
|
|
1656
|
+
) # TODO: use large_float
|
|
1657
|
+
|
|
1417
1658
|
|
|
1418
1659
|
def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
1419
|
-
n, = indices.shape
|
|
1660
|
+
(n,) = indices.shape
|
|
1420
1661
|
nbatches = n // batch_size + bool(n % batch_size)
|
|
1421
1662
|
batch_indices = jnp.arange(n) % nbatches
|
|
1422
|
-
return (
|
|
1423
|
-
.zeros((size, nbatches), dtype)
|
|
1663
|
+
return (
|
|
1664
|
+
jnp.zeros((size, nbatches), dtype)
|
|
1424
1665
|
.at[indices, batch_indices]
|
|
1425
1666
|
.add(values)
|
|
1426
1667
|
.sum(axis=1)
|
|
1427
1668
|
)
|
|
1428
1669
|
|
|
1670
|
+
|
|
1429
1671
|
def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
|
|
1430
1672
|
"""
|
|
1431
1673
|
Compute the likelihood ratio of a grow move.
|
|
1432
1674
|
|
|
1433
1675
|
Parameters
|
|
1434
1676
|
----------
|
|
1435
|
-
total_resid : float
|
|
1436
|
-
The sum of the residuals
|
|
1437
|
-
|
|
1438
|
-
The sum of the residuals in the left/right child of the leaf to grow.
|
|
1677
|
+
total_resid, left_resid, right_resid : float
|
|
1678
|
+
The sum of the residuals (scaled by error precision scale) of the
|
|
1679
|
+
datapoints falling in the nodes involved in the moves.
|
|
1439
1680
|
prelkv, prelk : dict
|
|
1440
1681
|
The pre-computed terms of the likelihood ratio, see
|
|
1441
1682
|
`precompute_likelihood_terms`.
|
|
@@ -1446,12 +1687,13 @@ def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk
|
|
|
1446
1687
|
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1447
1688
|
"""
|
|
1448
1689
|
exp_term = prelk['exp_factor'] * (
|
|
1449
|
-
left_resid * left_resid / prelkv['sigma2_left']
|
|
1450
|
-
right_resid * right_resid / prelkv['sigma2_right']
|
|
1451
|
-
total_resid * total_resid / prelkv['sigma2_total']
|
|
1690
|
+
left_resid * left_resid / prelkv['sigma2_left']
|
|
1691
|
+
+ right_resid * right_resid / prelkv['sigma2_right']
|
|
1692
|
+
- total_resid * total_resid / prelkv['sigma2_total']
|
|
1452
1693
|
)
|
|
1453
1694
|
return prelkv['sqrt_term'] + exp_term
|
|
1454
1695
|
|
|
1696
|
+
|
|
1455
1697
|
def accept_moves_final_stage(bart, moves):
|
|
1456
1698
|
"""
|
|
1457
1699
|
The final part of accepting the moves, in parallel across trees.
|
|
@@ -1478,7 +1720,8 @@ def accept_moves_final_stage(bart, moves):
|
|
|
1478
1720
|
bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
|
|
1479
1721
|
return bart
|
|
1480
1722
|
|
|
1481
|
-
|
|
1723
|
+
|
|
1724
|
+
@jaxext.vmap_nodoc
|
|
1482
1725
|
def apply_moves_to_leaf_indices(leaf_indices, moves):
|
|
1483
1726
|
"""
|
|
1484
1727
|
Update the leaf indices to match the accepted move.
|
|
@@ -1497,7 +1740,7 @@ def apply_moves_to_leaf_indices(leaf_indices, moves):
|
|
|
1497
1740
|
leaf_indices : int array (num_trees, n)
|
|
1498
1741
|
The updated leaf indices.
|
|
1499
1742
|
"""
|
|
1500
|
-
mask = ~jnp.array(1, leaf_indices.dtype)
|
|
1743
|
+
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1501
1744
|
is_child = (leaf_indices & mask) == moves['left']
|
|
1502
1745
|
return jnp.where(
|
|
1503
1746
|
is_child & moves['to_prune'],
|
|
@@ -1505,7 +1748,8 @@ def apply_moves_to_leaf_indices(leaf_indices, moves):
|
|
|
1505
1748
|
leaf_indices,
|
|
1506
1749
|
)
|
|
1507
1750
|
|
|
1508
|
-
|
|
1751
|
+
|
|
1752
|
+
@jaxext.vmap_nodoc
|
|
1509
1753
|
def apply_moves_to_split_trees(split_trees, moves):
|
|
1510
1754
|
"""
|
|
1511
1755
|
Update the split trees to match the accepted move.
|
|
@@ -1523,31 +1767,36 @@ def apply_moves_to_split_trees(split_trees, moves):
|
|
|
1523
1767
|
split_trees : int array (num_trees, 2 ** (d - 1))
|
|
1524
1768
|
The updated split trees.
|
|
1525
1769
|
"""
|
|
1526
|
-
return (
|
|
1527
|
-
.at[
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1770
|
+
return (
|
|
1771
|
+
split_trees.at[
|
|
1772
|
+
jnp.where(
|
|
1773
|
+
moves['grow'],
|
|
1774
|
+
moves['node'],
|
|
1775
|
+
split_trees.size,
|
|
1776
|
+
)
|
|
1777
|
+
]
|
|
1532
1778
|
.set(moves['grow_split'].astype(split_trees.dtype))
|
|
1533
|
-
.at[
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1779
|
+
.at[
|
|
1780
|
+
jnp.where(
|
|
1781
|
+
moves['to_prune'],
|
|
1782
|
+
moves['node'],
|
|
1783
|
+
split_trees.size,
|
|
1784
|
+
)
|
|
1785
|
+
]
|
|
1538
1786
|
.set(0)
|
|
1539
1787
|
)
|
|
1540
1788
|
|
|
1541
|
-
|
|
1789
|
+
|
|
1790
|
+
def sample_sigma(key, bart):
|
|
1542
1791
|
"""
|
|
1543
1792
|
Noise variance sampling step of BART MCMC.
|
|
1544
1793
|
|
|
1545
1794
|
Parameters
|
|
1546
1795
|
----------
|
|
1547
|
-
bart : dict
|
|
1548
|
-
A BART mcmc state, as created by `init`.
|
|
1549
1796
|
key : jax.dtypes.prng_key array
|
|
1550
1797
|
A jax random key.
|
|
1798
|
+
bart : dict
|
|
1799
|
+
A BART mcmc state, as created by `init`.
|
|
1551
1800
|
|
|
1552
1801
|
Returns
|
|
1553
1802
|
-------
|
|
@@ -1558,7 +1807,11 @@ def sample_sigma(bart, key):
|
|
|
1558
1807
|
|
|
1559
1808
|
resid = bart['resid']
|
|
1560
1809
|
alpha = bart['sigma2_alpha'] + resid.size / 2
|
|
1561
|
-
|
|
1810
|
+
if bart['prec_scale'] is None:
|
|
1811
|
+
scaled_resid = resid
|
|
1812
|
+
else:
|
|
1813
|
+
scaled_resid = resid * bart['prec_scale']
|
|
1814
|
+
norm2 = resid @ scaled_resid
|
|
1562
1815
|
beta = bart['sigma2_beta'] + norm2 / 2
|
|
1563
1816
|
|
|
1564
1817
|
sample = random.gamma(key, alpha)
|