bartz 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bartz/mcmcstep.py DELETED
@@ -1,2335 +0,0 @@
1
- # bartz/src/bartz/mcmcstep.py
2
- #
3
- # Copyright (c) 2024-2025, Giacomo Petrillo
4
- #
5
- # This file is part of bartz.
6
- #
7
- # Permission is hereby granted, free of charge, to any person obtaining a copy
8
- # of this software and associated documentation files (the "Software"), to deal
9
- # in the Software without restriction, including without limitation the rights
10
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
- # copies of the Software, and to permit persons to whom the Software is
12
- # furnished to do so, subject to the following conditions:
13
- #
14
- # The above copyright notice and this permission notice shall be included in all
15
- # copies or substantial portions of the Software.
16
- #
17
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
- # SOFTWARE.
24
-
25
- """
26
- Functions that implement the BART posterior MCMC initialization and update step.
27
-
28
- Functions that do MCMC steps operate by taking as input a bart state, and
29
- outputting a new state. The inputs are not modified.
30
-
31
- The main entry points are:
32
-
33
- - `State`: The dataclass that represents a BART MCMC state.
34
- - `init`: Creates an initial `State` from data and configurations.
35
- - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36
- """
37
-
38
- import math
39
- from dataclasses import replace
40
- from functools import cache, partial
41
- from typing import Any
42
-
43
- import jax
44
- from equinox import Module, field
45
- from jax import lax, random
46
- from jax import numpy as jnp
47
- from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt
48
-
49
- from . import grove
50
- from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc
51
-
52
-
53
- class Forest(Module):
54
- """
55
- Represents the MCMC state of a sum of trees.
56
-
57
- Parameters
58
- ----------
59
- leaf_trees
60
- The leaf values.
61
- var_trees
62
- The decision axes.
63
- split_trees
64
- The decision boundaries.
65
- p_nonterminal
66
- The probability of a nonterminal node at each depth, padded with a
67
- zero.
68
- p_propose_grow
69
- The unnormalized probability of picking a leaf for a grow proposal.
70
- leaf_indices
71
- The index of the leaf each datapoints falls into, for each tree.
72
- min_points_per_leaf
73
- The minimum number of data points in a leaf node.
74
- affluence_trees
75
- Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
76
- datapoints. If `min_points_per_leaf` is not specified, this is None.
77
- resid_batch_size
78
- count_batch_size
79
- The data batch sizes for computing the sufficient statistics. If `None`,
80
- they are computed with no batching.
81
- log_trans_prior
82
- The log transition and prior Metropolis-Hastings ratio for the
83
- proposed move on each tree.
84
- log_likelihood
85
- The log likelihood ratio.
86
- grow_prop_count
87
- prune_prop_count
88
- The number of grow/prune proposals made during one full MCMC cycle.
89
- grow_acc_count
90
- prune_acc_count
91
- The number of grow/prune moves accepted during one full MCMC cycle.
92
- sigma_mu2
93
- The prior variance of a leaf, conditional on the tree structure.
94
- """
95
-
96
- leaf_trees: Float32[Array, 'num_trees 2**d']
97
- var_trees: UInt[Array, 'num_trees 2**(d-1)']
98
- split_trees: UInt[Array, 'num_trees 2**(d-1)']
99
- p_nonterminal: Float32[Array, 'd']
100
- p_propose_grow: Float32[Array, '2**(d-1)']
101
- leaf_indices: UInt[Array, 'num_trees n']
102
- min_points_per_leaf: Int32[Array, ''] | None
103
- affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None
104
- resid_batch_size: int | None = field(static=True)
105
- count_batch_size: int | None = field(static=True)
106
- log_trans_prior: Float32[Array, 'num_trees'] | None
107
- log_likelihood: Float32[Array, 'num_trees'] | None
108
- grow_prop_count: Int32[Array, '']
109
- prune_prop_count: Int32[Array, '']
110
- grow_acc_count: Int32[Array, '']
111
- prune_acc_count: Int32[Array, '']
112
- sigma_mu2: Float32[Array, '']
113
-
114
-
115
- class State(Module):
116
- """
117
- Represents the MCMC state of BART.
118
-
119
- Parameters
120
- ----------
121
- X
122
- The predictors.
123
- max_split
124
- The maximum split index for each predictor.
125
- y
126
- The response. If the data type is `bool`, the model is binary regression.
127
- resid
128
- The residuals (`y` or `z` minus sum of trees).
129
- z
130
- The latent variable for binary regression. `None` in continuous
131
- regression.
132
- offset
133
- Constant shift added to the sum of trees.
134
- sigma2
135
- The error variance. `None` in binary regression.
136
- prec_scale
137
- The scale on the error precision, i.e., ``1 / error_scale ** 2``.
138
- `None` in binary regression.
139
- sigma2_alpha
140
- sigma2_beta
141
- The shape and scale parameters of the inverse gamma prior on the noise
142
- variance. `None` in binary regression.
143
- forest
144
- The sum of trees model.
145
- """
146
-
147
- X: UInt[Array, 'p n']
148
- max_split: UInt[Array, 'p']
149
- y: Float32[Array, 'n'] | Bool[Array, 'n']
150
- z: None | Float32[Array, 'n']
151
- offset: Float32[Array, '']
152
- resid: Float32[Array, 'n']
153
- sigma2: Float32[Array, ''] | None
154
- prec_scale: Float32[Array, 'n'] | None
155
- sigma2_alpha: Float32[Array, ''] | None
156
- sigma2_beta: Float32[Array, ''] | None
157
- forest: Forest
158
-
159
-
160
- def init(
161
- *,
162
- X: UInt[Any, 'p n'],
163
- y: Float32[Any, 'n'] | Bool[Any, 'n'],
164
- offset: float | Float32[Any, ''] = 0.0,
165
- max_split: UInt[Any, 'p'],
166
- num_trees: int,
167
- p_nonterminal: Float32[Any, 'd-1'],
168
- sigma_mu2: float | Float32[Any, ''],
169
- sigma2_alpha: float | Float32[Any, ''] | None = None,
170
- sigma2_beta: float | Float32[Any, ''] | None = None,
171
- error_scale: Float32[Any, 'n'] | None = None,
172
- min_points_per_leaf: int | None = None,
173
- resid_batch_size: int | None | str = 'auto',
174
- count_batch_size: int | None | str = 'auto',
175
- save_ratios: bool = False,
176
- ) -> State:
177
- """
178
- Make a BART posterior sampling MCMC initial state.
179
-
180
- Parameters
181
- ----------
182
- X
183
- The predictors. Note this is trasposed compared to the usual convention.
184
- y
185
- The response. If the data type is `bool`, the regression model is binary
186
- regression with probit.
187
- offset
188
- Constant shift added to the sum of trees. 0 if not specified.
189
- max_split
190
- The maximum split index for each variable. All split ranges start at 1.
191
- num_trees
192
- The number of trees in the forest.
193
- p_nonterminal
194
- The probability of a nonterminal node at each depth. The maximum depth
195
- of trees is fixed by the length of this array.
196
- sigma_mu2
197
- The prior variance of a leaf, conditional on the tree structure. The
198
- prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
199
- prior mean of leaves is always zero.
200
- sigma2_alpha
201
- sigma2_beta
202
- The shape and scale parameters of the inverse gamma prior on the error
203
- variance. Leave unspecified for binary regression.
204
- error_scale
205
- Each error is scaled by the corresponding factor in `error_scale`, so
206
- the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
207
- Not supported for binary regression. If not specified, defaults to 1 for
208
- all points, but potentially skipping calculations.
209
- min_points_per_leaf
210
- The minimum number of data points in a leaf node. 0 if not specified.
211
- resid_batch_size
212
- count_batch_size
213
- The batch sizes, along datapoints, for summing the residuals and
214
- counting the number of datapoints in each leaf. `None` for no batching.
215
- If 'auto', pick a value based on the device of `y`, or the default
216
- device.
217
- save_ratios
218
- Whether to save the Metropolis-Hastings ratios.
219
-
220
- Returns
221
- -------
222
- An initialized BART MCMC state.
223
-
224
- Raises
225
- ------
226
- ValueError
227
- If `y` is boolean and arguments unused in binary regression are set.
228
- """
229
- p_nonterminal = jnp.asarray(p_nonterminal)
230
- p_nonterminal = jnp.pad(p_nonterminal, (0, 1))
231
- max_depth = p_nonterminal.size
232
-
233
- @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees)
234
- def make_forest(max_depth, dtype):
235
- return grove.make_tree(max_depth, dtype)
236
-
237
- y = jnp.asarray(y)
238
- offset = jnp.asarray(offset)
239
-
240
- resid_batch_size, count_batch_size = _choose_suffstat_batch_size(
241
- resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
242
- )
243
-
244
- is_binary = y.dtype == bool
245
- if is_binary:
246
- if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,):
247
- raise ValueError(
248
- 'error_scale, sigma2_alpha, and sigma2_beta must be set '
249
- ' to `None` for binary regression.'
250
- )
251
- sigma2 = None
252
- else:
253
- sigma2_alpha = jnp.asarray(sigma2_alpha)
254
- sigma2_beta = jnp.asarray(sigma2_beta)
255
- sigma2 = sigma2_beta / sigma2_alpha
256
- # sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
257
- # TODO: I don't like this isfinite check, these functions should be
258
- # low-level and just do the thing. Why was it here?
259
-
260
- bart = State(
261
- X=jnp.asarray(X),
262
- max_split=jnp.asarray(max_split),
263
- y=y,
264
- z=jnp.full(y.shape, offset) if is_binary else None,
265
- offset=offset,
266
- resid=jnp.zeros(y.shape) if is_binary else y - offset,
267
- sigma2=sigma2,
268
- prec_scale=(
269
- None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
270
- ),
271
- sigma2_alpha=sigma2_alpha,
272
- sigma2_beta=sigma2_beta,
273
- forest=Forest(
274
- leaf_trees=make_forest(max_depth, jnp.float32),
275
- var_trees=make_forest(
276
- max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
277
- ),
278
- split_trees=make_forest(max_depth - 1, max_split.dtype),
279
- grow_prop_count=jnp.zeros((), int),
280
- grow_acc_count=jnp.zeros((), int),
281
- prune_prop_count=jnp.zeros((), int),
282
- prune_acc_count=jnp.zeros((), int),
283
- p_nonterminal=p_nonterminal,
284
- p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
285
- leaf_indices=jnp.ones(
286
- (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
287
- ),
288
- min_points_per_leaf=(
289
- None
290
- if min_points_per_leaf is None
291
- else jnp.asarray(min_points_per_leaf)
292
- ),
293
- affluence_trees=(
294
- None
295
- if min_points_per_leaf is None
296
- else make_forest(max_depth - 1, bool)
297
- .at[:, 1]
298
- .set(y.size >= 2 * min_points_per_leaf)
299
- ),
300
- resid_batch_size=resid_batch_size,
301
- count_batch_size=count_batch_size,
302
- log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
303
- log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
304
- sigma_mu2=jnp.asarray(sigma_mu2),
305
- ),
306
- )
307
-
308
- return bart
309
-
310
-
311
- def _choose_suffstat_batch_size(
312
- resid_batch_size, count_batch_size, y, forest_size
313
- ) -> tuple[int | None, ...]:
314
- @cache
315
- def get_platform():
316
- try:
317
- device = y.devices().pop()
318
- except jax.errors.ConcretizationTypeError:
319
- device = jax.devices()[0]
320
- platform = device.platform
321
- if platform not in ('cpu', 'gpu'):
322
- raise KeyError(f'Unknown platform: {platform}')
323
- return platform
324
-
325
- if resid_batch_size == 'auto':
326
- platform = get_platform()
327
- n = max(1, y.size)
328
- if platform == 'cpu':
329
- resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
330
- elif platform == 'gpu':
331
- resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
332
- resid_batch_size = max(1, resid_batch_size)
333
-
334
- if count_batch_size == 'auto':
335
- platform = get_platform()
336
- if platform == 'cpu':
337
- count_batch_size = None
338
- elif platform == 'gpu':
339
- n = max(1, y.size)
340
- count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
341
- # /4 is good on V100, /2 on L4/T4, still haven't tried A100
342
- max_memory = 2**29
343
- itemsize = 4
344
- min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
345
- count_batch_size = max(count_batch_size, min_batch_size)
346
- count_batch_size = max(1, count_batch_size)
347
-
348
- return resid_batch_size, count_batch_size
349
-
350
-
351
- @jax.jit
352
- def step(key: Key[Array, ''], bart: State) -> State:
353
- """
354
- Do one MCMC step.
355
-
356
- Parameters
357
- ----------
358
- key
359
- A jax random key.
360
- bart
361
- A BART mcmc state, as created by `init`.
362
-
363
- Returns
364
- -------
365
- The new BART mcmc state.
366
- """
367
- keys = split(key)
368
-
369
- if bart.y.dtype == bool: # binary regression
370
- bart = replace(bart, sigma2=jnp.float32(1))
371
- bart = step_trees(keys.pop(), bart)
372
- bart = replace(bart, sigma2=None)
373
- return step_z(keys.pop(), bart)
374
-
375
- else: # continuous regression
376
- bart = step_trees(keys.pop(), bart)
377
- return step_sigma(keys.pop(), bart)
378
-
379
-
380
- def step_trees(key: Key[Array, ''], bart: State) -> State:
381
- """
382
- Forest sampling step of BART MCMC.
383
-
384
- Parameters
385
- ----------
386
- key
387
- A jax random key.
388
- bart
389
- A BART mcmc state, as created by `init`.
390
-
391
- Returns
392
- -------
393
- The new BART mcmc state.
394
-
395
- Notes
396
- -----
397
- This function zeroes the proposal counters.
398
- """
399
- keys = split(key)
400
- moves = propose_moves(keys.pop(), bart.forest, bart.max_split)
401
- return accept_moves_and_sample_leaves(keys.pop(), bart, moves)
402
-
403
-
404
- class Moves(Module):
405
- """
406
- Moves proposed to modify each tree.
407
-
408
- Parameters
409
- ----------
410
- allowed
411
- Whether the move is possible in the first place. There are additional
412
- constraints that could forbid it, but they are computed at acceptance
413
- time.
414
- grow
415
- Whether the move is a grow move or a prune move.
416
- num_growable
417
- The number of growable leaves in the original tree.
418
- node
419
- The index of the leaf to grow or node to prune.
420
- left
421
- right
422
- The indices of the children of 'node'.
423
- partial_ratio
424
- A factor of the Metropolis-Hastings ratio of the move. It lacks
425
- the likelihood ratio and the probability of proposing the prune
426
- move. If the move is PRUNE, the ratio is inverted. `None` once
427
- `log_trans_prior_ratio` has been computed.
428
- log_trans_prior_ratio
429
- The logarithm of the product of the transition and prior terms of the
430
- Metropolis-Hastings ratio for the acceptance of the proposed move.
431
- `None` if not yet computed.
432
- grow_var
433
- The decision axes of the new rules.
434
- grow_split
435
- The decision boundaries of the new rules.
436
- var_trees
437
- The updated decision axes of the trees, valid whatever move.
438
- logu
439
- The logarithm of a uniform (0, 1] random variable to be used to
440
- accept the move. It's in (-oo, 0].
441
- acc
442
- Whether the move was accepted. `None` if not yet computed.
443
- to_prune
444
- Whether the final operation to apply the move is pruning. This indicates
445
- an accepted prune move or a rejected grow move. `None` if not yet
446
- computed.
447
- """
448
-
449
- allowed: Bool[Array, 'num_trees']
450
- grow: Bool[Array, 'num_trees']
451
- num_growable: UInt[Array, 'num_trees']
452
- node: UInt[Array, 'num_trees']
453
- left: UInt[Array, 'num_trees']
454
- right: UInt[Array, 'num_trees']
455
- partial_ratio: Float32[Array, 'num_trees'] | None
456
- log_trans_prior_ratio: None | Float32[Array, 'num_trees']
457
- grow_var: UInt[Array, 'num_trees']
458
- grow_split: UInt[Array, 'num_trees']
459
- var_trees: UInt[Array, 'num_trees 2**(d-1)']
460
- logu: Float32[Array, 'num_trees']
461
- acc: None | Bool[Array, 'num_trees']
462
- to_prune: None | Bool[Array, 'num_trees']
463
-
464
-
465
- def propose_moves(
466
- key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
467
- ) -> Moves:
468
- """
469
- Propose moves for all the trees.
470
-
471
- There are two types of moves: GROW (convert a leaf to a decision node and
472
- add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
473
- leaf, deleting its children).
474
-
475
- Parameters
476
- ----------
477
- key
478
- A jax random key.
479
- forest
480
- The `forest` field of a BART MCMC state.
481
- max_split
482
- The maximum split index for each variable, found in `State`.
483
-
484
- Returns
485
- -------
486
- The proposed move for each tree.
487
- """
488
- num_trees, _ = forest.leaf_trees.shape
489
- keys = split(key, 1 + 2 * num_trees)
490
-
491
- # compute moves
492
- grow_moves = propose_grow_moves(
493
- keys.pop(num_trees),
494
- forest.var_trees,
495
- forest.split_trees,
496
- forest.affluence_trees,
497
- max_split,
498
- forest.p_nonterminal,
499
- forest.p_propose_grow,
500
- )
501
- prune_moves = propose_prune_moves(
502
- keys.pop(num_trees),
503
- forest.split_trees,
504
- forest.affluence_trees,
505
- forest.p_nonterminal,
506
- forest.p_propose_grow,
507
- )
508
-
509
- u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32)
510
-
511
- # choose between grow or prune
512
- grow_allowed = grow_moves.num_growable.astype(bool)
513
- p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed)
514
- grow = u < p_grow # use < instead of <= because u is in [0, 1)
515
-
516
- # compute children indices
517
- node = jnp.where(grow, grow_moves.node, prune_moves.node)
518
- left = node << 1
519
- right = left + 1
520
-
521
- return Moves(
522
- allowed=grow | prune_moves.allowed,
523
- grow=grow,
524
- num_growable=grow_moves.num_growable,
525
- node=node,
526
- left=left,
527
- right=right,
528
- partial_ratio=jnp.where(
529
- grow, grow_moves.partial_ratio, prune_moves.partial_ratio
530
- ),
531
- log_trans_prior_ratio=None, # will be set in complete_ratio
532
- grow_var=grow_moves.var,
533
- grow_split=grow_moves.split,
534
- var_trees=grow_moves.var_tree,
535
- logu=jnp.log1p(-logu),
536
- acc=None, # will be set in accept_moves_sequential_stage
537
- to_prune=None, # will be set in accept_moves_sequential_stage
538
- )
539
-
540
-
541
- class GrowMoves(Module):
542
- """
543
- Represent a proposed grow move for each tree.
544
-
545
- Parameters
546
- ----------
547
- num_growable
548
- The number of growable leaves.
549
- node
550
- The index of the leaf to grow. ``2 ** d`` if there are no growable
551
- leaves.
552
- var
553
- split
554
- The decision axis and boundary of the new rule.
555
- partial_ratio
556
- A factor of the Metropolis-Hastings ratio of the move. It lacks
557
- the likelihood ratio and the probability of proposing the prune
558
- move.
559
- var_tree
560
- The updated decision axes of the tree.
561
- """
562
-
563
- num_growable: UInt[Array, 'num_trees']
564
- node: UInt[Array, 'num_trees']
565
- var: UInt[Array, 'num_trees']
566
- split: UInt[Array, 'num_trees']
567
- partial_ratio: Float32[Array, 'num_trees']
568
- var_tree: UInt[Array, 'num_trees 2**(d-1)']
569
-
570
-
571
- @partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None))
572
- def propose_grow_moves(
573
- key: Key[Array, ''],
574
- var_tree: UInt[Array, '2**(d-1)'],
575
- split_tree: UInt[Array, '2**(d-1)'],
576
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
577
- max_split: UInt[Array, 'p'],
578
- p_nonterminal: Float32[Array, 'd'],
579
- p_propose_grow: Float32[Array, '2**(d-1)'],
580
- ) -> GrowMoves:
581
- """
582
- Propose a GROW move for each tree.
583
-
584
- A GROW move picks a leaf node and converts it to a non-terminal node with
585
- two leaf children.
586
-
587
- Parameters
588
- ----------
589
- key
590
- A jax random key.
591
- var_tree
592
- The splitting axes of the tree.
593
- split_tree
594
- The splitting points of the tree.
595
- affluence_tree
596
- Whether a leaf has enough points to be grown.
597
- max_split
598
- The maximum split index for each variable.
599
- p_nonterminal
600
- The probability of a nonterminal node at each depth.
601
- p_propose_grow
602
- The unnormalized probability of choosing a leaf to grow.
603
-
604
- Returns
605
- -------
606
- An object representing the proposed move.
607
-
608
- Notes
609
- -----
610
- The move is not proposed if a leaf is already at maximum depth, or if a leaf
611
- has less than twice the requested minimum number of datapoints per leaf.
612
- This is marked by returning `num_growable` set to 0.
613
-
614
- The move is also not be possible if the ancestors of a leaf have
615
- exhausted the possible decision rules that lead to a non-empty selection.
616
- This is marked by returning `var` set to `p` and `split` set to 0. But this
617
- does not block the move from counting as "proposed", even though it is
618
- predictably going to be rejected. This simplifies the MCMC and should not
619
- reduce efficiency if not in unrealistic corner cases.
620
- """
621
- keys = split(key, 3)
622
-
623
- leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(
624
- keys.pop(), split_tree, affluence_tree, p_propose_grow
625
- )
626
-
627
- var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
628
- var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype))
629
-
630
- split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow)
631
-
632
- ratio = compute_partial_ratio(
633
- prob_choose, num_prunable, p_nonterminal, leaf_to_grow
634
- )
635
-
636
- return GrowMoves(
637
- num_growable=num_growable,
638
- node=leaf_to_grow,
639
- var=var,
640
- split=split_idx,
641
- partial_ratio=ratio,
642
- var_tree=var_tree,
643
- )
644
-
645
- # TODO it is not clear to me how var=p and split=0 when the move is not
646
- # possible lead to corrent behavior downstream. Like, the move is proposed,
647
- # but then it's a noop? And since it's a noop, it makes no difference if
648
- # it's "accepted" or "rejected", it's like it's always rejected, so who
649
- # cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
650
-
651
-
652
- def choose_leaf(
653
- key: Key[Array, ''],
654
- split_tree: UInt[Array, '2**(d-1)'],
655
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
656
- p_propose_grow: Float32[Array, '2**(d-1)'],
657
- ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
658
- """
659
- Choose a leaf node to grow in a tree.
660
-
661
- Parameters
662
- ----------
663
- key
664
- A jax random key.
665
- split_tree
666
- The splitting points of the tree.
667
- affluence_tree
668
- Whether a leaf has enough points that it could be split into two leaves
669
- satisfying the `min_points_per_leaf` requirement.
670
- p_propose_grow
671
- The unnormalized probability of choosing a leaf to grow.
672
-
673
- Returns
674
- -------
675
- leaf_to_grow : int
676
- The index of the leaf to grow. If ``num_growable == 0``, return
677
- ``2 ** d``.
678
- num_growable : int
679
- The number of leaf nodes that can be grown, i.e., are nonterminal
680
- and have at least twice `min_points_per_leaf` if set.
681
- prob_choose : float
682
- The (normalized) probability that this function had to choose that
683
- specific leaf, given the arguments.
684
- num_prunable : int
685
- The number of leaf parents that could be pruned, after converting the
686
- selected leaf to a non-terminal node.
687
- """
688
- is_growable = growable_leaves(split_tree, affluence_tree)
689
- num_growable = jnp.count_nonzero(is_growable)
690
- distr = jnp.where(is_growable, p_propose_grow, 0)
691
- leaf_to_grow, distr_norm = categorical(key, distr)
692
- leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size)
693
- prob_choose = distr[leaf_to_grow] / distr_norm
694
- is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1))
695
- num_prunable = jnp.count_nonzero(is_parent)
696
- return leaf_to_grow, num_growable, prob_choose, num_prunable
697
-
698
-
699
- def growable_leaves(
700
- split_tree: UInt[Array, '2**(d-1)'],
701
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
702
- ) -> Bool[Array, '2**(d-1)']:
703
- """
704
- Return a mask indicating the leaf nodes that can be proposed for growth.
705
-
706
- The condition is that a leaf is not at the bottom level and has at least two
707
- times the number of minimum points per leaf.
708
-
709
- Parameters
710
- ----------
711
- split_tree
712
- The splitting points of the tree.
713
- affluence_tree
714
- Whether a leaf has enough points to be grown.
715
-
716
- Returns
717
- -------
718
- The mask indicating the leaf nodes that can be proposed to grow.
719
- """
720
- is_growable = grove.is_actual_leaf(split_tree)
721
- if affluence_tree is not None:
722
- is_growable &= affluence_tree
723
- return is_growable
724
-
725
-
726
- def categorical(
727
- key: Key[Array, ''], distr: Float32[Array, 'n']
728
- ) -> tuple[Int32[Array, ''], Float32[Array, '']]:
729
- """
730
- Return a random integer from an arbitrary distribution.
731
-
732
- Parameters
733
- ----------
734
- key
735
- A jax random key.
736
- distr
737
- An unnormalized probability distribution.
738
-
739
- Returns
740
- -------
741
- u : Int32[Array, '']
742
- A random integer in the range ``[0, n)``. If all probabilities are zero,
743
- return ``n``.
744
- norm : Float32[Array, '']
745
- The sum of `distr`.
746
- """
747
- ecdf = jnp.cumsum(distr)
748
- u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1])
749
- return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1]
750
-
751
-
752
- def choose_variable(
753
- key: Key[Array, ''],
754
- var_tree: UInt[Array, '2**(d-1)'],
755
- split_tree: UInt[Array, '2**(d-1)'],
756
- max_split: UInt[Array, 'p'],
757
- leaf_index: Int32[Array, ''],
758
- ) -> Int32[Array, '']:
759
- """
760
- Choose a variable to split on for a new non-terminal node.
761
-
762
- Parameters
763
- ----------
764
- key
765
- A jax random key.
766
- var_tree
767
- The variable indices of the tree.
768
- split_tree
769
- The splitting points of the tree.
770
- max_split
771
- The maximum split index for each variable.
772
- leaf_index
773
- The index of the leaf to grow.
774
-
775
- Returns
776
- -------
777
- The index of the variable to split on.
778
-
779
- Notes
780
- -----
781
- The variable is chosen among the variables that have a non-empty range of
782
- allowed splits. If no variable has a non-empty range, return `p`.
783
- """
784
- var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index)
785
- return randint_exclude(key, max_split.size, var_to_ignore)
786
-
787
-
788
- def fully_used_variables(
789
- var_tree: UInt[Array, '2**(d-1)'],
790
- split_tree: UInt[Array, '2**(d-1)'],
791
- max_split: UInt[Array, 'p'],
792
- leaf_index: Int32[Array, ''],
793
- ) -> UInt[Array, 'd-2']:
794
- """
795
- Return a list of variables that have an empty split range at a given node.
796
-
797
- Parameters
798
- ----------
799
- var_tree
800
- The variable indices of the tree.
801
- split_tree
802
- The splitting points of the tree.
803
- max_split
804
- The maximum split index for each variable.
805
- leaf_index
806
- The index of the node, assumed to be valid for `var_tree`.
807
-
808
- Returns
809
- -------
810
- The indices of the variables that have an empty split range.
811
-
812
- Notes
813
- -----
814
- The number of unused variables is not known in advance. Unused values in the
815
- array are filled with `p`. The fill values are not guaranteed to be placed
816
- in any particular order, and variables may appear more than once.
817
- """
818
- var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index)
819
- split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0))
820
- l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore)
821
- num_split = r - l
822
- return jnp.where(num_split == 0, var_to_ignore, max_split.size)
823
-
824
-
825
- def ancestor_variables(
826
- var_tree: UInt[Array, '2**(d-1)'],
827
- max_split: UInt[Array, 'p'],
828
- node_index: Int32[Array, ''],
829
- ) -> UInt[Array, 'd-2']:
830
- """
831
- Return the list of variables in the ancestors of a node.
832
-
833
- Parameters
834
- ----------
835
- var_tree : int array (2 ** (d - 1),)
836
- The variable indices of the tree.
837
- max_split : int array (p,)
838
- The maximum split index for each variable. Used only to get `p`.
839
- node_index : int
840
- The index of the node, assumed to be valid for `var_tree`.
841
-
842
- Returns
843
- -------
844
- The variable indices of the ancestors of the node.
845
-
846
- Notes
847
- -----
848
- The ancestors are the nodes going from the root to the parent of the node.
849
- The number of ancestors is not known at tracing time; unused spots in the
850
- output array are filled with `p`.
851
- """
852
- max_num_ancestors = grove.tree_depth(var_tree) - 1
853
- ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size))
854
- carry = ancestor_vars.size - 1, node_index, ancestor_vars
855
-
856
- def loop(carry, _):
857
- i, index, ancestor_vars = carry
858
- index >>= 1
859
- var = var_tree[index]
860
- var = jnp.where(index, var, max_split.size)
861
- ancestor_vars = ancestor_vars.at[i].set(var)
862
- return (i - 1, index, ancestor_vars), None
863
-
864
- (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size)
865
- return ancestor_vars
866
-
867
-
868
- def split_range(
869
- var_tree: UInt[Array, '2**(d-1)'],
870
- split_tree: UInt[Array, '2**(d-1)'],
871
- max_split: UInt[Array, 'p'],
872
- node_index: Int32[Array, ''],
873
- ref_var: Int32[Array, ''],
874
- ) -> tuple[Int32[Array, ''], Int32[Array, '']]:
875
- """
876
- Return the range of allowed splits for a variable at a given node.
877
-
878
- Parameters
879
- ----------
880
- var_tree
881
- The variable indices of the tree.
882
- split_tree
883
- The splitting points of the tree.
884
- max_split
885
- The maximum split index for each variable.
886
- node_index
887
- The index of the node, assumed to be valid for `var_tree`.
888
- ref_var
889
- The variable for which to measure the split range.
890
-
891
- Returns
892
- -------
893
- The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
894
- """
895
- max_num_ancestors = grove.tree_depth(var_tree) - 1
896
- initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(
897
- jnp.int32
898
- )
899
- carry = 0, initial_r, node_index
900
-
901
- def loop(carry, _):
902
- l, r, index = carry
903
- right_child = (index & 1).astype(bool)
904
- index >>= 1
905
- split = split_tree[index]
906
- cond = (var_tree[index] == ref_var) & index.astype(bool)
907
- l = jnp.where(cond & right_child, jnp.maximum(l, split), l)
908
- r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r)
909
- return (l, r, index), None
910
-
911
- (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors)
912
- return l + 1, r
913
-
914
-
915
- def randint_exclude(
916
- key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
917
- ) -> Int32[Array, '']:
918
- """
919
- Return a random integer in a range, excluding some values.
920
-
921
- Parameters
922
- ----------
923
- key
924
- A jax random key.
925
- sup
926
- The exclusive upper bound of the range.
927
- exclude
928
- The values to exclude from the range. Values greater than or equal to
929
- `sup` are ignored. Values can appear more than once.
930
-
931
- Returns
932
- -------
933
- A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
934
-
935
- Notes
936
- -----
937
- If all values in the range are excluded, return `sup`.
938
- """
939
- exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup)
940
- num_allowed = sup - jnp.count_nonzero(exclude < sup)
941
- u = random.randint(key, (), 0, num_allowed)
942
-
943
- def loop(u, i):
944
- return jnp.where(i <= u, u + 1, u), None
945
-
946
- u, _ = lax.scan(loop, u, exclude)
947
- return u
948
-
949
-
950
- def choose_split(
951
- key: Key[Array, ''],
952
- var_tree: UInt[Array, '2**(d-1)'],
953
- split_tree: UInt[Array, '2**(d-1)'],
954
- max_split: UInt[Array, 'p'],
955
- leaf_index: Int32[Array, ''],
956
- ) -> Int32[Array, '']:
957
- """
958
- Choose a split point for a new non-terminal node.
959
-
960
- Parameters
961
- ----------
962
- key
963
- A jax random key.
964
- var_tree
965
- The splitting axes of the tree.
966
- split_tree
967
- The splitting points of the tree.
968
- max_split
969
- The maximum split index for each variable.
970
- leaf_index
971
- The index of the leaf to grow. It is assumed that `var_tree` already
972
- contains the target variable at this index.
973
-
974
- Returns
975
- -------
976
- The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
977
- """
978
- var = var_tree[leaf_index]
979
- l, r = split_range(var_tree, split_tree, max_split, leaf_index, var)
980
- return random.randint(key, (), l, r)
981
-
982
- # TODO what happens if leaf_index is out of bounds? And is the value used
983
- # in that case?
984
-
985
-
986
- def compute_partial_ratio(
987
- prob_choose: Float32[Array, ''],
988
- num_prunable: Int32[Array, ''],
989
- p_nonterminal: Float32[Array, 'd'],
990
- leaf_to_grow: Int32[Array, ''],
991
- ) -> Float32[Array, '']:
992
- """
993
- Compute the product of the transition and prior ratios of a grow move.
994
-
995
- Parameters
996
- ----------
997
- prob_choose
998
- The probability that the leaf had to be chosen amongst the growable
999
- leaves.
1000
- num_prunable
1001
- The number of leaf parents that could be pruned, after converting the
1002
- leaf to be grown to a non-terminal node.
1003
- p_nonterminal
1004
- The probability of a nonterminal node at each depth.
1005
- leaf_to_grow
1006
- The index of the leaf to grow.
1007
-
1008
- Returns
1009
- -------
1010
- The partial transition ratio times the prior ratio.
1011
-
1012
- Notes
1013
- -----
1014
- The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1015
- The "partial" transition ratio returned is missing the factor P(propose
1016
- prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
1017
- """
1018
- # the two ratios also contain factors num_available_split *
1019
- # num_available_var, but they cancel out
1020
-
1021
- # p_prune can't be computed here because it needs the count trees, which are
1022
- # computed in the acceptance phase
1023
-
1024
- prune_allowed = leaf_to_grow != 1
1025
- # prune allowed <---> the initial tree is not a root
1026
- # leaf to grow is root --> the tree can only be a root
1027
- # tree is a root --> the only leaf I can grow is root
1028
-
1029
- p_grow = jnp.where(prune_allowed, 0.5, 1)
1030
-
1031
- inv_trans_ratio = p_grow * prob_choose * num_prunable
1032
-
1033
- depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow]
1034
- p_parent = p_nonterminal[depth]
1035
- cp_children = 1 - p_nonterminal[depth + 1]
1036
- tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent)
1037
-
1038
- return tree_ratio / inv_trans_ratio
1039
-
1040
-
1041
- class PruneMoves(Module):
1042
- """
1043
- Represent a proposed prune move for each tree.
1044
-
1045
- Parameters
1046
- ----------
1047
- allowed
1048
- Whether the move is possible.
1049
- node
1050
- The index of the node to prune. ``2 ** d`` if no node can be pruned.
1051
- partial_ratio
1052
- A factor of the Metropolis-Hastings ratio of the move. It lacks
1053
- the likelihood ratio and the probability of proposing the prune
1054
- move. This ratio is inverted, and is meant to be inverted back in
1055
- `accept_move_and_sample_leaves`.
1056
- """
1057
-
1058
- allowed: Bool[Array, 'num_trees']
1059
- node: UInt[Array, 'num_trees']
1060
- partial_ratio: Float32[Array, 'num_trees']
1061
-
1062
-
1063
- @partial(vmap_nodoc, in_axes=(0, 0, 0, None, None))
1064
- def propose_prune_moves(
1065
- key: Key[Array, ''],
1066
- split_tree: UInt[Array, '2**(d-1)'],
1067
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
1068
- p_nonterminal: Float32[Array, 'd'],
1069
- p_propose_grow: Float32[Array, '2**(d-1)'],
1070
- ) -> PruneMoves:
1071
- """
1072
- Tree structure prune move proposal of BART MCMC.
1073
-
1074
- Parameters
1075
- ----------
1076
- key
1077
- A jax random key.
1078
- split_tree
1079
- The splitting points of the tree.
1080
- affluence_tree
1081
- Whether a leaf has enough points to be grown.
1082
- p_nonterminal
1083
- The probability of a nonterminal node at each depth.
1084
- p_propose_grow
1085
- The unnormalized probability of choosing a leaf to grow.
1086
-
1087
- Returns
1088
- -------
1089
- An object representing the proposed moves.
1090
- """
1091
- node_to_prune, num_prunable, prob_choose = choose_leaf_parent(
1092
- key, split_tree, affluence_tree, p_propose_grow
1093
- )
1094
- allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root
1095
-
1096
- ratio = compute_partial_ratio(
1097
- prob_choose, num_prunable, p_nonterminal, node_to_prune
1098
- )
1099
-
1100
- return PruneMoves(
1101
- allowed=allowed,
1102
- node=node_to_prune,
1103
- partial_ratio=ratio,
1104
- )
1105
-
1106
-
1107
- def choose_leaf_parent(
1108
- key: Key[Array, ''],
1109
- split_tree: UInt[Array, '2**(d-1)'],
1110
- affluence_tree: Bool[Array, '2**(d-1)'] | None,
1111
- p_propose_grow: Float32[Array, '2**(d-1)'],
1112
- ) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
1113
- """
1114
- Pick a non-terminal node with leaf children to prune in a tree.
1115
-
1116
- Parameters
1117
- ----------
1118
- key
1119
- A jax random key.
1120
- split_tree
1121
- The splitting points of the tree.
1122
- affluence_tree
1123
- Whether a leaf has enough points to be grown.
1124
- p_propose_grow
1125
- The unnormalized probability of choosing a leaf to grow.
1126
-
1127
- Returns
1128
- -------
1129
- node_to_prune : Int32[Array, '']
1130
- The index of the node to prune. If ``num_prunable == 0``, return
1131
- ``2 ** d``.
1132
- num_prunable : Int32[Array, '']
1133
- The number of leaf parents that could be pruned.
1134
- prob_choose : Float32[Array, '']
1135
- The (normalized) probability that `choose_leaf` would chose
1136
- `node_to_prune` as leaf to grow, if passed the tree where
1137
- `node_to_prune` had been pruned.
1138
- """
1139
- is_prunable = grove.is_leaves_parent(split_tree)
1140
- num_prunable = jnp.count_nonzero(is_prunable)
1141
- node_to_prune = randint_masked(key, is_prunable)
1142
- node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size)
1143
-
1144
- split_tree = split_tree.at[node_to_prune].set(0)
1145
- if affluence_tree is not None:
1146
- affluence_tree = affluence_tree.at[node_to_prune].set(True)
1147
- is_growable_leaf = growable_leaves(split_tree, affluence_tree)
1148
- prob_choose = p_propose_grow[node_to_prune]
1149
- prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf)
1150
-
1151
- return node_to_prune, num_prunable, prob_choose
1152
-
1153
-
1154
- def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']:
1155
- """
1156
- Return a random integer in a range, including only some values.
1157
-
1158
- Parameters
1159
- ----------
1160
- key
1161
- A jax random key.
1162
- mask
1163
- The mask indicating the allowed values.
1164
-
1165
- Returns
1166
- -------
1167
- A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1168
-
1169
- Notes
1170
- -----
1171
- If all values in the mask are `False`, return `n`.
1172
- """
1173
- ecdf = jnp.cumsum(mask)
1174
- u = random.randint(key, (), 0, ecdf[-1])
1175
- return jnp.searchsorted(ecdf, u, 'right')
1176
-
1177
-
1178
- def accept_moves_and_sample_leaves(
1179
- key: Key[Array, ''], bart: State, moves: Moves
1180
- ) -> State:
1181
- """
1182
- Accept or reject the proposed moves and sample the new leaf values.
1183
-
1184
- Parameters
1185
- ----------
1186
- key
1187
- A jax random key.
1188
- bart
1189
- A valid BART mcmc state.
1190
- moves
1191
- The proposed moves, see `propose_moves`.
1192
-
1193
- Returns
1194
- -------
1195
- A new (valid) BART mcmc state.
1196
- """
1197
- pso = accept_moves_parallel_stage(key, bart, moves)
1198
- bart, moves = accept_moves_sequential_stage(pso)
1199
- return accept_moves_final_stage(bart, moves)
1200
-
1201
-
1202
- class Counts(Module):
1203
- """
1204
- Number of datapoints in the nodes involved in proposed moves for each tree.
1205
-
1206
- Parameters
1207
- ----------
1208
- left
1209
- Number of datapoints in the left child.
1210
- right
1211
- Number of datapoints in the right child.
1212
- total
1213
- Number of datapoints in the parent (``= left + right``).
1214
- """
1215
-
1216
- left: UInt[Array, 'num_trees']
1217
- right: UInt[Array, 'num_trees']
1218
- total: UInt[Array, 'num_trees']
1219
-
1220
-
1221
- class Precs(Module):
1222
- """
1223
- Likelihood precision scale in the nodes involved in proposed moves for each tree.
1224
-
1225
- The "likelihood precision scale" of a tree node is the sum of the inverse
1226
- squared error scales of the datapoints selected by the node.
1227
-
1228
- Parameters
1229
- ----------
1230
- left
1231
- Likelihood precision scale in the left child.
1232
- right
1233
- Likelihood precision scale in the right child.
1234
- total
1235
- Likelihood precision scale in the parent (``= left + right``).
1236
- """
1237
-
1238
- left: Float32[Array, 'num_trees']
1239
- right: Float32[Array, 'num_trees']
1240
- total: Float32[Array, 'num_trees']
1241
-
1242
-
1243
- class PreLkV(Module):
1244
- """
1245
- Non-sequential terms of the likelihood ratio for each tree.
1246
-
1247
- These terms can be computed in parallel across trees.
1248
-
1249
- Parameters
1250
- ----------
1251
- sigma2_left
1252
- The noise variance in the left child of the leaves grown or pruned by
1253
- the moves.
1254
- sigma2_right
1255
- The noise variance in the right child of the leaves grown or pruned by
1256
- the moves.
1257
- sigma2_total
1258
- The noise variance in the total of the leaves grown or pruned by the
1259
- moves.
1260
- sqrt_term
1261
- The **logarithm** of the square root term of the likelihood ratio.
1262
- """
1263
-
1264
- sigma2_left: Float32[Array, 'num_trees']
1265
- sigma2_right: Float32[Array, 'num_trees']
1266
- sigma2_total: Float32[Array, 'num_trees']
1267
- sqrt_term: Float32[Array, 'num_trees']
1268
-
1269
-
1270
- class PreLk(Module):
1271
- """
1272
- Non-sequential terms of the likelihood ratio shared by all trees.
1273
-
1274
- Parameters
1275
- ----------
1276
- exp_factor
1277
- The factor to multiply the likelihood ratio by, shared by all trees.
1278
- """
1279
-
1280
- exp_factor: Float32[Array, '']
1281
-
1282
-
1283
- class PreLf(Module):
1284
- """
1285
- Pre-computed terms used to sample leaves from their posterior.
1286
-
1287
- These terms can be computed in parallel across trees.
1288
-
1289
- Parameters
1290
- ----------
1291
- mean_factor
1292
- The factor to be multiplied by the sum of the scaled residuals to
1293
- obtain the posterior mean.
1294
- centered_leaves
1295
- The mean-zero normal values to be added to the posterior mean to
1296
- obtain the posterior leaf samples.
1297
- """
1298
-
1299
- mean_factor: Float32[Array, 'num_trees 2**d']
1300
- centered_leaves: Float32[Array, 'num_trees 2**d']
1301
-
1302
-
1303
- class ParallelStageOut(Module):
1304
- """
1305
- The output of `accept_moves_parallel_stage`.
1306
-
1307
- Parameters
1308
- ----------
1309
- bart
1310
- A partially updated BART mcmc state.
1311
- moves
1312
- The proposed moves, with `partial_ratio` set to `None` and
1313
- `log_trans_prior_ratio` set to its final value.
1314
- prec_trees
1315
- The likelihood precision scale in each potential or actual leaf node. If
1316
- there is no precision scale, this is the number of points in each leaf.
1317
- move_counts
1318
- The counts of the number of points in the the nodes modified by the
1319
- moves. If `bart.min_points_per_leaf` is not set and
1320
- `bart.prec_scale` is set, they are not computed.
1321
- move_precs
1322
- The likelihood precision scale in each node modified by the moves. If
1323
- `bart.prec_scale` is not set, this is set to `move_counts`.
1324
- prelkv
1325
- prelk
1326
- prelf
1327
- Objects with pre-computed terms of the likelihood ratios and leaf
1328
- samples.
1329
- """
1330
-
1331
- bart: State
1332
- moves: Moves
1333
- prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d']
1334
- move_counts: Counts | None
1335
- move_precs: Precs | Counts
1336
- prelkv: PreLkV
1337
- prelk: PreLk
1338
- prelf: PreLf
1339
-
1340
-
1341
- def accept_moves_parallel_stage(
1342
- key: Key[Array, ''], bart: State, moves: Moves
1343
- ) -> ParallelStageOut:
1344
- """
1345
- Pre-computes quantities used to accept moves, in parallel across trees.
1346
-
1347
- Parameters
1348
- ----------
1349
- key : jax.dtypes.prng_key array
1350
- A jax random key.
1351
- bart : dict
1352
- A BART mcmc state.
1353
- moves : dict
1354
- The proposed moves, see `propose_moves`.
1355
-
1356
- Returns
1357
- -------
1358
- An object with all that could be done in parallel.
1359
- """
1360
- # where the move is grow, modify the state like the move was accepted
1361
- bart = replace(
1362
- bart,
1363
- forest=replace(
1364
- bart.forest,
1365
- var_trees=moves.var_trees,
1366
- leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1367
- leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
1368
- ),
1369
- )
1370
-
1371
- # count number of datapoints per leaf
1372
- if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None:
1373
- count_trees, move_counts = compute_count_trees(
1374
- bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1375
- )
1376
- else:
1377
- # move_counts is passed later to a function, but then is unused under
1378
- # this condition
1379
- move_counts = None
1380
-
1381
- # Check if some nodes can't surely be grown because they don't have enough
1382
- # datapoints. This check is not actually used now, it will be used at the
1383
- # beginning of the next step to propose moves.
1384
- if bart.forest.min_points_per_leaf is not None:
1385
- count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]]
1386
- bart = replace(
1387
- bart,
1388
- forest=replace(
1389
- bart.forest,
1390
- affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
1391
- ),
1392
- )
1393
-
1394
- # count number of datapoints per leaf, weighted by error precision scale
1395
- if bart.prec_scale is None:
1396
- prec_trees = count_trees
1397
- move_precs = move_counts
1398
- else:
1399
- prec_trees, move_precs = compute_prec_trees(
1400
- bart.prec_scale,
1401
- bart.forest.leaf_indices,
1402
- moves,
1403
- bart.forest.count_batch_size,
1404
- )
1405
-
1406
- # compute some missing information about moves
1407
- moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf)
1408
- bart = replace(
1409
- bart,
1410
- forest=replace(
1411
- bart.forest,
1412
- grow_prop_count=jnp.sum(moves.grow),
1413
- prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1414
- ),
1415
- )
1416
-
1417
- prelkv, prelk = precompute_likelihood_terms(
1418
- bart.sigma2, bart.forest.sigma_mu2, move_precs
1419
- )
1420
- prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2)
1421
-
1422
- return ParallelStageOut(
1423
- bart=bart,
1424
- moves=moves,
1425
- prec_trees=prec_trees,
1426
- move_counts=move_counts,
1427
- move_precs=move_precs,
1428
- prelkv=prelkv,
1429
- prelk=prelk,
1430
- prelf=prelf,
1431
- )
1432
-
1433
-
1434
- @partial(vmap_nodoc, in_axes=(0, 0, None))
1435
- def apply_grow_to_indices(
1436
- moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1437
- ) -> UInt[Array, 'num_trees n']:
1438
- """
1439
- Update the leaf indices to apply a grow move.
1440
-
1441
- Parameters
1442
- ----------
1443
- moves
1444
- The proposed moves, see `propose_moves`.
1445
- leaf_indices
1446
- The index of the leaf each datapoint falls into.
1447
- X
1448
- The predictors matrix.
1449
-
1450
- Returns
1451
- -------
1452
- The updated leaf indices.
1453
- """
1454
- left_child = moves.node.astype(leaf_indices.dtype) << 1
1455
- go_right = X[moves.grow_var, :] >= moves.grow_split
1456
- tree_size = jnp.array(2 * moves.var_trees.size)
1457
- node_to_update = jnp.where(moves.grow, moves.node, tree_size)
1458
- return jnp.where(
1459
- leaf_indices == node_to_update,
1460
- left_child + go_right,
1461
- leaf_indices,
1462
- )
1463
-
1464
-
1465
- def compute_count_trees(
1466
- leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1467
- ) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1468
- """
1469
- Count the number of datapoints in each leaf.
1470
-
1471
- Parameters
1472
- ----------
1473
- leaf_indices
1474
- The index of the leaf each datapoint falls into, with the deeper version
1475
- of the tree (post-GROW, pre-PRUNE).
1476
- moves
1477
- The proposed moves, see `propose_moves`.
1478
- batch_size
1479
- The data batch size to use for the summation.
1480
-
1481
- Returns
1482
- -------
1483
- count_trees : Int32[Array, 'num_trees 2**d']
1484
- The number of points in each potential or actual leaf node.
1485
- counts : Counts
1486
- The counts of the number of points in the leaves grown or pruned by the
1487
- moves.
1488
- """
1489
- num_trees, tree_size = moves.var_trees.shape
1490
- tree_size *= 2
1491
- tree_indices = jnp.arange(num_trees)
1492
-
1493
- count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)
1494
-
1495
- # count datapoints in nodes modified by move
1496
- left = count_trees[tree_indices, moves.left]
1497
- right = count_trees[tree_indices, moves.right]
1498
- counts = Counts(left=left, right=right, total=left + right)
1499
-
1500
- # write count into non-leaf node
1501
- count_trees = count_trees.at[tree_indices, moves.node].set(counts.total)
1502
-
1503
- return count_trees, counts
1504
-
1505
-
1506
- def count_datapoints_per_leaf(
1507
- leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1508
- ) -> Int32[Array, 'num_trees 2**(d-1)']:
1509
- """
1510
- Count the number of datapoints in each leaf.
1511
-
1512
- Parameters
1513
- ----------
1514
- leaf_indices
1515
- The index of the leaf each datapoint falls into.
1516
- tree_size
1517
- The size of the leaf tree array (2 ** d).
1518
- batch_size
1519
- The data batch size to use for the summation.
1520
-
1521
- Returns
1522
- -------
1523
- The number of points in each leaf node.
1524
- """
1525
- if batch_size is None:
1526
- return _count_scan(leaf_indices, tree_size)
1527
- else:
1528
- return _count_vec(leaf_indices, tree_size, batch_size)
1529
-
1530
-
1531
- def _count_scan(
1532
- leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1533
- ) -> Int32[Array, 'num_trees {tree_size}']:
1534
- def loop(_, leaf_indices):
1535
- return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32)
1536
-
1537
- _, count_trees = lax.scan(loop, None, leaf_indices)
1538
- return count_trees
1539
-
1540
-
1541
- def _aggregate_scatter(
1542
- values: Shaped[Array, '*'],
1543
- indices: Integer[Array, '*'],
1544
- size: int,
1545
- dtype: jnp.dtype,
1546
- ) -> Shaped[Array, '{size}']:
1547
- return jnp.zeros(size, dtype).at[indices].add(values)
1548
-
1549
-
1550
- def _count_vec(
1551
- leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1552
- ) -> Int32[Array, 'num_trees 2**(d-1)']:
1553
- return _aggregate_batched_alltrees(
1554
- 1, leaf_indices, tree_size, jnp.uint32, batch_size
1555
- )
1556
- # uint16 is super-slow on gpu, don't use it even if n < 2^16
1557
-
1558
-
1559
- def _aggregate_batched_alltrees(
1560
- values: Shaped[Array, '*'],
1561
- indices: UInt[Array, 'num_trees n'],
1562
- size: int,
1563
- dtype: jnp.dtype,
1564
- batch_size: int,
1565
- ) -> Shaped[Array, 'num_trees {size}']:
1566
- num_trees, n = indices.shape
1567
- tree_indices = jnp.arange(num_trees)
1568
- nbatches = n // batch_size + bool(n % batch_size)
1569
- batch_indices = jnp.arange(n) % nbatches
1570
- return (
1571
- jnp.zeros((num_trees, size, nbatches), dtype)
1572
- .at[tree_indices[:, None], indices, batch_indices]
1573
- .add(values)
1574
- .sum(axis=2)
1575
- )
1576
-
1577
-
1578
- def compute_prec_trees(
1579
- prec_scale: Float32[Array, 'n'],
1580
- leaf_indices: UInt[Array, 'num_trees n'],
1581
- moves: Moves,
1582
- batch_size: int | None,
1583
- ) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1584
- """
1585
- Compute the likelihood precision scale in each leaf.
1586
-
1587
- Parameters
1588
- ----------
1589
- prec_scale
1590
- The scale of the precision of the error on each datapoint.
1591
- leaf_indices
1592
- The index of the leaf each datapoint falls into, with the deeper version
1593
- of the tree (post-GROW, pre-PRUNE).
1594
- moves
1595
- The proposed moves, see `propose_moves`.
1596
- batch_size
1597
- The data batch size to use for the summation.
1598
-
1599
- Returns
1600
- -------
1601
- prec_trees : Float32[Array, 'num_trees 2**d']
1602
- The likelihood precision scale in each potential or actual leaf node.
1603
- precs : Precs
1604
- The likelihood precision scale in the nodes involved in the moves.
1605
- """
1606
- num_trees, tree_size = moves.var_trees.shape
1607
- tree_size *= 2
1608
- tree_indices = jnp.arange(num_trees)
1609
-
1610
- prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)
1611
-
1612
- # prec datapoints in nodes modified by move
1613
- left = prec_trees[tree_indices, moves.left]
1614
- right = prec_trees[tree_indices, moves.right]
1615
- precs = Precs(left=left, right=right, total=left + right)
1616
-
1617
- # write prec into non-leaf node
1618
- prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total)
1619
-
1620
- return prec_trees, precs
1621
-
1622
-
1623
- def prec_per_leaf(
1624
- prec_scale: Float32[Array, 'n'],
1625
- leaf_indices: UInt[Array, 'num_trees n'],
1626
- tree_size: int,
1627
- batch_size: int | None,
1628
- ) -> Float32[Array, 'num_trees {tree_size}']:
1629
- """
1630
- Compute the likelihood precision scale in each leaf.
1631
-
1632
- Parameters
1633
- ----------
1634
- prec_scale
1635
- The scale of the precision of the error on each datapoint.
1636
- leaf_indices
1637
- The index of the leaf each datapoint falls into.
1638
- tree_size
1639
- The size of the leaf tree array (2 ** d).
1640
- batch_size
1641
- The data batch size to use for the summation.
1642
-
1643
- Returns
1644
- -------
1645
- The likelihood precision scale in each leaf node.
1646
- """
1647
- if batch_size is None:
1648
- return _prec_scan(prec_scale, leaf_indices, tree_size)
1649
- else:
1650
- return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1651
-
1652
-
1653
- def _prec_scan(
1654
- prec_scale: Float32[Array, 'n'],
1655
- leaf_indices: UInt[Array, 'num_trees n'],
1656
- tree_size: int,
1657
- ) -> Float32[Array, 'num_trees {tree_size}']:
1658
- def loop(_, leaf_indices):
1659
- return None, _aggregate_scatter(
1660
- prec_scale, leaf_indices, tree_size, jnp.float32
1661
- )
1662
-
1663
- _, prec_trees = lax.scan(loop, None, leaf_indices)
1664
- return prec_trees
1665
-
1666
-
1667
- def _prec_vec(
1668
- prec_scale: Float32[Array, 'n'],
1669
- leaf_indices: UInt[Array, 'num_trees n'],
1670
- tree_size: int,
1671
- batch_size: int,
1672
- ) -> Float32[Array, 'num_trees {tree_size}']:
1673
- return _aggregate_batched_alltrees(
1674
- prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1675
- )
1676
-
1677
-
1678
- def complete_ratio(
1679
- moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1680
- ) -> Moves:
1681
- """
1682
- Complete non-likelihood MH ratio calculation.
1683
-
1684
- This function adds the probability of choosing the prune move.
1685
-
1686
- Parameters
1687
- ----------
1688
- moves
1689
- The proposed moves, see `propose_moves`.
1690
- move_counts
1691
- The counts of the number of points in the the nodes modified by the
1692
- moves.
1693
- min_points_per_leaf
1694
- The minimum number of data points in a leaf node.
1695
-
1696
- Returns
1697
- -------
1698
- The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1699
- """
1700
- p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf)
1701
- return replace(
1702
- moves,
1703
- log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
1704
- partial_ratio=None,
1705
- )
1706
-
1707
-
1708
- def compute_p_prune(
1709
- moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1710
- ) -> Float32[Array, 'num_trees']:
1711
- """
1712
- Compute the probability of proposing a prune move for each tree.
1713
-
1714
- Parameters
1715
- ----------
1716
- moves
1717
- The proposed moves, see `propose_moves`.
1718
- move_counts
1719
- The number of datapoints in the proposed children of the leaf to grow.
1720
- Not used if `min_points_per_leaf` is `None`.
1721
- min_points_per_leaf
1722
- The minimum number of data points in a leaf node.
1723
-
1724
- Returns
1725
- -------
1726
- The probability of proposing a prune move.
1727
-
1728
- Notes
1729
- -----
1730
- This probability is computed for going from the state with the deeper tree
1731
- to the one with the shallower one. This means, if grow: after accepting the
1732
- grow move, if prune: right away.
1733
- """
1734
- # calculation in case the move is grow
1735
- other_growable_leaves = moves.num_growable >= 2
1736
- new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2
1737
- if min_points_per_leaf is not None:
1738
- assert move_counts is not None
1739
- any_above_threshold = move_counts.left >= 2 * min_points_per_leaf
1740
- any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf
1741
- new_leaves_growable &= any_above_threshold
1742
- grow_again_allowed = other_growable_leaves | new_leaves_growable
1743
- grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1)
1744
-
1745
- # calculation in case the move is prune
1746
- prune_p_prune = jnp.where(moves.num_growable, 0.5, 1)
1747
-
1748
- return jnp.where(moves.grow, grow_p_prune, prune_p_prune)
1749
-
1750
-
1751
- @vmap_nodoc
1752
- def adapt_leaf_trees_to_grow_indices(
1753
- leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1754
- ) -> Float32[Array, 'num_trees 2**d']:
1755
- """
1756
- Modify leaves such that post-grow indices work on the original tree.
1757
-
1758
- The value of the leaf to grow is copied to what would be its children if the
1759
- grow move was accepted.
1760
-
1761
- Parameters
1762
- ----------
1763
- leaf_trees
1764
- The leaf values.
1765
- moves
1766
- The proposed moves, see `propose_moves`.
1767
-
1768
- Returns
1769
- -------
1770
- The modified leaf values.
1771
- """
1772
- values_at_node = leaf_trees[moves.node]
1773
- return (
1774
- leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1775
- .set(values_at_node)
1776
- .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1777
- .set(values_at_node)
1778
- )
1779
-
1780
-
1781
- def precompute_likelihood_terms(
1782
- sigma2: Float32[Array, ''],
1783
- sigma_mu2: Float32[Array, ''],
1784
- move_precs: Precs | Counts,
1785
- ) -> tuple[PreLkV, PreLk]:
1786
- """
1787
- Pre-compute terms used in the likelihood ratio of the acceptance step.
1788
-
1789
- Parameters
1790
- ----------
1791
- sigma2
1792
- The error variance, or the global error variance factor is `prec_scale`
1793
- is set.
1794
- sigma_mu2
1795
- The prior variance of each leaf.
1796
- move_precs
1797
- The likelihood precision scale in the leaves grown or pruned by the
1798
- moves, under keys 'left', 'right', and 'total' (left + right).
1799
-
1800
- Returns
1801
- -------
1802
- prelkv : PreLkV
1803
- Dictionary with pre-computed terms of the likelihood ratio, one per
1804
- tree.
1805
- prelk : PreLk
1806
- Dictionary with pre-computed terms of the likelihood ratio, shared by
1807
- all trees.
1808
- """
1809
- sigma2_left = sigma2 + move_precs.left * sigma_mu2
1810
- sigma2_right = sigma2 + move_precs.right * sigma_mu2
1811
- sigma2_total = sigma2 + move_precs.total * sigma_mu2
1812
- prelkv = PreLkV(
1813
- sigma2_left=sigma2_left,
1814
- sigma2_right=sigma2_right,
1815
- sigma2_total=sigma2_total,
1816
- sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1817
- )
1818
- return prelkv, PreLk(
1819
- exp_factor=sigma_mu2 / (2 * sigma2),
1820
- )
1821
-
1822
-
1823
- def precompute_leaf_terms(
1824
- key: Key[Array, ''],
1825
- prec_trees: Float32[Array, 'num_trees 2**d'],
1826
- sigma2: Float32[Array, ''],
1827
- sigma_mu2: Float32[Array, ''],
1828
- ) -> PreLf:
1829
- """
1830
- Pre-compute terms used to sample leaves from their posterior.
1831
-
1832
- Parameters
1833
- ----------
1834
- key
1835
- A jax random key.
1836
- prec_trees
1837
- The likelihood precision scale in each potential or actual leaf node.
1838
- sigma2
1839
- The error variance, or the global error variance factor if `prec_scale`
1840
- is set.
1841
- sigma_mu2
1842
- The prior variance of each leaf.
1843
-
1844
- Returns
1845
- -------
1846
- Pre-computed terms for leaf sampling.
1847
- """
1848
- prec_lk = prec_trees / sigma2
1849
- prec_prior = lax.reciprocal(sigma_mu2)
1850
- var_post = lax.reciprocal(prec_lk + prec_prior)
1851
- z = random.normal(key, prec_trees.shape, sigma2.dtype)
1852
- return PreLf(
1853
- mean_factor=var_post / sigma2,
1854
- # mean = mean_lk * prec_lk * var_post
1855
- # resid_tree = mean_lk * prec_tree -->
1856
- # --> mean_lk = resid_tree / prec_tree (kind of)
1857
- # mean_factor =
1858
- # = mean / resid_tree =
1859
- # = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
1860
- # = 1 / prec_tree * prec_tree / sigma2 * var_post =
1861
- # = var_post / sigma2
1862
- centered_leaves=z * jnp.sqrt(var_post),
1863
- )
1864
-
1865
-
1866
- def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]:
1867
- """
1868
- Accept/reject the moves one tree at a time.
1869
-
1870
- This is the most performance-sensitive function because it contains all and
1871
- only the parts of the algorithm that can not be parallelized across trees.
1872
-
1873
- Parameters
1874
- ----------
1875
- pso
1876
- The output of `accept_moves_parallel_stage`.
1877
-
1878
- Returns
1879
- -------
1880
- bart : State
1881
- A partially updated BART mcmc state.
1882
- moves : Moves
1883
- The accepted/rejected moves, with `acc` and `to_prune` set.
1884
- """
1885
-
1886
- def loop(resid, pt):
1887
- resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves(
1888
- resid,
1889
- SeqStageInAllTrees(
1890
- pso.bart.X,
1891
- pso.bart.forest.resid_batch_size,
1892
- pso.bart.prec_scale,
1893
- pso.bart.forest.min_points_per_leaf,
1894
- pso.bart.forest.log_likelihood is not None,
1895
- pso.prelk,
1896
- ),
1897
- pt,
1898
- )
1899
- return resid, (leaf_tree, acc, to_prune, ratios)
1900
-
1901
- pts = SeqStageInPerTree(
1902
- pso.bart.forest.leaf_trees,
1903
- pso.prec_trees,
1904
- pso.moves,
1905
- pso.move_counts,
1906
- pso.move_precs,
1907
- pso.bart.forest.leaf_indices,
1908
- pso.prelkv,
1909
- pso.prelf,
1910
- )
1911
- resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts)
1912
-
1913
- save_ratios = pso.bart.forest.log_likelihood is not None
1914
- bart = replace(
1915
- pso.bart,
1916
- resid=resid,
1917
- forest=replace(
1918
- pso.bart.forest,
1919
- leaf_trees=leaf_trees,
1920
- log_likelihood=ratios['log_likelihood'] if save_ratios else None,
1921
- log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
1922
- ),
1923
- )
1924
- moves = replace(pso.moves, acc=acc, to_prune=to_prune)
1925
-
1926
- return bart, moves
1927
-
1928
-
1929
- class SeqStageInAllTrees(Module):
1930
- """
1931
- The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
1932
-
1933
- Parameters
1934
- ----------
1935
- X
1936
- The predictors.
1937
- resid_batch_size
1938
- The batch size for computing the sum of residuals in each leaf.
1939
- prec_scale
1940
- The scale of the precision of the error on each datapoint. If None, it
1941
- is assumed to be 1.
1942
- min_points_per_leaf
1943
- The minimum number of data points in a leaf node.
1944
- save_ratios
1945
- Whether to save the acceptance ratios.
1946
- prelk
1947
- The pre-computed terms of the likelihood ratio which are shared across
1948
- trees.
1949
- """
1950
-
1951
- X: UInt[Array, 'p n']
1952
- resid_batch_size: int | None
1953
- prec_scale: Float32[Array, 'n'] | None
1954
- min_points_per_leaf: Int32[Array, ''] | None
1955
- save_ratios: bool
1956
- prelk: PreLk
1957
-
1958
-
1959
- class SeqStageInPerTree(Module):
1960
- """
1961
- The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
1962
-
1963
- Parameters
1964
- ----------
1965
- leaf_tree
1966
- The leaf values of the tree.
1967
- prec_tree
1968
- The likelihood precision scale in each potential or actual leaf node.
1969
- move
1970
- The proposed move, see `propose_moves`.
1971
- move_counts
1972
- The counts of the number of points in the the nodes modified by the
1973
- moves.
1974
- move_precs
1975
- The likelihood precision scale in each node modified by the moves.
1976
- leaf_indices
1977
- The leaf indices for the largest version of the tree compatible with
1978
- the move.
1979
- prelkv
1980
- prelf
1981
- The pre-computed terms of the likelihood ratio and leaf sampling which
1982
- are specific to the tree.
1983
- """
1984
-
1985
- leaf_tree: Float32[Array, '2**d']
1986
- prec_tree: Float32[Array, '2**d']
1987
- move: Moves
1988
- move_counts: Counts | None
1989
- move_precs: Precs | Counts
1990
- leaf_indices: UInt[Array, 'n']
1991
- prelkv: PreLkV
1992
- prelf: PreLf
1993
-
1994
-
1995
- def accept_move_and_sample_leaves(
1996
- resid: Float32[Array, 'n'],
1997
- at: SeqStageInAllTrees,
1998
- pt: SeqStageInPerTree,
1999
- ) -> tuple[
2000
- Float32[Array, 'n'],
2001
- Float32[Array, '2**d'],
2002
- Bool[Array, ''],
2003
- Bool[Array, ''],
2004
- dict[str, Float32[Array, '']],
2005
- ]:
2006
- """
2007
- Accept or reject a proposed move and sample the new leaf values.
2008
-
2009
- Parameters
2010
- ----------
2011
- resid
2012
- The residuals (data minus forest value).
2013
- at
2014
- The inputs that are the same for all trees.
2015
- pt
2016
- The inputs that are separate for each tree.
2017
-
2018
- Returns
2019
- -------
2020
- resid : Float32[Array, 'n']
2021
- The updated residuals (data minus forest value).
2022
- leaf_tree : Float32[Array, '2**d']
2023
- The new leaf values of the tree.
2024
- acc : Bool[Array, '']
2025
- Whether the move was accepted.
2026
- to_prune : Bool[Array, '']
2027
- Whether, to reflect the acceptance status of the move, the state should
2028
- be updated by pruning the leaves involved in the move.
2029
- ratios : dict[str, Float32[Array, '']]
2030
- The acceptance ratios for the moves. Empty if not to be saved.
2031
- """
2032
- # sum residuals in each leaf, in tree proposed by grow move
2033
- if at.prec_scale is None:
2034
- scaled_resid = resid
2035
- else:
2036
- scaled_resid = resid * at.prec_scale
2037
- resid_tree = sum_resid(
2038
- scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2039
- )
2040
-
2041
- # subtract starting tree from function
2042
- resid_tree += pt.prec_tree * pt.leaf_tree
2043
-
2044
- # get indices of move
2045
- node = pt.move.node
2046
- assert node.dtype == jnp.int32
2047
- left = pt.move.left
2048
- right = pt.move.right
2049
-
2050
- # sum residuals in parent node modified by move
2051
- resid_left = resid_tree[left]
2052
- resid_right = resid_tree[right]
2053
- resid_total = resid_left + resid_right
2054
- resid_tree = resid_tree.at[node].set(resid_total)
2055
-
2056
- # compute acceptance ratio
2057
- log_lk_ratio = compute_likelihood_ratio(
2058
- resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2059
- )
2060
- log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio
2061
- log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio)
2062
- ratios = {}
2063
- if at.save_ratios:
2064
- ratios.update(
2065
- log_trans_prior=pt.move.log_trans_prior_ratio,
2066
- # TODO save log_trans_prior_ratio as a vector outside of this loop,
2067
- # then change the option everywhere to `save_likelihood_ratio`.
2068
- log_likelihood=log_lk_ratio,
2069
- )
2070
-
2071
- # determine whether to accept the move
2072
- acc = pt.move.allowed & (pt.move.logu <= log_ratio)
2073
- if at.min_points_per_leaf is not None:
2074
- assert pt.move_counts is not None
2075
- acc &= pt.move_counts.left >= at.min_points_per_leaf
2076
- acc &= pt.move_counts.right >= at.min_points_per_leaf
2077
-
2078
- # compute leaves posterior and sample leaves
2079
- initial_leaf_tree = pt.leaf_tree
2080
- mean_post = resid_tree * pt.prelf.mean_factor
2081
- leaf_tree = mean_post + pt.prelf.centered_leaves
2082
-
2083
- # copy leaves around such that the leaf indices point to the correct leaf
2084
- to_prune = acc ^ pt.move.grow
2085
- leaf_tree = (
2086
- leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
2087
- .set(leaf_tree[node])
2088
- .at[jnp.where(to_prune, right, leaf_tree.size)]
2089
- .set(leaf_tree[node])
2090
- )
2091
-
2092
- # replace old tree with new tree in function values
2093
- resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices]
2094
-
2095
- return resid, leaf_tree, acc, to_prune, ratios
2096
-
2097
-
2098
- def sum_resid(
2099
- scaled_resid: Float32[Array, 'n'],
2100
- leaf_indices: UInt[Array, 'n'],
2101
- tree_size: int,
2102
- batch_size: int | None,
2103
- ) -> Float32[Array, '{tree_size}']:
2104
- """
2105
- Sum the residuals in each leaf.
2106
-
2107
- Parameters
2108
- ----------
2109
- scaled_resid
2110
- The residuals (data minus forest value) multiplied by the error
2111
- precision scale.
2112
- leaf_indices
2113
- The leaf indices of the tree (in which leaf each data point falls into).
2114
- tree_size
2115
- The size of the tree array (2 ** d).
2116
- batch_size
2117
- The data batch size for the aggregation. Batching increases numerical
2118
- accuracy and parallelism.
2119
-
2120
- Returns
2121
- -------
2122
- The sum of the residuals at data points in each leaf.
2123
- """
2124
- if batch_size is None:
2125
- aggr_func = _aggregate_scatter
2126
- else:
2127
- aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size)
2128
- return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32)
2129
-
2130
-
2131
- def _aggregate_batched_onetree(
2132
- values: Shaped[Array, '*'],
2133
- indices: Integer[Array, '*'],
2134
- size: int,
2135
- dtype: jnp.dtype,
2136
- batch_size: int,
2137
- ) -> Float32[Array, '{size}']:
2138
- (n,) = indices.shape
2139
- nbatches = n // batch_size + bool(n % batch_size)
2140
- batch_indices = jnp.arange(n) % nbatches
2141
- return (
2142
- jnp.zeros((size, nbatches), dtype)
2143
- .at[indices, batch_indices]
2144
- .add(values)
2145
- .sum(axis=1)
2146
- )
2147
-
2148
-
2149
- def compute_likelihood_ratio(
2150
- total_resid: Float32[Array, ''],
2151
- left_resid: Float32[Array, ''],
2152
- right_resid: Float32[Array, ''],
2153
- prelkv: PreLkV,
2154
- prelk: PreLk,
2155
- ) -> Float32[Array, '']:
2156
- """
2157
- Compute the likelihood ratio of a grow move.
2158
-
2159
- Parameters
2160
- ----------
2161
- total_resid
2162
- left_resid
2163
- right_resid
2164
- The sum of the residuals (scaled by error precision scale) of the
2165
- datapoints falling in the nodes involved in the moves.
2166
- prelkv
2167
- prelk
2168
- The pre-computed terms of the likelihood ratio, see
2169
- `precompute_likelihood_terms`.
2170
-
2171
- Returns
2172
- -------
2173
- The likelihood ratio P(data | new tree) / P(data | old tree).
2174
- """
2175
- exp_term = prelk.exp_factor * (
2176
- left_resid * left_resid / prelkv.sigma2_left
2177
- + right_resid * right_resid / prelkv.sigma2_right
2178
- - total_resid * total_resid / prelkv.sigma2_total
2179
- )
2180
- return prelkv.sqrt_term + exp_term
2181
-
2182
-
2183
- def accept_moves_final_stage(bart: State, moves: Moves) -> State:
2184
- """
2185
- Post-process the mcmc state after accepting/rejecting the moves.
2186
-
2187
- This function is separate from `accept_moves_sequential_stage` to signal it
2188
- can work in parallel across trees.
2189
-
2190
- Parameters
2191
- ----------
2192
- bart
2193
- A partially updated BART mcmc state.
2194
- moves
2195
- The proposed moves (see `propose_moves`) as updated by
2196
- `accept_moves_sequential_stage`.
2197
-
2198
- Returns
2199
- -------
2200
- The fully updated BART mcmc state.
2201
- """
2202
- return replace(
2203
- bart,
2204
- forest=replace(
2205
- bart.forest,
2206
- grow_acc_count=jnp.sum(moves.acc & moves.grow),
2207
- prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2208
- leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2209
- split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
2210
- ),
2211
- )
2212
-
2213
-
2214
- @vmap_nodoc
2215
- def apply_moves_to_leaf_indices(
2216
- leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2217
- ) -> UInt[Array, 'num_trees n']:
2218
- """
2219
- Update the leaf indices to match the accepted move.
2220
-
2221
- Parameters
2222
- ----------
2223
- leaf_indices
2224
- The index of the leaf each datapoint falls into, if the grow move was
2225
- accepted.
2226
- moves
2227
- The proposed moves (see `propose_moves`), as updated by
2228
- `accept_moves_sequential_stage`.
2229
-
2230
- Returns
2231
- -------
2232
- The updated leaf indices.
2233
- """
2234
- mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110
2235
- is_child = (leaf_indices & mask) == moves.left
2236
- return jnp.where(
2237
- is_child & moves.to_prune,
2238
- moves.node.astype(leaf_indices.dtype),
2239
- leaf_indices,
2240
- )
2241
-
2242
-
2243
- @vmap_nodoc
2244
- def apply_moves_to_split_trees(
2245
- split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2246
- ) -> UInt[Array, 'num_trees 2**(d-1)']:
2247
- """
2248
- Update the split trees to match the accepted move.
2249
-
2250
- Parameters
2251
- ----------
2252
- split_trees
2253
- The cutpoints of the decision nodes in the initial trees.
2254
- moves
2255
- The proposed moves (see `propose_moves`), as updated by
2256
- `accept_moves_sequential_stage`.
2257
-
2258
- Returns
2259
- -------
2260
- The updated split trees.
2261
- """
2262
- assert moves.to_prune is not None
2263
- return (
2264
- split_trees.at[
2265
- jnp.where(
2266
- moves.grow,
2267
- moves.node,
2268
- split_trees.size,
2269
- )
2270
- ]
2271
- .set(moves.grow_split.astype(split_trees.dtype))
2272
- .at[
2273
- jnp.where(
2274
- moves.to_prune,
2275
- moves.node,
2276
- split_trees.size,
2277
- )
2278
- ]
2279
- .set(0)
2280
- )
2281
-
2282
-
2283
- def step_sigma(key: Key[Array, ''], bart: State) -> State:
2284
- """
2285
- MCMC-update the error variance (factor).
2286
-
2287
- Parameters
2288
- ----------
2289
- key
2290
- A jax random key.
2291
- bart
2292
- A BART mcmc state.
2293
-
2294
- Returns
2295
- -------
2296
- The new BART mcmc state, with an updated `sigma2`.
2297
- """
2298
- resid = bart.resid
2299
- alpha = bart.sigma2_alpha + resid.size / 2
2300
- if bart.prec_scale is None:
2301
- scaled_resid = resid
2302
- else:
2303
- scaled_resid = resid * bart.prec_scale
2304
- norm2 = resid @ scaled_resid
2305
- beta = bart.sigma2_beta + norm2 / 2
2306
-
2307
- sample = random.gamma(key, alpha)
2308
- return replace(bart, sigma2=beta / sample)
2309
-
2310
-
2311
- def step_z(key: Key[Array, ''], bart: State) -> State:
2312
- """
2313
- MCMC-update the latent variable for binary regression.
2314
-
2315
- Parameters
2316
- ----------
2317
- key
2318
- A jax random key.
2319
- bart
2320
- A BART MCMC state.
2321
-
2322
- Returns
2323
- -------
2324
- The updated BART MCMC state.
2325
- """
2326
- trees_plus_offset = bart.z - bart.resid
2327
- lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf)
2328
- upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset)
2329
- resid = random.truncated_normal(key, lower, upper)
2330
- # TODO jax's implementation of truncated_normal is not good, it just does
2331
- # cdf inversion with erf and erf_inv. I can do better, at least avoiding to
2332
- # compute one of the boundaries, and maybe also flipping and using ndtr
2333
- # instead of erf for numerical stability (open an issue in jax?)
2334
- z = trees_plus_offset + resid
2335
- return replace(bart, z=z, resid=resid)