bartz 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '0.
|
|
1
|
+
__version__ = '0.4.0'
|
bartz/mcmcstep.py
CHANGED
|
@@ -148,6 +148,14 @@ def init(*,
|
|
|
148
148
|
Whether the `min_points_per_leaf` parameter is specified.
|
|
149
149
|
'resid_batch_size', 'count_batch_size' : int or None
|
|
150
150
|
The data batch sizes for computing the sufficient statistics.
|
|
151
|
+
'ratios' : dict, optional
|
|
152
|
+
If `save_ratios` is True, this field is present. It has the fields:
|
|
153
|
+
|
|
154
|
+
'log_trans_prior' : large_float array (num_trees,)
|
|
155
|
+
The log transition and prior Metropolis-Hastings ratio for the
|
|
156
|
+
proposed move on each tree.
|
|
157
|
+
'log_likelihood' : large_float array (num_trees,)
|
|
158
|
+
The log likelihood ratio.
|
|
151
159
|
"""
|
|
152
160
|
|
|
153
161
|
p_nonterminal = jnp.asarray(p_nonterminal, large_float)
|
|
@@ -161,7 +169,7 @@ def init(*,
|
|
|
161
169
|
small_float = jnp.dtype(small_float)
|
|
162
170
|
large_float = jnp.dtype(large_float)
|
|
163
171
|
y = jnp.asarray(y, small_float)
|
|
164
|
-
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y)
|
|
172
|
+
resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees)
|
|
165
173
|
sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
|
|
166
174
|
sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
|
|
167
175
|
|
|
@@ -202,19 +210,13 @@ def init(*,
|
|
|
202
210
|
|
|
203
211
|
if save_ratios:
|
|
204
212
|
bart['ratios'] = dict(
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
likelihood=jnp.full(num_trees, jnp.nan),
|
|
208
|
-
),
|
|
209
|
-
prune=dict(
|
|
210
|
-
trans_prior=jnp.full(num_trees, jnp.nan),
|
|
211
|
-
likelihood=jnp.full(num_trees, jnp.nan),
|
|
212
|
-
),
|
|
213
|
+
log_trans_prior=jnp.full(num_trees, jnp.nan),
|
|
214
|
+
log_likelihood=jnp.full(num_trees, jnp.nan),
|
|
213
215
|
)
|
|
214
216
|
|
|
215
217
|
return bart
|
|
216
218
|
|
|
217
|
-
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
|
|
219
|
+
def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
|
|
218
220
|
|
|
219
221
|
@functools.cache
|
|
220
222
|
def get_platform():
|
|
@@ -244,6 +246,10 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
|
|
|
244
246
|
n = max(1, y.size)
|
|
245
247
|
count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
|
|
246
248
|
# /4 is good on V100, /2 on L4/T4, still haven't tried A100
|
|
249
|
+
max_memory = 2 ** 29
|
|
250
|
+
itemsize = 4
|
|
251
|
+
min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
|
|
252
|
+
count_batch_size = max(count_batch_size, min_batch_size)
|
|
247
253
|
count_batch_size = max(1, count_batch_size)
|
|
248
254
|
|
|
249
255
|
return resid_batch_size, count_batch_size
|
|
@@ -289,8 +295,8 @@ def sample_trees(bart, key):
|
|
|
289
295
|
This function zeroes the proposal counters.
|
|
290
296
|
"""
|
|
291
297
|
key, subkey = random.split(key)
|
|
292
|
-
|
|
293
|
-
return accept_moves_and_sample_leaves(bart,
|
|
298
|
+
moves = sample_moves(bart, subkey)
|
|
299
|
+
return accept_moves_and_sample_leaves(bart, moves, key)
|
|
294
300
|
|
|
295
301
|
def sample_moves(bart, key):
|
|
296
302
|
"""
|
|
@@ -305,11 +311,65 @@ def sample_moves(bart, key):
|
|
|
305
311
|
|
|
306
312
|
Returns
|
|
307
313
|
-------
|
|
308
|
-
|
|
309
|
-
|
|
314
|
+
moves : dict
|
|
315
|
+
A dictionary with fields:
|
|
316
|
+
|
|
317
|
+
'allowed' : bool array (num_trees,)
|
|
318
|
+
Whether the move is possible.
|
|
319
|
+
'grow' : bool array (num_trees,)
|
|
320
|
+
Whether the move is a grow move or a prune move.
|
|
321
|
+
'num_growable' : int array (num_trees,)
|
|
322
|
+
The number of growable leaves in the original tree.
|
|
323
|
+
'node' : int array (num_trees,)
|
|
324
|
+
The index of the leaf to grow or node to prune.
|
|
325
|
+
'left', 'right' : int array (num_trees,)
|
|
326
|
+
The indices of the children of 'node'.
|
|
327
|
+
'partial_ratio' : float array (num_trees,)
|
|
328
|
+
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
329
|
+
the likelihood ratio and the probability of proposing the prune
|
|
330
|
+
move. If the move is Prune, the ratio is inverted.
|
|
331
|
+
'grow_var' : int array (num_trees,)
|
|
332
|
+
The decision axes of the new rules.
|
|
333
|
+
'grow_split' : int array (num_trees,)
|
|
334
|
+
The decision boundaries of the new rules.
|
|
335
|
+
'var_trees' : int array (num_trees, 2 ** (d - 1))
|
|
336
|
+
The updated decision axes of the trees, valid whatever move.
|
|
337
|
+
'logu' : float array (num_trees,)
|
|
338
|
+
The logarithm of a uniform (0, 1] random variable to be used to
|
|
339
|
+
accept the move. It's in (-oo, 0].
|
|
310
340
|
"""
|
|
311
|
-
|
|
312
|
-
|
|
341
|
+
ntree = bart['leaf_trees'].shape[0]
|
|
342
|
+
key = random.split(key, 1 + ntree)
|
|
343
|
+
key, subkey = key[0], key[1:]
|
|
344
|
+
|
|
345
|
+
# compute moves
|
|
346
|
+
grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey)
|
|
347
|
+
|
|
348
|
+
u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
|
|
349
|
+
|
|
350
|
+
# choose between grow or prune
|
|
351
|
+
grow_allowed = grow_moves['num_growable'].astype(bool)
|
|
352
|
+
p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
|
|
353
|
+
grow = u < p_grow # use < instead of <= because u is in [0, 1)
|
|
354
|
+
|
|
355
|
+
# compute children indices
|
|
356
|
+
node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
|
|
357
|
+
left = node << 1
|
|
358
|
+
right = left + 1
|
|
359
|
+
|
|
360
|
+
return dict(
|
|
361
|
+
allowed=grow | prune_moves['allowed'],
|
|
362
|
+
grow=grow,
|
|
363
|
+
num_growable=grow_moves['num_growable'],
|
|
364
|
+
node=node,
|
|
365
|
+
left=left,
|
|
366
|
+
right=right,
|
|
367
|
+
partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
|
|
368
|
+
grow_var=grow_moves['var'],
|
|
369
|
+
grow_split=grow_moves['split'],
|
|
370
|
+
var_trees=grow_moves['var_tree'],
|
|
371
|
+
logu=jnp.log1p(-logu),
|
|
372
|
+
)
|
|
313
373
|
|
|
314
374
|
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
|
|
315
375
|
def _sample_moves_vmap_trees(*args):
|
|
@@ -354,16 +414,14 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
|
|
|
354
414
|
'node' : int
|
|
355
415
|
The index of the leaf to grow. ``2 ** d`` if there are no growable
|
|
356
416
|
leaves.
|
|
357
|
-
'left', 'right' : int
|
|
358
|
-
The indices of the children of 'node'.
|
|
359
417
|
'var', 'split' : int
|
|
360
418
|
The decision axis and boundary of the new rule.
|
|
361
419
|
'partial_ratio' : float
|
|
362
420
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
363
421
|
the likelihood ratio and the probability of proposing the prune
|
|
364
422
|
move.
|
|
365
|
-
'var_tree'
|
|
366
|
-
The updated decision axes
|
|
423
|
+
'var_tree' : array (2 ** (d - 1),)
|
|
424
|
+
The updated decision axes of the tree.
|
|
367
425
|
"""
|
|
368
426
|
|
|
369
427
|
key, key1, key2 = random.split(key, 3)
|
|
@@ -374,21 +432,16 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
|
|
|
374
432
|
var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
|
|
375
433
|
|
|
376
434
|
split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
|
|
377
|
-
split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
|
|
378
435
|
|
|
379
|
-
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
436
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)
|
|
380
437
|
|
|
381
|
-
left = leaf_to_grow << 1
|
|
382
438
|
return dict(
|
|
383
439
|
num_growable=num_growable,
|
|
384
440
|
node=leaf_to_grow,
|
|
385
|
-
left=left,
|
|
386
|
-
right=left + 1,
|
|
387
441
|
var=var,
|
|
388
442
|
split=split,
|
|
389
443
|
partial_ratio=ratio,
|
|
390
444
|
var_tree=var_tree,
|
|
391
|
-
split_tree=split_tree,
|
|
392
445
|
)
|
|
393
446
|
|
|
394
447
|
def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
|
|
@@ -658,7 +711,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
|
|
|
658
711
|
l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
|
|
659
712
|
return random.randint(key, (), l, r)
|
|
660
713
|
|
|
661
|
-
def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
714
|
+
def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
|
|
662
715
|
"""
|
|
663
716
|
Compute the product of the transition and prior ratios of a grow move.
|
|
664
717
|
|
|
@@ -673,8 +726,6 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
673
726
|
The probability of a nonterminal node at each depth.
|
|
674
727
|
leaf_to_grow : int
|
|
675
728
|
The index of the leaf to grow.
|
|
676
|
-
new_split_tree : array (2 ** (d - 1),)
|
|
677
|
-
The splitting points of the tree, after the leaf is grown.
|
|
678
729
|
|
|
679
730
|
Returns
|
|
680
731
|
-------
|
|
@@ -699,7 +750,7 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
|
|
|
699
750
|
|
|
700
751
|
inv_trans_ratio = p_grow * prob_choose * num_prunable
|
|
701
752
|
|
|
702
|
-
depth = grove.tree_depths(
|
|
753
|
+
depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
|
|
703
754
|
p_parent = p_nonterminal[depth]
|
|
704
755
|
cp_children = 1 - p_nonterminal[depth + 1]
|
|
705
756
|
tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
|
|
@@ -736,8 +787,6 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p
|
|
|
736
787
|
Whether the move is possible.
|
|
737
788
|
'node' : int
|
|
738
789
|
The index of the node to prune. ``2 ** d`` if no node can be pruned.
|
|
739
|
-
'left', 'right' : int
|
|
740
|
-
The indices of the children of 'node'.
|
|
741
790
|
'partial_ratio' : float
|
|
742
791
|
A factor of the Metropolis-Hastings ratio of the move. It lacks
|
|
743
792
|
the likelihood ratio and the probability of proposing the prune
|
|
@@ -746,14 +795,11 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p
|
|
|
746
795
|
node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
|
|
747
796
|
allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
|
|
748
797
|
|
|
749
|
-
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune
|
|
798
|
+
ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune)
|
|
750
799
|
|
|
751
|
-
left = node_to_prune << 1
|
|
752
800
|
return dict(
|
|
753
801
|
allowed=allowed,
|
|
754
802
|
node=node_to_prune,
|
|
755
|
-
left=left,
|
|
756
|
-
right=left + 1,
|
|
757
803
|
partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
|
|
758
804
|
)
|
|
759
805
|
|
|
@@ -819,7 +865,7 @@ def randint_masked(key, mask):
|
|
|
819
865
|
u = random.randint(key, (), 0, ecdf[-1])
|
|
820
866
|
return jnp.searchsorted(ecdf, u, 'right')
|
|
821
867
|
|
|
822
|
-
def accept_moves_and_sample_leaves(bart,
|
|
868
|
+
def accept_moves_and_sample_leaves(bart, moves, key):
|
|
823
869
|
"""
|
|
824
870
|
Accept or reject the proposed moves and sample the new leaf values.
|
|
825
871
|
|
|
@@ -827,12 +873,8 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
|
827
873
|
----------
|
|
828
874
|
bart : dict
|
|
829
875
|
A BART mcmc state.
|
|
830
|
-
|
|
831
|
-
The
|
|
832
|
-
`grow_move`.
|
|
833
|
-
prune_moves : dict
|
|
834
|
-
The proposals for prune moves, batched over the first axis. See
|
|
835
|
-
`prune_move`.
|
|
876
|
+
moves : dict
|
|
877
|
+
The proposed moves, see `sample_moves`.
|
|
836
878
|
key : jax.dtypes.prng_key array
|
|
837
879
|
A jax random key.
|
|
838
880
|
|
|
@@ -841,11 +883,11 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
|
|
|
841
883
|
bart : dict
|
|
842
884
|
The new BART mcmc state.
|
|
843
885
|
"""
|
|
844
|
-
bart,
|
|
845
|
-
bart,
|
|
846
|
-
return accept_moves_final_stage(bart,
|
|
886
|
+
bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key)
|
|
887
|
+
bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)
|
|
888
|
+
return accept_moves_final_stage(bart, moves)
|
|
847
889
|
|
|
848
|
-
def accept_moves_parallel_stage(bart,
|
|
890
|
+
def accept_moves_parallel_stage(bart, moves, key):
|
|
849
891
|
"""
|
|
850
892
|
Pre-computes quantities used to accept moves, in parallel across trees.
|
|
851
893
|
|
|
@@ -853,9 +895,8 @@ def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
|
|
|
853
895
|
----------
|
|
854
896
|
bart : dict
|
|
855
897
|
A BART mcmc state.
|
|
856
|
-
|
|
857
|
-
The
|
|
858
|
-
`grow_move` and `prune_move`.
|
|
898
|
+
moves : dict
|
|
899
|
+
The proposed moves, see `sample_moves`.
|
|
859
900
|
key : jax.dtypes.prng_key array
|
|
860
901
|
A jax random key.
|
|
861
902
|
|
|
@@ -863,51 +904,50 @@ def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
|
|
|
863
904
|
-------
|
|
864
905
|
bart : dict
|
|
865
906
|
A partially updated BART mcmc state.
|
|
866
|
-
|
|
867
|
-
The
|
|
868
|
-
by '
|
|
869
|
-
count_trees : array (num_trees, 2 **
|
|
907
|
+
moves : dict
|
|
908
|
+
The proposed moves, with the field 'partial_ratio' replaced
|
|
909
|
+
by 'log_trans_prior_ratio'.
|
|
910
|
+
count_trees : array (num_trees, 2 ** d)
|
|
870
911
|
The number of points in each potential or actual leaf node.
|
|
871
912
|
move_counts : dict
|
|
872
913
|
The counts of the number of points in the the nodes modified by the
|
|
873
914
|
moves.
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
Random standard normal values used to sample the new leaf values.
|
|
915
|
+
prelkv, prelk, prelf : dict
|
|
916
|
+
Dictionary with pre-computed terms of the likelihood ratios and leaf
|
|
917
|
+
samples.
|
|
878
918
|
"""
|
|
879
919
|
bart = bart.copy()
|
|
880
920
|
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
bart['leaf_indices'] = apply_grow_to_indices(grow_moves, bart['leaf_indices'], bart['X'])
|
|
886
|
-
|
|
887
|
-
count_trees, move_counts = compute_count_trees(bart['leaf_indices'], grow_moves, prune_moves, bart['opt']['count_batch_size'])
|
|
888
|
-
|
|
889
|
-
grow_moves, prune_moves = complete_ratio(grow_moves, prune_moves, move_counts, bart['min_points_per_leaf'])
|
|
921
|
+
# where the move is grow, modify the state like the move was accepted
|
|
922
|
+
bart['var_trees'] = moves['var_trees']
|
|
923
|
+
bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
|
|
924
|
+
bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
|
|
890
925
|
|
|
926
|
+
# count number of datapoints per leaf
|
|
927
|
+
count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size'])
|
|
891
928
|
if bart['opt']['require_min_points']:
|
|
892
|
-
count_half_trees = count_trees[:, :
|
|
929
|
+
count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
|
|
893
930
|
bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
|
|
894
931
|
|
|
895
|
-
|
|
932
|
+
# compute some missing information about moves
|
|
933
|
+
moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
|
|
934
|
+
bart['grow_prop_count'] = jnp.sum(moves['grow'])
|
|
935
|
+
bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
|
|
896
936
|
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
z = random.normal(key, bart['leaf_trees'].shape, bart['opt']['large_float'])
|
|
937
|
+
prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts)
|
|
938
|
+
prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key)
|
|
900
939
|
|
|
901
|
-
return bart,
|
|
940
|
+
return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
|
|
902
941
|
|
|
903
|
-
|
|
942
|
+
@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
|
|
943
|
+
def apply_grow_to_indices(moves, leaf_indices, X):
|
|
904
944
|
"""
|
|
905
945
|
Update the leaf indices to apply a grow move.
|
|
906
946
|
|
|
907
947
|
Parameters
|
|
908
948
|
----------
|
|
909
|
-
|
|
910
|
-
The
|
|
949
|
+
moves : dict
|
|
950
|
+
The proposed moves, see `sample_moves`.
|
|
911
951
|
leaf_indices : array (num_trees, n)
|
|
912
952
|
The index of the leaf each datapoint falls into.
|
|
913
953
|
X : array (p, n)
|
|
@@ -918,17 +958,17 @@ def apply_grow_to_indices(grow_moves, leaf_indices, X):
|
|
|
918
958
|
grow_leaf_indices : array (num_trees, n)
|
|
919
959
|
The updated leaf indices.
|
|
920
960
|
"""
|
|
921
|
-
left_child =
|
|
922
|
-
go_right = X[
|
|
923
|
-
tree_size = jnp.array(2 *
|
|
924
|
-
node_to_update = jnp.where(
|
|
961
|
+
left_child = moves['node'].astype(leaf_indices.dtype) << 1
|
|
962
|
+
go_right = X[moves['grow_var'], :] >= moves['grow_split']
|
|
963
|
+
tree_size = jnp.array(2 * moves['var_trees'].size)
|
|
964
|
+
node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
|
|
925
965
|
return jnp.where(
|
|
926
|
-
leaf_indices == node_to_update
|
|
927
|
-
left_child
|
|
966
|
+
leaf_indices == node_to_update,
|
|
967
|
+
left_child + go_right,
|
|
928
968
|
leaf_indices,
|
|
929
969
|
)
|
|
930
970
|
|
|
931
|
-
def compute_count_trees(
|
|
971
|
+
def compute_count_trees(leaf_indices, moves, batch_size):
|
|
932
972
|
"""
|
|
933
973
|
Count the number of datapoints in each leaf.
|
|
934
974
|
|
|
@@ -937,8 +977,8 @@ def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
|
|
|
937
977
|
grow_leaf_indices : int array (num_trees, n)
|
|
938
978
|
The index of the leaf each datapoint falls into, if the grow move is
|
|
939
979
|
accepted.
|
|
940
|
-
|
|
941
|
-
The
|
|
980
|
+
moves : dict
|
|
981
|
+
The proposed moves, see `sample_moves`.
|
|
942
982
|
batch_size : int or None
|
|
943
983
|
The data batch size to use for the summation.
|
|
944
984
|
|
|
@@ -952,24 +992,20 @@ def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
|
|
|
952
992
|
'left', 'right', and 'total'.
|
|
953
993
|
"""
|
|
954
994
|
|
|
955
|
-
ntree, tree_size =
|
|
995
|
+
ntree, tree_size = moves['var_trees'].shape
|
|
956
996
|
tree_size *= 2
|
|
957
|
-
counts = dict(grow=dict(), prune=dict())
|
|
958
997
|
tree_indices = jnp.arange(ntree)
|
|
959
998
|
|
|
960
|
-
count_trees = count_datapoints_per_leaf(
|
|
999
|
+
count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
|
|
961
1000
|
|
|
962
|
-
# count datapoints in
|
|
963
|
-
counts
|
|
964
|
-
counts['
|
|
965
|
-
counts['
|
|
966
|
-
|
|
1001
|
+
# count datapoints in nodes modified by move
|
|
1002
|
+
counts = dict()
|
|
1003
|
+
counts['left'] = count_trees[tree_indices, moves['left']]
|
|
1004
|
+
counts['right'] = count_trees[tree_indices, moves['right']]
|
|
1005
|
+
counts['total'] = counts['left'] + counts['right']
|
|
967
1006
|
|
|
968
|
-
# count
|
|
969
|
-
|
|
970
|
-
counts['prune']['right'] = count_trees[tree_indices, prune_moves['right']]
|
|
971
|
-
counts['prune']['total'] = counts['prune']['left'] + counts['prune']['right']
|
|
972
|
-
count_trees = count_trees.at[tree_indices, prune_moves['node']].set(counts['prune']['total'])
|
|
1007
|
+
# write count into non-leaf node
|
|
1008
|
+
count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
|
|
973
1009
|
|
|
974
1010
|
return count_trees, counts
|
|
975
1011
|
|
|
@@ -1025,7 +1061,7 @@ def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
|
|
|
1025
1061
|
.sum(axis=2)
|
|
1026
1062
|
)
|
|
1027
1063
|
|
|
1028
|
-
def complete_ratio(
|
|
1064
|
+
def complete_ratio(moves, move_counts, min_points_per_leaf):
|
|
1029
1065
|
"""
|
|
1030
1066
|
Complete non-likelihood MH ratio calculation.
|
|
1031
1067
|
|
|
@@ -1033,8 +1069,8 @@ def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
|
|
|
1033
1069
|
|
|
1034
1070
|
Parameters
|
|
1035
1071
|
----------
|
|
1036
|
-
|
|
1037
|
-
The
|
|
1072
|
+
moves : dict
|
|
1073
|
+
The proposed moves, see `sample_moves`.
|
|
1038
1074
|
move_counts : dict
|
|
1039
1075
|
The counts of the number of points in the the nodes modified by the
|
|
1040
1076
|
moves.
|
|
@@ -1043,61 +1079,62 @@ def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
|
|
|
1043
1079
|
|
|
1044
1080
|
Returns
|
|
1045
1081
|
-------
|
|
1046
|
-
|
|
1047
|
-
The
|
|
1048
|
-
|
|
1082
|
+
moves : dict
|
|
1083
|
+
The updated moves, with the field 'partial_ratio' replaced by
|
|
1084
|
+
'log_trans_prior_ratio'.
|
|
1049
1085
|
"""
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
grow_moves['trans_prior_ratio'] = grow_moves.pop('partial_ratio') * grow_p_prune
|
|
1055
|
-
prune_moves['trans_prior_ratio'] = prune_moves.pop('partial_ratio') * prune_p_prune
|
|
1056
|
-
return grow_moves, prune_moves
|
|
1086
|
+
moves = moves.copy()
|
|
1087
|
+
p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf)
|
|
1088
|
+
moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
|
|
1089
|
+
return moves
|
|
1057
1090
|
|
|
1058
|
-
def compute_p_prune(
|
|
1091
|
+
def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
|
|
1059
1092
|
"""
|
|
1060
1093
|
Compute the probability of proposing a prune move.
|
|
1061
1094
|
|
|
1062
1095
|
Parameters
|
|
1063
1096
|
----------
|
|
1064
|
-
|
|
1065
|
-
The
|
|
1066
|
-
|
|
1097
|
+
moves : dict
|
|
1098
|
+
The proposed moves, see `sample_moves`.
|
|
1099
|
+
left_count, right_count : int
|
|
1067
1100
|
The number of datapoints in the proposed children of the leaf to grow.
|
|
1068
1101
|
min_points_per_leaf : int or None
|
|
1069
1102
|
The minimum number of data points in a leaf node.
|
|
1070
1103
|
|
|
1071
1104
|
Returns
|
|
1072
1105
|
-------
|
|
1073
|
-
|
|
1074
|
-
The probability of proposing a prune move
|
|
1075
|
-
move.
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
other_growable_leaves =
|
|
1080
|
-
new_leaves_growable =
|
|
1106
|
+
p_prune : float
|
|
1107
|
+
The probability of proposing a prune move. If grow: after accepting the
|
|
1108
|
+
grow move, if prune: right away.
|
|
1109
|
+
"""
|
|
1110
|
+
|
|
1111
|
+
# calculation in case the move is grow
|
|
1112
|
+
other_growable_leaves = moves['num_growable'] >= 2
|
|
1113
|
+
new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
|
|
1081
1114
|
if min_points_per_leaf is not None:
|
|
1082
|
-
any_above_threshold =
|
|
1083
|
-
any_above_threshold |=
|
|
1115
|
+
any_above_threshold = left_count >= 2 * min_points_per_leaf
|
|
1116
|
+
any_above_threshold |= right_count >= 2 * min_points_per_leaf
|
|
1084
1117
|
new_leaves_growable &= any_above_threshold
|
|
1085
1118
|
grow_again_allowed = other_growable_leaves | new_leaves_growable
|
|
1086
1119
|
grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
|
|
1087
|
-
prune_p_prune = jnp.where(grow_move['num_growable'], 0.5, 1)
|
|
1088
|
-
return grow_p_prune, prune_p_prune
|
|
1089
1120
|
|
|
1090
|
-
|
|
1121
|
+
# calculation in case the move is prune
|
|
1122
|
+
prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
|
|
1123
|
+
|
|
1124
|
+
return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
|
|
1125
|
+
|
|
1126
|
+
@jaxext.vmap_nodoc
|
|
1127
|
+
def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
|
|
1091
1128
|
"""
|
|
1092
|
-
Modify leaf values such that the indices of the grow
|
|
1129
|
+
Modify leaf values such that the indices of the grow moves work on the
|
|
1093
1130
|
original tree.
|
|
1094
1131
|
|
|
1095
1132
|
Parameters
|
|
1096
1133
|
----------
|
|
1097
1134
|
leaf_trees : float array (num_trees, 2 ** d)
|
|
1098
1135
|
The leaf values.
|
|
1099
|
-
|
|
1100
|
-
The
|
|
1136
|
+
moves : dict
|
|
1137
|
+
The proposed moves, see `sample_moves`.
|
|
1101
1138
|
|
|
1102
1139
|
Returns
|
|
1103
1140
|
-------
|
|
@@ -1105,17 +1142,86 @@ def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
|
|
|
1105
1142
|
The modified leaf values. The value of the leaf to grow is copied to
|
|
1106
1143
|
what would be its children if the grow move was accepted.
|
|
1107
1144
|
"""
|
|
1108
|
-
|
|
1109
|
-
tree_indices = jnp.arange(ntree)
|
|
1110
|
-
values_at_node = leaf_trees[tree_indices, grow_moves['node']]
|
|
1145
|
+
values_at_node = leaf_trees[moves['node']]
|
|
1111
1146
|
return (leaf_trees
|
|
1112
|
-
.at[
|
|
1147
|
+
.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
|
|
1113
1148
|
.set(values_at_node)
|
|
1114
|
-
.at[
|
|
1149
|
+
.at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
|
|
1115
1150
|
.set(values_at_node)
|
|
1116
1151
|
)
|
|
1117
1152
|
|
|
1118
|
-
def
|
|
1153
|
+
def precompute_likelihood_terms(count_trees, sigma2, move_counts):
|
|
1154
|
+
"""
|
|
1155
|
+
Pre-compute terms used in the likelihood ratio of the acceptance step.
|
|
1156
|
+
|
|
1157
|
+
Parameters
|
|
1158
|
+
----------
|
|
1159
|
+
count_trees : array (num_trees, 2 ** d)
|
|
1160
|
+
The number of points in each potential or actual leaf node.
|
|
1161
|
+
sigma2 : float
|
|
1162
|
+
The noise variance.
|
|
1163
|
+
move_counts : dict
|
|
1164
|
+
The counts of the number of points in the the nodes modified by the
|
|
1165
|
+
moves.
|
|
1166
|
+
|
|
1167
|
+
Returns
|
|
1168
|
+
-------
|
|
1169
|
+
prelkv : dict
|
|
1170
|
+
Dictionary with pre-computed terms of the likelihood ratio, one per
|
|
1171
|
+
tree.
|
|
1172
|
+
prelk : dict
|
|
1173
|
+
Dictionary with pre-computed terms of the likelihood ratio, shared by
|
|
1174
|
+
all trees.
|
|
1175
|
+
"""
|
|
1176
|
+
ntree = len(count_trees)
|
|
1177
|
+
sigma_mu2 = 1 / ntree
|
|
1178
|
+
prelkv = dict()
|
|
1179
|
+
prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2
|
|
1180
|
+
prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2
|
|
1181
|
+
prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2
|
|
1182
|
+
prelkv['sqrt_term'] = jnp.log(
|
|
1183
|
+
sigma2 * prelkv['sigma2_total'] /
|
|
1184
|
+
(prelkv['sigma2_left'] * prelkv['sigma2_right'])
|
|
1185
|
+
) / 2
|
|
1186
|
+
return prelkv, dict(
|
|
1187
|
+
exp_factor=sigma_mu2 / (2 * sigma2),
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
def precompute_leaf_terms(count_trees, sigma2, key):
|
|
1191
|
+
"""
|
|
1192
|
+
Pre-compute terms used to sample leaves from their posterior.
|
|
1193
|
+
|
|
1194
|
+
Parameters
|
|
1195
|
+
----------
|
|
1196
|
+
count_trees : array (num_trees, 2 ** d)
|
|
1197
|
+
The number of points in each potential or actual leaf node.
|
|
1198
|
+
sigma2 : float
|
|
1199
|
+
The noise variance.
|
|
1200
|
+
key : jax.dtypes.prng_key array
|
|
1201
|
+
A jax random key.
|
|
1202
|
+
|
|
1203
|
+
Returns
|
|
1204
|
+
-------
|
|
1205
|
+
prelf : dict
|
|
1206
|
+
Dictionary with pre-computed terms of the leaf sampling, with fields:
|
|
1207
|
+
|
|
1208
|
+
'mean_factor' : float array (num_trees, 2 ** d)
|
|
1209
|
+
The factor to be multiplied by the sum of residuals to obtain the
|
|
1210
|
+
posterior mean.
|
|
1211
|
+
'centered_leaves' : float array (num_trees, 2 ** d)
|
|
1212
|
+
The mean-zero normal values to be added to the posterior mean to
|
|
1213
|
+
obtain the posterior leaf samples.
|
|
1214
|
+
"""
|
|
1215
|
+
ntree = len(count_trees)
|
|
1216
|
+
prec_lk = count_trees / sigma2
|
|
1217
|
+
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
1218
|
+
z = random.normal(key, count_trees.shape, sigma2.dtype)
|
|
1219
|
+
return dict(
|
|
1220
|
+
mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
|
|
1221
|
+
centered_leaves=z * jnp.sqrt(var_post),
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf):
|
|
1119
1225
|
"""
|
|
1120
1226
|
The part of accepting the moves that has to be done one tree at a time.
|
|
1121
1227
|
|
|
@@ -1123,57 +1229,63 @@ def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, mo
|
|
|
1123
1229
|
----------
|
|
1124
1230
|
bart : dict
|
|
1125
1231
|
A partially updated BART mcmc state.
|
|
1126
|
-
count_trees : array (num_trees, 2 **
|
|
1232
|
+
count_trees : array (num_trees, 2 ** d)
|
|
1127
1233
|
The number of points in each potential or actual leaf node.
|
|
1128
|
-
|
|
1129
|
-
The
|
|
1130
|
-
`prune_move`.
|
|
1234
|
+
moves : dict
|
|
1235
|
+
The proposed moves, see `sample_moves`.
|
|
1131
1236
|
move_counts : dict
|
|
1132
1237
|
The counts of the number of points in the the nodes modified by the
|
|
1133
1238
|
moves.
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
Random standard normal values used to sample the new leaf values.
|
|
1239
|
+
prelkv, prelk, prelf : dict
|
|
1240
|
+
Dictionaries with pre-computed terms of the likelihood ratios and leaf
|
|
1241
|
+
samples.
|
|
1138
1242
|
|
|
1139
1243
|
Returns
|
|
1140
1244
|
-------
|
|
1141
1245
|
bart : dict
|
|
1142
1246
|
A partially updated BART mcmc state.
|
|
1143
|
-
|
|
1144
|
-
The
|
|
1247
|
+
moves : dict
|
|
1248
|
+
The proposed moves, with these additional fields:
|
|
1249
|
+
|
|
1250
|
+
'acc' : bool array (num_trees,)
|
|
1251
|
+
Whether the move was accepted.
|
|
1252
|
+
'to_prune' : bool array (num_trees,)
|
|
1253
|
+
Whether, to reflect the acceptance status of the move, the state
|
|
1254
|
+
should be updated by pruning the leaves involved in the move.
|
|
1145
1255
|
"""
|
|
1146
1256
|
bart = bart.copy()
|
|
1257
|
+
moves = moves.copy()
|
|
1147
1258
|
|
|
1148
1259
|
def loop(resid, item):
|
|
1149
|
-
resid, leaf_tree,
|
|
1260
|
+
resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
|
|
1150
1261
|
bart['X'],
|
|
1151
1262
|
len(bart['leaf_trees']),
|
|
1152
1263
|
bart['opt']['resid_batch_size'],
|
|
1153
1264
|
resid,
|
|
1154
|
-
bart['sigma2'],
|
|
1155
1265
|
bart['min_points_per_leaf'],
|
|
1156
1266
|
'ratios' in bart,
|
|
1267
|
+
prelk,
|
|
1157
1268
|
*item,
|
|
1158
1269
|
)
|
|
1159
|
-
return resid, (leaf_tree,
|
|
1270
|
+
return resid, (leaf_tree, acc, to_prune, ratios)
|
|
1160
1271
|
|
|
1161
1272
|
items = (
|
|
1162
1273
|
bart['leaf_trees'], count_trees,
|
|
1163
|
-
|
|
1274
|
+
moves, move_counts,
|
|
1164
1275
|
bart['leaf_indices'],
|
|
1165
|
-
|
|
1276
|
+
prelkv, prelf,
|
|
1166
1277
|
)
|
|
1167
|
-
resid, (leaf_trees,
|
|
1278
|
+
resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
|
|
1168
1279
|
|
|
1169
1280
|
bart['resid'] = resid
|
|
1170
1281
|
bart['leaf_trees'] = leaf_trees
|
|
1171
|
-
bart['split_trees'] = split_trees
|
|
1172
1282
|
bart.get('ratios', {}).update(ratios)
|
|
1283
|
+
moves['acc'] = acc
|
|
1284
|
+
moves['to_prune'] = to_prune
|
|
1173
1285
|
|
|
1174
|
-
return bart,
|
|
1286
|
+
return bart, moves
|
|
1175
1287
|
|
|
1176
|
-
def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid,
|
|
1288
|
+
def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf):
|
|
1177
1289
|
"""
|
|
1178
1290
|
Accept or reject a proposed move and sample the new leaf values.
|
|
1179
1291
|
|
|
@@ -1187,25 +1299,25 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min
|
|
|
1187
1299
|
The batch size for computing the sum of residuals in each leaf.
|
|
1188
1300
|
resid : float array (n,)
|
|
1189
1301
|
The residuals (data minus forest value).
|
|
1190
|
-
sigma2 : float
|
|
1191
|
-
The noise variance.
|
|
1192
1302
|
min_points_per_leaf : int or None
|
|
1193
1303
|
The minimum number of data points in a leaf node.
|
|
1194
1304
|
save_ratios : bool
|
|
1195
1305
|
Whether to save the acceptance ratios.
|
|
1306
|
+
prelk : dict
|
|
1307
|
+
The pre-computed terms of the likelihood ratio which are shared across
|
|
1308
|
+
trees.
|
|
1196
1309
|
leaf_tree : float array (2 ** d,)
|
|
1197
1310
|
The leaf values of the tree.
|
|
1198
1311
|
count_tree : int array (2 ** d,)
|
|
1199
1312
|
The number of datapoints in each leaf.
|
|
1200
|
-
|
|
1201
|
-
The
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
Standard normal random values.
|
|
1313
|
+
move : dict
|
|
1314
|
+
The proposed move, see `sample_moves`.
|
|
1315
|
+
leaf_indices : int array (n,)
|
|
1316
|
+
The leaf indices for the largest version of the tree compatible with
|
|
1317
|
+
the move.
|
|
1318
|
+
prelkv, prelf : dict
|
|
1319
|
+
The pre-computed terms of the likelihood ratio and leaf sampling which
|
|
1320
|
+
are specific to the tree.
|
|
1209
1321
|
|
|
1210
1322
|
Returns
|
|
1211
1323
|
-------
|
|
@@ -1213,123 +1325,68 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min
|
|
|
1213
1325
|
The updated residuals (data minus forest value).
|
|
1214
1326
|
leaf_tree : float array (2 ** d,)
|
|
1215
1327
|
The new leaf values of the tree.
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1328
|
+
acc : bool
|
|
1329
|
+
Whether the move was accepted.
|
|
1330
|
+
to_prune : bool
|
|
1331
|
+
Whether, to reflect the acceptance status of the move, the state should
|
|
1332
|
+
be updated by pruning the leaves involved in the move.
|
|
1220
1333
|
ratios : dict
|
|
1221
1334
|
The acceptance ratios for the moves. Empty if not to be saved.
|
|
1222
1335
|
"""
|
|
1223
1336
|
|
|
1224
1337
|
# sum residuals and count units per leaf, in tree proposed by grow move
|
|
1225
|
-
resid_tree = sum_resid(resid,
|
|
1338
|
+
resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size)
|
|
1226
1339
|
|
|
1227
1340
|
# subtract starting tree from function
|
|
1228
1341
|
resid_tree += count_tree * leaf_tree
|
|
1229
1342
|
|
|
1230
|
-
# get indices of
|
|
1231
|
-
|
|
1232
|
-
assert
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
# sum residuals in
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
resid_tree = resid_tree.at[
|
|
1241
|
-
|
|
1242
|
-
#
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
prune_right = prune_move['right']
|
|
1247
|
-
|
|
1248
|
-
# sum residuals in node to prune
|
|
1249
|
-
prune_resid_left = resid_tree[prune_left]
|
|
1250
|
-
prune_resid_right = resid_tree[prune_right]
|
|
1251
|
-
prune_resid_total = prune_resid_left + prune_resid_right
|
|
1252
|
-
resid_tree = resid_tree.at[prune_node].set(prune_resid_total)
|
|
1253
|
-
|
|
1254
|
-
# Now resid_tree and count_tree contain correct values whatever move is
|
|
1255
|
-
# accepted.
|
|
1256
|
-
|
|
1257
|
-
# compute likelihood ratios
|
|
1258
|
-
grow_lk_ratio = compute_likelihood_ratio(grow_resid_total, grow_resid_left, grow_resid_right, move_counts['grow']['total'], move_counts['grow']['left'], move_counts['grow']['right'], sigma2, ntree)
|
|
1259
|
-
prune_lk_ratio = compute_likelihood_ratio(prune_resid_total, prune_resid_left, prune_resid_right, move_counts['prune']['total'], move_counts['prune']['left'], move_counts['prune']['right'], sigma2, ntree)
|
|
1260
|
-
|
|
1261
|
-
# compute acceptance ratios
|
|
1262
|
-
grow_ratio = grow_move['trans_prior_ratio'] * grow_lk_ratio
|
|
1263
|
-
if min_points_per_leaf is not None:
|
|
1264
|
-
grow_ratio = jnp.where(move_counts['grow']['left'] >= min_points_per_leaf, grow_ratio, 0)
|
|
1265
|
-
grow_ratio = jnp.where(move_counts['grow']['right'] >= min_points_per_leaf, grow_ratio, 0)
|
|
1266
|
-
prune_ratio = prune_move['trans_prior_ratio'] * prune_lk_ratio
|
|
1267
|
-
prune_ratio = lax.reciprocal(prune_ratio)
|
|
1268
|
-
|
|
1269
|
-
# save acceptance ratios
|
|
1343
|
+
# get indices of move
|
|
1344
|
+
node = move['node']
|
|
1345
|
+
assert node.dtype == jnp.int32
|
|
1346
|
+
left = move['left']
|
|
1347
|
+
right = move['right']
|
|
1348
|
+
|
|
1349
|
+
# sum residuals in parent node modified by move
|
|
1350
|
+
resid_left = resid_tree[left]
|
|
1351
|
+
resid_right = resid_tree[right]
|
|
1352
|
+
resid_total = resid_left + resid_right
|
|
1353
|
+
resid_tree = resid_tree.at[node].set(resid_total)
|
|
1354
|
+
|
|
1355
|
+
# compute acceptance ratio
|
|
1356
|
+
log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk)
|
|
1357
|
+
log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
|
|
1358
|
+
log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
|
|
1270
1359
|
ratios = {}
|
|
1271
1360
|
if save_ratios:
|
|
1272
1361
|
ratios.update(
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
likelihood=grow_lk_ratio,
|
|
1276
|
-
),
|
|
1277
|
-
prune=dict(
|
|
1278
|
-
trans_prior=lax.reciprocal(prune_move['trans_prior_ratio']),
|
|
1279
|
-
likelihood=lax.reciprocal(prune_lk_ratio),
|
|
1280
|
-
),
|
|
1362
|
+
log_trans_prior=move['log_trans_prior_ratio'],
|
|
1363
|
+
log_likelihood=log_lk_ratio,
|
|
1281
1364
|
)
|
|
1282
1365
|
|
|
1283
|
-
# determine what move to propose (not proposing anything is an option)
|
|
1284
|
-
grow_allowed = grow_move['num_growable'].astype(bool)
|
|
1285
|
-
p_grow = jnp.where(grow_allowed & prune_move['allowed'], 0.5, grow_allowed)
|
|
1286
|
-
try_grow = u[0] < p_grow # use < instead of <= because coins are in [0, 1)
|
|
1287
|
-
try_prune = prune_move['allowed'] & ~try_grow
|
|
1288
|
-
|
|
1289
1366
|
# determine whether to accept the move
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
split_tree = grow_move['split_tree']
|
|
1295
|
-
split_tree = split_tree.at[jnp.where(do_grow, split_tree.size, grow_node)].set(0)
|
|
1296
|
-
split_tree = split_tree.at[jnp.where(do_prune, prune_node, split_tree.size)].set(0)
|
|
1297
|
-
# I can leave garbage in var_tree, resid_tree, count_tree
|
|
1367
|
+
acc = move['allowed'] & (move['logu'] <= log_ratio)
|
|
1368
|
+
if min_points_per_leaf is not None:
|
|
1369
|
+
acc &= move_counts['left'] >= min_points_per_leaf
|
|
1370
|
+
acc &= move_counts['right'] >= min_points_per_leaf
|
|
1298
1371
|
|
|
1299
1372
|
# compute leaves posterior and sample leaves
|
|
1300
|
-
inv_sigma2 = lax.reciprocal(sigma2)
|
|
1301
|
-
prec_lk = count_tree * inv_sigma2
|
|
1302
|
-
var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
|
|
1303
|
-
mean_post = resid_tree * inv_sigma2 * var_post # = mean_lk * prec_lk * var_post
|
|
1304
1373
|
initial_leaf_tree = leaf_tree
|
|
1305
|
-
|
|
1374
|
+
mean_post = resid_tree * prelf['mean_factor']
|
|
1375
|
+
leaf_tree = mean_post + prelf['centered_leaves']
|
|
1306
1376
|
|
|
1307
|
-
# copy leaves around such that the
|
|
1308
|
-
|
|
1309
|
-
.at[jnp.where(do_prune, prune_left, leaf_tree.size)]
|
|
1310
|
-
.set(leaf_tree[prune_node])
|
|
1311
|
-
.at[jnp.where(do_prune, prune_right, leaf_tree.size)]
|
|
1312
|
-
.set(leaf_tree[prune_node])
|
|
1313
|
-
)
|
|
1377
|
+
# copy leaves around such that the leaf indices select the right leaf
|
|
1378
|
+
to_prune = acc ^ move['grow']
|
|
1314
1379
|
leaf_tree = (leaf_tree
|
|
1315
|
-
.at[jnp.where(
|
|
1316
|
-
.set(leaf_tree[
|
|
1317
|
-
.at[jnp.where(
|
|
1318
|
-
.set(leaf_tree[
|
|
1380
|
+
.at[jnp.where(to_prune, left, leaf_tree.size)]
|
|
1381
|
+
.set(leaf_tree[node])
|
|
1382
|
+
.at[jnp.where(to_prune, right, leaf_tree.size)]
|
|
1383
|
+
.set(leaf_tree[node])
|
|
1319
1384
|
)
|
|
1320
1385
|
|
|
1321
1386
|
# replace old tree with new tree in function values
|
|
1322
|
-
resid += (initial_leaf_tree - leaf_tree)[
|
|
1323
|
-
|
|
1324
|
-
# pack proposal and acceptance indicators
|
|
1325
|
-
counts = dict(
|
|
1326
|
-
grow_prop_count=try_grow,
|
|
1327
|
-
grow_acc_count=do_grow,
|
|
1328
|
-
prune_prop_count=try_prune,
|
|
1329
|
-
prune_acc_count=do_prune,
|
|
1330
|
-
)
|
|
1387
|
+
resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
|
|
1331
1388
|
|
|
1332
|
-
return resid, leaf_tree,
|
|
1389
|
+
return resid, leaf_tree, acc, to_prune, ratios
|
|
1333
1390
|
|
|
1334
1391
|
def sum_resid(resid, leaf_indices, tree_size, batch_size):
|
|
1335
1392
|
"""
|
|
@@ -1369,7 +1426,7 @@ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
|
|
|
1369
1426
|
.sum(axis=1)
|
|
1370
1427
|
)
|
|
1371
1428
|
|
|
1372
|
-
def compute_likelihood_ratio(total_resid, left_resid, right_resid,
|
|
1429
|
+
def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
|
|
1373
1430
|
"""
|
|
1374
1431
|
Compute the likelihood ratio of a grow move.
|
|
1375
1432
|
|
|
@@ -1379,37 +1436,23 @@ def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count,
|
|
|
1379
1436
|
The sum of the residuals in the leaf to grow.
|
|
1380
1437
|
left_resid, right_resid : float
|
|
1381
1438
|
The sum of the residuals in the left/right child of the leaf to grow.
|
|
1382
|
-
|
|
1383
|
-
The
|
|
1384
|
-
|
|
1385
|
-
The number of datapoints in the left/right child of the leaf to grow.
|
|
1386
|
-
sigma2 : float
|
|
1387
|
-
The noise variance.
|
|
1388
|
-
n_tree : int
|
|
1389
|
-
The number of trees in the forest.
|
|
1439
|
+
prelkv, prelk : dict
|
|
1440
|
+
The pre-computed terms of the likelihood ratio, see
|
|
1441
|
+
`precompute_likelihood_terms`.
|
|
1390
1442
|
|
|
1391
1443
|
Returns
|
|
1392
1444
|
-------
|
|
1393
1445
|
ratio : float
|
|
1394
1446
|
The likelihood ratio P(data | new tree) / P(data | old tree).
|
|
1395
1447
|
"""
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
sigma2_total = sigma2 + total_count * sigma_mu2
|
|
1401
|
-
|
|
1402
|
-
sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
|
|
1403
|
-
|
|
1404
|
-
exp_term = sigma_mu2 / (2 * sigma2) * (
|
|
1405
|
-
left_resid * left_resid / sigma2_left +
|
|
1406
|
-
right_resid * right_resid / sigma2_right -
|
|
1407
|
-
total_resid * total_resid / sigma2_total
|
|
1448
|
+
exp_term = prelk['exp_factor'] * (
|
|
1449
|
+
left_resid * left_resid / prelkv['sigma2_left'] +
|
|
1450
|
+
right_resid * right_resid / prelkv['sigma2_right'] -
|
|
1451
|
+
total_resid * total_resid / prelkv['sigma2_total']
|
|
1408
1452
|
)
|
|
1453
|
+
return prelkv['sqrt_term'] + exp_term
|
|
1409
1454
|
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
|
|
1455
|
+
def accept_moves_final_stage(bart, moves):
|
|
1413
1456
|
"""
|
|
1414
1457
|
The final part of accepting the moves, in parallel across trees.
|
|
1415
1458
|
|
|
@@ -1419,8 +1462,9 @@ def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
|
|
|
1419
1462
|
A partially updated BART mcmc state.
|
|
1420
1463
|
counts : dict
|
|
1421
1464
|
The indicators of proposals and acceptances for grow and prune moves.
|
|
1422
|
-
|
|
1423
|
-
The
|
|
1465
|
+
moves : dict
|
|
1466
|
+
The proposed moves (see `sample_moves`) as updated by
|
|
1467
|
+
`accept_moves_sequential_stage`.
|
|
1424
1468
|
|
|
1425
1469
|
Returns
|
|
1426
1470
|
-------
|
|
@@ -1428,15 +1472,14 @@ def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
|
|
|
1428
1472
|
The fully updated BART mcmc state.
|
|
1429
1473
|
"""
|
|
1430
1474
|
bart = bart.copy()
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
bart['leaf_indices'] = apply_moves_to_indices(bart['leaf_indices'], counts, grow_moves, prune_moves)
|
|
1436
|
-
|
|
1475
|
+
bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
|
|
1476
|
+
bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
|
|
1477
|
+
bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
|
|
1478
|
+
bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
|
|
1437
1479
|
return bart
|
|
1438
1480
|
|
|
1439
|
-
|
|
1481
|
+
@jax.vmap
|
|
1482
|
+
def apply_moves_to_leaf_indices(leaf_indices, moves):
|
|
1440
1483
|
"""
|
|
1441
1484
|
Update the leaf indices to match the accepted move.
|
|
1442
1485
|
|
|
@@ -1445,10 +1488,9 @@ def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
|
|
|
1445
1488
|
leaf_indices : int array (num_trees, n)
|
|
1446
1489
|
The index of the leaf each datapoint falls into, if the grow move was
|
|
1447
1490
|
accepted.
|
|
1448
|
-
|
|
1449
|
-
The
|
|
1450
|
-
|
|
1451
|
-
The proposals for the moves. See `grow_move` and `prune_move`.
|
|
1491
|
+
moves : dict
|
|
1492
|
+
The proposed moves (see `sample_moves`), as updated by
|
|
1493
|
+
`accept_moves_sequential_stage`.
|
|
1452
1494
|
|
|
1453
1495
|
Returns
|
|
1454
1496
|
-------
|
|
@@ -1456,19 +1498,46 @@ def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
|
|
|
1456
1498
|
The updated leaf indices.
|
|
1457
1499
|
"""
|
|
1458
1500
|
mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
|
|
1459
|
-
|
|
1460
|
-
leaf_indices = jnp.where(
|
|
1461
|
-
cond & ~counts['grow_acc_count'][:, None],
|
|
1462
|
-
grow_moves['node'][:, None].astype(leaf_indices.dtype),
|
|
1463
|
-
leaf_indices,
|
|
1464
|
-
)
|
|
1465
|
-
cond = (leaf_indices & mask) == prune_moves['left'][:, None]
|
|
1501
|
+
is_child = (leaf_indices & mask) == moves['left']
|
|
1466
1502
|
return jnp.where(
|
|
1467
|
-
|
|
1468
|
-
|
|
1503
|
+
is_child & moves['to_prune'],
|
|
1504
|
+
moves['node'].astype(leaf_indices.dtype),
|
|
1469
1505
|
leaf_indices,
|
|
1470
1506
|
)
|
|
1471
1507
|
|
|
1508
|
+
@jax.vmap
|
|
1509
|
+
def apply_moves_to_split_trees(split_trees, moves):
|
|
1510
|
+
"""
|
|
1511
|
+
Update the split trees to match the accepted move.
|
|
1512
|
+
|
|
1513
|
+
Parameters
|
|
1514
|
+
----------
|
|
1515
|
+
split_trees : int array (num_trees, 2 ** (d - 1))
|
|
1516
|
+
The cutpoints of the decision nodes in the initial trees.
|
|
1517
|
+
moves : dict
|
|
1518
|
+
The proposed moves (see `sample_moves`), as updated by
|
|
1519
|
+
`accept_moves_sequential_stage`.
|
|
1520
|
+
|
|
1521
|
+
Returns
|
|
1522
|
+
-------
|
|
1523
|
+
split_trees : int array (num_trees, 2 ** (d - 1))
|
|
1524
|
+
The updated split trees.
|
|
1525
|
+
"""
|
|
1526
|
+
return (split_trees
|
|
1527
|
+
.at[jnp.where(
|
|
1528
|
+
moves['grow'],
|
|
1529
|
+
moves['node'],
|
|
1530
|
+
split_trees.size,
|
|
1531
|
+
)]
|
|
1532
|
+
.set(moves['grow_split'].astype(split_trees.dtype))
|
|
1533
|
+
.at[jnp.where(
|
|
1534
|
+
moves['to_prune'],
|
|
1535
|
+
moves['node'],
|
|
1536
|
+
split_trees.size,
|
|
1537
|
+
)]
|
|
1538
|
+
.set(0)
|
|
1539
|
+
)
|
|
1540
|
+
|
|
1472
1541
|
def sample_sigma(bart, key):
|
|
1473
1542
|
"""
|
|
1474
1543
|
Noise variance sampling step of BART MCMC.
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
bartz/BART.py,sha256=CbGzFWtYw5u38Z9-Hy3CbDXpKOOvPFAAkSqu2HZl8no,16862
|
|
2
2
|
bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
|
|
3
|
-
bartz/_version.py,sha256=
|
|
3
|
+
bartz/_version.py,sha256=2eiWQI55fd-roDdkt4Hvl9WzrTJ4xQo33VzFud6D03U,22
|
|
4
4
|
bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
|
|
5
5
|
bartz/grove.py,sha256=x_6NK_l-hrXfy1PhssYNJkX41-w_WqjDziww0E7YRS8,8500
|
|
6
6
|
bartz/jaxext.py,sha256=RcVWTCGS8lXF7GBsNbKrpuA4MTcokItq0CpWm3s7CGk,12033
|
|
7
7
|
bartz/mcmcloop.py,sha256=lKDszvniNXka99X3e9RCrTgvEAZHA7ZbVXEgxUYvKMY,7634
|
|
8
|
-
bartz/mcmcstep.py,sha256=
|
|
8
|
+
bartz/mcmcstep.py,sha256=diI9vHXHMvu_Lk_bSJ-a038OnEbXDpNEikVPhRcxEys,54987
|
|
9
9
|
bartz/prepcovars.py,sha256=mMgfL-LGJ_8QpOL6iy7yfkL8A7FrT7Zfn5M3voyNwSQ,5818
|
|
10
|
-
bartz-0.
|
|
11
|
-
bartz-0.
|
|
12
|
-
bartz-0.
|
|
13
|
-
bartz-0.
|
|
10
|
+
bartz-0.4.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
|
|
11
|
+
bartz-0.4.0.dist-info/METADATA,sha256=K86CVXT6ayPnc2hjhreYGMEeYWfYJIZdDkKuBB0-FYA,4500
|
|
12
|
+
bartz-0.4.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
13
|
+
bartz-0.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|