bartz 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # bartz/src/bartz/mcmcstep.py
2
2
  #
3
- # Copyright (c) 2024, Giacomo Petrillo
3
+ # Copyright (c) 2024-2025, Giacomo Petrillo
4
4
  #
5
5
  # This file is part of bartz.
6
6
  #
@@ -37,14 +37,14 @@ import functools
37
37
  import math
38
38
 
39
39
  import jax
40
- from jax import random
40
+ from jax import lax, random
41
41
  from jax import numpy as jnp
42
- from jax import lax
43
42
 
44
- from . import jaxext
45
- from . import grove
43
+ from . import grove, jaxext
46
44
 
47
- def init(*,
45
+
46
+ def init(
47
+ *,
48
48
  X,
49
49
  y,
50
50
  max_split,
@@ -52,13 +52,14 @@ def init(*,
52
52
  p_nonterminal,
53
53
  sigma2_alpha,
54
54
  sigma2_beta,
55
+ error_scale=None,
55
56
  small_float=jnp.float32,
56
57
  large_float=jnp.float32,
57
58
  min_points_per_leaf=None,
58
59
  resid_batch_size='auto',
59
60
  count_batch_size='auto',
60
61
  save_ratios=False,
61
- ):
62
+ ):
62
63
  """
63
64
  Make a BART posterior sampling MCMC initial state.
64
65
 
@@ -76,9 +77,12 @@ def init(*,
76
77
  The probability of a nonterminal node at each depth. The maximum depth
77
78
  of trees is fixed by the length of this array.
78
79
  sigma2_alpha : float
79
- The shape parameter of the inverse gamma prior on the noise variance.
80
+ The shape parameter of the inverse gamma prior on the error variance.
80
81
  sigma2_beta : float
81
- The scale parameter of the inverse gamma prior on the noise variance.
82
+ The scale parameter of the inverse gamma prior on the error variance.
83
+ error_scale : float array (n,), optional
84
+ Each error is scaled by the corresponding factor in `error_scale`, so
85
+ the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
82
86
  small_float : dtype, default float32
83
87
  The dtype for large arrays used in the algorithm.
84
88
  large_float : dtype, default float32
@@ -110,6 +114,8 @@ def init(*,
110
114
  roundoff.
111
115
  'sigma2' : large_float
112
116
  The noise variance.
117
+ 'prec_scale' : large_float array (n,) or None
118
+ The scale on the error precision, i.e., ``1 / error_scale ** 2``.
113
119
  'grow_prop_count', 'prune_prop_count' : int
114
120
  The number of grow/prune proposals made during one full MCMC cycle.
115
121
  'grow_acc_count', 'prune_acc_count' : int
@@ -169,16 +175,27 @@ def init(*,
169
175
  small_float = jnp.dtype(small_float)
170
176
  large_float = jnp.dtype(large_float)
171
177
  y = jnp.asarray(y, small_float)
172
- resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees)
178
+ resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
179
+ resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
180
+ )
173
181
  sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float)
174
- sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
182
+ sigma2 = jnp.where(
183
+ jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1
184
+ ) # TODO: I don't like this error check, these functions should be low-level and just do the thing. Why is it here?
175
185
 
176
186
  bart = dict(
177
187
  leaf_trees=make_forest(max_depth, small_float),
178
- var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
188
+ var_trees=make_forest(
189
+ max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)
190
+ ),
179
191
  split_trees=make_forest(max_depth - 1, max_split.dtype),
180
192
  resid=jnp.asarray(y, large_float),
181
193
  sigma2=sigma2,
194
+ prec_scale=(
195
+ None
196
+ if error_scale is None
197
+ else lax.reciprocal(jnp.square(jnp.asarray(error_scale, large_float)))
198
+ ),
182
199
  grow_prop_count=jnp.zeros((), int),
183
200
  grow_acc_count=jnp.zeros((), int),
184
201
  prune_prop_count=jnp.zeros((), int),
@@ -190,14 +207,18 @@ def init(*,
190
207
  max_split=jnp.asarray(max_split),
191
208
  y=y,
192
209
  X=jnp.asarray(X),
193
- leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
210
+ leaf_indices=jnp.ones(
211
+ (num_trees, y.size), jaxext.minimal_unsigned_dtype(2**max_depth - 1)
212
+ ),
194
213
  min_points_per_leaf=(
195
- None if min_points_per_leaf is None else
196
- jnp.asarray(min_points_per_leaf)
214
+ None if min_points_per_leaf is None else jnp.asarray(min_points_per_leaf)
197
215
  ),
198
216
  affluence_trees=(
199
- None if min_points_per_leaf is None else
200
- make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
217
+ None
218
+ if min_points_per_leaf is None
219
+ else make_forest(max_depth - 1, bool)
220
+ .at[:, 1]
221
+ .set(y.size >= 2 * min_points_per_leaf)
201
222
  ),
202
223
  opt=jaxext.LeafDict(
203
224
  small_float=small_float,
@@ -216,8 +237,8 @@ def init(*,
216
237
 
217
238
  return bart
218
239
 
219
- def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
220
240
 
241
+ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size):
221
242
  @functools.cache
222
243
  def get_platform():
223
244
  try:
@@ -233,9 +254,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
233
254
  platform = get_platform()
234
255
  n = max(1, y.size)
235
256
  if platform == 'cpu':
236
- resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
257
+ resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
237
258
  elif platform == 'gpu':
238
- resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
259
+ resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
239
260
  resid_batch_size = max(1, resid_batch_size)
240
261
 
241
262
  if count_batch_size == 'auto':
@@ -244,9 +265,9 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
244
265
  count_batch_size = None
245
266
  elif platform == 'gpu':
246
267
  n = max(1, y.size)
247
- count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
248
- # /4 is good on V100, /2 on L4/T4, still haven't tried A100
249
- max_memory = 2 ** 29
268
+ count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
269
+ # /4 is good on V100, /2 on L4/T4, still haven't tried A100
270
+ max_memory = 2**29
250
271
  itemsize = 4
251
272
  min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
252
273
  count_batch_size = max(count_batch_size, min_batch_size)
@@ -254,16 +275,17 @@ def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_si
254
275
 
255
276
  return resid_batch_size, count_batch_size
256
277
 
257
- def step(bart, key):
278
+
279
+ def step(key, bart):
258
280
  """
259
281
  Perform one full MCMC step on a BART state.
260
282
 
261
283
  Parameters
262
284
  ----------
263
- bart : dict
264
- A BART mcmc state, as created by `init`.
265
285
  key : jax.dtypes.prng_key array
266
286
  A jax random key.
287
+ bart : dict
288
+ A BART mcmc state, as created by `init`.
267
289
 
268
290
  Returns
269
291
  -------
@@ -271,19 +293,20 @@ def step(bart, key):
271
293
  The new BART mcmc state.
272
294
  """
273
295
  key, subkey = random.split(key)
274
- bart = sample_trees(bart, subkey)
275
- return sample_sigma(bart, key)
296
+ bart = sample_trees(subkey, bart)
297
+ return sample_sigma(key, bart)
298
+
276
299
 
277
- def sample_trees(bart, key):
300
+ def sample_trees(key, bart):
278
301
  """
279
302
  Forest sampling step of BART MCMC.
280
303
 
281
304
  Parameters
282
305
  ----------
283
- bart : dict
284
- A BART mcmc state, as created by `init`.
285
306
  key : jax.dtypes.prng_key array
286
307
  A jax random key.
308
+ bart : dict
309
+ A BART mcmc state, as created by `init`.
287
310
 
288
311
  Returns
289
312
  -------
@@ -295,19 +318,20 @@ def sample_trees(bart, key):
295
318
  This function zeroes the proposal counters.
296
319
  """
297
320
  key, subkey = random.split(key)
298
- moves = sample_moves(bart, subkey)
299
- return accept_moves_and_sample_leaves(bart, moves, key)
321
+ moves = sample_moves(subkey, bart)
322
+ return accept_moves_and_sample_leaves(key, bart, moves)
300
323
 
301
- def sample_moves(bart, key):
324
+
325
+ def sample_moves(key, bart):
302
326
  """
303
327
  Propose moves for all the trees.
304
328
 
305
329
  Parameters
306
330
  ----------
307
- bart : dict
308
- BART mcmc state.
309
331
  key : jax.dtypes.prng_key array
310
332
  A jax random key.
333
+ bart : dict
334
+ BART mcmc state.
311
335
 
312
336
  Returns
313
337
  -------
@@ -343,14 +367,22 @@ def sample_moves(bart, key):
343
367
  key, subkey = key[0], key[1:]
344
368
 
345
369
  # compute moves
346
- grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey)
370
+ grow_moves, prune_moves = _sample_moves_vmap_trees(
371
+ subkey,
372
+ bart['var_trees'],
373
+ bart['split_trees'],
374
+ bart['affluence_trees'],
375
+ bart['max_split'],
376
+ bart['p_nonterminal'],
377
+ bart['p_propose_grow'],
378
+ )
347
379
 
348
380
  u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float'])
349
381
 
350
382
  # choose between grow or prune
351
383
  grow_allowed = grow_moves['num_growable'].astype(bool)
352
384
  p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed)
353
- grow = u < p_grow # use < instead of <= because u is in [0, 1)
385
+ grow = u < p_grow # use < instead of <= because u is in [0, 1)
354
386
 
355
387
  # compute children indices
356
388
  node = jnp.where(grow, grow_moves['node'], prune_moves['node'])
@@ -364,22 +396,28 @@ def sample_moves(bart, key):
364
396
  node=node,
365
397
  left=left,
366
398
  right=right,
367
- partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
399
+ partial_ratio=jnp.where(
400
+ grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']
401
+ ),
368
402
  grow_var=grow_moves['var'],
369
403
  grow_split=grow_moves['split'],
370
404
  var_trees=grow_moves['var_tree'],
371
405
  logu=jnp.log1p(-logu),
372
406
  )
373
407
 
374
- @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0))
408
+
409
+ @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
375
410
  def _sample_moves_vmap_trees(*args):
376
- args, key = args[:-1], args[-1]
411
+ key, args = args[0], args[1:]
377
412
  key, key1 = random.split(key)
378
- grow = grow_move(*args, key)
379
- prune = prune_move(*args, key1)
413
+ grow = grow_move(key, *args)
414
+ prune = prune_move(key1, *args)
380
415
  return grow, prune
381
416
 
382
- def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
417
+
418
+ def grow_move(
419
+ key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
420
+ ):
383
421
  """
384
422
  Tree structure grow move proposal of BART MCMC.
385
423
 
@@ -426,14 +464,18 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
426
464
 
427
465
  key, key1, key2 = random.split(key, 3)
428
466
 
429
- leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key)
467
+ leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
468
+ key, split_tree, affluence_tree, p_propose_grow
469
+ )
430
470
 
431
- var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1)
471
+ var = choose_variable(key1, var_tree, split_tree, max_split, leaf_to_grow)
432
472
  var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
433
473
 
434
- split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2)
474
+ split = choose_split(key2, var_tree, split_tree, max_split, leaf_to_grow)
435
475
 
436
- ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)
476
+ ratio = compute_partial_ratio(
477
+ prob_choose, num_prunable, p_nonterminal, leaf_to_grow
478
+ )
437
479
 
438
480
  return dict(
439
481
  num_growable=num_growable,
@@ -444,7 +486,8 @@ def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_
444
486
  var_tree=var_tree,
445
487
  )
446
488
 
447
- def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
489
+
490
+ def choose_leaf(key, split_tree, affluence_tree, p_propose_grow):
448
491
  """
449
492
  Choose a leaf node to grow in a tree.
450
493
 
@@ -482,6 +525,7 @@ def choose_leaf(split_tree, affluence_tree, p_propose_grow, key):
482
525
  num_prunable = jnp.count_nonzero(is_parent)
483
526
  return leaf_to_grow, num_growable, prob_choose, num_prunable
484
527
 
528
+
485
529
  def growable_leaves(split_tree, affluence_tree):
486
530
  """
487
531
  Return a mask indicating the leaf nodes that can be proposed for growth.
@@ -505,6 +549,7 @@ def growable_leaves(split_tree, affluence_tree):
505
549
  is_growable &= affluence_tree
506
550
  return is_growable
507
551
 
552
+
508
553
  def categorical(key, distr):
509
554
  """
510
555
  Return a random integer from an arbitrary distribution.
@@ -521,12 +566,15 @@ def categorical(key, distr):
521
566
  u : int
522
567
  A random integer in the range ``[0, n)``. If all probabilities are zero,
523
568
  return ``n``.
569
+ norm : float
570
+ The sum of `distr`.
524
571
  """
525
572
  ecdf = jnp.cumsum(distr)
526
573
  u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
527
574
  return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
528
575
 
529
- def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
576
+
577
+ def choose_variable(key, var_tree, split_tree, max_split, leaf_index):
530
578
  """
531
579
  Choose a variable to split on for a new non-terminal node.
532
580
 
@@ -556,6 +604,7 @@ def choose_variable(var_tree, split_tree, max_split, leaf_index, key):
556
604
  var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
557
605
  return randint_exclude(key, max_split.size, var_to_ignore)
558
606
 
607
+
559
608
  def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
560
609
  """
561
610
  Return a list of variables that have an empty split range at a given node.
@@ -586,6 +635,7 @@ def fully_used_variables(var_tree, split_tree, max_split, leaf_index):
586
635
  num_split = r - l
587
636
  return jnp.where(num_split == 0, var_to_ignore, max_split.size)
588
637
 
638
+
589
639
  def ancestor_variables(var_tree, max_split, node_index):
590
640
  """
591
641
  Return the list of variables in the ancestors of a node.
@@ -606,8 +656,11 @@ def ancestor_variables(var_tree, max_split, node_index):
606
656
  the parent. Unused spots are filled with `p`.
607
657
  """
608
658
  max_num_ancestors = grove.tree_depth(var_tree) - 1
609
- ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size))
659
+ ancestor_vars = jnp.zeros(
660
+ max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)
661
+ )
610
662
  carry = ancestor_vars.size - 1, node_index, ancestor_vars
663
+
611
664
  def loop(carry, _):
612
665
  i, index, ancestor_vars = carry
613
666
  index >>= 1
@@ -615,9 +668,11 @@ def ancestor_variables(var_tree, max_split, node_index):
615
668
  var = jnp.where(index, var, max_split.size)
616
669
  ancestor_vars = ancestor_vars.at[i].set(var)
617
670
  return (i - 1, index, ancestor_vars), None
671
+
618
672
  (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
619
673
  return ancestor_vars
620
674
 
675
+
621
676
  def split_range(var_tree, split_tree, max_split, node_index, ref_var):
622
677
  """
623
678
  Return the range of allowed splits for a variable at a given node.
@@ -641,8 +696,11 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
641
696
  The range of allowed splits is [l, r).
642
697
  """
643
698
  max_num_ancestors = grove.tree_depth(var_tree) - 1
644
- initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(jnp.int32)
699
+ initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
700
+ jnp.int32
701
+ )
645
702
  carry = 0, initial_r, node_index
703
+
646
704
  def loop(carry, _):
647
705
  l, r, index = carry
648
706
  right_child = (index & 1).astype(bool)
@@ -652,9 +710,11 @@ def split_range(var_tree, split_tree, max_split, node_index, ref_var):
652
710
  l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
653
711
  r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
654
712
  return (l, r, index), None
713
+
655
714
  (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
656
715
  return l + 1, r
657
716
 
717
+
658
718
  def randint_exclude(key, sup, exclude):
659
719
  """
660
720
  Return a random integer in a range, excluding some values.
@@ -679,12 +739,15 @@ def randint_exclude(key, sup, exclude):
679
739
  exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
680
740
  num_allowed = sup - jnp.count_nonzero(exclude < sup)
681
741
  u = random.randint(key, (), 0, num_allowed)
742
+
682
743
  def loop(u, i):
683
744
  return jnp.where(i <= u, u + 1, u), None
745
+
684
746
  u, _ = lax.scan(loop, u, exclude)
685
747
  return u
686
748
 
687
- def choose_split(var_tree, split_tree, max_split, leaf_index, key):
749
+
750
+ def choose_split(key, var_tree, split_tree, max_split, leaf_index):
688
751
  """
689
752
  Choose a split point for a new non-terminal node.
690
753
 
@@ -711,6 +774,7 @@ def choose_split(var_tree, split_tree, max_split, leaf_index, key):
711
774
  l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
712
775
  return random.randint(key, (), l, r)
713
776
 
777
+
714
778
  def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow):
715
779
  """
716
780
  Compute the product of the transition and prior ratios of a grow move.
@@ -742,9 +806,9 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
742
806
  # computed in the acceptance phase
743
807
 
744
808
  prune_allowed = leaf_to_grow != 1
745
- # prune allowed <---> the initial tree is not a root
746
- # leaf to grow is root --> the tree can only be a root
747
- # tree is a root --> the only leaf I can grow is root
809
+ # prune allowed <---> the initial tree is not a root
810
+ # leaf to grow is root --> the tree can only be a root
811
+ # tree is a root --> the only leaf I can grow is root
748
812
 
749
813
  p_grow = jnp.where(prune_allowed, 0.5, 1)
750
814
 
@@ -757,7 +821,10 @@ def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow
757
821
 
758
822
  return tree_ratio / inv_trans_ratio
759
823
 
760
- def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key):
824
+
825
+ def prune_move(
826
+ key, var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow
827
+ ):
761
828
  """
762
829
  Tree structure prune move proposal of BART MCMC.
763
830
 
@@ -792,18 +859,23 @@ def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p
792
859
  the likelihood ratio and the probability of proposing the prune
793
860
  move. This ratio is inverted.
794
861
  """
795
- node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)
796
- allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
862
+ node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
863
+ key, split_tree, affluence_tree, p_propose_grow
864
+ )
865
+ allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
797
866
 
798
- ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune)
867
+ ratio = compute_partial_ratio(
868
+ prob_choose, num_prunable, p_nonterminal, node_to_prune
869
+ )
799
870
 
800
871
  return dict(
801
872
  allowed=allowed,
802
873
  node=node_to_prune,
803
- partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
874
+ partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
804
875
  )
805
876
 
806
- def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
877
+
878
+ def choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow):
807
879
  """
808
880
  Pick a non-terminal node with leaf children to prune in a tree.
809
881
 
@@ -835,8 +907,7 @@ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
835
907
 
836
908
  split_tree = split_tree.at[node_to_prune].set(0)
837
909
  affluence_tree = (
838
- None if affluence_tree is None else
839
- affluence_tree.at[node_to_prune].set(True)
910
+ None if affluence_tree is None else affluence_tree.at[node_to_prune].set(True)
840
911
  )
841
912
  is_growable_leaf = growable_leaves(split_tree, affluence_tree)
842
913
  prob_choose = p_propose_grow[node_to_prune]
@@ -844,6 +915,7 @@ def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key):
844
915
 
845
916
  return node_to_prune, num_prunable, prob_choose
846
917
 
918
+
847
919
  def randint_masked(key, mask):
848
920
  """
849
921
  Return a random integer in a range, including only some values.
@@ -865,40 +937,46 @@ def randint_masked(key, mask):
865
937
  u = random.randint(key, (), 0, ecdf[-1])
866
938
  return jnp.searchsorted(ecdf, u, 'right')
867
939
 
868
- def accept_moves_and_sample_leaves(bart, moves, key):
940
+
941
+ def accept_moves_and_sample_leaves(key, bart, moves):
869
942
  """
870
943
  Accept or reject the proposed moves and sample the new leaf values.
871
944
 
872
945
  Parameters
873
946
  ----------
947
+ key : jax.dtypes.prng_key array
948
+ A jax random key.
874
949
  bart : dict
875
950
  A BART mcmc state.
876
951
  moves : dict
877
952
  The proposed moves, see `sample_moves`.
878
- key : jax.dtypes.prng_key array
879
- A jax random key.
880
953
 
881
954
  Returns
882
955
  -------
883
956
  bart : dict
884
957
  The new BART mcmc state.
885
958
  """
886
- bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key)
887
- bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)
959
+ bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf = (
960
+ accept_moves_parallel_stage(key, bart, moves)
961
+ )
962
+ bart, moves = accept_moves_sequential_stage(
963
+ bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
964
+ )
888
965
  return accept_moves_final_stage(bart, moves)
889
966
 
890
- def accept_moves_parallel_stage(bart, moves, key):
967
+
968
+ def accept_moves_parallel_stage(key, bart, moves):
891
969
  """
892
970
  Pre-computes quantities used to accept moves, in parallel across trees.
893
971
 
894
972
  Parameters
895
973
  ----------
974
+ key : jax.dtypes.prng_key array
975
+ A jax random key.
896
976
  bart : dict
897
977
  A BART mcmc state.
898
978
  moves : dict
899
979
  The proposed moves, see `sample_moves`.
900
- key : jax.dtypes.prng_key array
901
- A jax random key.
902
980
 
903
981
  Returns
904
982
  -------
@@ -907,11 +985,14 @@ def accept_moves_parallel_stage(bart, moves, key):
907
985
  moves : dict
908
986
  The proposed moves, with the field 'partial_ratio' replaced
909
987
  by 'log_trans_prior_ratio'.
910
- count_trees : array (num_trees, 2 ** d)
911
- The number of points in each potential or actual leaf node.
988
+ prec_trees : float array (num_trees, 2 ** d)
989
+ The likelihood precision scale in each potential or actual leaf node. If
990
+ there is no precision scale, this is the number of points in each leaf.
912
991
  move_counts : dict
913
992
  The counts of the number of points in the the nodes modified by the
914
993
  moves.
994
+ move_precs : dict
995
+ The likelihood precision scale in each node modified by the moves.
915
996
  prelkv, prelk, prelf : dict
916
997
  Dictionary with pre-computed terms of the likelihood ratios and leaf
917
998
  samples.
@@ -924,20 +1005,35 @@ def accept_moves_parallel_stage(bart, moves, key):
924
1005
  bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves)
925
1006
 
926
1007
  # count number of datapoints per leaf
927
- count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size'])
1008
+ count_trees, move_counts = compute_count_trees(
1009
+ bart['leaf_indices'], moves, bart['opt']['count_batch_size']
1010
+ )
928
1011
  if bart['opt']['require_min_points']:
929
- count_half_trees = count_trees[:, :bart['var_trees'].shape[1]]
1012
+ count_half_trees = count_trees[:, : bart['var_trees'].shape[1]]
930
1013
  bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf']
931
1014
 
1015
+ # count number of datapoints per leaf, weighted by error precision scale
1016
+ if bart['prec_scale'] is None:
1017
+ prec_trees = count_trees
1018
+ move_precs = move_counts
1019
+ else:
1020
+ prec_trees, move_precs = compute_prec_trees(
1021
+ bart['prec_scale'],
1022
+ bart['leaf_indices'],
1023
+ moves,
1024
+ bart['opt']['count_batch_size'],
1025
+ )
1026
+
932
1027
  # compute some missing information about moves
933
1028
  moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf'])
934
1029
  bart['grow_prop_count'] = jnp.sum(moves['grow'])
935
1030
  bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow'])
936
1031
 
937
- prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts)
938
- prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key)
1032
+ prelkv, prelk = precompute_likelihood_terms(bart['sigma2'], move_precs)
1033
+ prelf = precompute_leaf_terms(key, prec_trees, bart['sigma2'])
1034
+
1035
+ return bart, moves, prec_trees, move_counts, move_precs, prelkv, prelk, prelf
939
1036
 
940
- return bart, moves, count_trees, move_counts, prelkv, prelk, prelf
941
1037
 
942
1038
  @functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None))
943
1039
  def apply_grow_to_indices(moves, leaf_indices, X):
@@ -968,15 +1064,16 @@ def apply_grow_to_indices(moves, leaf_indices, X):
968
1064
  leaf_indices,
969
1065
  )
970
1066
 
1067
+
971
1068
  def compute_count_trees(leaf_indices, moves, batch_size):
972
1069
  """
973
1070
  Count the number of datapoints in each leaf.
974
1071
 
975
1072
  Parameters
976
1073
  ----------
977
- grow_leaf_indices : int array (num_trees, n)
978
- The index of the leaf each datapoint falls into, if the grow move is
979
- accepted.
1074
+ leaf_indices : int array (num_trees, n)
1075
+ The index of the leaf each datapoint falls into, with the deeper version
1076
+ of the tree (post-GROW, pre-PRUNE).
980
1077
  moves : dict
981
1078
  The proposed moves, see `sample_moves`.
982
1079
  batch_size : int or None
@@ -987,9 +1084,8 @@ def compute_count_trees(leaf_indices, moves, batch_size):
987
1084
  count_trees : int array (num_trees, 2 ** (d - 1))
988
1085
  The number of points in each potential or actual leaf node.
989
1086
  counts : dict
990
- The counts of the number of points in the the nodes modified by the
991
- moves, organized as two dictionaries 'grow' and 'prune', with subfields
992
- 'left', 'right', and 'total'.
1087
+ The counts of the number of points in the leaves grown or pruned by the
1088
+ moves, under keys 'left', 'right', and 'total' (left + right).
993
1089
  """
994
1090
 
995
1091
  ntree, tree_size = moves['var_trees'].shape
@@ -1009,6 +1105,7 @@ def compute_count_trees(leaf_indices, moves, batch_size):
1009
1105
 
1010
1106
  return count_trees, counts
1011
1107
 
1108
+
1012
1109
  def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1013
1110
  """
1014
1111
  Count the number of datapoints in each leaf.
@@ -1032,40 +1129,129 @@ def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size):
1032
1129
  else:
1033
1130
  return _count_vec(leaf_indices, tree_size, batch_size)
1034
1131
 
1132
+
1035
1133
  def _count_scan(leaf_indices, tree_size):
1036
1134
  def loop(_, leaf_indices):
1037
1135
  return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1136
+
1038
1137
  _, count_trees = lax.scan(loop, None, leaf_indices)
1039
1138
  return count_trees
1040
1139
 
1140
+
1041
1141
  def _aggregate_scatter(values, indices, size, dtype):
1042
- return (jnp
1043
- .zeros(size, dtype)
1044
- .at[indices]
1045
- .add(values)
1046
- )
1142
+ return jnp.zeros(size, dtype).at[indices].add(values)
1143
+
1047
1144
 
1048
1145
  def _count_vec(leaf_indices, tree_size, batch_size):
1049
- return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size)
1050
- # uint16 is super-slow on gpu, don't use it even if n < 2^16
1146
+ return _aggregate_batched_alltrees(
1147
+ 1, leaf_indices, tree_size, jnp.uint32, batch_size
1148
+ )
1149
+ # uint16 is super-slow on gpu, don't use it even if n < 2^16
1150
+
1051
1151
 
1052
1152
  def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size):
1053
1153
  ntree, n = indices.shape
1054
1154
  tree_indices = jnp.arange(ntree)
1055
1155
  nbatches = n // batch_size + bool(n % batch_size)
1056
1156
  batch_indices = jnp.arange(n) % nbatches
1057
- return (jnp
1058
- .zeros((ntree, size, nbatches), dtype)
1157
+ return (
1158
+ jnp.zeros((ntree, size, nbatches), dtype)
1059
1159
  .at[tree_indices[:, None], indices, batch_indices]
1060
1160
  .add(values)
1061
1161
  .sum(axis=2)
1062
1162
  )
1063
1163
 
1164
+
1165
+ def compute_prec_trees(prec_scale, leaf_indices, moves, batch_size):
1166
+ """
1167
+ Compute the likelihood precision scale in each leaf.
1168
+
1169
+ Parameters
1170
+ ----------
1171
+ prec_scale : float array (n,)
1172
+ The scale of the precision of the error on each datapoint.
1173
+ leaf_indices : int array (num_trees, n)
1174
+ The index of the leaf each datapoint falls into, with the deeper version
1175
+ of the tree (post-GROW, pre-PRUNE).
1176
+ moves : dict
1177
+ The proposed moves, see `sample_moves`.
1178
+ batch_size : int or None
1179
+ The data batch size to use for the summation.
1180
+
1181
+ Returns
1182
+ -------
1183
+ prec_trees : float array (num_trees, 2 ** (d - 1))
1184
+ The likelihood precision scale in each potential or actual leaf node.
1185
+ counts : dict
1186
+ The likelihood precision scale in the leaves grown or pruned by the
1187
+ moves, under keys 'left', 'right', and 'total' (left + right).
1188
+ """
1189
+
1190
+ ntree, tree_size = moves['var_trees'].shape
1191
+ tree_size *= 2
1192
+ tree_indices = jnp.arange(ntree)
1193
+
1194
+ prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
1195
+
1196
+ # prec datapoints in nodes modified by move
1197
+ precs = dict()
1198
+ precs['left'] = prec_trees[tree_indices, moves['left']]
1199
+ precs['right'] = prec_trees[tree_indices, moves['right']]
1200
+ precs['total'] = precs['left'] + precs['right']
1201
+
1202
+ # write prec into non-leaf node
1203
+ prec_trees = prec_trees.at[tree_indices, moves['node']].set(precs['total'])
1204
+
1205
+ return prec_trees, precs
1206
+
1207
+
1208
+ def prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size):
1209
+ """
1210
+ Compute the likelihood precision scale in each leaf.
1211
+
1212
+ Parameters
1213
+ ----------
1214
+ prec_scale : float array (n,)
1215
+ The scale of the precision of the error on each datapoint.
1216
+ leaf_indices : int array (num_trees, n)
1217
+ The index of the leaf each datapoint falls into.
1218
+ tree_size : int
1219
+ The size of the leaf tree array (2 ** d).
1220
+ batch_size : int or None
1221
+ The data batch size to use for the summation.
1222
+
1223
+ Returns
1224
+ -------
1225
+ prec_trees : int array (num_trees, 2 ** (d - 1))
1226
+ The likelihood precision scale in each leaf node.
1227
+ """
1228
+ if batch_size is None:
1229
+ return _prec_scan(prec_scale, leaf_indices, tree_size)
1230
+ else:
1231
+ return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1232
+
1233
+
1234
+ def _prec_scan(prec_scale, leaf_indices, tree_size):
1235
+ def loop(_, leaf_indices):
1236
+ return None, _aggregate_scatter(
1237
+ prec_scale, leaf_indices, tree_size, jnp.float32
1238
+ ) # TODO: use large_float
1239
+
1240
+ _, prec_trees = lax.scan(loop, None, leaf_indices)
1241
+ return prec_trees
1242
+
1243
+
1244
+ def _prec_vec(prec_scale, leaf_indices, tree_size, batch_size):
1245
+ return _aggregate_batched_alltrees(
1246
+ prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1247
+ ) # TODO: use large_float
1248
+
1249
+
1064
1250
  def complete_ratio(moves, move_counts, min_points_per_leaf):
1065
1251
  """
1066
1252
  Complete non-likelihood MH ratio calculation.
1067
1253
 
1068
- This functions adds the probability of choosing the prune move.
1254
+ This function adds the probability of choosing the prune move.
1069
1255
 
1070
1256
  Parameters
1071
1257
  ----------
@@ -1084,10 +1270,13 @@ def complete_ratio(moves, move_counts, min_points_per_leaf):
1084
1270
  'log_trans_prior_ratio'.
1085
1271
  """
1086
1272
  moves = moves.copy()
1087
- p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf)
1273
+ p_prune = compute_p_prune(
1274
+ moves, move_counts['left'], move_counts['right'], min_points_per_leaf
1275
+ )
1088
1276
  moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune)
1089
1277
  return moves
1090
1278
 
1279
+
1091
1280
  def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1092
1281
  """
1093
1282
  Compute the probability of proposing a prune move.
@@ -1123,6 +1312,7 @@ def compute_p_prune(moves, left_count, right_count, min_points_per_leaf):
1123
1312
 
1124
1313
  return jnp.where(moves['grow'], grow_p_prune, prune_p_prune)
1125
1314
 
1315
+
1126
1316
  @jaxext.vmap_nodoc
1127
1317
  def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1128
1318
  """
@@ -1143,26 +1333,25 @@ def adapt_leaf_trees_to_grow_indices(leaf_trees, moves):
1143
1333
  what would be its children if the grow move was accepted.
1144
1334
  """
1145
1335
  values_at_node = leaf_trees[moves['node']]
1146
- return (leaf_trees
1147
- .at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1336
+ return (
1337
+ leaf_trees.at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1148
1338
  .set(values_at_node)
1149
1339
  .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1150
1340
  .set(values_at_node)
1151
1341
  )
1152
1342
 
1153
- def precompute_likelihood_terms(count_trees, sigma2, move_counts):
1343
+
1344
+ def precompute_likelihood_terms(sigma2, move_precs):
1154
1345
  """
1155
1346
  Pre-compute terms used in the likelihood ratio of the acceptance step.
1156
1347
 
1157
1348
  Parameters
1158
1349
  ----------
1159
- count_trees : array (num_trees, 2 ** d)
1160
- The number of points in each potential or actual leaf node.
1161
1350
  sigma2 : float
1162
1351
  The noise variance.
1163
- move_counts : dict
1164
- The counts of the number of points in the the nodes modified by the
1165
- moves.
1352
+ move_precs : dict
1353
+ The likelihood precision scale in the leaves grown or pruned by the
1354
+ moves, under keys 'left', 'right', and 'total' (left + right).
1166
1355
 
1167
1356
  Returns
1168
1357
  -------
@@ -1173,32 +1362,37 @@ def precompute_likelihood_terms(count_trees, sigma2, move_counts):
1173
1362
  Dictionary with pre-computed terms of the likelihood ratio, shared by
1174
1363
  all trees.
1175
1364
  """
1176
- ntree = len(count_trees)
1365
+ ntree = len(move_precs['total'])
1177
1366
  sigma_mu2 = 1 / ntree
1178
1367
  prelkv = dict()
1179
- prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2
1180
- prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2
1181
- prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2
1182
- prelkv['sqrt_term'] = jnp.log(
1183
- sigma2 * prelkv['sigma2_total'] /
1184
- (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1185
- ) / 2
1368
+ prelkv['sigma2_left'] = sigma2 + move_precs['left'] * sigma_mu2
1369
+ prelkv['sigma2_right'] = sigma2 + move_precs['right'] * sigma_mu2
1370
+ prelkv['sigma2_total'] = sigma2 + move_precs['total'] * sigma_mu2
1371
+ prelkv['sqrt_term'] = (
1372
+ jnp.log(
1373
+ sigma2
1374
+ * prelkv['sigma2_total']
1375
+ / (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1376
+ )
1377
+ / 2
1378
+ )
1186
1379
  return prelkv, dict(
1187
1380
  exp_factor=sigma_mu2 / (2 * sigma2),
1188
1381
  )
1189
1382
 
1190
- def precompute_leaf_terms(count_trees, sigma2, key):
1383
+
1384
+ def precompute_leaf_terms(key, prec_trees, sigma2):
1191
1385
  """
1192
1386
  Pre-compute terms used to sample leaves from their posterior.
1193
1387
 
1194
1388
  Parameters
1195
1389
  ----------
1196
- count_trees : array (num_trees, 2 ** d)
1197
- The number of points in each potential or actual leaf node.
1198
- sigma2 : float
1199
- The noise variance.
1200
1390
  key : jax.dtypes.prng_key array
1201
1391
  A jax random key.
1392
+ prec_trees : array (num_trees, 2 ** d)
1393
+ The likelihood precision scale in each potential or actual leaf node.
1394
+ sigma2 : float
1395
+ The noise variance.
1202
1396
 
1203
1397
  Returns
1204
1398
  -------
@@ -1206,22 +1400,25 @@ def precompute_leaf_terms(count_trees, sigma2, key):
1206
1400
  Dictionary with pre-computed terms of the leaf sampling, with fields:
1207
1401
 
1208
1402
  'mean_factor' : float array (num_trees, 2 ** d)
1209
- The factor to be multiplied by the sum of residuals to obtain the
1210
- posterior mean.
1403
+ The factor to be multiplied by the sum of the scaled residuals to
1404
+ obtain the posterior mean.
1211
1405
  'centered_leaves' : float array (num_trees, 2 ** d)
1212
1406
  The mean-zero normal values to be added to the posterior mean to
1213
1407
  obtain the posterior leaf samples.
1214
1408
  """
1215
- ntree = len(count_trees)
1216
- prec_lk = count_trees / sigma2
1217
- var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1218
- z = random.normal(key, count_trees.shape, sigma2.dtype)
1409
+ ntree = len(prec_trees)
1410
+ prec_lk = prec_trees / sigma2
1411
+ var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior)
1412
+ z = random.normal(key, prec_trees.shape, sigma2.dtype)
1219
1413
  return dict(
1220
- mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1414
+ mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1221
1415
  centered_leaves=z * jnp.sqrt(var_post),
1222
1416
  )
1223
1417
 
1224
- def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf):
1418
+
1419
+ def accept_moves_sequential_stage(
1420
+ bart, prec_trees, moves, move_counts, move_precs, prelkv, prelk, prelf
1421
+ ):
1225
1422
  """
1226
1423
  The part of accepting the moves that has to be done one tree at a time.
1227
1424
 
@@ -1229,13 +1426,15 @@ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv,
1229
1426
  ----------
1230
1427
  bart : dict
1231
1428
  A partially updated BART mcmc state.
1232
- count_trees : array (num_trees, 2 ** d)
1233
- The number of points in each potential or actual leaf node.
1429
+ prec_trees : float array (num_trees, 2 ** d)
1430
+ The likelihood precision scale in each potential or actual leaf node.
1234
1431
  moves : dict
1235
1432
  The proposed moves, see `sample_moves`.
1236
1433
  move_counts : dict
1237
1434
  The counts of the number of points in the the nodes modified by the
1238
1435
  moves.
1436
+ move_precs : dict
1437
+ The likelihood precision scale in each node modified by the moves.
1239
1438
  prelkv, prelk, prelf : dict
1240
1439
  Dictionaries with pre-computed terms of the likelihood ratios and leaf
1241
1440
  samples.
@@ -1262,6 +1461,7 @@ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv,
1262
1461
  len(bart['leaf_trees']),
1263
1462
  bart['opt']['resid_batch_size'],
1264
1463
  resid,
1464
+ bart['prec_scale'],
1265
1465
  bart['min_points_per_leaf'],
1266
1466
  'ratios' in bart,
1267
1467
  prelk,
@@ -1270,22 +1470,44 @@ def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv,
1270
1470
  return resid, (leaf_tree, acc, to_prune, ratios)
1271
1471
 
1272
1472
  items = (
1273
- bart['leaf_trees'], count_trees,
1274
- moves, move_counts,
1473
+ bart['leaf_trees'],
1474
+ prec_trees,
1475
+ moves,
1476
+ move_counts,
1477
+ move_precs,
1275
1478
  bart['leaf_indices'],
1276
- prelkv, prelf,
1479
+ prelkv,
1480
+ prelf,
1277
1481
  )
1278
1482
  resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items)
1279
1483
 
1280
1484
  bart['resid'] = resid
1281
1485
  bart['leaf_trees'] = leaf_trees
1282
- bart.get('ratios', {}).update(ratios)
1486
+ bart.get('ratios', {}).update(ratios) # noop if there are no ratios
1283
1487
  moves['acc'] = acc
1284
1488
  moves['to_prune'] = to_prune
1285
1489
 
1286
1490
  return bart, moves
1287
1491
 
1288
- def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf):
1492
+
1493
+ def accept_move_and_sample_leaves(
1494
+ X,
1495
+ ntree,
1496
+ resid_batch_size,
1497
+ resid,
1498
+ prec_scale,
1499
+ min_points_per_leaf,
1500
+ save_ratios,
1501
+ prelk,
1502
+ leaf_tree,
1503
+ prec_tree,
1504
+ move,
1505
+ move_counts,
1506
+ move_precs,
1507
+ leaf_indices,
1508
+ prelkv,
1509
+ prelf,
1510
+ ):
1289
1511
  """
1290
1512
  Accept or reject a proposed move and sample the new leaf values.
1291
1513
 
@@ -1299,6 +1521,9 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1299
1521
  The batch size for computing the sum of residuals in each leaf.
1300
1522
  resid : float array (n,)
1301
1523
  The residuals (data minus forest value).
1524
+ prec_scale : float array (n,) or None
1525
+ The scale of the precision of the error on each datapoint. If None, it
1526
+ is assumed to be 1.
1302
1527
  min_points_per_leaf : int or None
1303
1528
  The minimum number of data points in a leaf node.
1304
1529
  save_ratios : bool
@@ -1308,10 +1533,15 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1308
1533
  trees.
1309
1534
  leaf_tree : float array (2 ** d,)
1310
1535
  The leaf values of the tree.
1311
- count_tree : int array (2 ** d,)
1312
- The number of datapoints in each leaf.
1536
+ prec_tree : float array (2 ** d,)
1537
+ The likelihood precision scale in each potential or actual leaf node.
1313
1538
  move : dict
1314
1539
  The proposed move, see `sample_moves`.
1540
+ move_counts : dict
1541
+ The counts of the number of points in the the nodes modified by the
1542
+ moves.
1543
+ move_precs : dict
1544
+ The likelihood precision scale in each node modified by the moves.
1315
1545
  leaf_indices : int array (n,)
1316
1546
  The leaf indices for the largest version of the tree compatible with
1317
1547
  the move.
@@ -1334,11 +1564,15 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1334
1564
  The acceptance ratios for the moves. Empty if not to be saved.
1335
1565
  """
1336
1566
 
1337
- # sum residuals and count units per leaf, in tree proposed by grow move
1338
- resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size)
1567
+ # sum residuals in each leaf, in tree proposed by grow move
1568
+ if prec_scale is None:
1569
+ scaled_resid = resid
1570
+ else:
1571
+ scaled_resid = resid * prec_scale
1572
+ resid_tree = sum_resid(scaled_resid, leaf_indices, leaf_tree.size, resid_batch_size)
1339
1573
 
1340
1574
  # subtract starting tree from function
1341
- resid_tree += count_tree * leaf_tree
1575
+ resid_tree += prec_tree * leaf_tree
1342
1576
 
1343
1577
  # get indices of move
1344
1578
  node = move['node']
@@ -1353,7 +1587,9 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1353
1587
  resid_tree = resid_tree.at[node].set(resid_total)
1354
1588
 
1355
1589
  # compute acceptance ratio
1356
- log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk)
1590
+ log_lk_ratio = compute_likelihood_ratio(
1591
+ resid_total, resid_left, resid_right, prelkv, prelk
1592
+ )
1357
1593
  log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio
1358
1594
  log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio)
1359
1595
  ratios = {}
@@ -1374,10 +1610,10 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1374
1610
  mean_post = resid_tree * prelf['mean_factor']
1375
1611
  leaf_tree = mean_post + prelf['centered_leaves']
1376
1612
 
1377
- # copy leaves around such that the leaf indices select the right leaf
1613
+ # copy leaves around such that the leaf indices point to the correct leaf
1378
1614
  to_prune = acc ^ move['grow']
1379
- leaf_tree = (leaf_tree
1380
- .at[jnp.where(to_prune, left, leaf_tree.size)]
1615
+ leaf_tree = (
1616
+ leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
1381
1617
  .set(leaf_tree[node])
1382
1618
  .at[jnp.where(to_prune, right, leaf_tree.size)]
1383
1619
  .set(leaf_tree[node])
@@ -1388,14 +1624,16 @@ def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_
1388
1624
 
1389
1625
  return resid, leaf_tree, acc, to_prune, ratios
1390
1626
 
1391
- def sum_resid(resid, leaf_indices, tree_size, batch_size):
1627
+
1628
+ def sum_resid(scaled_resid, leaf_indices, tree_size, batch_size):
1392
1629
  """
1393
1630
  Sum the residuals in each leaf.
1394
1631
 
1395
1632
  Parameters
1396
1633
  ----------
1397
- resid : float array (n,)
1398
- The residuals (data minus forest value).
1634
+ scaled_resid : float array (n,)
1635
+ The residuals (data minus forest value) multiplied by the error
1636
+ precision scale.
1399
1637
  leaf_indices : int array (n,)
1400
1638
  The leaf indices of the tree (in which leaf each data point falls into).
1401
1639
  tree_size : int
@@ -1413,29 +1651,32 @@ def sum_resid(resid, leaf_indices, tree_size, batch_size):
1413
1651
  aggr_func = _aggregate_scatter
1414
1652
  else:
1415
1653
  aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size)
1416
- return aggr_func(resid, leaf_indices, tree_size, jnp.float32)
1654
+ return aggr_func(
1655
+ scaled_resid, leaf_indices, tree_size, jnp.float32
1656
+ ) # TODO: use large_float
1657
+
1417
1658
 
1418
1659
  def _aggregate_batched_onetree(values, indices, size, dtype, batch_size):
1419
- n, = indices.shape
1660
+ (n,) = indices.shape
1420
1661
  nbatches = n // batch_size + bool(n % batch_size)
1421
1662
  batch_indices = jnp.arange(n) % nbatches
1422
- return (jnp
1423
- .zeros((size, nbatches), dtype)
1663
+ return (
1664
+ jnp.zeros((size, nbatches), dtype)
1424
1665
  .at[indices, batch_indices]
1425
1666
  .add(values)
1426
1667
  .sum(axis=1)
1427
1668
  )
1428
1669
 
1670
+
1429
1671
  def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk):
1430
1672
  """
1431
1673
  Compute the likelihood ratio of a grow move.
1432
1674
 
1433
1675
  Parameters
1434
1676
  ----------
1435
- total_resid : float
1436
- The sum of the residuals in the leaf to grow.
1437
- left_resid, right_resid : float
1438
- The sum of the residuals in the left/right child of the leaf to grow.
1677
+ total_resid, left_resid, right_resid : float
1678
+ The sum of the residuals (scaled by error precision scale) of the
1679
+ datapoints falling in the nodes involved in the moves.
1439
1680
  prelkv, prelk : dict
1440
1681
  The pre-computed terms of the likelihood ratio, see
1441
1682
  `precompute_likelihood_terms`.
@@ -1446,12 +1687,13 @@ def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk
1446
1687
  The likelihood ratio P(data | new tree) / P(data | old tree).
1447
1688
  """
1448
1689
  exp_term = prelk['exp_factor'] * (
1449
- left_resid * left_resid / prelkv['sigma2_left'] +
1450
- right_resid * right_resid / prelkv['sigma2_right'] -
1451
- total_resid * total_resid / prelkv['sigma2_total']
1690
+ left_resid * left_resid / prelkv['sigma2_left']
1691
+ + right_resid * right_resid / prelkv['sigma2_right']
1692
+ - total_resid * total_resid / prelkv['sigma2_total']
1452
1693
  )
1453
1694
  return prelkv['sqrt_term'] + exp_term
1454
1695
 
1696
+
1455
1697
  def accept_moves_final_stage(bart, moves):
1456
1698
  """
1457
1699
  The final part of accepting the moves, in parallel across trees.
@@ -1478,7 +1720,8 @@ def accept_moves_final_stage(bart, moves):
1478
1720
  bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves)
1479
1721
  return bart
1480
1722
 
1481
- @jax.vmap
1723
+
1724
+ @jaxext.vmap_nodoc
1482
1725
  def apply_moves_to_leaf_indices(leaf_indices, moves):
1483
1726
  """
1484
1727
  Update the leaf indices to match the accepted move.
@@ -1497,7 +1740,7 @@ def apply_moves_to_leaf_indices(leaf_indices, moves):
1497
1740
  leaf_indices : int array (num_trees, n)
1498
1741
  The updated leaf indices.
1499
1742
  """
1500
- mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1743
+ mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1501
1744
  is_child = (leaf_indices & mask) == moves['left']
1502
1745
  return jnp.where(
1503
1746
  is_child & moves['to_prune'],
@@ -1505,7 +1748,8 @@ def apply_moves_to_leaf_indices(leaf_indices, moves):
1505
1748
  leaf_indices,
1506
1749
  )
1507
1750
 
1508
- @jax.vmap
1751
+
1752
+ @jaxext.vmap_nodoc
1509
1753
  def apply_moves_to_split_trees(split_trees, moves):
1510
1754
  """
1511
1755
  Update the split trees to match the accepted move.
@@ -1523,31 +1767,36 @@ def apply_moves_to_split_trees(split_trees, moves):
1523
1767
  split_trees : int array (num_trees, 2 ** (d - 1))
1524
1768
  The updated split trees.
1525
1769
  """
1526
- return (split_trees
1527
- .at[jnp.where(
1528
- moves['grow'],
1529
- moves['node'],
1530
- split_trees.size,
1531
- )]
1770
+ return (
1771
+ split_trees.at[
1772
+ jnp.where(
1773
+ moves['grow'],
1774
+ moves['node'],
1775
+ split_trees.size,
1776
+ )
1777
+ ]
1532
1778
  .set(moves['grow_split'].astype(split_trees.dtype))
1533
- .at[jnp.where(
1534
- moves['to_prune'],
1535
- moves['node'],
1536
- split_trees.size,
1537
- )]
1779
+ .at[
1780
+ jnp.where(
1781
+ moves['to_prune'],
1782
+ moves['node'],
1783
+ split_trees.size,
1784
+ )
1785
+ ]
1538
1786
  .set(0)
1539
1787
  )
1540
1788
 
1541
- def sample_sigma(bart, key):
1789
+
1790
+ def sample_sigma(key, bart):
1542
1791
  """
1543
1792
  Noise variance sampling step of BART MCMC.
1544
1793
 
1545
1794
  Parameters
1546
1795
  ----------
1547
- bart : dict
1548
- A BART mcmc state, as created by `init`.
1549
1796
  key : jax.dtypes.prng_key array
1550
1797
  A jax random key.
1798
+ bart : dict
1799
+ A BART mcmc state, as created by `init`.
1551
1800
 
1552
1801
  Returns
1553
1802
  -------
@@ -1558,7 +1807,11 @@ def sample_sigma(bart, key):
1558
1807
 
1559
1808
  resid = bart['resid']
1560
1809
  alpha = bart['sigma2_alpha'] + resid.size / 2
1561
- norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float'])
1810
+ if bart['prec_scale'] is None:
1811
+ scaled_resid = resid
1812
+ else:
1813
+ scaled_resid = resid * bart['prec_scale']
1814
+ norm2 = resid @ scaled_resid
1562
1815
  beta = bart['sigma2_beta'] + norm2 / 2
1563
1816
 
1564
1817
  sample = random.gamma(key, alpha)