bartz 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -34,6 +34,7 @@ range of possible values.
34
34
  """
35
35
 
36
36
  import functools
37
+ import math
37
38
 
38
39
  import jax
39
40
  from jax import random
@@ -54,7 +55,9 @@ def init(*,
54
55
  small_float=jnp.float32,
55
56
  large_float=jnp.float32,
56
57
  min_points_per_leaf=None,
57
- suffstat_batch_size='auto',
58
+ resid_batch_size='auto',
59
+ count_batch_size='auto',
60
+ save_ratios=False,
58
61
  ):
59
62
  """
60
63
  Make a BART posterior sampling MCMC initial state.
@@ -82,10 +85,13 @@ def init(*,
82
85
  The dtype for scalars, small arrays, and arrays which require accuracy.
83
86
  min_points_per_leaf : int, optional
84
87
  The minimum number of data points in a leaf node. 0 if not specified.
85
- suffstat_batch_size : int, None, str, default 'auto'
86
- The batch size for computing sufficient statistics. `None` for no
87
- batching. If 'auto', pick a value based on the device of `y`, or the
88
- default device.
88
+ resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
89
+ The batch sizes, along datapoints, for summing the residuals and
90
+ counting the number of datapoints in each leaf. `None` for no batching.
91
+ If 'auto', pick a value based on the device of `y`, or the default
92
+ device.
93
+ save_ratios : bool, default False
94
+ Whether to save the Metropolis-Hastings ratios.
89
95
 
90
96
  Returns
91
97
  -------
@@ -111,6 +117,8 @@ def init(*,
111
117
  'p_nonterminal' : large_float array (d,)
112
118
  The probability of a nonterminal node at each depth, padded with a
113
119
  zero.
120
+ 'p_propose_grow' : large_float array (2 ** (d - 1),)
121
+ The unnormalized probability of picking a leaf for a grow proposal.
114
122
  'sigma2_alpha' : large_float
115
123
  The shape parameter of the inverse gamma prior on the noise variance.
116
124
  'sigma2_beta' : large_float
@@ -121,6 +129,8 @@ def init(*,
121
129
  The response.
122
130
  'X' : int array (p, n)
123
131
  The predictors.
132
+ 'leaf_indices' : int array (num_trees, n)
133
+ The index of the leaf each datapoints falls into, for each tree.
124
134
  'min_points_per_leaf' : int or None
125
135
  The minimum number of data points in a leaf node.
126
136
  'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
@@ -129,8 +139,6 @@ def init(*,
129
139
  'opt' : LeafDict
130
140
  A dictionary with config values:
131
141
 
132
- 'suffstat_batch_size' : int or None
133
- The batch size for computing sufficient statistics.
134
142
  'small_float' : dtype
135
143
  The dtype for large arrays used in the algorithm.
136
144
  'large_float' : dtype
@@ -138,6 +146,8 @@ def init(*,
138
146
  accuracy.
139
147
  'require_min_points' : bool
140
148
  Whether the `min_points_per_leaf` parameter is specified.
149
+ 'resid_batch_size', 'count_batch_size' : int or None
150
+ The data batch sizes for computing the sufficient statistics.
141
151
  """
142
152
 
143
153
  p_nonterminal = jnp.asarray(p_nonterminal, large_float)
@@ -151,24 +161,28 @@ def init(*,
151
161
  small_float = jnp.dtype(small_float)
152
162
  large_float = jnp.dtype(large_float)
153
163
  y = jnp.asarray(y, small_float)
154
- suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
164
+ resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y)
165
+ sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
166
+ sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
155
167
 
156
168
  bart = dict(
157
169
  leaf_trees=make_forest(max_depth, small_float),
158
170
  var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
159
171
  split_trees=make_forest(max_depth - 1, max_split.dtype),
160
172
  resid=jnp.asarray(y, large_float),
161
- sigma2=jnp.ones((), large_float),
173
+ sigma2=sigma2,
162
174
  grow_prop_count=jnp.zeros((), int),
163
175
  grow_acc_count=jnp.zeros((), int),
164
176
  prune_prop_count=jnp.zeros((), int),
165
177
  prune_acc_count=jnp.zeros((), int),
166
178
  p_nonterminal=p_nonterminal,
179
+ p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
167
180
  sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
168
181
  sigma2_beta=jnp.asarray(sigma2_beta, large_float),
169
182
  max_split=jnp.asarray(max_split),
170
183
  y=y,
171
184
  X=jnp.asarray(X),
185
+ leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
172
186
  min_points_per_leaf=(
173
187
  None if min_points_per_leaf is None else
174
188
  jnp.asarray(min_points_per_leaf)
@@ -178,37 +192,61 @@ def init(*,
178
192
  make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
179
193
  ),
180
194
  opt=jaxext.LeafDict(
181
- suffstat_batch_size=suffstat_batch_size,
182
195
  small_float=small_float,
183
196
  large_float=large_float,
184
197
  require_min_points=min_points_per_leaf is not None,
198
+ resid_batch_size=resid_batch_size,
199
+ count_batch_size=count_batch_size,
185
200
  ),
186
201
  )
187
202
 
203
+ if save_ratios:
204
+ bart['ratios'] = dict(
205
+ grow=dict(
206
+ trans_prior=jnp.full(num_trees, jnp.nan),
207
+ likelihood=jnp.full(num_trees, jnp.nan),
208
+ ),
209
+ prune=dict(
210
+ trans_prior=jnp.full(num_trees, jnp.nan),
211
+ likelihood=jnp.full(num_trees, jnp.nan),
212
+ ),
213
+ )
214
+
188
215
  return bart
189
216
 
190
- def _choose_suffstat_batch_size(size, y):
191
- if size == 'auto':
217
+ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y):
218
+
219
+ @functools.cache
220
+ def get_platform():
192
221
  try:
193
222
  device = y.devices().pop()
194
223
  except jax.errors.ConcretizationTypeError:
195
224
  device = jax.devices()[0]
196
225
  platform = device.platform
226
+ if platform not in ('cpu', 'gpu'):
227
+ raise KeyError(f'Unknown platform: {platform}')
228
+ return platform
197
229
 
230
+ if resid_batch_size == 'auto':
231
+ platform = get_platform()
232
+ n = max(1, y.size)
198
233
  if platform == 'cpu':
199
- return None
200
- # maybe I should batch residuals (not counts) for numerical
201
- # accuracy, even if it's slower
234
+ resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
202
235
  elif platform == 'gpu':
203
- return 128 # 128 is good on A100, and V100 at high n
204
- # 512 is good on T4, and V100 at low n
205
- else:
206
- raise KeyError(f'Unknown platform: {platform}')
207
-
208
- elif size is not None:
209
- return int(size)
210
-
211
- return size
236
+ resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
237
+ resid_batch_size = max(1, resid_batch_size)
238
+
239
+ if count_batch_size == 'auto':
240
+ platform = get_platform()
241
+ if platform == 'cpu':
242
+ count_batch_size = None
243
+ elif platform == 'gpu':
244
+ n = max(1, y.size)
245
+ count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
246
+ # /4 is good on V100, /2 on L4/T4, still haven't tried A100
247
+ count_batch_size = max(1, count_batch_size)
248
+
249
+ return resid_batch_size, count_batch_size
212
250
 
213
251
  def step(bart, key):
214
252
  """
@@ -248,14 +286,11 @@ def sample_trees(bart, key):
248
286
 
249
287
  Notes
250
288
  -----
251
- This function zeroes the proposal counters before using them.
289
+ This function zeroes the proposal counters.
252
290
  """
253
- bart = bart.copy()
254
291
  key, subkey = random.split(key)
255
292
  grow_moves, prune_moves = sample_moves(bart, subkey)
256
- bart['var_trees'] = grow_moves['var_tree']
257
- grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
258
- return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
293
+ return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key)
259
294
 
260
295
  def sample_moves(bart, key):
261
296
  """
@@ -274,17 +309,17 @@ def sample_moves(bart, key):
274
309
  The proposals for grow and prune moves. See `grow_move` and `prune_move`.
275
310
  """
276
311
  key = random.split(key, bart['var_trees'].shape[0])
277
- return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
312
+ return _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], key)
278
313
 
279
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
280
- def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
314
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
315
+ def _sample_moves_vmap_trees(*args):
316
+ args, key = args[:-1], args[-1]
281
317
  key, key1 = random.split(key)
282
- args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
283
318
  grow = grow_move(*args, key)
284
319
  prune = prune_move(*args, key1)
285
320
  return grow, prune
286
321
 
287
- def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
322
+ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
288
323
  """
289
324
  Tree structure grow move proposal of BART MCMC.
290
325
 
@@ -304,6 +339,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
304
339
  The maximum split index for each variable.
305
340
  p_nonterminal : array (d,)
306
341
  The probability of a nonterminal node at each depth.
342
+ p_propose_grow : array (2 ** (d - 1),)
343
+ The unnormalized probability of choosing a leaf to grow.
307
344
  key : jax.dtypes.prng_key array
308
345
  A jax random key.
309
346
 
@@ -312,41 +349,49 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
312
349
  grow_move : dict
313
350
  A dictionary with fields:
314
351
 
315
- 'allowed' : bool
316
- Whether the move is possible.
352
+ 'num_growable' : int
353
+ The number of growable leaves.
317
354
  'node' : int
318
- The index of the leaf to grow.
319
- 'var_tree' : array (2 ** (d - 1),)
320
- The new decision axes of the tree.
321
- 'split_tree' : array (2 ** (d - 1),)
322
- The new decision boundaries of the tree.
355
+ The index of the leaf to grow. ``2 ** d`` if there are no growable
356
+ leaves.
357
+ 'left', 'right' : int
358
+ The indices of the children of 'node'.
359
+ 'var', 'split' : int
360
+ The decision axis and boundary of the new rule.
323
361
  'partial_ratio' : float
324
362
  A factor of the Metropolis-Hastings ratio of the move. It lacks
325
363
  the likelihood ratio and the probability of proposing the prune
326
364
  move.
365
+ 'var_tree', 'split_tree' : array (2 ** (d - 1),)
366
+ The updated decision axes and boundaries of the tree.
327
367
  """
328
368
 
329
369
  key, key1, key2 = random.split(key, 3)
330
-
331
- leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
370
+
371
+ leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
332
372
 
333
373
  var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
334
374
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
335
-
375
+
336
376
  split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
337
377
  split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
338
378
 
339
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
379
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
340
380
 
381
+ left = leaf_to_grow << 1
341
382
  return dict(
342
- allowed=allowed,
383
+ num_growable=num_growable,
343
384
  node=leaf_to_grow,
385
+ left=left,
386
+ right=left + 1,
387
+ var=var,
388
+ split=split,
344
389
  partial_ratio=ratio,
345
390
  var_tree=var_tree,
346
391
  split_tree=split_tree,
347
392
  )
348
393
 
349
- def choose_leaf(split_tree, affluence_tree, key):
394
+ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
350
395
  """
351
396
  Choose a leaf node to grow in a tree.
352
397
 
@@ -356,6 +401,8 @@ def choose_leaf(split_tree, affluence_tree, key):
356
401
  The splitting points of the tree.
357
402
  affluence_tree : bool array (2 ** (d - 1),) or None
358
403
  Whether a leaf has enough points to be grown.
404
+ p_propose_grow : array (2 ** (d - 1),)
405
+ The unnormalized probability of choosing a leaf to grow.
359
406
  key : jax.dtypes.prng_key array
360
407
  A jax random key.
361
408
 
@@ -366,19 +413,21 @@ def choose_leaf(split_tree, affluence_tree, key):
366
413
  ``2 ** d``.
367
414
  num_growable : int
368
415
  The number of leaf nodes that can be grown.
416
+ prob_choose : float
417
+ The normalized probability of choosing the selected leaf.
369
418
  num_prunable : int
370
419
  The number of leaf parents that could be pruned, after converting the
371
420
  selected leaf to a non-terminal node.
372
- allowed : bool
373
- Whether the grow move is allowed.
374
421
  """
375
- is_growable, allowed = growable_leaves(split_tree, affluence_tree)
376
- leaf_to_grow = randint_masked(key, is_growable)
377
- leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
422
+ is_growable = growable_leaves(split_tree, affluence_tree)
378
423
  num_growable = jnp.count_nonzero(is_growable)
424
+ distr = jnp.where(is_growable, p_propose_grow, 0)
425
+ leaf_to_grow, distr_norm = categorical(key, distr)
426
+ leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
427
+ prob_choose = distr[leaf_to_grow] / distr_norm
379
428
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
380
429
  num_prunable = jnp.count_nonzero(is_parent)
381
- return leaf_to_grow, num_growable, num_prunable, allowed
430
+ return leaf_to_grow, num_growable, prob_choose, num_prunable
382
431
 
383
432
  def growable_leaves(split_tree, affluence_tree):
384
433
  """
@@ -397,34 +446,32 @@ def growable_leaves(split_tree, affluence_tree):
397
446
  The mask indicating the leaf nodes that can be proposed to grow, i.e.,
398
447
  that are not at the bottom level and have at least two times the number
399
448
  of minimum points per leaf.
400
- allowed : bool
401
- Whether the grow move is allowed, i.e., there are growable leaves.
402
449
  """
403
450
  is_growable = grove.is_actual_leaf(split_tree)
404
451
  if affluence_tree is not None:
405
452
  is_growable &= affluence_tree
406
- return is_growable, jnp.any(is_growable)
453
+ return is_growable
407
454
 
408
- def randint_masked(key, mask):
455
+ def categorical(key, distr):
409
456
  """
410
- Return a random integer in a range, including only some values.
457
+ Return a random integer from an arbitrary distribution.
411
458
 
412
459
  Parameters
413
460
  ----------
414
461
  key : jax.dtypes.prng_key array
415
462
  A jax random key.
416
- mask : bool array (n,)
417
- The mask indicating the allowed values.
463
+ distr : float array (n,)
464
+ An unnormalized probability distribution.
418
465
 
419
466
  Returns
420
467
  -------
421
468
  u : int
422
- A random integer in the range ``[0, n)``, and which satisfies
423
- ``mask[u] == True``. If all values in the mask are `False`, return `n`.
469
+ A random integer in the range ``[0, n)``. If all probabilities are zero,
470
+ return ``n``.
424
471
  """
425
- ecdf = jnp.cumsum(mask)
426
- u = random.randint(key, (), 0, ecdf[-1])
427
- return jnp.searchsorted(ecdf, u, 'right')
472
+ ecdf = jnp.cumsum(distr)
473
+ u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
474
+ return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
428
475
 
429
476
  def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
430
477
  """
@@ -479,7 +526,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
479
526
  filled with `p`. The fill values are not guaranteed to be placed in any
480
527
  particular order. Variables may appear more than once.
481
528
  """
482
-
529
+
483
530
  var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
484
531
  split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
485
532
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
@@ -611,7 +658,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
611
658
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
612
659
  return random.randint(key, (), l, r)
613
660
 
614
- def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
661
+ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
615
662
  """
616
663
  Compute the product of the transition and prior ratios of a grow move.
617
664
 
@@ -640,6 +687,9 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
640
687
  # the two ratios also contain factors num_available_split *
641
688
  # num_available_var, but they cancel out
642
689
 
690
+ # p_prune can't be computed here because it needs the count trees, which are
691
+ # computed in the acceptance phase
692
+
643
693
  prune_allowed = leaf_to_grow != 1
644
694
  # prune allowed <---> the initial tree is not a root
645
695
  # leaf to grow is root --> the tree can only be a root
@@ -647,31 +697,33 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
647
697
 
648
698
  p_grow = jnp.where(prune_allowed, 0.5, 1)
649
699
 
650
- trans_ratio = num_growable / (p_grow * num_prunable)
700
+ inv_trans_ratio = p_grow * prob_choose * num_prunable
651
701
 
652
702
  depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
653
703
  p_parent = p_nonterminal[depth]
654
704
  cp_children = 1 - p_nonterminal[depth + 1]
655
705
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
656
706
 
657
- return trans_ratio * tree_ratio
707
+ return tree_ratio / inv_trans_ratio
658
708
 
659
- def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
709
+ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
660
710
  """
661
711
  Tree structure prune move proposal of BART MCMC.
662
712
 
663
713
  Parameters
664
714
  ----------
665
- var_tree : array (2 ** (d - 1),)
715
+ var_tree : int array (2 ** (d - 1),)
666
716
  The variable indices of the tree.
667
- split_tree : array (2 ** (d - 1),)
717
+ split_tree : int array (2 ** (d - 1),)
668
718
  The splitting points of the tree.
669
719
  affluence_tree : bool array (2 ** (d - 1),) or None
670
720
  Whether a leaf has enough points to be grown.
671
- max_split : array (p,)
721
+ max_split : int array (p,)
672
722
  The maximum split index for each variable.
673
- p_nonterminal : array (d,)
723
+ p_nonterminal : float array (d,)
674
724
  The probability of a nonterminal node at each depth.
725
+ p_propose_grow : float array (2 ** (d - 1),)
726
+ The unnormalized probability of choosing a leaf to grow.
675
727
  key : jax.dtypes.prng_key array
676
728
  A jax random key.
677
729
 
@@ -683,24 +735,29 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
683
735
  'allowed' : bool
684
736
  Whether the move is possible.
685
737
  'node' : int
686
- The index of the node to prune.
738
+ The index of the node to prune. ``2 ** d`` if no node can be pruned.
739
+ 'left', 'right' : int
740
+ The indices of the children of 'node'.
687
741
  'partial_ratio' : float
688
742
  A factor of the Metropolis-Hastings ratio of the move. It lacks
689
743
  the likelihood ratio and the probability of proposing the prune
690
744
  move. This ratio is inverted.
691
745
  """
692
- node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
746
+ node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
693
747
  allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
694
748
 
695
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
749
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune, split_tree)
696
750
 
751
+ left = node_to_prune << 1
697
752
  return dict(
698
753
  allowed=allowed,
699
754
  node=node_to_prune,
755
+ left=left,
756
+ right=left + 1,
700
757
  partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
701
758
  )
702
759
 
703
- def choose_leaf_parent(split_tree, affluence_tree, key):
760
+ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
704
761
  """
705
762
  Pick a non-terminal node with leaf children to prune in a tree.
706
763
 
@@ -710,6 +767,8 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
710
767
  The splitting points of the tree.
711
768
  affluence_tree : bool array (2 ** (d - 1),) or None
712
769
  Whether a leaf has enough points to be grown.
770
+ p_propose_grow : array (2 ** (d - 1),)
771
+ The unnormalized probability of choosing a leaf to grow.
713
772
  key : jax.dtypes.prng_key array
714
773
  A jax random key.
715
774
 
@@ -717,28 +776,50 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
717
776
  -------
718
777
  node_to_prune : int
719
778
  The index of the node to prune. If ``num_prunable == 0``, return
720
- ``split_tree.size``.
779
+ ``2 ** d``.
721
780
  num_prunable : int
722
781
  The number of leaf parents that could be pruned.
723
- num_growable : int
724
- The number of leaf nodes that can be grown, after pruning the chosen
725
- node.
782
+ prob_choose : float
783
+ The normalized probability of choosing the node to prune for growth.
726
784
  """
727
785
  is_prunable = grove.is_leaves_parent(split_tree)
728
- node_to_prune = randint_masked(key, is_prunable)
729
786
  num_prunable = jnp.count_nonzero(is_prunable)
787
+ node_to_prune = randint_masked(key, is_prunable)
788
+ node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
730
789
 
731
- pruned_split_tree = split_tree.at[node_to_prune].set(0)
732
- pruned_affluence_tree = (
790
+ split_tree = split_tree.at[node_to_prune].set(0)
791
+ affluence_tree = (
733
792
  None if affluence_tree is None else
734
793
  affluence_tree.at[node_to_prune].set(True)
735
794
  )
736
- is_growable_leaf, _ = growable_leaves(pruned_split_tree, pruned_affluence_tree)
737
- num_growable = jnp.count_nonzero(is_growable_leaf)
795
+ is_growable_leaf = growable_leaves(split_tree, affluence_tree)
796
+ prob_choose = p_propose_grow[node_to_prune]
797
+ prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
798
+
799
+ return node_to_prune, num_prunable, prob_choose
800
+
801
+ def randint_masked(key, mask):
802
+ """
803
+ Return a random integer in a range, including only some values.
804
+
805
+ Parameters
806
+ ----------
807
+ key : jax.dtypes.prng_key array
808
+ A jax random key.
809
+ mask : bool array (n,)
810
+ The mask indicating the allowed values.
738
811
 
739
- return node_to_prune, num_prunable, num_growable
812
+ Returns
813
+ -------
814
+ u : int
815
+ A random integer in the range ``[0, n)``, and which satisfies
816
+ ``mask[u] == True``. If all values in the mask are `False`, return `n`.
817
+ """
818
+ ecdf = jnp.cumsum(mask)
819
+ u = random.randint(key, (), 0, ecdf[-1])
820
+ return jnp.searchsorted(ecdf, u, 'right')
740
821
 
741
- def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
822
+ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, key):
742
823
  """
743
824
  Accept or reject the proposed moves and sample the new leaf values.
744
825
 
@@ -752,8 +833,6 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
752
833
  prune_moves : dict
753
834
  The proposals for prune moves, batched over the first axis. See
754
835
  `prune_move`.
755
- grow_leaf_indices : int array (num_trees, n)
756
- The leaf indices of the trees proposed by the grow move.
757
836
  key : jax.dtypes.prng_key array
758
837
  A jax random key.
759
838
 
@@ -762,41 +841,339 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
762
841
  bart : dict
763
842
  The new BART mcmc state.
764
843
  """
844
+ bart, grow_moves, prune_moves, count_trees, move_counts, u, z = accept_moves_parallel_stage(bart, grow_moves, prune_moves, key)
845
+ bart, counts = accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z)
846
+ return accept_moves_final_stage(bart, counts, grow_moves, prune_moves)
847
+
848
+ def accept_moves_parallel_stage(bart, grow_moves, prune_moves, key):
849
+ """
850
+ Pre-computes quantities used to accept moves, in parallel across trees.
851
+
852
+ Parameters
853
+ ----------
854
+ bart : dict
855
+ A BART mcmc state.
856
+ grow_moves, prune_moves : dict
857
+ The proposals for the moves, batched over the first axis. See
858
+ `grow_move` and `prune_move`.
859
+ key : jax.dtypes.prng_key array
860
+ A jax random key.
861
+
862
+ Returns
863
+ -------
864
+ bart : dict
865
+ A partially updated BART mcmc state.
866
+ grow_moves, prune_moves : dict
867
+ The proposals for the moves, with the field 'partial_ratio' replaced
868
+ by 'trans_prior_ratio'.
869
+ count_trees : array (num_trees, 2 ** (d - 1))
870
+ The number of points in each potential or actual leaf node.
871
+ move_counts : dict
872
+ The counts of the number of points in the the nodes modified by the
873
+ moves.
874
+ u : float array (num_trees, 2)
875
+ Random uniform values used to accept the moves.
876
+ z : float array (num_trees, 2 ** d)
877
+ Random standard normal values used to sample the new leaf values.
878
+ """
765
879
  bart = bart.copy()
766
- def loop(carry, item):
767
- resid = carry.pop('resid')
768
- resid, carry, trees = accept_move_and_sample_leaves(
880
+
881
+ bart['var_trees'] = grow_moves['var_tree']
882
+ # Since var_tree can contain garbage, I can set the var of leaf to be
883
+ # grown irrespectively of what move I'm gonna accept in the end.
884
+
885
+ bart['leaf_indices'] = apply_grow_to_indices(grow_moves, bart['leaf_indices'], bart['X'])
886
+
887
+ count_trees, move_counts = compute_count_trees(bart['leaf_indices'], grow_moves, prune_moves, bart['opt']['count_batch_size'])
888
+
889
+ grow_moves, prune_moves = complete_ratio(grow_moves, prune_moves, move_counts, bart['min_points_per_leaf'])
890
+
891
+ if bart['opt']['require_min_points']:
892
+ count_half_trees = count_trees[:, :grow_moves['split_tree'].shape[1]]
893
+ bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
894
+
895
+ bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], grow_moves)
896
+
897
+ key, subkey = random.split(key)
898
+ u = random.uniform(subkey, (len(bart['leaf_trees']), 2), bart['opt']['large_float'])
899
+ z = random.normal(key, bart['leaf_trees'].shape, bart['opt']['large_float'])
900
+
901
+ return bart, grow_moves, prune_moves, count_trees, move_counts, u, z
902
+
903
+ def apply_grow_to_indices(grow_moves, leaf_indices, X):
904
+ """
905
+ Update the leaf indices to apply a grow move.
906
+
907
+ Parameters
908
+ ----------
909
+ grow_moves : dict
910
+ The proposals for grow moves. See `grow_move`.
911
+ leaf_indices : array (num_trees, n)
912
+ The index of the leaf each datapoint falls into.
913
+ X : array (p, n)
914
+ The predictors matrix.
915
+
916
+ Returns
917
+ -------
918
+ grow_leaf_indices : array (num_trees, n)
919
+ The updated leaf indices.
920
+ """
921
+ left_child = grow_moves['node'].astype(leaf_indices.dtype) << 1
922
+ go_right = X[grow_moves['var'], :] >= grow_moves['split'][:, None]
923
+ tree_size = jnp.array(2 * grow_moves['split_tree'].shape[1])
924
+ node_to_update = jnp.where(grow_moves['num_growable'], grow_moves['node'], tree_size)
925
+ return jnp.where(
926
+ leaf_indices == node_to_update[:, None],
927
+ left_child[:, None] + go_right,
928
+ leaf_indices,
929
+ )
930
+
931
+ def compute_count_trees(grow_leaf_indices, grow_moves, prune_moves, batch_size):
932
+ """
933
+ Count the number of datapoints in each leaf.
934
+
935
+ Parameters
936
+ ----------
937
+ grow_leaf_indices : int array (num_trees, n)
938
+ The index of the leaf each datapoint falls into, if the grow move is
939
+ accepted.
940
+ grow_moves, prune_moves : dict
941
+ The proposals for the moves. See `grow_move` and `prune_move`.
942
+ batch_size : int or None
943
+ The data batch size to use for the summation.
944
+
945
+ Returns
946
+ -------
947
+ count_trees : int array (num_trees, 2 ** (d - 1))
948
+ The number of points in each potential or actual leaf node.
949
+ counts : dict
950
+ The counts of the number of points in the the nodes modified by the
951
+ moves, organized as two dictionaries 'grow' and 'prune', with subfields
952
+ 'left', 'right', and 'total'.
953
+ """
954
+
955
+ ntree, tree_size = grow_moves['split_tree'].shape
956
+ tree_size *= 2
957
+ counts = dict(grow=dict(), prune=dict())
958
+ tree_indices = jnp.arange(ntree)
959
+
960
+ count_trees = count_datapoints_per_leaf(grow_leaf_indices, tree_size, batch_size)
961
+
962
+ # count datapoints in leaf to grow
963
+ counts['grow']['left'] = count_trees[tree_indices, grow_moves['left']]
964
+ counts['grow']['right'] = count_trees[tree_indices, grow_moves['right']]
965
+ counts['grow']['total'] = counts['grow']['left'] + counts['grow']['right']
966
+ count_trees = count_trees.at[tree_indices, grow_moves['node']].set(counts['grow']['total'])
967
+
968
+ # count datapoints in node to prune
969
+ counts['prune']['left'] = count_trees[tree_indices, prune_moves['left']]
970
+ counts['prune']['right'] = count_trees[tree_indices, prune_moves['right']]
971
+ counts['prune']['total'] = counts['prune']['left'] + counts['prune']['right']
972
+ count_trees = count_trees.at[tree_indices, prune_moves['node']].set(counts['prune']['total'])
973
+
974
+ return count_trees, counts
975
+
976
+ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
977
+ """
978
+ Count the number of datapoints in each leaf.
979
+
980
+ Parameters
981
+ ----------
982
+ leaf_indices : int array (num_trees, n)
983
+ The index of the leaf each datapoint falls into.
984
+ tree_size : int
985
+ The size of the leaf tree array (2 ** d).
986
+ batch_size : int or None
987
+ The data batch size to use for the summation.
988
+
989
+ Returns
990
+ -------
991
+ count_trees : int array (num_trees, 2 ** (d - 1))
992
+ The number of points in each leaf node.
993
+ """
994
+ if batch_size is None:
995
+ return _count_scan(leaf_indices, tree_size)
996
+ else:
997
+ return _count_vec(leaf_indices, tree_size, batch_size)
998
+
999
+ def _count_scan(leaf_indices, tree_size):
1000
+ def loop(_, leaf_indices):
1001
+ return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1002
+ _, count_trees = lax.scan(loop, None, leaf_indices)
1003
+ return count_trees
1004
+
1005
+ def _aggregate_scatter(values, indices, size, dtype):
1006
+ return (jnp
1007
+ .zeros(size, dtype)
1008
+ .at[indices]
1009
+ .add(values)
1010
+ )
1011
+
1012
+ def _count_vec(leaf_indices, tree_size, batch_size):
1013
+ return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
1014
+ # uint16 is super-slow on gpu, don't use it even if n < 2^16
1015
+
1016
+ def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1017
+ ntree, n = indices.shape
1018
+ tree_indices = jnp.arange(ntree)
1019
+ nbatches = n // batch_size + bool(n % batch_size)
1020
+ batch_indices = jnp.arange(n) % nbatches
1021
+ return (jnp
1022
+ .zeros((ntree, size, nbatches), dtype)
1023
+ .at[tree_indices[:, None], indices, batch_indices]
1024
+ .add(values)
1025
+ .sum(axis=2)
1026
+ )
1027
+
1028
+ def complete_ratio(grow_moves, prune_moves, move_counts, min_points_per_leaf):
1029
+ """
1030
+ Complete non-likelihood MH ratio calculation.
1031
+
1032
+ This functions adds the probability of choosing the prune move.
1033
+
1034
+ Parameters
1035
+ ----------
1036
+ grow_moves, prune_moves : dict
1037
+ The proposals for the moves. See `grow_move` and `prune_move`.
1038
+ move_counts : dict
1039
+ The counts of the number of points in the the nodes modified by the
1040
+ moves.
1041
+ min_points_per_leaf : int or None
1042
+ The minimum number of data points in a leaf node.
1043
+
1044
+ Returns
1045
+ -------
1046
+ grow_moves, prune_moves : dict
1047
+ The proposals for the moves, with the field 'partial_ratio' replaced
1048
+ by 'trans_prior_ratio'.
1049
+ """
1050
+ grow_moves = grow_moves.copy()
1051
+ prune_moves = prune_moves.copy()
1052
+ compute_p_prune_vec = jax.vmap(compute_p_prune, in_axes=(0, 0, 0, None))
1053
+ grow_p_prune, prune_p_prune = compute_p_prune_vec(grow_moves, move_counts['grow']['left'], move_counts['grow']['right'], min_points_per_leaf)
1054
+ grow_moves['trans_prior_ratio'] = grow_moves.pop('partial_ratio') * grow_p_prune
1055
+ prune_moves['trans_prior_ratio'] = prune_moves.pop('partial_ratio') * prune_p_prune
1056
+ return grow_moves, prune_moves
1057
+
1058
+ def compute_p_prune(grow_move, grow_left_count, grow_right_count, min_points_per_leaf):
1059
+ """
1060
+ Compute the probability of proposing a prune move.
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ grow_move : dict
1065
+ The proposal for the grow move, see `grow_move`.
1066
+ grow_left_count, grow_right_count : int
1067
+ The number of datapoints in the proposed children of the leaf to grow.
1068
+ min_points_per_leaf : int or None
1069
+ The minimum number of data points in a leaf node.
1070
+
1071
+ Returns
1072
+ -------
1073
+ grow_p_prune : float
1074
+ The probability of proposing a prune move, after accepting the grow
1075
+ move.
1076
+ prune_p_prune : float
1077
+ The probability of proposing the prune move.
1078
+ """
1079
+ other_growable_leaves = grow_move['num_growable'] >= 2
1080
+ new_leaves_growable = grow_move['node'] < grow_move['split_tree'].size // 2
1081
+ if min_points_per_leaf is not None:
1082
+ any_above_threshold = grow_left_count >= 2 * min_points_per_leaf
1083
+ any_above_threshold |= grow_right_count >= 2 * min_points_per_leaf
1084
+ new_leaves_growable &= any_above_threshold
1085
+ grow_again_allowed = other_growable_leaves | new_leaves_growable
1086
+ grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1087
+ prune_p_prune = jnp.where(grow_move['num_growable'], 0.5, 1)
1088
+ return grow_p_prune, prune_p_prune
1089
+
1090
+ def adapt_leaf_trees_to_grow_indices(leaf_trees, grow_moves):
1091
+ """
1092
+ Modify leaf values such that the indices of the grow move work on the
1093
+ original tree.
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ leaf_trees : float array (num_trees, 2 ** d)
1098
+ The leaf values.
1099
+ grow_moves : dict
1100
+ The proposals for grow moves. See `grow_move`.
1101
+
1102
+ Returns
1103
+ -------
1104
+ leaf_trees : float array (num_trees, 2 ** d)
1105
+ The modified leaf values. The value of the leaf to grow is copied to
1106
+ what would be its children if the grow move was accepted.
1107
+ """
1108
+ ntree, _ = leaf_trees.shape
1109
+ tree_indices = jnp.arange(ntree)
1110
+ values_at_node = leaf_trees[tree_indices, grow_moves['node']]
1111
+ return (leaf_trees
1112
+ .at[tree_indices, grow_moves['left']]
1113
+ .set(values_at_node)
1114
+ .at[tree_indices, grow_moves['right']]
1115
+ .set(values_at_node)
1116
+ )
1117
+
1118
+ def accept_moves_sequential_stage(bart, count_trees, grow_moves, prune_moves, move_counts, u, z):
1119
+ """
1120
+ The part of accepting the moves that has to be done one tree at a time.
1121
+
1122
+ Parameters
1123
+ ----------
1124
+ bart : dict
1125
+ A partially updated BART mcmc state.
1126
+ count_trees : array (num_trees, 2 ** (d - 1))
1127
+ The number of points in each potential or actual leaf node.
1128
+ grow_moves, prune_moves : dict
1129
+ The proposals for the moves, with completed ratios. See `grow_move` and
1130
+ `prune_move`.
1131
+ move_counts : dict
1132
+ The counts of the number of points in the the nodes modified by the
1133
+ moves.
1134
+ u : float array (num_trees, 2)
1135
+ Random uniform values used to for proposal and accept decisions.
1136
+ z : float array (num_trees, 2 ** d)
1137
+ Random standard normal values used to sample the new leaf values.
1138
+
1139
+ Returns
1140
+ -------
1141
+ bart : dict
1142
+ A partially updated BART mcmc state.
1143
+ counts : dict
1144
+ The indicators of proposals and acceptances for grow and prune moves.
1145
+ """
1146
+ bart = bart.copy()
1147
+
1148
+ def loop(resid, item):
1149
+ resid, leaf_tree, split_tree, counts, ratios = accept_move_and_sample_leaves(
769
1150
  bart['X'],
770
1151
  len(bart['leaf_trees']),
771
- bart['opt']['suffstat_batch_size'],
1152
+ bart['opt']['resid_batch_size'],
772
1153
  resid,
773
1154
  bart['sigma2'],
774
1155
  bart['min_points_per_leaf'],
775
- carry,
1156
+ 'ratios' in bart,
776
1157
  *item,
777
1158
  )
778
- carry['resid'] = resid
779
- return carry, trees
780
- carry = {
781
- k: jnp.zeros_like(bart[k]) for k in
782
- ['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
783
- }
784
- carry['resid'] = bart['resid']
1159
+ return resid, (leaf_tree, split_tree, counts, ratios)
1160
+
785
1161
  items = (
786
- bart['leaf_trees'],
787
- bart['split_trees'],
788
- bart['affluence_trees'],
789
- grow_moves,
790
- prune_moves,
791
- grow_leaf_indices,
792
- random.split(key, len(bart['leaf_trees'])),
1162
+ bart['leaf_trees'], count_trees,
1163
+ grow_moves, prune_moves, move_counts,
1164
+ bart['leaf_indices'],
1165
+ u, z,
793
1166
  )
794
- carry, trees = lax.scan(loop, carry, items)
795
- bart.update(carry)
796
- bart.update(trees)
797
- return bart
1167
+ resid, (leaf_trees, split_trees, counts, ratios) = lax.scan(loop, bart['resid'], items)
1168
+
1169
+ bart['resid'] = resid
1170
+ bart['leaf_trees'] = leaf_trees
1171
+ bart['split_trees'] = split_trees
1172
+ bart.get('ratios', {}).update(ratios)
798
1173
 
799
- def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
1174
+ return bart, counts
1175
+
1176
+ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, sigma2, min_points_per_leaf, save_ratios, leaf_tree, count_tree, grow_move, prune_move, move_counts, grow_leaf_indices, u, z):
800
1177
  """
801
1178
  Accept or reject a proposed move and sample the new leaf values.
802
1179
 
@@ -806,158 +1183,157 @@ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2,
806
1183
  The predictors.
807
1184
  ntree : int
808
1185
  The number of trees in the forest.
809
- suffstat_batch_size : int, None
810
- The batch size for computing sufficient statistics.
1186
+ resid_batch_size : int, None
1187
+ The batch size for computing the sum of residuals in each leaf.
811
1188
  resid : float array (n,)
812
1189
  The residuals (data minus forest value).
813
1190
  sigma2 : float
814
1191
  The noise variance.
815
1192
  min_points_per_leaf : int or None
816
1193
  The minimum number of data points in a leaf node.
817
- counts : dict
818
- The acceptance counts from the mcmc state dict.
1194
+ save_ratios : bool
1195
+ Whether to save the acceptance ratios.
819
1196
  leaf_tree : float array (2 ** d,)
820
1197
  The leaf values of the tree.
821
- split_tree : int array (2 ** (d - 1),)
822
- The decision boundaries of the tree.
823
- affluence_tree : bool array (2 ** (d - 1),) or None
824
- Whether a leaf has enough points to be grown.
825
- grow_move : dict
826
- The proposal for the grow move. See `grow_move`.
827
- prune_move : dict
828
- The proposal for the prune move. See `prune_move`.
1198
+ count_tree : int array (2 ** d,)
1199
+ The number of datapoints in each leaf.
1200
+ grow_move, prune_move : dict
1201
+ The proposals for the moves, with completed ratios. See `grow_move` and
1202
+ `prune_move`.
829
1203
  grow_leaf_indices : int array (n,)
830
1204
  The leaf indices of the tree proposed by the grow move.
831
- key : jax.dtypes.prng_key array
832
- A jax random key.
1205
+ u : float array (2,)
1206
+ Two uniform random values in [0, 1).
1207
+ z : float array (2 ** d,)
1208
+ Standard normal random values.
833
1209
 
834
1210
  Returns
835
1211
  -------
836
1212
  resid : float array (n,)
837
1213
  The updated residuals (data minus forest value).
1214
+ leaf_tree : float array (2 ** d,)
1215
+ The new leaf values of the tree.
1216
+ split_tree : int array (2 ** (d - 1),)
1217
+ The updated decision boundaries of the tree.
838
1218
  counts : dict
839
- The updated acceptance counts.
840
- trees : dict
841
- The updated tree arrays.
1219
+ The indicators of proposals and acceptances for grow and prune moves.
1220
+ ratios : dict
1221
+ The acceptance ratios for the moves. Empty if not to be saved.
842
1222
  """
843
-
844
- # compute leaf indices in starting tree
845
- grow_node = grow_move['node']
846
- grow_left = grow_node << 1
847
- grow_right = grow_left + 1
848
- leaf_indices = jnp.where(
849
- (grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
850
- grow_node,
851
- grow_leaf_indices,
852
- )
853
1223
 
854
- # compute leaf indices in prune tree
855
- prune_node = prune_move['node']
856
- prune_left = prune_node << 1
857
- prune_right = prune_left + 1
858
- prune_leaf_indices = jnp.where(
859
- (leaf_indices == prune_left) | (leaf_indices == prune_right),
860
- prune_node,
861
- leaf_indices,
862
- )
1224
+ # sum residuals and count units per leaf, in tree proposed by grow move
1225
+ resid_tree = sum_resid(resid, grow_leaf_indices, leaf_tree.size, resid_batch_size)
863
1226
 
864
1227
  # subtract starting tree from function
865
- resid += leaf_tree[leaf_indices]
1228
+ resid_tree += count_tree * leaf_tree
866
1229
 
867
- # aggregate residuals and count units per leaf
868
- grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
1230
+ # get indices of grow move
1231
+ grow_node = grow_move['node']
1232
+ assert grow_node.dtype == jnp.int32
1233
+ grow_left = grow_move['left']
1234
+ grow_right = grow_move['right']
869
1235
 
870
- # compute aggregations in starting tree
871
- # I do not zero the children because garbage there does not matter
872
- resid_tree = (grow_resid_tree.at[grow_node]
873
- .set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
874
- count_tree = (grow_count_tree.at[grow_node]
875
- .set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
1236
+ # sum residuals in leaf to grow
1237
+ grow_resid_left = resid_tree[grow_left]
1238
+ grow_resid_right = resid_tree[grow_right]
1239
+ grow_resid_total = grow_resid_left + grow_resid_right
1240
+ resid_tree = resid_tree.at[grow_node].set(grow_resid_total)
876
1241
 
877
- # compute aggregations in prune tree
878
- prune_resid_tree = (resid_tree.at[prune_node]
879
- .set(resid_tree[prune_left] + resid_tree[prune_right]))
880
- prune_count_tree = (count_tree.at[prune_node]
881
- .set(count_tree[prune_left] + count_tree[prune_right]))
1242
+ # get indices of prune move
1243
+ prune_node = prune_move['node']
1244
+ assert prune_node.dtype == jnp.int32
1245
+ prune_left = prune_move['left']
1246
+ prune_right = prune_move['right']
882
1247
 
883
- # compute affluence trees
884
- if min_points_per_leaf is not None:
885
- grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
886
- prune_affluence_tree = affluence_tree.at[prune_node].set(True)
1248
+ # sum residuals in node to prune
1249
+ prune_resid_left = resid_tree[prune_left]
1250
+ prune_resid_right = resid_tree[prune_right]
1251
+ prune_resid_total = prune_resid_left + prune_resid_right
1252
+ resid_tree = resid_tree.at[prune_node].set(prune_resid_total)
887
1253
 
888
- # compute probability of proposing prune
889
- grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
890
- prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
1254
+ # Now resid_tree and count_tree contain correct values whatever move is
1255
+ # accepted.
891
1256
 
892
1257
  # compute likelihood ratios
893
- grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
894
- prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
1258
+ grow_lk_ratio = compute_likelihood_ratio(grow_resid_total, grow_resid_left, grow_resid_right, move_counts['grow']['total'], move_counts['grow']['left'], move_counts['grow']['right'], sigma2, ntree)
1259
+ prune_lk_ratio = compute_likelihood_ratio(prune_resid_total, prune_resid_left, prune_resid_right, move_counts['prune']['total'], move_counts['prune']['left'], move_counts['prune']['right'], sigma2, ntree)
895
1260
 
896
1261
  # compute acceptance ratios
897
- grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
898
- prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
1262
+ grow_ratio = grow_move['trans_prior_ratio'] * grow_lk_ratio
1263
+ if min_points_per_leaf is not None:
1264
+ grow_ratio = jnp.where(move_counts['grow']['left'] >= min_points_per_leaf, grow_ratio, 0)
1265
+ grow_ratio = jnp.where(move_counts['grow']['right'] >= min_points_per_leaf, grow_ratio, 0)
1266
+ prune_ratio = prune_move['trans_prior_ratio'] * prune_lk_ratio
899
1267
  prune_ratio = lax.reciprocal(prune_ratio)
900
1268
 
901
- # random coins in [0, 1) for proposal and acceptance
902
- key, subkey = random.split(key)
903
- u0, u1 = random.uniform(subkey, (2,))
1269
+ # save acceptance ratios
1270
+ ratios = {}
1271
+ if save_ratios:
1272
+ ratios.update(
1273
+ grow=dict(
1274
+ trans_prior=grow_move['trans_prior_ratio'],
1275
+ likelihood=grow_lk_ratio,
1276
+ ),
1277
+ prune=dict(
1278
+ trans_prior=lax.reciprocal(prune_move['trans_prior_ratio']),
1279
+ likelihood=lax.reciprocal(prune_lk_ratio),
1280
+ ),
1281
+ )
904
1282
 
905
1283
  # determine what move to propose (not proposing anything is an option)
906
- p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
907
- try_grow = u0 < p_grow
1284
+ grow_allowed = grow_move['num_growable'].astype(bool)
1285
+ p_grow = jnp.where(grow_allowed & prune_move['allowed'], 0.5, grow_allowed)
1286
+ try_grow = u[0] < p_grow # use < instead of <= because coins are in [0, 1)
908
1287
  try_prune = prune_move['allowed'] & ~try_grow
909
1288
 
910
1289
  # determine whether to accept the move
911
- do_grow = try_grow & (u1 < grow_ratio)
912
- do_prune = try_prune & (u1 < prune_ratio)
913
-
914
- # pick trees for chosen move
915
- trees = {}
916
- split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
917
- # the prune var tree is equal to the initial one, because I leave garbage values behind
918
- split_tree = split_tree.at[prune_node].set(
919
- jnp.where(do_prune, 0, split_tree[prune_node]))
920
- if min_points_per_leaf is not None:
921
- affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
922
- affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
923
- resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
924
- count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
925
- resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
926
- count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
927
-
928
- # update acceptance counts
929
- counts = counts.copy()
930
- counts['grow_prop_count'] += try_grow
931
- counts['grow_acc_count'] += do_grow
932
- counts['prune_prop_count'] += try_prune
933
- counts['prune_acc_count'] += do_prune
934
-
935
- # compute leaves posterior
936
- prec_lk = count_tree / sigma2
1290
+ do_grow = try_grow & (u[1] < grow_ratio)
1291
+ do_prune = try_prune & (u[1] < prune_ratio)
1292
+
1293
+ # pick split tree for chosen move
1294
+ split_tree = grow_move['split_tree']
1295
+ split_tree = split_tree.at[jnp.where(do_grow, split_tree.size, grow_node)].set(0)
1296
+ split_tree = split_tree.at[jnp.where(do_prune, prune_node, split_tree.size)].set(0)
1297
+ # I can leave garbage in var_tree, resid_tree, count_tree
1298
+
1299
+ # compute leaves posterior and sample leaves
1300
+ inv_sigma2 = lax.reciprocal(sigma2)
1301
+ prec_lk = count_tree * inv_sigma2
937
1302
  var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
938
- mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
939
-
940
- # sample leaves
941
- z = random.normal(key, mean_post.shape, mean_post.dtype)
1303
+ mean_post = resid_tree * inv_sigma2 * var_post # = mean_lk * prec_lk * var_post
1304
+ initial_leaf_tree = leaf_tree
942
1305
  leaf_tree = mean_post + z * jnp.sqrt(var_post)
943
1306
 
944
- # add new tree to function
945
- leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
946
- leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
947
- resid -= leaf_tree[leaf_indices]
1307
+ # copy leaves around such that the grow leaf indices select the right leaf
1308
+ leaf_tree = (leaf_tree
1309
+ .at[jnp.where(do_prune, prune_left, leaf_tree.size)]
1310
+ .set(leaf_tree[prune_node])
1311
+ .at[jnp.where(do_prune, prune_right, leaf_tree.size)]
1312
+ .set(leaf_tree[prune_node])
1313
+ )
1314
+ leaf_tree = (leaf_tree
1315
+ .at[jnp.where(do_grow, leaf_tree.size, grow_left)]
1316
+ .set(leaf_tree[grow_node])
1317
+ .at[jnp.where(do_grow, leaf_tree.size, grow_right)]
1318
+ .set(leaf_tree[grow_node])
1319
+ )
1320
+
1321
+ # replace old tree with new tree in function values
1322
+ resid += (initial_leaf_tree - leaf_tree)[grow_leaf_indices]
948
1323
 
949
- # pack trees
950
- trees = {
951
- 'leaf_trees': leaf_tree,
952
- 'split_trees': split_tree,
953
- 'affluence_trees': affluence_tree,
954
- }
1324
+ # pack proposal and acceptance indicators
1325
+ counts = dict(
1326
+ grow_prop_count=try_grow,
1327
+ grow_acc_count=do_grow,
1328
+ prune_prop_count=try_prune,
1329
+ prune_acc_count=do_prune,
1330
+ )
955
1331
 
956
- return resid, counts, trees
1332
+ return resid, leaf_tree, split_tree, counts, ratios
957
1333
 
958
- def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
1334
+ def sum_resid(resid, leaf_indices, tree_size, batch_size):
959
1335
  """
960
- Compute the sufficient statistics for the likelihood ratio of a tree move.
1336
+ Sum the residuals in each leaf.
961
1337
 
962
1338
  Parameters
963
1339
  ----------
@@ -968,104 +1344,56 @@ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
968
1344
  tree_size : int
969
1345
  The size of the tree array (2 ** d).
970
1346
  batch_size : int, None
971
- The batch size for the aggregation. Batching increases numerical
1347
+ The data batch size for the aggregation. Batching increases numerical
972
1348
  accuracy and parallelism.
973
1349
 
974
1350
  Returns
975
1351
  -------
976
1352
  resid_tree : float array (2 ** d,)
977
1353
  The sum of the residuals at data points in each leaf.
978
- count_tree : int array (2 ** d,)
979
- The number of data points in each leaf.
980
1354
  """
981
1355
  if batch_size is None:
982
1356
  aggr_func = _aggregate_scatter
983
1357
  else:
984
- aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
985
- resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
986
- count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
987
- return resid_tree, count_tree
988
-
989
- def _aggregate_scatter(values, indices, size, dtype):
990
- return (jnp
991
- .zeros(size, dtype)
992
- .at[indices]
993
- .add(values)
994
- )
1358
+ aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1359
+ return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
995
1360
 
996
- def _aggregate_batched(values, indices, size, dtype, batch_size):
997
- nbatches = indices.size // batch_size + bool(indices.size % batch_size)
998
- batch_indices = jnp.arange(indices.size) // batch_size
1361
+ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1362
+ n, = indices.shape
1363
+ nbatches = n // batch_size + bool(n % batch_size)
1364
+ batch_indices = jnp.arange(n) % nbatches
999
1365
  return (jnp
1000
- .zeros((nbatches, size), dtype)
1001
- .at[batch_indices, indices]
1366
+ .zeros((size, nbatches), dtype)
1367
+ .at[indices, batch_indices]
1002
1368
  .add(values)
1003
- .sum(axis=0)
1369
+ .sum(axis=1)
1004
1370
  )
1005
1371
 
1006
- def compute_p_prune_back(new_split_tree, new_affluence_tree):
1007
- """
1008
- Compute the probability of proposing a prune move after doing a grow move.
1009
-
1010
- Parameters
1011
- ----------
1012
- new_split_tree : int array (2 ** (d - 1),)
1013
- The decision boundaries of the tree, after the grow move.
1014
- new_affluence_tree : bool array (2 ** (d - 1),)
1015
- Which leaves have enough points to be grown, after the grow move.
1016
-
1017
- Returns
1018
- -------
1019
- p_prune : float
1020
- The probability of proposing a prune move after the grow move. This is
1021
- 0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
1022
- at least the node just grown can be pruned.
1023
- """
1024
- _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
1025
- return jnp.where(grow_again_allowed, 0.5, 1)
1026
-
1027
- def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
1372
+ def compute_likelihood_ratio(total_resid, left_resid, right_resid, total_count, left_count, right_count, sigma2, n_tree):
1028
1373
  """
1029
1374
  Compute the likelihood ratio of a grow move.
1030
1375
 
1031
1376
  Parameters
1032
1377
  ----------
1033
- resid_tree : float array (2 ** d,)
1034
- The sum of the residuals at data points in each leaf.
1035
- count_tree : int array (2 ** d,)
1036
- The number of data points in each leaf.
1378
+ total_resid : float
1379
+ The sum of the residuals in the leaf to grow.
1380
+ left_resid, right_resid : float
1381
+ The sum of the residuals in the left/right child of the leaf to grow.
1382
+ total_count : int
1383
+ The number of datapoints in the leaf to grow.
1384
+ left_count, right_count : int
1385
+ The number of datapoints in the left/right child of the leaf to grow.
1037
1386
  sigma2 : float
1038
1387
  The noise variance.
1039
- node : int
1040
- The index of the leaf that has been grown.
1041
1388
  n_tree : int
1042
1389
  The number of trees in the forest.
1043
- min_points_per_leaf : int or None
1044
- The minimum number of data points in a leaf node.
1045
1390
 
1046
1391
  Returns
1047
1392
  -------
1048
1393
  ratio : float
1049
1394
  The likelihood ratio P(data | new tree) / P(data | old tree).
1050
-
1051
- Notes
1052
- -----
1053
- The ratio is set to 0 if the grow move would create leaves with not enough
1054
- datapoints per leaf, although this is part of the prior rather than the
1055
- likelihood.
1056
1395
  """
1057
1396
 
1058
- left_child = node << 1
1059
- right_child = left_child + 1
1060
-
1061
- left_resid = resid_tree[left_child]
1062
- right_resid = resid_tree[right_child]
1063
- total_resid = left_resid + right_resid
1064
-
1065
- left_count = count_tree[left_child]
1066
- right_count = count_tree[right_child]
1067
- total_count = left_count + right_count
1068
-
1069
1397
  sigma_mu2 = 1 / n_tree
1070
1398
  sigma2_left = sigma2 + left_count * sigma_mu2
1071
1399
  sigma2_right = sigma2 + right_count * sigma_mu2
@@ -1079,13 +1407,67 @@ def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_p
1079
1407
  total_resid * total_resid / sigma2_total
1080
1408
  )
1081
1409
 
1082
- ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1410
+ return jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1083
1411
 
1084
- if min_points_per_leaf is not None:
1085
- ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
1086
- ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
1412
+ def accept_moves_final_stage(bart, counts, grow_moves, prune_moves):
1413
+ """
1414
+ The final part of accepting the moves, in parallel across trees.
1415
+
1416
+ Parameters
1417
+ ----------
1418
+ bart : dict
1419
+ A partially updated BART mcmc state.
1420
+ counts : dict
1421
+ The indicators of proposals and acceptances for grow and prune moves.
1422
+ grow_moves, prune_moves : dict
1423
+ The proposals for the moves. See `grow_move` and `prune_move`.
1424
+
1425
+ Returns
1426
+ -------
1427
+ bart : dict
1428
+ The fully updated BART mcmc state.
1429
+ """
1430
+ bart = bart.copy()
1431
+
1432
+ for k, v in counts.items():
1433
+ bart[k] = jnp.sum(v, axis=0)
1087
1434
 
1088
- return ratio
1435
+ bart['leaf_indices'] = apply_moves_to_indices(bart['leaf_indices'], counts, grow_moves, prune_moves)
1436
+
1437
+ return bart
1438
+
1439
+ def apply_moves_to_indices(leaf_indices, counts, grow_moves, prune_moves):
1440
+ """
1441
+ Update the leaf indices to match the accepted move.
1442
+
1443
+ Parameters
1444
+ ----------
1445
+ leaf_indices : int array (num_trees, n)
1446
+ The index of the leaf each datapoint falls into, if the grow move was
1447
+ accepted.
1448
+ counts : dict
1449
+ The indicators of proposals and acceptances for grow and prune moves.
1450
+ grow_moves, prune_moves : dict
1451
+ The proposals for the moves. See `grow_move` and `prune_move`.
1452
+
1453
+ Returns
1454
+ -------
1455
+ leaf_indices : int array (num_trees, n)
1456
+ The updated leaf indices.
1457
+ """
1458
+ mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1459
+ cond = (leaf_indices & mask) == grow_moves['left'][:, None]
1460
+ leaf_indices = jnp.where(
1461
+ cond & ~counts['grow_acc_count'][:, None],
1462
+ grow_moves['node'][:, None].astype(leaf_indices.dtype),
1463
+ leaf_indices,
1464
+ )
1465
+ cond = (leaf_indices & mask) == prune_moves['left'][:, None]
1466
+ return jnp.where(
1467
+ cond & counts['prune_acc_count'][:, None],
1468
+ prune_moves['node'][:, None].astype(leaf_indices.dtype),
1469
+ leaf_indices,
1470
+ )
1089
1471
 
1090
1472
  def sample_sigma(bart, key):
1091
1473
  """
@@ -1107,7 +1489,7 @@ def sample_sigma(bart, key):
1107
1489
 
1108
1490
  resid = bart['resid']
1109
1491
  alpha = bart['sigma2_alpha'] + resid.size / 2
1110
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
1492
+ norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
1111
1493
  beta = bart['sigma2_beta'] + norm2 / 2
1112
1494
 
1113
1495
  sample = random.gamma(key, alpha)