bartz 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1603 @@
1
+ # bartz/src/bartz/mcmcstep/_step.py
2
+ #
3
+ # Copyright (c) 2024-2026, The Bartz Contributors
4
+ #
5
+ # This file is part of bartz.
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+ """Implement `step`, `step_trees`, and the accept-reject logic."""
26
+
27
+ from dataclasses import replace
28
+ from functools import partial
29
+
30
+ try:
31
+ # available since jax v0.6.1
32
+ from jax import shard_map
33
+ except ImportError:
34
+ # deprecated in jax v0.8.0
35
+ from jax.experimental.shard_map import shard_map
36
+
37
+ import jax
38
+ from equinox import Module, tree_at
39
+ from jax import lax, random, vmap
40
+ from jax import numpy as jnp
41
+ from jax.lax import cond
42
+ from jax.scipy.linalg import solve_triangular
43
+ from jax.scipy.special import gammaln, logsumexp
44
+ from jax.sharding import Mesh, PartitionSpec
45
+ from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt, UInt32
46
+
47
+ from bartz._profiler import (
48
+ get_profile_mode,
49
+ jit_and_block_if_profiling,
50
+ jit_if_not_profiling,
51
+ jit_if_profiling,
52
+ vmap_chains_if_not_profiling,
53
+ vmap_chains_if_profiling,
54
+ )
55
+ from bartz.grove import var_histogram
56
+ from bartz.jaxext import split, truncated_normal_onesided, vmap_nodoc
57
+ from bartz.mcmcstep._moves import Moves, propose_moves
58
+ from bartz.mcmcstep._state import State, StepConfig, chol_with_gersh, field
59
+
60
+
61
+ @partial(jit_if_not_profiling, donate_argnums=(1,))
62
+ @partial(vmap_chains_if_not_profiling, auto_split_keys=True)
63
+ def step(key: Key[Array, ''], bart: State) -> State:
64
+ """
65
+ Do one MCMC step.
66
+
67
+ Parameters
68
+ ----------
69
+ key
70
+ A jax random key.
71
+ bart
72
+ A BART mcmc state, as created by `init`.
73
+
74
+ Returns
75
+ -------
76
+ The new BART mcmc state.
77
+
78
+ Notes
79
+ -----
80
+ The memory of the input state is re-used for the output state, so the input
81
+ state can not be used any more after calling `step`. All this applies
82
+ outside of `jax.jit`.
83
+ """
84
+ # handle the interactions between chains and profile mode
85
+ num_chains = bart.forest.num_chains()
86
+ chain_shape = () if num_chains is None else (num_chains,)
87
+ if get_profile_mode() and num_chains is not None and key.ndim == 0:
88
+ key = random.split(key, num_chains)
89
+ assert key.shape == chain_shape
90
+
91
+ keys = split(key, 3)
92
+
93
+ if bart.y.dtype == bool:
94
+ bart = replace(bart, error_cov_inv=jnp.ones(chain_shape))
95
+ bart = step_trees(keys.pop(), bart)
96
+ bart = replace(bart, error_cov_inv=None)
97
+ bart = step_z(keys.pop(), bart)
98
+
99
+ else: # continuous or multivariate regression
100
+ bart = step_trees(keys.pop(), bart)
101
+ bart = step_error_cov_inv(keys.pop(), bart)
102
+
103
+ bart = step_sparse(keys.pop(), bart)
104
+ return step_config(bart)
105
+
106
+
107
+ def step_trees(key: Key[Array, ''], bart: State) -> State:
108
+ """
109
+ Forest sampling step of BART MCMC.
110
+
111
+ Parameters
112
+ ----------
113
+ key
114
+ A jax random key.
115
+ bart
116
+ A BART mcmc state, as created by `init`.
117
+
118
+ Returns
119
+ -------
120
+ The new BART mcmc state.
121
+
122
+ Notes
123
+ -----
124
+ This function zeroes the proposal counters.
125
+ """
126
+ keys = split(key)
127
+ moves = propose_moves(keys.pop(), bart.forest)
128
+ return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
129
+
130
+
131
+ def accept_moves_and_sample_leaves(
132
+ key: Key[Array, ''], bart: State, moves: Moves
133
+ ) -> State:
134
+ """
135
+ Accept or reject the proposed moves and sample the new leaf values.
136
+
137
+ Parameters
138
+ ----------
139
+ key
140
+ A jax random key.
141
+ bart
142
+ A valid BART mcmc state.
143
+ moves
144
+ The proposed moves, see `propose_moves`.
145
+
146
+ Returns
147
+ -------
148
+ A new (valid) BART mcmc state.
149
+ """
150
+ pso = accept_moves_parallel_stage(key, bart, moves)
151
+ bart, moves = accept_moves_sequential_stage(pso)
152
+ return accept_moves_final_stage(bart, moves)
153
+
154
+
155
+ class Counts(Module):
156
+ """Number of datapoints in the nodes involved in proposed moves for each tree."""
157
+
158
+ left: UInt[Array, '*chains num_trees'] = field(chains=True)
159
+ """Number of datapoints in the left child."""
160
+
161
+ right: UInt[Array, '*chains num_trees'] = field(chains=True)
162
+ """Number of datapoints in the right child."""
163
+
164
+ total: UInt[Array, '*chains num_trees'] = field(chains=True)
165
+ """Number of datapoints in the parent (``= left + right``)."""
166
+
167
+
168
+ class Precs(Module):
169
+ """Likelihood precision scale in the nodes involved in proposed moves for each tree.
170
+
171
+ The "likelihood precision scale" of a tree node is the sum of the inverse
172
+ squared error scales of the datapoints selected by the node.
173
+ """
174
+
175
+ left: Float32[Array, '*chains num_trees'] = field(chains=True)
176
+ """Likelihood precision scale in the left child."""
177
+
178
+ right: Float32[Array, '*chains num_trees'] = field(chains=True)
179
+ """Likelihood precision scale in the right child."""
180
+
181
+ total: Float32[Array, '*chains num_trees'] = field(chains=True)
182
+ """Likelihood precision scale in the parent (``= left + right``)."""
183
+
184
+
185
+ class PreLkV(Module):
186
+ """Non-sequential terms of the likelihood ratio for each tree.
187
+
188
+ These terms can be computed in parallel across trees.
189
+ """
190
+
191
+ left: (
192
+ Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
193
+ ) = field(chains=True)
194
+ """In the univariate case, this is the scalar term
195
+
196
+ ``1 / error_cov_inv + n_left / leaf_prior_cov_inv``.
197
+
198
+ In the multivariate case, this is the matrix term
199
+
200
+ ``error_cov_inv @ inv(leaf_prior_cov_inv + n_left * error_cov_inv) @ error_cov_inv``.
201
+
202
+ ``n_left`` is the number of datapoints in the left child, or the
203
+ likelihood precision scale in the heteroskedastic case."""
204
+
205
+ right: (
206
+ Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
207
+ ) = field(chains=True)
208
+ """In the univariate case, this is the scalar term
209
+
210
+ ``1 / error_cov_inv + n_right / leaf_prior_cov_inv``.
211
+
212
+ In the multivariate case, this is the matrix term
213
+
214
+ ``error_cov_inv @ inv(leaf_prior_cov_inv + n_right * error_cov_inv) @ error_cov_inv``.
215
+
216
+ ``n_right`` is the number of datapoints in the right child, or the
217
+ likelihood precision scale in the heteroskedastic case."""
218
+
219
+ total: (
220
+ Float32[Array, '*chains num_trees'] | Float32[Array, '*chains num_trees k k']
221
+ ) = field(chains=True)
222
+ """In the univariate case, this is the scalar term
223
+
224
+ ``1 / error_cov_inv + n_total / leaf_prior_cov_inv``.
225
+
226
+ In the multivariate case, this is the matrix term
227
+
228
+ ``error_cov_inv @ inv(leaf_prior_cov_inv + n_total * error_cov_inv) @ error_cov_inv``.
229
+
230
+ ``n_total`` is the number of datapoints in the parent node, or the
231
+ likelihood precision scale in the heteroskedastic case."""
232
+
233
+ log_sqrt_term: Float32[Array, '*chains num_trees'] = field(chains=True)
234
+ """The logarithm of the square root term of the likelihood ratio."""
235
+
236
+
237
+ class PreLk(Module):
238
+ """Non-sequential terms of the likelihood ratio shared by all trees."""
239
+
240
+ exp_factor: Float32[Array, '*chains'] = field(chains=True)
241
+ """The factor to multiply the likelihood ratio by, shared by all trees."""
242
+
243
+
244
+ class PreLf(Module):
245
+ """Pre-computed terms used to sample leaves from their posterior.
246
+
247
+ These terms can be computed in parallel across trees.
248
+
249
+ For each tree and leaf, the terms are scalars in the univariate case, and
250
+ matrices/vectors in the multivariate case.
251
+ """
252
+
253
+ mean_factor: (
254
+ Float32[Array, '*chains num_trees 2**d']
255
+ | Float32[Array, '*chains num_trees k k 2**d']
256
+ ) = field(chains=True)
257
+ """The factor to be right-multiplied by the sum of the scaled residuals to
258
+ obtain the posterior mean."""
259
+
260
+ centered_leaves: (
261
+ Float32[Array, '*chains num_trees 2**d']
262
+ | Float32[Array, '*chains num_trees k 2**d']
263
+ ) = field(chains=True)
264
+ """The mean-zero normal values to be added to the posterior mean to
265
+ obtain the posterior leaf samples."""
266
+
267
+
268
+ class ParallelStageOut(Module):
269
+ """The output of `accept_moves_parallel_stage`."""
270
+
271
+ bart: State
272
+ """A partially updated BART mcmc state."""
273
+
274
+ moves: Moves
275
+ """The proposed moves, with `partial_ratio` set to `None` and
276
+ `log_trans_prior_ratio` set to its final value."""
277
+
278
+ prec_trees: (
279
+ Float32[Array, '*chains num_trees 2**d']
280
+ | Int32[Array, '*chains num_trees 2**d']
281
+ ) = field(chains=True)
282
+ """The likelihood precision scale in each potential or actual leaf node. If
283
+ there is no precision scale, this is the number of points in each leaf."""
284
+
285
+ move_precs: Precs | Counts
286
+ """The likelihood precision scale in each node modified by the moves. If
287
+ `bart.prec_scale` is not set, this is set to `move_counts`."""
288
+
289
+ prelkv: PreLkV
290
+ """Object with pre-computed terms of the likelihood ratios."""
291
+
292
+ prelk: PreLk | None
293
+ """Object with pre-computed terms of the likelihood ratios."""
294
+
295
+ prelf: PreLf
296
+ """Object with pre-computed terms of the leaf samples."""
297
+
298
+
299
+ @partial(jit_and_block_if_profiling, donate_argnums=(1, 2))
300
+ @vmap_chains_if_profiling
301
+ def accept_moves_parallel_stage(
302
+ key: Key[Array, ''], bart: State, moves: Moves
303
+ ) -> ParallelStageOut:
304
+ """
305
+ Pre-compute quantities used to accept moves, in parallel across trees.
306
+
307
+ Parameters
308
+ ----------
309
+ key
310
+ A jax random key.
311
+ bart
312
+ A BART mcmc state.
313
+ moves
314
+ The proposed moves, see `propose_moves`.
315
+
316
+ Returns
317
+ -------
318
+ An object with all that could be done in parallel.
319
+ """
320
+ # where the move is grow, modify the state like the move was accepted
321
+ bart = replace(
322
+ bart,
323
+ forest=replace(
324
+ bart.forest,
325
+ var_tree=moves.var_tree,
326
+ leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
327
+ leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
328
+ ),
329
+ )
330
+
331
+ # count number of datapoints per leaf
332
+ if (
333
+ bart.forest.min_points_per_decision_node is not None
334
+ or bart.forest.min_points_per_leaf is not None
335
+ or bart.prec_scale is None
336
+ ):
337
+ count_trees, move_counts = compute_count_trees(
338
+ bart.forest.leaf_indices, moves, bart.config
339
+ )
340
+
341
+ # mark which leaves & potential leaves have enough points to be grown
342
+ if bart.forest.min_points_per_decision_node is not None:
343
+ count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]]
344
+ moves = replace(
345
+ moves,
346
+ affluence_tree=moves.affluence_tree
347
+ & (count_half_trees >= bart.forest.min_points_per_decision_node),
348
+ )
349
+
350
+ # copy updated affluence_tree to state
351
+ bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree)
352
+
353
+ # veto grove move if new leaves don't have enough datapoints
354
+ if bart.forest.min_points_per_leaf is not None:
355
+ moves = replace(
356
+ moves,
357
+ allowed=moves.allowed
358
+ & (move_counts.left >= bart.forest.min_points_per_leaf)
359
+ & (move_counts.right >= bart.forest.min_points_per_leaf),
360
+ )
361
+
362
+ # count number of datapoints per leaf, weighted by error precision scale
363
+ if bart.prec_scale is None:
364
+ prec_trees = count_trees
365
+ move_precs = move_counts
366
+ else:
367
+ prec_trees, move_precs = compute_prec_trees(
368
+ bart.prec_scale, bart.forest.leaf_indices, moves, bart.config
369
+ )
370
+ assert move_precs is not None
371
+
372
+ # compute some missing information about moves
373
+ moves = complete_ratio(moves, bart.forest.p_nonterminal)
374
+ save_ratios = bart.forest.log_likelihood is not None
375
+ bart = replace(
376
+ bart,
377
+ forest=replace(
378
+ bart.forest,
379
+ grow_prop_count=jnp.sum(moves.grow),
380
+ prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
381
+ log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
382
+ ),
383
+ )
384
+
385
+ assert bart.error_cov_inv is not None
386
+ prelkv, prelk = precompute_likelihood_terms(
387
+ bart.error_cov_inv, bart.forest.leaf_prior_cov_inv, move_precs
388
+ )
389
+ prelf = precompute_leaf_terms(
390
+ key, prec_trees, bart.error_cov_inv, bart.forest.leaf_prior_cov_inv
391
+ )
392
+
393
+ return ParallelStageOut(
394
+ bart=bart,
395
+ moves=moves,
396
+ prec_trees=prec_trees,
397
+ move_precs=move_precs,
398
+ prelkv=prelkv,
399
+ prelk=prelk,
400
+ prelf=prelf,
401
+ )
402
+
403
+
404
+ @partial(vmap_nodoc, in_axes=(0, 0, None))
405
+ def apply_grow_to_indices(
406
+ moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
407
+ ) -> UInt[Array, 'num_trees n']:
408
+ """
409
+ Update the leaf indices to apply a grow move.
410
+
411
+ Parameters
412
+ ----------
413
+ moves
414
+ The proposed moves, see `propose_moves`.
415
+ leaf_indices
416
+ The index of the leaf each datapoint falls into.
417
+ X
418
+ The predictors matrix.
419
+
420
+ Returns
421
+ -------
422
+ The updated leaf indices.
423
+ """
424
+ left_child = moves.node.astype(leaf_indices.dtype) << 1
425
+ x: UInt[Array, ' n'] = X[moves.grow_var, :]
426
+ go_right = x >= moves.grow_split
427
+ tree_size = jnp.array(2 * moves.var_tree.size)
428
+ node_to_update = jnp.where(moves.grow, moves.node, tree_size)
429
+ return jnp.where(
430
+ leaf_indices == node_to_update, left_child + go_right, leaf_indices
431
+ )
432
+
433
+
434
+ def _compute_count_or_prec_trees(
435
+ prec_scale: Float32[Array, ' n'] | None,
436
+ leaf_indices: UInt[Array, 'num_trees n'],
437
+ moves: Moves,
438
+ config: StepConfig,
439
+ ) -> (
440
+ tuple[UInt32[Array, 'num_trees 2**d'], Counts]
441
+ | tuple[Float32[Array, 'num_trees 2**d'], Precs]
442
+ ):
443
+ """Implement `compute_count_trees` and `compute_prec_trees`."""
444
+ if config.prec_count_num_trees is None:
445
+ compute = vmap(_compute_count_or_prec_tree, in_axes=(None, 0, 0, None))
446
+ return compute(prec_scale, leaf_indices, moves, config)
447
+
448
+ def compute(args):
449
+ leaf_indices, moves = args
450
+ return _compute_count_or_prec_tree(prec_scale, leaf_indices, moves, config)
451
+
452
+ return lax.map(
453
+ compute, (leaf_indices, moves), batch_size=config.prec_count_num_trees
454
+ )
455
+
456
+
457
+ def _compute_count_or_prec_tree(
458
+ prec_scale: Float32[Array, ' n'] | None,
459
+ leaf_indices: UInt[Array, ' n'],
460
+ moves: Moves,
461
+ config: StepConfig,
462
+ ) -> tuple[UInt32[Array, ' 2**d'], Counts] | tuple[Float32[Array, ' 2**d'], Precs]:
463
+ """Compute count or precision tree for a single tree."""
464
+ (tree_size,) = moves.var_tree.shape
465
+ tree_size *= 2
466
+
467
+ if prec_scale is None:
468
+ value = 1
469
+ cls = Counts
470
+ dtype = jnp.uint32
471
+ num_batches = config.count_num_batches
472
+ else:
473
+ value = prec_scale
474
+ cls = Precs
475
+ dtype = jnp.float32
476
+ num_batches = config.prec_num_batches
477
+
478
+ trees = _scatter_add(
479
+ value, leaf_indices, tree_size, dtype, num_batches, config.mesh
480
+ )
481
+
482
+ # count datapoints in nodes modified by move
483
+ left = trees[moves.left]
484
+ right = trees[moves.right]
485
+ counts = cls(left=left, right=right, total=left + right)
486
+
487
+ # write count into non-leaf node
488
+ trees = trees.at[moves.node].set(counts.total)
489
+
490
+ return trees, counts
491
+
492
+
493
+ def compute_count_trees(
494
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, config: StepConfig
495
+ ) -> tuple[UInt32[Array, 'num_trees 2**d'], Counts]:
496
+ """
497
+ Count the number of datapoints in each leaf.
498
+
499
+ Parameters
500
+ ----------
501
+ leaf_indices
502
+ The index of the leaf each datapoint falls into, with the deeper version
503
+ of the tree (post-GROW, pre-PRUNE).
504
+ moves
505
+ The proposed moves, see `propose_moves`.
506
+ config
507
+ The MCMC configuration.
508
+
509
+ Returns
510
+ -------
511
+ count_trees : Int32[Array, 'num_trees 2**d']
512
+ The number of points in each potential or actual leaf node.
513
+ counts : Counts
514
+ The counts of the number of points in the leaves grown or pruned by the
515
+ moves.
516
+ """
517
+ return _compute_count_or_prec_trees(None, leaf_indices, moves, config)
518
+
519
+
520
+ def compute_prec_trees(
521
+ prec_scale: Float32[Array, ' n'],
522
+ leaf_indices: UInt[Array, 'num_trees n'],
523
+ moves: Moves,
524
+ config: StepConfig,
525
+ ) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
526
+ """
527
+ Compute the likelihood precision scale in each leaf.
528
+
529
+ Parameters
530
+ ----------
531
+ prec_scale
532
+ The scale of the precision of the error on each datapoint.
533
+ leaf_indices
534
+ The index of the leaf each datapoint falls into, with the deeper version
535
+ of the tree (post-GROW, pre-PRUNE).
536
+ moves
537
+ The proposed moves, see `propose_moves`.
538
+ config
539
+ The MCMC configuration.
540
+
541
+ Returns
542
+ -------
543
+ prec_trees : Float32[Array, 'num_trees 2**d']
544
+ The likelihood precision scale in each potential or actual leaf node.
545
+ precs : Precs
546
+ The likelihood precision scale in the nodes involved in the moves.
547
+ """
548
+ return _compute_count_or_prec_trees(prec_scale, leaf_indices, moves, config)
549
+
550
+
551
+ @partial(vmap_nodoc, in_axes=(0, None))
552
+ def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves:
553
+ """
554
+ Complete non-likelihood MH ratio calculation.
555
+
556
+ This function adds the probability of choosing a prune move over the grow
557
+ move in the inverse transition, and the a priori probability that the
558
+ children nodes are leaves.
559
+
560
+ Parameters
561
+ ----------
562
+ moves
563
+ The proposed moves. Must have already been updated to keep into account
564
+ the thresholds on the number of datapoints per node, this happens in
565
+ `accept_moves_parallel_stage`.
566
+ p_nonterminal
567
+ The a priori probability of each node being nonterminal conditional on
568
+ its ancestors, including at the maximum depth where it should be zero.
569
+
570
+ Returns
571
+ -------
572
+ The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
573
+ """
574
+ # can the leaves be grown?
575
+ left_growable = moves.affluence_tree.at[moves.left].get(
576
+ mode='fill', fill_value=False
577
+ )
578
+ right_growable = moves.affluence_tree.at[moves.right].get(
579
+ mode='fill', fill_value=False
580
+ )
581
+
582
+ # p_prune if grow
583
+ other_growable_leaves = moves.num_growable >= 2
584
+ grow_again_allowed = other_growable_leaves | left_growable | right_growable
585
+ grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1.0)
586
+
587
+ # p_prune if prune
588
+ prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
589
+
590
+ # select p_prune
591
+ p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune)
592
+
593
+ # prior probability of both children being terminal
594
+ pt_left = 1 - p_nonterminal[moves.left] * left_growable
595
+ pt_right = 1 - p_nonterminal[moves.right] * right_growable
596
+ pt_children = pt_left * pt_right
597
+
598
+ assert moves.partial_ratio is not None
599
+ return replace(
600
+ moves,
601
+ log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
602
+ partial_ratio=None,
603
+ )
604
+
605
+
606
+ @vmap_nodoc
607
+ def adapt_leaf_trees_to_grow_indices(
608
+ leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
609
+ ) -> Float32[Array, 'num_trees 2**d']:
610
+ """
611
+ Modify leaves such that post-grow indices work on the original tree.
612
+
613
+ The value of the leaf to grow is copied to what would be its children if the
614
+ grow move was accepted.
615
+
616
+ Parameters
617
+ ----------
618
+ leaf_trees
619
+ The leaf values.
620
+ moves
621
+ The proposed moves, see `propose_moves`.
622
+
623
+ Returns
624
+ -------
625
+ The modified leaf values.
626
+ """
627
+ values_at_node = leaf_trees[..., moves.node]
628
+ return (
629
+ leaf_trees.at[..., jnp.where(moves.grow, moves.left, leaf_trees.size)]
630
+ .set(values_at_node)
631
+ .at[..., jnp.where(moves.grow, moves.right, leaf_trees.size)]
632
+ .set(values_at_node)
633
+ )
634
+
635
+
636
+ def _logdet_from_chol(L: Float32[Array, '... k k']) -> Float32[Array, '...']:
637
+ """Compute logdet of A = LL' via Cholesky (sum of log of diag^2)."""
638
+ diags: Float32[Array, '... k'] = jnp.diagonal(L, axis1=-2, axis2=-1)
639
+ return 2.0 * jnp.sum(jnp.log(diags), axis=-1)
640
+
641
+
642
+ def _precompute_likelihood_terms_uv(
643
+ error_cov_inv: Float32[Array, ''],
644
+ leaf_prior_cov_inv: Float32[Array, ''],
645
+ move_precs: Precs | Counts,
646
+ ) -> tuple[PreLkV, PreLk]:
647
+ sigma2 = lax.reciprocal(error_cov_inv)
648
+ sigma_mu2 = lax.reciprocal(leaf_prior_cov_inv)
649
+ left = sigma2 + move_precs.left * sigma_mu2
650
+ right = sigma2 + move_precs.right * sigma_mu2
651
+ total = sigma2 + move_precs.total * sigma_mu2
652
+ prelkv = PreLkV(
653
+ left=left,
654
+ right=right,
655
+ total=total,
656
+ log_sqrt_term=jnp.log(sigma2 * total / (left * right)) / 2,
657
+ )
658
+ return prelkv, PreLk(exp_factor=error_cov_inv / leaf_prior_cov_inv / 2)
659
+
660
+
661
+ def _precompute_likelihood_terms_mv(
662
+ error_cov_inv: Float32[Array, 'k k'],
663
+ leaf_prior_cov_inv: Float32[Array, 'k k'],
664
+ move_precs: Counts,
665
+ ) -> tuple[PreLkV, None]:
666
+ nL: UInt[Array, 'num_trees 1 1'] = move_precs.left[..., None, None]
667
+ nR: UInt[Array, 'num_trees 1 1'] = move_precs.right[..., None, None]
668
+ nT: UInt[Array, 'num_trees 1 1'] = move_precs.total[..., None, None]
669
+
670
+ L_left: Float32[Array, 'num_trees k k'] = chol_with_gersh(
671
+ error_cov_inv * nL + leaf_prior_cov_inv
672
+ )
673
+ L_right: Float32[Array, 'num_trees k k'] = chol_with_gersh(
674
+ error_cov_inv * nR + leaf_prior_cov_inv
675
+ )
676
+ L_total: Float32[Array, 'num_trees k k'] = chol_with_gersh(
677
+ error_cov_inv * nT + leaf_prior_cov_inv
678
+ )
679
+
680
+ log_sqrt_term: Float32[Array, ' num_trees'] = 0.5 * (
681
+ _logdet_from_chol(chol_with_gersh(leaf_prior_cov_inv))
682
+ + _logdet_from_chol(L_total)
683
+ - _logdet_from_chol(L_left)
684
+ - _logdet_from_chol(L_right)
685
+ )
686
+
687
+ def _term_from_chol(
688
+ L: Float32[Array, 'num_trees k k'],
689
+ ) -> Float32[Array, 'num_trees k k']:
690
+ rhs: Float32[Array, 'num_trees k k'] = jnp.broadcast_to(error_cov_inv, L.shape)
691
+ Y: Float32[Array, 'num_trees k k'] = solve_triangular(L, rhs, lower=True)
692
+ return Y.mT @ Y
693
+
694
+ prelkv = PreLkV(
695
+ left=_term_from_chol(L_left),
696
+ right=_term_from_chol(L_right),
697
+ total=_term_from_chol(L_total),
698
+ log_sqrt_term=log_sqrt_term,
699
+ )
700
+
701
+ return prelkv, None
702
+
703
+
704
+ def precompute_likelihood_terms(
705
+ error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
706
+ leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
707
+ move_precs: Precs | Counts,
708
+ ) -> tuple[PreLkV, PreLk | None]:
709
+ """
710
+ Pre-compute terms used in the likelihood ratio of the acceptance step.
711
+
712
+ Handles both univariate and multivariate cases based on the shape of the
713
+ input arrays. The multivariate implementation assumes a homoskedastic error
714
+ model (i.e., the residual covariance is the same for all observations).
715
+
716
+ Parameters
717
+ ----------
718
+ error_cov_inv
719
+ The inverse error variance (univariate) or the inverse of the error
720
+ covariance matrix (multivariate). For univariate case, this is the
721
+ inverse global error variance factor if `prec_scale` is set.
722
+ leaf_prior_cov_inv
723
+ The inverse prior variance of each leaf (univariate) or the inverse of
724
+ prior covariance matrix of each leaf (multivariate).
725
+ move_precs
726
+ The likelihood precision scale in the leaves grown or pruned by the
727
+ moves, under keys 'left', 'right', and 'total' (left + right).
728
+
729
+ Returns
730
+ -------
731
+ prelkv : PreLkV
732
+ Pre-computed terms of the likelihood ratio, one per tree.
733
+ prelk : PreLk | None
734
+ Pre-computed terms of the likelihood ratio, shared by all trees.
735
+ """
736
+ if error_cov_inv.ndim == 2:
737
+ assert isinstance(move_precs, Counts)
738
+ return _precompute_likelihood_terms_mv(
739
+ error_cov_inv, leaf_prior_cov_inv, move_precs
740
+ )
741
+ else:
742
+ return _precompute_likelihood_terms_uv(
743
+ error_cov_inv, leaf_prior_cov_inv, move_precs
744
+ )
745
+
746
+
747
+ def _precompute_leaf_terms_uv(
748
+ key: Key[Array, ''],
749
+ prec_trees: Float32[Array, 'num_trees 2**d'],
750
+ error_cov_inv: Float32[Array, ''],
751
+ leaf_prior_cov_inv: Float32[Array, ''],
752
+ z: Float32[Array, 'num_trees 2**d'] | None = None,
753
+ ) -> PreLf:
754
+ prec_lk = prec_trees * error_cov_inv
755
+ var_post = lax.reciprocal(prec_lk + leaf_prior_cov_inv)
756
+ if z is None:
757
+ z = random.normal(key, prec_trees.shape, error_cov_inv.dtype)
758
+ return PreLf(
759
+ mean_factor=var_post * error_cov_inv,
760
+ # | mean = mean_lk * prec_lk * var_post
761
+ # | resid_tree = mean_lk * prec_tree -->
762
+ # | --> mean_lk = resid_tree / prec_tree (kind of)
763
+ # | mean_factor =
764
+ # | = mean / resid_tree =
765
+ # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
766
+ # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
767
+ # | = var_post / sigma2
768
+ centered_leaves=z * jnp.sqrt(var_post),
769
+ )
770
+
771
+
772
+ def _precompute_leaf_terms_mv(
773
+ key: Key[Array, ''],
774
+ prec_trees: Float32[Array, 'num_trees 2**d'],
775
+ error_cov_inv: Float32[Array, 'k k'],
776
+ leaf_prior_cov_inv: Float32[Array, 'k k'],
777
+ z: Float32[Array, 'num_trees 2**d k'] | None = None,
778
+ ) -> PreLf:
779
+ num_trees, tree_size = prec_trees.shape
780
+ k = error_cov_inv.shape[0]
781
+ n_k: Float32[Array, 'num_trees tree_size 1 1'] = prec_trees[..., None, None]
782
+
783
+ # Only broadcast the inverse of error covariance matrix to satisfy JAX's
784
+ # batching rules for `lax.linalg.solve_triangular`, which does not support
785
+ # implicit broadcasting.
786
+ error_cov_inv_batched = jnp.broadcast_to(
787
+ error_cov_inv, (num_trees, tree_size, k, k)
788
+ )
789
+
790
+ posterior_precision: Float32[Array, 'num_trees tree_size k k'] = (
791
+ leaf_prior_cov_inv + n_k * error_cov_inv_batched
792
+ )
793
+
794
+ L_prec: Float32[Array, 'num_trees tree_size k k'] = chol_with_gersh(
795
+ posterior_precision
796
+ )
797
+ Y: Float32[Array, 'num_trees tree_size k k'] = solve_triangular(
798
+ L_prec, error_cov_inv_batched, lower=True
799
+ )
800
+ mean_factor: Float32[Array, 'num_trees tree_size k k'] = solve_triangular(
801
+ L_prec, Y, trans='T', lower=True
802
+ )
803
+ mean_factor = mean_factor.mT
804
+ mean_factor_out: Float32[Array, 'num_trees k k tree_size'] = jnp.moveaxis(
805
+ mean_factor, 1, -1
806
+ )
807
+
808
+ if z is None:
809
+ z = random.normal(key, (num_trees, tree_size, k))
810
+ centered_leaves: Float32[Array, 'num_trees tree_size k'] = solve_triangular(
811
+ L_prec, z, trans='T'
812
+ )
813
+ centered_leaves_out: Float32[Array, 'num_trees k tree_size'] = jnp.swapaxes(
814
+ centered_leaves, -1, -2
815
+ )
816
+
817
+ return PreLf(mean_factor=mean_factor_out, centered_leaves=centered_leaves_out)
818
+
819
+
820
+ def precompute_leaf_terms(
821
+ key: Key[Array, ''],
822
+ prec_trees: Float32[Array, 'num_trees 2**d'],
823
+ error_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
824
+ leaf_prior_cov_inv: Float32[Array, ''] | Float32[Array, 'k k'],
825
+ z: Float32[Array, 'num_trees 2**d']
826
+ | Float32[Array, 'num_trees 2**d k']
827
+ | None = None,
828
+ ) -> PreLf:
829
+ """
830
+ Pre-compute terms used to sample leaves from their posterior.
831
+
832
+ Handles both univariate and multivariate cases based on the shape of the
833
+ input arrays.
834
+
835
+ Parameters
836
+ ----------
837
+ key
838
+ A jax random key.
839
+ prec_trees
840
+ The likelihood precision scale in each potential or actual leaf node.
841
+ error_cov_inv
842
+ The inverse error variance (univariate) or the inverse of error
843
+ covariance matrix (multivariate). For univariate case, this is the
844
+ inverse global error variance factor if `prec_scale` is set.
845
+ leaf_prior_cov_inv
846
+ The inverse prior variance of each leaf (univariate) or the inverse of
847
+ prior covariance matrix of each leaf (multivariate).
848
+ z
849
+ Optional standard normal noise to use for sampling the centered leaves.
850
+ This is intended for testing purposes only.
851
+
852
+ Returns
853
+ -------
854
+ Pre-computed terms for leaf sampling.
855
+ """
856
+ if error_cov_inv.ndim == 2:
857
+ return _precompute_leaf_terms_mv(
858
+ key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
859
+ )
860
+ else:
861
+ return _precompute_leaf_terms_uv(
862
+ key, prec_trees, error_cov_inv, leaf_prior_cov_inv, z
863
+ )
864
+
865
+
866
+ @partial(jit_and_block_if_profiling, donate_argnums=(0,))
867
+ @vmap_chains_if_profiling
868
+ def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
869
+ """
870
+ Accept/reject the moves one tree at a time.
871
+
872
+ This is the most performance-sensitive function because it contains all and
873
+ only the parts of the algorithm that can not be parallelized across trees.
874
+
875
+ Parameters
876
+ ----------
877
+ pso
878
+ The output of `accept_moves_parallel_stage`.
879
+
880
+ Returns
881
+ -------
882
+ bart : State
883
+ A partially updated BART mcmc state.
884
+ moves : Moves
885
+ The accepted/rejected moves, with `acc` and `to_prune` set.
886
+ """
887
+
888
+ def loop(resid, pt):
889
+ resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves(
890
+ resid,
891
+ SeqStageInAllTrees(
892
+ pso.bart.X,
893
+ pso.bart.config.resid_num_batches,
894
+ pso.bart.config.mesh,
895
+ pso.bart.prec_scale,
896
+ pso.bart.forest.log_likelihood is not None,
897
+ pso.prelk,
898
+ ),
899
+ pt,
900
+ )
901
+ return resid, (leaf_tree, acc, to_prune, lkratio)
902
+
903
+ pts = SeqStageInPerTree(
904
+ pso.bart.forest.leaf_tree,
905
+ pso.prec_trees,
906
+ pso.moves,
907
+ pso.move_precs,
908
+ pso.bart.forest.leaf_indices,
909
+ pso.prelkv,
910
+ pso.prelf,
911
+ )
912
+ resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts)
913
+
914
+ bart = replace(
915
+ pso.bart,
916
+ resid=resid,
917
+ forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
918
+ )
919
+ moves = replace(pso.moves, acc=acc, to_prune=to_prune)
920
+
921
+ return bart, moves
922
+
923
+
924
+ class SeqStageInAllTrees(Module):
925
+ """The inputs to `accept_move_and_sample_leaves` that are shared by all trees."""
926
+
927
+ X: UInt[Array, 'p n']
928
+ """The predictors."""
929
+
930
+ resid_num_batches: int | None = field(static=True)
931
+ """The number of batches for computing the sum of residuals in each leaf."""
932
+
933
+ mesh: Mesh | None = field(static=True)
934
+ """The mesh of devices to use."""
935
+
936
+ prec_scale: Float32[Array, ' n'] | None
937
+ """The scale of the precision of the error on each datapoint. If None, it
938
+ is assumed to be 1."""
939
+
940
+ save_ratios: bool = field(static=True)
941
+ """Whether to save the acceptance ratios."""
942
+
943
+ prelk: PreLk | None
944
+ """The pre-computed terms of the likelihood ratio which are shared across
945
+ trees."""
946
+
947
+
948
+ class SeqStageInPerTree(Module):
949
+ """The inputs to `accept_move_and_sample_leaves` that are separate for each tree."""
950
+
951
+ leaf_tree: Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d']
952
+ """The leaf values of the tree."""
953
+
954
+ prec_tree: Float32[Array, ' 2**d']
955
+ """The likelihood precision scale in each potential or actual leaf node."""
956
+
957
+ move: Moves
958
+ """The proposed move, see `propose_moves`."""
959
+
960
+ move_precs: Precs | Counts
961
+ """The likelihood precision scale in each node modified by the moves."""
962
+
963
+ leaf_indices: UInt[Array, ' n']
964
+ """The leaf indices for the largest version of the tree compatible with
965
+ the move."""
966
+
967
+ prelkv: PreLkV
968
+ """The pre-computed terms of the likelihood ratio which are specific to the tree."""
969
+
970
+ prelf: PreLf
971
+ """The pre-computed terms of the leaf sampling which are specific to the tree."""
972
+
973
+
974
+ def accept_move_and_sample_leaves(
975
+ resid: Float32[Array, ' n'] | Float32[Array, ' k n'],
976
+ at: SeqStageInAllTrees,
977
+ pt: SeqStageInPerTree,
978
+ ) -> tuple[
979
+ Float32[Array, ' n'] | Float32[Array, ' k n'],
980
+ Float32[Array, ' 2**d'] | Float32[Array, ' k 2**d'],
981
+ Bool[Array, ''],
982
+ Bool[Array, ''],
983
+ Float32[Array, ''] | None,
984
+ ]:
985
+ """
986
+ Accept or reject a proposed move and sample the new leaf values.
987
+
988
+ Parameters
989
+ ----------
990
+ resid
991
+ The residuals (data minus forest value).
992
+ at
993
+ The inputs that are the same for all trees.
994
+ pt
995
+ The inputs that are separate for each tree.
996
+
997
+ Returns
998
+ -------
999
+ resid : Float32[Array, 'n'] | Float32[Array, ' k n']
1000
+ The updated residuals (data minus forest value).
1001
+ leaf_tree : Float32[Array, '2**d'] | Float32[Array, ' k 2**d']
1002
+ The new leaf values of the tree.
1003
+ acc : Bool[Array, '']
1004
+ Whether the move was accepted.
1005
+ to_prune : Bool[Array, '']
1006
+ Whether, to reflect the acceptance status of the move, the state should
1007
+ be updated by pruning the leaves involved in the move.
1008
+ log_lk_ratio : Float32[Array, ''] | None
1009
+ The logarithm of the likelihood ratio for the move. `None` if not to be
1010
+ saved.
1011
+ """
1012
+ # sum residuals in each leaf, in tree proposed by grow move
1013
+ if at.prec_scale is None:
1014
+ scaled_resid = resid
1015
+ else:
1016
+ scaled_resid = resid * at.prec_scale
1017
+
1018
+ tree_size = pt.leaf_tree.shape[-1] # 2**d
1019
+
1020
+ resid_tree = sum_resid(
1021
+ scaled_resid, pt.leaf_indices, tree_size, at.resid_num_batches, at.mesh
1022
+ )
1023
+
1024
+ # subtract starting tree from function
1025
+ resid_tree += pt.prec_tree * pt.leaf_tree
1026
+
1027
+ # sum residuals in parent node modified by move and compute likelihood
1028
+ resid_left = resid_tree[..., pt.move.left]
1029
+ resid_right = resid_tree[..., pt.move.right]
1030
+ resid_total = resid_left + resid_right
1031
+ assert pt.move.node.dtype == jnp.int32
1032
+ resid_tree = resid_tree.at[..., pt.move.node].set(resid_total)
1033
+
1034
+ log_lk_ratio = compute_likelihood_ratio(
1035
+ resid_total, resid_left, resid_right, pt.prelkv, at.prelk
1036
+ )
1037
+
1038
+ # calculate accept/reject ratio
1039
+ log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
1040
+ log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
1041
+ if not at.save_ratios:
1042
+ log_lk_ratio = None
1043
+
1044
+ # determine whether to accept the move
1045
+ acc = pt.move.allowed & (pt.move.logu <= log_ratio)
1046
+
1047
+ # compute leaves posterior and sample leaves
1048
+ if resid.ndim > 1:
1049
+ mean_post = jnp.einsum('kil,kl->il', pt.prelf.mean_factor, resid_tree)
1050
+ else:
1051
+ mean_post = resid_tree * pt.prelf.mean_factor
1052
+ leaf_tree = mean_post + pt.prelf.centered_leaves
1053
+
1054
+ # copy leaves around such that the leaf indices point to the correct leaf
1055
+ to_prune = acc ^ pt.move.grow
1056
+ leaf_tree = (
1057
+ leaf_tree.at[..., jnp.where(to_prune, pt.move.left, tree_size)]
1058
+ .set(leaf_tree[..., pt.move.node])
1059
+ .at[..., jnp.where(to_prune, pt.move.right, tree_size)]
1060
+ .set(leaf_tree[..., pt.move.node])
1061
+ )
1062
+ # replace old tree with new tree in function values
1063
+ resid += (pt.leaf_tree - leaf_tree)[..., pt.leaf_indices]
1064
+
1065
+ return resid, leaf_tree, acc, to_prune, log_lk_ratio
1066
+
1067
+
1068
+ @partial(jnp.vectorize, excluded=(1, 2, 3, 4), signature='(n)->(ts)')
1069
+ def sum_resid(
1070
+ scaled_resid: Float32[Array, ' n'] | Float32[Array, 'k n'],
1071
+ leaf_indices: UInt[Array, ' n'],
1072
+ tree_size: int,
1073
+ resid_num_batches: int | None,
1074
+ mesh: Mesh | None,
1075
+ ) -> Float32[Array, ' {tree_size}'] | Float32[Array, 'k {tree_size}']:
1076
+ """
1077
+ Sum the residuals in each leaf.
1078
+
1079
+ Handles both univariate and multivariate cases based on the shape of the
1080
+ input arrays.
1081
+
1082
+ Parameters
1083
+ ----------
1084
+ scaled_resid
1085
+ The residuals (data minus forest value) multiplied by the error
1086
+ precision scale. For multivariate case, shape is ``(k, n)`` where ``k``
1087
+ is the number of outcome columns.
1088
+ leaf_indices
1089
+ The leaf indices of the tree (in which leaf each data point falls into).
1090
+ tree_size
1091
+ The size of the tree array (2 ** d).
1092
+ resid_num_batches
1093
+ The number of batches for computing the sum of residuals in each leaf.
1094
+ mesh
1095
+ The mesh of devices to use.
1096
+
1097
+ Returns
1098
+ -------
1099
+ The sum of the residuals at data points in each leaf. For multivariate
1100
+ case, returns per-leaf sums of residual vectors.
1101
+ """
1102
+ return _scatter_add(
1103
+ scaled_resid, leaf_indices, tree_size, jnp.float32, resid_num_batches, mesh
1104
+ )
1105
+
1106
+
1107
+ def _scatter_add(
1108
+ values: Float32[Array, ' n'] | int,
1109
+ indices: Integer[Array, ' n'],
1110
+ size: int,
1111
+ dtype: jnp.dtype,
1112
+ batch_size: int | None,
1113
+ mesh: Mesh | None,
1114
+ ) -> Shaped[Array, ' {size}']:
1115
+ """Indexed reduce with optional batching."""
1116
+ # check `values`
1117
+ values = jnp.asarray(values)
1118
+ assert values.ndim == 0 or values.shape == indices.shape
1119
+
1120
+ # set configuration
1121
+ _scatter_add = partial(
1122
+ _scatter_add_impl, size=size, dtype=dtype, num_batches=batch_size
1123
+ )
1124
+
1125
+ # single-device invocation
1126
+ if mesh is None or 'data' not in mesh.axis_names:
1127
+ return _scatter_add(values, indices)
1128
+
1129
+ # multi-device invocation
1130
+ if values.shape:
1131
+ in_specs = PartitionSpec('data'), PartitionSpec('data')
1132
+ else:
1133
+ in_specs = PartitionSpec(), PartitionSpec('data')
1134
+ _scatter_add = partial(_scatter_add, final_psum=True)
1135
+ _scatter_add = shard_map(
1136
+ _scatter_add,
1137
+ in_specs=in_specs,
1138
+ out_specs=PartitionSpec(),
1139
+ mesh=mesh,
1140
+ **_get_shard_map_patch_kwargs(),
1141
+ )
1142
+ return _scatter_add(values, indices)
1143
+
1144
+
1145
+ def _get_shard_map_patch_kwargs():
1146
+ # see jax/issues/#34249, problem with vmap(shard_map(psum))
1147
+ # we tried the config jax_disable_vmap_shmap_error but it didn't work
1148
+ if jax.__version__ in ('0.8.1', '0.8.2'):
1149
+ return {'check_vma': False}
1150
+ else:
1151
+ return {}
1152
+
1153
+
1154
+ def _scatter_add_impl(
1155
+ values: Float32[Array, ' n'] | Int32[Array, ''],
1156
+ indices: Integer[Array, ' n'],
1157
+ /,
1158
+ *,
1159
+ size: int,
1160
+ dtype: jnp.dtype,
1161
+ num_batches: int | None,
1162
+ final_psum: bool = False,
1163
+ ) -> Shaped[Array, ' {size}']:
1164
+ if num_batches is None:
1165
+ out = jnp.zeros(size, dtype).at[indices].add(values)
1166
+
1167
+ else:
1168
+ # in the sharded case, n is the size of the local shard, not the full size
1169
+ (n,) = indices.shape
1170
+ batch_indices = jnp.arange(n) % num_batches
1171
+ out = (
1172
+ jnp.zeros((size, num_batches), dtype)
1173
+ .at[indices, batch_indices]
1174
+ .add(values)
1175
+ .sum(axis=1)
1176
+ )
1177
+
1178
+ if final_psum:
1179
+ out = lax.psum(out, 'data')
1180
+ return out
1181
+
1182
+
1183
+ def _compute_likelihood_ratio_uv(
1184
+ total_resid: Float32[Array, ''],
1185
+ left_resid: Float32[Array, ''],
1186
+ right_resid: Float32[Array, ''],
1187
+ prelkv: PreLkV,
1188
+ prelk: PreLk,
1189
+ ) -> Float32[Array, '']:
1190
+ exp_term = prelk.exp_factor * (
1191
+ left_resid * left_resid / prelkv.left
1192
+ + right_resid * right_resid / prelkv.right
1193
+ - total_resid * total_resid / prelkv.total
1194
+ )
1195
+ return prelkv.log_sqrt_term + exp_term
1196
+
1197
+
1198
+ def _compute_likelihood_ratio_mv(
1199
+ total_resid: Float32[Array, ' k'],
1200
+ left_resid: Float32[Array, ' k'],
1201
+ right_resid: Float32[Array, ' k'],
1202
+ prelkv: PreLkV,
1203
+ ) -> Float32[Array, '']:
1204
+ def _quadratic_form(r, mat):
1205
+ return r @ mat @ r
1206
+
1207
+ qf_left = _quadratic_form(left_resid, prelkv.left)
1208
+ qf_right = _quadratic_form(right_resid, prelkv.right)
1209
+ qf_total = _quadratic_form(total_resid, prelkv.total)
1210
+ exp_term = 0.5 * (qf_left + qf_right - qf_total)
1211
+ return prelkv.log_sqrt_term + exp_term
1212
+
1213
+
1214
+ def compute_likelihood_ratio(
1215
+ total_resid: Float32[Array, ''] | Float32[Array, ' k'],
1216
+ left_resid: Float32[Array, ''] | Float32[Array, ' k'],
1217
+ right_resid: Float32[Array, ''] | Float32[Array, ' k'],
1218
+ prelkv: PreLkV,
1219
+ prelk: PreLk | None,
1220
+ ) -> Float32[Array, '']:
1221
+ """
1222
+ Compute the likelihood ratio of a grow move.
1223
+
1224
+ Handles both univariate and multivariate cases based on the shape of the
1225
+ residual arrays.
1226
+
1227
+ Parameters
1228
+ ----------
1229
+ total_resid
1230
+ left_resid
1231
+ right_resid
1232
+ The sum of the residuals (scaled by error precision scale) of the
1233
+ datapoints falling in the nodes involved in the moves.
1234
+ prelkv
1235
+ prelk
1236
+ The pre-computed terms of the likelihood ratio, see
1237
+ `precompute_likelihood_terms`.
1238
+
1239
+ Returns
1240
+ -------
1241
+ The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
1242
+ """
1243
+ if total_resid.ndim > 0:
1244
+ return _compute_likelihood_ratio_mv(
1245
+ total_resid, left_resid, right_resid, prelkv
1246
+ )
1247
+ else:
1248
+ assert prelk is not None
1249
+ return _compute_likelihood_ratio_uv(
1250
+ total_resid, left_resid, right_resid, prelkv, prelk
1251
+ )
1252
+
1253
+
1254
+ @partial(jit_and_block_if_profiling, donate_argnums=(0, 1))
1255
+ @vmap_chains_if_profiling
1256
+ def accept_moves_final_stage(bart: State, moves: Moves) -> State:
1257
+ """
1258
+ Post-process the mcmc state after accepting/rejecting the moves.
1259
+
1260
+ This function is separate from `accept_moves_sequential_stage` to signal it
1261
+ can work in parallel across trees.
1262
+
1263
+ Parameters
1264
+ ----------
1265
+ bart
1266
+ A partially updated BART mcmc state.
1267
+ moves
1268
+ The proposed moves (see `propose_moves`) as updated by
1269
+ `accept_moves_sequential_stage`.
1270
+
1271
+ Returns
1272
+ -------
1273
+ The fully updated BART mcmc state.
1274
+ """
1275
+ return replace(
1276
+ bart,
1277
+ forest=replace(
1278
+ bart.forest,
1279
+ grow_acc_count=jnp.sum(moves.acc & moves.grow),
1280
+ prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
1281
+ leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
1282
+ split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
1283
+ ),
1284
+ )
1285
+
1286
+
1287
+ @vmap_nodoc
1288
+ def apply_moves_to_leaf_indices(
1289
+ leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
1290
+ ) -> UInt[Array, 'num_trees n']:
1291
+ """
1292
+ Update the leaf indices to match the accepted move.
1293
+
1294
+ Parameters
1295
+ ----------
1296
+ leaf_indices
1297
+ The index of the leaf each datapoint falls into, if the grow move was
1298
+ accepted.
1299
+ moves
1300
+ The proposed moves (see `propose_moves`), as updated by
1301
+ `accept_moves_sequential_stage`.
1302
+
1303
+ Returns
1304
+ -------
1305
+ The updated leaf indices.
1306
+ """
1307
+ mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
1308
+ is_child = (leaf_indices & mask) == moves.left
1309
+ assert moves.to_prune is not None
1310
+ return jnp.where(
1311
+ is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
1312
+ )
1313
+
1314
+
1315
+ @vmap_nodoc
1316
+ def apply_moves_to_split_trees(
1317
+ split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
1318
+ ) -> UInt[Array, 'num_trees 2**(d-1)']:
1319
+ """
1320
+ Update the split trees to match the accepted move.
1321
+
1322
+ Parameters
1323
+ ----------
1324
+ split_tree
1325
+ The cutpoints of the decision nodes in the initial trees.
1326
+ moves
1327
+ The proposed moves (see `propose_moves`), as updated by
1328
+ `accept_moves_sequential_stage`.
1329
+
1330
+ Returns
1331
+ -------
1332
+ The updated split trees.
1333
+ """
1334
+ assert moves.to_prune is not None
1335
+ return (
1336
+ split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
1337
+ .set(moves.grow_split.astype(split_tree.dtype))
1338
+ .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
1339
+ .set(0)
1340
+ )
1341
+
1342
+
1343
+ @jax.jit
1344
+ def _sample_wishart_bartlett(
1345
+ key: Key[Array, ''], df: Float32[Array, ''], scale_inv: Float32[Array, 'k k']
1346
+ ) -> Float32[Array, 'k k']:
1347
+ """
1348
+ Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
1349
+
1350
+ Parameters
1351
+ ----------
1352
+ key
1353
+ A JAX random key
1354
+ df
1355
+ Degrees of freedom
1356
+ scale_inv
1357
+ Scale matrix of the corresponding Inverse Wishart distribution
1358
+
1359
+ Returns
1360
+ -------
1361
+ A sample from Wishart(df, scale)
1362
+ """
1363
+ keys = split(key)
1364
+
1365
+ # Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
1366
+ # chi^2(k) = Gamma(k/2, scale=2)
1367
+ k, _ = scale_inv.shape
1368
+ df_vector = df - jnp.arange(k)
1369
+ chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0
1370
+ diag_A = jnp.sqrt(chi2_samples)
1371
+
1372
+ off_diag_A = random.normal(keys.pop(), (k, k))
1373
+ A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A)
1374
+ L = chol_with_gersh(scale_inv, absolute_eps=True)
1375
+ T = solve_triangular(L, A, lower=True, trans='T')
1376
+
1377
+ return T @ T.T
1378
+
1379
+
1380
+ def _step_error_cov_inv_uv(key: Key[Array, ''], bart: State) -> State:
1381
+ resid = bart.resid
1382
+ # inverse gamma prior: alpha = df / 2, beta = scale / 2
1383
+ alpha = bart.error_cov_df / 2 + resid.size / 2
1384
+ if bart.prec_scale is None:
1385
+ scaled_resid = resid
1386
+ else:
1387
+ scaled_resid = resid * bart.prec_scale
1388
+ norm2 = resid @ scaled_resid
1389
+ beta = bart.error_cov_scale / 2 + norm2 / 2
1390
+
1391
+ sample = random.gamma(key, alpha)
1392
+ # random.gamma seems to be slow at compiling, maybe cdf inversion would
1393
+ # be better, but it's not implemented in jax
1394
+ return replace(bart, error_cov_inv=sample / beta)
1395
+
1396
+
1397
+ def _step_error_cov_inv_mv(key: Key[Array, ''], bart: State) -> State:
1398
+ n = bart.resid.shape[-1]
1399
+ df_post = bart.error_cov_df + n
1400
+ scale_post = bart.error_cov_scale + bart.resid @ bart.resid.T
1401
+
1402
+ prec = _sample_wishart_bartlett(key, df_post, scale_post)
1403
+ return replace(bart, error_cov_inv=prec)
1404
+
1405
+
1406
+ @partial(jit_and_block_if_profiling, donate_argnums=(1,))
1407
+ @vmap_chains_if_profiling
1408
+ def step_error_cov_inv(key: Key[Array, ''], bart: State) -> State:
1409
+ """
1410
+ MCMC-update the inverse error covariance.
1411
+
1412
+ Handles both univariate and multivariate cases based on the BART state's
1413
+ `kind` attribute.
1414
+
1415
+ Parameters
1416
+ ----------
1417
+ key
1418
+ A jax random key.
1419
+ bart
1420
+ A BART mcmc state.
1421
+
1422
+ Returns
1423
+ -------
1424
+ The new BART mcmc state, with an updated `error_cov_inv`.
1425
+ """
1426
+ assert bart.error_cov_inv is not None
1427
+ if bart.error_cov_inv.ndim == 2:
1428
+ return _step_error_cov_inv_mv(key, bart)
1429
+ else:
1430
+ return _step_error_cov_inv_uv(key, bart)
1431
+
1432
+
1433
+ @partial(jit_and_block_if_profiling, donate_argnums=(1,))
1434
+ @vmap_chains_if_profiling
1435
+ def step_z(key: Key[Array, ''], bart: State) -> State:
1436
+ """
1437
+ MCMC-update the latent variable for binary regression.
1438
+
1439
+ Parameters
1440
+ ----------
1441
+ key
1442
+ A jax random key.
1443
+ bart
1444
+ A BART MCMC state.
1445
+
1446
+ Returns
1447
+ -------
1448
+ The updated BART MCMC state.
1449
+ """
1450
+ trees_plus_offset = bart.z - bart.resid
1451
+ assert bart.y.dtype == bool
1452
+ resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset)
1453
+ z = trees_plus_offset + resid
1454
+ return replace(bart, z=z, resid=resid)
1455
+
1456
+
1457
+ def step_s(key: Key[Array, ''], bart: State) -> State:
1458
+ """
1459
+ Update `log_s` using Dirichlet sampling.
1460
+
1461
+ The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
1462
+ is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
1463
+ varcount is the count of how many times each variable is used in the
1464
+ current forest.
1465
+
1466
+ Parameters
1467
+ ----------
1468
+ key
1469
+ Random key for sampling.
1470
+ bart
1471
+ The current BART state.
1472
+
1473
+ Returns
1474
+ -------
1475
+ Updated BART state with re-sampled `log_s`.
1476
+
1477
+ Notes
1478
+ -----
1479
+ This full conditional is approximated, because it does not take into account
1480
+ that there are forbidden decision rules.
1481
+ """
1482
+ assert bart.forest.theta is not None
1483
+
1484
+ # histogram current variable usage
1485
+ p = bart.forest.max_split.size
1486
+ varcount = var_histogram(
1487
+ p, bart.forest.var_tree, bart.forest.split_tree, sum_batch_axis=-1
1488
+ )
1489
+
1490
+ # sample from Dirichlet posterior
1491
+ alpha = bart.forest.theta / p + varcount
1492
+ log_s = random.loggamma(key, alpha)
1493
+
1494
+ # update forest with new s
1495
+ return replace(bart, forest=replace(bart.forest, log_s=log_s))
1496
+
1497
+
1498
+ def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State:
1499
+ """
1500
+ Update `theta`.
1501
+
1502
+ The prior is theta / (theta + rho) ~ Beta(a, b).
1503
+
1504
+ Parameters
1505
+ ----------
1506
+ key
1507
+ Random key for sampling.
1508
+ bart
1509
+ The current BART state.
1510
+ num_grid
1511
+ The number of points in the evenly-spaced grid used to sample
1512
+ theta / (theta + rho).
1513
+
1514
+ Returns
1515
+ -------
1516
+ Updated BART state with re-sampled `theta`.
1517
+ """
1518
+ assert bart.forest.log_s is not None
1519
+ assert bart.forest.rho is not None
1520
+ assert bart.forest.a is not None
1521
+ assert bart.forest.b is not None
1522
+
1523
+ # the grid points are the midpoints of num_grid bins in (0, 1)
1524
+ padding = 1 / (2 * num_grid)
1525
+ lamda_grid = jnp.linspace(padding, 1 - padding, num_grid)
1526
+
1527
+ # normalize s
1528
+ log_s = bart.forest.log_s - logsumexp(bart.forest.log_s)
1529
+
1530
+ # sample lambda
1531
+ logp, theta_grid = _log_p_lamda(
1532
+ lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
1533
+ )
1534
+ i = random.categorical(key, logp)
1535
+ theta = theta_grid[i]
1536
+
1537
+ return replace(bart, forest=replace(bart.forest, theta=theta))
1538
+
1539
+
1540
+ def _log_p_lamda(
1541
+ lamda: Float32[Array, ' num_grid'],
1542
+ log_s: Float32[Array, ' p'],
1543
+ rho: Float32[Array, ''],
1544
+ a: Float32[Array, ''],
1545
+ b: Float32[Array, ''],
1546
+ ) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
1547
+ # in the following I use lamda[::-1] == 1 - lamda
1548
+ theta = rho * lamda / lamda[::-1]
1549
+ p = log_s.size
1550
+ return (
1551
+ (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
1552
+ + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
1553
+ + gammaln(theta)
1554
+ - p * gammaln(theta / p)
1555
+ + theta / p * jnp.sum(log_s)
1556
+ ), theta
1557
+
1558
+
1559
+ @partial(jit_and_block_if_profiling, donate_argnums=(1,))
1560
+ @vmap_chains_if_profiling
1561
+ def step_sparse(key: Key[Array, ''], bart: State) -> State:
1562
+ """
1563
+ Update the sparsity parameters.
1564
+
1565
+ This invokes `step_s`, and then `step_theta` only if the parameters of
1566
+ the theta prior are defined.
1567
+
1568
+ Parameters
1569
+ ----------
1570
+ key
1571
+ Random key for sampling.
1572
+ bart
1573
+ The current BART state.
1574
+
1575
+ Returns
1576
+ -------
1577
+ Updated BART state with re-sampled `log_s` and `theta`.
1578
+ """
1579
+ if bart.config.sparse_on_at is not None:
1580
+ bart = cond(
1581
+ bart.config.steps_done < bart.config.sparse_on_at,
1582
+ lambda _key, bart: bart,
1583
+ _step_sparse,
1584
+ key,
1585
+ bart,
1586
+ )
1587
+ return bart
1588
+
1589
+
1590
+ def _step_sparse(key, bart):
1591
+ keys = split(key)
1592
+ bart = step_s(keys.pop(), bart)
1593
+ if bart.forest.rho is not None:
1594
+ bart = step_theta(keys.pop(), bart)
1595
+ return bart
1596
+
1597
+
1598
+ @jit_if_profiling
1599
+ # jit to avoid the overhead of replace(_: Module)
1600
+ def step_config(bart):
1601
+ config = bart.config
1602
+ config = replace(config, steps_done=config.steps_done + 1)
1603
+ return replace(bart, config=config)