bartz 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -10,10 +10,10 @@
10
10
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
11
  # copies of the Software, and to permit persons to whom the Software is
12
12
  # furnished to do so, subject to the following conditions:
13
- #
13
+ #
14
14
  # The above copyright notice and this permission notice shall be included in all
15
15
  # copies or substantial portions of the Software.
16
- #
16
+ #
17
17
  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
18
  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
19
  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
@@ -34,6 +34,7 @@ range of possible values.
34
34
  """
35
35
 
36
36
  import functools
37
+ import math
37
38
 
38
39
  import jax
39
40
  from jax import random
@@ -54,7 +55,9 @@ def init(*,
54
55
  small_float=jnp.float32,
55
56
  large_float=jnp.float32,
56
57
  min_points_per_leaf=None,
57
- suffstat_batch_size='auto',
58
+ resid_batch_size='auto',
59
+ count_batch_size='auto',
60
+ save_ratios=False,
58
61
  ):
59
62
  """
60
63
  Make a BART posterior sampling MCMC initial state.
@@ -82,10 +85,13 @@ def init(*,
82
85
  The dtype for scalars, small arrays, and arrays which require accuracy.
83
86
  min_points_per_leaf : int, optional
84
87
  The minimum number of data points in a leaf node. 0 if not specified.
85
- suffstat_batch_size : int, None, str, default 'auto'
86
- The batch size for computing sufficient statistics. `None` for no
87
- batching. If 'auto', pick a value based on the device of `y`, or the
88
- default device.
88
+ resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
89
+ The batch sizes, along datapoints, for summing the residuals and
90
+ counting the number of datapoints in each leaf. `None` for no batching.
91
+ If 'auto', pick a value based on the device of `y`, or the default
92
+ device.
93
+ save_ratios : bool, default False
94
+ Whether to save the Metropolis-Hastings ratios.
89
95
 
90
96
  Returns
91
97
  -------
@@ -111,6 +117,8 @@ def init(*,
111
117
  'p_nonterminal' : large_float array (d,)
112
118
  The probability of a nonterminal node at each depth, padded with a
113
119
  zero.
120
+ 'p_propose_grow' : large_float array (2 ** (d - 1),)
121
+ The unnormalized probability of picking a leaf for a grow proposal.
114
122
  'sigma2_alpha' : large_float
115
123
  The shape parameter of the inverse gamma prior on the noise variance.
116
124
  'sigma2_beta' : large_float
@@ -121,6 +129,8 @@ def init(*,
121
129
  The response.
122
130
  'X' : int array (p, n)
123
131
  The predictors.
132
+ 'leaf_indices' : int array (num_trees, n)
133
+ The index of the leaf each datapoints falls into, for each tree.
124
134
  'min_points_per_leaf' : int or None
125
135
  The minimum number of data points in a leaf node.
126
136
  'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
@@ -129,8 +139,6 @@ def init(*,
129
139
  'opt' : LeafDict
130
140
  A dictionary with config values:
131
141
 
132
- 'suffstat_batch_size' : int or None
133
- The batch size for computing sufficient statistics.
134
142
  'small_float' : dtype
135
143
  The dtype for large arrays used in the algorithm.
136
144
  'large_float' : dtype
@@ -138,6 +146,16 @@ def init(*,
138
146
  accuracy.
139
147
  'require_min_points' : bool
140
148
  Whether the `min_points_per_leaf` parameter is specified.
149
+ 'resid_batch_size', 'count_batch_size' : int or None
150
+ The data batch sizes for computing the sufficient statistics.
151
+ 'ratios' : dict, optional
152
+ If `save_ratios` is True, this field is present. It has the fields:
153
+
154
+ 'log_trans_prior' : large_float array (num_trees,)
155
+ The log transition and prior Metropolis-Hastings ratio for the
156
+ proposed move on each tree.
157
+ 'log_likelihood' : large_float array (num_trees,)
158
+ The log likelihood ratio.
141
159
  """
142
160
 
143
161
  p_nonterminal = jnp.asarray(p_nonterminal, large_float)
@@ -151,24 +169,28 @@ def init(*,
151
169
  small_float = jnp.dtype(small_float)
152
170
  large_float = jnp.dtype(large_float)
153
171
  y = jnp.asarray(y, small_float)
154
- suffstat_batch_size = _choose_suffstat_batch_size(suffstat_batch_size, y)
172
+ resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees)
173
+ sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
174
+ sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
155
175
 
156
176
  bart = dict(
157
177
  leaf_trees=make_forest(max_depth, small_float),
158
178
  var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
159
179
  split_trees=make_forest(max_depth - 1, max_split.dtype),
160
180
  resid=jnp.asarray(y, large_float),
161
- sigma2=jnp.ones((), large_float),
181
+ sigma2=sigma2,
162
182
  grow_prop_count=jnp.zeros((), int),
163
183
  grow_acc_count=jnp.zeros((), int),
164
184
  prune_prop_count=jnp.zeros((), int),
165
185
  prune_acc_count=jnp.zeros((), int),
166
186
  p_nonterminal=p_nonterminal,
187
+ p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
167
188
  sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
168
189
  sigma2_beta=jnp.asarray(sigma2_beta, large_float),
169
190
  max_split=jnp.asarray(max_split),
170
191
  y=y,
171
192
  X=jnp.asarray(X),
193
+ leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
172
194
  min_points_per_leaf=(
173
195
  None if min_points_per_leaf is None else
174
196
  jnp.asarray(min_points_per_leaf)
@@ -178,37 +200,59 @@ def init(*,
178
200
  make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
179
201
  ),
180
202
  opt=jaxext.LeafDict(
181
- suffstat_batch_size=suffstat_batch_size,
182
203
  small_float=small_float,
183
204
  large_float=large_float,
184
205
  require_min_points=min_points_per_leaf is not None,
206
+ resid_batch_size=resid_batch_size,
207
+ count_batch_size=count_batch_size,
185
208
  ),
186
209
  )
187
210
 
211
+ if save_ratios:
212
+ bart['ratios'] = dict(
213
+ log_trans_prior=jnp.full(num_trees, jnp.nan),
214
+ log_likelihood=jnp.full(num_trees, jnp.nan),
215
+ )
216
+
188
217
  return bart
189
218
 
190
- def _choose_suffstat_batch_size(size, y):
191
- if size == 'auto':
219
+ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
220
+
221
+ @functools.cache
222
+ def get_platform():
192
223
  try:
193
224
  device = y.devices().pop()
194
225
  except jax.errors.ConcretizationTypeError:
195
226
  device = jax.devices()[0]
196
227
  platform = device.platform
228
+ if platform not in ('cpu', 'gpu'):
229
+ raise KeyError(f'Unknown platform: {platform}')
230
+ return platform
197
231
 
232
+ if resid_batch_size == 'auto':
233
+ platform = get_platform()
234
+ n = max(1, y.size)
198
235
  if platform == 'cpu':
199
- return None
200
- # maybe I should batch residuals (not counts) for numerical
201
- # accuracy, even if it's slower
236
+ resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
202
237
  elif platform == 'gpu':
203
- return 128 # 128 is good on A100, and V100 at high n
204
- # 512 is good on T4, and V100 at low n
205
- else:
206
- raise KeyError(f'Unknown platform: {platform}')
207
-
208
- elif size is not None:
209
- return int(size)
210
-
211
- return size
238
+ resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
239
+ resid_batch_size = max(1, resid_batch_size)
240
+
241
+ if count_batch_size == 'auto':
242
+ platform = get_platform()
243
+ if platform == 'cpu':
244
+ count_batch_size = None
245
+ elif platform == 'gpu':
246
+ n = max(1, y.size)
247
+ count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
248
+ # /4 is good on V100, /2 on L4/T4, still haven't tried A100
249
+ max_memory = 2 ** 29
250
+ itemsize = 4
251
+ min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
252
+ count_batch_size = max(count_batch_size, min_batch_size)
253
+ count_batch_size = max(1, count_batch_size)
254
+
255
+ return resid_batch_size, count_batch_size
212
256
 
213
257
  def step(bart, key):
214
258
  """
@@ -248,14 +292,11 @@ def sample_trees(bart, key):
248
292
 
249
293
  Notes
250
294
  -----
251
- This function zeroes the proposal counters before using them.
295
+ This function zeroes the proposal counters.
252
296
  """
253
- bart = bart.copy()
254
297
  key, subkey = random.split(key)
255
- grow_moves, prune_moves = sample_moves(bart, subkey)
256
- bart['var_trees'] = grow_moves['var_tree']
257
- grow_leaf_indices = grove.traverse_forest(bart['X'], grow_moves['var_tree'], grow_moves['split_tree'])
258
- return accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key)
298
+ moves = sample_moves(bart, subkey)
299
+ return accept_moves_and_sample_leaves(bart, moves, key)
259
300
 
260
301
  def sample_moves(bart, key):
261
302
  """
@@ -270,21 +311,75 @@ def sample_moves(bart, key):
270
311
 
271
312
  Returns
272
313
  -------
273
- grow_moves, prune_moves : dict
274
- The proposals for grow and prune moves. See `grow_move` and `prune_move`.
314
+ moves : dict
315
+ A dictionary with fields:
316
+
317
+ 'allowed' : bool array (num_trees,)
318
+ Whether the move is possible.
319
+ 'grow' : bool array (num_trees,)
320
+ Whether the move is a grow move or a prune move.
321
+ 'num_growable' : int array (num_trees,)
322
+ The number of growable leaves in the original tree.
323
+ 'node' : int array (num_trees,)
324
+ The index of the leaf to grow or node to prune.
325
+ 'left', 'right' : int array (num_trees,)
326
+ The indices of the children of 'node'.
327
+ 'partial_ratio' : float array (num_trees,)
328
+ A factor of the Metropolis-Hastings ratio of the move. It lacks
329
+ the likelihood ratio and the probability of proposing the prune
330
+ move. If the move is Prune, the ratio is inverted.
331
+ 'grow_var' : int array (num_trees,)
332
+ The decision axes of the new rules.
333
+ 'grow_split' : int array (num_trees,)
334
+ The decision boundaries of the new rules.
335
+ 'var_trees' : int array (num_trees, 2 ** (d - 1))
336
+ The updated decision axes of the trees, valid whatever move.
337
+ 'logu' : float array (num_trees,)
338
+ The logarithm of a uniform (0, 1] random variable to be used to
339
+ accept the move. It's in (-oo, 0].
275
340
  """
276
- key = random.split(key, bart['var_trees'].shape[0])
277
- return sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], key)
341
+ ntree = bart['leaf_trees'].shape[0]
342
+ key = random.split(key, 1 + ntree)
343
+ key, subkey = key[0], key[1:]
344
+
345
+ # compute moves
346
+ grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey)
347
+
348
+ u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
278
349
 
279
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, 0))
280
- def sample_moves_vmap_trees(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
350
+ # choose between grow or prune
351
+ grow_allowed = grow_moves['num_growable'].astype(bool)
352
+ p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
353
+ grow = u < p_grow # use < instead of <= because u is in [0, 1)
354
+
355
+ # compute children indices
356
+ node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
357
+ left = node << 1
358
+ right = left + 1
359
+
360
+ return dict(
361
+ allowed=grow | prune_moves['allowed'],
362
+ grow=grow,
363
+ num_growable=grow_moves['num_growable'],
364
+ node=node,
365
+ left=left,
366
+ right=right,
367
+ partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
368
+ grow_var=grow_moves['var'],
369
+ grow_split=grow_moves['split'],
370
+ var_trees=grow_moves['var_tree'],
371
+ logu=jnp.log1p(-logu),
372
+ )
373
+
374
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
375
+ def _sample_moves_vmap_trees(*args):
376
+ args, key = args[:-1], args[-1]
281
377
  key, key1 = random.split(key)
282
- args = var_tree, split_tree, affluence_tree, max_split, p_nonterminal
283
378
  grow = grow_move(*args, key)
284
379
  prune = prune_move(*args, key1)
285
380
  return grow, prune
286
381
 
287
- def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
382
+ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
288
383
  """
289
384
  Tree structure grow move proposal of BART MCMC.
290
385
 
@@ -304,6 +399,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
304
399
  The maximum split index for each variable.
305
400
  p_nonterminal : array (d,)
306
401
  The probability of a nonterminal node at each depth.
402
+ p_propose_grow : array (2 ** (d - 1),)
403
+ The unnormalized probability of choosing a leaf to grow.
307
404
  key : jax.dtypes.prng_key array
308
405
  A jax random key.
309
406
 
@@ -312,41 +409,42 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, ke
312
409
  grow_move : dict
313
410
  A dictionary with fields:
314
411
 
315
- 'allowed' : bool
316
- Whether the move is possible.
412
+ 'num_growable' : int
413
+ The number of growable leaves.
317
414
  'node' : int
318
- The index of the leaf to grow.
319
- 'var_tree' : array (2 ** (d - 1),)
320
- The new decision axes of the tree.
321
- 'split_tree' : array (2 ** (d - 1),)
322
- The new decision boundaries of the tree.
415
+ The index of the leaf to grow. ``2 ** d`` if there are no growable
416
+ leaves.
417
+ 'var', 'split' : int
418
+ The decision axis and boundary of the new rule.
323
419
  'partial_ratio' : float
324
420
  A factor of the Metropolis-Hastings ratio of the move. It lacks
325
421
  the likelihood ratio and the probability of proposing the prune
326
422
  move.
423
+ 'var_tree' : array (2 ** (d - 1),)
424
+ The updated decision axes of the tree.
327
425
  """
328
426
 
329
427
  key, key1, key2 = random.split(key, 3)
330
-
331
- leaf_to_grow, num_growable, num_prunable, allowed = choose_leaf(split_tree, affluence_tree, key)
428
+
429
+ leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
332
430
 
333
431
  var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
334
432
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
335
-
433
+
336
434
  split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
337
- split_tree = split_tree.at[leaf_to_grow].set(split.astype(split_tree.dtype))
338
435
 
339
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, split_tree)
436
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)
340
437
 
341
438
  return dict(
342
- allowed=allowed,
439
+ num_growable=num_growable,
343
440
  node=leaf_to_grow,
441
+ var=var,
442
+ split=split,
344
443
  partial_ratio=ratio,
345
444
  var_tree=var_tree,
346
- split_tree=split_tree,
347
445
  )
348
446
 
349
- def choose_leaf(split_tree, affluence_tree, key):
447
+ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
350
448
  """
351
449
  Choose a leaf node to grow in a tree.
352
450
 
@@ -356,6 +454,8 @@ def choose_leaf(split_tree, affluence_tree, key):
356
454
  The splitting points of the tree.
357
455
  affluence_tree : bool array (2 ** (d - 1),) or None
358
456
  Whether a leaf has enough points to be grown.
457
+ p_propose_grow : array (2 ** (d - 1),)
458
+ The unnormalized probability of choosing a leaf to grow.
359
459
  key : jax.dtypes.prng_key array
360
460
  A jax random key.
361
461
 
@@ -366,19 +466,21 @@ def choose_leaf(split_tree, affluence_tree, key):
366
466
  ``2 ** d``.
367
467
  num_growable : int
368
468
  The number of leaf nodes that can be grown.
469
+ prob_choose : float
470
+ The normalized probability of choosing the selected leaf.
369
471
  num_prunable : int
370
472
  The number of leaf parents that could be pruned, after converting the
371
473
  selected leaf to a non-terminal node.
372
- allowed : bool
373
- Whether the grow move is allowed.
374
474
  """
375
- is_growable, allowed = growable_leaves(split_tree, affluence_tree)
376
- leaf_to_grow = randint_masked(key, is_growable)
377
- leaf_to_grow = jnp.where(allowed, leaf_to_grow, 2 * split_tree.size)
475
+ is_growable = growable_leaves(split_tree, affluence_tree)
378
476
  num_growable = jnp.count_nonzero(is_growable)
477
+ distr = jnp.where(is_growable, p_propose_grow, 0)
478
+ leaf_to_grow, distr_norm = categorical(key, distr)
479
+ leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
480
+ prob_choose = distr[leaf_to_grow] / distr_norm
379
481
  is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
380
482
  num_prunable = jnp.count_nonzero(is_parent)
381
- return leaf_to_grow, num_growable, num_prunable, allowed
483
+ return leaf_to_grow, num_growable, prob_choose, num_prunable
382
484
 
383
485
  def growable_leaves(split_tree, affluence_tree):
384
486
  """
@@ -397,34 +499,32 @@ def growable_leaves(split_tree, affluence_tree):
397
499
  The mask indicating the leaf nodes that can be proposed to grow, i.e.,
398
500
  that are not at the bottom level and have at least two times the number
399
501
  of minimum points per leaf.
400
- allowed : bool
401
- Whether the grow move is allowed, i.e., there are growable leaves.
402
502
  """
403
503
  is_growable = grove.is_actual_leaf(split_tree)
404
504
  if affluence_tree is not None:
405
505
  is_growable &= affluence_tree
406
- return is_growable, jnp.any(is_growable)
506
+ return is_growable
407
507
 
408
- def randint_masked(key, mask):
508
+ def categorical(key, distr):
409
509
  """
410
- Return a random integer in a range, including only some values.
510
+ Return a random integer from an arbitrary distribution.
411
511
 
412
512
  Parameters
413
513
  ----------
414
514
  key : jax.dtypes.prng_key array
415
515
  A jax random key.
416
- mask : bool array (n,)
417
- The mask indicating the allowed values.
516
+ distr : float array (n,)
517
+ An unnormalized probability distribution.
418
518
 
419
519
  Returns
420
520
  -------
421
521
  u : int
422
- A random integer in the range ``[0, n)``, and which satisfies
423
- ``mask[u] == True``. If all values in the mask are `False`, return `n`.
522
+ A random integer in the range ``[0, n)``. If all probabilities are zero,
523
+ return ``n``.
424
524
  """
425
- ecdf = jnp.cumsum(mask)
426
- u = random.randint(key, (), 0, ecdf[-1])
427
- return jnp.searchsorted(ecdf, u, 'right')
525
+ ecdf = jnp.cumsum(distr)
526
+ u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
527
+ return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
428
528
 
429
529
  def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
430
530
  """
@@ -479,7 +579,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
479
579
  filled with `p`. The fill values are not guaranteed to be placed in any
480
580
  particular order. Variables may appear more than once.
481
581
  """
482
-
582
+
483
583
  var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
484
584
  split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
485
585
  l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
@@ -611,7 +711,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
611
711
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
612
712
  return random.randint(key, (), l, r)
613
713
 
614
- def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_grow, new_split_tree):
714
+ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
615
715
  """
616
716
  Compute the product of the transition and prior ratios of a grow move.
617
717
 
@@ -626,8 +726,6 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
626
726
  The probability of a nonterminal node at each depth.
627
727
  leaf_to_grow : int
628
728
  The index of the leaf to grow.
629
- new_split_tree : array (2 ** (d - 1),)
630
- The splitting points of the tree, after the leaf is grown.
631
729
 
632
730
  Returns
633
731
  -------
@@ -640,6 +738,9 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
640
738
  # the two ratios also contain factors num_available_split *
641
739
  # num_available_var, but they cancel out
642
740
 
741
+ # p_prune can't be computed here because it needs the count trees, which are
742
+ # computed in the acceptance phase
743
+
643
744
  prune_allowed = leaf_to_grow != 1
644
745
  # prune allowed <---> the initial tree is not a root
645
746
  # leaf to grow is root --> the tree can only be a root
@@ -647,31 +748,33 @@ def compute_partial_ratio(num_growable, num_prunable, p_nonterminal, leaf_to_gro
647
748
 
648
749
  p_grow = jnp.where(prune_allowed, 0.5, 1)
649
750
 
650
- trans_ratio = num_growable / (p_grow * num_prunable)
751
+ inv_trans_ratio = p_grow * prob_choose * num_prunable
651
752
 
652
- depth = grove.tree_depths(new_split_tree.size)[leaf_to_grow]
753
+ depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
653
754
  p_parent = p_nonterminal[depth]
654
755
  cp_children = 1 - p_nonterminal[depth + 1]
655
756
  tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
656
757
 
657
- return trans_ratio * tree_ratio
758
+ return tree_ratio / inv_trans_ratio
658
759
 
659
- def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, key):
760
+ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
660
761
  """
661
762
  Tree structure prune move proposal of BART MCMC.
662
763
 
663
764
  Parameters
664
765
  ----------
665
- var_tree : array (2 ** (d - 1),)
766
+ var_tree : int array (2 ** (d - 1),)
666
767
  The variable indices of the tree.
667
- split_tree : array (2 ** (d - 1),)
768
+ split_tree : int array (2 ** (d - 1),)
668
769
  The splitting points of the tree.
669
770
  affluence_tree : bool array (2 ** (d - 1),) or None
670
771
  Whether a leaf has enough points to be grown.
671
- max_split : array (p,)
772
+ max_split : int array (p,)
672
773
  The maximum split index for each variable.
673
- p_nonterminal : array (d,)
774
+ p_nonterminal : float array (d,)
674
775
  The probability of a nonterminal node at each depth.
776
+ p_propose_grow : float array (2 ** (d - 1),)
777
+ The unnormalized probability of choosing a leaf to grow.
675
778
  key : jax.dtypes.prng_key array
676
779
  A jax random key.
677
780
 
@@ -683,16 +786,16 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
683
786
  'allowed' : bool
684
787
  Whether the move is possible.
685
788
  'node' : int
686
- The index of the node to prune.
789
+ The index of the node to prune. ``2 ** d`` if no node can be pruned.
687
790
  'partial_ratio' : float
688
791
  A factor of the Metropolis-Hastings ratio of the move. It lacks
689
792
  the likelihood ratio and the probability of proposing the prune
690
793
  move. This ratio is inverted.
691
794
  """
692
- node_to_prune, num_prunable, num_growable = choose_leaf_parent(split_tree, affluence_tree, key)
795
+ node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
693
796
  allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
694
797
 
695
- ratio = compute_partial_ratio(num_growable, num_prunable, p_nonterminal, node_to_prune, split_tree)
798
+ ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune)
696
799
 
697
800
  return dict(
698
801
  allowed=allowed,
@@ -700,7 +803,7 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, k
700
803
  partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
701
804
  )
702
805
 
703
- def choose_leaf_parent(split_tree, affluence_tree, key):
806
+ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
704
807
  """
705
808
  Pick a non-terminal node with leaf children to prune in a tree.
706
809
 
@@ -710,6 +813,8 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
710
813
  The splitting points of the tree.
711
814
  affluence_tree : bool array (2 ** (d - 1),) or None
712
815
  Whether a leaf has enough points to be grown.
816
+ p_propose_grow : array (2 ** (d - 1),)
817
+ The unnormalized probability of choosing a leaf to grow.
713
818
  key : jax.dtypes.prng_key array
714
819
  A jax random key.
715
820
 
@@ -717,28 +822,50 @@ def choose_leaf_parent(split_tree, affluence_tree, key):
717
822
  -------
718
823
  node_to_prune : int
719
824
  The index of the node to prune. If ``num_prunable == 0``, return
720
- ``split_tree.size``.
825
+ ``2 ** d``.
721
826
  num_prunable : int
722
827
  The number of leaf parents that could be pruned.
723
- num_growable : int
724
- The number of leaf nodes that can be grown, after pruning the chosen
725
- node.
828
+ prob_choose : float
829
+ The normalized probability of choosing the node to prune for growth.
726
830
  """
727
831
  is_prunable = grove.is_leaves_parent(split_tree)
728
- node_to_prune = randint_masked(key, is_prunable)
729
832
  num_prunable = jnp.count_nonzero(is_prunable)
833
+ node_to_prune = randint_masked(key, is_prunable)
834
+ node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
730
835
 
731
- pruned_split_tree = split_tree.at[node_to_prune].set(0)
732
- pruned_affluence_tree = (
836
+ split_tree = split_tree.at[node_to_prune].set(0)
837
+ affluence_tree = (
733
838
  None if affluence_tree is None else
734
839
  affluence_tree.at[node_to_prune].set(True)
735
840
  )
736
- is_growable_leaf, _ = growable_leaves(pruned_split_tree, pruned_affluence_tree)
737
- num_growable = jnp.count_nonzero(is_growable_leaf)
841
+ is_growable_leaf = growable_leaves(split_tree, affluence_tree)
842
+ prob_choose = p_propose_grow[node_to_prune]
843
+ prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
738
844
 
739
- return node_to_prune, num_prunable, num_growable
845
+ return node_to_prune, num_prunable, prob_choose
740
846
 
741
- def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indices, key):
847
+ def randint_masked(key, mask):
848
+ """
849
+ Return a random integer in a range, including only some values.
850
+
851
+ Parameters
852
+ ----------
853
+ key : jax.dtypes.prng_key array
854
+ A jax random key.
855
+ mask : bool array (n,)
856
+ The mask indicating the allowed values.
857
+
858
+ Returns
859
+ -------
860
+ u : int
861
+ A random integer in the range ``[0, n)``, and which satisfies
862
+ ``mask[u] == True``. If all values in the mask are `False`, return `n`.
863
+ """
864
+ ecdf = jnp.cumsum(mask)
865
+ u = random.randint(key, (), 0, ecdf[-1])
866
+ return jnp.searchsorted(ecdf, u, 'right')
867
+
868
+ def accept_moves_and_sample_leaves(bart, moves, key):
742
869
  """
743
870
  Accept or reject the proposed moves and sample the new leaf values.
744
871
 
@@ -746,14 +873,8 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
746
873
  ----------
747
874
  bart : dict
748
875
  A BART mcmc state.
749
- grow_moves : dict
750
- The proposals for grow moves, batched over the first axis. See
751
- `grow_move`.
752
- prune_moves : dict
753
- The proposals for prune moves, batched over the first axis. See
754
- `prune_move`.
755
- grow_leaf_indices : int array (num_trees, n)
756
- The leaf indices of the trees proposed by the grow move.
876
+ moves : dict
877
+ The proposed moves, see `sample_moves`.
757
878
  key : jax.dtypes.prng_key array
758
879
  A jax random key.
759
880
 
@@ -762,41 +883,409 @@ def accept_moves_and_sample_leaves(bart, grow_moves, prune_moves, grow_leaf_indi
762
883
  bart : dict
763
884
  The new BART mcmc state.
764
885
  """
886
+ bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key)
887
+ bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)
888
+ return accept_moves_final_stage(bart, moves)
889
+
890
+ def accept_moves_parallel_stage(bart, moves, key):
891
+ """
892
+ Pre-computes quantities used to accept moves, in parallel across trees.
893
+
894
+ Parameters
895
+ ----------
896
+ bart : dict
897
+ A BART mcmc state.
898
+ moves : dict
899
+ The proposed moves, see `sample_moves`.
900
+ key : jax.dtypes.prng_key array
901
+ A jax random key.
902
+
903
+ Returns
904
+ -------
905
+ bart : dict
906
+ A partially updated BART mcmc state.
907
+ moves : dict
908
+ The proposed moves, with the field 'partial_ratio' replaced
909
+ by 'log_trans_prior_ratio'.
910
+ count_trees : array (num_trees, 2 ** d)
911
+ The number of points in each potential or actual leaf node.
912
+ move_counts : dict
913
+ The counts of the number of points in the the nodes modified by the
914
+ moves.
915
+ prelkv, prelk, prelf : dict
916
+ Dictionary with pre-computed terms of the likelihood ratios and leaf
917
+ samples.
918
+ """
919
+ bart = bart.copy()
920
+
921
+ # where the move is grow, modify the state like the move was accepted
922
+ bart['var_trees'] = moves['var_trees']
923
+ bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X'])
924
+ bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
925
+
926
+ # count number of datapoints per leaf
927
+ count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size'])
928
+ if bart['opt']['require_min_points']:
929
+ count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
930
+ bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
931
+
932
+ # compute some missing information about moves
933
+ moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
934
+ bart['grow_prop_count'] = jnp.sum(moves['grow'])
935
+ bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
936
+
937
+ prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts)
938
+ prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key)
939
+
940
+ return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
941
+
942
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
943
+ def apply_grow_to_indices(moves, leaf_indices, X):
944
+ """
945
+ Update the leaf indices to apply a grow move.
946
+
947
+ Parameters
948
+ ----------
949
+ moves : dict
950
+ The proposed moves, see `sample_moves`.
951
+ leaf_indices : array (num_trees, n)
952
+ The index of the leaf each datapoint falls into.
953
+ X : array (p, n)
954
+ The predictors matrix.
955
+
956
+ Returns
957
+ -------
958
+ grow_leaf_indices : array (num_trees, n)
959
+ The updated leaf indices.
960
+ """
961
+ left_child = moves['node'].astype(leaf_indices.dtype) << 1
962
+ go_right = X[moves['grow_var'], :] >= moves['grow_split']
963
+ tree_size = jnp.array(2 * moves['var_trees'].size)
964
+ node_to_update = jnp.where(moves['grow'], moves['node'], tree_size)
965
+ return jnp.where(
966
+ leaf_indices == node_to_update,
967
+ left_child + go_right,
968
+ leaf_indices,
969
+ )
970
+
971
+ def compute_count_trees(leaf_indices, moves, batch_size):
972
+ """
973
+ Count the number of datapoints in each leaf.
974
+
975
+ Parameters
976
+ ----------
977
+ grow_leaf_indices : int array (num_trees, n)
978
+ The index of the leaf each datapoint falls into, if the grow move is
979
+ accepted.
980
+ moves : dict
981
+ The proposed moves, see `sample_moves`.
982
+ batch_size : int or None
983
+ The data batch size to use for the summation.
984
+
985
+ Returns
986
+ -------
987
+ count_trees : int array (num_trees, 2 ** (d - 1))
988
+ The number of points in each potential or actual leaf node.
989
+ counts : dict
990
+ The counts of the number of points in the the nodes modified by the
991
+ moves, organized as two dictionaries 'grow' and 'prune', with subfields
992
+ 'left', 'right', and 'total'.
993
+ """
994
+
995
+ ntree, tree_size = moves['var_trees'].shape
996
+ tree_size *= 2
997
+ tree_indices = jnp.arange(ntree)
998
+
999
+ count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
1000
+
1001
+ # count datapoints in nodes modified by move
1002
+ counts = dict()
1003
+ counts['left'] = count_trees[tree_indices, moves['left']]
1004
+ counts['right'] = count_trees[tree_indices, moves['right']]
1005
+ counts['total'] = counts['left'] + counts['right']
1006
+
1007
+ # write count into non-leaf node
1008
+ count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total'])
1009
+
1010
+ return count_trees, counts
1011
+
1012
+ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1013
+ """
1014
+ Count the number of datapoints in each leaf.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ leaf_indices : int array (num_trees, n)
1019
+ The index of the leaf each datapoint falls into.
1020
+ tree_size : int
1021
+ The size of the leaf tree array (2 ** d).
1022
+ batch_size : int or None
1023
+ The data batch size to use for the summation.
1024
+
1025
+ Returns
1026
+ -------
1027
+ count_trees : int array (num_trees, 2 ** (d - 1))
1028
+ The number of points in each leaf node.
1029
+ """
1030
+ if batch_size is None:
1031
+ return _count_scan(leaf_indices, tree_size)
1032
+ else:
1033
+ return _count_vec(leaf_indices, tree_size, batch_size)
1034
+
1035
+ def _count_scan(leaf_indices, tree_size):
1036
+ def loop(_, leaf_indices):
1037
+ return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1038
+ _, count_trees = lax.scan(loop, None, leaf_indices)
1039
+ return count_trees
1040
+
1041
+ def _aggregate_scatter(values, indices, size, dtype):
1042
+ return (jnp
1043
+ .zeros(size, dtype)
1044
+ .at[indices]
1045
+ .add(values)
1046
+ )
1047
+
1048
+ def _count_vec(leaf_indices, tree_size, batch_size):
1049
+ return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
1050
+ # uint16 is super-slow on gpu, don't use it even if n < 2^16
1051
+
1052
+ def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1053
+ ntree, n = indices.shape
1054
+ tree_indices = jnp.arange(ntree)
1055
+ nbatches = n // batch_size + bool(n % batch_size)
1056
+ batch_indices = jnp.arange(n) % nbatches
1057
+ return (jnp
1058
+ .zeros((ntree, size, nbatches), dtype)
1059
+ .at[tree_indices[:, None], indices, batch_indices]
1060
+ .add(values)
1061
+ .sum(axis=2)
1062
+ )
1063
+
1064
+ def complete_ratio(moves, move_counts, min_points_per_leaf):
1065
+ """
1066
+ Complete non-likelihood MH ratio calculation.
1067
+
1068
+ This functions adds the probability of choosing the prune move.
1069
+
1070
+ Parameters
1071
+ ----------
1072
+ moves : dict
1073
+ The proposed moves, see `sample_moves`.
1074
+ move_counts : dict
1075
+ The counts of the number of points in the the nodes modified by the
1076
+ moves.
1077
+ min_points_per_leaf : int or None
1078
+ The minimum number of data points in a leaf node.
1079
+
1080
+ Returns
1081
+ -------
1082
+ moves : dict
1083
+ The updated moves, with the field 'partial_ratio' replaced by
1084
+ 'log_trans_prior_ratio'.
1085
+ """
1086
+ moves = moves.copy()
1087
+ p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf)
1088
+ moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
1089
+ return moves
1090
+
1091
+ def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1092
+ """
1093
+ Compute the probability of proposing a prune move.
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ moves : dict
1098
+ The proposed moves, see `sample_moves`.
1099
+ left_count, right_count : int
1100
+ The number of datapoints in the proposed children of the leaf to grow.
1101
+ min_points_per_leaf : int or None
1102
+ The minimum number of data points in a leaf node.
1103
+
1104
+ Returns
1105
+ -------
1106
+ p_prune : float
1107
+ The probability of proposing a prune move. If grow: after accepting the
1108
+ grow move, if prune: right away.
1109
+ """
1110
+
1111
+ # calculation in case the move is grow
1112
+ other_growable_leaves = moves['num_growable'] >= 2
1113
+ new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2
1114
+ if min_points_per_leaf is not None:
1115
+ any_above_threshold = left_count >= 2 * min_points_per_leaf
1116
+ any_above_threshold |= right_count >= 2 * min_points_per_leaf
1117
+ new_leaves_growable &= any_above_threshold
1118
+ grow_again_allowed = other_growable_leaves | new_leaves_growable
1119
+ grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1120
+
1121
+ # calculation in case the move is prune
1122
+ prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1)
1123
+
1124
+ return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
1125
+
1126
+ @jaxext.vmap_nodoc
1127
+ def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1128
+ """
1129
+ Modify leaf values such that the indices of the grow moves work on the
1130
+ original tree.
1131
+
1132
+ Parameters
1133
+ ----------
1134
+ leaf_trees : float array (num_trees, 2 ** d)
1135
+ The leaf values.
1136
+ moves : dict
1137
+ The proposed moves, see `sample_moves`.
1138
+
1139
+ Returns
1140
+ -------
1141
+ leaf_trees : float array (num_trees, 2 ** d)
1142
+ The modified leaf values. The value of the leaf to grow is copied to
1143
+ what would be its children if the grow move was accepted.
1144
+ """
1145
+ values_at_node = leaf_trees[moves['node']]
1146
+ return (leaf_trees
1147
+ .at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1148
+ .set(values_at_node)
1149
+ .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1150
+ .set(values_at_node)
1151
+ )
1152
+
1153
+ def precompute_likelihood_terms(count_trees, sigma2, move_counts):
1154
+ """
1155
+ Pre-compute terms used in the likelihood ratio of the acceptance step.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ count_trees : array (num_trees, 2 ** d)
1160
+ The number of points in each potential or actual leaf node.
1161
+ sigma2 : float
1162
+ The noise variance.
1163
+ move_counts : dict
1164
+ The counts of the number of points in the the nodes modified by the
1165
+ moves.
1166
+
1167
+ Returns
1168
+ -------
1169
+ prelkv : dict
1170
+ Dictionary with pre-computed terms of the likelihood ratio, one per
1171
+ tree.
1172
+ prelk : dict
1173
+ Dictionary with pre-computed terms of the likelihood ratio, shared by
1174
+ all trees.
1175
+ """
1176
+ ntree = len(count_trees)
1177
+ sigma_mu2 = 1 / ntree
1178
+ prelkv = dict()
1179
+ prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2
1180
+ prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2
1181
+ prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2
1182
+ prelkv['sqrt_term'] = jnp.log(
1183
+ sigma2 * prelkv['sigma2_total'] /
1184
+ (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1185
+ ) / 2
1186
+ return prelkv, dict(
1187
+ exp_factor=sigma_mu2 / (2 * sigma2),
1188
+ )
1189
+
1190
+ def precompute_leaf_terms(count_trees, sigma2, key):
1191
+ """
1192
+ Pre-compute terms used to sample leaves from their posterior.
1193
+
1194
+ Parameters
1195
+ ----------
1196
+ count_trees : array (num_trees, 2 ** d)
1197
+ The number of points in each potential or actual leaf node.
1198
+ sigma2 : float
1199
+ The noise variance.
1200
+ key : jax.dtypes.prng_key array
1201
+ A jax random key.
1202
+
1203
+ Returns
1204
+ -------
1205
+ prelf : dict
1206
+ Dictionary with pre-computed terms of the leaf sampling, with fields:
1207
+
1208
+ 'mean_factor' : float array (num_trees, 2 ** d)
1209
+ The factor to be multiplied by the sum of residuals to obtain the
1210
+ posterior mean.
1211
+ 'centered_leaves' : float array (num_trees, 2 ** d)
1212
+ The mean-zero normal values to be added to the posterior mean to
1213
+ obtain the posterior leaf samples.
1214
+ """
1215
+ ntree = len(count_trees)
1216
+ prec_lk = count_trees / sigma2
1217
+ var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1218
+ z = random.normal(key, count_trees.shape, sigma2.dtype)
1219
+ return dict(
1220
+ mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1221
+ centered_leaves=z * jnp.sqrt(var_post),
1222
+ )
1223
+
1224
+ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf):
1225
+ """
1226
+ The part of accepting the moves that has to be done one tree at a time.
1227
+
1228
+ Parameters
1229
+ ----------
1230
+ bart : dict
1231
+ A partially updated BART mcmc state.
1232
+ count_trees : array (num_trees, 2 ** d)
1233
+ The number of points in each potential or actual leaf node.
1234
+ moves : dict
1235
+ The proposed moves, see `sample_moves`.
1236
+ move_counts : dict
1237
+ The counts of the number of points in the the nodes modified by the
1238
+ moves.
1239
+ prelkv, prelk, prelf : dict
1240
+ Dictionaries with pre-computed terms of the likelihood ratios and leaf
1241
+ samples.
1242
+
1243
+ Returns
1244
+ -------
1245
+ bart : dict
1246
+ A partially updated BART mcmc state.
1247
+ moves : dict
1248
+ The proposed moves, with these additional fields:
1249
+
1250
+ 'acc' : bool array (num_trees,)
1251
+ Whether the move was accepted.
1252
+ 'to_prune' : bool array (num_trees,)
1253
+ Whether, to reflect the acceptance status of the move, the state
1254
+ should be updated by pruning the leaves involved in the move.
1255
+ """
765
1256
  bart = bart.copy()
766
- def loop(carry, item):
767
- resid = carry.pop('resid')
768
- resid, carry, trees = accept_move_and_sample_leaves(
1257
+ moves = moves.copy()
1258
+
1259
+ def loop(resid, item):
1260
+ resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
769
1261
  bart['X'],
770
1262
  len(bart['leaf_trees']),
771
- bart['opt']['suffstat_batch_size'],
1263
+ bart['opt']['resid_batch_size'],
772
1264
  resid,
773
- bart['sigma2'],
774
1265
  bart['min_points_per_leaf'],
775
- carry,
1266
+ 'ratios' in bart,
1267
+ prelk,
776
1268
  *item,
777
1269
  )
778
- carry['resid'] = resid
779
- return carry, trees
780
- carry = {
781
- k: jnp.zeros_like(bart[k]) for k in
782
- ['grow_prop_count', 'prune_prop_count', 'grow_acc_count', 'prune_acc_count']
783
- }
784
- carry['resid'] = bart['resid']
1270
+ return resid, (leaf_tree, acc, to_prune, ratios)
1271
+
785
1272
  items = (
786
- bart['leaf_trees'],
787
- bart['split_trees'],
788
- bart['affluence_trees'],
789
- grow_moves,
790
- prune_moves,
791
- grow_leaf_indices,
792
- random.split(key, len(bart['leaf_trees'])),
1273
+ bart['leaf_trees'], count_trees,
1274
+ moves, move_counts,
1275
+ bart['leaf_indices'],
1276
+ prelkv, prelf,
793
1277
  )
794
- carry, trees = lax.scan(loop, carry, items)
795
- bart.update(carry)
796
- bart.update(trees)
797
- return bart
1278
+ resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
1279
+
1280
+ bart['resid'] = resid
1281
+ bart['leaf_trees'] = leaf_trees
1282
+ bart.get('ratios', {}).update(ratios)
1283
+ moves['acc'] = acc
1284
+ moves['to_prune'] = to_prune
1285
+
1286
+ return bart, moves
798
1287
 
799
- def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2, min_points_per_leaf, counts, leaf_tree, split_tree, affluence_tree, grow_move, prune_move, grow_leaf_indices, key):
1288
+ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf):
800
1289
  """
801
1290
  Accept or reject a proposed move and sample the new leaf values.
802
1291
 
@@ -806,158 +1295,102 @@ def accept_move_and_sample_leaves(X, ntree, suffstat_batch_size, resid, sigma2,
806
1295
  The predictors.
807
1296
  ntree : int
808
1297
  The number of trees in the forest.
809
- suffstat_batch_size : int, None
810
- The batch size for computing sufficient statistics.
1298
+ resid_batch_size : int, None
1299
+ The batch size for computing the sum of residuals in each leaf.
811
1300
  resid : float array (n,)
812
1301
  The residuals (data minus forest value).
813
- sigma2 : float
814
- The noise variance.
815
1302
  min_points_per_leaf : int or None
816
1303
  The minimum number of data points in a leaf node.
817
- counts : dict
818
- The acceptance counts from the mcmc state dict.
1304
+ save_ratios : bool
1305
+ Whether to save the acceptance ratios.
1306
+ prelk : dict
1307
+ The pre-computed terms of the likelihood ratio which are shared across
1308
+ trees.
819
1309
  leaf_tree : float array (2 ** d,)
820
1310
  The leaf values of the tree.
821
- split_tree : int array (2 ** (d - 1),)
822
- The decision boundaries of the tree.
823
- affluence_tree : bool array (2 ** (d - 1),) or None
824
- Whether a leaf has enough points to be grown.
825
- grow_move : dict
826
- The proposal for the grow move. See `grow_move`.
827
- prune_move : dict
828
- The proposal for the prune move. See `prune_move`.
829
- grow_leaf_indices : int array (n,)
830
- The leaf indices of the tree proposed by the grow move.
831
- key : jax.dtypes.prng_key array
832
- A jax random key.
1311
+ count_tree : int array (2 ** d,)
1312
+ The number of datapoints in each leaf.
1313
+ move : dict
1314
+ The proposed move, see `sample_moves`.
1315
+ leaf_indices : int array (n,)
1316
+ The leaf indices for the largest version of the tree compatible with
1317
+ the move.
1318
+ prelkv, prelf : dict
1319
+ The pre-computed terms of the likelihood ratio and leaf sampling which
1320
+ are specific to the tree.
833
1321
 
834
1322
  Returns
835
1323
  -------
836
1324
  resid : float array (n,)
837
1325
  The updated residuals (data minus forest value).
838
- counts : dict
839
- The updated acceptance counts.
840
- trees : dict
841
- The updated tree arrays.
842
- """
843
-
844
- # compute leaf indices in starting tree
845
- grow_node = grow_move['node']
846
- grow_left = grow_node << 1
847
- grow_right = grow_left + 1
848
- leaf_indices = jnp.where(
849
- (grow_leaf_indices == grow_left) | (grow_leaf_indices == grow_right),
850
- grow_node,
851
- grow_leaf_indices,
852
- )
1326
+ leaf_tree : float array (2 ** d,)
1327
+ The new leaf values of the tree.
1328
+ acc : bool
1329
+ Whether the move was accepted.
1330
+ to_prune : bool
1331
+ Whether, to reflect the acceptance status of the move, the state should
1332
+ be updated by pruning the leaves involved in the move.
1333
+ ratios : dict
1334
+ The acceptance ratios for the moves. Empty if not to be saved.
1335
+ """
853
1336
 
854
- # compute leaf indices in prune tree
855
- prune_node = prune_move['node']
856
- prune_left = prune_node << 1
857
- prune_right = prune_left + 1
858
- prune_leaf_indices = jnp.where(
859
- (leaf_indices == prune_left) | (leaf_indices == prune_right),
860
- prune_node,
861
- leaf_indices,
862
- )
1337
+ # sum residuals and count units per leaf, in tree proposed by grow move
1338
+ resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size)
863
1339
 
864
1340
  # subtract starting tree from function
865
- resid += leaf_tree[leaf_indices]
866
-
867
- # aggregate residuals and count units per leaf
868
- grow_resid_tree, grow_count_tree = sufficient_stat(resid, grow_leaf_indices, leaf_tree.size, suffstat_batch_size)
869
-
870
- # compute aggregations in starting tree
871
- # I do not zero the children because garbage there does not matter
872
- resid_tree = (grow_resid_tree.at[grow_node]
873
- .set(grow_resid_tree[grow_left] + grow_resid_tree[grow_right]))
874
- count_tree = (grow_count_tree.at[grow_node]
875
- .set(grow_count_tree[grow_left] + grow_count_tree[grow_right]))
876
-
877
- # compute aggregations in prune tree
878
- prune_resid_tree = (resid_tree.at[prune_node]
879
- .set(resid_tree[prune_left] + resid_tree[prune_right]))
880
- prune_count_tree = (count_tree.at[prune_node]
881
- .set(count_tree[prune_left] + count_tree[prune_right]))
882
-
883
- # compute affluence trees
884
- if min_points_per_leaf is not None:
885
- grow_affluence_tree = grow_count_tree[:grow_count_tree.size // 2] >= 2 * min_points_per_leaf
886
- prune_affluence_tree = affluence_tree.at[prune_node].set(True)
887
-
888
- # compute probability of proposing prune
889
- grow_p_prune = compute_p_prune_back(grow_move['split_tree'], grow_affluence_tree)
890
- prune_p_prune = compute_p_prune_back(split_tree, affluence_tree)
891
-
892
- # compute likelihood ratios
893
- grow_lk_ratio = compute_likelihood_ratio(grow_resid_tree, grow_count_tree, sigma2, grow_node, ntree, min_points_per_leaf)
894
- prune_lk_ratio = compute_likelihood_ratio(resid_tree, count_tree, sigma2, prune_node, ntree, min_points_per_leaf)
895
-
896
- # compute acceptance ratios
897
- grow_ratio = grow_p_prune * grow_move['partial_ratio'] * grow_lk_ratio
898
- prune_ratio = prune_p_prune * prune_move['partial_ratio'] * prune_lk_ratio
899
- prune_ratio = lax.reciprocal(prune_ratio)
900
-
901
- # random coins in [0, 1) for proposal and acceptance
902
- key, subkey = random.split(key)
903
- u0, u1 = random.uniform(subkey, (2,))
904
-
905
- # determine what move to propose (not proposing anything is an option)
906
- p_grow = jnp.where(grow_move['allowed'] & prune_move['allowed'], 0.5, grow_move['allowed'])
907
- try_grow = u0 < p_grow
908
- try_prune = prune_move['allowed'] & ~try_grow
1341
+ resid_tree += count_tree * leaf_tree
1342
+
1343
+ # get indices of move
1344
+ node = move['node']
1345
+ assert node.dtype == jnp.int32
1346
+ left = move['left']
1347
+ right = move['right']
1348
+
1349
+ # sum residuals in parent node modified by move
1350
+ resid_left = resid_tree[left]
1351
+ resid_right = resid_tree[right]
1352
+ resid_total = resid_left + resid_right
1353
+ resid_tree = resid_tree.at[node].set(resid_total)
1354
+
1355
+ # compute acceptance ratio
1356
+ log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk)
1357
+ log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
1358
+ log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
1359
+ ratios = {}
1360
+ if save_ratios:
1361
+ ratios.update(
1362
+ log_trans_prior=move['log_trans_prior_ratio'],
1363
+ log_likelihood=log_lk_ratio,
1364
+ )
909
1365
 
910
1366
  # determine whether to accept the move
911
- do_grow = try_grow & (u1 < grow_ratio)
912
- do_prune = try_prune & (u1 < prune_ratio)
913
-
914
- # pick trees for chosen move
915
- trees = {}
916
- split_tree = jnp.where(do_grow, grow_move['split_tree'], split_tree)
917
- # the prune var tree is equal to the initial one, because I leave garbage values behind
918
- split_tree = split_tree.at[prune_node].set(
919
- jnp.where(do_prune, 0, split_tree[prune_node]))
1367
+ acc = move['allowed'] & (move['logu'] <= log_ratio)
920
1368
  if min_points_per_leaf is not None:
921
- affluence_tree = jnp.where(do_grow, grow_affluence_tree, affluence_tree)
922
- affluence_tree = jnp.where(do_prune, prune_affluence_tree, affluence_tree)
923
- resid_tree = jnp.where(do_grow, grow_resid_tree, resid_tree)
924
- count_tree = jnp.where(do_grow, grow_count_tree, count_tree)
925
- resid_tree = jnp.where(do_prune, prune_resid_tree, resid_tree)
926
- count_tree = jnp.where(do_prune, prune_count_tree, count_tree)
927
-
928
- # update acceptance counts
929
- counts = counts.copy()
930
- counts['grow_prop_count'] += try_grow
931
- counts['grow_acc_count'] += do_grow
932
- counts['prune_prop_count'] += try_prune
933
- counts['prune_acc_count'] += do_prune
934
-
935
- # compute leaves posterior
936
- prec_lk = count_tree / sigma2
937
- var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
938
- mean_post = resid_tree / sigma2 * var_post # = mean_lk * prec_lk * var_post
939
-
940
- # sample leaves
941
- z = random.normal(key, mean_post.shape, mean_post.dtype)
942
- leaf_tree = mean_post + z * jnp.sqrt(var_post)
943
-
944
- # add new tree to function
945
- leaf_indices = jnp.where(do_grow, grow_leaf_indices, leaf_indices)
946
- leaf_indices = jnp.where(do_prune, prune_leaf_indices, leaf_indices)
947
- resid -= leaf_tree[leaf_indices]
1369
+ acc &= move_counts['left'] >= min_points_per_leaf
1370
+ acc &= move_counts['right'] >= min_points_per_leaf
1371
+
1372
+ # compute leaves posterior and sample leaves
1373
+ initial_leaf_tree = leaf_tree
1374
+ mean_post = resid_tree * prelf['mean_factor']
1375
+ leaf_tree = mean_post + prelf['centered_leaves']
1376
+
1377
+ # copy leaves around such that the leaf indices select the right leaf
1378
+ to_prune = acc ^ move['grow']
1379
+ leaf_tree = (leaf_tree
1380
+ .at[jnp.where(to_prune, left, leaf_tree.size)]
1381
+ .set(leaf_tree[node])
1382
+ .at[jnp.where(to_prune, right, leaf_tree.size)]
1383
+ .set(leaf_tree[node])
1384
+ )
948
1385
 
949
- # pack trees
950
- trees = {
951
- 'leaf_trees': leaf_tree,
952
- 'split_trees': split_tree,
953
- 'affluence_trees': affluence_tree,
954
- }
1386
+ # replace old tree with new tree in function values
1387
+ resid += (initial_leaf_tree - leaf_tree)[leaf_indices]
955
1388
 
956
- return resid, counts, trees
1389
+ return resid, leaf_tree, acc, to_prune, ratios
957
1390
 
958
- def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
1391
+ def sum_resid(resid, leaf_indices, tree_size, batch_size):
959
1392
  """
960
- Compute the sufficient statistics for the likelihood ratio of a tree move.
1393
+ Sum the residuals in each leaf.
961
1394
 
962
1395
  Parameters
963
1396
  ----------
@@ -968,124 +1401,142 @@ def sufficient_stat(resid, leaf_indices, tree_size, batch_size):
968
1401
  tree_size : int
969
1402
  The size of the tree array (2 ** d).
970
1403
  batch_size : int, None
971
- The batch size for the aggregation. Batching increases numerical
1404
+ The data batch size for the aggregation. Batching increases numerical
972
1405
  accuracy and parallelism.
973
1406
 
974
1407
  Returns
975
1408
  -------
976
1409
  resid_tree : float array (2 ** d,)
977
1410
  The sum of the residuals at data points in each leaf.
978
- count_tree : int array (2 ** d,)
979
- The number of data points in each leaf.
980
1411
  """
981
1412
  if batch_size is None:
982
1413
  aggr_func = _aggregate_scatter
983
1414
  else:
984
- aggr_func = functools.partial(_aggregate_batched, batch_size=batch_size)
985
- resid_tree = aggr_func(resid, leaf_indices, tree_size, jnp.float32)
986
- count_tree = aggr_func(1, leaf_indices, tree_size, jnp.uint32)
987
- return resid_tree, count_tree
1415
+ aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1416
+ return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
988
1417
 
989
- def _aggregate_scatter(values, indices, size, dtype):
1418
+ def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1419
+ n, = indices.shape
1420
+ nbatches = n // batch_size + bool(n % batch_size)
1421
+ batch_indices = jnp.arange(n) % nbatches
990
1422
  return (jnp
991
- .zeros(size, dtype)
992
- .at[indices]
1423
+ .zeros((size, nbatches), dtype)
1424
+ .at[indices, batch_indices]
993
1425
  .add(values)
1426
+ .sum(axis=1)
994
1427
  )
995
1428
 
996
- def _aggregate_batched(values, indices, size, dtype, batch_size):
997
- nbatches = indices.size // batch_size + bool(indices.size % batch_size)
998
- batch_indices = jnp.arange(indices.size) // batch_size
999
- return (jnp
1000
- .zeros((nbatches, size), dtype)
1001
- .at[batch_indices, indices]
1002
- .add(values)
1003
- .sum(axis=0)
1004
- )
1005
-
1006
- def compute_p_prune_back(new_split_tree, new_affluence_tree):
1429
+ def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
1007
1430
  """
1008
- Compute the probability of proposing a prune move after doing a grow move.
1431
+ Compute the likelihood ratio of a grow move.
1009
1432
 
1010
1433
  Parameters
1011
1434
  ----------
1012
- new_split_tree : int array (2 ** (d - 1),)
1013
- The decision boundaries of the tree, after the grow move.
1014
- new_affluence_tree : bool array (2 ** (d - 1),)
1015
- Which leaves have enough points to be grown, after the grow move.
1435
+ total_resid : float
1436
+ The sum of the residuals in the leaf to grow.
1437
+ left_resid, right_resid : float
1438
+ The sum of the residuals in the left/right child of the leaf to grow.
1439
+ prelkv, prelk : dict
1440
+ The pre-computed terms of the likelihood ratio, see
1441
+ `precompute_likelihood_terms`.
1016
1442
 
1017
1443
  Returns
1018
1444
  -------
1019
- p_prune : float
1020
- The probability of proposing a prune move after the grow move. This is
1021
- 0.5 if grow is possible again, and 1 if it isn't. It can't be 0 because
1022
- at least the node just grown can be pruned.
1445
+ ratio : float
1446
+ The likelihood ratio P(data | new tree) / P(data | old tree).
1023
1447
  """
1024
- _, grow_again_allowed = growable_leaves(new_split_tree, new_affluence_tree)
1025
- return jnp.where(grow_again_allowed, 0.5, 1)
1448
+ exp_term = prelk['exp_factor'] * (
1449
+ left_resid * left_resid / prelkv['sigma2_left'] +
1450
+ right_resid * right_resid / prelkv['sigma2_right'] -
1451
+ total_resid * total_resid / prelkv['sigma2_total']
1452
+ )
1453
+ return prelkv['sqrt_term'] + exp_term
1026
1454
 
1027
- def compute_likelihood_ratio(resid_tree, count_tree, sigma2, node, n_tree, min_points_per_leaf):
1455
+ def accept_moves_final_stage(bart, moves):
1028
1456
  """
1029
- Compute the likelihood ratio of a grow move.
1457
+ The final part of accepting the moves, in parallel across trees.
1030
1458
 
1031
1459
  Parameters
1032
1460
  ----------
1033
- resid_tree : float array (2 ** d,)
1034
- The sum of the residuals at data points in each leaf.
1035
- count_tree : int array (2 ** d,)
1036
- The number of data points in each leaf.
1037
- sigma2 : float
1038
- The noise variance.
1039
- node : int
1040
- The index of the leaf that has been grown.
1041
- n_tree : int
1042
- The number of trees in the forest.
1043
- min_points_per_leaf : int or None
1044
- The minimum number of data points in a leaf node.
1461
+ bart : dict
1462
+ A partially updated BART mcmc state.
1463
+ counts : dict
1464
+ The indicators of proposals and acceptances for grow and prune moves.
1465
+ moves : dict
1466
+ The proposed moves (see `sample_moves`) as updated by
1467
+ `accept_moves_sequential_stage`.
1045
1468
 
1046
1469
  Returns
1047
1470
  -------
1048
- ratio : float
1049
- The likelihood ratio P(data | new tree) / P(data | old tree).
1050
-
1051
- Notes
1052
- -----
1053
- The ratio is set to 0 if the grow move would create leaves with not enough
1054
- datapoints per leaf, although this is part of the prior rather than the
1055
- likelihood.
1471
+ bart : dict
1472
+ The fully updated BART mcmc state.
1056
1473
  """
1474
+ bart = bart.copy()
1475
+ bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow'])
1476
+ bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow'])
1477
+ bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves)
1478
+ bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
1479
+ return bart
1057
1480
 
1058
- left_child = node << 1
1059
- right_child = left_child + 1
1060
-
1061
- left_resid = resid_tree[left_child]
1062
- right_resid = resid_tree[right_child]
1063
- total_resid = left_resid + right_resid
1064
-
1065
- left_count = count_tree[left_child]
1066
- right_count = count_tree[right_child]
1067
- total_count = left_count + right_count
1068
-
1069
- sigma_mu2 = 1 / n_tree
1070
- sigma2_left = sigma2 + left_count * sigma_mu2
1071
- sigma2_right = sigma2 + right_count * sigma_mu2
1072
- sigma2_total = sigma2 + total_count * sigma_mu2
1481
+ @jax.vmap
1482
+ def apply_moves_to_leaf_indices(leaf_indices, moves):
1483
+ """
1484
+ Update the leaf indices to match the accepted move.
1073
1485
 
1074
- sqrt_term = sigma2 * sigma2_total / (sigma2_left * sigma2_right)
1486
+ Parameters
1487
+ ----------
1488
+ leaf_indices : int array (num_trees, n)
1489
+ The index of the leaf each datapoint falls into, if the grow move was
1490
+ accepted.
1491
+ moves : dict
1492
+ The proposed moves (see `sample_moves`), as updated by
1493
+ `accept_moves_sequential_stage`.
1075
1494
 
1076
- exp_term = sigma_mu2 / (2 * sigma2) * (
1077
- left_resid * left_resid / sigma2_left +
1078
- right_resid * right_resid / sigma2_right -
1079
- total_resid * total_resid / sigma2_total
1495
+ Returns
1496
+ -------
1497
+ leaf_indices : int array (num_trees, n)
1498
+ The updated leaf indices.
1499
+ """
1500
+ mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1501
+ is_child = (leaf_indices & mask) == moves['left']
1502
+ return jnp.where(
1503
+ is_child & moves['to_prune'],
1504
+ moves['node'].astype(leaf_indices.dtype),
1505
+ leaf_indices,
1080
1506
  )
1081
1507
 
1082
- ratio = jnp.sqrt(sqrt_term) * jnp.exp(exp_term)
1508
+ @jax.vmap
1509
+ def apply_moves_to_split_trees(split_trees, moves):
1510
+ """
1511
+ Update the split trees to match the accepted move.
1083
1512
 
1084
- if min_points_per_leaf is not None:
1085
- ratio = jnp.where(right_count >= min_points_per_leaf, ratio, 0)
1086
- ratio = jnp.where(left_count >= min_points_per_leaf, ratio, 0)
1513
+ Parameters
1514
+ ----------
1515
+ split_trees : int array (num_trees, 2 ** (d - 1))
1516
+ The cutpoints of the decision nodes in the initial trees.
1517
+ moves : dict
1518
+ The proposed moves (see `sample_moves`), as updated by
1519
+ `accept_moves_sequential_stage`.
1087
1520
 
1088
- return ratio
1521
+ Returns
1522
+ -------
1523
+ split_trees : int array (num_trees, 2 ** (d - 1))
1524
+ The updated split trees.
1525
+ """
1526
+ return (split_trees
1527
+ .at[jnp.where(
1528
+ moves['grow'],
1529
+ moves['node'],
1530
+ split_trees.size,
1531
+ )]
1532
+ .set(moves['grow_split'].astype(split_trees.dtype))
1533
+ .at[jnp.where(
1534
+ moves['to_prune'],
1535
+ moves['node'],
1536
+ split_trees.size,
1537
+ )]
1538
+ .set(0)
1539
+ )
1089
1540
 
1090
1541
  def sample_sigma(bart, key):
1091
1542
  """
@@ -1107,7 +1558,7 @@ def sample_sigma(bart, key):
1107
1558
 
1108
1559
  resid = bart['resid']
1109
1560
  alpha = bart['sigma2_alpha'] + resid.size / 2
1110
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['sigma2_beta'].dtype)
1561
+ norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
1111
1562
  beta = bart['sigma2_beta'] + norm2 / 2
1112
1563
 
1113
1564
  sample = random.gamma(key, alpha)