bartz 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -34,6 +34,7 @@ range of possible values.
34
34
  """
35
35
 
36
36
  import functools
37
+ import math
37
38
 
38
39
  import jax
39
40
  from jax import random
@@ -54,7 +55,9 @@ def init(*,
54
55
  small_float=jnp.float32,
55
56
  large_float=jnp.float32,
56
57
  min_points_per_leaf=None,
57
- suffstat_batch_size='auto',
58
+ resid_batch_size='auto',
59
+ count_batch_size='auto',
60
+ save_ratios=False,
58
61
  ):
59
62
  """
60
63
  Make a BART posterior sampling MCMC initial state.
@@ -82,9 +85,13 @@ def init(*,
82
85
  The dtype for scalars, small arrays, and arrays which require accuracy.
83
86
  min_points_per_leaf : int, optional
84
87
  The minimum number of data points in a leaf node. 0 if not specified.
85
- suffstat_batch_size : int, None, str, default 'auto'
86
- The batch size for computing sufficient statistics. `None` for no
87
- batching. If 'auto', pick a value based on the device of `y`.
88
+ resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
89
+ The batch sizes, along datapoints, for summing the residuals and
90
+ counting the number of datapoints in each leaf. `None` for no batching.
91
+ If 'auto', pick a value based on the device of `y`, or the default
92
+ device.
93
+ save_ratios : bool, default False
94
+ Whether to save the Metropolis-Hastings ratios.
88
95
 
89
96
  Returns
90
97
  -------
@@ -110,6 +117,8 @@ def init(*,
110
117
  'p_nonterminal' : large_float array (d,)
111
118
  The probability of a nonterminal node at each depth, padded with a
112
119
  zero.
120
+ 'p_propose_grow' : large_float array (2 ** (d - 1),)
121
+ The unnormalized probability of picking a leaf for a grow proposal.
113
122
  'sigma2_alpha' : large_float
114
123
  The shape parameter of the inverse gamma prior on the noise variance.
115
124
  'sigma2_beta' : large_float
@@ -120,6 +129,8 @@ def init(*,
120
129
  The response.
121
130
  'X' : int array (p, n)
122
131
  The predictors.
132
+ 'leaf_indices' : int array (num_trees, n)
133
+ The index of the leaf each datapoints falls into, for each tree.
123
134
  'min_points_per_leaf' : int or None
124
135
  The minimum number of data points in a leaf node.
125
136
  'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
@@ -128,8 +139,6 @@ def init(*,
128
139
  'opt' : LeafDict
129
140
  A dictionary with config values:
130
141
 
131
- 'suffstat_batch_size' : int or None
132
- The batch size for computing sufficient statistics.
133
142
  'small_float' : dtype
134
143
  The dtype for large arrays used in the algorithm.
135
144
  'large_float' : dtype
@@ -137,6 +146,8 @@ def init(*,
137
146
  accuracy.
138
147
  'require_min_points' : bool
139
148
  Whether the `min_points_per_leaf` parameter is specified.
149
+ 'resid_batch_size', 'count_batch_size' : int or None
150
+ The data batch sizes for computing the sufficient statistics.
140
151
  """
141
152
 
142
153
  p_nonterminal = jnp.asarray(p_nonterminal, large_float)
@@ -150,24 +161,28 @@ def init(*,
150
161
  small_float = jnp.dtype(small_float)
151
162
  large_float = jnp.dtype(large_float)
152
163
  y = jnp.asarray(y, small_float)
153
- suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
164
+ resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y)
165
+ sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
166
+ sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
154
167
 
155
168
  bart = dict(
156
169
  leaf_trees=make_forest(max_depth, small_float),
157
170
  var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
158
171
  split_trees=make_forest(max_depth - 1, max_split.dtype),
159
172
  resid=jnp.asarray(y, large_float),
160
- sigma2=jnp.ones((), large_float),
173
+ sigma2=sigma2,
161
174
  grow_prop_count=jnp.zeros((), int),
162
175
  grow_acc_count=jnp.zeros((), int),
163
176
  prune_prop_count=jnp.zeros((), int),
164
177
  prune_acc_count=jnp.zeros((), int),
165
178
  p_nonterminal=p_nonterminal,
179
+ p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
166
180
  sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
167
181
  sigma2_beta=jnp.asarray(sigma2_beta, large_float),
168
182
  max_split=jnp.asarray(max_split),
169
183
  y=y,
170
184
  X=jnp.asarray(X),
185
+ leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
171
186
  min_points_per_leaf=(
172
187
  None if min_points_per_leaf is None else
173
188
  jnp.asarray(min_points_per_leaf)
@@ -177,30 +192,61 @@ def init(*,
177
192
  make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
178
193
  ),
179
194
  opt=jaxext.LeafDict(
180
- suffstat_batch_size=suffstat_batch_size,
181
195
  small_float=small_float,
182
196
  large_float=large_float,
183
197
  require_min_points=min_points_per_leaf is not None,
198
+ resid_batch_size=resid_batch_size,
199
+ count_batch_size=count_batch_size,
184
200
  ),
185
201
  )
186
202
 
203
+ if save_ratios:
204
+ bart['ratios'] = dict(
205
+ grow=dict(
206
+ trans_prior=jnp.full(num_trees, jnp.nan),
207
+ likelihood=jnp.full(num_trees, jnp.nan),
208
+ ),
209
+ prune=dict(
210
+ trans_prior=jnp.full(num_trees, jnp.nan),
211
+ likelihood=jnp.full(num_trees, jnp.nan),
212
+ ),
213
+ )
214
+
187
215
  return bart
188
216
 
189
- def _choose_suffstat_batch_size(size, y):
190
- if size == 'auto':
191
- platform = y.devices().pop().platform
217
+ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
218
+
219
+ @functools.cache
220
+ def get_platform():
221
+ try:
222
+ device = y.devices().pop()
223
+ except jax.errors.ConcretizationTypeError:
224
+ device = jax.devices()[0]
225
+ platform = device.platform
226
+ if platform not in ('cpu', 'gpu'):
227
+ raise KeyError(f'Unknown platform: {platform}')
228
+ return platform
229
+
230
+ if resid_batch_size == 'auto':
231
+ platform = get_platform()
232
+ n = max(1, y.size)
192
233
  if platform == 'cpu':
193
- return None
194
- # maybe I should batch residuals (not counts) for numerical
195
- # accuracy, even if it's slower
234
+ resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
196
235
  elif platform == 'gpu':
197
- return 128 # 128 is good on A100, and V100 at high n
198
- # 512 is good on T4, and V100 at low n
199
- else:
200
- raise KeyError(f'Unknown platform: {platform}')
201
- elif size is not None:
202
- return int(size)
203
- return size
236
+ resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
237
+ resid_batch_size = max(1, resid_batch_size)
238
+
239
+ if count_batch_size == 'auto':
240
+ platform = get_platform()
241
+ if platform == 'cpu':
242
+ count_batch_size = None
243
+ elif platform == 'gpu':
244
+ n = max(1, y.size)
245
+ count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
246
+ # /4 is good on V100, /2 on L4/T4, still haven't tried A100
247
+ count_batch_size = max(1, count_batch_size)
248
+
249
+ return resid_batch_size, count_batch_size
204
250
 
205
251
  def step(bart, key):
206
252
  """
@@ -240,14 +286,11 @@ def sample_trees(bart, key):
240
286
 
241
287
  Notes
242
288
  -----
243
- This function zeroes the proposal counters before using them.
289
+ This function zeroes the proposal counters.
244
290
  """
245
- bart = bart.copy()
246
291
  key, subkey = random.split(key)
247
292
  grow_moves, prune_moves = sample_moves(bart, subkey)
248
- bart['var_trees'] = grow_moves['var_tree']
249
- grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
250
- return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
293
+ return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
251
294
 
252
295
  def sample_moves(bart, key):
253
296
  """
@@ -266,17 +309,17 @@ def sample_moves(bart, key):
266
309
  The proposals for grow and prune moves. See `grow_move` and `prune_move`.
267
310
  """
268
311
  key = random.split(key, bart['var_trees'].shape[0])
269
- return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
312
+ return _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], key)
270
313
 
271
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
272
- def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
314
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
315
+ def _sample_moves_vmap_trees(*args):
316
+ args, key = args[:-1], args[-1]
273
317
  key, key1 = random.split(key)
274
- args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
275
318
  grow = grow_move(*args, key)
276
319
  prune = prune_move(*args, key1)
277
320
  return grow, prune
278
321
 
279
- def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
322
+ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
280
323
  """
281
324
  Tree structure grow move proposal of BART MCMC.
282
325
 
@@ -296,6 +339,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
296
339
  The maximum split index for each variable.
297
340
  p_nonterminal : array (d,)
298
341
  The probability of a nonterminal node at each depth.
342
+ p_propose_grow : array (2 ** (d - 1),)
343
+ The unnormalized probability of choosing a leaf to grow.
299
344
  key : jax.dtypes.prng_key array
300
345
  A jax random key.
301
346
 
@@ -304,41 +349,49 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
304
349
  grow_move : dict
305
350
  A dictionary with fields:
306
351
 
307
- 'allowed' : bool
308
- Whether the move is possible.
352
+ 'num_growable' : int
353
+ The number of growable leaves.
309
354
  'node' : int
310
- The index of the leaf to grow.
311
- 'var_tree' : array (2 ** (d - 1),)
312
- The new decision axes of the tree.
313
- 'split_tree' : array (2 ** (d - 1),)
314
- The new decision boundaries of the tree.
355
+ The index of the leaf to grow. ``2 ** d`` if there are no growable
356
+ leaves.
357
+ 'left', 'right' : int
358
+ The indices of the children of 'node'.
359
+ 'var', 'split' : int
360
+ The decision axis and boundary of the new rule.
315
361
  'partial_ratio' : float
316
362
  A factor of the Metropolis-Hastings ratio of the move. It lacks
317
363
  the likelihood ratio and the probability of proposing the prune
318
364
  move.
365
+ 'var_tree', 'split_tree' : array (2 ** (d - 1),)
366
+ The updated decision axes and boundaries of the tree.
319
367
  """
320
368
 
321
369
  key, key1, key2 = random.split(key, 3)
322
-
323
- leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
370
+
371
+ leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
324
372
 
325
373
  var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
326
374
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
327
-
375
+
328
376
  split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
329
377
  split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
330
378
 
331
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
379
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
332
380
 
381
+ left = leaf_to_grow << 1
333
382
  return dict(
334
- allowed=allowed,
383
+ num_growable=num_growable,
335
384
  node=leaf_to_grow,
385
+ left=left,
386
+ right=left + 1,
387
+ var=var,
388
+ split=split,
336
389
  partial_ratio=ratio,
337
390
  var_tree=var_tree,
338
391
  split_tree=split_tree,
339
392
  )
340
393
 
341
- def choose_leaf(split_tree, affluence_tree, key):
394
+ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
342
395
  """
343
396
  Choose a leaf node to grow in a tree.
344
397
 
@@ -348,6 +401,8 @@ def choose_leaf(split_tree, affluence_tree, key):
348
401
  The splitting points of the tree.
349
402
  affluence_tree : bool array (2 ** (d - 1),) or None
350
403
  Whether a leaf has enough points to be grown.
404
+ p_propose_grow : array (2 ** (d - 1),)
405
+ The unnormalized probability of choosing a leaf to grow.
351
406
  key : jax.dtypes.prng_key array
352
407
  A jax random key.
353
408
 
@@ -358,19 +413,21 @@ def choose_leaf(split_tree, affluence_tree, key):
358
413
  ``2 ** d``.
359
414
  num_growable : int
360
415
  The number of leaf nodes that can be grown.
416
+ prob_choose : float
417
+ The normalized probability of choosing the selected leaf.
361
418
  num_prunable : int
362
419
  The number of leaf parents that could be pruned, after converting the
363
420
  selected leaf to a non-terminal node.
364
- allowed : bool
365
- Whether the grow move is allowed.
366
421
  """
367
- is_growable, allowed = growable_leaves(split_tree, affluence_tree)
368
- leaf_to_grow = randint_masked(key, is_growable)
369
- leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
422
+ is_growable = growable_leaves(split_tree, affluence_tree)
370
423
  num_growable = jnp.count_nonzero(is_growable)
424
+ distr = jnp.where(is_growable, p_propose_grow, 0)
425
+ leaf_to_grow, distr_norm = categorical(key, distr)
426
+ leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
427
+ prob_choose = distr[leaf_to_grow] / distr_norm
371
428
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
372
429
  num_prunable = jnp.count_nonzero(is_parent)
373
- return leaf_to_grow, num_growable, num_prunable, allowed
430
+ return leaf_to_grow, num_growable, prob_choose, num_prunable
374
431
 
375
432
  def growable_leaves(split_tree, affluence_tree):
376
433
  """
@@ -389,34 +446,32 @@ def growable_leaves(split_tree, affluence_tree):
389
446
  The mask indicating the leaf nodes that can be proposed to grow, i.e.,
390
447
  that are not at the bottom level and have at least two times the number
391
448
  of minimum points per leaf.
392
- allowed : bool
393
- Whether the grow move is allowed, i.e., there are growable leaves.
394
449
  """
395
450
  is_growable = grove.is_actual_leaf(split_tree)
396
451
  if affluence_tree is not None:
397
452
  is_growable &= affluence_tree
398
- return is_growable, jnp.any(is_growable)
453
+ return is_growable
399
454
 
400
- def randint_masked(key, mask):
455
+ def categorical(key, distr):
401
456
  """
402
- Return a random integer in a range, including only some values.
457
+ Return a random integer from an arbitrary distribution.
403
458
 
404
459
  Parameters
405
460
  ----------
406
461
  key : jax.dtypes.prng_key array
407
462
  A jax random key.
408
- mask : bool array (n,)
409
- The mask indicating the allowed values.
463
+ distr : float array (n,)
464
+ An unnormalized probability distribution.
410
465
 
411
466
  Returns
412
467
  -------
413
468
  u : int
414
- A random integer in the range ``[0, n)``, and which satisfies
415
- ``mask[u] == True``. If all values in the mask are `False`, return `n`.
469
+ A random integer in the range ``[0, n)``. If all probabilities are zero,
470
+ return ``n``.
416
471
  """
417
- ecdf = jnp.cumsum(mask)
418
- u = random.randint(key, (), 0, ecdf[-1])
419
- return jnp.searchsorted(ecdf, u, 'right')
472
+ ecdf = jnp.cumsum(distr)
473
+ u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
474
+ return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
420
475
 
421
476
  def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
422
477
  """
@@ -471,7 +526,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
471
526
  filled with `p`. The fill values are not guaranteed to be placed in any
472
527
  particular order. Variables may appear more than once.
473
528
  """
474
-
529
+
475
530
  var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
476
531
  split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
477
532
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
@@ -603,7 +658,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
603
658
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
604
659
  return random.randint(key, (), l, r)
605
660
 
606
- def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
661
+ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
607
662
  """
608
663
  Compute the product of the transition and prior ratios of a grow move.
609
664
 
@@ -632,6 +687,9 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
632
687
  # the two ratios also contain factors num_available_split *
633
688
  # num_available_var, but they cancel out
634
689
 
690
+ # p_prune can't be computed here because it needs the count trees, which are
691
+ # computed in the acceptance phase
692
+
635
693
  prune_allowed = leaf_to_grow != 1
636
694
  # prune allowed <---> the initial tree is not a root
637
695
  # leaf to grow is root --> the tree can only be a root
@@ -639,31 +697,33 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
639
697
 
640
698
  p_grow = jnp.where(prune_allowed, 0.5, 1)
641
699
 
642
- trans_ratio = num_growable / (p_grow * num_prunable)
700
+ inv_trans_ratio = p_grow * prob_choose * num_prunable
643
701
 
644
702
  depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
645
703
  p_parent = p_nonterminal[depth]
646
704
  cp_children = 1 - p_nonterminal[depth + 1]
647
705
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
648
706
 
649
- return trans_ratio * tree_ratio
707
+ return tree_ratio / inv_trans_ratio
650
708
 
651
- def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
709
+ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
652
710
  """
653
711
  Tree structure prune move proposal of BART MCMC.
654
712
 
655
713
  Parameters
656
714
  ----------
657
- var_tree : array (2 ** (d - 1),)
715
+ var_tree : int array (2 ** (d - 1),)
658
716
  The variable indices of the tree.
659
- split_tree : array (2 ** (d - 1),)
717
+ split_tree : int array (2 ** (d - 1),)
660
718
  The splitting points of the tree.
661
719
  affluence_tree : bool array (2 ** (d - 1),) or None
662
720
  Whether a leaf has enough points to be grown.
663
- max_split : array (p,)
721
+ max_split : int array (p,)
664
722
  The maximum split index for each variable.
665
- p_nonterminal : array (d,)
723
+ p_nonterminal : float array (d,)
666
724
  The probability of a nonterminal node at each depth.
725
+ p_propose_grow : float array (2 ** (d - 1),)
726
+ The unnormalized probability of choosing a leaf to grow.
667
727
  key : jax.dtypes.prng_key array
668
728
  A jax random key.
669
729
 
@@ -675,24 +735,29 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
675
735
  'allowed' : bool
676
736
  Whether the move is possible.
677
737
  'node' : int
678
- The index of the node to prune.
738
+ The index of the node to prune. ``2 ** d`` if no node can be pruned.
739
+ 'left', 'right' : int
740
+ The indices of the children of 'node'.
679
741
  'partial_ratio' : float
680
742
  A factor of the Metropolis-Hastings ratio of the move. It lacks
681
743
  the likelihood ratio and the probability of proposing the prune
682
744
  move. This ratio is inverted.
683
745
  """
684
- node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
746
+ node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
685
747
  allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
686
748
 
687
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
749
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune, split_tree)
688
750
 
751
+ left = node_to_prune << 1
689
752
  return dict(
690
753
  allowed=allowed,
691
754
  node=node_to_prune,
755
+ left=left,
756
+ right=left + 1,
692
757
  partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
693
758
  )
694
759
 
695
- def choose_leaf_parent(split_tree, affluence_tree, key):
760
+ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
696
761
  """
697
762
  Pick a non-terminal node with leaf children to prune in a tree.
698
763
 
@@ -702,6 +767,8 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
702
767
  The splitting points of the tree.
703
768
  affluence_tree : bool array (2 ** (d - 1),) or None
704
769
  Whether a leaf has enough points to be grown.
770
+ p_propose_grow : array (2 ** (d - 1),)
771
+ The unnormalized probability of choosing a leaf to grow.
705
772
  key : jax.dtypes.prng_key array
706
773
  A jax random key.
707
774
 
@@ -709,28 +776,50 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
709
776
  -------
710
777
  node_to_prune : int
711
778
  The index of the node to prune. If ``num_prunable == 0``, return
712
- ``split_tree.size``.
779
+ ``2 ** d``.
713
780
  num_prunable : int
714
781
  The number of leaf parents that could be pruned.
715
- num_growable : int
716
- The number of leaf nodes that can be grown, after pruning the chosen
717
- node.
782
+ prob_choose : float
783
+ The normalized probability of choosing the node to prune for growth.
718
784
  """
719
785
  is_prunable = grove.is_leaves_parent(split_tree)
720
- node_to_prune = randint_masked(key, is_prunable)
721
786
  num_prunable = jnp.count_nonzero(is_prunable)
787
+ node_to_prune = randint_masked(key, is_prunable)
788
+ node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
722
789
 
723
- pruned_split_tree = split_tree.at[node_to_prune].set(0)
724
- pruned_affluence_tree = (
790
+ split_tree = split_tree.at[node_to_prune].set(0)
791
+ affluence_tree = (
725
792
  None if affluence_tree is None else
726
793
  affluence_tree.at[node_to_prune].set(True)
727
794
  )
728
- is_growable_leaf, _ = growable_leaves(pruned_split_tree, pruned_affluence_tree)
729
- num_growable = jnp.count_nonzero(is_growable_leaf)
795
+ is_growable_leaf = growable_leaves(split_tree, affluence_tree)
796
+ prob_choose = p_propose_grow[node_to_prune]
797
+ prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
798
+
799
+ return node_to_prune, num_prunable, prob_choose
800
+
801
+ def randint_masked(key, mask):
802
+ """
803
+ Return a random integer in a range, including only some values.
730
804
 
731
- return node_to_prune, num_prunable, num_growable
805
+ Parameters
806
+ ----------
807
+ key : jax.dtypes.prng_key array
808
+ A jax random key.
809
+ mask : bool array (n,)
810
+ The mask indicating the allowed values.
811
+
812
+ Returns
813
+ -------
814
+ u : int
815
+ A random integer in the range ``[0, n)``, and which satisfies
816
+ ``mask[u] == True``. If all values in the mask are `False`, return `n`.
817
+ """
818
+ ecdf = jnp.cumsum(mask)
819
+ u = random.randint(key, (), 0, ecdf[-1])
820
+ return jnp.searchsorted(ecdf, u, 'right')
732
821
 
733
- def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
822
+ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
734
823
  """
735
824
  Accept or reject the proposed moves and sample the new leaf values.
736
825
 
@@ -744,8 +833,6 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
744
833
  prune_moves : dict
745
834
  The proposals for prune moves, batched over the first axis. See
746
835
  `prune_move`.
747
- grow_leaf_indices : int array (num_trees, n)
748
- The leaf indices of the trees proposed by the grow move.
749
836
  key : jax.dtypes.prng_key array
750
837
  A jax random key.
751
838
 
@@ -754,41 +841,339 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
754
841
  bart : dict
755
842
  The new BART mcmc state.
756
843
  """
844
+ bart, grow_moves, prune_moves, count_trees, move_counts, u, z = accept_moves_parallel_stage(bart, grow_moves, prune_moves, key)
845
+ bart, counts = accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z)
846
+ return accept_moves_final_stage(bart, counts, grow_moves, prune_moves)
847
+
848
+ def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
849
+ """
850
+ Pre-computes quantities used to accept moves, in parallel across trees.
851
+
852
+ Parameters
853
+ ----------
854
+ bart : dict
855
+ A BART mcmc state.
856
+ grow_moves, prune_moves : dict
857
+ The proposals for the moves, batched over the first axis. See
858
+ `grow_move` and `prune_move`.
859
+ key : jax.dtypes.prng_key array
860
+ A jax random key.
861
+
862
+ Returns
863
+ -------
864
+ bart : dict
865
+ A partially updated BART mcmc state.
866
+ grow_moves, prune_moves : dict
867
+ The proposals for the moves, with the field 'partial_ratio' replaced
868
+ by 'trans_prior_ratio'.
869
+ count_trees : array (num_trees, 2 ** (d - 1))
870
+ The number of points in each potential or actual leaf node.
871
+ move_counts : dict
872
+ The counts of the number of points in the the nodes modified by the
873
+ moves.
874
+ u : float array (num_trees, 2)
875
+ Random uniform values used to accept the moves.
876
+ z : float array (num_trees, 2 ** d)
877
+ Random standard normal values used to sample the new leaf values.
878
+ """
879
+ bart = bart.copy()
880
+
881
+ bart['var_trees'] = grow_moves['var_tree']
882
+ # Since var_tree can contain garbage, I can set the var of leaf to be
883
+ # grown irrespectively of what move I'm gonna accept in the end.
884
+
885
+ bart['leaf_indices'] = apply_grow_to_indices(grow_moves, bart['leaf_indices'], bart['X'])
886
+
887
+ count_trees, move_counts = compute_count_trees(bart['leaf_indices'], grow_moves, prune_moves, bart['opt']['count_batch_size'])
888
+
889
+ grow_moves, prune_moves = complete_ratio(grow_moves, prune_moves, move_counts, bart['min_points_per_leaf'])
890
+
891
+ if bart['opt']['require_min_points']:
892
+ count_half_trees = count_trees[:, :grow_moves['split_tree'].shape[1]]
893
+ bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
894
+
895
+ bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], grow_moves)
896
+
897
+ key, subkey = random.split(key)
898
+ u = random.uniform(subkey, (len(bart['leaf_trees']), 2), bart['opt']['large_float'])
899
+ z = random.normal(key, bart['leaf_trees'].shape, bart['opt']['large_float'])
900
+
901
+ return bart, grow_moves, prune_moves, count_trees, move_counts, u, z
902
+
903
+ def apply_grow_to_indices(grow_moves, leaf_indices, X):
904
+ """
905
+ Update the leaf indices to apply a grow move.
906
+
907
+ Parameters
908
+ ----------
909
+ grow_moves : dict
910
+ The proposals for grow moves. See `grow_move`.
911
+ leaf_indices : array (num_trees, n)
912
+ The index of the leaf each datapoint falls into.
913
+ X : array (p, n)
914
+ The predictors matrix.
915
+
916
+ Returns
917
+ -------
918
+ grow_leaf_indices : array (num_trees, n)
919
+ The updated leaf indices.
920
+ """
921
+ left_child = grow_moves['node'].astype(leaf_indices.dtype) << 1
922
+ go_right = X[grow_moves['var'], :] >= grow_moves['split'][:, None]
923
+ tree_size = jnp.array(2 * grow_moves['split_tree'].shape[1])
924
+ node_to_update = jnp.where(grow_moves['num_growable'], grow_moves['node'], tree_size)
925
+ return jnp.where(
926
+ leaf_indices == node_to_update[:, None],
927
+ left_child[:, None] + go_right,
928
+ leaf_indices,
929
+ )
930
+
931
+ def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
932
+ """
933
+ Count the number of datapoints in each leaf.
934
+
935
+ Parameters
936
+ ----------
937
+ grow_leaf_indices : int array (num_trees, n)
938
+ The index of the leaf each datapoint falls into, if the grow move is
939
+ accepted.
940
+ grow_moves, prune_moves : dict
941
+ The proposals for the moves. See `grow_move` and `prune_move`.
942
+ batch_size : int or None
943
+ The data batch size to use for the summation.
944
+
945
+ Returns
946
+ -------
947
+ count_trees : int array (num_trees, 2 ** (d - 1))
948
+ The number of points in each potential or actual leaf node.
949
+ counts : dict
950
+ The counts of the number of points in the the nodes modified by the
951
+ moves, organized as two dictionaries 'grow' and 'prune', with subfields
952
+ 'left', 'right', and 'total'.
953
+ """
954
+
955
+ ntree, tree_size = grow_moves['split_tree'].shape
956
+ tree_size *= 2
957
+ counts = dict(grow=dict(), prune=dict())
958
+ tree_indices = jnp.arange(ntree)
959
+
960
+ count_trees = count_datapoints_per_leaf(grow_leaf_indices, tree_size, batch_size)
961
+
962
+ # count datapoints in leaf to grow
963
+ counts['grow']['left'] = count_trees[tree_indices, grow_moves['left']]
964
+ counts['grow']['right'] = count_trees[tree_indices, grow_moves['right']]
965
+ counts['grow']['total'] = counts['grow']['left'] + counts['grow']['right']
966
+ count_trees = count_trees.at[tree_indices, grow_moves['node']].set(counts['grow']['total'])
967
+
968
+ # count datapoints in node to prune
969
+ counts['prune']['left'] = count_trees[tree_indices, prune_moves['left']]
970
+ counts['prune']['right'] = count_trees[tree_indices, prune_moves['right']]
971
+ counts['prune']['total'] = counts['prune']['left'] + counts['prune']['right']
972
+ count_trees = count_trees.at[tree_indices, prune_moves['node']].set(counts['prune']['total'])
973
+
974
+ return count_trees, counts
975
+
976
+ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
977
+ """
978
+ Count the number of datapoints in each leaf.
979
+
980
+ Parameters
981
+ ----------
982
+ leaf_indices : int array (num_trees, n)
983
+ The index of the leaf each datapoint falls into.
984
+ tree_size : int
985
+ The size of the leaf tree array (2 ** d).
986
+ batch_size : int or None
987
+ The data batch size to use for the summation.
988
+
989
+ Returns
990
+ -------
991
+ count_trees : int array (num_trees, 2 ** (d - 1))
992
+ The number of points in each leaf node.
993
+ """
994
+ if batch_size is None:
995
+ return _count_scan(leaf_indices, tree_size)
996
+ else:
997
+ return _count_vec(leaf_indices, tree_size, batch_size)
998
+
999
+ def _count_scan(leaf_indices, tree_size):
1000
+ def loop(_, leaf_indices):
1001
+ return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1002
+ _, count_trees = lax.scan(loop, None, leaf_indices)
1003
+ return count_trees
1004
+
1005
+ def _aggregate_scatter(values, indices, size, dtype):
1006
+ return (jnp
1007
+ .zeros(size, dtype)
1008
+ .at[indices]
1009
+ .add(values)
1010
+ )
1011
+
1012
+ def _count_vec(leaf_indices, tree_size, batch_size):
1013
+ return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
1014
+ # uint16 is super-slow on gpu, don't use it even if n < 2^16
1015
+
1016
+ def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1017
+ ntree, n = indices.shape
1018
+ tree_indices = jnp.arange(ntree)
1019
+ nbatches = n // batch_size + bool(n % batch_size)
1020
+ batch_indices = jnp.arange(n) % nbatches
1021
+ return (jnp
1022
+ .zeros((ntree, size, nbatches), dtype)
1023
+ .at[tree_indices[:, None], indices, batch_indices]
1024
+ .add(values)
1025
+ .sum(axis=2)
1026
+ )
1027
+
1028
+ def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
1029
+ """
1030
+ Complete non-likelihood MH ratio calculation.
1031
+
1032
+ This functions adds the probability of choosing the prune move.
1033
+
1034
+ Parameters
1035
+ ----------
1036
+ grow_moves, prune_moves : dict
1037
+ The proposals for the moves. See `grow_move` and `prune_move`.
1038
+ move_counts : dict
1039
+ The counts of the number of points in the the nodes modified by the
1040
+ moves.
1041
+ min_points_per_leaf : int or None
1042
+ The minimum number of data points in a leaf node.
1043
+
1044
+ Returns
1045
+ -------
1046
+ grow_moves, prune_moves : dict
1047
+ The proposals for the moves, with the field 'partial_ratio' replaced
1048
+ by 'trans_prior_ratio'.
1049
+ """
1050
+ grow_moves = grow_moves.copy()
1051
+ prune_moves = prune_moves.copy()
1052
+ compute_p_prune_vec = jax.vmap(compute_p_prune, in_axes=(0, 0, 0, None))
1053
+ grow_p_prune, prune_p_prune = compute_p_prune_vec(grow_moves, move_counts['grow']['left'], move_counts['grow']['right'], min_points_per_leaf)
1054
+ grow_moves['trans_prior_ratio'] = grow_moves.pop('partial_ratio') * grow_p_prune
1055
+ prune_moves['trans_prior_ratio'] = prune_moves.pop('partial_ratio') * prune_p_prune
1056
+ return grow_moves, prune_moves
1057
+
1058
+ def compute_p_prune(grow_move, grow_left_count, grow_right_count, min_points_per_leaf):
1059
+ """
1060
+ Compute the probability of proposing a prune move.
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ grow_move : dict
1065
+ The proposal for the grow move, see `grow_move`.
1066
+ grow_left_count, grow_right_count : int
1067
+ The number of datapoints in the proposed children of the leaf to grow.
1068
+ min_points_per_leaf : int or None
1069
+ The minimum number of data points in a leaf node.
1070
+
1071
+ Returns
1072
+ -------
1073
+ grow_p_prune : float
1074
+ The probability of proposing a prune move, after accepting the grow
1075
+ move.
1076
+ prune_p_prune : float
1077
+ The probability of proposing the prune move.
1078
+ """
1079
+ other_growable_leaves = grow_move['num_growable'] >= 2
1080
+ new_leaves_growable = grow_move['node'] < grow_move['split_tree'].size // 2
1081
+ if min_points_per_leaf is not None:
1082
+ any_above_threshold = grow_left_count >= 2 * min_points_per_leaf
1083
+ any_above_threshold |= grow_right_count >= 2 * min_points_per_leaf
1084
+ new_leaves_growable &= any_above_threshold
1085
+ grow_again_allowed = other_growable_leaves | new_leaves_growable
1086
+ grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1087
+ prune_p_prune = jnp.where(grow_move['num_growable'], 0.5, 1)
1088
+ return grow_p_prune, prune_p_prune
1089
+
1090
+ def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
1091
+ """
1092
+ Modify leaf values such that the indices of the grow move work on the
1093
+ original tree.
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ leaf_trees : float array (num_trees, 2 ** d)
1098
+ The leaf values.
1099
+ grow_moves : dict
1100
+ The proposals for grow moves. See `grow_move`.
1101
+
1102
+ Returns
1103
+ -------
1104
+ leaf_trees : float array (num_trees, 2 ** d)
1105
+ The modified leaf values. The value of the leaf to grow is copied to
1106
+ what would be its children if the grow move was accepted.
1107
+ """
1108
+ ntree, _ = leaf_trees.shape
1109
+ tree_indices = jnp.arange(ntree)
1110
+ values_at_node = leaf_trees[tree_indices, grow_moves['node']]
1111
+ return (leaf_trees
1112
+ .at[tree_indices, grow_moves['left']]
1113
+ .set(values_at_node)
1114
+ .at[tree_indices, grow_moves['right']]
1115
+ .set(values_at_node)
1116
+ )
1117
+
1118
+ def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z):
1119
+ """
1120
+ The part of accepting the moves that has to be done one tree at a time.
1121
+
1122
+ Parameters
1123
+ ----------
1124
+ bart : dict
1125
+ A partially updated BART mcmc state.
1126
+ count_trees : array (num_trees, 2 ** (d - 1))
1127
+ The number of points in each potential or actual leaf node.
1128
+ grow_moves, prune_moves : dict
1129
+ The proposals for the moves, with completed ratios. See `grow_move` and
1130
+ `prune_move`.
1131
+ move_counts : dict
1132
+ The counts of the number of points in the the nodes modified by the
1133
+ moves.
1134
+ u : float array (num_trees, 2)
1135
+ Random uniform values used to for proposal and accept decisions.
1136
+ z : float array (num_trees, 2 ** d)
1137
+ Random standard normal values used to sample the new leaf values.
1138
+
1139
+ Returns
1140
+ -------
1141
+ bart : dict
1142
+ A partially updated BART mcmc state.
1143
+ counts : dict
1144
+ The indicators of proposals and acceptances for grow and prune moves.
1145
+ """
757
1146
  bart = bart.copy()
758
- def loop(carry, item):
759
- resid = carry.pop('resid')
760
- resid, carry, trees = accept_move_and_sample_leaves(
1147
+
1148
+ def loop(resid, item):
1149
+ resid, leaf_tree, split_tree, counts, ratios = accept_move_and_sample_leaves(
761
1150
  bart['X'],
762
1151
  len(bart['leaf_trees']),
763
- bart['opt']['suffstat_batch_size'],
1152
+ bart['opt']['resid_batch_size'],
764
1153
  resid,
765
1154
  bart['sigma2'],
766
1155
  bart['min_points_per_leaf'],
767
- carry,
1156
+ 'ratios' in bart,
768
1157
  *item,
769
1158
  )
770
- carry['resid'] = resid
771
- return carry, trees
772
- carry = {
773
- k: jnp.zeros_like(bart[k]) for k in
774
- ['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
775
- }
776
- carry['resid'] = bart['resid']
1159
+ return resid, (leaf_tree, split_tree, counts, ratios)
1160
+
777
1161
  items = (
778
- bart['leaf_trees'],
779
- bart['split_trees'],
780
- bart['affluence_trees'],
781
- grow_moves,
782
- prune_moves,
783
- grow_leaf_indices,
784
- random.split(key, len(bart['leaf_trees'])),
1162
+ bart['leaf_trees'], count_trees,
1163
+ grow_moves, prune_moves, move_counts,
1164
+ bart['leaf_indices'],
1165
+ u, z,
785
1166
  )
786
- carry, trees = lax.scan(loop, carry, items)
787
- bart.update(carry)
788
- bart.update(trees)
789
- return bart
1167
+ resid, (leaf_trees, split_trees, counts, ratios) = lax.scan(loop, bart['resid'], items)
1168
+
1169
+ bart['resid'] = resid
1170
+ bart['leaf_trees'] = leaf_trees
1171
+ bart['split_trees'] = split_trees
1172
+ bart.get('ratios', {}).update(ratios)
790
1173
 
791
- def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
1174
+ return bart, counts
1175
+
1176
+ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min_points_per_leaf, save_ratios, leaf_tree, count_tree, grow_move, prune_move, move_counts, grow_leaf_indices, u, z):
792
1177
  """
793
1178
  Accept or reject a proposed move and sample the new leaf values.
794
1179
 
@@ -798,158 +1183,157 @@ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2,
798
1183
  The predictors.
799
1184
  ntree : int
800
1185
  The number of trees in the forest.
801
- suffstat_batch_size : int, None
802
- The batch size for computing sufficient statistics.
1186
+ resid_batch_size : int, None
1187
+ The batch size for computing the sum of residuals in each leaf.
803
1188
  resid : float array (n,)
804
1189
  The residuals (data minus forest value).
805
1190
  sigma2 : float
806
1191
  The noise variance.
807
1192
  min_points_per_leaf : int or None
808
1193
  The minimum number of data points in a leaf node.
809
- counts : dict
810
- The acceptance counts from the mcmc state dict.
1194
+ save_ratios : bool
1195
+ Whether to save the acceptance ratios.
811
1196
  leaf_tree : float array (2 ** d,)
812
1197
  The leaf values of the tree.
813
- split_tree : int array (2 ** (d - 1),)
814
- The decision boundaries of the tree.
815
- affluence_tree : bool array (2 ** (d - 1),) or None
816
- Whether a leaf has enough points to be grown.
817
- grow_move : dict
818
- The proposal for the grow move. See `grow_move`.
819
- prune_move : dict
820
- The proposal for the prune move. See `prune_move`.
1198
+ count_tree : int array (2 ** d,)
1199
+ The number of datapoints in each leaf.
1200
+ grow_move, prune_move : dict
1201
+ The proposals for the moves, with completed ratios. See `grow_move` and
1202
+ `prune_move`.
821
1203
  grow_leaf_indices : int array (n,)
822
1204
  The leaf indices of the tree proposed by the grow move.
823
- key : jax.dtypes.prng_key array
824
- A jax random key.
1205
+ u : float array (2,)
1206
+ Two uniform random values in [0, 1).
1207
+ z : float array (2 ** d,)
1208
+ Standard normal random values.
825
1209
 
826
1210
  Returns
827
1211
  -------
828
1212
  resid : float array (n,)
829
1213
  The updated residuals (data minus forest value).
1214
+ leaf_tree : float array (2 ** d,)
1215
+ The new leaf values of the tree.
1216
+ split_tree : int array (2 ** (d - 1),)
1217
+ The updated decision boundaries of the tree.
830
1218
  counts : dict
831
- The updated acceptance counts.
832
- trees : dict
833
- The updated tree arrays.
1219
+ The indicators of proposals and acceptances for grow and prune moves.
1220
+ ratios : dict
1221
+ The acceptance ratios for the moves. Empty if not to be saved.
834
1222
  """
835
-
836
- # compute leaf indices in starting tree
837
- grow_node = grow_move['node']
838
- grow_left = grow_node << 1
839
- grow_right = grow_left + 1
840
- leaf_indices = jnp.where(
841
- (grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
842
- grow_node,
843
- grow_leaf_indices,
844
- )
845
1223
 
846
- # compute leaf indices in prune tree
847
- prune_node = prune_move['node']
848
- prune_left = prune_node << 1
849
- prune_right = prune_left + 1
850
- prune_leaf_indices = jnp.where(
851
- (leaf_indices == prune_left) | (leaf_indices == prune_right),
852
- prune_node,
853
- leaf_indices,
854
- )
1224
+ # sum residuals and count units per leaf, in tree proposed by grow move
1225
+ resid_tree = sum_resid(resid, grow_leaf_indices, leaf_tree.size, resid_batch_size)
855
1226
 
856
1227
  # subtract starting tree from function
857
- resid += leaf_tree[leaf_indices]
1228
+ resid_tree += count_tree * leaf_tree
858
1229
 
859
- # aggregate residuals and count units per leaf
860
- grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
1230
+ # get indices of grow move
1231
+ grow_node = grow_move['node']
1232
+ assert grow_node.dtype == jnp.int32
1233
+ grow_left = grow_move['left']
1234
+ grow_right = grow_move['right']
861
1235
 
862
- # compute aggregations in starting tree
863
- # I do not zero the children because garbage there does not matter
864
- resid_tree = (grow_resid_tree.at[grow_node]
865
- .set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
866
- count_tree = (grow_count_tree.at[grow_node]
867
- .set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
1236
+ # sum residuals in leaf to grow
1237
+ grow_resid_left = resid_tree[grow_left]
1238
+ grow_resid_right = resid_tree[grow_right]
1239
+ grow_resid_total = grow_resid_left + grow_resid_right
1240
+ resid_tree = resid_tree.at[grow_node].set(grow_resid_total)
868
1241
 
869
- # compute aggregations in prune tree
870
- prune_resid_tree = (resid_tree.at[prune_node]
871
- .set(resid_tree[prune_left] + resid_tree[prune_right]))
872
- prune_count_tree = (count_tree.at[prune_node]
873
- .set(count_tree[prune_left] + count_tree[prune_right]))
1242
+ # get indices of prune move
1243
+ prune_node = prune_move['node']
1244
+ assert prune_node.dtype == jnp.int32
1245
+ prune_left = prune_move['left']
1246
+ prune_right = prune_move['right']
874
1247
 
875
- # compute affluence trees
876
- if min_points_per_leaf is not None:
877
- grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
878
- prune_affluence_tree = affluence_tree.at[prune_node].set(True)
1248
+ # sum residuals in node to prune
1249
+ prune_resid_left = resid_tree[prune_left]
1250
+ prune_resid_right = resid_tree[prune_right]
1251
+ prune_resid_total = prune_resid_left + prune_resid_right
1252
+ resid_tree = resid_tree.at[prune_node].set(prune_resid_total)
879
1253
 
880
- # compute probability of proposing prune
881
- grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
882
- prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
1254
+ # Now resid_tree and count_tree contain correct values whatever move is
1255
+ # accepted.
883
1256
 
884
1257
  # compute likelihood ratios
885
- grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
886
- prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
1258
+ grow_lk_ratio = compute_likelihood_ratio(grow_resid_total, grow_resid_left, grow_resid_right, move_counts['grow']['total'], move_counts['grow']['left'], move_counts['grow']['right'], sigma2, ntree)
1259
+ prune_lk_ratio = compute_likelihood_ratio(prune_resid_total, prune_resid_left, prune_resid_right, move_counts['prune']['total'], move_counts['prune']['left'], move_counts['prune']['right'], sigma2, ntree)
887
1260
 
888
1261
  # compute acceptance ratios
889
- grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
890
- prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
1262
+ grow_ratio = grow_move['trans_prior_ratio'] * grow_lk_ratio
1263
+ if min_points_per_leaf is not None:
1264
+ grow_ratio = jnp.where(move_counts['grow']['left'] >= min_points_per_leaf, grow_ratio, 0)
1265
+ grow_ratio = jnp.where(move_counts['grow']['right'] >= min_points_per_leaf, grow_ratio, 0)
1266
+ prune_ratio = prune_move['trans_prior_ratio'] * prune_lk_ratio
891
1267
  prune_ratio = lax.reciprocal(prune_ratio)
892
1268
 
893
- # random coins in [0, 1) for proposal and acceptance
894
- key, subkey = random.split(key)
895
- u0, u1 = random.uniform(subkey, (2,))
1269
+ # save acceptance ratios
1270
+ ratios = {}
1271
+ if save_ratios:
1272
+ ratios.update(
1273
+ grow=dict(
1274
+ trans_prior=grow_move['trans_prior_ratio'],
1275
+ likelihood=grow_lk_ratio,
1276
+ ),
1277
+ prune=dict(
1278
+ trans_prior=lax.reciprocal(prune_move['trans_prior_ratio']),
1279
+ likelihood=lax.reciprocal(prune_lk_ratio),
1280
+ ),
1281
+ )
896
1282
 
897
1283
  # determine what move to propose (not proposing anything is an option)
898
- p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
899
- try_grow = u0 < p_grow
1284
+ grow_allowed = grow_move['num_growable'].astype(bool)
1285
+ p_grow = jnp.where(grow_allowed & prune_move['allowed'], 0.5, grow_allowed)
1286
+ try_grow = u[0] < p_grow # use < instead of <= because coins are in [0, 1)
900
1287
  try_prune = prune_move['allowed'] & ~try_grow
901
1288
 
902
1289
  # determine whether to accept the move
903
- do_grow = try_grow & (u1 < grow_ratio)
904
- do_prune = try_prune & (u1 < prune_ratio)
905
-
906
- # pick trees for chosen move
907
- trees = {}
908
- split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
909
- # the prune var tree is equal to the initial one, because I leave garbage values behind
910
- split_tree = split_tree.at[prune_node].set(
911
- jnp.where(do_prune, 0, split_tree[prune_node]))
912
- if min_points_per_leaf is not None:
913
- affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
914
- affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
915
- resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
916
- count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
917
- resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
918
- count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
919
-
920
- # update acceptance counts
921
- counts = counts.copy()
922
- counts['grow_prop_count'] += try_grow
923
- counts['grow_acc_count'] += do_grow
924
- counts['prune_prop_count'] += try_prune
925
- counts['prune_acc_count'] += do_prune
926
-
927
- # compute leaves posterior
928
- prec_lk = count_tree / sigma2
1290
+ do_grow = try_grow & (u[1] < grow_ratio)
1291
+ do_prune = try_prune & (u[1] < prune_ratio)
1292
+
1293
+ # pick split tree for chosen move
1294
+ split_tree = grow_move['split_tree']
1295
+ split_tree = split_tree.at[jnp.where(do_grow, split_tree.size, grow_node)].set(0)
1296
+ split_tree = split_tree.at[jnp.where(do_prune, prune_node, split_tree.size)].set(0)
1297
+ # I can leave garbage in var_tree, resid_tree, count_tree
1298
+
1299
+ # compute leaves posterior and sample leaves
1300
+ inv_sigma2 = lax.reciprocal(sigma2)
1301
+ prec_lk = count_tree * inv_sigma2
929
1302
  var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
930
- mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
931
-
932
- # sample leaves
933
- z = random.normal(key, mean_post.shape, mean_post.dtype)
1303
+ mean_post = resid_tree * inv_sigma2 * var_post # = mean_lk * prec_lk * var_post
1304
+ initial_leaf_tree = leaf_tree
934
1305
  leaf_tree = mean_post + z * jnp.sqrt(var_post)
935
1306
 
936
- # add new tree to function
937
- leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
938
- leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
939
- resid -= leaf_tree[leaf_indices]
1307
+ # copy leaves around such that the grow leaf indices select the right leaf
1308
+ leaf_tree = (leaf_tree
1309
+ .at[jnp.where(do_prune, prune_left, leaf_tree.size)]
1310
+ .set(leaf_tree[prune_node])
1311
+ .at[jnp.where(do_prune, prune_right, leaf_tree.size)]
1312
+ .set(leaf_tree[prune_node])
1313
+ )
1314
+ leaf_tree = (leaf_tree
1315
+ .at[jnp.where(do_grow, leaf_tree.size, grow_left)]
1316
+ .set(leaf_tree[grow_node])
1317
+ .at[jnp.where(do_grow, leaf_tree.size, grow_right)]
1318
+ .set(leaf_tree[grow_node])
1319
+ )
1320
+
1321
+ # replace old tree with new tree in function values
1322
+ resid += (initial_leaf_tree - leaf_tree)[grow_leaf_indices]
940
1323
 
941
- # pack trees
942
- trees = {
943
- 'leaf_trees': leaf_tree,
944
- 'split_trees': split_tree,
945
- 'affluence_trees': affluence_tree,
946
- }
1324
+ # pack proposal and acceptance indicators
1325
+ counts = dict(
1326
+ grow_prop_count=try_grow,
1327
+ grow_acc_count=do_grow,
1328
+ prune_prop_count=try_prune,
1329
+ prune_acc_count=do_prune,
1330
+ )
947
1331
 
948
- return resid, counts, trees
1332
+ return resid, leaf_tree, split_tree, counts, ratios
949
1333
 
950
- def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
1334
+ def sum_resid(resid, leaf_indices, tree_size, batch_size):
951
1335
  """
952
- Compute the sufficient statistics for the likelihood ratio of a tree move.
1336
+ Sum the residuals in each leaf.
953
1337
 
954
1338
  Parameters
955
1339
  ----------
@@ -960,104 +1344,56 @@ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
960
1344
  tree_size : int
961
1345
  The size of the tree array (2 ** d).
962
1346
  batch_size : int, None
963
- The batch size for the aggregation. Batching increases numerical
1347
+ The data batch size for the aggregation. Batching increases numerical
964
1348
  accuracy and parallelism.
965
1349
 
966
1350
  Returns
967
1351
  -------
968
1352
  resid_tree : float array (2 ** d,)
969
1353
  The sum of the residuals at data points in each leaf.
970
- count_tree : int array (2 ** d,)
971
- The number of data points in each leaf.
972
1354
  """
973
1355
  if batch_size is None:
974
1356
  aggr_func = _aggregate_scatter
975
1357
  else:
976
- aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
977
- resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
978
- count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
979
- return resid_tree, count_tree
980
-
981
- def _aggregate_scatter(values, indices, size, dtype):
982
- return (jnp
983
- .zeros(size, dtype)
984
- .at[indices]
985
- .add(values)
986
- )
1358
+ aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1359
+ return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
987
1360
 
988
- def _aggregate_batched(values, indices, size, dtype, batch_size):
989
- nbatches = indices.size // batch_size + bool(indices.size % batch_size)
990
- batch_indices = jnp.arange(indices.size) // batch_size
1361
+ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1362
+ n, = indices.shape
1363
+ nbatches = n // batch_size + bool(n % batch_size)
1364
+ batch_indices = jnp.arange(n) % nbatches
991
1365
  return (jnp
992
- .zeros((nbatches, size), dtype)
993
- .at[batch_indices, indices]
1366
+ .zeros((size, nbatches), dtype)
1367
+ .at[indices, batch_indices]
994
1368
  .add(values)
995
- .sum(axis=0)
1369
+ .sum(axis=1)
996
1370
  )
997
1371
 
998
- def compute_p_prune_back(new_split_tree, new_affluence_tree):
999
- """
1000
- Compute the probability of proposing a prune move after doing a grow move.
1001
-
1002
- Parameters
1003
- ----------
1004
- new_split_tree : int array (2 ** (d - 1),)
1005
- The decision boundaries of the tree, after the grow move.
1006
- new_affluence_tree : bool array (2 ** (d - 1),)
1007
- Which leaves have enough points to be grown, after the grow move.
1008
-
1009
- Returns
1010
- -------
1011
- p_prune : float
1012
- The probability of proposing a prune move after the grow move. This is
1013
- 0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
1014
- at least the node just grown can be pruned.
1015
- """
1016
- _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
1017
- return jnp.where(grow_again_allowed, 0.5, 1)
1018
-
1019
- def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
1372
+ def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count, left_count, right_count, sigma2, n_tree):
1020
1373
  """
1021
1374
  Compute the likelihood ratio of a grow move.
1022
1375
 
1023
1376
  Parameters
1024
1377
  ----------
1025
- resid_tree : float array (2 ** d,)
1026
- The sum of the residuals at data points in each leaf.
1027
- count_tree : int array (2 ** d,)
1028
- The number of data points in each leaf.
1378
+ total_resid : float
1379
+ The sum of the residuals in the leaf to grow.
1380
+ left_resid, right_resid : float
1381
+ The sum of the residuals in the left/right child of the leaf to grow.
1382
+ total_count : int
1383
+ The number of datapoints in the leaf to grow.
1384
+ left_count, right_count : int
1385
+ The number of datapoints in the left/right child of the leaf to grow.
1029
1386
  sigma2 : float
1030
1387
  The noise variance.
1031
- node : int
1032
- The index of the leaf that has been grown.
1033
1388
  n_tree : int
1034
1389
  The number of trees in the forest.
1035
- min_points_per_leaf : int or None
1036
- The minimum number of data points in a leaf node.
1037
1390
 
1038
1391
  Returns
1039
1392
  -------
1040
1393
  ratio : float
1041
1394
  The likelihood ratio P(data | new tree) / P(data | old tree).
1042
-
1043
- Notes
1044
- -----
1045
- The ratio is set to 0 if the grow move would create leaves with not enough
1046
- datapoints per leaf, although this is part of the prior rather than the
1047
- likelihood.
1048
1395
  """
1049
1396
 
1050
- left_child = node << 1
1051
- right_child = left_child + 1
1052
-
1053
- left_resid = resid_tree[left_child]
1054
- right_resid = resid_tree[right_child]
1055
- total_resid = left_resid + right_resid
1056
-
1057
- left_count = count_tree[left_child]
1058
- right_count = count_tree[right_child]
1059
- total_count = left_count + right_count
1060
-
1061
1397
  sigma_mu2 = 1 / n_tree
1062
1398
  sigma2_left = sigma2 + left_count * sigma_mu2
1063
1399
  sigma2_right = sigma2 + right_count * sigma_mu2
@@ -1071,13 +1407,67 @@ def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_p
1071
1407
  total_resid * total_resid / sigma2_total
1072
1408
  )
1073
1409
 
1074
- ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1410
+ return jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1075
1411
 
1076
- if min_points_per_leaf is not None:
1077
- ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
1078
- ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
1412
+ def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
1413
+ """
1414
+ The final part of accepting the moves, in parallel across trees.
1415
+
1416
+ Parameters
1417
+ ----------
1418
+ bart : dict
1419
+ A partially updated BART mcmc state.
1420
+ counts : dict
1421
+ The indicators of proposals and acceptances for grow and prune moves.
1422
+ grow_moves, prune_moves : dict
1423
+ The proposals for the moves. See `grow_move` and `prune_move`.
1424
+
1425
+ Returns
1426
+ -------
1427
+ bart : dict
1428
+ The fully updated BART mcmc state.
1429
+ """
1430
+ bart = bart.copy()
1431
+
1432
+ for k, v in counts.items():
1433
+ bart[k] = jnp.sum(v, axis=0)
1079
1434
 
1080
- return ratio
1435
+ bart['leaf_indices'] = apply_moves_to_indices(bart['leaf_indices'], counts, grow_moves, prune_moves)
1436
+
1437
+ return bart
1438
+
1439
+ def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
1440
+ """
1441
+ Update the leaf indices to match the accepted move.
1442
+
1443
+ Parameters
1444
+ ----------
1445
+ leaf_indices : int array (num_trees, n)
1446
+ The index of the leaf each datapoint falls into, if the grow move was
1447
+ accepted.
1448
+ counts : dict
1449
+ The indicators of proposals and acceptances for grow and prune moves.
1450
+ grow_moves, prune_moves : dict
1451
+ The proposals for the moves. See `grow_move` and `prune_move`.
1452
+
1453
+ Returns
1454
+ -------
1455
+ leaf_indices : int array (num_trees, n)
1456
+ The updated leaf indices.
1457
+ """
1458
+ mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1459
+ cond = (leaf_indices & mask) == grow_moves['left'][:, None]
1460
+ leaf_indices = jnp.where(
1461
+ cond & ~counts['grow_acc_count'][:, None],
1462
+ grow_moves['node'][:, None].astype(leaf_indices.dtype),
1463
+ leaf_indices,
1464
+ )
1465
+ cond = (leaf_indices & mask) == prune_moves['left'][:, None]
1466
+ return jnp.where(
1467
+ cond & counts['prune_acc_count'][:, None],
1468
+ prune_moves['node'][:, None].astype(leaf_indices.dtype),
1469
+ leaf_indices,
1470
+ )
1081
1471
 
1082
1472
  def sample_sigma(bart, key):
1083
1473
  """
@@ -1099,7 +1489,7 @@ def sample_sigma(bart, key):
1099
1489
 
1100
1490
  resid = bart['resid']
1101
1491
  alpha = bart['sigma2_alpha'] + resid.size / 2
1102
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
1492
+ norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
1103
1493
  beta = bart['sigma2_beta'] + norm2 / 2
1104
1494
 
1105
1495
  sample = random.gamma(key, alpha)