bartz 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.3.0'
1
+ __version__ = '0.4.0'
bartz/mcmcstep.py CHANGED
@@ -148,6 +148,14 @@ def init(*,
148
148
  Whether the `min_points_per_leaf` parameter is specified.
149
149
  'resid_batch_size', 'count_batch_size' : int or None
150
150
  The data batch sizes for computing the sufficient statistics.
151
+ 'ratios' : dict, optional
152
+ If `save_ratios` is True, this field is present. It has the fields:
153
+
154
+ 'log_trans_prior' : large_float array (num_trees,)
155
+ The log transition and prior Metropolis-Hastings ratio for the
156
+ proposed move on each tree.
157
+ 'log_likelihood' : large_float array (num_trees,)
158
+ The log likelihood ratio.
151
159
  """
152
160
 
153
161
  p_nonterminal = jnp.asarray(p_nonterminal, large_float)
@@ -161,7 +169,7 @@ def init(*,
161
169
  small_float = jnp.dtype(small_float)
162
170
  large_float = jnp.dtype(large_float)
163
171
  y = jnp.asarray(y, small_float)
164
- resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y)
172
+ resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees)
165
173
  sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
166
174
  sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
167
175
 
@@ -202,19 +210,13 @@ def init(*,
202
210
 
203
211
  if save_ratios:
204
212
  bart['ratios'] = dict(
205
- grow=dict(
206
- trans_prior=jnp.full(num_trees, jnp.nan),
207
- likelihood=jnp.full(num_trees, jnp.nan),
208
- ),
209
- prune=dict(
210
- trans_prior=jnp.full(num_trees, jnp.nan),
211
- likelihood=jnp.full(num_trees, jnp.nan),
212
- ),
213
+ log_trans_prior=jnp.full(num_trees, jnp.nan),
214
+ log_likelihood=jnp.full(num_trees, jnp.nan),
213
215
  )
214
216
 
215
217
  return bart
216
218
 
217
- def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
219
+ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
218
220
 
219
221
  @functools.cache
220
222
  def get_platform():
@@ -244,6 +246,10 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
244
246
  n = max(1, y.size)
245
247
  count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
246
248
  # /4 is good on V100, /2 on L4/T4, still haven't tried A100
249
+ max_memory = 2 ** 29
250
+ itemsize = 4
251
+ min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
252
+ count_batch_size = max(count_batch_size, min_batch_size)
247
253
  count_batch_size = max(1, count_batch_size)
248
254
 
249
255
  return resid_batch_size, count_batch_size
@@ -289,8 +295,8 @@ def sample_trees(bart, key):
289
295
  This function zeroes the proposal counters.
290
296
  """
291
297
  key, subkey = random.split(key)
292
- grow_moves, prune_moves = sample_moves(bart, subkey)
293
- return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
298
+ moves = sample_moves(bart, subkey)
299
+ return accept_moves_and_sample_leaves(bart, moves, key)
294
300
 
295
301
  def sample_moves(bart, key):
296
302
  """
@@ -305,11 +311,65 @@ def sample_moves(bart, key):
305
311
 
306
312
  Returns
307
313
  -------
308
- grow_moves, prune_moves : dict
309
- The proposals for grow and prune moves. See `grow_move` and `prune_move`.
314
+ moves : dict
315
+ A dictionary with fields:
316
+
317
+ 'allowed' : bool array (num_trees,)
318
+ Whether the move is possible.
319
+ 'grow' : bool array (num_trees,)
320
+ Whether the move is a grow move or a prune move.
321
+ 'num_growable' : int array (num_trees,)
322
+ The number of growable leaves in the original tree.
323
+ 'node' : int array (num_trees,)
324
+ The index of the leaf to grow or node to prune.
325
+ 'left', 'right' : int array (num_trees,)
326
+ The indices of the children of 'node'.
327
+ 'partial_ratio' : float array (num_trees,)
328
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
329
+ the likelihood ratio and the probability of proposing the prune
330
+ move. If the move is Prune, the ratio is inverted.
331
+ 'grow_var' : int array (num_trees,)
332
+ The decision axes of the new rules.
333
+ 'grow_split' : int array (num_trees,)
334
+ The decision boundaries of the new rules.
335
+ 'var_trees' : int array (num_trees, 2 ** (d - 1))
336
+ The updated decision axes of the trees, valid whatever move.
337
+ 'logu' : float array (num_trees,)
338
+ The logarithm of a uniform (0, 1] random variable to be used to
339
+ accept the move. It's in (-oo, 0].
310
340
  """
311
- key = random.split(key, bart['var_trees'].shape[0])
312
- return _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], key)
341
+ ntree = bart['leaf_trees'].shape[0]
342
+ key = random.split(key, 1 + ntree)
343
+ key, subkey = key[0], key[1:]
344
+
345
+ # compute moves
346
+ grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey)
347
+
348
+ u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
349
+
350
+ # choose between grow or prune
351
+ grow_allowed = grow_moves['num_growable'].astype(bool)
352
+ p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
353
+ grow = u < p_grow # use < instead of <= because u is in [0, 1)
354
+
355
+ # compute children indices
356
+ node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
357
+ left = node << 1
358
+ right = left + 1
359
+
360
+ return dict(
361
+ allowed=grow | prune_moves['allowed'],
362
+ grow=grow,
363
+ num_growable=grow_moves['num_growable'],
364
+ node=node,
365
+ left=left,
366
+ right=right,
367
+ partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
368
+ grow_var=grow_moves['var'],
369
+ grow_split=grow_moves['split'],
370
+ var_trees=grow_moves['var_tree'],
371
+ logu=jnp.log1p(-logu),
372
+ )
313
373
 
314
374
  @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
315
375
  def _sample_moves_vmap_trees(*args):
@@ -354,16 +414,14 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
354
414
  'node' : int
355
415
  The index of the leaf to grow. ``2 ** d`` if there are no growable
356
416
  leaves.
357
- 'left', 'right' : int
358
- The indices of the children of 'node'.
359
417
  'var', 'split' : int
360
418
  The decision axis and boundary of the new rule.
361
419
  'partial_ratio' : float
362
420
  A factor of the Metropolis-Hastings ratio of the move. It lacks
363
421
  the likelihood ratio and the probability of proposing the prune
364
422
  move.
365
- 'var_tree', 'split_tree' : array (2 ** (d - 1),)
366
- The updated decision axes and boundaries of the tree.
423
+ 'var_tree' : array (2 ** (d - 1),)
424
+ The updated decision axes of the tree.
367
425
  """
368
426
 
369
427
  key, key1, key2 = random.split(key, 3)
@@ -374,21 +432,16 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
374
432
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
375
433
 
376
434
  split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
377
- split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
378
435
 
379
- ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
436
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)
380
437
 
381
- left = leaf_to_grow << 1
382
438
  return dict(
383
439
  num_growable=num_growable,
384
440
  node=leaf_to_grow,
385
- left=left,
386
- right=left + 1,
387
441
  var=var,
388
442
  split=split,
389
443
  partial_ratio=ratio,
390
444
  var_tree=var_tree,
391
- split_tree=split_tree,
392
445
  )
393
446
 
394
447
  def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
@@ -658,7 +711,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
658
711
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
659
712
  return random.randint(key, (), l, r)
660
713
 
661
- def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
714
+ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
662
715
  """
663
716
  Compute the product of the transition and prior ratios of a grow move.
664
717
 
@@ -673,8 +726,6 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
673
726
  The probability of a nonterminal node at each depth.
674
727
  leaf_to_grow : int
675
728
  The index of the leaf to grow.
676
- new_split_tree : array (2 ** (d - 1),)
677
- The splitting points of the tree, after the leaf is grown.
678
729
 
679
730
  Returns
680
731
  -------
@@ -699,7 +750,7 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
699
750
 
700
751
  inv_trans_ratio = p_grow * prob_choose * num_prunable
701
752
 
702
- depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
753
+ depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
703
754
  p_parent = p_nonterminal[depth]
704
755
  cp_children = 1 - p_nonterminal[depth + 1]
705
756
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
@@ -736,8 +787,6 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p
736
787
  Whether the move is possible.
737
788
  'node' : int
738
789
  The index of the node to prune. ``2 ** d`` if no node can be pruned.
739
- 'left', 'right' : int
740
- The indices of the children of 'node'.
741
790
  'partial_ratio' : float
742
791
  A factor of the Metropolis-Hastings ratio of the move. It lacks
743
792
  the likelihood ratio and the probability of proposing the prune
@@ -746,14 +795,11 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p
746
795
  node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
747
796
  allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
748
797
 
749
- ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune, split_tree)
798
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune)
750
799
 
751
- left = node_to_prune << 1
752
800
  return dict(
753
801
  allowed=allowed,
754
802
  node=node_to_prune,
755
- left=left,
756
- right=left + 1,
757
803
  partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
758
804
  )
759
805
 
@@ -819,7 +865,7 @@ def randint_masked(key, mask):
819
865
  u = random.randint(key, (), 0, ecdf[-1])
820
866
  return jnp.searchsorted(ecdf, u, 'right')
821
867
 
822
- def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
868
+ def accept_moves_and_sample_leaves(bart, moves, key):
823
869
  """
824
870
  Accept or reject the proposed moves and sample the new leaf values.
825
871
 
@@ -827,12 +873,8 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
827
873
  ----------
828
874
  bart : dict
829
875
  A BART mcmc state.
830
- grow_moves : dict
831
- The proposals for grow moves, batched over the first axis. See
832
- `grow_move`.
833
- prune_moves : dict
834
- The proposals for prune moves, batched over the first axis. See
835
- `prune_move`.
876
+ moves : dict
877
+ The proposed moves, see `sample_moves`.
836
878
  key : jax.dtypes.prng_key array
837
879
  A jax random key.
838
880
 
@@ -841,11 +883,11 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
841
883
  bart : dict
842
884
  The new BART mcmc state.
843
885
  """
844
- bart, grow_moves, prune_moves, count_trees, move_counts, u, z = accept_moves_parallel_stage(bart, grow_moves, prune_moves, key)
845
- bart, counts = accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z)
846
- return accept_moves_final_stage(bart, counts, grow_moves, prune_moves)
886
+ bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key)
887
+ bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)
888
+ return accept_moves_final_stage(bart, moves)
847
889
 
848
- def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
890
+ def accept_moves_parallel_stage(bart, moves, key):
849
891
  """
850
892
  Pre-computes quantities used to accept moves, in parallel across trees.
851
893
 
@@ -853,9 +895,8 @@ def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
853
895
  ----------
854
896
  bart : dict
855
897
  A BART mcmc state.
856
- grow_moves, prune_moves : dict
857
- The proposals for the moves, batched over the first axis. See
858
- `grow_move` and `prune_move`.
898
+ moves : dict
899
+ The proposed moves, see `sample_moves`.
859
900
  key : jax.dtypes.prng_key array
860
901
  A jax random key.
861
902
 
@@ -863,51 +904,50 @@ def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
863
904
  -------
864
905
  bart : dict
865
906
  A partially updated BART mcmc state.
866
- grow_moves, prune_moves : dict
867
- The proposals for the moves, with the field 'partial_ratio' replaced
868
- by 'trans_prior_ratio'.
869
- count_trees : array (num_trees, 2 ** (d - 1))
907
+ moves : dict
908
+ The proposed moves, with the field 'partial_ratio' replaced
909
+ by 'log_trans_prior_ratio'.
910
+ count_trees : array (num_trees, 2 ** d)
870
911
  The number of points in each potential or actual leaf node.
871
912
  move_counts : dict
872
913
  The counts of the number of points in the the nodes modified by the
873
914
  moves.
874
- u : float array (num_trees, 2)
875
- Random uniform values used to accept the moves.
876
- z : float array (num_trees, 2 ** d)
877
- Random standard normal values used to sample the new leaf values.
915
+ prelkv, prelk, prelf : dict
916
+ Dictionary with pre-computed terms of the likelihood ratios and leaf
917
+ samples.
878
918
  """
879
919
  bart = bart.copy()
880
920
 
881
- bart['var_trees'] = grow_moves['var_tree']
882
- # Since var_tree can contain garbage, I can set the var of leaf to be
883
- # grown irrespectively of what move I'm gonna accept in the end.
884
-
885
- bart['leaf_indices'] = apply_grow_to_indices(grow_moves, bart['leaf_indices'], bart['X'])
886
-
887
- count_trees, move_counts = compute_count_trees(bart['leaf_indices'], grow_moves, prune_moves, bart['opt']['count_batch_size'])
888
-
889
- grow_moves, prune_moves = complete_ratio(grow_moves, prune_moves, move_counts, bart['min_points_per_leaf'])
921
+ # where the move is grow, modify the state like the move was accepted
922
+ bart['var_trees'] = moves['var_trees']
923
+ bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
924
+ bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
890
925
 
926
+ # count number of datapoints per leaf
927
+ count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size'])
891
928
  if bart['opt']['require_min_points']:
892
- count_half_trees = count_trees[:, :grow_moves['split_tree'].shape[1]]
929
+ count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
893
930
  bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
894
931
 
895
- bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], grow_moves)
932
+ # compute some missing information about moves
933
+ moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
934
+ bart['grow_prop_count'] = jnp.sum(moves['grow'])
935
+ bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
896
936
 
897
- key, subkey = random.split(key)
898
- u = random.uniform(subkey, (len(bart['leaf_trees']), 2), bart['opt']['large_float'])
899
- z = random.normal(key, bart['leaf_trees'].shape, bart['opt']['large_float'])
937
+ prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts)
938
+ prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key)
900
939
 
901
- return bart, grow_moves, prune_moves, count_trees, move_counts, u, z
940
+ return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
902
941
 
903
- def apply_grow_to_indices(grow_moves, leaf_indices, X):
942
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
943
+ def apply_grow_to_indices(moves, leaf_indices, X):
904
944
  """
905
945
  Update the leaf indices to apply a grow move.
906
946
 
907
947
  Parameters
908
948
  ----------
909
- grow_moves : dict
910
- The proposals for grow moves. See `grow_move`.
949
+ moves : dict
950
+ The proposed moves, see `sample_moves`.
911
951
  leaf_indices : array (num_trees, n)
912
952
  The index of the leaf each datapoint falls into.
913
953
  X : array (p, n)
@@ -918,17 +958,17 @@ def apply_grow_to_indices(grow_moves, leaf_indices, X):
918
958
  grow_leaf_indices : array (num_trees, n)
919
959
  The updated leaf indices.
920
960
  """
921
- left_child = grow_moves['node'].astype(leaf_indices.dtype) << 1
922
- go_right = X[grow_moves['var'], :] >= grow_moves['split'][:, None]
923
- tree_size = jnp.array(2 * grow_moves['split_tree'].shape[1])
924
- node_to_update = jnp.where(grow_moves['num_growable'], grow_moves['node'], tree_size)
961
+ left_child = moves['node'].astype(leaf_indices.dtype) << 1
962
+ go_right = X[moves['grow_var'], :] >= moves['grow_split']
963
+ tree_size = jnp.array(2 * moves['var_trees'].size)
964
+ node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
925
965
  return jnp.where(
926
- leaf_indices == node_to_update[:, None],
927
- left_child[:, None] + go_right,
966
+ leaf_indices == node_to_update,
967
+ left_child + go_right,
928
968
  leaf_indices,
929
969
  )
930
970
 
931
- def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
971
+ def compute_count_trees(leaf_indices, moves, batch_size):
932
972
  """
933
973
  Count the number of datapoints in each leaf.
934
974
 
@@ -937,8 +977,8 @@ def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
937
977
  grow_leaf_indices : int array (num_trees, n)
938
978
  The index of the leaf each datapoint falls into, if the grow move is
939
979
  accepted.
940
- grow_moves, prune_moves : dict
941
- The proposals for the moves. See `grow_move` and `prune_move`.
980
+ moves : dict
981
+ The proposed moves, see `sample_moves`.
942
982
  batch_size : int or None
943
983
  The data batch size to use for the summation.
944
984
 
@@ -952,24 +992,20 @@ def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
952
992
  'left', 'right', and 'total'.
953
993
  """
954
994
 
955
- ntree, tree_size = grow_moves['split_tree'].shape
995
+ ntree, tree_size = moves['var_trees'].shape
956
996
  tree_size *= 2
957
- counts = dict(grow=dict(), prune=dict())
958
997
  tree_indices = jnp.arange(ntree)
959
998
 
960
- count_trees = count_datapoints_per_leaf(grow_leaf_indices, tree_size, batch_size)
999
+ count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
961
1000
 
962
- # count datapoints in leaf to grow
963
- counts['grow']['left'] = count_trees[tree_indices, grow_moves['left']]
964
- counts['grow']['right'] = count_trees[tree_indices, grow_moves['right']]
965
- counts['grow']['total'] = counts['grow']['left'] + counts['grow']['right']
966
- count_trees = count_trees.at[tree_indices, grow_moves['node']].set(counts['grow']['total'])
1001
+ # count datapoints in nodes modified by move
1002
+ counts = dict()
1003
+ counts['left'] = count_trees[tree_indices, moves['left']]
1004
+ counts['right'] = count_trees[tree_indices, moves['right']]
1005
+ counts['total'] = counts['left'] + counts['right']
967
1006
 
968
- # count datapoints in node to prune
969
- counts['prune']['left'] = count_trees[tree_indices, prune_moves['left']]
970
- counts['prune']['right'] = count_trees[tree_indices, prune_moves['right']]
971
- counts['prune']['total'] = counts['prune']['left'] + counts['prune']['right']
972
- count_trees = count_trees.at[tree_indices, prune_moves['node']].set(counts['prune']['total'])
1007
+ # write count into non-leaf node
1008
+ count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
973
1009
 
974
1010
  return count_trees, counts
975
1011
 
@@ -1025,7 +1061,7 @@ def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1025
1061
  .sum(axis=2)
1026
1062
  )
1027
1063
 
1028
- def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
1064
+ def complete_ratio(moves, move_counts, min_points_per_leaf):
1029
1065
  """
1030
1066
  Complete non-likelihood MH ratio calculation.
1031
1067
 
@@ -1033,8 +1069,8 @@ def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
1033
1069
 
1034
1070
  Parameters
1035
1071
  ----------
1036
- grow_moves, prune_moves : dict
1037
- The proposals for the moves. See `grow_move` and `prune_move`.
1072
+ moves : dict
1073
+ The proposed moves, see `sample_moves`.
1038
1074
  move_counts : dict
1039
1075
  The counts of the number of points in the the nodes modified by the
1040
1076
  moves.
@@ -1043,61 +1079,62 @@ def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
1043
1079
 
1044
1080
  Returns
1045
1081
  -------
1046
- grow_moves, prune_moves : dict
1047
- The proposals for the moves, with the field 'partial_ratio' replaced
1048
- by 'trans_prior_ratio'.
1082
+ moves : dict
1083
+ The updated moves, with the field 'partial_ratio' replaced by
1084
+ 'log_trans_prior_ratio'.
1049
1085
  """
1050
- grow_moves = grow_moves.copy()
1051
- prune_moves = prune_moves.copy()
1052
- compute_p_prune_vec = jax.vmap(compute_p_prune, in_axes=(0, 0, 0, None))
1053
- grow_p_prune, prune_p_prune = compute_p_prune_vec(grow_moves, move_counts['grow']['left'], move_counts['grow']['right'], min_points_per_leaf)
1054
- grow_moves['trans_prior_ratio'] = grow_moves.pop('partial_ratio') * grow_p_prune
1055
- prune_moves['trans_prior_ratio'] = prune_moves.pop('partial_ratio') * prune_p_prune
1056
- return grow_moves, prune_moves
1086
+ moves = moves.copy()
1087
+ p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf)
1088
+ moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
1089
+ return moves
1057
1090
 
1058
- def compute_p_prune(grow_move, grow_left_count, grow_right_count, min_points_per_leaf):
1091
+ def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1059
1092
  """
1060
1093
  Compute the probability of proposing a prune move.
1061
1094
 
1062
1095
  Parameters
1063
1096
  ----------
1064
- grow_move : dict
1065
- The proposal for the grow move, see `grow_move`.
1066
- grow_left_count, grow_right_count : int
1097
+ moves : dict
1098
+ The proposed moves, see `sample_moves`.
1099
+ left_count, right_count : int
1067
1100
  The number of datapoints in the proposed children of the leaf to grow.
1068
1101
  min_points_per_leaf : int or None
1069
1102
  The minimum number of data points in a leaf node.
1070
1103
 
1071
1104
  Returns
1072
1105
  -------
1073
- grow_p_prune : float
1074
- The probability of proposing a prune move, after accepting the grow
1075
- move.
1076
- prune_p_prune : float
1077
- The probability of proposing the prune move.
1078
- """
1079
- other_growable_leaves = grow_move['num_growable'] >= 2
1080
- new_leaves_growable = grow_move['node'] < grow_move['split_tree'].size // 2
1106
+ p_prune : float
1107
+ The probability of proposing a prune move. If grow: after accepting the
1108
+ grow move, if prune: right away.
1109
+ """
1110
+
1111
+ # calculation in case the move is grow
1112
+ other_growable_leaves = moves['num_growable'] >= 2
1113
+ new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
1081
1114
  if min_points_per_leaf is not None:
1082
- any_above_threshold = grow_left_count >= 2 * min_points_per_leaf
1083
- any_above_threshold |= grow_right_count >= 2 * min_points_per_leaf
1115
+ any_above_threshold = left_count >= 2 * min_points_per_leaf
1116
+ any_above_threshold |= right_count >= 2 * min_points_per_leaf
1084
1117
  new_leaves_growable &= any_above_threshold
1085
1118
  grow_again_allowed = other_growable_leaves | new_leaves_growable
1086
1119
  grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1087
- prune_p_prune = jnp.where(grow_move['num_growable'], 0.5, 1)
1088
- return grow_p_prune, prune_p_prune
1089
1120
 
1090
- def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
1121
+ # calculation in case the move is prune
1122
+ prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
1123
+
1124
+ return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
1125
+
1126
+ @jaxext.vmap_nodoc
1127
+ def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1091
1128
  """
1092
- Modify leaf values such that the indices of the grow move work on the
1129
+ Modify leaf values such that the indices of the grow moves work on the
1093
1130
  original tree.
1094
1131
 
1095
1132
  Parameters
1096
1133
  ----------
1097
1134
  leaf_trees : float array (num_trees, 2 ** d)
1098
1135
  The leaf values.
1099
- grow_moves : dict
1100
- The proposals for grow moves. See `grow_move`.
1136
+ moves : dict
1137
+ The proposed moves, see `sample_moves`.
1101
1138
 
1102
1139
  Returns
1103
1140
  -------
@@ -1105,17 +1142,86 @@ def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
1105
1142
  The modified leaf values. The value of the leaf to grow is copied to
1106
1143
  what would be its children if the grow move was accepted.
1107
1144
  """
1108
- ntree, _ = leaf_trees.shape
1109
- tree_indices = jnp.arange(ntree)
1110
- values_at_node = leaf_trees[tree_indices, grow_moves['node']]
1145
+ values_at_node = leaf_trees[moves['node']]
1111
1146
  return (leaf_trees
1112
- .at[tree_indices, grow_moves['left']]
1147
+ .at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1113
1148
  .set(values_at_node)
1114
- .at[tree_indices, grow_moves['right']]
1149
+ .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1115
1150
  .set(values_at_node)
1116
1151
  )
1117
1152
 
1118
- def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z):
1153
+ def precompute_likelihood_terms(count_trees, sigma2, move_counts):
1154
+ """
1155
+ Pre-compute terms used in the likelihood ratio of the acceptance step.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ count_trees : array (num_trees, 2 ** d)
1160
+ The number of points in each potential or actual leaf node.
1161
+ sigma2 : float
1162
+ The noise variance.
1163
+ move_counts : dict
1164
+ The counts of the number of points in the the nodes modified by the
1165
+ moves.
1166
+
1167
+ Returns
1168
+ -------
1169
+ prelkv : dict
1170
+ Dictionary with pre-computed terms of the likelihood ratio, one per
1171
+ tree.
1172
+ prelk : dict
1173
+ Dictionary with pre-computed terms of the likelihood ratio, shared by
1174
+ all trees.
1175
+ """
1176
+ ntree = len(count_trees)
1177
+ sigma_mu2 = 1 / ntree
1178
+ prelkv = dict()
1179
+ prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2
1180
+ prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2
1181
+ prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2
1182
+ prelkv['sqrt_term'] = jnp.log(
1183
+ sigma2 * prelkv['sigma2_total'] /
1184
+ (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1185
+ ) / 2
1186
+ return prelkv, dict(
1187
+ exp_factor=sigma_mu2 / (2 * sigma2),
1188
+ )
1189
+
1190
+ def precompute_leaf_terms(count_trees, sigma2, key):
1191
+ """
1192
+ Pre-compute terms used to sample leaves from their posterior.
1193
+
1194
+ Parameters
1195
+ ----------
1196
+ count_trees : array (num_trees, 2 ** d)
1197
+ The number of points in each potential or actual leaf node.
1198
+ sigma2 : float
1199
+ The noise variance.
1200
+ key : jax.dtypes.prng_key array
1201
+ A jax random key.
1202
+
1203
+ Returns
1204
+ -------
1205
+ prelf : dict
1206
+ Dictionary with pre-computed terms of the leaf sampling, with fields:
1207
+
1208
+ 'mean_factor' : float array (num_trees, 2 ** d)
1209
+ The factor to be multiplied by the sum of residuals to obtain the
1210
+ posterior mean.
1211
+ 'centered_leaves' : float array (num_trees, 2 ** d)
1212
+ The mean-zero normal values to be added to the posterior mean to
1213
+ obtain the posterior leaf samples.
1214
+ """
1215
+ ntree = len(count_trees)
1216
+ prec_lk = count_trees / sigma2
1217
+ var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1218
+ z = random.normal(key, count_trees.shape, sigma2.dtype)
1219
+ return dict(
1220
+ mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1221
+ centered_leaves=z * jnp.sqrt(var_post),
1222
+ )
1223
+
1224
+ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf):
1119
1225
  """
1120
1226
  The part of accepting the moves that has to be done one tree at a time.
1121
1227
 
@@ -1123,57 +1229,63 @@ def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, mo
1123
1229
  ----------
1124
1230
  bart : dict
1125
1231
  A partially updated BART mcmc state.
1126
- count_trees : array (num_trees, 2 ** (d - 1))
1232
+ count_trees : array (num_trees, 2 ** d)
1127
1233
  The number of points in each potential or actual leaf node.
1128
- grow_moves, prune_moves : dict
1129
- The proposals for the moves, with completed ratios. See `grow_move` and
1130
- `prune_move`.
1234
+ moves : dict
1235
+ The proposed moves, see `sample_moves`.
1131
1236
  move_counts : dict
1132
1237
  The counts of the number of points in the the nodes modified by the
1133
1238
  moves.
1134
- u : float array (num_trees, 2)
1135
- Random uniform values used to for proposal and accept decisions.
1136
- z : float array (num_trees, 2 ** d)
1137
- Random standard normal values used to sample the new leaf values.
1239
+ prelkv, prelk, prelf : dict
1240
+ Dictionaries with pre-computed terms of the likelihood ratios and leaf
1241
+ samples.
1138
1242
 
1139
1243
  Returns
1140
1244
  -------
1141
1245
  bart : dict
1142
1246
  A partially updated BART mcmc state.
1143
- counts : dict
1144
- The indicators of proposals and acceptances for grow and prune moves.
1247
+ moves : dict
1248
+ The proposed moves, with these additional fields:
1249
+
1250
+ 'acc' : bool array (num_trees,)
1251
+ Whether the move was accepted.
1252
+ 'to_prune' : bool array (num_trees,)
1253
+ Whether, to reflect the acceptance status of the move, the state
1254
+ should be updated by pruning the leaves involved in the move.
1145
1255
  """
1146
1256
  bart = bart.copy()
1257
+ moves = moves.copy()
1147
1258
 
1148
1259
  def loop(resid, item):
1149
- resid, leaf_tree, split_tree, counts, ratios = accept_move_and_sample_leaves(
1260
+ resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
1150
1261
  bart['X'],
1151
1262
  len(bart['leaf_trees']),
1152
1263
  bart['opt']['resid_batch_size'],
1153
1264
  resid,
1154
- bart['sigma2'],
1155
1265
  bart['min_points_per_leaf'],
1156
1266
  'ratios' in bart,
1267
+ prelk,
1157
1268
  *item,
1158
1269
  )
1159
- return resid, (leaf_tree, split_tree, counts, ratios)
1270
+ return resid, (leaf_tree, acc, to_prune, ratios)
1160
1271
 
1161
1272
  items = (
1162
1273
  bart['leaf_trees'], count_trees,
1163
- grow_moves, prune_moves, move_counts,
1274
+ moves, move_counts,
1164
1275
  bart['leaf_indices'],
1165
- u, z,
1276
+ prelkv, prelf,
1166
1277
  )
1167
- resid, (leaf_trees, split_trees, counts, ratios) = lax.scan(loop, bart['resid'], items)
1278
+ resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
1168
1279
 
1169
1280
  bart['resid'] = resid
1170
1281
  bart['leaf_trees'] = leaf_trees
1171
- bart['split_trees'] = split_trees
1172
1282
  bart.get('ratios', {}).update(ratios)
1283
+ moves['acc'] = acc
1284
+ moves['to_prune'] = to_prune
1173
1285
 
1174
- return bart, counts
1286
+ return bart, moves
1175
1287
 
1176
- def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min_points_per_leaf, save_ratios, leaf_tree, count_tree, grow_move, prune_move, move_counts, grow_leaf_indices, u, z):
1288
+ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf):
1177
1289
  """
1178
1290
  Accept or reject a proposed move and sample the new leaf values.
1179
1291
 
@@ -1187,25 +1299,25 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min
1187
1299
  The batch size for computing the sum of residuals in each leaf.
1188
1300
  resid : float array (n,)
1189
1301
  The residuals (data minus forest value).
1190
- sigma2 : float
1191
- The noise variance.
1192
1302
  min_points_per_leaf : int or None
1193
1303
  The minimum number of data points in a leaf node.
1194
1304
  save_ratios : bool
1195
1305
  Whether to save the acceptance ratios.
1306
+ prelk : dict
1307
+ The pre-computed terms of the likelihood ratio which are shared across
1308
+ trees.
1196
1309
  leaf_tree : float array (2 ** d,)
1197
1310
  The leaf values of the tree.
1198
1311
  count_tree : int array (2 ** d,)
1199
1312
  The number of datapoints in each leaf.
1200
- grow_move, prune_move : dict
1201
- The proposals for the moves, with completed ratios. See `grow_move` and
1202
- `prune_move`.
1203
- grow_leaf_indices : int array (n,)
1204
- The leaf indices of the tree proposed by the grow move.
1205
- u : float array (2,)
1206
- Two uniform random values in [0, 1).
1207
- z : float array (2 ** d,)
1208
- Standard normal random values.
1313
+ move : dict
1314
+ The proposed move, see `sample_moves`.
1315
+ leaf_indices : int array (n,)
1316
+ The leaf indices for the largest version of the tree compatible with
1317
+ the move.
1318
+ prelkv, prelf : dict
1319
+ The pre-computed terms of the likelihood ratio and leaf sampling which
1320
+ are specific to the tree.
1209
1321
 
1210
1322
  Returns
1211
1323
  -------
@@ -1213,123 +1325,68 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min
1213
1325
  The updated residuals (data minus forest value).
1214
1326
  leaf_tree : float array (2 ** d,)
1215
1327
  The new leaf values of the tree.
1216
- split_tree : int array (2 ** (d - 1),)
1217
- The updated decision boundaries of the tree.
1218
- counts : dict
1219
- The indicators of proposals and acceptances for grow and prune moves.
1328
+ acc : bool
1329
+ Whether the move was accepted.
1330
+ to_prune : bool
1331
+ Whether, to reflect the acceptance status of the move, the state should
1332
+ be updated by pruning the leaves involved in the move.
1220
1333
  ratios : dict
1221
1334
  The acceptance ratios for the moves. Empty if not to be saved.
1222
1335
  """
1223
1336
 
1224
1337
  # sum residuals and count units per leaf, in tree proposed by grow move
1225
- resid_tree = sum_resid(resid, grow_leaf_indices, leaf_tree.size, resid_batch_size)
1338
+ resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size)
1226
1339
 
1227
1340
  # subtract starting tree from function
1228
1341
  resid_tree += count_tree * leaf_tree
1229
1342
 
1230
- # get indices of grow move
1231
- grow_node = grow_move['node']
1232
- assert grow_node.dtype == jnp.int32
1233
- grow_left = grow_move['left']
1234
- grow_right = grow_move['right']
1235
-
1236
- # sum residuals in leaf to grow
1237
- grow_resid_left = resid_tree[grow_left]
1238
- grow_resid_right = resid_tree[grow_right]
1239
- grow_resid_total = grow_resid_left + grow_resid_right
1240
- resid_tree = resid_tree.at[grow_node].set(grow_resid_total)
1241
-
1242
- # get indices of prune move
1243
- prune_node = prune_move['node']
1244
- assert prune_node.dtype == jnp.int32
1245
- prune_left = prune_move['left']
1246
- prune_right = prune_move['right']
1247
-
1248
- # sum residuals in node to prune
1249
- prune_resid_left = resid_tree[prune_left]
1250
- prune_resid_right = resid_tree[prune_right]
1251
- prune_resid_total = prune_resid_left + prune_resid_right
1252
- resid_tree = resid_tree.at[prune_node].set(prune_resid_total)
1253
-
1254
- # Now resid_tree and count_tree contain correct values whatever move is
1255
- # accepted.
1256
-
1257
- # compute likelihood ratios
1258
- grow_lk_ratio = compute_likelihood_ratio(grow_resid_total, grow_resid_left, grow_resid_right, move_counts['grow']['total'], move_counts['grow']['left'], move_counts['grow']['right'], sigma2, ntree)
1259
- prune_lk_ratio = compute_likelihood_ratio(prune_resid_total, prune_resid_left, prune_resid_right, move_counts['prune']['total'], move_counts['prune']['left'], move_counts['prune']['right'], sigma2, ntree)
1260
-
1261
- # compute acceptance ratios
1262
- grow_ratio = grow_move['trans_prior_ratio'] * grow_lk_ratio
1263
- if min_points_per_leaf is not None:
1264
- grow_ratio = jnp.where(move_counts['grow']['left'] >= min_points_per_leaf, grow_ratio, 0)
1265
- grow_ratio = jnp.where(move_counts['grow']['right'] >= min_points_per_leaf, grow_ratio, 0)
1266
- prune_ratio = prune_move['trans_prior_ratio'] * prune_lk_ratio
1267
- prune_ratio = lax.reciprocal(prune_ratio)
1268
-
1269
- # save acceptance ratios
1343
+ # get indices of move
1344
+ node = move['node']
1345
+ assert node.dtype == jnp.int32
1346
+ left = move['left']
1347
+ right = move['right']
1348
+
1349
+ # sum residuals in parent node modified by move
1350
+ resid_left = resid_tree[left]
1351
+ resid_right = resid_tree[right]
1352
+ resid_total = resid_left + resid_right
1353
+ resid_tree = resid_tree.at[node].set(resid_total)
1354
+
1355
+ # compute acceptance ratio
1356
+ log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk)
1357
+ log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
1358
+ log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
1270
1359
  ratios = {}
1271
1360
  if save_ratios:
1272
1361
  ratios.update(
1273
- grow=dict(
1274
- trans_prior=grow_move['trans_prior_ratio'],
1275
- likelihood=grow_lk_ratio,
1276
- ),
1277
- prune=dict(
1278
- trans_prior=lax.reciprocal(prune_move['trans_prior_ratio']),
1279
- likelihood=lax.reciprocal(prune_lk_ratio),
1280
- ),
1362
+ log_trans_prior=move['log_trans_prior_ratio'],
1363
+ log_likelihood=log_lk_ratio,
1281
1364
  )
1282
1365
 
1283
- # determine what move to propose (not proposing anything is an option)
1284
- grow_allowed = grow_move['num_growable'].astype(bool)
1285
- p_grow = jnp.where(grow_allowed & prune_move['allowed'], 0.5, grow_allowed)
1286
- try_grow = u[0] < p_grow # use < instead of <= because coins are in [0, 1)
1287
- try_prune = prune_move['allowed'] & ~try_grow
1288
-
1289
1366
  # determine whether to accept the move
1290
- do_grow = try_grow & (u[1] < grow_ratio)
1291
- do_prune = try_prune & (u[1] < prune_ratio)
1292
-
1293
- # pick split tree for chosen move
1294
- split_tree = grow_move['split_tree']
1295
- split_tree = split_tree.at[jnp.where(do_grow, split_tree.size, grow_node)].set(0)
1296
- split_tree = split_tree.at[jnp.where(do_prune, prune_node, split_tree.size)].set(0)
1297
- # I can leave garbage in var_tree, resid_tree, count_tree
1367
+ acc = move['allowed'] & (move['logu'] <= log_ratio)
1368
+ if min_points_per_leaf is not None:
1369
+ acc &= move_counts['left'] >= min_points_per_leaf
1370
+ acc &= move_counts['right'] >= min_points_per_leaf
1298
1371
 
1299
1372
  # compute leaves posterior and sample leaves
1300
- inv_sigma2 = lax.reciprocal(sigma2)
1301
- prec_lk = count_tree * inv_sigma2
1302
- var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1303
- mean_post = resid_tree * inv_sigma2 * var_post # = mean_lk * prec_lk * var_post
1304
1373
  initial_leaf_tree = leaf_tree
1305
- leaf_tree = mean_post + z * jnp.sqrt(var_post)
1374
+ mean_post = resid_tree * prelf['mean_factor']
1375
+ leaf_tree = mean_post + prelf['centered_leaves']
1306
1376
 
1307
- # copy leaves around such that the grow leaf indices select the right leaf
1308
- leaf_tree = (leaf_tree
1309
- .at[jnp.where(do_prune, prune_left, leaf_tree.size)]
1310
- .set(leaf_tree[prune_node])
1311
- .at[jnp.where(do_prune, prune_right, leaf_tree.size)]
1312
- .set(leaf_tree[prune_node])
1313
- )
1377
+ # copy leaves around such that the leaf indices select the right leaf
1378
+ to_prune = acc ^ move['grow']
1314
1379
  leaf_tree = (leaf_tree
1315
- .at[jnp.where(do_grow, leaf_tree.size, grow_left)]
1316
- .set(leaf_tree[grow_node])
1317
- .at[jnp.where(do_grow, leaf_tree.size, grow_right)]
1318
- .set(leaf_tree[grow_node])
1380
+ .at[jnp.where(to_prune, left, leaf_tree.size)]
1381
+ .set(leaf_tree[node])
1382
+ .at[jnp.where(to_prune, right, leaf_tree.size)]
1383
+ .set(leaf_tree[node])
1319
1384
  )
1320
1385
 
1321
1386
  # replace old tree with new tree in function values
1322
- resid += (initial_leaf_tree - leaf_tree)[grow_leaf_indices]
1323
-
1324
- # pack proposal and acceptance indicators
1325
- counts = dict(
1326
- grow_prop_count=try_grow,
1327
- grow_acc_count=do_grow,
1328
- prune_prop_count=try_prune,
1329
- prune_acc_count=do_prune,
1330
- )
1387
+ resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
1331
1388
 
1332
- return resid, leaf_tree, split_tree, counts, ratios
1389
+ return resid, leaf_tree, acc, to_prune, ratios
1333
1390
 
1334
1391
  def sum_resid(resid, leaf_indices, tree_size, batch_size):
1335
1392
  """
@@ -1369,7 +1426,7 @@ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1369
1426
  .sum(axis=1)
1370
1427
  )
1371
1428
 
1372
- def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count, left_count, right_count, sigma2, n_tree):
1429
+ def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
1373
1430
  """
1374
1431
  Compute the likelihood ratio of a grow move.
1375
1432
 
@@ -1379,37 +1436,23 @@ def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count,
1379
1436
  The sum of the residuals in the leaf to grow.
1380
1437
  left_resid, right_resid : float
1381
1438
  The sum of the residuals in the left/right child of the leaf to grow.
1382
- total_count : int
1383
- The number of datapoints in the leaf to grow.
1384
- left_count, right_count : int
1385
- The number of datapoints in the left/right child of the leaf to grow.
1386
- sigma2 : float
1387
- The noise variance.
1388
- n_tree : int
1389
- The number of trees in the forest.
1439
+ prelkv, prelk : dict
1440
+ The pre-computed terms of the likelihood ratio, see
1441
+ `precompute_likelihood_terms`.
1390
1442
 
1391
1443
  Returns
1392
1444
  -------
1393
1445
  ratio : float
1394
1446
  The likelihood ratio P(data | new tree) / P(data | old tree).
1395
1447
  """
1396
-
1397
- sigma_mu2 = 1 / n_tree
1398
- sigma2_left = sigma2 + left_count * sigma_mu2
1399
- sigma2_right = sigma2 + right_count * sigma_mu2
1400
- sigma2_total = sigma2 + total_count * sigma_mu2
1401
-
1402
- sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
1403
-
1404
- exp_term = sigma_mu2 / (2 * sigma2) * (
1405
- left_resid * left_resid / sigma2_left +
1406
- right_resid * right_resid / sigma2_right -
1407
- total_resid * total_resid / sigma2_total
1448
+ exp_term = prelk['exp_factor'] * (
1449
+ left_resid * left_resid / prelkv['sigma2_left'] +
1450
+ right_resid * right_resid / prelkv['sigma2_right'] -
1451
+ total_resid * total_resid / prelkv['sigma2_total']
1408
1452
  )
1453
+ return prelkv['sqrt_term'] + exp_term
1409
1454
 
1410
- return jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1411
-
1412
- def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
1455
+ def accept_moves_final_stage(bart, moves):
1413
1456
  """
1414
1457
  The final part of accepting the moves, in parallel across trees.
1415
1458
 
@@ -1419,8 +1462,9 @@ def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
1419
1462
  A partially updated BART mcmc state.
1420
1463
  counts : dict
1421
1464
  The indicators of proposals and acceptances for grow and prune moves.
1422
- grow_moves, prune_moves : dict
1423
- The proposals for the moves. See `grow_move` and `prune_move`.
1465
+ moves : dict
1466
+ The proposed moves (see `sample_moves`) as updated by
1467
+ `accept_moves_sequential_stage`.
1424
1468
 
1425
1469
  Returns
1426
1470
  -------
@@ -1428,15 +1472,14 @@ def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
1428
1472
  The fully updated BART mcmc state.
1429
1473
  """
1430
1474
  bart = bart.copy()
1431
-
1432
- for k, v in counts.items():
1433
- bart[k] = jnp.sum(v, axis=0)
1434
-
1435
- bart['leaf_indices'] = apply_moves_to_indices(bart['leaf_indices'], counts, grow_moves, prune_moves)
1436
-
1475
+ bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
1476
+ bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
1477
+ bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
1478
+ bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
1437
1479
  return bart
1438
1480
 
1439
- def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
1481
+ @jax.vmap
1482
+ def apply_moves_to_leaf_indices(leaf_indices, moves):
1440
1483
  """
1441
1484
  Update the leaf indices to match the accepted move.
1442
1485
 
@@ -1445,10 +1488,9 @@ def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
1445
1488
  leaf_indices : int array (num_trees, n)
1446
1489
  The index of the leaf each datapoint falls into, if the grow move was
1447
1490
  accepted.
1448
- counts : dict
1449
- The indicators of proposals and acceptances for grow and prune moves.
1450
- grow_moves, prune_moves : dict
1451
- The proposals for the moves. See `grow_move` and `prune_move`.
1491
+ moves : dict
1492
+ The proposed moves (see `sample_moves`), as updated by
1493
+ `accept_moves_sequential_stage`.
1452
1494
 
1453
1495
  Returns
1454
1496
  -------
@@ -1456,19 +1498,46 @@ def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
1456
1498
  The updated leaf indices.
1457
1499
  """
1458
1500
  mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1459
- cond = (leaf_indices & mask) == grow_moves['left'][:, None]
1460
- leaf_indices = jnp.where(
1461
- cond & ~counts['grow_acc_count'][:, None],
1462
- grow_moves['node'][:, None].astype(leaf_indices.dtype),
1463
- leaf_indices,
1464
- )
1465
- cond = (leaf_indices & mask) == prune_moves['left'][:, None]
1501
+ is_child = (leaf_indices & mask) == moves['left']
1466
1502
  return jnp.where(
1467
- cond & counts['prune_acc_count'][:, None],
1468
- prune_moves['node'][:, None].astype(leaf_indices.dtype),
1503
+ is_child & moves['to_prune'],
1504
+ moves['node'].astype(leaf_indices.dtype),
1469
1505
  leaf_indices,
1470
1506
  )
1471
1507
 
1508
+ @jax.vmap
1509
+ def apply_moves_to_split_trees(split_trees, moves):
1510
+ """
1511
+ Update the split trees to match the accepted move.
1512
+
1513
+ Parameters
1514
+ ----------
1515
+ split_trees : int array (num_trees, 2 ** (d - 1))
1516
+ The cutpoints of the decision nodes in the initial trees.
1517
+ moves : dict
1518
+ The proposed moves (see `sample_moves`), as updated by
1519
+ `accept_moves_sequential_stage`.
1520
+
1521
+ Returns
1522
+ -------
1523
+ split_trees : int array (num_trees, 2 ** (d - 1))
1524
+ The updated split trees.
1525
+ """
1526
+ return (split_trees
1527
+ .at[jnp.where(
1528
+ moves['grow'],
1529
+ moves['node'],
1530
+ split_trees.size,
1531
+ )]
1532
+ .set(moves['grow_split'].astype(split_trees.dtype))
1533
+ .at[jnp.where(
1534
+ moves['to_prune'],
1535
+ moves['node'],
1536
+ split_trees.size,
1537
+ )]
1538
+ .set(0)
1539
+ )
1540
+
1472
1541
  def sample_sigma(bart, key):
1473
1542
  """
1474
1543
  Noise variance sampling step of BART MCMC.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bartz
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: A JAX implementation of BART
5
5
  Home-page: https://github.com/Gattocrucco/bartz
6
6
  License: MIT
@@ -1,13 +1,13 @@
1
1
  bartz/BART.py,sha256=CbGzFWtYw5u38Z9-Hy3CbDXpKOOvPFAAkSqu2HZl8no,16862
2
2
  bartz/__init__.py,sha256=E96vsP0bZ8brejpZmEmRoXuMsUdinO_B_SKUUl1rLsg,1448
3
- bartz/_version.py,sha256=3wVEs2QD_7OcTlD97cZdCeizd2hUbJJ0GeIO8wQIjrk,22
3
+ bartz/_version.py,sha256=2eiWQI55fd-roDdkt4Hvl9WzrTJ4xQo33VzFud6D03U,22
4
4
  bartz/debug.py,sha256=9ZH-JfwZVu5OPhHBEyXQHAU5H9KIu1vxLK7yNv4m4Ew,5314
5
5
  bartz/grove.py,sha256=x_6NK_l-hrXfy1PhssYNJkX41-w_WqjDziww0E7YRS8,8500
6
6
  bartz/jaxext.py,sha256=RcVWTCGS8lXF7GBsNbKrpuA4MTcokItq0CpWm3s7CGk,12033
7
7
  bartz/mcmcloop.py,sha256=lKDszvniNXka99X3e9RCrTgvEAZHA7ZbVXEgxUYvKMY,7634
8
- bartz/mcmcstep.py,sha256=HPcxfl5f-OESZul-iurn0JmOnUJBe6IYTVaATeR6YBA,54221
8
+ bartz/mcmcstep.py,sha256=diI9vHXHMvu_Lk_bSJ-a038OnEbXDpNEikVPhRcxEys,54987
9
9
  bartz/prepcovars.py,sha256=mMgfL-LGJ_8QpOL6iy7yfkL8A7FrT7Zfn5M3voyNwSQ,5818
10
- bartz-0.3.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
- bartz-0.3.0.dist-info/METADATA,sha256=ymZNoowDdqQFyAJdeKKj6t7h8_eBXQr2cVPglyoYLDQ,4500
12
- bartz-0.3.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
- bartz-0.3.0.dist-info/RECORD,,
10
+ bartz-0.4.0.dist-info/LICENSE,sha256=heuIJZQK9IexJYC-fYHoLUrgj8HG8yS3G072EvKh-94,1073
11
+ bartz-0.4.0.dist-info/METADATA,sha256=K86CVXT6ayPnc2hjhreYGMEeYWfYJIZdDkKuBB0-FYA,4500
12
+ bartz-0.4.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
13
+ bartz-0.4.0.dist-info/RECORD,,
File without changes
File without changes